In [1]:
import torch
from torch import nn
from torch_geometric.nn import GCNConv, GATConv, ChebConv, SAGEConv
from torch.nn import Linear
import torch.nn.functional as F
from GNNNestedCVEvaluation import GNNNestedCVEvaluation
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import add_self_loops
from hyperopt import hp
import numpy as np
from tqdm.notebook import tqdm

  _torch_pytree._register_pytree_node(


In [2]:
class GCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = .2, normalize = False, add_self_loops = True):
        super(GCN, self).__init__()
        hidden_dim = int(hidden_dim)
        self.conv1 = GCNConv(in_dim, hidden_dim, normalize = normalize, add_self_loops=add_self_loops)
        self.conv2 = GCNConv(hidden_dim, out_dim, normalize = normalize, add_self_loops=add_self_loops)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [3]:
class GAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = .2, heads = 1, add_self_loops = True):
        super(GAT, self).__init__()
        hidden_dim = int(hidden_dim)
        heads = int(heads)
        self.conv1 = GATConv(in_dim, hidden_dim, add_self_loops=add_self_loops, concat=True, dropout = dropout, heads = heads)
        self.conv2 = GATConv(hidden_dim*heads, out_dim, add_self_loops=add_self_loops, concat=False, dropout = dropout, heads = heads)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [4]:
class Cheb(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = .2, K = 2, normalization = "sym"):
        super(Cheb, self).__init__()
        hidden_dim = int(hidden_dim)
        K = int(K)
        self.conv1 = ChebConv(in_dim, hidden_dim, K = K, normalization = normalization)
        self.conv2 = ChebConv(hidden_dim, out_dim, K = K, normalization = normalization)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [5]:
class SAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = .2, normalize = False, project = True, root_weight = True):
        super(SAGE, self).__init__()
        hidden_dim = int(hidden_dim)
        self.conv1 = SAGEConv(in_dim, hidden_dim, normalize = normalize, project = project, root_weight = root_weight)
        self.conv2 = SAGEConv(hidden_dim, out_dim, normalize = normalize, project = project, root_weight = root_weight)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [6]:
dataset_name = 'Cora'
split = "public"
dataset = Planetoid(root='data/', name=dataset_name, split=split)
dataset.transform = T.NormalizeFeatures()

In [7]:
class GNNSpace():
    def __init__(self, dataset):
        self.hidden_dim_limits = (8, 1024)
        self.dropout_limits = (0.0, 0.8)
        self.weight_decay_limits = (1e-5, 1e-2)
        self.lr_limits = (1e-4, 1e-1)
        self.out_dim = [dataset.num_classes]
        self.gnn_space = None
        self.initialize_space()

    def initialize_space(self):
        gnn_choices = {
            'out_dim': self.out_dim
        }
         
        self.gnn_space = {
            **{key: hp.choice(key, value) for key, value in gnn_choices.items()},
            'lr': hp.loguniform('lr',np.log(self.lr_limits[0]), np.log(self.lr_limits[1])),
            'weight_decay': hp.loguniform('weight_decay',np.log(self.weight_decay_limits[0]), np.log(self.weight_decay_limits[1])),
            'dropout': hp.uniform('dropout', self.dropout_limits[0], self.dropout_limits[1]),
            'hidden_dim': hp.qloguniform('hidden_dim', low=np.log(self.hidden_dim_limits[0]), high=np.log(self.hidden_dim_limits[1]), q=16)
        }
        
    def add_choice(self, key, items):
        self.gnn_space[key] = hp.choice(key, items)
        
    def add_uniform(self, key, limits: tuple):
        self.gnn_space[key] = hp.uniform(key, limits[0], limits[1])
        
    def add_loguniform(self, key, limits: tuple):
        self.gnn_space[key] = hp.loguniform(key, np.log(limits[0]), np.log(limits[1]))
        
    def add_qloguniform(self, key, limits, q):
        self.gnn_space[key] = hp.qloguniform(key, low=np.log(limits[0]), high=np.log(limits[1]), q=q)

class GCNSpace(GNNSpace):
    def __init__(self, dataset):
        super().__init__(dataset)

    def get_space(self):
        self.add_choice('normalize', [True])
        self.add_choice('add_self_loops', [True, False])
        return self.gnn_space    

class GATSpace(GNNSpace):
    def __init__(self, dataset):
        super().__init__(dataset)

    def get_space(self):
        self.add_qloguniform('heads', (1, 8), 2)
        self.add_choice('add_self_loops', [True, False])
        return self.gnn_space    

class ChebSpace(GNNSpace):
    def __init__(self, dataset):
        super().__init__(dataset)

    def get_space(self):
        self.add_qloguniform('K', (1, 4), 2)
        self.add_choice('normalization', ["sym", "rw", None])
        return self.gnn_space    

class SAGESpace(GNNSpace):
    def __init__(self, dataset):
        super().__init__(dataset)

    def get_space(self):
        self.add_choice('normalize', [True, False])
        self.add_choice('project', [True, False])
        self.add_choice('root_weight', [True, False])
        return self.gnn_space   

In [8]:
data = dataset[0]

In [9]:
## Evaluate performance on CPU  ? For fairness to skleanr?

In [10]:
device = torch.device("cuda:2")

In [11]:
gcn_space = GCNSpace(dataset)
gat_space = GATSpace(dataset)
cheb_space = ChebSpace(dataset)
sage_space = SAGESpace(dataset)

In [12]:
gnns = [GCN, GAT, Cheb, SAGE]
gnn_spaces = [gcn_space.get_space(), gat_space.get_space(), cheb_space.get_space(), sage_space.get_space()]

In [13]:
score_store = {}
param_store = {}

In [14]:
for i, space in tqdm(enumerate(gnn_spaces)):
    gnn_nestedCV_evaluation = GNNNestedCVEvaluation(device, gnns[i],data, max_evals= len(space.keys())*20)
    gnn_nestedCV_evaluation.nested_cross_validate(3, 3, space)
    score_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.outer_scores
    param_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.best_params_per_fold

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/07/11 15:50:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

In [15]:
for key in score_store:
    print(f"{key}: {score_store[key].mean()} +- {score_store[key].std()}")

GCN: 0.8722293178240458 +- 0.005838102617249471
GAT: 0.8777684569358826 +- 0.005007765279293293
Cheb: 0.8733387589454651 +- 0.0013324528694526733
SAGE: 0.8777725497881571 +- 0.011455236431865783


In [16]:
score_store

{'GCN': array([0.86710966, 0.88039869, 0.86917961]),
 'GAT': array([0.87375414, 0.88482839, 0.87472284]),
 'Cheb': array([0.87153929, 0.87375414, 0.87472284]),
 'SAGE': array([0.86157256, 0.88593578, 0.8858093 ])}

In [17]:
param_store

{'GCN': [{'add_self_loops': True,
   'dropout': 0.47424160371604307,
   'hidden_dim': 144.0,
   'lr': 0.021272837679656945,
   'normalize': True,
   'out_dim': 7,
   'weight_decay': 2.1772420021592394e-05},
  {'add_self_loops': True,
   'dropout': 0.537811056200838,
   'hidden_dim': 80.0,
   'lr': 0.058222676194743604,
   'normalize': True,
   'out_dim': 7,
   'weight_decay': 3.2384324365308864e-05},
  {'add_self_loops': True,
   'dropout': 0.2637331666881179,
   'hidden_dim': 96.0,
   'lr': 0.013496083012736382,
   'normalize': True,
   'out_dim': 7,
   'weight_decay': 1.006404989082683e-05}],
 'GAT': [{'add_self_loops': True,
   'dropout': 0.6895628641088936,
   'heads': 4.0,
   'hidden_dim': 80.0,
   'lr': 0.01795645123205303,
   'out_dim': 7,
   'weight_decay': 1.3457235596951733e-05},
  {'add_self_loops': True,
   'dropout': 0.39026223827683415,
   'heads': 4.0,
   'hidden_dim': 16.0,
   'lr': 0.04339016312002795,
   'out_dim': 7,
   'weight_decay': 6.484918193992611e-05},
  {'add

In [18]:
citeseer_dataset = Planetoid(root='data/', name='CiteSeer', split="public")
citeseer_dataset.transform = T.NormalizeFeatures()
citeseer_data = citeseer_dataset[0]

In [19]:
gcn_space = GCNSpace(citeseer_dataset)
gat_space = GATSpace(citeseer_dataset)
cheb_space = ChebSpace(citeseer_dataset)
sage_space = SAGESpace(citeseer_dataset)

In [20]:
gnns = [GCN, GAT, Cheb, SAGE]
gnn_spaces = [gcn_space.get_space(), gat_space.get_space(), cheb_space.get_space(), sage_space.get_space()]

In [21]:
citeseer_score_store = {}
citeseer_param_store = {}
for i, space in tqdm(enumerate(gnn_spaces)):
    gnn_nestedCV_evaluation = GNNNestedCVEvaluation(device, gnns[i],citeseer_data, max_evals= len(space.keys())*20)
    gnn_nestedCV_evaluation.nested_cross_validate(3, 3, space)
    citeseer_score_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.outer_scores
    citeseer_param_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.best_params_per_fold

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

In [22]:
for key in citeseer_score_store:
    print(f"{key}: {citeseer_score_store[key].mean()} +- {citeseer_score_store[key].std()}")

GCN: 0.7586414217948914 +- 0.02141872646009377
GAT: 0.7589419881502787 +- 0.023792611182515678
Cheb: 0.7613465785980225 +- 0.02056661917717496
SAGE: 0.7526300152142843 +- 0.026559326758019863


In [23]:
pubmed_dataset = Planetoid(root='data/', name='PubMed', split="public")
pubmed_dataset.transform = T.NormalizeFeatures()
pubmed_data = pubmed_dataset[0]

gcn_space = GCNSpace(pubmed_dataset)
gat_space = GATSpace(pubmed_dataset)
cheb_space = ChebSpace(pubmed_dataset)
sage_space = SAGESpace(pubmed_dataset)

gnns = [GCN, GAT, Cheb, SAGE]
gnn_spaces = [gcn_space.get_space(), gat_space.get_space(), cheb_space.get_space(), sage_space.get_space()]

In [24]:
pubmed_score_store = {}
pubmed_param_store = {}
for i, space in tqdm(enumerate(gnn_spaces)):
    gnn_nestedCV_evaluation = GNNNestedCVEvaluation(device, gnns[i],pubmed_data, max_evals= len(space.keys())*20)
    gnn_nestedCV_evaluation.nested_cross_validate(3, 3, space)
    pubmed_score_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.outer_scores
    pubmed_param_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.best_params_per_fold

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

In [25]:
for key in pubmed_score_store:
    print(f"{key}: {pubmed_score_store[key].mean()} +- {pubmed_score_store[key].std()}")

GCN: 0.8725467125574747 +- 0.00289239374290525
GAT: 0.8681343992551168 +- 0.0057158039541598095
Cheb: 0.8941524227460226 +- 0.002923933153628512
SAGE: 0.884212056795756 +- 0.0066458837470739366


In [26]:
for key in pubmed_score_store:
    print(f"{key}: {pubmed_param_store[key]}")

GCN: [{'add_self_loops': True, 'dropout': 0.4556927805939384, 'hidden_dim': 896.0, 'lr': 0.06233173239342549, 'normalize': True, 'out_dim': 3, 'weight_decay': 1.6283791718207615e-05}, {'add_self_loops': True, 'dropout': 0.6527405492907251, 'hidden_dim': 48.0, 'lr': 0.09593362782798921, 'normalize': True, 'out_dim': 3, 'weight_decay': 4.85583129410908e-05}, {'add_self_loops': True, 'dropout': 0.10280714642195694, 'hidden_dim': 528.0, 'lr': 0.05518335977213315, 'normalize': True, 'out_dim': 3, 'weight_decay': 1.2298512419921563e-05}]
GAT: [{'add_self_loops': True, 'dropout': 0.5057827782559263, 'heads': 8.0, 'hidden_dim': 16.0, 'lr': 0.05270726194179401, 'out_dim': 3, 'weight_decay': 2.1907183200679608e-05}, {'add_self_loops': True, 'dropout': 0.4905219112044247, 'heads': 8.0, 'hidden_dim': 32.0, 'lr': 0.060790365936704824, 'out_dim': 3, 'weight_decay': 1.8242465880828062e-05}, {'add_self_loops': True, 'dropout': 0.2257741434998842, 'heads': 2.0, 'hidden_dim': 144.0, 'lr': 0.030085450863

In [27]:
for key in citeseer_param_store:
    print(f"{key}: {citeseer_param_store[key]}")

GCN: [{'add_self_loops': True, 'dropout': 0.7851715415606595, 'hidden_dim': 96.0, 'lr': 0.08024489398851882, 'normalize': True, 'out_dim': 6, 'weight_decay': 5.69183240961598e-05}, {'add_self_loops': True, 'dropout': 0.08797559362505392, 'hidden_dim': 208.0, 'lr': 0.014539829638287207, 'normalize': True, 'out_dim': 6, 'weight_decay': 0.0003462086942167347}, {'add_self_loops': True, 'dropout': 0.38806165574913687, 'hidden_dim': 144.0, 'lr': 0.028691345829894827, 'normalize': True, 'out_dim': 6, 'weight_decay': 7.173951140296107e-05}]
GAT: [{'add_self_loops': True, 'dropout': 0.4333100747873734, 'heads': 4.0, 'hidden_dim': 192.0, 'lr': 0.01226185951601327, 'out_dim': 6, 'weight_decay': 1.6104813864862927e-05}, {'add_self_loops': True, 'dropout': 0.4914647997981434, 'heads': 4.0, 'hidden_dim': 720.0, 'lr': 0.005733510025359498, 'out_dim': 6, 'weight_decay': 0.0001643711258758383}, {'add_self_loops': True, 'dropout': 0.5119548332427251, 'heads': 2.0, 'hidden_dim': 240.0, 'lr': 0.0280903211

In [28]:
for key in param_store:
    print(f"{key}: {param_store[key]}")

GCN: [{'add_self_loops': True, 'dropout': 0.47424160371604307, 'hidden_dim': 144.0, 'lr': 0.021272837679656945, 'normalize': True, 'out_dim': 7, 'weight_decay': 2.1772420021592394e-05}, {'add_self_loops': True, 'dropout': 0.537811056200838, 'hidden_dim': 80.0, 'lr': 0.058222676194743604, 'normalize': True, 'out_dim': 7, 'weight_decay': 3.2384324365308864e-05}, {'add_self_loops': True, 'dropout': 0.2637331666881179, 'hidden_dim': 96.0, 'lr': 0.013496083012736382, 'normalize': True, 'out_dim': 7, 'weight_decay': 1.006404989082683e-05}]
GAT: [{'add_self_loops': True, 'dropout': 0.6895628641088936, 'heads': 4.0, 'hidden_dim': 80.0, 'lr': 0.01795645123205303, 'out_dim': 7, 'weight_decay': 1.3457235596951733e-05}, {'add_self_loops': True, 'dropout': 0.39026223827683415, 'heads': 4.0, 'hidden_dim': 16.0, 'lr': 0.04339016312002795, 'out_dim': 7, 'weight_decay': 6.484918193992611e-05}, {'add_self_loops': True, 'dropout': 0.5311029935837578, 'heads': 4.0, 'hidden_dim': 736.0, 'lr': 0.00551493591