In [None]:
from sympy import bernoulli
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np

In [None]:
from gnnboundary import *

# COLLAB

In [None]:
dataset = CollabDataset(seed=12345)

In [None]:
model = GCNClassifier(node_features=len(dataset.NODE_CLS),
                      num_classes=len(dataset.GRAPH_CLS),
                      hidden_channels=64,
                      num_layers=5)
model.load_state_dict(torch.load('ckpts/collab.pt'))

In [None]:
dataset_list_gt = dataset.split_by_class()
dataset_list_pred = dataset.split_by_pred(model)

In [None]:
evaluation = dataset.model_evaluate(model)
evaluation

In [None]:
draw_matrix(evaluation['cm'], dataset.GRAPH_CLS.values(), fmt='d')

In [None]:
mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]

In [None]:
adj_ratio_mat, boundary_info = pairwise_boundary_analysis(model, dataset_list_pred)
draw_matrix(adj_ratio_mat, names=dataset.GRAPH_CLS.values(), fmt='.2f')

In [None]:
trainer = {}
sampler = {}

# 0 & 1

In [None]:
cls_1, cls_2 = 0, 1
trainer[cls_1, cls_2] = Trainer(
    sampler=(s := GraphSampler(
        max_nodes=25,
        temperature=0.2,
        num_node_cls=len(dataset.NODE_CLS),
        learn_node_feat=True
    )),
    discriminator=model,
    criterion=WeightedCriterion([
        dict(key="logits", criterion=DynamicBalancingBoundaryCriterion(classes=[cls_1, cls_2]), weight=25),
        dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
        dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
        dict(key="logits", criterion=MeanPenalty(), weight=0),
        dict(key="omega", criterion=NormPenalty(order=1), weight=2),
        dict(key="omega", criterion=NormPenalty(order=2), weight=1),
        # dict(key="xi", criterion=NormPenalty(order=1), weight=0),
        # dict(key="xi", criterion=NormPenalty(order=2), weight=0),
        # dict(key="eta", criterion=NormPenalty(order=1), weight=0),
        # dict(key="eta", criterion=NormPenalty(order=2), weight=0),
        dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
    ]),
    optimizer=(o := torch.optim.SGD(s.parameters(), lr=1)),
    scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
    dataset=dataset,
    budget_penalty=BudgetPenalty(budget=10, order=2, beta=1),
)

In [None]:
cls_1, cls_2 = 0, 1
trainer[cls_1, cls_2].train(
    iterations=2000,
    target_probs={cls_1: (0.4, 0.6), cls_2: (0.4, 0.6)},
    target_size=30,
    w_budget_init=1,
    w_budget_inc=1.1,
    w_budget_dec=0.95,
    k_samples=16
)

# 0 & 2

In [None]:
cls_1, cls_2 = 0, 2
trainer[cls_1, cls_2] = Trainer(
    sampler=(s := GraphSampler(
        max_nodes=25,
        temperature=0.2,
        num_node_cls=len(dataset.NODE_CLS),
        learn_node_feat=True
    )),
    discriminator=model,
    criterion=WeightedCriterion([
        dict(key="logits", criterion=DynamicBalancingBoundaryCriterion(classes=[cls_1, cls_2]), weight=25),
        dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
        dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
        dict(key="logits", criterion=MeanPenalty(), weight=0),
        dict(key="omega", criterion=NormPenalty(order=1), weight=2),
        dict(key="omega", criterion=NormPenalty(order=2), weight=1),
        # dict(key="xi", criterion=NormPenalty(order=1), weight=0),
        # dict(key="xi", criterion=NormPenalty(order=2), weight=0),
        # dict(key="eta", criterion=NormPenalty(order=1), weight=0),
        # dict(key="eta", criterion=NormPenalty(order=2), weight=0),
        dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
    ]),
    optimizer=(o := torch.optim.SGD(s.parameters(), lr=1)),
    scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
    dataset=dataset,
    budget_penalty=BudgetPenalty(budget=10, order=2, beta=1)
)

In [None]:
cls_1, cls_2 = 0, 2
trainer[cls_1, cls_2].train(
    iterations=2000,
    target_probs={cls_1: (0.4, 0.6), cls_2: (0.4, 0.6)},
    target_size=30,
    w_budget_init=1,
    w_budget_inc=1.1,
    w_budget_dec=0.95,
    k_samples=16
)

In [None]:
trainer[0, 2].evaluate(threshold=0.5, show=True)

In [None]:
adjacent_class_pairs = [(0, 1), (0, 2)]
num_classes = len(dataset.GRAPH_CLS)
num_samples = 128
boundary_margin = np.zeros((num_classes, num_classes))
boundary_thickness = np.zeros((num_classes, num_classes))

for class_pair in adjacent_class_pairs:
    c1, c2 = class_pair

    cur_trainer = trainer[c1, c2]

    print(f'Complexity: {get_model_complexity(cur_trainer, original_class_idx=c1, adjacent_class_idx=c2, num_samples=num_samples)}')
    margin = get_model_boundary_margin(cur_trainer,
                                       dataset_list_pred,
                                       model,
                                       original_class_idx=c1,
                                       adjacent_class_idx=c2,
                                       num_samples=num_samples,
                                       from_best_boundary_graph=False)

    thickness = get_model_boundary_thickness(cur_trainer,
                                             dataset_list_pred,
                                             model,
                                             original_class_idx=c1,
                                             adjacent_class_idx=c2,
                                             num_samples=num_samples,
                                             from_best_boundary_graph=False)
    boundary_thickness[c1, c2] = thickness
    boundary_margin[c1, c2] = margin

    margin = get_model_boundary_margin(cur_trainer,
                                       dataset_list_pred,
                                       model,
                                       original_class_idx=c2,
                                       adjacent_class_idx=c1,
                                       num_samples=num_samples,
                                       from_best_boundary_graph=False)

    thickness = get_model_boundary_thickness(cur_trainer,
                                             dataset_list_pred,
                                             model,
                                             original_class_idx=c2,
                                             adjacent_class_idx=c1,
                                             num_samples=num_samples,
                                             from_best_boundary_graph=False)

    boundary_margin[c2, c1] = margin
    boundary_thickness[c2, c1] = thickness


draw_matrix(boundary_margin, dataset.GRAPH_CLS.values())
draw_matrix(boundary_thickness, dataset.GRAPH_CLS.values())