In [1]:
from itertools import product

import argparse
from datasets import get_dataset
from train_eval import cross_validation_with_val_set

from gcn import GCN, GCNWithJK
from graph_sage import GraphSAGE, GraphSAGEWithJK
from gin import GIN0, GIN0WithJK, GIN, GINWithJK
from graclus import Graclus
from top_k import TopK
from sag_pool import SAGPool
from diff_pool import DiffPool
from edge_pool import EdgePool
from global_attention import GlobalAttentionNet
from set2set import Set2SetNet
from sort_pool import SortPool
from asap import ASAP
import torch

In [2]:
import easydict 
args = easydict.EasyDict({
     "batch_size": 128,
      "epochs": 160, 
      "lr": 0.01, 
      "lr_decay_factor": 0.5,
      "lr_decay_step_size":50
      })


In [3]:
layers = [ 3,4,5]
hiddens = [32,64,128]
datasets = ['PROTEINS']

nets = [
    GCN,
    GraphSAGE,
    GIN,
    SAGPool,
    SortPool,
    TopK
    # Graclus,
    # Set2SetNet
]

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# for num_layers, hidden in product(layers, hiddens):
#         print(f'--\n{dataset_name} - {Net.__name__} - {num_layers} - {hidden}')

def logger(info):
    fold, epoch = info['fold'] + 1, info['epoch']
    val_loss, test_acc = info['val_loss'], info['test_acc']
    if epoch % 80 ==0:
        print(f'{fold:02d}/{epoch:03d}: Val Loss: {val_loss:.4f}, '
          f'Test Accuracy: {test_acc:.3f}')


results = []
for dataset_name, Net in product(datasets, nets):
    best_result = (float('inf'), 0, 0)  # (loss, acc, std)
    # print(f'--\n{dataset_name} - {Net.__name__}')
    for num_layers, hidden in product(layers, hiddens):
        print(f'--\n{dataset_name} - {Net.__name__} - Num layers:{num_layers} -Hidden layer:{hidden}')
        # print(f'--\nNum layers:{num_layers} - Hidden layer:{hidden}')
        dataset = get_dataset(dataset_name, sparse=Net != DiffPool)
        model = Net(dataset, num_layers, hidden)
        loss, acc, std = cross_validation_with_val_set(
            dataset,
            model,
            folds=5,
            epochs=args.epochs,
            batch_size=args.batch_size,
            lr=args.lr,
            lr_decay_factor=args.lr_decay_factor,
            lr_decay_step_size=args.lr_decay_step_size,
            weight_decay=0,
            logger=logger,
        )
        if loss < best_result[0]:
            best_result = (loss, acc, std)
        torch.cuda.empty_cache()
    torch.cuda.empty_cache()
    desc = f'{best_result[1]:.3f} ± {best_result[2]:.3f}'
    print(f'Best result - {desc}')
    results += [f'{dataset_name} - {model}: {desc}']
results = '\n'.join(results)
print(f'--\n{results}')


--
PROTEINS - GCN - Num layers:3 -Hidden layer:32
01/080: Val Loss: 0.5330, Test Accuracy: 0.691
01/160: Val Loss: 0.5248, Test Accuracy: 0.722
02/080: Val Loss: 0.6174, Test Accuracy: 0.735
02/160: Val Loss: 0.6181, Test Accuracy: 0.735
03/080: Val Loss: 0.5639, Test Accuracy: 0.722
03/160: Val Loss: 0.5551, Test Accuracy: 0.722
04/080: Val Loss: 0.5529, Test Accuracy: 0.698
04/160: Val Loss: 0.5606, Test Accuracy: 0.703
05/080: Val Loss: 0.5939, Test Accuracy: 0.757
05/160: Val Loss: 0.5909, Test Accuracy: 0.761
Val Loss: 0.5609, Test Accuracy: 0.726 ± 0.035, Duration: 18.248
--
PROTEINS - GCN - Num layers:3 -Hidden layer:64
01/080: Val Loss: 0.5298, Test Accuracy: 0.682
01/160: Val Loss: 0.5340, Test Accuracy: 0.695
02/080: Val Loss: 0.6172, Test Accuracy: 0.735
02/160: Val Loss: 0.6499, Test Accuracy: 0.731
03/080: Val Loss: 0.5544, Test Accuracy: 0.717
03/160: Val Loss: 0.5509, Test Accuracy: 0.726
04/080: Val Loss: 0.5651, Test Accuracy: 0.707
04/160: Val Loss: 0.5642, Test Accur