In [88]:
import os
import yaml

from pathlib import Path

import torch
import torch.nn.functional as F


from torch_geometric.loader import DataLoader
from torch_geometric.nn import to_hetero


from mscproject.metrics import EvalMetrics, EvalMetricsTuple
from mscproject import models
from mscproject.datasets import CompanyBeneficialOwners

# TODO: regularisation like https://stackoverflow.com/questions/42704283/l1-l2-regularization-in-pytorch
# TODO: follow this example https://github.com/pyg-team/pytorch_geometric/issues/3958

while not Path("data") in Path(".").iterdir():
    os.chdir("..")

In [89]:
conf_dict = yaml.safe_load(Path("config/conf.yaml").read_text())
dataset_path = "data/pyg/"

dataset = CompanyBeneficialOwners(dataset_path, to_undirected=True)

input_data = dataset[0]  # type: ignore
input_metadata = dataset.metadata()

model = models.GAT(
    in_channels=-1,
    hidden_channels=16,
    num_layers=3,
    out_channels=1,
    jk="last",
    # heads=1,
    # concat=True,
    v2=True,
    add_self_loops=False,
)

model = to_hetero(model, metadata=input_metadata, aggr="sum")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset, model = dataset.data.to(device), model.to(device)

with torch.no_grad():  # Initialize lazy modules.
    out = model(dataset.x_dict, dataset.edge_index_dict)

optimizer = torch.optim.Adam(
    model.parameters(), lr=0.01, weight_decay=5e-4, amsgrad=False
)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=0.05, weight_decay=0)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0, nesterov=False, weight_decay=0)

In [90]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(input_data.x_dict, input_data.edge_index_dict)

    company_train_mask = input_data["company"].train_mask
    person_train_mask = input_data["person"].train_mask

    companies_out = out["company"][company_train_mask]
    persons_out = out["person"][person_train_mask]
    out_tensor = torch.cat((companies_out, persons_out), dim=0).float().squeeze()

    companies_y = input_data.y_dict["company"][company_train_mask]
    persons_y = input_data.y_dict["person"][person_train_mask]

    y_tensor = torch.cat((companies_y, persons_y), dim=0).float().squeeze()

    # Multiply importance of anomalous data by 10.
    importance = (y_tensor * 9) + 1

    loss = F.binary_cross_entropy(out_tensor, y_tensor, weight=importance)
    loss.backward()
    optimizer.step()

    return float(loss)

In [91]:
from typing import NamedTuple


class EvalResult(NamedTuple):
    train: EvalMetrics
    val: EvalMetrics


@torch.no_grad()
def test() -> EvalResult:
    model.eval()

    prediction_dict = model(input_data.x_dict, input_data.edge_index_dict)

    eval_metrics_list = []

    for split in ["train_mask", "val_mask"]:

        masks = []
        actuals = []
        predictions = []

        for node_type in ["company", "person"]:
            mask = input_data[node_type][split]
            actual = input_data.y_dict[node_type][mask]
            prediction = prediction_dict[node_type][mask]

            masks.append(mask)
            predictions.append(prediction)
            actuals.append(actual)

        combined_predictions = torch.cat(predictions, dim=0).squeeze()
        combined_actuals = torch.cat(actuals, dim=0).squeeze()

        eval_metrics_list.append(
            EvalMetrics.from_tensors(
                combined_predictions, combined_actuals, pos_weight_multiplier=10
            )
        )

    return EvalResult(*eval_metrics_list)

In [92]:
import numpy as np

# Early Stopping Callback
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.epoch = 0
        self.counter = 0
        self.best_score = None
        self.best_epoch = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss):

        self.epoch += 1
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_loss = val_loss
            self.best_epoch = self.epoch
            self.counter = 0

In [93]:
early_stopping = EarlyStopping(patience=10, verbose=False)

metrics_history = []
max_epochs = 200

while not early_stopping.early_stop and early_stopping.epoch < max_epochs:
    loss = train()
    eval_metrics = test()
    val_loss = eval_metrics.val.loss
    stop = early_stopping(val_loss)
    metrics_history.append(eval_metrics)
    print(f"Epoch: {early_stopping.epoch:03d}")
    print(f"Train: {eval_metrics.train}")
    print(f"Val: {eval_metrics.val}")
    print("-" * 79)

print()
print("-" * 79)
print("Training complete!")
print(f"Best epoch: {early_stopping.best_epoch}")
print(f"Best validation loss: {early_stopping.best_loss}")
print("-" * 79)

Epoch: 001
Train: loss: 1.329, acc: 0.898, prc: 0.107, rec: 0.856, f1: 0.190, auc: 0.527, aprc: 0.108
Val: loss: 1.252, acc: 0.912, prc: 0.094, rec: 0.887, f1: 0.171, auc: 0.533, aprc: 0.093
-------------------------------------------------------------------------------
Epoch: 002
Train: loss: 1.323, acc: 0.898, prc: 0.107, rec: 0.930, f1: 0.191, auc: 0.545, aprc: 0.113
Val: loss: 1.250, acc: 0.912, prc: 0.094, rec: 0.952, f1: 0.171, auc: 0.550, aprc: 0.098
-------------------------------------------------------------------------------
Epoch: 003
Train: loss: 1.311, acc: 0.898, prc: 0.108, rec: 0.920, f1: 0.194, auc: 0.568, aprc: 0.124
Val: loss: 1.242, acc: 0.912, prc: 0.094, rec: 0.923, f1: 0.170, auc: 0.559, aprc: 0.103
-------------------------------------------------------------------------------
Epoch: 004
Train: loss: 1.305, acc: 0.898, prc: 0.111, rec: 0.865, f1: 0.197, auc: 0.582, aprc: 0.138
Val: loss: 1.237, acc: 0.912, prc: 0.097, rec: 0.877, f1: 0.174, auc: 0.556, aprc: 0.

In [94]:
dataset["company"].feature_names

['onehotencoder__CompanyStatus_Active - Proposal to Strike off__processed',
 'onehotencoder__CompanyStatus_None__processed',
 'onehotencoder__CompanyStatus_infrequent_sklearn__processed',
 'onehotencoder__Accounts_AccountCategory_DORMANT__processed',
 'onehotencoder__Accounts_AccountCategory_FULL__processed',
 'onehotencoder__Accounts_AccountCategory_GROUP__processed',
 'onehotencoder__Accounts_AccountCategory_MICRO ENTITY__processed',
 'onehotencoder__Accounts_AccountCategory_NO ACCOUNTS FILED__processed',
 'onehotencoder__Accounts_AccountCategory_SMALL__processed',
 'onehotencoder__Accounts_AccountCategory_TOTAL EXEMPTION FULL__processed',
 'onehotencoder__Accounts_AccountCategory_UNAUDITED ABRIDGED__processed',
 'onehotencoder__Accounts_AccountCategory_None__processed',
 'onehotencoder__Accounts_AccountCategory_infrequent_sklearn__processed',
 'onehotencoder__SICCode_SicText_1_41202 - Construction of domestic buildings__processed',
 'onehotencoder__SICCode_SicText_1_64209 - Activiti

In [95]:
dataset["person"].feature_names

['onehotencoder__nationality_BE__processed',
 'onehotencoder__nationality_CA__processed',
 'onehotencoder__nationality_CH__processed',
 'onehotencoder__nationality_DE__processed',
 'onehotencoder__nationality_ES__processed',
 'onehotencoder__nationality_GB__processed',
 'onehotencoder__nationality_IE__processed',
 'onehotencoder__nationality_PH__processed',
 'onehotencoder__nationality_PL__processed',
 'onehotencoder__nationality_ZA__processed',
 'onehotencoder__nationality_None__processed',
 'onehotencoder__nationality_infrequent_sklearn__processed',
 'birthDate__processed',
 'indegree__processed',
 'outdegree__processed',
 'closeness__processed',
 'clustering__processed',
 'pagerank__processed']