# Load the trained models and compute the brier score on the test set. Save for plotting.

This is needed bc didn't record Brier Score on the training runs. Thus compute Brier Score and save to dict for plotting in another file.

In [32]:
import os
import time
import warnings

import yaml
import time
import argparse
import os

import numpy as np

import torch
import torch.nn.functional as F

from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges, negative_sampling

from gnn import GCNEncoder, GAE, VGAE, VariationalMLPDecoder, MLPDecoder, posterior_predictive_metrics

from sklearn.metrics import average_precision_score, roc_auc_score, f1_score

warnings.filterwarnings("ignore", category=UserWarning, message="'train_test_split_edges' is deprecated")
warnings.filterwarnings("ignore", category=FutureWarning, module="torch") # supress annoying warning which is internal to pytorch-geometric loading (I think)

from config import GNN_DIR, DATA_DIR
log_dir =  GNN_DIR + "logs/"
model_dir = GNN_DIR + "models/"


In [33]:
# load a single model on a single dataset on a single run and compute Brier Score

dataset_name = "Cora"
#model = "ks"
run = 0
seed = 5 * run

# Define model hyperparameters (should match your training configuration)
latent_distrib = 'beta'  # change if using a different distribution
pre_str = 'edge_' if latent_distrib in ['ks', 'beta', 'tanh-normal'] else ''

model_path = f"{model_dir}" + pre_str + f"{latent_distrib}_{dataset_name}_{run}.pt"

# Define model hyperparameters (should match your training configuration)
hidden_channels = 32
out_channels = 32

In [34]:
def encode_and_decode(model, x, train_pos_edge_index, neg_edge_index):
    embed = model(x, train_pos_edge_index)
    distrib_pos = model.decode(embed, train_pos_edge_index)
    distrib_neg = model.decode(embed, neg_edge_index)
    return distrib_pos, distrib_neg

def posterior_predictive_metrics(post_pred_samples, labels, print_metrics=False):
    #labels = torch.cat([torch.ones(pos_edge_index.shape[1]), torch.zeros(neg_edge_index.shape[1])], dim=0)

    #post_pred_samples = torch.stack(post_pred_samples) # shape (num_samples, num_edges)
    mean, std = post_pred_samples.mean(dim=0), post_pred_samples.std(dim=0)
    error = (mean - labels).abs()

    #### metrics
    # quality of the uncertainty
    pearson_corr = torch.corrcoef(torch.stack((error, std)))[0, 1] # shape (1)

    # quality of uncertainty over positive and negative edges
    pos_mask, neg_mask = (labels == 1), (labels == 0)
    #assert len(pos_mask) == len(neg_mask), "Positive and negative masks must be the same length"
    num_unique_error = torch.unique(error).shape[0]
    num_unique_std = torch.unique(std).shape[0]
    #print(f"\tUnique error values: {num_unique_error}, unique std values: {num_unique_std}")
    
    pearson_corr_pos = torch.corrcoef(torch.stack((std[pos_mask], error[pos_mask])))[0, 1]
    pearson_corr_neg = torch.corrcoef(torch.stack((std[neg_mask], error[neg_mask])))[0, 1]

    # brier score
    bs = (mean - labels)**2
    bs_pos = bs[pos_mask]
    bs_neg = bs[neg_mask]

    pred_binary = (mean >= 0.5).int()  # Threshold at 0.5 to get binary predictions
    # quality of mean of posterior predictive as predictor.
    # auc and ap may not make sense, as the mean is not really a probability.
    auc, ap, f1 = roc_auc_score(labels, mean), average_precision_score(labels, mean), f1_score(labels, pred_binary)

    if print_metrics:
        print(f"\tMean predictor performance: AUC: {auc:.4f}, AP: {ap:.4f}, F1: {f1:.4f}")
        print(f"\tPearson correlation: {pearson_corr:.4f}, pos: {pearson_corr_pos:.4f}, neg: {pearson_corr_neg:.4f}")
        print(f"\tBrier Score: {bs.mean().float():.4f}, pos: {bs_pos.mean().float():.4f}, neg: {bs_neg.mean().float():.4f}")

    metrics = {
        'abs_error': error.mean().float(),
        'brier_score': bs.mean().float(),
        'std': std.mean().float(),
        'pearson_corr': pearson_corr.float(),
        'pearson_corr_pos': pearson_corr_pos.float(),
        'pearson_corr_neg': pearson_corr_neg.float(),
        'auc': auc,
        'ap': ap,
        'f1': f1
    }
    return metrics

def test_posterior_predictive(model, x, pos_edge_index, neg_edge_index, num_samples=30, viz=False):
    model.eval()
    with torch.no_grad():
        post_pred_samples = []
        labels = torch.cat([torch.ones(pos_edge_index.shape[1]), torch.zeros(neg_edge_index.shape[1])], dim=0)

        for i in range(num_samples):
            # encode and decode
            distrib_pos, distrib_neg = encode_and_decode(model, x, pos_edge_index, neg_edge_index)

            # sample
            edge_prob_sample_pos = distrib_pos.rsample()
            edge_prob_sample_neg = distrib_neg.rsample()

            # sample the bernoulli likelihood
            post_pred_samples.append(torch.cat([
                torch.bernoulli(edge_prob_sample_pos), 
                torch.bernoulli(edge_prob_sample_neg)], 
                dim=0))
            
    post_pred_samples = torch.stack(post_pred_samples) # shape (num_samples, num_edges)
    metrics = posterior_predictive_metrics(post_pred_samples, labels)
    return metrics

In [70]:
# Set the device to GPU if available, otherwise use CPU
def compute_test_metrics(dataset_name, latent_distrib, run, seed):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #print(f"Using device: {device}")

    dataset = Planetoid(DATA_DIR, dataset_name, transform=T.NormalizeFeatures())
    data = dataset[0]
    # Remove masks if any, then create a train/test split
    data.train_mask = data.val_mask = data.test_mask = None
    torch.manual_seed(seed)
    data_tts = train_test_split_edges(data)

    model = GAE(
        GCNEncoder(dataset.num_features, hidden_channels, out_channels),
        VariationalMLPDecoder(out_channels, out_channels, latent_distrib=latent_distrib)
    )
    model = model.to(device)
    x = data.x.to(device)
    train_pos_edge_index = data.train_pos_edge_index.to(device)

    #print(f"Edge Latent: {latent_distrib} on {dataset_name} - Run {run}")

    # Instantiate the model
    model = GAE(
        GCNEncoder(dataset.num_features, hidden_channels, out_channels),
        VariationalMLPDecoder(out_channels, out_channels, latent_distrib=latent_distrib)
    ).to(device)

    # Load the trained model state
    pre_str = 'edge_' if latent_distrib in ['ks', 'beta', 'tanh-normal'] else ''
    model_path = f"{model_dir}" + pre_str + f"{latent_distrib}_{dataset_name}_{run}.pt"
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    model.load_state_dict(state_dict)
    model.eval()

    # Prepare test edges and features
    x = data.x.to(device)
    pos_edge_index = data.test_pos_edge_index.to(device)
    neg_edge_index = data.test_neg_edge_index.to(device)

    metrics = test_posterior_predictive(model, x, pos_edge_index, neg_edge_index)
    return metrics
#print(metrics)

from collections import defaultdict
brier_scores = defaultdict(dict)
for dataset_name in ["Cora", "Citeseer", "Pubmed"]:
    brier_scores[dataset_name] = defaultdict(list)
    print(f"Dataset: {dataset_name}")
    for latent_distrib in ["ks", "beta", "tanh-normal"]:
        print(f"\tEdge Latent: {latent_distrib} on {dataset_name}")
        for run in [0, 1, 2, 3, 4]:
            metrics = compute_test_metrics(dataset_name, latent_distrib, run, 5*run)
            brier_scores[dataset_name][latent_distrib].append(metrics['brier_score'].item())

        print(f"\tBrier Score across Runs: {np.mean(brier_scores[dataset_name][latent_distrib]):.5f} +/- {np.std(brier_scores[dataset_name][latent_distrib]):.5f}")

Dataset: Cora
	Edge Latent: ks on Cora
	Brier Score across Runs: 0.08209 +/- 0.01029
	Edge Latent: beta on Cora
	Brier Score across Runs: 0.11650 +/- 0.01380
	Edge Latent: tanh-normal on Cora
	Brier Score across Runs: 0.11521 +/- 0.04466
Dataset: Citeseer
	Edge Latent: ks on Citeseer
	Brier Score across Runs: 0.10402 +/- 0.00802
	Edge Latent: beta on Citeseer
	Brier Score across Runs: 0.14022 +/- 0.01847
	Edge Latent: tanh-normal on Citeseer
	Brier Score across Runs: 0.11059 +/- 0.00429
Dataset: Pubmed
	Edge Latent: ks on Pubmed
	Brier Score across Runs: 0.06075 +/- 0.00430
	Edge Latent: beta on Pubmed
	Brier Score across Runs: 0.06916 +/- 0.00249
	Edge Latent: tanh-normal on Pubmed
	Brier Score across Runs: 0.12017 +/- 0.04162


In [71]:
print(brier_scores)

defaultdict(<class 'dict'>, {'Cora': defaultdict(<class 'list'>, {'ks': [0.08901433646678925, 0.07873287051916122, 0.09745624661445618, 0.06736031174659729, 0.07789163291454315], 'beta': [0.10884460061788559, 0.09924519807100296, 0.13080644607543945, 0.10886358469724655, 0.1347469985485077], 'tanh-normal': [0.08107948303222656, 0.17799915373325348, 0.07496626675128937, 0.08115749061107635, 0.16086654365062714]}), 'Citeseer': defaultdict(<class 'list'>, {'ks': [0.10626861453056335, 0.11789742857217789, 0.10029914230108261, 0.10188400000333786, 0.09374602884054184], 'beta': [0.11231989413499832, 0.16957508027553558, 0.13816359639167786, 0.1460866928100586, 0.13496580719947815], 'tanh-normal': [0.11061416566371918, 0.11421610414981842, 0.11055921763181686, 0.11478753387928009, 0.10277044028043747]}), 'Pubmed': defaultdict(<class 'list'>, {'ks': [0.06285273283720016, 0.05707894638180733, 0.06074433773756027, 0.05549626424908638, 0.06759877502918243], 'beta': [0.06824583560228348, 0.0694205

In [72]:
brier_scores.keys()

dict_keys(['Cora', 'Citeseer', 'Pubmed'])

In [64]:
# run point estimator GNN

# Set the device to GPU if available, otherwise use CPU
def compute_test_metrics_base(dataset_name, run, seed):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #print(f"Using device: {device}")

    dataset = Planetoid(DATA_DIR, dataset_name, transform=T.NormalizeFeatures())
    data = dataset[0]
    # Remove masks if any, then create a train/test split
    data.train_mask = data.val_mask = data.test_mask = None
    torch.manual_seed(seed)
    data_tts = train_test_split_edges(data)

    model = GAE(GCNEncoder(dataset.num_features, hidden_channels, out_channels),
                MLPDecoder(out_channels, out_channels)
                )
    model = model.to(device)
    x = data.x.to(device)
    train_pos_edge_index = data.train_pos_edge_index.to(device)

    #print(f"Edge Latent: {latent_distrib} on {dataset_name} - Run {run}")

    # Instantiate the model
    model = GAE(
        GCNEncoder(dataset.num_features, hidden_channels, out_channels),
        VariationalMLPDecoder(out_channels, out_channels, latent_distrib=latent_distrib)
    ).to(device)

    # Load the trained model state
    pre_str = 'edge_' if latent_distrib in ['ks', 'beta', 'tanh-normal'] else ''
    model_path = f"{model_dir}" + pre_str + f"{latent_distrib}_{dataset_name}_{run}.pt"
    state_dict = torch.load(model_path, map_location=device, weights_only=True)
    model.load_state_dict(state_dict)
    model.eval()

    # Prepare test edges and features
    x = data.x.to(device)
    pos_edge_index = data.test_pos_edge_index.to(device)
    neg_edge_index = data.test_neg_edge_index.to(device)

    metrics = test_posterior_predictive(model, x, pos_edge_index, neg_edge_index)
    return metrics
#print(metrics)

In [73]:
for dataset_name in ["Cora", "Citeseer", "Pubmed"]:
    print(f"Dataset: {dataset_name}")
    for run in [0, 1, 2, 3, 4]:
        metrics = compute_test_metrics_base(dataset_name, run, 5*run)
        brier_scores[dataset_name]['base'].append(metrics['brier_score'].item())

    print(f"\tBrier Score across Runs: {np.mean(brier_scores[dataset_name]['base']):.5f} +/- {np.std(brier_scores[dataset_name]['base']):.5f}")

Dataset: Cora
	Brier Score across Runs: 0.11504 +/- 0.04522
Dataset: Citeseer
	Brier Score across Runs: 0.10954 +/- 0.00428
Dataset: Pubmed
	Brier Score across Runs: 0.12070 +/- 0.04181


In [86]:
print(brier_scores)
print(brier_scores.keys())
print(brier_scores.values())
for ds in brier_scores:
    print(f"Dataset {ds}")
    for m in brier_scores[ds]:
        print(f"\tmodel: {m} - {brier_scores[ds][m]}")

defaultdict(<class 'dict'>, {'Cora': defaultdict(<class 'list'>, {'ks': [0.08901433646678925, 0.07873287051916122, 0.09745624661445618, 0.06736031174659729, 0.07789163291454315], 'beta': [0.10884460061788559, 0.09924519807100296, 0.13080644607543945, 0.10886358469724655, 0.1347469985485077], 'tanh-normal': [0.08107948303222656, 0.17799915373325348, 0.07496626675128937, 0.08115749061107635, 0.16086654365062714], 'base': [0.08150433003902435, 0.1790006160736084, 0.07329749315977097, 0.08068837970495224, 0.16070735454559326]}), 'Citeseer': defaultdict(<class 'list'>, {'ks': [0.10626861453056335, 0.11789742857217789, 0.10029914230108261, 0.10188400000333786, 0.09374602884054184], 'beta': [0.11231989413499832, 0.16957508027553558, 0.13816359639167786, 0.1460866928100586, 0.13496580719947815], 'tanh-normal': [0.11061416566371918, 0.11421610414981842, 0.11055921763181686, 0.11478753387928009, 0.10277044028043747], 'base': [0.11054332554340363, 0.11482906341552734, 0.11001341789960861, 0.11062

In [90]:
log_dir + 'calibration_metrics.pkl'

'/Users/maxw/projects/pathwise_grad_kumar/experiments/gnn/logs/calibration_metrics.pkl'

In [95]:
brier_scores

defaultdict(dict,
            {'Cora': defaultdict(list,
                         {'ks': [0.08901433646678925,
                           0.07873287051916122,
                           0.09745624661445618,
                           0.06736031174659729,
                           0.07789163291454315],
                          'beta': [0.10884460061788559,
                           0.09924519807100296,
                           0.13080644607543945,
                           0.10886358469724655,
                           0.1347469985485077],
                          'tanh-normal': [0.08107948303222656,
                           0.17799915373325348,
                           0.07496626675128937,
                           0.08115749061107635,
                           0.16086654365062714],
                          'base': [0.08150433003902435,
                           0.1790006160736084,
                           0.07329749315977097,
                           0.080688379704

In [94]:
log_dir + 'calibration_metrics.pkl'

'/Users/maxw/projects/pathwise_grad_kumar/experiments/gnn/logs/calibration_metrics.pkl'

In [96]:
# save the calibration metrics to a file
import pickle
with open(log_dir + 'calibration_metrics.pkl', 'wb') as f:
    pickle.dump(brier_scores, f)
    f.flush()  # Ensure data is written to disk
    os.fsync(f.fileno())  # Force write to disk (optional)