In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

while 'notebooks' in os.getcwd():
    os.chdir('..')

import pandas as pd
import numpy as np
from sklearn.model_selection import ParameterGrid

from src.train.structural_omega.graph_sage import StructuralOmegaGraphSageCosSim
from src.train.structural_omega.gcn import StructuralOmegaGCNCosSim
from src.train.structural_omega.mlp import StructuralOmegaMLP
from src.train.structural_omega.gat import StructuralOmegaGATCosSim
from src.train.positional_omega.graph_sage import PositionalOmegaGraphSageCosSim
from src.train.positional_omega.node2vec import PositionalOmegaNode2Vec

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_parameters = [
    {
        'model': StructuralOmegaMLP,
        'model_name': "$\Omega_{s}MLP$",
        'parameter_range':{
            'n_layers': list(range(1, 6))
        }
    },
    {
        'model': StructuralOmegaGraphSageCosSim,
        'model_name': "$\Omega_{s}GraphSage$",
        'parameter_range':{
            'n_layers_graph_sage': list(range(1, 4))
        }
    },
    {
        'model': StructuralOmegaGCNCosSim,
        'model_name': "$\Omega_{s}GCN$",
        'parameter_range':{
            'n_layers_gcn': list(range(1, 4))
        }
    },
    {
        'model': StructuralOmegaGATCosSim,
        'model_name': "$\Omega_{s}GAT$",
        'parameter_range':{
            'n_layers_gat': list(range(1, 4))
        }
    },
    {
        'model': PositionalOmegaGraphSageCosSim,
        'model_name': "$\Omega_{p}GraphSage$",
        'parameter_range':{
            'n_layers': list(range(1, 4))
        }
    },
    {
        'model': PositionalOmegaNode2Vec,
        'model_name': "$\Omega_{p}Node2Vec$",
        'parameter_range':{
            'p': np.logspace(-1, 1, 5).round(2).tolist(),
            'q': np.logspace(-1, 1, 5).round(2).tolist()
        }
    }
]

In [15]:
model_series_list = []
for dataset in ['ogbn-arxiv', 'cora', 'pubmed', 'citeseer']:
    for model_dict in model_parameters:
        model = model_dict['model']
        model_name = model_dict['model_name']
        for params in ParameterGrid(model_dict['parameter_range']):
            metrics = model.read_metrics(dataset, **params)
            aucs_list = []
            for run in metrics['run'].unique():
                sub_df = metrics.query(f'run == {run}')
                idxmax = sub_df['auc_val'].idxmax()
                aucs = sub_df.loc[idxmax, ['auc_train', 'auc_val', 'auc_test']]
                aucs_list.append(aucs.rename(run))

            aucs_df = pd.concat(aucs_list, axis=1).rename_axis(columns='run').T
            aucs_mean = aucs_df.mean()
            aucs_std = aucs_df.std()

            model_series_list.append(pd.Series({
                'dataset': dataset,
                'model_name': model_name,
                'params': params,
                'mean_auc_train': aucs_mean['auc_train'],
                'std_auc_train': aucs_std['auc_train'],
                'mean_auc_val': aucs_mean['auc_val'],
                'std_auc_val': aucs_std['auc_val'],
                'mean_auc_test': aucs_mean['auc_test'],
                'std_auc_test': aucs_std['auc_test'],
            }))

summary_df = pd.concat(model_series_list, axis=1).T

In [16]:
summary_df['mean_auc_val'].astype(float).idxmax()

91

In [17]:
def get_max(model_df):
    idxmax = model_df['mean_auc_val'].astype(float).idxmax()
    return model_df.loc[idxmax].drop(index=['dataset', 'model_name', 'params']).astype(float)


best_df = summary_df.groupby(['dataset', 'model_name'], group_keys=False).apply(
    get_max)

best_df.style.background_gradient(axis=0).format(lambda x: f'{x:.3f}')


Unnamed: 0_level_0,Unnamed: 1_level_0,mean_auc_train,std_auc_train,mean_auc_val,std_auc_val,mean_auc_test,std_auc_test
dataset,model_name,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
citeseer,$\Omega_{p}GraphSage$,0.582,0.017,0.627,0.013,0.583,0.013
citeseer,$\Omega_{p}Node2Vec$,0.477,0.031,0.535,0.017,0.483,0.011
citeseer,$\Omega_{s}GAT$,0.924,0.03,0.697,0.01,0.631,0.018
citeseer,$\Omega_{s}GCN$,0.917,0.003,0.708,0.002,0.683,0.002
citeseer,$\Omega_{s}GraphSage$,1.0,0.0,0.786,0.002,0.751,0.005
citeseer,$\Omega_{s}MLP$,0.989,0.006,0.676,0.002,0.663,0.005
cora,$\Omega_{p}GraphSage$,0.713,0.008,0.791,0.005,0.73,0.005
cora,$\Omega_{p}Node2Vec$,0.617,0.011,0.635,0.012,0.596,0.011
cora,$\Omega_{s}GAT$,0.945,0.009,0.853,0.012,0.789,0.015
cora,$\Omega_{s}GCN$,0.995,0.0,0.878,0.002,0.819,0.002


In [19]:
print(best_df.style.format(lambda x: f'{x:.3f}').to_latex())

\begin{tabular}{llrrrrrr}
 &  & mean_auc_train & std_auc_train & mean_auc_val & std_auc_val & mean_auc_test & std_auc_test \\
dataset & model_name &  &  &  &  &  &  \\
\multirow[c]{6}{*}{citeseer} & $\Omega_{p}GraphSage$ & 0.582 & 0.017 & 0.627 & 0.013 & 0.583 & 0.013 \\
 & $\Omega_{p}Node2Vec$ & 0.477 & 0.031 & 0.535 & 0.017 & 0.483 & 0.011 \\
 & $\Omega_{s}GAT$ & 0.924 & 0.030 & 0.697 & 0.010 & 0.631 & 0.018 \\
 & $\Omega_{s}GCN$ & 0.917 & 0.003 & 0.708 & 0.002 & 0.683 & 0.002 \\
 & $\Omega_{s}GraphSage$ & 1.000 & 0.000 & 0.786 & 0.002 & 0.751 & 0.005 \\
 & $\Omega_{s}MLP$ & 0.989 & 0.006 & 0.676 & 0.002 & 0.663 & 0.005 \\
\multirow[c]{6}{*}{cora} & $\Omega_{p}GraphSage$ & 0.713 & 0.008 & 0.791 & 0.005 & 0.730 & 0.005 \\
 & $\Omega_{p}Node2Vec$ & 0.617 & 0.011 & 0.635 & 0.012 & 0.596 & 0.011 \\
 & $\Omega_{s}GAT$ & 0.945 & 0.009 & 0.853 & 0.012 & 0.789 & 0.015 \\
 & $\Omega_{s}GCN$ & 0.995 & 0.000 & 0.878 & 0.002 & 0.819 & 0.002 \\
 & $\Omega_{s}GraphSage$ & 1.000 & 0.000 & 0.871 & 0

\texttt{ogbn-arxiv} & \Omega_{p}^{GraphSage} & 0.675 \pm 0.003 & 0.698 \pm 0.004 & 0.690 \pm 0.004 \\
 & \Omega_{p}^{Node2Vec} & 0.622 \pm 0.000 & 0.638 \pm 0.002 & 0.634 \pm 0.001 \\
 & \Omega_{s}^{GAT} & 0.866 \pm 0.002 & \mathbf{0.849 \pm 0.002} & 0.820 \pm 0.003 \\
 & \Omega_{s}^{GCN} & 0.868 \pm 0.002 & 0.847 \pm 0.001 & \mathbf{0.823 \pm 0.002} \\
 & \Omega_{s}^{GraphSage} & 0.890 \pm 0.002 & \mathbf{0.849 \pm 0.002} & 0.817 \pm 0.001 \\
 & \Omega_{s}^{MLP} & 0.788 \pm 0.004 & 0.736 \pm 0.002 & 0.700 \pm 0.002 \\

\cmidrule(lr){1-5}
CiteSeer & \Omega_{p}^{GraphSage} & 0.582 \pm 0.017 & 0.627 \pm 0.013 & 0.583 \pm 0.013 \\
 & \Omega_{p}^{Node2Vec} & 0.477 \pm 0.031 & 0.535 \pm 0.017 & 0.483 \pm 0.011 \\
 & \Omega_{s}^{GAT} & 0.924 \pm 0.030 & 0.697 \pm 0.010 & 0.631 \pm 0.018 \\
 & \Omega_{s}^{GCN} & 0.917 \pm 0.003 & 0.708 \pm 0.002 & 0.683 \pm 0.002 \\
 & \Omega_{s}^{GraphSage} & 1.000 \pm 0.000 & \mathbf{0.786 \pm 0.002} & \mathbf{0.751 \pm 0.005} \\
 & \Omega_{s}^{MLP} & 0.989 \pm 0.006 & 0.676 \pm 0.002 & 0.663 \pm 0.005 \\

\cmidrule(lr){1-5}
Cora & \Omega_{p}^{GraphSage} & 0.713 \pm 0.008 & 0.791 \pm 0.005 & 0.730 \pm 0.005 \\
 & \Omega_{p}^{Node2Vec} & 0.617 \pm 0.011 & 0.635 \pm 0.012 & 0.596 \pm 0.011 \\
 & \Omega_{s}^{GAT} & 0.945 \pm 0.009 & 0.853 \pm 0.012 & 0.789 \pm 0.015 \\
 & \Omega_{s}^{GCN} & 0.995 \pm 0.000 & \mathbf{0.878 \pm 0.002} & 0.819 \pm 0.002 \\
 & \Omega_{s}^{GraphSage} & 1.000 \pm 0.000 & 0.871 \pm 0.003 & \mathbf{0.822 \pm 0.002} \\
 & \Omega_{s}^{MLP} & 0.926 \pm 0.002 & 0.711 \pm 0.001 & 0.685 \pm 0.000 \\

\cmidrule(lr){1-5}
Pubmed & \Omega_{p}^{GraphSage} & 0.543 \pm 0.018 & 0.535 \pm 0.013 & 0.560 \pm 0.025 \\
 & \Omega_{p}^{Node2Vec} & 0.514 \pm 0.012 & 0.547 \pm 0.025 & 0.526 \pm 0.029 \\
 & \Omega_{s}^{GAT} & 0.971 \pm 0.005 & 0.863 \pm 0.013 & 0.852 \pm 0.013 \\
 & \Omega_{s}^{GCN} & 0.983 \pm 0.004 & 0.883 \pm 0.011 & 0.874 \pm 0.009 \\
 & \Omega_{s}^{GraphSage} & 0.999 \pm 0.001 & \mathbf{0.903 \pm 0.010} & \mathbf{0.881 \pm 0.021} \\
 & \Omega_{s}^{MLP} & 0.870 \pm 0.017 & 0.757 \pm 0.004 & 0.754 \pm 0.003 \\