In [1]:
%load_ext autoreload
%autoreload 2

In [62]:
import os

while 'notebooks' in os.getcwd():
    os.chdir('..')

import pandas as pd
import numpy as np
from sklearn.model_selection import ParameterGrid

from src.train.structural_omega.graph_sage import StructuralOmegaGraphSageCosSim
from src.train.structural_omega.gcn import StructuralOmegaGCNCosSim
from src.train.structural_omega.mlp import StructuralOmegaMLP
from src.train.structural_omega.gat import StructuralOmegaGATCosSim
from src.train.positional_omega.graph_sage import PositionalOmegaGraphSageCosSim
from src.train.positional_omega.node2vec import PositionalOmegaNode2Vec

In [120]:
model_parameters = [
    {
        'model': StructuralOmegaMLP,
        'model_name': "$\Omega_{s}MLP$",
        'parameter_range':{
            'n_layers': list(range(1, 6))
        }
    },
    {
        'model': StructuralOmegaGraphSageCosSim,
        'model_name': "$\Omega_{s}GraphSage$",
        'parameter_range':{
            'n_layers_graph_sage': list(range(1, 4))
        }
    },
    {
        'model': StructuralOmegaGCNCosSim,
        'model_name': "$\Omega_{s}GCN$",
        'parameter_range':{
            'n_layers_gcn': list(range(1, 4))
        }
    },
    {
        'model': StructuralOmegaGATCosSim,
        'model_name': "$\Omega_{s}GAT$",
        'parameter_range':{
            'n_layers_gat': list(range(1, 4))
        }
    },
    {
        'model': PositionalOmegaGraphSageCosSim,
        'model_name': "$\Omega_{p}GraphSage$",
        'parameter_range':{
            'n_layers': list(range(1, 4))
        }
    },
    {
        'model': PositionalOmegaNode2Vec,
        'model_name': "$\Omega_{p}Node2Vec$",
        'parameter_range':{
            'p': np.logspace(-1, 1, 5).round(2).tolist(),
            'q': np.logspace(-1, 1, 5).round(2).tolist()
        }
    }
]

In [127]:
model_series_list = []
for dataset in ['ogbn-arxiv', 'cora', 'pubmed']:
    for model_dict in model_parameters:
        model = model_dict['model']
        model_name = model_dict['model_name']
        for params in ParameterGrid(model_dict['parameter_range']):
            metrics = model.read_metrics(dataset, **params)
            aucs_list = []
            for run in metrics['run'].unique():
                sub_df = metrics.query(f'run == {run}')
                idxmax = sub_df['auc_val'].idxmax()
                aucs = sub_df.loc[idxmax, ['auc_train', 'auc_val', 'auc_test']]
                aucs_list.append(aucs.rename(run))

            aucs_df = pd.concat(aucs_list, axis=1).rename_axis(columns='run').T
            aucs_mean = aucs_df.mean()
            aucs_std = aucs_df.std()

            model_series_list.append(pd.Series({
                'dataset': dataset,
                'model_name': model_name,
                'params': params,
                'mean_auc_train': aucs_mean['auc_train'],
                'std_auc_train': aucs_std['auc_train'],
                'mean_auc_val': aucs_mean['auc_val'],
                'std_auc_val': aucs_std['auc_val'],
                'mean_auc_test': aucs_mean['auc_test'],
                'std_auc_test': aucs_std['auc_test'],
            }))

summary_df = pd.concat(model_series_list, axis=1).T

In [128]:
summary_df['mean_auc_val'].astype(float).idxmax()

91

In [129]:
def get_max(model_df):
    idxmax = model_df['mean_auc_val'].astype(float).idxmax()
    return model_df.loc[idxmax].drop(index=['dataset', 'model_name', 'params']).astype(float)


summary_df.groupby(['dataset', 'model_name'], group_keys=False).apply(
    get_max).style.background_gradient(axis=0).format(lambda x: f'{x:.3f}')


Unnamed: 0_level_0,Unnamed: 1_level_0,mean_auc_train,std_auc_train,mean_auc_val,std_auc_val,mean_auc_test,std_auc_test
dataset,model_name,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
cora,$\Omega_{p}GraphSage$,0.713,0.008,0.791,0.005,0.73,0.005
cora,$\Omega_{p}Node2Vec$,0.617,0.011,0.635,0.012,0.596,0.011
cora,$\Omega_{s}GAT$,0.945,0.009,0.853,0.012,0.789,0.015
cora,$\Omega_{s}GCN$,0.995,0.0,0.878,0.002,0.819,0.002
cora,$\Omega_{s}GraphSage$,1.0,0.0,0.871,0.003,0.822,0.002
cora,$\Omega_{s}MLP$,0.926,0.002,0.711,0.001,0.685,0.0
ogbn-arxiv,$\Omega_{p}GraphSage$,0.675,0.003,0.698,0.004,0.69,0.004
ogbn-arxiv,$\Omega_{p}Node2Vec$,0.622,0.0,0.638,0.002,0.634,0.001
ogbn-arxiv,$\Omega_{s}GAT$,0.866,0.002,0.849,0.002,0.82,0.003
ogbn-arxiv,$\Omega_{s}GCN$,0.868,0.002,0.847,0.001,0.823,0.002
