In [1]:
import torch
from torch import nn
from torch_geometric.nn import GCNConv, GATConv, ChebConv, SAGEConv
from torch.nn import Linear
import torch.nn.functional as F
from GNNNestedCVEvaluation import GNNNestedCVEvaluation
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import add_self_loops
from hyperopt import hp
import numpy as np
from tqdm.notebook import tqdm

  _torch_pytree._register_pytree_node(


In [2]:
class GCN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = .2, normalize = False, add_self_loops = True):
        super(GCN, self).__init__()
        hidden_dim = int(hidden_dim)
        self.conv1 = GCNConv(in_dim, hidden_dim, normalize = normalize, add_self_loops=add_self_loops)
        self.conv2 = GCNConv(hidden_dim, out_dim, normalize = normalize, add_self_loops=add_self_loops)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [3]:
class GAT(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = .2, heads = 1, add_self_loops = True):
        super(GAT, self).__init__()
        hidden_dim = int(hidden_dim)
        heads = int(heads)
        self.conv1 = GATConv(in_dim, hidden_dim, add_self_loops=add_self_loops, concat=True, dropout = dropout, heads = heads)
        self.conv2 = GATConv(hidden_dim*heads, out_dim, add_self_loops=add_self_loops, concat=False, dropout = dropout, heads = heads)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [4]:
class Cheb(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = .2, K = 2, normalization = "sym"):
        super(Cheb, self).__init__()
        hidden_dim = int(hidden_dim)
        K = int(K)
        self.conv1 = ChebConv(in_dim, hidden_dim, K = K, normalization = normalization)
        self.conv2 = ChebConv(hidden_dim, out_dim, K = K, normalization = normalization)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [5]:
class SAGE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout = .2, normalize = False, project = True, root_weight = True):
        super(SAGE, self).__init__()
        hidden_dim = int(hidden_dim)
        self.conv1 = SAGEConv(in_dim, hidden_dim, normalize = normalize, project = project, root_weight = root_weight)
        self.conv2 = SAGEConv(hidden_dim, out_dim, normalize = normalize, project = project, root_weight = root_weight)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout(x)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

In [6]:
dataset_name = 'Cora'
split = "public"
dataset = Planetoid(root='data/', name=dataset_name, split=split)
dataset.transform = T.NormalizeFeatures()

In [7]:
class GNNSpace():
    def __init__(self, dataset):
        self.hidden_dim_limits = (8, 1024)
        self.dropout_limits = (0.0, 0.8)
        self.weight_decay_limits = (1e-5, 1e-2)
        self.lr_limits = (1e-4, 1e-1)
        self.out_dim = [dataset.num_classes]
        self.gnn_space = None
        self.initialize_space()

    def initialize_space(self):
        gnn_choices = {
            'out_dim': self.out_dim
        }
         
        self.gnn_space = {
            **{key: hp.choice(key, value) for key, value in gnn_choices.items()},
            'lr': hp.loguniform('lr',np.log(self.lr_limits[0]), np.log(self.lr_limits[1])),
            'weight_decay': hp.loguniform('weight_decay',np.log(self.weight_decay_limits[0]), np.log(self.weight_decay_limits[1])),
            'dropout': hp.uniform('dropout', self.dropout_limits[0], self.dropout_limits[1]),
            'hidden_dim': hp.qloguniform('hidden_dim', low=np.log(self.hidden_dim_limits[0]), high=np.log(self.hidden_dim_limits[1]), q=16)
        }
        
    def add_choice(self, key, items):
        self.gnn_space[key] = hp.choice(key, items)
        
    def add_uniform(self, key, limits: tuple):
        self.gnn_space[key] = hp.uniform(key, limits[0], limits[1])
        
    def add_loguniform(self, key, limits: tuple):
        self.gnn_space[key] = hp.loguniform(key, np.log(limits[0]), np.log(limits[1]))
        
    def add_qloguniform(self, key, limits, q):
        self.gnn_space[key] = hp.qloguniform(key, low=np.log(limits[0]), high=np.log(limits[1]), q=q)

class GCNSpace(GNNSpace):
    def __init__(self, dataset):
        super().__init__(dataset)

    def get_space(self):
        self.add_choice('normalize', [True])
        self.add_choice('add_self_loops', [True, False])
        return self.gnn_space    

class GATSpace(GNNSpace):
    def __init__(self, dataset):
        super().__init__(dataset)

    def get_space(self):
        self.add_qloguniform('heads', (1, 8), 2)
        self.add_choice('add_self_loops', [True, False])
        return self.gnn_space    

class ChebSpace(GNNSpace):
    def __init__(self, dataset):
        super().__init__(dataset)

    def get_space(self):
        self.add_qloguniform('K', (1, 4), 2)
        self.add_choice('normalization', ["sym", "rw", None])
        return self.gnn_space    

class SAGESpace(GNNSpace):
    def __init__(self, dataset):
        super().__init__(dataset)

    def get_space(self):
        self.add_choice('normalize', [True, False])
        self.add_choice('project', [True, False])
        self.add_choice('root_weight', [True, False])
        return self.gnn_space   

In [21]:
data = dataset[0]

In [9]:
device = torch.device("cuda:0")

In [11]:
gcn_space = GCNSpace(dataset)
gat_space = GATSpace(dataset)
cheb_space = ChebSpace(dataset)
sage_space = SAGESpace(dataset)

In [12]:
gnns = [GCN, GAT, Cheb, SAGE]
gnn_spaces = [gcn_space.get_space(), gat_space.get_space(), cheb_space.get_space(), sage_space.get_space()]

In [13]:
score_store = {}
param_store = {}

In [15]:
for i, space in tqdm(enumerate(gnn_spaces)):
    gnn_nestedCV_evaluation = GNNNestedCVEvaluation(device, gnns[i],data, max_evals= len(space.keys())*20)
    gnn_nestedCV_evaluation.nested_cross_validate(3, 3, space)
    score_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.outer_scores
    param_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.best_params_per_fold

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

In [16]:
for key in score_store:
    print(f"{key}: {score_store[key].mean()} +- {score_store[key].std()}")

GCN: 0.8792458375295004 +- 0.004806010591741205
GAT: 0.8770338694254557 +- 0.009405334532765292
Cheb: 0.8792450030644735 +- 0.005537561156159736
SAGE: 0.877399722735087 +- 0.008213056437160679


In [17]:
score_store

{'GCN': array([0.8748616 , 0.88593578, 0.87694013]),
 'GAT': array([0.86378741, 0.88261354, 0.88470066]),
 'Cheb': array([0.87596899, 0.88704318, 0.87472284]),
 'SAGE': array([0.86821705, 0.88815063, 0.87583148])}

In [18]:
param_store

{'GCN': [{'add_self_loops': True,
   'dropout': 0.7830821518165071,
   'hidden_dim': 208.0,
   'lr': 0.0026296407267124198,
   'normalize': True,
   'out_dim': 7,
   'weight_decay': 9.53314742679954e-05},
  {'add_self_loops': True,
   'dropout': 0.5361831509471188,
   'hidden_dim': 80.0,
   'lr': 0.007534386754083397,
   'normalize': True,
   'out_dim': 7,
   'weight_decay': 6.782625665826144e-05},
  {'add_self_loops': True,
   'dropout': 0.691592115826571,
   'hidden_dim': 160.0,
   'lr': 0.003976187118461275,
   'normalize': True,
   'out_dim': 7,
   'weight_decay': 1.7857147267971386e-05}],
 'GAT': [{'add_self_loops': True,
   'dropout': 0.5477675890361065,
   'heads': 4.0,
   'hidden_dim': 32.0,
   'lr': 0.018006349199088747,
   'out_dim': 7,
   'weight_decay': 7.754420752713125e-05},
  {'add_self_loops': True,
   'dropout': 0.6764701952981045,
   'heads': 4.0,
   'hidden_dim': 32.0,
   'lr': 0.006405105129334616,
   'out_dim': 7,
   'weight_decay': 2.782719770741545e-05},
  {'add_

In [22]:
citeseer_dataset = Planetoid(root='data/', name='CiteSeer', split="public")
citeseer_dataset.transform = T.NormalizeFeatures()
citeseer_data = citeseer_dataset[0]

In [23]:
gcn_space = GCNSpace(citeseer_dataset)
gat_space = GATSpace(citeseer_dataset)
cheb_space = ChebSpace(citeseer_dataset)
sage_space = SAGESpace(citeseer_dataset)

In [24]:
gnns = [GCN, GAT, Cheb, SAGE]
gnn_spaces = [gcn_space.get_space(), gat_space.get_space(), cheb_space.get_space(), sage_space.get_space()]

In [25]:
citeseer_score_store = {}
citeseer_param_store = {}
for i, space in tqdm(enumerate(gnn_spaces)):
    gnn_nestedCV_evaluation = GNNNestedCVEvaluation(device, gnns[i],citeseer_data, max_evals= len(space.keys())*20)
    gnn_nestedCV_evaluation.nested_cross_validate(3, 3, space)
    citeseer_score_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.outer_scores
    citeseer_param_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.best_params_per_fold

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

In [26]:
for key in citeseer_score_store:
    print(f"{key}: {citeseer_score_store[key].mean()} +- {citeseer_score_store[key].std()}")

GCN: 0.7673579851786295 +- 0.01766988857757331
GAT: 0.7670574386914571 +- 0.019005028452958674
Cheb: 0.7676585714022318 +- 0.02205463126466205
SAGE: 0.7631499965985616 +- 0.01197013610035922


In [27]:
pubmed_dataset = Planetoid(root='data/', name='PubMed', split="public")
pubmed_dataset.transform = T.NormalizeFeatures()
pubmed_data = pubmed_dataset[0]

gcn_space = GCNSpace(pubmed_dataset)
gat_space = GATSpace(pubmed_dataset)
cheb_space = ChebSpace(pubmed_dataset)
sage_space = SAGESpace(pubmed_dataset)

gnns = [GCN, GAT, Cheb, SAGE]
gnn_spaces = [gcn_space.get_space(), gat_space.get_space(), cheb_space.get_space(), sage_space.get_space()]

In [28]:
pubmed_score_store = {}
pubmed_param_store = {}
for i, space in tqdm(enumerate(gnn_spaces)):
    gnn_nestedCV_evaluation = GNNNestedCVEvaluation(device, gnns[i],pubmed_data, max_evals= len(space.keys())*20)
    gnn_nestedCV_evaluation.nested_cross_validate(3, 3, space)
    pubmed_score_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.outer_scores
    pubmed_param_store[gnns[i].__name__] = gnn_nestedCV_evaluation.nested_transd_cv.best_params_per_fold

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

0it [00:00, ?it/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register

In [29]:
for key in pubmed_score_store:
    print(f"{key}: {pubmed_score_store[key].mean()} +- {pubmed_score_store[key].std()}")

GCN: 0.8839073578516642 +- 0.0011108321189197857
GAT: 0.87888636191686 +- 0.00223852307064744
Cheb: 0.896434485912323 +- 0.001438275443485223
SAGE: 0.8933915893236796 +- 0.001545616407310158


In [31]:
for key in pubmed_score_store:
    print(f"{key}: {pubmed_param_store[key]}")

GCN: [{'add_self_loops': True, 'dropout': 0.33312126156085836, 'hidden_dim': 816.0, 'lr': 0.009705973076591805, 'normalize': True, 'out_dim': 3, 'weight_decay': 1.0052408258228317e-05}, {'add_self_loops': True, 'dropout': 0.3267152968099746, 'hidden_dim': 16.0, 'lr': 0.09609450370667662, 'normalize': True, 'out_dim': 3, 'weight_decay': 1.7660471439858415e-05}, {'add_self_loops': True, 'dropout': 0.21961043860703638, 'hidden_dim': 16.0, 'lr': 0.06716571971330859, 'normalize': True, 'out_dim': 3, 'weight_decay': 1.7968064575771056e-05}]
GAT: [{'add_self_loops': True, 'dropout': 0.22156755726440325, 'heads': 4.0, 'hidden_dim': 32.0, 'lr': 0.03300233026241423, 'out_dim': 3, 'weight_decay': 1.540994040014641e-05}, {'add_self_loops': True, 'dropout': 0.20938281143090182, 'heads': 6.0, 'hidden_dim': 48.0, 'lr': 0.019535213027491612, 'out_dim': 3, 'weight_decay': 1.4236420424359435e-05}, {'add_self_loops': True, 'dropout': 0.28876932376518377, 'heads': 8.0, 'hidden_dim': 32.0, 'lr': 0.03168595

In [36]:
for key in citeseer_param_store:
    print(f"{key}: {citeseer_param_store[key]}")

GCN: [{'add_self_loops': True, 'dropout': 0.6334741357327276, 'hidden_dim': 96.0, 'lr': 0.00352274757165745, 'normalize': True, 'out_dim': 6, 'weight_decay': 0.00015535727839034922}, {'add_self_loops': True, 'dropout': 0.34300080896221785, 'hidden_dim': 704.0, 'lr': 0.00638587676223848, 'normalize': True, 'out_dim': 6, 'weight_decay': 0.00034756657872084533}, {'add_self_loops': True, 'dropout': 0.2859010842385907, 'hidden_dim': 976.0, 'lr': 0.0006431113325014272, 'normalize': True, 'out_dim': 6, 'weight_decay': 0.0001679202460683939}]
GAT: [{'add_self_loops': True, 'dropout': 0.6821493483566271, 'heads': 4.0, 'hidden_dim': 736.0, 'lr': 0.0017631842941992998, 'out_dim': 6, 'weight_decay': 5.445394804550826e-05}, {'add_self_loops': True, 'dropout': 0.32525300024484394, 'heads': 2.0, 'hidden_dim': 1008.0, 'lr': 0.0009607728723961251, 'out_dim': 6, 'weight_decay': 0.0003110786481327837}, {'add_self_loops': True, 'dropout': 0.5513408031025707, 'heads': 6.0, 'hidden_dim': 128.0, 'lr': 0.0015

In [37]:
for key in param_store:
    print(f"{key}: {param_store[key]}")

GCN: [{'add_self_loops': True, 'dropout': 0.7830821518165071, 'hidden_dim': 208.0, 'lr': 0.0026296407267124198, 'normalize': True, 'out_dim': 7, 'weight_decay': 9.53314742679954e-05}, {'add_self_loops': True, 'dropout': 0.5361831509471188, 'hidden_dim': 80.0, 'lr': 0.007534386754083397, 'normalize': True, 'out_dim': 7, 'weight_decay': 6.782625665826144e-05}, {'add_self_loops': True, 'dropout': 0.691592115826571, 'hidden_dim': 160.0, 'lr': 0.003976187118461275, 'normalize': True, 'out_dim': 7, 'weight_decay': 1.7857147267971386e-05}]
GAT: [{'add_self_loops': True, 'dropout': 0.5477675890361065, 'heads': 4.0, 'hidden_dim': 32.0, 'lr': 0.018006349199088747, 'out_dim': 7, 'weight_decay': 7.754420752713125e-05}, {'add_self_loops': True, 'dropout': 0.6764701952981045, 'heads': 4.0, 'hidden_dim': 32.0, 'lr': 0.006405105129334616, 'out_dim': 7, 'weight_decay': 2.782719770741545e-05}, {'add_self_loops': True, 'dropout': 0.6858013470623977, 'heads': 6.0, 'hidden_dim': 16.0, 'lr': 0.0063609578869