In [1]:
import os
from pathlib import Path

import numpy as np
import optuna
import pandas as pd
import torch
from tqdm import tqdm
from torch_geometric.transforms import AddSelfLoops
from mscproject.transforms import RemoveSelfLoops

from mscproject.datasets import CompanyBeneficialOwners
from mscproject.transforms import RemoveSelfLoops
import mscproject.models as mod
import mscproject.experiment as exp

while not Path("data") in Path(".").iterdir():
    os.chdir("..")

In [2]:
MODEL_DIR = Path("data/models/pyg/unregularised/")
OPTUNA_DB = Path("data/optuna-03.db")
DATASET_PATH = Path("data/pyg")
PREDICTION_DIR = Path("data/predictions")

model_names = [x.stem for x in MODEL_DIR.iterdir()]
model_names

['GraphSAGE', 'GCN']

In [3]:
def get_best_trial(model_name):
    study = optuna.load_study(
        study_name=f"pyg_model_selection_{model_name}_DESIGN",
        storage=f"sqlite:///{OPTUNA_DB}",
    )
    model_params = study.best_params
    user_attrs = study.best_trial.user_attrs
    return model_params, user_attrs

In [None]:
get_best_trial("GraphSAGE")

({'act': 'gelu',
  'add_self_loops': False,
  'edge_aggr': 'max',
  'hidden_channels_log2': 8,
  'jk': 'none',
  'num_layers': 3},
 {'acc': 0.9271753430366516,
  'aprc': 0.7167404890060425,
  'auc': 0.9486182332038879,
  'best_epoch': 445,
  'f1': 0.5279343724250793,
  'learning_rate': 0.01,
  'loss': 0.860064685344696,
  'n_hidden': 256,
  'precision': 0.3665480315685272,
  'recall': 0.9432234168052673,
  'total_epochs': 645})

In [4]:
get_best_trial("GCN")

({'act': 'gelu',
  'add_self_loops': True,
  'bias': False,
  'edge_aggr': 'min',
  'gcn_aggr': 'min',
  'hidden_channels_log2': 7,
  'jk': 'last',
  'num_layers': 4},
 {'acc': 0.9289156794548035,
  'aprc': 0.9664060473442078,
  'auc': 0.9966164231300354,
  'best_epoch': 503,
  'f1': 0.8083209991455078,
  'learning_rate': 0.01,
  'loss': 0.12577208876609802,
  'n_hidden': 128,
  'precision': 0.6800000071525574,
  'recall': 0.9963369965553284,
  'total_epochs': 703})

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_metrics = {}

# Load and evaluate models
for model_name in model_names:

    print("Evaluating model:", model_name)

    model_path = MODEL_DIR / f"{model_name}.pt"
    if not model_path.exists():
        print(f"Model {model_name} does not exist, skipping")
        continue

    model_params, user_attrs = get_best_trial(model_name)
    user_attrs["model_type"] = model_name

    dataset = CompanyBeneficialOwners(DATASET_PATH, to_undirected=True)
    dataset = dataset.data.to(device)

    model, optimiser, _ = exp.build_experiment_from_trial_params(
        model_params, user_attrs, dataset
    )
    model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))

    model.to(device)

    if model_params["add_self_loops"]:
        dataset = AddSelfLoops(fill_value=1.0)(dataset)
    else:
        dataset = RemoveSelfLoops()(dataset)

    eval_metrics = exp.evaluate(
        model, dataset, on_train=False, on_val=False, on_test=True
    )

    model_metrics[model_name] = eval_metrics.test
    print(model_name, eval_metrics.test)

    print("Model structure:")
    print(model)

    print("Making predictions...")
    prediction_dict = model(dataset.x_dict, dataset.edge_index_dict)
    prediction_df_list = []

    print("Saving predictions...")
    for node_type in dataset.node_types:
        prediction = (
            prediction_dict[node_type][dataset[node_type].test_mask]
            .cpu()
            .detach()
            .numpy()
            .flatten()
        )
        actual = (
            dataset.y_dict[node_type][dataset[node_type].test_mask]
            .cpu()
            .detach()
            .numpy()
            .flatten()
        )
        df = pd.DataFrame({"pred_proba": prediction, "actual": actual})
        prediction_df_list.append(df)

    prediction_df = pd.concat(prediction_df_list)
    prediction_df.to_csv(PREDICTION_DIR / f"{model_name}.csv", index=False)

Evaluating model: GraphSAGE
GraphSAGE loss: 1.194, acc: 0.926, prc: 0.597, rec: 0.891, f1: 0.715, auc: 0.953, aprc: 0.766
Model structure:
GraphModule(
  (convs): ModuleList(
    (0): ModuleDict(
      (company__owns__company): SAGEConv(-1, 256, aggr=mean)
      (person__owns__company): SAGEConv(-1, 256, aggr=mean)
      (company__rev_owns__company): SAGEConv(-1, 256, aggr=mean)
      (company__rev_owns__person): SAGEConv(-1, 256, aggr=mean)
    )
    (1): ModuleDict(
      (company__owns__company): SAGEConv(256, 256, aggr=mean)
      (person__owns__company): SAGEConv(256, 256, aggr=mean)
      (company__rev_owns__company): SAGEConv(256, 256, aggr=mean)
      (company__rev_owns__person): SAGEConv(256, 256, aggr=mean)
    )
    (2): ModuleDict(
      (company__owns__company): SAGEConv(256, 1, aggr=mean)
      (person__owns__company): SAGEConv(256, 1, aggr=mean)
      (company__rev_owns__company): SAGEConv(256, 1, aggr=mean)
      (company__rev_owns__person): SAGEConv(256, 1, aggr=mean)


In [10]:
performance_comparison = pd.DataFrame.from_dict(model_metrics, orient="index")
performance_comparison.to_csv("reports/test-performance-pyg.csv", index_label="model")