# Dependencies

In [1]:
!python -c "import torch; print(torch.__version__)"

1.13.0+cu116


In [2]:
%env TORCH=1.13.0+cu116
!pip install torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-geometric
!pip install ogb
!pip install networkx

env: TORCH=1.13.0+cu116
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.0+cu116.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.0%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (9.4 MB)
[K     |████████████████████████████████| 9.4 MB 43.2 MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.0+pt113cu116
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.13.0+cu116.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_sparse-0.6.16%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (4.5 MB)
[K     |████████████████████████████████| 4.5 MB 20.1 MB/s 
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.16+pt113cu116
Looking in i

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting ogb
  Downloading ogb-1.3.5-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 6.2 MB/s 
Collecting outdated>=0.2.0
  Downloading outdated-0.2.2-py2.py3-none-any.whl (7.5 kB)
Collecting littleutils
  Downloading littleutils-0.2.2.tar.gz (6.6 kB)
Building wheels for collected packages: littleutils
  Building wheel for littleutils (setup.py) ... [?25l[?25hdone
  Created wheel for littleutils: filename=littleutils-0.2.2-py3-none-any.whl size=7047 sha256=f2d5638f025a0b51d4861fed16dee6d0b8bbe654c4749c55ac75ffaf86c14836
  Stored in directory: /root/.cache/pip/wheels/6a/33/c4/0ef84d7f5568c2823e3d63a6e08988852fb9e4bc822034870a
Successfully built littleutils
Installing collected packages: littleutils, outdated, ogb
Successfully installed littleutils-0.2.2 ogb-1.3.5 outdated-0.2.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels

In [3]:
import os
import numpy as np
import random 
from tqdm import tqdm
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.data.data import Data 
import torch_geometric.utils as U
from torch_geometric.datasets import TUDataset
import torch_geometric.transforms as T
from torch_geometric.nn import MessagePassing, GATConv, global_mean_pool
from torch_geometric.utils.convert import from_scipy_sparse_matrix, from_networkx

from ogb.graphproppred import Evaluator, PygGraphPropPredDataset
from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder

import networkx as nx

# Data utilities

In [4]:
def get_dataset(dataset_name='MUTAG'):
    """
    Returns lists of adjacency matrices, node features, and graph labels for a 
    dataset in TUDataset.
    """
    raw_dataset = TUDataset(root='data/TUDataset', name=dataset_name)
    A = []
    X = []
    Y = []
    for graph in raw_dataset:
        adj_matrix = U.to_dense_adj(graph.edge_index).squeeze(0)
        A.append(adj_matrix)
        X.append(graph.x)
        Y.append(graph.y)

    return A, X, Y

def num_cayley_nodes(n):
    """
    Calculate number of nodes in a Cayley graph generated by the symmetric group 
    S_n using eq. (10) in https://arxiv.org/pdf/2210.02997.pdf.
    """
    prod = 1
    for p in range(2, n+1):
        prime = True
        for i in range(2, p):
            if (p % i) == 0:
                prime = False
                break
        
        if prime and (n % p) == 0:
            prod *= (1-1/(p**2))

    return n**3 * prod


def graph_mini_batch(adj_matrix_list,  x_list,  y_list, batch_size=32):
    """
    Iterator which outputs:
        A_B: Block-diag adjacency matrix for subgraphs in batch B
        FA_B: Block-diag adjacency matrix for fully-connected rewirings of 
              subgraphs in batch B
        cayley_B: Block-diag adjacency matrix for expander graph rewirings of 
                  subgraphs in batch B
        X_B: Batched node features list
        Y_B: Batched node labels list
        Batch: Vector that maps each node to its repsective graph in batch B
    """
    n = len(x_list)

    FA_adj_matrix_list = [
        torch.ones((len(adj), len(adj))) for adj in adj_matrix_list
    ]

    num_batches = n // batch_size + 1
    for i in range(num_batches):
        start = batch_size * i
        end = min(batch_size * (i+1), n)

        A_B = torch.block_diag(*adj_matrix_list[start:end])
        X_B = torch.cat(x_list[start:end], dim=0)
        Y_B = torch.cat(y_list[start:end], dim=0)
        batch_lists = [
            torch.LongTensor([j for k in range(len(g))]) for j, g in 
            enumerate(x_list[start:end])
        ]
        Batch = torch.cat(batch_lists, dim=-1)

        #Fully-connected
        FA_B = torch.block_diag(*FA_adj_matrix_list[start:end])

        #Expander
        cayley_adj_matrix_list = []
        for g in x_list[start:end]:
            l = len(g)

            n_opt = 1
            n_cayley = 2
            while n_cayley < l:
                n_opt += 1
                n_cayley = num_cayley_nodes(n_opt)

            # For the MUTAG and ENZYNES datasets, the optimal Cayley graph sizes
            # are either 3, 4, or 5
            if n_opt == 3:
                cayley_adj_matrix_list.append(
                    torch.tensor(np.load('cayley_3.npy'))[:l , :l])
            elif n_opt == 4:
                cayley_adj_matrix_list.append(
                    torch.tensor(np.load('cayley_4.npy'))[:l , :l])
            elif n_opt == 5:
                cayley_adj_matrix_list.append(
                    torch.tensor(np.load('cayley_5.npy'))[:l , :l])

        cayley_A_B = torch.block_diag(*cayley_adj_matrix_list)
    
        yield A_B, FA_B, cayley_A_B, X_B, Y_B, Batch


# GAT rewirings

In [5]:
class GAT(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, dropout=0.5):
        super(GAT, self).__init__()
        """
        Args:
            input_dim: input feaure dimension
            hidden_dim: hidden feature dimension
            output_dim: output dimensions
            n_layers: number of layers
            dropout: dropout_ratio
        """
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.dropout = dropout
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.conv_layers = nn.ModuleList([GATConv(input_dim, hidden_dim)])
        for i in range(1,n_layers):
            self.conv_layers.append(GATConv(hidden_dim, hidden_dim))

        self.linear = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, adj, FA_adj, cayley_adj, batch) -> torch.Tensor:
        edge_index = U.dense_to_sparse(adj)[0].to(self.device)
    
        for i in range(self.n_layers):
            x = self.conv_layers[i](x, edge_index)
         
            if i < self.n_layers - 1:
                x = F.relu(x) #Remove ReLU for the last layer

            x = F.dropout(x, p=self.dropout, training = self.training)

        x = global_mean_pool(x, batch)
        x = self.linear(x)
        return x

In [6]:
class GAT_FA(GAT):
    def forward(self, x, adj, FA_adj, cayley_adj, batch) -> torch.Tensor:
        edge_index = U.dense_to_sparse(adj)[0].to(self.device)
        FA_edge_index = U.dense_to_sparse(FA_adj)[0].to(self.device)

        for i in range(self.n_layers):
            if i < self.n_layers - 1:
                x = self.conv_layers[i](x, edge_index)
            else:
                x = self.conv_layers[i](x, FA_edge_index)
        
            if i < self.n_layers - 1:
                x = F.relu(x) #Remove ReLU for the last layer

            x = F.dropout(x, p=self.dropout, training = self.training)

        x = global_mean_pool(x, batch)
        x = self.linear(x)
        return x

In [7]:
class GAT_AllFA(GAT):
    def forward(self, x, adj, FA_adj, cayley_adj, batch) -> torch.Tensor:
        FA_edge_index = U.dense_to_sparse(FA_adj)[0].to(self.device)

        for i in range(self.n_layers):
            x = self.conv_layers[i](x, FA_edge_index)
        
            if i < self.n_layers - 1:
                x = F.relu(x) #Remove ReLU for the last layer

            x = F.dropout(x, p=self.dropout, training = self.training)

        x = global_mean_pool(x, batch)
        x = self.linear(x)
        return x
       

In [8]:
class GAT_Expander(GAT):
   def forward(self, x, adj, FA_adj, cayley_adj, batch) -> torch.Tensor:
        edge_index = U.dense_to_sparse(adj)[0].to(self.device)
        expander_edge_index = U.dense_to_sparse(cayley_adj)[0].to(self.device)

        for i in range(self.n_layers):
            if i % 2 == 0:
                x = self.conv_layers[i](x, edge_index)
            else:
                x = self.conv_layers[i](x, expander_edge_index)

            if i < self.n_layers - 1:
              x = F.relu(x) #Remove ReLU for the last layer

            x = F.dropout(x, p=self.dropout, training = self.training)

        x = global_mean_pool(x, batch)
        x = self.linear(x)
        return x

In [9]:
class GAT_Virtual(GAT):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, dropout=0.5):
        super().__init__(input_dim, hidden_dim, output_dim, n_layers, dropout)

        # Set the initial virtual node embedding to 0
        self.virtualnode_embedding = torch.nn.Embedding(1, hidden_dim)
        torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)

        # List of MLPs to transform virtual node at every layer
        self.mlp_virtualnode_list = torch.nn.ModuleList()
        for layer in range(n_layers - 1):
            self.mlp_virtualnode_list.append(
                torch.nn.Sequential(torch.nn.Linear(hidden_dim, 2*hidden_dim), \
                torch.nn.ReLU(), torch.nn.Linear(2*hidden_dim, hidden_dim), \
                torch.nn.ReLU()))
                

    def forward(self, x, adj, FA_adj, cayley_adj, batch) -> torch.Tensor:
        edge_index = U.dense_to_sparse(adj)[0].to(self.device)

        virtualnode_embedding = self.virtualnode_embedding(
            torch.zeros(batch[-1].item()+1).to(edge_index.dtype).to(self.device))

        for i in range(self.n_layers):
            if i > 0:
                x = x + virtualnode_embedding[batch]

            x = self.conv_layers[i](x, edge_index)
       
            if i < self.n_layers - 1:
                x = F.relu(x) #Remove ReLU for the last layer

            x = F.dropout(x, p=self.dropout, training = self.training)

            if i < self.n_layers - 1:
                # add message from graph nodes to virtual nodes
                virtualnode_embedding_temp = global_mean_pool(x, batch) + \
                                             virtualnode_embedding
                # transform virtual nodes using MLP
                virtualnode_embedding = F.dropout(
                    self.mlp_virtualnode_list[i](
                        virtualnode_embedding_temp),self.dropout,training=self.training)

        x = global_mean_pool(x, batch)
        x = self.linear(x)
        return x, att_weight_list

# Training and evaluation setup

In [19]:
def train(A, X, Y, batch_size, model, optimizer, criterion, device):
    model.train()

    losses = []
    correct = 0
    total_num = 0
    for adj, FA_adj, cayley_adj, x, y, batch in graph_mini_batch(A, X, Y, batch_size):
        adj = adj.to(device)
        FA_adj = FA_adj.to(device)
        cayley_adj = cayley_adj.to(device)
        x = x.to(device)
        y = y.to(device)
        batch=batch.to(device)

        optimizer.zero_grad()

        pred = model(x, adj, FA_adj, cayley_adj, batch)

        loss = criterion(pred, y)
        losses.append(loss.item())

        correct += (pred.argmax(dim=-1) == y).sum()
        total_num += len(y)
        
        loss.backward()
        optimizer.step()

    return np.mean(losses), correct/total_num 

def eval(A, X, Y, batch_size, model, device):
    model.eval()
   
    correct = 0
    total_num = 0
    for adj, FA_adj, cayley_adj, x, y, batch in graph_mini_batch(A, X, Y, batch_size):
        adj = adj.to(device)
        FA_adj = FA_adj.to(device)
        cayley_adj = cayley_adj.to(device)
        x = x.to(device)
        y = y.to(device)
        batch=batch.to(device)
       
        pred = model(x, adj, FA_adj, cayley_adj, batch)

        correct += (pred.argmax(dim=-1) == y).sum()
        total_num += len(y)

    return correct/total_num 


In [11]:
class EarlyStopper:
    def __init__(self, patience):
        self.patience = patience
        self.counter = 0
        self.prev_valid_score = 0
        self.best_model = None

    def early_stop(self, valid_score):
        if valid_score >= self.prev_valid_score:
            self.counter = 0
        else:
            self.counter += 1
            if self.counter == self.patience:
                return True
        self.prev_valid_score = valid_score
        return False

In [12]:
def train_model(A_train,X_train,Y_train,A_test,X_test,Y_test,params,verbose):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = params['model'](params['input_dim'], 
                                params['hidden_dim'], 
                                params['output_dim'], 
                                params['n_layers'],
                                params['dropout']).to(device)
    
    optimizer = optim.Adam(model.parameters(), 
                           lr=params['lr'], 
                           weight_decay=params['weight_decay'])
    
    criterion = nn.CrossEntropyLoss() 

    early_stopper = EarlyStopper(params['max_patience'])

    train_losses = []
    train_scores = []
    test_scores = []

    for epoch in tqdm(range(1, params['epochs']+1)):
        train_loss, train_acc = train(A_train,
                                      X_train, 
                                      Y_train, 
                                      params['batch_size'], 
                                      model, 
                                      optimizer, 
                                      criterion, 
                                      device)
        test_acc = eval(A_test, 
                        X_test, 
                        Y_test, 
                        params['batch_size'], 
                        model, 
                        device).item()
       
        epoch_len = len(str(params['epochs']))

        print_msg = (f'[{epoch:>{epoch_len}}/{params["epochs"]:>{epoch_len}}] ' +
                        f'loss: {train_loss:.5f} ' +
                        f'train acc: {train_acc:.5f} ' +
                        f'test acc: {test_acc:.5f} '
                        )
        if verbose:
            print(print_msg)

        train_losses.append(train_loss)
        train_scores.append(train_acc.item())
        test_scores.append(test_acc)

        if early_stopper.early_stop(train_acc):   
            print('Stopped early at epoch {}.'.format(epoch))          
            break

    best_epoch = np.argmax(train_scores)
    print('Best train acc: {}'.format(train_scores[best_epoch]), 
          'Test acc: {}'.format(test_scores[best_epoch]))
    return train_scores[best_epoch], test_scores[best_epoch]


In [13]:
def cross_val(A, X, Y, params, verbose=False):
    """
    10-fold cross-validation
    """
    group_size = len(A)//10+1

    train_accs = []
    test_accs = []
    for i in range(0, len(A), group_size):
        print('Run {}/10...'.format(i//group_size + 1))

        A_test = A[i: i+group_size]
        X_test = X[i: i+group_size]
        Y_test = Y[i: i+group_size]
        A_train = A[:i] + A[i+group_size:]
        X_train = X[:i] + X[i+group_size:]
        Y_train = Y[:i] + Y[i+group_size:]

        train_acc, test_acc = train_model(A_train, 
                                          X_train, 
                                          Y_train, 
                                          A_test, 
                                          X_test, 
                                          Y_test, 
                                          params, 
                                          verbose)
        train_accs.append(train_acc)
        test_accs.append(test_acc)
    print('Train accuracy:', np.mean(train_accs), '+- ', np.std(train_accs))
    print('Test accuracy:', np.mean(test_accs), '+- ', np.std(test_accs))

# Experiments

In [14]:
torch.manual_seed(42) 

<torch._C.Generator at 0x7fa2bde8f490>

## MUTAG

In [34]:
A, X, Y = get_dataset(dataset_name='MUTAG')

Downloading https://www.chrsmrrs.com/graphkerneldatasets/MUTAG.zip
Extracting data/TUDataset/MUTAG/MUTAG.zip
Processing...
Done!


In [35]:
# Shuffle before cross-val
ds = list(zip(A, X, Y))
random.shuffle(ds)
A, X, Y = zip(*ds)

### GAT

In [None]:
params = {
    'input_dim': 7,
    'hidden_dim': 64,
    'output_dim': 2,
    'n_layers': 4,
    'epochs': 500,
    'model':GAT,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.3,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)

Run 1/10...


100%|██████████| 500/500 [01:37<00:00,  5.14it/s]


Best train acc: 0.88165682554245 Test acc: 0.6842105388641357
Run 2/10...


100%|██████████| 500/500 [01:35<00:00,  5.22it/s]


Best train acc: 0.8520710468292236 Test acc: 0.6842105388641357
Run 3/10...


100%|██████████| 500/500 [01:35<00:00,  5.23it/s]


Best train acc: 0.8934911489486694 Test acc: 0.6315789222717285
Run 4/10...


100%|██████████| 500/500 [01:36<00:00,  5.18it/s]


Best train acc: 0.8875739574432373 Test acc: 0.7894737124443054
Run 5/10...


100%|██████████| 500/500 [01:35<00:00,  5.23it/s]


Best train acc: 0.8757396340370178 Test acc: 0.7368420958518982
Run 6/10...


100%|██████████| 500/500 [01:35<00:00,  5.23it/s]


Best train acc: 0.8520710468292236 Test acc: 0.7894737124443054
Run 7/10...


100%|██████████| 500/500 [01:36<00:00,  5.19it/s]


Best train acc: 0.8757396340370178 Test acc: 0.8947368264198303
Run 8/10...


100%|██████████| 500/500 [01:35<00:00,  5.24it/s]


Best train acc: 0.8757396340370178 Test acc: 0.8421052694320679
Run 9/10...


100%|██████████| 500/500 [01:36<00:00,  5.17it/s]


Best train acc: 0.8757396340370178 Test acc: 0.6842105388641357
Run 10/10...


100%|██████████| 500/500 [01:35<00:00,  5.22it/s]


Best train acc: 0.859649121761322 Test acc: 0.9411764740943909
Train accuracy: 0.8729471683502197 +-  0.013384732449527443
Test accuracy: 0.7678018629550933 +-  0.09649316048792862


### +FA

In [None]:
params = {
    'input_dim': 7,
    'hidden_dim': 64,
    'output_dim': 2,
    'n_layers': 4,
    'epochs': 500,
    'model':GAT_FA,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.3,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)

Run 1/10...


100%|██████████| 500/500 [01:36<00:00,  5.17it/s]


Best train acc: 0.8994082808494568 Test acc: 0.7368420958518982
Run 2/10...


100%|██████████| 500/500 [01:37<00:00,  5.15it/s]


Best train acc: 0.88165682554245 Test acc: 0.8421052694320679
Run 3/10...


100%|██████████| 500/500 [01:36<00:00,  5.18it/s]


Best train acc: 0.9112426042556763 Test acc: 0.6842105388641357
Run 4/10...


100%|██████████| 500/500 [01:37<00:00,  5.15it/s]


Best train acc: 0.9053254723548889 Test acc: 0.7368420958518982
Run 5/10...


100%|██████████| 500/500 [01:36<00:00,  5.17it/s]


Best train acc: 0.8994082808494568 Test acc: 0.7368420958518982
Run 6/10...


100%|██████████| 500/500 [01:36<00:00,  5.20it/s]


Best train acc: 0.8994082808494568 Test acc: 0.7368420958518982
Run 7/10...


100%|██████████| 500/500 [01:37<00:00,  5.13it/s]


Best train acc: 0.8875739574432373 Test acc: 0.8421052694320679
Run 8/10...


100%|██████████| 500/500 [01:36<00:00,  5.18it/s]


Best train acc: 0.9053254723548889 Test acc: 0.8421052694320679
Run 9/10...


100%|██████████| 500/500 [01:37<00:00,  5.12it/s]


Best train acc: 0.9053254723548889 Test acc: 0.7894737124443054
Run 10/10...


100%|██████████| 500/500 [01:36<00:00,  5.18it/s]

Best train acc: 0.8888888955116272 Test acc: 1.0
Train accuracy: 0.8983563542366028 +-  0.008943422299746128
Test accuracy: 0.7947368443012237 +-  0.0863222094991247





### Virtual node

In [None]:
params = {
    'input_dim': 7,
    'hidden_dim': 64,
    'output_dim': 2,
    'n_layers': 4,
    'epochs': 500,
    'model':GAT_Virtual,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.3,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)

Run 1/10...


100%|██████████| 500/500 [02:00<00:00,  4.14it/s]


Best train acc: 0.8994082808494568 Test acc: 0.6842105388641357
Run 2/10...


100%|██████████| 500/500 [01:58<00:00,  4.20it/s]


Best train acc: 0.9053254723548889 Test acc: 0.8947368264198303
Run 3/10...


100%|██████████| 500/500 [02:00<00:00,  4.16it/s]


Best train acc: 0.9349112510681152 Test acc: 0.7368420958518982
Run 4/10...


100%|██████████| 500/500 [01:59<00:00,  4.18it/s]


Best train acc: 0.9053254723548889 Test acc: 0.7368420958518982
Run 5/10...


100%|██████████| 500/500 [01:59<00:00,  4.17it/s]


Best train acc: 0.9053254723548889 Test acc: 0.7368420958518982
Run 6/10...


 38%|███▊      | 190/500 [00:45<01:14,  4.19it/s]


Stopped early at epoch 191.
Best train acc: 0.8757396340370178 Test acc: 0.6842105388641357
Run 7/10...


 76%|███████▌  | 381/500 [01:30<00:28,  4.19it/s]


Stopped early at epoch 382.
Best train acc: 0.9171597957611084 Test acc: 0.7894737124443054
Run 8/10...


100%|██████████| 500/500 [01:58<00:00,  4.22it/s]


Best train acc: 0.9230769276618958 Test acc: 0.7894737124443054
Run 9/10...


100%|██████████| 500/500 [01:58<00:00,  4.22it/s]


Best train acc: 0.9171597957611084 Test acc: 0.7894737124443054
Run 10/10...


100%|██████████| 500/500 [01:57<00:00,  4.27it/s]

Best train acc: 0.9005848169326782 Test acc: 1.0
Train accuracy: 0.9084016919136048 +-  0.01519444773241756
Test accuracy: 0.7842105329036713 +-  0.09251787013694499





### Expander graph propagation

In [None]:
params = {
    'input_dim': 7,
    'hidden_dim': 64,
    'output_dim': 2,
    'n_layers': 4,
    'epochs': 500,
    'model':GAT_Expander,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.3,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)

Run 1/10...


100%|██████████| 500/500 [01:37<00:00,  5.14it/s]


Best train acc: 0.88165682554245 Test acc: 0.6842105388641357
Run 2/10...


100%|██████████| 500/500 [01:36<00:00,  5.20it/s]


Best train acc: 0.8698225021362305 Test acc: 0.8421052694320679
Run 3/10...


100%|██████████| 500/500 [01:36<00:00,  5.20it/s]


Best train acc: 0.8994082808494568 Test acc: 0.7368420958518982
Run 4/10...


100%|██████████| 500/500 [01:37<00:00,  5.15it/s]


Best train acc: 0.8934911489486694 Test acc: 0.7368420958518982
Run 5/10...


100%|██████████| 500/500 [01:36<00:00,  5.21it/s]


Best train acc: 0.88165682554245 Test acc: 0.7368420958518982
Run 6/10...


100%|██████████| 500/500 [01:37<00:00,  5.13it/s]


Best train acc: 0.88165682554245 Test acc: 0.8947368264198303
Run 7/10...


100%|██████████| 500/500 [01:36<00:00,  5.19it/s]


Best train acc: 0.8639053106307983 Test acc: 0.8947368264198303
Run 8/10...


100%|██████████| 500/500 [01:36<00:00,  5.19it/s]


Best train acc: 0.8757396340370178 Test acc: 0.8421052694320679
Run 9/10...


100%|██████████| 500/500 [01:38<00:00,  5.06it/s]


Best train acc: 0.8757396340370178 Test acc: 0.6842105388641357
Run 10/10...


100%|██████████| 500/500 [01:36<00:00,  5.20it/s]

Best train acc: 0.8830409646034241 Test acc: 0.9411764740943909
Train accuracy: 0.8806117951869965 +-  0.009855729984466282
Test accuracy: 0.7993808031082154 +-  0.08955666584198103





### All FA

In [None]:
params = {
    'input_dim': 7,
    'hidden_dim': 64,
    'output_dim': 2,
    'n_layers': 4,
    'epochs': 500,
    'model':GAT_AllFA,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.3,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)


Run 1/10...


100%|██████████| 500/500 [01:36<00:00,  5.18it/s]


Best train acc: 0.9053254723548889 Test acc: 0.7368420958518982
Run 2/10...


100%|██████████| 500/500 [01:36<00:00,  5.20it/s]


Best train acc: 0.88165682554245 Test acc: 0.8947368264198303
Run 3/10...


100%|██████████| 500/500 [01:34<00:00,  5.27it/s]


Best train acc: 0.9112426042556763 Test acc: 0.7368420958518982
Run 4/10...


100%|██████████| 500/500 [01:35<00:00,  5.22it/s]


Best train acc: 0.8934911489486694 Test acc: 0.7368420958518982
Run 5/10...


100%|██████████| 500/500 [01:35<00:00,  5.23it/s]


Best train acc: 0.8994082808494568 Test acc: 0.7368420958518982
Run 6/10...


 85%|████████▍ | 423/500 [01:20<00:14,  5.25it/s]


Stopped early at epoch 424.
Best train acc: 0.8994082808494568 Test acc: 0.7368420958518982
Run 7/10...


100%|██████████| 500/500 [01:35<00:00,  5.23it/s]


Best train acc: 0.8934911489486694 Test acc: 0.8947368264198303
Run 8/10...


100%|██████████| 500/500 [01:34<00:00,  5.27it/s]


Best train acc: 0.8934911489486694 Test acc: 0.8421052694320679
Run 9/10...


100%|██████████| 500/500 [01:36<00:00,  5.19it/s]


Best train acc: 0.8934911489486694 Test acc: 0.6842105388641357
Run 10/10...


100%|██████████| 500/500 [01:35<00:00,  5.24it/s]

Best train acc: 0.9005848169326782 Test acc: 0.9411764740943909
Train accuracy: 0.8971590876579285 +-  0.007622492165058771
Test accuracy: 0.7941176414489746 +-  0.08524058029054181





## ENZYMES

In [15]:
A, X, Y = get_dataset(dataset_name='ENZYMES')

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Extracting data/TUDataset/ENZYMES/ENZYMES.zip
Processing...
Done!


In [16]:
# Shuffle before cross-val
ds = list(zip(A, X, Y))
random.shuffle(ds)
A, X, Y = zip(*ds)

### GAT

In [33]:
params = {
    'input_dim': 3,
    'hidden_dim': 64,
    'output_dim': 6,
    'n_layers': 4,
    'epochs': 500,
    'model': GAT,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.2,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)

Run 1/10...


 56%|█████▌    | 279/500 [02:26<01:55,  1.91it/s]


Stopped early at epoch 280.
Best train acc: 0.4025973975658417 Test acc: 0.2950819432735443
Run 2/10...


100%|██████████| 500/500 [04:21<00:00,  1.91it/s]


Best train acc: 0.4582560360431671 Test acc: 0.3442622721195221
Run 3/10...


100%|██████████| 500/500 [04:20<00:00,  1.92it/s]


Best train acc: 0.43784788250923157 Test acc: 0.37704914808273315
Run 4/10...


100%|██████████| 500/500 [04:20<00:00,  1.92it/s]


Best train acc: 0.44341373443603516 Test acc: 0.2950819432735443
Run 5/10...


100%|██████████| 500/500 [04:19<00:00,  1.92it/s]


Best train acc: 0.44155845046043396 Test acc: 0.2786885201931
Run 6/10...


100%|██████████| 500/500 [04:19<00:00,  1.93it/s]


Best train acc: 0.41001856327056885 Test acc: 0.4098360538482666
Run 7/10...


 45%|████▍     | 224/500 [01:56<02:24,  1.92it/s]


Stopped early at epoch 225.
Best train acc: 0.38033396005630493 Test acc: 0.21311473846435547
Run 8/10...


100%|██████████| 500/500 [04:20<00:00,  1.92it/s]


Best train acc: 0.45083487033843994 Test acc: 0.32786881923675537
Run 9/10...


100%|██████████| 500/500 [04:22<00:00,  1.91it/s]


Best train acc: 0.4322820007801056 Test acc: 0.3606557250022888
Run 10/10...


100%|██████████| 500/500 [04:24<00:00,  1.89it/s]

Best train acc: 0.44990891218185425 Test acc: 0.37254902720451355
Train accuracy: 0.4307051807641983 +-  0.023708829945163035
Test accuracy: 0.32741881906986237 +-  0.054965911348480716





### +FA

In [31]:
params = {
    'input_dim': 3,
    'hidden_dim': 64,
    'output_dim': 6,
    'n_layers': 4,
    'epochs': 500,
    'model': GAT_FA,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.2,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)

Run 1/10...


 53%|█████▎    | 267/500 [02:23<02:05,  1.85it/s]


Stopped early at epoch 268.
Best train acc: 0.36920222640037537 Test acc: 0.2950819432735443
Run 2/10...


100%|██████████| 500/500 [04:27<00:00,  1.87it/s]


Best train acc: 0.44897958636283875 Test acc: 0.2786885201931
Run 3/10...


100%|██████████| 500/500 [04:29<00:00,  1.86it/s]


Best train acc: 0.43784788250923157 Test acc: 0.3934426009654999
Run 4/10...


100%|██████████| 500/500 [04:26<00:00,  1.88it/s]


Best train acc: 0.44897958636283875 Test acc: 0.32786881923675537
Run 5/10...


100%|██████████| 500/500 [04:26<00:00,  1.88it/s]


Best train acc: 0.44526901841163635 Test acc: 0.24590162932872772
Run 6/10...


 66%|██████▌   | 331/500 [02:55<01:29,  1.88it/s]


Stopped early at epoch 332.
Best train acc: 0.3562152087688446 Test acc: 0.32786881923675537
Run 7/10...


100%|██████████| 500/500 [04:25<00:00,  1.88it/s]


Best train acc: 0.43970316648483276 Test acc: 0.21311473846435547
Run 8/10...


100%|██████████| 500/500 [04:26<00:00,  1.88it/s]


Best train acc: 0.4341372847557068 Test acc: 0.2950819432735443
Run 9/10...


100%|██████████| 500/500 [04:25<00:00,  1.89it/s]


Best train acc: 0.4582560360431671 Test acc: 0.44262292981147766
Run 10/10...


 51%|█████     | 253/500 [02:16<02:13,  1.85it/s]

Stopped early at epoch 254.
Best train acc: 0.375227689743042 Test acc: 0.3333333432674408
Train accuracy: 0.4213817685842514 +-  0.03650025377065803
Test accuracy: 0.3153005287051201 +-  0.06354089081703616





### Virtual node

In [None]:
params = {
    'input_dim': 3,
    'hidden_dim': 64,
    'output_dim': 6,
    'n_layers': 4,
    'epochs': 500,
    'model': GAT_Virtual,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.3,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)

Run 1/10...


 61%|██████    | 305/500 [04:01<02:34,  1.26it/s]


Stopped early at epoch 306.
Best train acc: 0.7625231742858887 Test acc: 0.3606557250022888
Run 2/10...


 51%|█████     | 254/500 [03:18<03:12,  1.28it/s]


Stopped early at epoch 255.
Best train acc: 0.706864595413208 Test acc: 0.31147539615631104
Run 3/10...


 55%|█████▌    | 276/500 [03:35<02:54,  1.28it/s]


Stopped early at epoch 277.
Best train acc: 0.706864595413208 Test acc: 0.44262292981147766
Run 4/10...


 24%|██▍       | 119/500 [01:33<04:59,  1.27it/s]


Stopped early at epoch 120.
Best train acc: 0.5157699584960938 Test acc: 0.4590163826942444
Run 5/10...


 50%|█████     | 252/500 [03:15<03:12,  1.29it/s]


Stopped early at epoch 253.
Best train acc: 0.7365491986274719 Test acc: 0.42622947692871094
Run 6/10...


 42%|████▏     | 209/500 [02:41<03:45,  1.29it/s]


Stopped early at epoch 210.
Best train acc: 0.6122449040412903 Test acc: 0.3442622721195221
Run 7/10...


 45%|████▌     | 227/500 [02:56<03:31,  1.29it/s]


Stopped early at epoch 228.
Best train acc: 0.6883116960525513 Test acc: 0.4754098057746887
Run 8/10...


 47%|████▋     | 236/500 [03:03<03:24,  1.29it/s]


Stopped early at epoch 237.
Best train acc: 0.6474953889846802 Test acc: 0.44262292981147766
Run 9/10...


 76%|███████▌  | 380/500 [04:55<01:33,  1.29it/s]


Stopped early at epoch 381.
Best train acc: 0.85899817943573 Test acc: 0.42622947692871094
Run 10/10...


100%|██████████| 500/500 [06:34<00:00,  1.27it/s]

Best train acc: 0.8269581198692322 Test acc: 0.4705882668495178
Train accuracy: 0.7062579810619354 +-  0.0953317440649498
Test accuracy: 0.41591126620769503 +-  0.05396412021721359





### Expander graph propagation

In [32]:
params = {
    'input_dim': 3,
    'hidden_dim': 64,
    'output_dim': 6,
    'n_layers': 4,
    'epochs': 500,
    'model': GAT_Expander,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.3,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)

Run 1/10...


 90%|████████▉ | 449/500 [03:57<00:26,  1.89it/s]


Stopped early at epoch 450.
Best train acc: 0.5194805264472961 Test acc: 0.3606557250022888
Run 2/10...


100%|██████████| 500/500 [04:22<00:00,  1.90it/s]


Best train acc: 0.4953617751598358 Test acc: 0.31147539615631104
Run 3/10...


100%|██████████| 500/500 [04:21<00:00,  1.91it/s]


Best train acc: 0.48051947355270386 Test acc: 0.49180325865745544
Run 4/10...


 61%|██████    | 306/500 [02:41<01:42,  1.90it/s]


Stopped early at epoch 307.
Best train acc: 0.4304267168045044 Test acc: 0.37704914808273315
Run 5/10...


100%|██████████| 500/500 [04:24<00:00,  1.89it/s]


Best train acc: 0.5046381950378418 Test acc: 0.37704914808273315
Run 6/10...


 51%|█████     | 256/500 [02:18<02:12,  1.85it/s]


Stopped early at epoch 257.
Best train acc: 0.41929498314857483 Test acc: 0.44262292981147766
Run 7/10...


100%|██████████| 500/500 [04:21<00:00,  1.91it/s]


Best train acc: 0.5120593905448914 Test acc: 0.21311473846435547
Run 8/10...


100%|██████████| 500/500 [04:23<00:00,  1.90it/s]


Best train acc: 0.4730983376502991 Test acc: 0.31147539615631104
Run 9/10...


100%|██████████| 500/500 [04:23<00:00,  1.90it/s]


Best train acc: 0.48608535528182983 Test acc: 0.2950819432735443
Run 10/10...


100%|██████████| 500/500 [04:26<00:00,  1.87it/s]

Best train acc: 0.46265938878059387 Test acc: 0.3921568691730499
Train accuracy: 0.4783624142408371 +-  0.031513433691532276
Test accuracy: 0.357248455286026 +-  0.07499179098616035





### All FA

In [None]:
params = {
    'input_dim': 3,
    'hidden_dim': 64,
    'output_dim': 6,
    'n_layers': 4,
    'epochs': 500,
    'model': GAT_AllFA,
    'lr': 0.001,
    'weight_decay': 0,
    'max_patience': 5,
    'dropout': 0.3,
    'batch_size': 32
}

cross_val(A, X, Y, params, verbose=False)

Run 1/10...


100%|██████████| 500/500 [05:18<00:00,  1.57it/s]


Best train acc: 0.38033396005630493 Test acc: 0.19672130048274994
Run 2/10...


100%|██████████| 500/500 [05:14<00:00,  1.59it/s]


Best train acc: 0.3636363744735718 Test acc: 0.2295081913471222
Run 3/10...


100%|██████████| 500/500 [05:15<00:00,  1.58it/s]


Best train acc: 0.37847867608070374 Test acc: 0.19672130048274994
Run 4/10...


100%|██████████| 500/500 [05:13<00:00,  1.59it/s]


Best train acc: 0.3858998119831085 Test acc: 0.3606557250022888
Run 5/10...


100%|██████████| 500/500 [05:13<00:00,  1.59it/s]


Best train acc: 0.36734694242477417 Test acc: 0.2950819432735443
Run 6/10...


100%|██████████| 500/500 [05:14<00:00,  1.59it/s]


Best train acc: 0.36920222640037537 Test acc: 0.2950819432735443
Run 7/10...


100%|██████████| 500/500 [05:13<00:00,  1.59it/s]


Best train acc: 0.37476807832717896 Test acc: 0.19672130048274994
Run 8/10...


100%|██████████| 500/500 [05:12<00:00,  1.60it/s]


Best train acc: 0.37105751037597656 Test acc: 0.31147539615631104
Run 9/10...


100%|██████████| 500/500 [05:14<00:00,  1.59it/s]


Best train acc: 0.38218924403190613 Test acc: 0.37704914808273315
Run 10/10...


100%|██████████| 500/500 [05:17<00:00,  1.57it/s]

Best train acc: 0.3570127487182617 Test acc: 0.4117647111415863
Train accuracy: 0.3729925572872162 +-  0.008551016667978207
Test accuracy: 0.28707809597253797 +-  0.07588537837119338



