In [None]:
import os
from tqdm.notebook import tqdm
from torch_geometric.datasets import Flickr
import pandas as pd
import multiprocessing as mp
from functools import partial
import pickle as pkl
from torch_geometric.utils import k_hop_subgraph
from utils import *

%load_ext autoreload
%autoreload 2

In [None]:
exp_path = ''  # Experiment folder path goes here.
compute_nbhds = True # Set this to False to avoid recomputing.

In [None]:
dataset = Flickr('data/Flickr')
data = dataset[0]
test_x = data.x[data.test_mask].numpy()
test_y = data.y[data.test_mask].numpy()
preds_path = os.path.join(exp_path, 'preds.pkl')
with open(preds_path, 'rb') as f:
    preds = pkl.load(f)

In [None]:
num_hops = 2
nb_fstr = f'data/Flickr/{num_hops}_hop_nbhds.pkl'
if compute_nbhds:
    test_subgraph = data.subgraph(data.test_mask)
    nbs = [k_hop_subgraph(i, num_hops=num_hops, edge_index=test_subgraph.edge_index)[0] for i in tqdm(range(test_subgraph.num_nodes))]

    nbs = [np.array(nbhd) for nbhd in nbs]
    nbs = [nbhd[nbhd != i] for i, nbhd in tqdm(enumerate(nbs))]
    with open(nb_fstr, 'wb') as f:
        pkl.dump(nbs, f)
else:
    with open(nb_fstr, 'rb') as f:
        nbs = pkl.load(f)

In [None]:
len(nbs)

In [None]:
# Optionally drop all the small neighbourhoods from the graph.
cutoff = 50
node_pool = [i for i, nb in enumerate(nbs) if len(nb) > cutoff]

In [None]:
len(node_pool)

In [None]:
test_y

In [None]:
alpha = 0.1
n_trials = 100
n_calib = 1000
nb_stats = []
full_stats = []
calib_partial = partial(calibrate_full, probs=preds, labels=test_y, alpha=alpha)
q_nb = []
q_full = []
with mp.Pool(12) as p:
    for k in tqdm(range(n_trials)):
        nodes = np.random.choice(node_pool, n_calib, replace=False)
        # Neighbourhood calibration
        quantiles_nb = [calibrate(preds[nbs[i]], test_y[nbs[i]], alpha) for i in nodes]
        quantiles_nb = np.concatenate(quantiles_nb)
        q_nb.append(quantiles_nb)
        sets_nb = predict(preds[nodes], quantiles_nb[:, None])
        nb_stats.append(evaluate_predictions(sets_nb, test_x[nodes], test_y[nodes]))
        # Full calibration
        quantiles_full = p.map(calib_partial, nodes, chunksize=10)
        quantiles_full = np.concatenate(quantiles_full)
        q_full.append(quantiles_full)
        sets_full = predict(preds[nodes], quantiles_full[:, None])
        full_stats.append(evaluate_predictions(sets_full, test_x[nodes], test_y[nodes]))


In [None]:
nb_df = pd.DataFrame(nb_stats, columns=['coverage', 'set_size', 'cc_set_size'])
nb_df['coverage'].plot(kind='hist', bins=30)
nb_df.describe()

In [None]:
full_df = pd.DataFrame(full_stats, columns=['coverage', 'set_size', 'cc_set_size'])
full_df['coverage'].plot(kind='hist', bins=30)
full_df.describe()