In [1]:
import os
import sys
directory_path = os.path.abspath(os.path.join('..'))
if directory_path not in sys.path:
    sys.path.append(directory_path)
    
from trainer import Trainer
from dataset import get_ogb_data
from torch_geometric import seed_everything
from ogb.nodeproppred import Evaluator
from predict import evaluate_test
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
d_name = 'arxiv'
data, split_idx, num_classes = get_ogb_data('arxiv')
evaluator = Evaluator(name=f'ogbn-{d_name}')

Number of nodes in the graph: 169343
Number of edges in the graph: 1166243
Number of training nodes: 90941
Number of validation nodes: 29799
Number of test nodes: 48603
Node feature matrix with shape: torch.Size([169343, 128])
Graph connectivity in COO format with shape: torch.Size([2, 1166243])
Target to train against : torch.Size([169343, 1])
Node feature length 128
number of target categories: 40


In [3]:
DEEP_DEPTH = 5
SHALLOW_DEPTH = 2 
depths_list = [DEEP_DEPTH, SHALLOW_DEPTH]

trainer = Trainer()
epochs = 10

trainer.run(data, split_idx, num_classes, evaluator, depths_list, epochs)

==> resuming from epoch 175, best para 0.6115


In [8]:
model = trainer.model
test_loader = trainer.test_loader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_idx = split_idx['test']

nods_results_depths_dict = {}

ids_col = {'id': test_idx}
id_depths_df = pd.DataFrame(data=ids_col)
id_results_df = pd.DataFrame(data=ids_col)

num_permutations = 15
for perm in tqdm(range(num_permutations)):
    test_acc, inference_depths, ids, correctness = evaluate_test(model, test_loader, evaluator, data, depths_list, device)
    
    temp_df = pd.DataFrame(data={'id': ids, f'depth_{perm}': inference_depths})
    temp_df.sort_values(by=['id'], inplace=True)
    id_depths_df = id_depths_df.merge(temp_df, on='id')
    
    temp_df = pd.DataFrame(data={'id': ids, f'correctness_{perm}': correctness})
    temp_df.sort_values(by=['id'], inplace=True)
    id_results_df = id_results_df.merge(temp_df, on = 'id')

100%|██████████| 15/15 [00:20<00:00,  1.35s/it]


In [9]:
sorted_ids = ids_col['id'].tolist()

all_correctnesses = list(id_results_df.drop(['id'], axis=1).to_numpy().flatten())
all_depths = list(id_depths_df.drop(['id'], axis=1).to_numpy().flatten())
all_sorted_ids = list(np.repeat(sorted_ids, num_permutations))

all_results_df = pd.DataFrame(data={'id': all_sorted_ids, 'correctness': all_correctnesses, 'depth': all_depths})

In [11]:
from scipy.stats import permutation_test

def statistic(x, y, axis):
    return np.mean(x, axis=axis) - np.mean(y, axis=axis)

ids_to_pvalue = {}

for id in tqdm(np.array(test_idx)):
    id_df = all_results_df[all_results_df['id'] == id].copy()
    x = np.array(id_df[id_df['depth'] == 2]['correctness'])
    y = np.array(id_df[id_df['depth'] == 5]['correctness'])
    
    try:
        res_less = permutation_test((x, y), statistic, vectorized=True,
                                n_resamples=9999, alternative='less')
    
        res_greater = permutation_test((x, y), statistic, vectorized=True,
                            n_resamples=9999, alternative='greater')
        
    except:
        continue
    

 10%|█         | 4876/48603 [05:18<45:07, 16.15it/s]  