In [158]:
import os
from pathlib import Path

import numpy as np
import optuna
import pandas as pd
import torch

from mscproject.datasets import CompanyBeneficialOwners
import mscproject.models as mod

while not Path("data") in Path(".").iterdir():
    os.chdir("..")

In [159]:
study_names = (
    "pyg_model_selection_ALL",
    "pyg_model_selection_GCN",
    "pyg_model_selection_GraphSAGE",
    "pyg_model_selection_GAT",
    "pyg_model_selection_HGT",
    "pyg_model_selection_HAN",
)

In [160]:
from IPython.display import display

In [161]:
trials_dfs = [
    study.trials_dataframe().assign(study_name=study.study_name)
    for study in (
        optuna.load_study(study_name=study_name, storage="sqlite:///data/optuna.db")
        for study_name in study_names
    )
]
eval_df = pd.concat(trials_dfs, join="inner", axis=0)


In [162]:
# Delete outlier trials
# study = optuna.load_study(study_name="pyg_model_selection_GraphSAGE", storage="sqlite:///data/optuna.db")
# study.tell(trial=118, state=optuna.trial.TrialState.FAIL)


In [163]:
study.trials_dataframe().query("number == 118")

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_act,params_dropout,params_edge_aggr,params_han_dropout,params_heads,...,user_attrs_aprc,user_attrs_auc,user_attrs_best_epoch,user_attrs_f1,user_attrs_learning_rate,user_attrs_n_hidden,user_attrs_precision,user_attrs_recall,user_attrs_total_epochs,state
118,118,1.141197,2022-09-07 21:51:51.882301,2022-09-07 21:52:04.448549,0 days 00:00:12.566248,relu,0.798258,mean,0.864482,8,...,0.263534,0.663456,67.0,0.205607,0.01,64,0.116446,0.877493,77,COMPLETE


In [171]:
for study in (
    optuna.load_study(study_name=study_name, storage="sqlite:///data/optuna.db")
    for study_name in study_names
):
    print(study.study_name)
    # Plot the results.
    mean_top_10_loss = study.trials_dataframe()["value"].sort_values().head(10).mean()
    print("Mean of top 10 best loss:", mean_top_10_loss)
    optuna.visualization.plot_optimization_history(study).show()
    # optuna.visualization.plot_contour(study).show()
    optuna.visualization.plot_slice(study).show()
    optuna.visualization.plot_param_importances(study).show()
    print()
    print()

pyg_model_selection_ALL
Mean of top 10 best loss:




pyg_model_selection_GCN
Mean of top 10 best loss:




pyg_model_selection_GraphSAGE
Mean of top 10 best loss:




pyg_model_selection_GAT
Mean of top 10 best loss:




pyg_model_selection_HGT
Mean of top 10 best loss:


In [165]:
best = study.best_params

In [166]:
best

{'act': 'gelu',
 'dropout': 0.6615282053968176,
 'edge_aggr': 'mean',
 'han_dropout': 0.8315675825340431,
 'heads': 8,
 'hidden_channels_log2': 7,
 'jk': 'last',
 'n_layers': 3,
 'negative_slope': 0.9023018164793171,
 'weight_decay': 9.51434471016692e-05}

In [167]:
dataset_path = "data/pyg/"

# Set the device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the dataset.
dataset = CompanyBeneficialOwners(dataset_path, to_undirected=True)
dataset = dataset.data.to(device)

In [168]:
model = mod.get_model("HAN")

In [169]:
model(in_channels=-1, hidden_channels=64, num_layers=3, metadata=dataset.metadata())

KeyError: 'heads'

In [None]:
dataset

HeteroData(
  [1mcompany[0m={
    x=[96530, 32],
    y=[96530],
    train_mask=[96530],
    val_mask=[96530],
    test_mask=[96530],
    feature_names=[32]
  },
  [1mperson[0m={
    x=[32609, 18],
    y=[32609],
    train_mask=[32609],
    val_mask=[32609],
    test_mask=[32609],
    feature_names=[18]
  },
  [1m(company, owns, company)[0m={
    edge_index=[2, 54607],
    edge_attr=[54607, 1],
    train_mask=[54607],
    val_mask=[54607],
    test_mask=[54607]
  },
  [1m(person, owns, company)[0m={
    edge_index=[2, 80219],
    edge_attr=[80219, 1],
    train_mask=[80219],
    val_mask=[80219],
    test_mask=[80219]
  },
  [1m(company, rev_owns, company)[0m={
    edge_index=[2, 54607],
    edge_attr=[54607, 1],
    train_mask=[54607],
    val_mask=[54607],
    test_mask=[54607]
  },
  [1m(company, rev_owns, person)[0m={
    edge_index=[2, 80219],
    edge_attr=[80219, 1],
    train_mask=[80219],
    val_mask=[80219],
    test_mask=[80219]
  }
)