In [None]:
from gnnboundary.utils.boundary_baseline import BaselineGenerator
%load_ext autoreload
%autoreload 2

In [None]:
import torch

from gnnboundary import *

In [None]:
seed = 12345

datasets = {
 "collab": CollabDataset(seed=seed),
 "motif": MotifDataset(seed=seed),
 "enzymes": ENZYMESDataset(seed=seed)
}


In [None]:
from gnnboundary.utils import BaselineGenerator

def baseline_class_probabilities(dataset_name, model, class_pair, num_samples=500):

    dataset = datasets[dataset_name]

    generator = BaselineGenerator(dataset.split_by_class(), class_pair)
    samples = generator.sample(num_samples)

    model.eval()
    return model.forward(dataset.convert(samples))['probs']

def get_model(dataset_name):

    dataset = datasets[dataset_name]

    config = {
        "node_features": len(dataset.NODE_CLS),
        "num_classes": len(dataset.GRAPH_CLS),
    }

    match dataset_name:
        case "collab":
            config["hidden_channels"] = 64
            config["num_layers"] = 5
        case "motif":
            config["hidden_channels"] = 6
            config["num_layers"] = 3
        case "enzymes":
            config["hidden_channels"] = 32
            config["num_layers"] = 3
        case _:
            pass

    model = GCNClassifier(**config)
    model.load_state_dict(torch.load(f"ckpts/{dataset_name}.pt"))

    return model


In [None]:
dataset_name = 'motif'
adjacent_class_pairs = [[0, 1], [0, 2], [1, 3]]
model = get_model(dataset_name)

for class_pair in adjacent_class_pairs:
    class_probabilities = baseline_class_probabilities(dataset_name, model, class_pair)
    print(f'Class pair {class_pair}')
    print(f'    --- Mean class probabilities {class_probabilities.mean(dim=0)}')
    print(f'    --- Std class probabilities {class_probabilities.std(dim=0)}')

In [None]:
dataset_name = 'collab'
adjacent_class_pairs = [[0, 1], [0, 2]]
model = get_model(dataset_name)

for class_pair in adjacent_class_pairs:
    class_probabilities = baseline_class_probabilities(dataset_name, model, class_pair)
    print(f'Class pair {class_pair}')
    print(f'    --- Mean class probabilities {class_probabilities.mean(dim=0)}')
    print(f'    --- Std class probabilities {class_probabilities.std(dim=0)}')

In [None]:
dataset_name = 'enzymes'
adjacent_class_pairs = [[0, 3], [0, 4], [0, 5], [1, 2], [1, 5], [2, 4], [3, 4], [4, 5]]
model = get_model(dataset_name)

for class_pair in adjacent_class_pairs:
    class_probabilities = baseline_class_probabilities(dataset_name, model, class_pair)
    print(f'Class pair {class_pair}')
    print(f'    --- Mean class probabilities {class_probabilities.mean(dim=0)}')
    print(f'    --- Std class probabilities {class_probabilities.std(dim=0)}')