In [1]:
import os
from pathlib import Path

import numpy as np
import optuna
import pandas as pd
import torch
from tqdm import tqdm
from torch_geometric.transforms import AddSelfLoops
from mscproject.transforms import RemoveSelfLoops

from mscproject.datasets import CompanyBeneficialOwners
from mscproject.transforms import RemoveSelfLoops
import mscproject.models as mod
import mscproject.experiment as exp

while not Path("data") in Path(".").iterdir():
    os.chdir("..")

In [2]:
MODEL_DIR = Path("data/models/pyg/weights-unregularised/")
OPTUNA_DB = Path("data/optuna-06.db")
DATASET_PATH = Path("data/pyg")
PREDICTION_DIR = Path("data/predictions")

model_names = [x.stem for x in MODEL_DIR.iterdir()]
model_names

['GraphSAGE', 'KGNN']

In [3]:
dataset = CompanyBeneficialOwners(DATASET_PATH, to_undirected=True)
dataset = dataset.data.to("cpu")

In [4]:
def get_best_trial(model_name):
    study = optuna.load_study(
        study_name=f"pyg_model_selection_{model_name}_ARCHITECTURE",
        storage=f"sqlite:///{OPTUNA_DB}",
    )
    model_params = study.best_params
    user_attrs = study.best_trial.user_attrs
    return model_params, user_attrs

In [5]:
model_name = "GraphSAGE"
study = optuna.load_study(
    study_name=f"pyg_model_selection_{model_name}_ARCHITECTURE",
    storage=f"sqlite:///{OPTUNA_DB}",
)

In [11]:
study.trials_dataframe().sort_values("value", ascending=False)[:10].T

Unnamed: 0,71,2,90,57,65,92,42,22,25,56
number,71,2,90,57,65,92,42,22,25,56
value,0.974547,0.964633,0.964578,0.953541,0.890489,0.846144,0.804536,0.798571,0.772965,0.749134
datetime_start,2023-01-09 13:51:44.689671,2023-01-08 23:45:58.579411,2023-01-09 14:37:38.644384,2023-01-09 13:03:53.627998,2023-01-09 13:38:23.969175,2023-01-09 14:43:26.845169,2023-01-09 01:23:31.508692,2023-01-09 00:31:55.916797,2023-01-09 00:38:23.500434,2023-01-09 12:59:41.820345
datetime_complete,2023-01-09 13:55:06.312907,2023-01-09 00:01:36.145491,2023-01-09 14:41:50.938074,2023-01-09 13:07:27.693139,2023-01-09 13:40:37.164632,2023-01-09 14:45:21.441809,2023-01-09 01:27:14.505445,2023-01-09 00:35:09.474524,2023-01-09 00:41:06.815786,2023-01-09 13:03:53.590702
duration,0 days 00:03:21.623236,0 days 00:15:37.566080,0 days 00:04:12.293690,0 days 00:03:34.065141,0 days 00:02:13.195457,0 days 00:01:54.596640,0 days 00:03:42.996753,0 days 00:03:13.557727,0 days 00:02:43.315352,0 days 00:04:11.770357
params_act,leaky_relu,gelu,gelu,relu,leaky_relu,gelu,leaky_relu,leaky_relu,leaky_relu,leaky_relu
params_add_self_loops,False,True,False,True,False,False,False,False,False,False
params_bias,True,False,True,True,True,True,True,True,True,True
params_gnn_aggr,max,max,min,min,min,min,mean,mean,mean,mean
params_hidden_channels_log2,8,8,7,7,7,6,8,8,8,8


In [5]:
get_best_trial("GraphSAGE")

({'act': 'leaky_relu',
  'add_self_loops': False,
  'bias': True,
  'gnn_aggr': 'max',
  'hidden_channels_log2': 8,
  'jk': 'none',
  'num_layers': 4,
  'to_hetero_aggr': 'min'},
 {'acc': 0.9309237003326416,
  'aprc': 0.9745465517044067,
  'aprc_history': [0.08378168940544128,
   0.05609657242894173,
   0.06113765388727188,
   0.10603241622447968,
   0.11127810180187225,
   0.11036204546689987,
   0.10628073662519455,
   0.09745337069034576,
   0.09453783929347992,
   0.10052155703306198,
   0.11373697966337204,
   0.12952886521816254,
   0.1178245022892952,
   0.15287569165229797,
   0.16045233607292175,
   0.17180752754211426,
   0.15863114595413208,
   0.1991516649723053,
   0.222223162651062,
   0.2511574327945709,
   0.2615694999694824,
   0.24503661692142487,
   0.33745962381362915,
   0.29018062353134155,
   0.2849925756454468,
   0.3031104803085327,
   0.4120880365371704,
   0.4282169044017792,
   0.41041746735572815,
   0.4452163279056549,
   0.4496076703071594,
   0.459260731

In [6]:
get_best_trial("KGNN")

({'act': 'gelu',
  'add_self_loops': False,
  'bias': True,
  'gnn_aggr': 'max',
  'hidden_channels_log2': 8,
  'jk': 'last',
  'num_layers': 3,
  'to_hetero_aggr': 'min'},
 {'acc': 0.9291834235191345,
  'aprc': 0.9709917306900024,
  'aprc_history': [0.08662247657775879,
   0.05535983294248581,
   0.07431593537330627,
   0.09581813216209412,
   0.08236677944660187,
   0.08700983971357346,
   0.09674164652824402,
   0.09904752671718597,
   0.09727895259857178,
   0.10366372764110565,
   0.11654387414455414,
   0.12331970781087875,
   0.13008642196655273,
   0.1216675266623497,
   0.141188383102417,
   0.1429957002401352,
   0.13772374391555786,
   0.13665030896663666,
   0.16232596337795258,
   0.15622681379318237,
   0.1951773464679718,
   0.17062173783779144,
   0.21084186434745789,
   0.23359593749046326,
   0.23868553340435028,
   0.2407713681459427,
   0.29477977752685547,
   0.29266539216041565,
   0.30065950751304626,
   0.3373481035232544,
   0.31420618295669556,
   0.2984260320

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_metrics = {}

# Load and evaluate models
for model_name in model_names:

    print("Evaluating model:", model_name)

    model_path = MODEL_DIR / f"{model_name}.pt"

    if not model_path.exists():
        print(f"Model {model_name} does not exist, skipping")
        continue
    else:
        print(f"Loading model from: {model_path}")

    model_params, user_attrs = get_best_trial(model_name)
    user_attrs["model_type"] = model_name
    del user_attrs["aprc_history"]

    print(f"Using model params: {model_params}")
    print(f"Using user attrs: {user_attrs}")

    dataset = CompanyBeneficialOwners(DATASET_PATH, to_undirected=True)
    dataset = dataset.data.to(device)

    model, optimiser, _ = exp.build_experiment_from_trial_params(
        model_params, user_attrs, dataset
    )
    model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))

    model.to(device)

    if model_params["add_self_loops"]:
        dataset = AddSelfLoops(fill_value=1.0)(dataset)
    else:
        dataset = RemoveSelfLoops()(dataset)

    eval_metrics = exp.evaluate(
        model, dataset, on_train=False, on_val=False, on_test=True
    )

    model_metrics[model_name] = eval_metrics.test
    print(model_name, eval_metrics.test)

    print("Making predictions...")
    prediction_dict = model(dataset.x_dict, dataset.edge_index_dict)
    prediction_df_list = []

    print("Saving predictions...")
    for node_type in dataset.node_types:
        prediction = (
            prediction_dict[node_type][dataset[node_type].test_mask]
            .cpu()
            .detach()
            .numpy()
            .flatten()
        )
        actual = (
            dataset.y_dict[node_type][dataset[node_type].test_mask]
            .cpu()
            .detach()
            .numpy()
            .flatten()
        )
        df = pd.DataFrame({"pred_proba": prediction, "actual": actual})
        prediction_df_list.append(df)

    prediction_df = pd.concat(prediction_df_list)
    prediction_df.to_csv(PREDICTION_DIR / f"{model_name}.csv", index=False)

Evaluating model: GraphSAGE
GraphSAGE loss: 1.194, acc: 0.926, prc: 0.597, rec: 0.891, f1: 0.715, auc: 0.953, aprc: 0.766
Model structure:
GraphModule(
  (convs): ModuleList(
    (0): ModuleDict(
      (company__owns__company): SAGEConv(-1, 256, aggr=mean)
      (person__owns__company): SAGEConv(-1, 256, aggr=mean)
      (company__rev_owns__company): SAGEConv(-1, 256, aggr=mean)
      (company__rev_owns__person): SAGEConv(-1, 256, aggr=mean)
    )
    (1): ModuleDict(
      (company__owns__company): SAGEConv(256, 256, aggr=mean)
      (person__owns__company): SAGEConv(256, 256, aggr=mean)
      (company__rev_owns__company): SAGEConv(256, 256, aggr=mean)
      (company__rev_owns__person): SAGEConv(256, 256, aggr=mean)
    )
    (2): ModuleDict(
      (company__owns__company): SAGEConv(256, 1, aggr=mean)
      (person__owns__company): SAGEConv(256, 1, aggr=mean)
      (company__rev_owns__company): SAGEConv(256, 1, aggr=mean)
      (company__rev_owns__person): SAGEConv(256, 1, aggr=mean)


In [10]:
performance_comparison = pd.DataFrame.from_dict(model_metrics, orient="index")
performance_comparison.to_csv("reports/test-performance-pyg.csv", index_label="model")