In [2]:
import os
import sys
directory_path = os.path.abspath(os.path.join('..'))
if directory_path not in sys.path:
    sys.path.append(directory_path)
  
from trainer import Trainer
from dataset import get_ogb_data
from torch_geometric import seed_everything
from torch_geometric.nn import SAGEConv, GCNConv, GATConv
from torch_geometric.loader import NeighborLoader
from ogb.nodeproppred import Evaluator
from gnn_model import GNN
from predict import evaluate_test
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DEEP_DEPTH = 16
SHALLOW_DEPTH = 2 
depths_list = [DEEP_DEPTH, SHALLOW_DEPTH]

In [4]:
d_name = 'arxiv'
data, split_idx, num_classes = get_ogb_data('arxiv')
evaluator = Evaluator(name=f'ogbn-{d_name}')

def load_model(model_type, depth, data, num_classes):
    
    max_depth = 2 if depth == 'shallow' else 16
    
    if model_type=='SAGE':
        model = GNN(SAGEConv, data.x.shape[1], data.x.shape[1], num_classes, n_layers=max_depth)
    elif model_type=='GCN':
        model = GNN(GCNConv, data.x.shape[1], data.x.shape[1], num_classes, n_layers=max_depth)
    elif model_type=='GAT':
        model = GNN(GATConv, data.x.shape[1], data.x.shape[1], num_classes, n_layers=max_depth)

    model_dir = f'{depth}_{model_type}'
    best_pth_fn = os.path.join(model_dir,'model_best.pth')
    if os.path.exists(best_pth_fn):
        checkpoint = torch.load(best_pth_fn)
        model.load_state_dict(checkpoint['network_state_dict'], strict=False)
        print(f'{depth}_{model_type} loaded')
    return model

def load_test_loader(data, test_idx):
    test_loader = NeighborLoader(data, input_nodes=test_idx, num_neighbors=[-1],
                                    batch_size=128, shuffle=True)
    print('Test loader loaded')
    return test_loader

Number of nodes in the graph: 169343
Number of edges in the graph: 1166243
Number of training nodes: 90941
Number of validation nodes: 29799
Number of test nodes: 48603
Node feature matrix with shape: torch.Size([169343, 128])
Graph connectivity in COO format with shape: torch.Size([2, 1166243])
Target to train against : torch.Size([169343, 1])
Node feature length 128
number of target categories: 40


In [5]:
model_types = ['GAT', 'GCN']
depth_types = ['both']

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_idx = split_idx['test']
num_permutations = 1000

In [32]:
for model_type in model_types:
    for depth_type in depth_types:
        model = load_model(model_type, depth_type, data, num_classes)
        model.to(device)
        test_loader = load_test_loader(data, test_idx)

        nods_results_depths_dict = {}
        ids_col = {'id': test_idx}
        id_depths_df = pd.DataFrame(data=ids_col)
        id_results_df = pd.DataFrame(data=ids_col)

        for perm in tqdm(range(num_permutations)):
            test_acc, inference_depths, ids, correctness = evaluate_test(model, test_loader, evaluator, data, depths_list, device)
            
            temp_df = pd.DataFrame(data={'id': ids, f'depth_{perm}': inference_depths})
            temp_df.sort_values(by=['id'], inplace=True)
            id_depths_df = id_depths_df.merge(temp_df, on='id')
            
            temp_df = pd.DataFrame(data={'id': ids, f'correctness_{perm}': correctness})
            temp_df.sort_values(by=['id'], inplace=True)
            id_results_df = id_results_df.merge(temp_df, on = 'id')
            
        sorted_ids = ids_col['id'].tolist()
        all_correctnesses = list(id_results_df.drop(['id'], axis=1).to_numpy().flatten())
        all_depths = list(id_depths_df.drop(['id'], axis=1).to_numpy().flatten())
        all_sorted_ids = list(np.repeat(sorted_ids, num_permutations))
        all_results_df = pd.DataFrame(data={'id': all_sorted_ids, 'correctness': all_correctnesses, 'depth': all_depths})
        
        id_depths_df.to_csv(f'id_depths_{depth_type}_{model_type}.csv', index=False)
        id_results_df.to_csv(f'id_results_{depth_type}_{model_type}.csv', index=False)
        all_results_df.to_csv(f'all_results_{depth_type}_{model_type}.csv', index=False)

shallow_GAT loaded
Test loader loaded


100%|██████████| 1000/1000 [22:06<00:00,  1.33s/it]


deep_GAT loaded
Test loader loaded


100%|██████████| 1000/1000 [58:19<00:00,  3.50s/it]


both_GAT loaded
Test loader loaded


100%|██████████| 1000/1000 [58:26<00:00,  3.51s/it]


shallow_GCN loaded
Test loader loaded


100%|██████████| 1000/1000 [16:42<00:00,  1.00s/it]


deep_GCN loaded
Test loader loaded


100%|██████████| 1000/1000 [32:27<00:00,  1.95s/it]


both_GCN loaded
Test loader loaded


100%|██████████| 1000/1000 [32:26<00:00,  1.95s/it]


In [6]:
from scipy.stats import permutation_test

def statistic(x, y, axis):
    return np.mean(x, axis=axis) - np.mean(y, axis=axis)

results_folder = 'results'
for model_type in model_types:
    for depth_type in depth_types:
        print(f'{model_type} {depth_type}')
        all_results_df = pd.read_csv(os.path.join(results_folder, f'all_results_{depth_type}_{model_type}.csv'))
        ids_to_pvalue = {}
        for id in tqdm(np.array(test_idx)):
            id_df = all_results_df[all_results_df['id'] == id].copy()
            x = np.array(id_df[id_df['depth'] == 2]['correctness'])
            y = np.array(id_df[id_df['depth'] == 16]['correctness'])
            try:
                res_less = permutation_test((x, y), statistic, vectorized=True,
                                        n_resamples=1000, alternative='less')
            
                res_greater = permutation_test((x, y), statistic, vectorized=True,
                                    n_resamples=1000, alternative='greater')
            except:
                continue
            ids_to_pvalue[id] = (res_less.pvalue, res_greater.pvalue)
        # create dataframe with pvalues
        df = pd.DataFrame.from_dict(ids_to_pvalue, orient='index', columns=['pvalue_less', 'pvalue_greater'])
        df.index.name = 'ids'                           
        df.to_csv(f'results/ids_to_pvalue_{model_type}_{depth_type}.csv', index=True)

GAT both


100%|██████████| 48603/48603 [59:27<00:00, 13.63it/s] 


GCN both


100%|██████████| 48603/48603 [56:48<00:00, 14.26it/s] 
