In [1]:
import os
from pathlib import Path

import numpy as np
import optuna
import pandas as pd
import torch
from tqdm import tqdm

from mscproject.datasets import CompanyBeneficialOwners
import mscproject.models as mod
import mscproject.experiment as exp

while not Path("data") in Path(".").iterdir():
    os.chdir("..")

In [2]:
study_names = (
    "pyg_model_selection_ALL",
    "pyg_model_selection_GCN",
    "pyg_model_selection_GraphSAGE",
    "pyg_model_selection_GAT",
    "pyg_model_selection_HGT",
    "pyg_model_selection_HAN",
)

In [3]:
from IPython.display import display

In [4]:
trials_dfs = [
    study.trials_dataframe().assign(study_name=study.study_name)
    for study in (
        optuna.load_study(study_name=study_name, storage="sqlite:///data/optuna.db")
        for study_name in study_names
    )
]
eval_df = pd.concat(trials_dfs, join="inner", axis=0)

In [5]:
for study in (
    optuna.load_study(study_name=study_name, storage="sqlite:///data/optuna.db")
    for study_name in study_names
):
    print(study.study_name)
    # Plot the results.
    mean_top_10_loss = study.trials_dataframe()["value"].sort_values().head(10).mean()
    print("Mean of top 10 best loss:", mean_top_10_loss)
    optuna.visualization.plot_optimization_history(study).show()
    # optuna.visualization.plot_contour(study).show()
    optuna.visualization.plot_slice(study).show()
    optuna.visualization.plot_param_importances(study).show()
    print()
    print()

pyg_model_selection_ALL
Mean of top 10 best loss: 1.093110477924347




pyg_model_selection_GCN
Mean of top 10 best loss: 1.061334228515625




pyg_model_selection_GraphSAGE
Mean of top 10 best loss: 1.07492436170578




pyg_model_selection_GAT
Mean of top 10 best loss: 1.105336558818817




pyg_model_selection_HGT
Mean of top 10 best loss: 1.120419991016388




pyg_model_selection_HAN
Mean of top 10 best loss: 1.0775286436080933






In [6]:
best = study.best_params

In [7]:
metric = "user_attrs_aprc"
first = True

best_trials = {}

for study_name, df in zip(study_names, trials_dfs):
    if first:
        first = False
        continue
    print(study_name)
    top = df.sort_values(metric, ascending=False)[:10]
    param_columns = [x for x in top.columns if x.startswith("params")]
    display(top[["value", metric, *param_columns]])

    model_type = study_name.split("_")[-1]
    best_trials[model_type] = top.iloc[0].to_dict()
    best_trials[model_type]["model_type"] = model_type

    print()

pyg_model_selection_GCN


Unnamed: 0,value,user_attrs_aprc,params_act,params_bias,params_dropout,params_edge_aggr,params_gcn_aggr,params_hidden_channels_log2,params_jk,params_n_layers,params_weight_decay
154,1.068621,0.300115,gelu,False,0.275822,mean,min,8,none,4,0.000446
146,1.053777,0.29865,gelu,False,0.308114,mean,min,8,none,4,0.000233
150,1.058767,0.296211,gelu,False,0.262661,mean,min,9,none,4,0.00023
180,1.070995,0.293452,gelu,False,0.180217,mean,min,8,none,4,0.000125
121,1.056951,0.292615,gelu,True,0.482475,min,min,8,none,4,0.000401
178,1.065492,0.292496,gelu,False,0.194984,mean,min,8,none,4,0.000146
172,1.084868,0.2916,gelu,False,0.244064,mean,min,8,none,4,3.7e-05
155,1.062847,0.291305,gelu,False,0.21559,mean,min,8,none,4,0.000356
183,1.059647,0.290809,gelu,False,0.140067,mean,min,8,none,4,0.000547
176,1.074374,0.290174,gelu,False,0.283315,mean,min,8,none,4,0.000123



pyg_model_selection_GraphSAGE


Unnamed: 0,value,user_attrs_aprc,params_act,params_dropout,params_edge_aggr,params_hidden_channels_log2,params_jk,params_n_layers,params_weight_decay
377,1.081923,0.301158,gelu,0.001452,max,7,none,6,3.9e-05
299,1.079405,0.298031,relu,0.018137,max,6,none,9,2.8e-05
467,1.076838,0.296652,gelu,0.059616,max,6,none,10,2e-05
460,1.076581,0.295407,gelu,0.090728,max,6,none,10,3.9e-05
400,1.083387,0.293883,gelu,0.06192,max,7,none,5,0.000121
434,1.093487,0.293736,gelu,0.02043,max,6,none,4,4.4e-05
424,1.070499,0.29302,gelu,0.028091,max,6,none,10,2.1e-05
407,1.083028,0.292475,gelu,0.049849,min,7,none,6,0.000139
145,1.074698,0.291238,gelu,0.127093,sum,7,none,5,0.000117
457,1.097094,0.290115,gelu,0.053415,max,6,none,10,4.7e-05



pyg_model_selection_GAT


Unnamed: 0,value,user_attrs_aprc,params_act,params_concat,params_dropout,params_edge_aggr,params_heads,params_hidden_channels_log2,params_jk,params_n_layers,params_weight_decay
129,1.108123,0.279968,gelu,True,0.018349,sum,8,5,last,3,0.000548
193,1.091401,0.279189,gelu,True,0.094579,mean,8,5,last,3,0.000259
199,1.092709,0.277845,gelu,True,0.071797,mean,8,5,last,3,0.000158
158,1.119289,0.276867,gelu,True,0.049397,sum,8,5,last,3,0.000315
150,1.109862,0.275258,gelu,True,0.036164,sum,8,5,last,3,0.000348
107,1.122961,0.275176,gelu,True,0.014991,sum,8,5,last,3,0.000417
112,1.111793,0.274302,gelu,True,0.017208,sum,8,5,last,3,0.000545
196,1.102922,0.271922,gelu,True,0.100224,mean,8,5,last,3,0.000215
167,1.115588,0.2708,gelu,True,0.100173,sum,8,5,last,3,7.2e-05
190,1.113185,0.269868,gelu,True,0.05298,mean,8,5,last,3,1.7e-05



pyg_model_selection_HGT


Unnamed: 0,value,user_attrs_aprc,params_act,params_dropout,params_edge_aggr,params_group,params_heads,params_hidden_channels_log2,params_jk,params_n_layers,params_weight_decay
197,1.117464,0.258622,relu,0.510325,mean,mean,16,4,last,1,9.3e-05
124,1.120706,0.2551,gelu,0.464841,mean,mean,16,4,none,1,0.000411
161,1.121687,0.253846,gelu,0.46304,mean,mean,16,4,none,1,0.000236
119,1.124539,0.252403,gelu,0.496201,mean,mean,16,4,none,1,0.000285
123,1.121111,0.250331,gelu,0.456334,mean,mean,16,4,none,1,0.00053
181,1.121398,0.247546,gelu,0.433734,mean,mean,16,4,none,1,0.000243
189,1.115456,0.246076,gelu,0.525532,mean,mean,16,4,none,1,0.000133
156,1.119323,0.245978,gelu,0.469792,mean,mean,16,4,none,1,0.000163
155,1.127324,0.245895,gelu,0.4725,mean,mean,16,4,none,1,0.00043
195,1.122329,0.245527,gelu,0.507693,mean,mean,16,4,none,1,0.000202



pyg_model_selection_HAN


Unnamed: 0,value,user_attrs_aprc,params_act,params_dropout,params_edge_aggr,params_han_dropout,params_heads,params_hidden_channels_log2,params_jk,params_n_layers,params_negative_slope,params_weight_decay
175,1.06572,0.294649,relu,0.676268,mean,0.739011,8,7,last,3,0.899143,7.700657e-05
169,1.08444,0.293395,relu,0.742444,mean,0.793413,8,7,last,3,0.857914,0.0001880315
193,1.070493,0.290482,relu,0.624195,mean,0.766473,8,7,last,3,0.853534,2.992104e-05
178,1.065092,0.28936,gelu,0.661528,mean,0.831568,8,7,last,3,0.902302,9.514345e-05
171,1.092896,0.286839,relu,0.656755,mean,0.794544,8,7,last,3,0.885047,0.000265046
150,1.08294,0.286418,relu,0.758099,mean,0.846684,8,7,last,3,0.856413,0.0001194791
173,1.091174,0.285874,relu,0.704942,mean,0.797343,8,7,last,3,0.846063,0.0003055602
155,1.0866,0.285226,relu,0.722577,mean,0.786789,8,7,last,3,0.863275,7.469814e-05
132,1.080382,0.284849,relu,0.637325,mean,0.760159,8,6,last,3,0.911293,0.0003098422
191,1.078625,0.284602,relu,0.619023,mean,0.765217,8,7,last,3,0.86975,1.881953e-07





In [8]:
best_trials

{'GCN': {'number': 154,
  'value': 1.0686213970184326,
  'datetime_start': Timestamp('2022-09-07 19:25:41.290970'),
  'datetime_complete': Timestamp('2022-09-07 19:26:07.215206'),
  'duration': Timedelta('0 days 00:00:25.924236'),
  'params_act': 'gelu',
  'params_bias': False,
  'params_dropout': 0.2758218102706141,
  'params_edge_aggr': 'mean',
  'params_gcn_aggr': 'min',
  'params_hidden_channels_log2': 8,
  'params_jk': 'none',
  'params_n_layers': 4,
  'params_weight_decay': 0.0004460764689980147,
  'user_attrs_acc': 0.9115312099456787,
  'user_attrs_aprc': 0.3001147210597992,
  'user_attrs_auc': 0.6966859698295593,
  'user_attrs_best_epoch': 43.0,
  'user_attrs_f1': 0.21509107947349548,
  'user_attrs_learning_rate': 0.01,
  'user_attrs_n_hidden': 256,
  'user_attrs_precision': 0.12685422599315643,
  'user_attrs_recall': 0.7065526843070984,
  'user_attrs_total_epochs': 53.0,
  'state': 'COMPLETE',
  'study_name': 'pyg_model_selection_GCN',
  'model_type': 'GCN'},
 'GraphSAGE': {'n

In [9]:
import mscproject.experiment as exp

In [10]:
# remove prefix from string
def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix) :]
    return text

In [73]:
def build_experiment_from_trial_params(trial_params, dataset, verbose=False):
    param_dict = {
        remove_prefix(k, "params_"): v
        for k, v in trial_params.items()
        if k.startswith("params")
    }
    # Rename key from "n_layers" to "num_layers"
    if "n_layers" in param_dict:
        param_dict["num_layers"] = param_dict.pop("n_layers")
    param_dict["in_channels"] = -1
    param_dict["out_channels"] = 1
    param_dict["act_first"] = True
    param_dict["add_self_loops"] = True
    param_dict["model_type"] = mod.get_model(trial_params["model_type"])
    param_dict["v2"] = True
    lr = trial_params["user_attrs_learning_rate"]
    param_dict["jk"] = None if param_dict["jk"] == "none" else param_dict["jk"]
    if verbose:
        print(param_dict)
    return exp.get_model_and_optimiser(param_dict, dataset, lr)

In [74]:
dataset_path = "data/pyg/"

# Set the device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {str(device).upper()}")

# Load the dataset.
dataset = CompanyBeneficialOwners(dataset_path, to_undirected=True)
dataset = dataset.data.to(device)

model_metrics = {}

models_dir = Path("models/pyg")
models_dir.mkdir(parents=True, exist_ok=True)

for model_name in best_trials.keys():
    print("Training model:", model_name)
    trial_dict = best_trials[model_name]
    model, optimiser = build_experiment_from_trial_params(
        trial_dict, dataset, verbose=True
    )

    # Train and evaluate the model.
    best_epoch = int(trial_dict["user_attrs_best_epoch"])

    progress = tqdm(range(best_epoch))

    for epoch in progress:
        loss = exp.train(model, dataset, optimiser, on_val=True)
        progress.set_description(f"Train loss: {loss:.4f}")

    eval_metrics = exp.evaluate(
        model, dataset, on_train=False, on_val=False, on_test=True
    )

    model_metrics[model_name] = eval_metrics.test
    print(eval_metrics.test)

    # Save the trained model.
    torch.save(model, f"models/pyg/{model_name}.pt")
    print()

In [69]:
import dataclasses as dc

In [70]:
pd.DataFrame.from_dict(model_metrics, orient="index")

Unnamed: 0,loss,accuracy,precision,recall,f1,auroc,average_precision
GCN,2.893258,0.908909,0.093082,0.422254,0.152538,0.561894,0.100976
GraphSAGE,7.443916,0.887946,0.093534,0.7903,0.167271,0.523927,0.090558
GAT,1.252797,0.912528,0.089199,0.500713,0.151424,0.501229,0.08956
HGT,1.241533,0.912528,0.088436,0.997147,0.162464,0.531433,0.10611
HAN,1.240035,0.912528,0.088799,0.998573,0.163094,0.567514,0.097231


In [44]:
pd.DataFrame.from_dict(dc.asdict(eval_metrics.test), orient="columns")

ValueError: If using all scalar values, you must pass an index