In [1]:
import os
from pathlib import Path

import numpy as np
import optuna
import pandas as pd
import torch
from tqdm import tqdm
from torch_geometric.transforms import AddSelfLoops
from mscproject.transforms import RemoveSelfLoops

from mscproject.datasets import CompanyBeneficialOwners
from mscproject.transforms import RemoveSelfLoops
import mscproject.models as mod
import mscproject.experiment as exp

while not Path("data") in Path(".").iterdir():
    os.chdir("..")

In [2]:
MODEL_DIR = Path("data/models/pyg/weights-unregularised/")
OPTUNA_DB = Path("data/optuna-06.db")
DATASET_PATH = Path("data/pyg")
PREDICTION_DIR = Path("data/predictions")

model_names = [x.stem for x in MODEL_DIR.iterdir()]
model_names

['GraphSAGE', 'KGNN', 'KGNN_aprc_history', 'GraphSAGE_aprc_history']

In [3]:
dataset = CompanyBeneficialOwners(DATASET_PATH, to_undirected=True)
dataset = dataset.data.to("cpu")

In [4]:
def get_best_trial(model_name):
    study = optuna.load_study(
        study_name=f"pyg_model_selection_{model_name}_ARCHITECTURE",
        storage=f"sqlite:///{OPTUNA_DB}",
    )
    model_params = study.best_params
    user_attrs = study.best_trial.user_attrs
    return model_params, user_attrs

In [5]:
model_name = "GraphSAGE"
study = optuna.load_study(
    study_name=f"pyg_model_selection_{model_name}_ARCHITECTURE",
    storage=f"sqlite:///{OPTUNA_DB}",
)

In [6]:
study.trials_dataframe().sort_values("value", ascending=False)[:10].T

Unnamed: 0,16,0,46,34,45,49,47,4,43,42
number,16,0,46,34,45,49,47,4,43,42
value,0.80028,0.592597,0.54846,0.54487,0.515676,0.495268,0.476724,0.442385,0.437794,0.410043
datetime_start,2023-01-13 00:42:10.372377,2023-01-13 00:12:38.146701,2023-01-13 03:00:44.772426,2023-01-13 02:05:07.497213,2023-01-13 02:58:48.600167,2023-01-13 03:05:13.811092,2023-01-13 03:02:40.070519,2023-01-13 00:18:09.159703,2023-01-13 02:45:59.710231,2023-01-13 02:29:47.081945
datetime_complete,2023-01-13 00:47:55.923364,2023-01-13 00:14:47.617278,2023-01-13 03:02:40.046598,2023-01-13 02:09:14.893556,2023-01-13 03:00:44.749204,2023-01-13 03:08:17.236204,2023-01-13 03:04:10.539683,2023-01-13 00:25:25.907928,2023-01-13 02:52:17.525233,2023-01-13 02:45:59.687279
duration,0 days 00:05:45.550987,0 days 00:02:09.470577,0 days 00:01:55.274172,0 days 00:04:07.396343,0 days 00:01:56.149037,0 days 00:03:03.425112,0 days 00:01:30.469164,0 days 00:07:16.748225,0 days 00:06:17.815002,0 days 00:16:12.605334
params_act,leaky_relu,leaky_relu,leaky_relu,relu,leaky_relu,relu,leaky_relu,leaky_relu,leaky_relu,leaky_relu
params_add_self_loops,False,False,False,False,False,False,False,False,False,False
params_bias,False,False,True,True,True,True,True,False,True,True
params_gnn_aggr,min,min,mean,min,mean,mean,mean,sum,min,min
params_hidden_channels_log2,7,5,6,7,6,6,6,8,7,8


In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_metrics = {}

# Load and evaluate models
for model_name in model_names:

    print("Evaluating model:", model_name)

    model_path = MODEL_DIR / f"{model_name}.pt"

    if not model_path.exists():
        print(f"Model {model_name} does not exist, skipping")
        continue
    else:
        print(f"Loading model from: {model_path}")

    model_params, user_attrs = get_best_trial(model_name)
    user_attrs["model_type"] = model_name

    pd.DataFrame(user_attrs["aprc_history"]).to_csv(
        MODEL_DIR / f"{model_name}_aprc_history.csv", index=False
    )
    del user_attrs["aprc_history"]

    print(f"Using model params: {model_params}")
    print(f"Using user attrs: {user_attrs}")

    dataset = CompanyBeneficialOwners(DATASET_PATH, to_undirected=True)
    dataset = dataset.data.to(device)

    model, optimiser, _ = exp.build_experiment_from_trial_params(
        model_params, user_attrs, dataset
    )
    model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))

    model.to(device)

    if model_params["add_self_loops"]:
        dataset = AddSelfLoops(fill_value=1.0)(dataset)
    else:
        dataset = RemoveSelfLoops()(dataset)

    eval_metrics = exp.evaluate(
        model, dataset, on_train=False, on_val=False, on_test=True
    )

    model_metrics[model_name] = eval_metrics.test
    print(model_name, eval_metrics.test)

    print("Making predictions...")
    prediction_dict = model(dataset.x_dict, dataset.edge_index_dict)
    prediction_df_list = []

    print("Saving predictions...")
    for node_type in dataset.node_types:
        prediction = (
            prediction_dict[node_type][dataset[node_type].test_mask]
            .cpu()
            .detach()
            .numpy()
            .flatten()
        )
        actual = (
            dataset.y_dict[node_type][dataset[node_type].test_mask]
            .cpu()
            .detach()
            .numpy()
            .flatten()
        )
        df = pd.DataFrame({"pred_proba": prediction, "actual": actual})
        prediction_df_list.append(df)

    prediction_df = pd.concat(prediction_df_list)
    prediction_df.to_csv(PREDICTION_DIR / f"{model_name}.csv", index=False)

Evaluating model: GraphSAGE
Loading model from: data/models/pyg/weights-unregularised/GraphSAGE.pt
Using model params: {'act': 'leaky_relu', 'add_self_loops': False, 'bias': False, 'gnn_aggr': 'min', 'hidden_channels_log2': 7, 'jk': 'last', 'num_layers': 7, 'to_hetero_aggr': 'min'}
Using user attrs: {'acc': 0.9346719980239868, 'aprc': 0.8002802133560181, 'auc': 0.9660666584968567, 'best_epoch': 522, 'f1': 0.6975089311599731, 'learning_rate': 0.01, 'loss': 1.704654574394226, 'n_hidden': 128, 'precision': 0.5704307556152344, 'recall': 0.8974359035491943, 'total_epochs': 722, 'model_type': 'GraphSAGE'}
GraphSAGE loss: 3.312, acc: 0.932, prc: 0.529, rec: 0.929, f1: 0.674, auc: 0.945, aprc: 0.775
Making predictions...
Saving predictions...
Evaluating model: KGNN
Loading model from: data/models/pyg/weights-unregularised/KGNN.pt
Using model params: {'act': 'relu', 'add_self_loops': False, 'bias': True, 'gnn_aggr': 'max', 'hidden_channels_log2': 7, 'jk': 'none', 'num_layers': 5, 'to_hetero_agg

In [8]:
performance_comparison = pd.DataFrame.from_dict(model_metrics, orient="index")
performance_comparison.to_csv("reports/test-performance-pyg.csv", index_label="model")

In [9]:
model

GraphModule(
  (convs): ModuleList(
    (0): ModuleDict(
      (company__owns__company): GraphConv(-1, 128)
      (person__owns__company): GraphConv(-1, 128)
      (company__rev_owns__company): GraphConv(-1, 128)
      (company__rev_owns__person): GraphConv(-1, 128)
    )
    (1): ModuleDict(
      (company__owns__company): GraphConv(128, 128)
      (person__owns__company): GraphConv(128, 128)
      (company__rev_owns__company): GraphConv(128, 128)
      (company__rev_owns__person): GraphConv(128, 128)
    )
    (2): ModuleDict(
      (company__owns__company): GraphConv(128, 128)
      (person__owns__company): GraphConv(128, 128)
      (company__rev_owns__company): GraphConv(128, 128)
      (company__rev_owns__person): GraphConv(128, 128)
    )
    (3): ModuleDict(
      (company__owns__company): GraphConv(128, 128)
      (person__owns__company): GraphConv(128, 128)
      (company__rev_owns__company): GraphConv(128, 128)
      (company__rev_owns__person): GraphConv(128, 128)
    )
    