In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import gnnboundary
import torch
import json
import time
import os
import numpy as np

In [None]:
def get_dataset_setup(
        dataset_name,
        use_gat=False,
        retrained=False
    ):

    datasets = {
        "motif": gnnboundary.MotifDataset,
        "collab": gnnboundary.CollabDataset,
        "enzymes": gnnboundary.ENZYMESDataset,
    }
    models = {
        "motif": lambda ds: (gnnboundary.GCNClassifier(
            node_features=len(ds.NODE_CLS),
            num_classes=len(ds.GRAPH_CLS),
            hidden_channels=6,
            num_layers=3,
        ), "ckpts/motif.pt"),
        "collab": lambda ds: (gnnboundary.GCNClassifier(
            node_features=len(ds.NODE_CLS),
            num_classes=len(ds.GRAPH_CLS),
            hidden_channels=64,
            num_layers=5,
        ), "ckpts/collab.pt"),
        "enzymes":lambda ds: (gnnboundary.GCNClassifier(
            node_features=len(dataset.NODE_CLS),
            num_classes=len(dataset.GRAPH_CLS),
            hidden_channels=32,
            num_layers=3,
        ), "ckpts/enzymes.pt"),
        "motif_gat": lambda ds: (gnnboundary.GCNClassifier(
            node_features=len(ds.NODE_CLS),
            num_classes=len(ds.GRAPH_CLS),
            hidden_channels=6,
            num_layers=3,
            use_gat=True,
        ), "ckpts/motif_gat.pt"),
        "collab_gat": lambda ds: (gnnboundary.GCNClassifier(
            node_features=len(ds.NODE_CLS),
            num_classes=len(ds.GRAPH_CLS),
            hidden_channels=64,
            num_layers=5,
            use_gat=True,
        ), "ckpts/collab_gat.pt"),
        "motif_retrained": lambda ds: (gnnboundary.GCNClassifier(
            node_features=len(ds.NODE_CLS),
            num_classes=len(ds.GRAPH_CLS),
            hidden_channels=6,
            num_layers=3,
        ), "ckpts/motif_retrained.pt"),
    }
    
    dataset = datasets[dataset_name](seed=12345)
    if use_gat:
        dataset_name += "_gat"
    if retrained:
        dataset_name += "_retrained"

    model, path = models[dataset_name](dataset)
    model.load_state_dict(torch.load(path))
    print(path)
    return dataset, model

In [None]:
def get_trainer(cls_idx, dataset_name, use_gat=False, use_retrained=False, sampler_path=None):
    dataset, model = get_dataset_setup(dataset_name, use_gat=use_gat, retrained=use_retrained)

    sampler = gnnboundary.GraphSampler(
        max_nodes=25,
        temperature=0.15,
        num_node_cls=len(dataset.NODE_CLS),
        learn_node_feat=True,
    )

    criterion = gnnboundary.WeightedCriterion([
        dict(key="logits", criterion=gnnboundary.DynamicBalancingBoundaryCriterion(
            classes=list(cls_idx), alpha=1, beta=1
        ), weight=25),
        # dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
        # dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
        # dict(key="logits", criterion=gnnboundary.MeanPenalty(), weight=1),
        dict(key="omega", criterion=gnnboundary.NormPenalty(order=1), weight=1),
        dict(key="omega", criterion=gnnboundary.NormPenalty(order=2), weight=1),
        # dict(key="xi", criterion=NormPenalty(order=1), weight=0),
        # dict(key="xi", criterion=NormPenalty(order=2), weight=0),
        # dict(key="eta", criterion=NormPenalty(order=1), weight=0),
        # dict(key="eta", criterion=NormPenalty(order=2), weight=0),
        # dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
    ])

    def get_optimizer(sampler):
        optimizer = torch.optim.SGD(sampler.parameters(), lr=1)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1)
        # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1, total_steps=500)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
        return [optimizer], scheduler

    trainer = gnnboundary.Trainer(
        sampler=sampler,
        discriminator=model,
        criterion=criterion,
        optim_factory=get_optimizer,
        dataset=dataset,
        budget_penalty=gnnboundary.BudgetPenalty(budget=10, order=2, beta=1),
    )
    
    if sampler_path is not None:
        sampler.load(sampler_path)
    return trainer


def train_eval(cls_idx, dataset_name, num_runs, num_samples, train_args, use_gat=False, use_retrained=False, show_runs=False):
    start = time.time()
    
    train_args["target_probs"] = {cls_idx[0]: train_args["target_probs"], cls_idx[1]: train_args["target_probs"]}
    trainer = get_trainer(cls_idx, dataset_name, use_gat=use_gat, use_retrained=use_retrained)
    logs = trainer.batch_generate(cls_idx, total=num_runs, num_boundary_samples=num_samples, show_runs=show_runs, **train_args)
    
    converged = [(l["train_scores"], l["eval_scores"]) for l in logs if l["converged"]]
    scores = {}
    
    for label, score_list in zip(["train", "eval"], zip(*converged)):
        means = torch.stack([s["mean"] for s in score_list])
        stds = torch.stack([s["std"] for s in score_list])
        
        loss = (means[:, list(cls_idx)] - 0.5).abs() + stds[:, list(cls_idx)]
        best_idx = loss.sum(dim=1).argmin().item()
        
        scores[label] = {"mean_mean": means.mean(dim=0).tolist(),
                         "mean_std": stds.mean(dim=0).tolist(),
                         "best_idx": best_idx,
                         "best_mean": means[best_idx].tolist(),
                         "best_std": stds[best_idx].tolist()}
     
    convergence_rate = len(converged) / num_runs
    scores["convergence_rate"] = convergence_rate
    total_time = time.time() - start
    scores["time"] = total_time
    
    print(f"Time: {total_time} seconds")
    print(f"Classes: {cls_idx}", f"Num runs: {num_runs}, num samples: {num_samples}", sep="\n", end="\n\n")
    print(f"Convergence rate: {convergence_rate}")
    if len(converged) > 0:
        print(f"""Train - mean: {scores["train"]["mean_mean"]}, std: {scores["train"]["mean_std"]}
            best_idx: {scores["train"]["best_idx"]},
            best: {scores["train"]["best_mean"]}, std: {scores["train"]["best_std"]}""")
        print(f"""Eval - mean: {scores["eval"]["mean_mean"]}, std: {scores["eval"]["mean_std"]}
            best_idx: {scores["eval"]["best_idx"]},
            best: {scores["eval"]["best_mean"]}, std: {scores["eval"]["best_std"]}""")
    
    return scores, logs

In [None]:
train_args = dict(
    iterations=500,
    target_probs=(0.45, 0.55),
    show_progress=False,
    target_size=60,
    w_budget_init=1,
    w_budget_inc=1.15,
    w_budget_dec=0.98,
    k_samples=32,
)

In [None]:
def get_result_paths(dataset_name, cls_idx, save_dir, use_json=False):
    os.makedirs(save_dir, exist_ok=True)
    ext = "json" if use_json else "pt"
    
    base_name = f"{dataset_name}_{cls_idx[0]}-{cls_idx[1]}"
    return os.path.join(save_dir, f"{base_name}_scores.{ext}"), os.path.join(save_dir, f"{base_name}_logs.{ext}")

In [None]:
cls_idx = ()
dataset_name = 'motif'
save_dir="./sampler_ckpts/MOTIF"

scores, logs = train_eval(
    cls_idx,
    dataset_name,
    num_runs=1000,
    num_samples=500,
    show_runs=False,
    train_args=train_args,
    use_retrained=True
)
scores_path, logs_path = get_result_paths(dataset_name, cls_idx, save_dir)
torch.save(scores, scores_path)
torch.save(logs, logs_path)

In [None]:
# bpws = [(x['bpws'], x['converged']) for x in logs]
# for weights, converged in bpws:
#     plt.plot(weights, "b" if converged else "r")
#     plt.ylim(1, 2)
        
# weight_inc = lambda x: any(w > 1 for w in x[0])
# unstable = list(filter(weight_inc, bpws))
# stable = list(filter(lambda x: not weight_inc(x), bpws))

# print(f"stable: {len([x for _, x in stable if x])}/{len(stable)}")
# print(f"unstable: {len([x for _, x in unstable if x])}/{len(unstable)}")


In [None]:
def get_best_ckpt_sampler(scores, logs, class_pair, dataset, use_train=False):
    run_idx = scores["train" if use_train else "eval"]["best_idx"]
    converged = [l for l in logs if l["converged"]]
    
    directory = f"sampler_ckpts/{dataset.name}/{class_pair[0]}-{class_pair[1]}"
    if (p := converged[run_idx].get("save_path")) is not None:
        return p

    files = [
        os.path.join(directory, file)
        for file in os.listdir(directory)
        if os.path.isfile(os.path.join(directory, file)) and not file.startswith('.')
    ]
    files.sort(key=os.path.getmtime)
    return files[run_idx - len(converged)]

def evaluate_sampler(adjacent_class_pairs,
                     dataset_name,
                     num_samples,
                     sampler_ckpt_dir,
                     get_ckpt_from_logs=False,
                     use_json=False,
                     use_train_sampling=False,
                     from_retrained_model=False,
                     sampler_ckpt_paths=[]):

    #make sure that order of sampler_ckpt_paths is the same as adjacent class pairs
    trainers = []
    dataset, _ = get_dataset_setup(dataset_name, retrained=from_retrained_model)

    if get_ckpt_from_logs:
        for class_pair in adjacent_class_pairs:
            scores_path, logs_path = get_result_paths(dataset_name, class_pair, save_dir=sampler_ckpt_dir, use_json=use_json)
            if use_json:
                with open(scores_path, "r") as f:
                    scores = json.load(f)
                with open(logs_path, "r") as f:
                    logs = json.load(f)
            else:
                scores = torch.load(scores_path)
                logs = torch.load(logs_path)
            sampler_ckpt_paths.append(get_best_ckpt_sampler(scores, logs, class_pair, dataset, use_train=use_train_sampling))

    for sampler_path, class_pair in zip(sampler_ckpt_paths, adjacent_class_pairs):
        trainers.append(get_trainer(class_pair, dataset_name=dataset_name, sampler_path=sampler_path))

    dataset, model = get_dataset_setup(dataset_name, use_gat=False)

    evaluation = gnnboundary.evaluate_boundary(dataset,
                                               trainers,
                                               adjacent_class_pairs,
                                               model,
                                               num_samples)

    save_path = f'./figures/{dataset_name}'
    evaluation['boundary_margin'][evaluation['boundary_margin'] == 0] = np.nan
    evaluation['boundary_thickness'][evaluation['boundary_thickness'] == 0] = np.nan

    gnnboundary.draw_matrix(
        evaluation['boundary_margin'],
        dataset.GRAPH_CLS.values(),
        xlabel='Decision boundary',
        ylabel='Decision region',
        file_name=f'{dataset_name}_boundary_margin.png',
        save_path=save_path
    )
    gnnboundary.draw_matrix(
        evaluation['boundary_thickness'],
        dataset.GRAPH_CLS.values(),
        xlabel='Decision boundary',
        ylabel='Decision region',
        file_name=f'{dataset_name}_boundary_thickness.png',
        save_path=save_path
    )

    with open(f'{save_path}/boundary_complexity.json', 'w') as f:
        boundary_complexity = {}
        for class_pair, complexity in evaluation['boundary_complexity'].items():
            boundary_complexity[f'{class_pair[0]}_{class_pair[1]}'] = complexity
        json.dump(boundary_complexity, f)


In [None]:
adjacent_class_pairs = [(0, 1), (0, 2), (1, 2)]
sampler_ckpt_dir = './sampler_ckpts/COLLAB'
dataset_name = 'collab'
num_samples = 500

evaluate_sampler(adjacent_class_pairs,
                 dataset_name,
                 num_samples,
                 sampler_ckpt_dir=sampler_ckpt_dir,
                 get_ckpt_from_logs=True,
                 #sampler_ckpt_paths=sampler_ckpt_paths, PROVIDE CKPT PATH AND ADD HERE, GET CKPT FROM LOGS SHOULD BE FALSE
                 )

In [None]:
adjacent_class_pairs = [(0, 5), (1, 5)]
sampler_ckpt_dir = './sampler_ckpts/ENZYMES'
dataset_name = 'enzymes'
num_samples = 500

evaluate_sampler(adjacent_class_pairs,
                 dataset_name,
                 num_samples,
                 sampler_ckpt_dir=sampler_ckpt_dir,
                 get_ckpt_from_logs=True,
                 use_json=True,
                 #sampler_ckpt_paths=sampler_ckpt_paths, PROVIDE CKPT PATH AND ADD HERE, GET CKPT FROM LOGS SHOULD BE FALSE
                 )