In [1]:
# %%
import numpy as np
from tqdm import tqdm
import copy
import torch
from data import get_school_data

from ugnn.networks import Dynamic_Network, Unfolded_Network
from ugnn.utils.masks import non_zero_degree_mask, mask_split, pad_unfolded_mask
from ugnn.gnns import GCN, train, valid
from ugnn.conformal import get_prediction_sets

### Load data

In [2]:
# %%
As, node_labels, all_labels = get_school_data(return_all_labels=True)
T = len(As)
n = As[0].shape[0]
num_classes = len(np.unique(node_labels))

# Convert to a torch geometric dataset containing T graphs
dyn_network = Dynamic_Network(As, node_labels)

# Unfold the T graphs into a single graph
unf_network = Unfolded_Network(dyn_network)[0]


Number of time windows: 9
Number of nodes: 236


### Set up masks for the specified regime

In [3]:
# %%
# See https://arxiv.org/abs/2405.19230 for details on different regimes
regime = "temporal transductive"
data_mask = non_zero_degree_mask(As, n, T)
train_mask, valid_mask, calib_mask, test_mask = mask_split(
    data_mask, split_props=[0.2, 0.1, 0.35, 0.35], regime=regime
)

# Pad masks to include anchor nodes (required when unfolding)
train_mask = pad_unfolded_mask(train_mask, n)
valid_mask = pad_unfolded_mask(valid_mask, n)
calib_mask = pad_unfolded_mask(calib_mask, n)
test_mask = pad_unfolded_mask(test_mask, n)


### Train a UGCN

In [4]:
# %%
model = GCN(
    num_nodes=unf_network.num_nodes, num_channels=16, num_classes=num_classes, seed=123
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

max_valid_acc = 0
for epoch in tqdm(range(200)):
    _ = train(model, unf_network, train_mask, optimizer)
    valid_acc = valid(model, unf_network, valid_mask)

    if valid_acc > max_valid_acc:
        max_valid_acc = valid_acc
        best_model = copy.deepcopy(model)

test_acc = valid(model, unf_network, test_mask)
print(f"Test accuracy: {test_acc:0.3f}")


100%|██████████| 200/200 [00:04<00:00, 43.44it/s]

Test accuracy: 0.950





### Compute conformal prediction sets

In [5]:
# %%
output = best_model(unf_network.x, unf_network.edge_index, unf_network.edge_weight)
all_pred_sets = get_prediction_sets(
    output, unf_network, calib_mask, test_mask, alpha=0.1
)

for calib_node_idx in np.random.randint(len(all_pred_sets), size=10):
    node_pred_set = all_pred_sets[calib_node_idx]
    node_idx = np.where(calib_mask == 1)[0][calib_node_idx]

    # Convert predictions to label names
    possible_labels_for_node = [
        str(all_labels[pred]) for pred in np.where(node_pred_set == 1)[0]
    ]
    print(
        f"Node {node_idx} (True label {str(all_labels[unf_network.y[node_idx]])}): {possible_labels_for_node}"
    )

Node 1937 (True label 2A): ['1B', '2A', '2B', '3A', '4B']
Node 2168 (True label 1B): ['1B', '2A', '2B']
Node 1735 (True label 2B): ['2A', '3A', '3B', '4A', '4B', '5A', '5B', 'Teachers']
Node 1988 (True label 3A): ['1A', '3A', '3B', '4B', '5B']
Node 2110 (True label 5B): ['1A', '3A', '3B', '4B']
Node 1663 (True label 1A): []
Node 1998 (True label 3A): ['1A', '3A', '3B']
Node 1840 (True label 5A): ['2B', '3A', '4B', '5A', '5B', 'Teachers']
Node 2199 (True label 2B): ['1B', '3A', '3B', '4A', '4B', '5A', '5B', 'Teachers']
Node 2020 (True label 3B): ['4A']


In [8]:
from ugnn.conformal import new_get_prediction_sets

# %%
output = best_model(unf_network.x, unf_network.edge_index, unf_network.edge_weight)
all_pred_sets = new_get_prediction_sets(
    output, unf_network, calib_mask, test_mask, score_function="RAPS", alpha=0.1
)

for calib_node_idx in np.random.randint(len(all_pred_sets), size=10):
    node_pred_set = all_pred_sets[calib_node_idx]
    node_idx = np.where(calib_mask == 1)[0][calib_node_idx]

    # Convert predictions to label names
    possible_labels_for_node = [
        str(all_labels[pred]) for pred in np.where(node_pred_set == 1)[0]
    ]
    print(
        f"Node {node_idx} (True label {str(all_labels[unf_network.y[node_idx]])}): {possible_labels_for_node}"
    )

Penalty: 0.0, Avg size: 3.5548961424332344
Penalty: 1e-05, Avg size: 3.5548961424332344
Penalty: 0.0001, Avg size: 3.5548961424332344
Penalty: 0.001, Avg size: 3.5548961424332344
Penalty: 0.01, Avg size: 3.5548961424332344
Penalty: 0.1, Avg size: 3.5548961424332344
Penalty: 0.5, Avg size: 3.5548961424332344

Best penalty: 0.0
Node 2206 (True label 2B): ['1B', '3A', '3B', '4B']
Node 1742 (True label 2B): ['1A', '2A', '3A', '3B', '4B', 'Teachers']
Node 2081 (True label 5A): ['4A', '4B', '5A', '5B', 'Teachers']
Node 1765 (True label 3A): ['1A', '1B', '3A', '3B', '4B']
Node 1897 (True label 1A): ['4B']
Node 1957 (True label 2B): ['1B', '2A', '2B']
Node 2156 (True label 1B): ['1B', '2A', '2B', '3A']
Node 1672 (True label 1A): ['1B', '2B']
Node 1995 (True label 3A): ['1A', '3A', '3B', '4B']
Node 1906 (True label 1A): ['1A', '3A', '3B', '4B']
