## Preliminary

In [1]:
!pip uninstall -y torch-geometric torch-sparse torch-scatter torch-cluster pyg-lib

[0m

In [2]:
!pip install torch==2.5.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

Looking in indexes: https://download.pytorch.org/whl/cu124


In [3]:
!pip install torch-geometric \
  torch-sparse \
  torch-scatter \
  torch-cluster \
  pyg-lib \
  -f https://data.pyg.org/whl/torch-2.5.1+cu124.html

Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_sparse-0.6.18%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m64.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_scatter-2.1.2%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m81.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_cluster-1.6.3%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (3.4 M

In [81]:
import pandas as pd
import numpy as np
import itertools
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch_geometric
from torch_geometric.nn import GCNConv, GINEConv, BatchNorm, Linear, GATConv, PNAConv, RGCNConv, summary
from torch_geometric.data import Data, HeteroData
from torch_geometric.typing import OptTensor
from torch_geometric.utils import degree
from torch_geometric.transforms import BaseTransform
from torch_geometric.loader import LinkNeighborLoader

from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, average_precision_score

import tqdm
import os
import sys
import random
import json
from typing import Union
from google.colab import drive

content_base = "/content/drive"
drive.mount(content_base)

data_dir = os.path.join(content_base, "My Drive/Capstone/data")
data_file = os.path.join(data_dir, "HI-Small_25.csv")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
from types import SimpleNamespace

args = SimpleNamespace(
    # Adaptations
    emlps=False,
    ports=False,
    tds=False,
    ego=False,

    # Model parameters
    batch_size=8192,
    n_epochs=100,
    num_neighs=[100, 100],

    # Misc
    seed=1,
    tqdm=False,
    data='Small_HI',
    model='gin',
    testing=False,
    save_model=False,
    unique_name=False,
    finetune=False,
    inference=False,
    avg_tps=False
)


## Formatting Data

Do not run this part if data formatting process has been performed. Load the formatted data in the next part.

In [None]:
inPath = data_file
outPath = os.path.join(data_dir, "Formatted-HI-Small_25.csv")

In [None]:
raw = pd.read_csv(inPath, dtype=str)
raw.shape

In [None]:
currency = dict()
paymentFormat = dict()
bankAcc = dict()
account = dict()

def get_dict_val(name, collection):
    if name in collection:
        val = collection[name]
    else:
        val = len(collection)
        collection[name] = val
    return val

header = "EdgeID,from_id,to_id,Timestamp,\
Amount Sent,Sent Currency,Amount Received,Received Currency,\
Payment Format,Is Laundering\n"

firstTs = -1

In [None]:
with open(outPath, 'w') as writer:
    writer.write(header)

    for i, row in raw.iterrows():
        datetime_object = datetime.strptime(row["Timestamp"], '%Y/%m/%d %H:%M')

        # Extracting timestamp elements
        ts = datetime_object.timestamp()
        day = datetime_object.day
        month = datetime_object.month
        year = datetime_object.year
        hour = datetime_object.hour
        minute = datetime_object.minute

        if firstTs == -1:
            startTime = datetime(year, month, day)
            firstTs = startTime.timestamp() - 10

        ts = ts - firstTs

        cur1 = get_dict_val(row["Receiving Currency"], currency)
        cur2 = get_dict_val(row["Payment Currency"], currency)

        fmt = get_dict_val(row["Payment Format"], paymentFormat)

        fromAccIdStr = row["From Bank"] + row.iloc[2]
        fromId = get_dict_val(fromAccIdStr, account)

        toAccIdStr = row["To Bank"] + row.iloc[4]
        toId = get_dict_val(toAccIdStr, account)

        amountReceivedOrig = float(row["Amount Received"])
        amountPaidOrig = float(row["Amount Paid"])

        isl = int(row["Is Laundering"])

        line = f'{i},{fromId},{toId},{ts},{amountPaidOrig},{cur2},{amountReceivedOrig},{cur1},{fmt},{isl}\n'
        writer.write(line)

formatted = pd.read_csv(outPath)
formatted = formatted.sort_values(by="Timestamp")
formatted.to_csv(outPath, index=False)

## Custom Workflow

In [60]:
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
import json

class CustomPreprocessingPipeline:

    def __init__(self, dataset_path: str):
        """
        Initialize pipeline with dataset
        """
        self.dataset_path = dataset_path
        self.df = pd.read_csv(self.dataset_path, dtype=str)
        self.currency_map = dict()
        self.pmt_map = dict()
        self.currency_conv = dict()

        # Track if preprocessing steps have been completed
        self.preprocessed = {
            "formatted": False,
            "duplicates_removed": False,
            "currency_normalized": False,
            "time_features_extracted": False,
            "cyclical_encoded": False,
            "weekend_encoded": False,
        }

    def df_summary(self):
        print("DATA HEAD")
        display(self.df.head())
        print("\nFEATURE TYPE")
        display(self.df.info())

    def format_transactions(self, formatted=False):
        if formatted:
            print("Fetching formatted transactions...")
            self.df = pd.read_csv(os.path.join(data_dir, "Formatted-HI-Small_25.csv"))

            with open(os.path.join(data_dir, "currency_map.json"), "r") as f:
                self.currency_map = json.load(f)
            with open(os.path.join(data_dir, "pmt_map.json"), "r") as f:
                self.pmt_map = json.load(f)
        else:
            print("Formatting transactions...")
            account = dict()

            def get_dict_val(name, collection):
                if name in collection:
                    val = collection[name]
                else:
                    val = len(collection)
                    collection[name] = val
                return val

            firstTs = -1
            processed_rows = []

            for i, row in self.df.iterrows():
                datetime_object = datetime.strptime(row["Timestamp"], '%Y/%m/%d %H:%M')

                ts = datetime_object.timestamp()
                day = datetime_object.day
                month = datetime_object.month
                year = datetime_object.year

                if firstTs == -1:
                    startTime = datetime(year, month, day)
                    firstTs = startTime.timestamp() - 10

                ts = ts - firstTs

                cur_received = get_dict_val(row["Receiving Currency"], self.currency_map)
                cur_sent = get_dict_val(row["Payment Currency"], self.currency_map)
                fmt = get_dict_val(row["Payment Format"], self.pmt_map)

                from_account_str = row["From Bank"] + row.iloc[2]
                to_account_str = row["To Bank"] + row.iloc[4]

                from_id = get_dict_val(from_account_str, account)
                to_id = get_dict_val(to_account_str, account)

                amount_received = float(row["Amount Received"])
                amount_sent = float(row["Amount Paid"])
                is_laundering = int(row["Is Laundering"])

                processed_rows.append([
                    i, from_id, to_id, ts, amount_sent, cur_sent,
                    amount_received, cur_received, fmt, is_laundering
                ])

            self.df = pd.DataFrame(processed_rows, columns=[
                          "edge_id", "from_id", "to_id", "timestamp", "sent_amount",
                          "sent_currency", "received_amount", "received_currency",
                          "payment_type", "is_laundering"
                      ]).sort_values(by="timestamp").reset_index(drop=True)

            self.df.to_csv(os.path.join(data_dir, "Formatted-HI-Small_25.csv"), index=False)

            with open(os.path.join(data_dir, "currency_map.json"), "w") as f:
                json.dump(self.currency_map, f, indent=4)

            with open(os.path.join(data_dir, "pmt_map.json"), "w") as f:
                json.dump(self.pmt_map, f, indent=4)

        self.preprocessed["formatted"] = True

    def drop_duplicates(self):
        self.df.drop_duplicates(inplace=True)
        self.preprocessed["duplicates_removed"] = True

    def get_usd_conversion(self) -> dict[str, float]:

        currencies = set()
        currency_conversion = {}

        with open(self.dataset_path, "r", encoding="utf-8") as file:
            header = True

            for line in file:
                if header:
                    header = False
                    continue

                columns = line.strip().split(",")

                sent_amount = columns[7]
                sent_currency = columns[8]
                received_amount = columns[5]
                received_currency = columns[6]

                currencies.add(sent_currency)
                currencies.add(received_currency)

                conversion_rate = float(received_amount) / float(sent_amount)

                if sent_currency not in currency_conversion:
                    currency_conversion[sent_currency] = {sent_currency: 1.0}

                currency_conversion[sent_currency][received_currency] = conversion_rate

        usd_conversion = currency_conversion.get("US Dollar", {})

        if set(usd_conversion.keys()) == currencies:
            return usd_conversion

    def currency_normalization(self):
        print("Normalizing currency...")
        if "sent_currency" not in self.df.columns or "received_currency" not in self.df.columns:
            raise KeyError(
                "Currency columns missing. Need to run 'rename_columns' "
                "preprocessing step first."
            )

        self.currency_conv = self.get_usd_conversion()
        reverse_currency_map = {v: k for k, v in self.currency_map.items()}

        self.df["sent_amount_usd"] = self.df.apply(
            lambda row: row["sent_amount"] * self.currency_conv.get(reverse_currency_map[row["sent_currency"]], 1),
            axis=1,
        )
        self.df["received_amount_usd"] = self.df.apply(
            lambda row: row["received_amount"] * self.currency_conv.get(reverse_currency_map[row["received_currency"]], 1),
            axis=1,
        )
        self.preprocessed["currency_normalized"] = True

    def extract_time_features(self):
        print("Extracting time features...")
        if "timestamp" not in self.df.columns:
            raise KeyError(
                "Missing 'timestamp' column, were columns renamed properly?"
            )
        if not isinstance(self.df["timestamp"], datetime):
            self.df["timestamp_copy"] = pd.to_datetime(self.df["timestamp"])

        # Extract items from timestamp
        self.df["hour_of_day"] = self.df["timestamp_copy"].dt.hour
        self.df["day_of_week"] = self.df["timestamp_copy"].dt.weekday # 0=Monday,...,6=Sunday
        self.df["seconds_since_midnight"] = (
            self.df["timestamp_copy"].dt.hour * 3600 +  # Convert hours to seconds
            self.df["timestamp_copy"].dt.minute * 60 +  # Convert minutes to seconds
            self.df["timestamp_copy"].dt.second         # Keep seconds
        )

        # Transform timestamp to raw int unix
        # self.df["timestamp_int"] = self.df["timestamp"].astype(int) / 10**9

        # Just a temp assignment, will be scaled later on
        # self.df["timestamp_scaled"] = self.df["timestamp"].astype(int) / 10**9

        self.df.drop(columns=["timestamp_copy"], inplace= True)

        self.preprocessed["time_features_extracted"] = True

    def cyclical_encoding(self):
        print("Adding cyclical encoding to time feats...")

        if not self.preprocessed["time_features_extracted"]:
            raise RuntimeError("Time features missing, run `extract_time_features` first.")

        self.df["day_sin"] = np.sin(2 * np.pi * self.df["day_of_week"] / 7)
        self.df["day_cos"] = np.cos(2 * np.pi * self.df["day_of_week"] / 7)
        self.df["time_of_day_sin"] = np.sin(2 * np.pi * self.df["seconds_since_midnight"] / 86400)
        self.df["time_of_day_cos"] = np.cos(2 * np.pi * self.df["seconds_since_midnight"] / 86400)

        self.preprocessed["cyclical_encoded"] = True

    def binary_weekend(self):
        if "day_of_week" not in self.df.columns:
            raise KeyError("Day-of-week feature missing. Run `extract_time_features` first.")
        self.df["is_weekend"] = self.df["day_of_week"].isin([5, 6]).astype(int)
        self.preprocessed["weekend_encoded"] = True

    def run_preprocessing(self, formatted=False):
        """Runs all preprocessing steps in the correct order.
           Option to not include graph_feats calculation (takes long time)
        """
        print("Running preprocessing pipeline...\n")

        try:
            self.format_transactions(formatted)
            self.drop_duplicates()
            self.currency_normalization()
            self.extract_time_features()
            self.cyclical_encoding()
            self.binary_weekend()

            print("Preprocessing completed successfully!")
            print(self.preprocessed)

        except Exception as e:
            print(f"Error in preprocessing: {e}")


In [61]:
pl = CustomPreprocessingPipeline(os.path.join(data_dir, "HI-Small_25.csv"))
pl.run_preprocessing(formatted=True)

In [62]:
pl.df.head(10)

Unnamed: 0,edge_id,from_id,to_id,timestamp,sent_amount,sent_currency,received_amount,received_currency,payment_type,is_laundering,sent_amount_usd,received_amount_usd,hour_of_day,day_of_week,seconds_since_midnight,day_sin,day_cos,time_of_day_sin,time_of_day_cos,is_weekend
0,20203,16239,16239,10.0,1015540.22,0,1015540.22,0,0,0,1015540.0,1015540.0,0,3,0,0.433884,-0.900969,0.0,1.0,0
1,31178,25443,25443,10.0,71498.91,1,71498.91,1,0,0,61019.11,61019.11,0,3,0,0.433884,-0.900969,0.0,1.0,0
2,57430,47062,47062,10.0,780.42,2,780.42,2,0,0,1102.577,1102.577,0,3,0,0.433884,-0.900969,0.0,1.0,0
3,2638,2035,2035,10.0,39994.79,0,39994.79,0,0,0,39994.79,39994.79,0,3,0,0.433884,-0.900969,0.0,1.0,0
4,2647,2043,2043,10.0,4785.2,0,4785.2,0,0,0,4785.2,4785.2,0,3,0,0.433884,-0.900969,0.0,1.0,0
5,57414,47049,47049,10.0,95321.25,2,95321.25,2,0,0,134669.9,134669.9,0,3,0,0.433884,-0.900969,0.0,1.0,0
6,2584,82,1998,10.0,14777.01,0,14777.01,0,2,0,14777.01,14777.01,0,3,0,0.433884,-0.900969,0.0,1.0,0
7,57478,47102,47103,10.0,40.19,2,40.19,2,2,0,56.78043,56.78043,0,3,0,0.433884,-0.900969,0.0,1.0,0
8,20215,6729,16248,10.0,314074.32,0,314074.32,0,5,0,314074.3,314074.3,0,3,0,0.433884,-0.900969,0.0,1.0,0
9,20216,16249,16250,10.0,246.0,0,246.0,0,1,0,246.0,246.0,0,3,0,0.433884,-0.900969,0.0,1.0,0


In [63]:
pl.df.columns

Index(['edge_id', 'from_id', 'to_id', 'timestamp', 'sent_amount',
       'sent_currency', 'received_amount', 'received_currency', 'payment_type',
       'is_laundering', 'sent_amount_usd', 'received_amount_usd',
       'hour_of_day', 'day_of_week', 'seconds_since_midnight', 'day_sin',
       'day_cos', 'time_of_day_sin', 'time_of_day_cos', 'is_weekend'],
      dtype='object')

In [64]:
pl.df.to_csv(os.path.join(data_dir, "Custom-Formatted-HI-Small_25.csv"), index=False)

## Preprocessing and Data Loading

Load the formatted data if it exists in your drive.

In [65]:
def set_seed(seed: int = 0) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

In [66]:
def to_adj_nodes_with_times(data):
    num_nodes = data.num_nodes
    timestamps = torch.zeros((data.edge_index.shape[1], 1)) if data.timestamps is None else data.timestamps.reshape((-1,1))
    edges = torch.cat((data.edge_index.T, timestamps), dim=1)
    adj_list_out = dict([(i, []) for i in range(num_nodes)])
    adj_list_in = dict([(i, []) for i in range(num_nodes)])
    for u,v,t in edges:
        u,v,t = int(u), int(v), int(t)
        adj_list_out[u] += [(v, t)]
        adj_list_in[v] += [(u, t)]
    return adj_list_in, adj_list_out

def to_adj_edges_with_times(data):
    num_nodes = data.num_nodes
    timestamps = torch.zeros((data.edge_index.shape[1], 1)) if data.timestamps is None else data.timestamps.reshape((-1,1))
    edges = torch.cat((data.edge_index.T, timestamps), dim=1)
    # calculate adjacent edges with times per node
    adj_edges_out = dict([(i, []) for i in range(num_nodes)])
    adj_edges_in = dict([(i, []) for i in range(num_nodes)])
    for i, (u,v,t) in enumerate(edges):
        u,v,t = int(u), int(v), int(t)
        adj_edges_out[u] += [(i, v, t)]
        adj_edges_in[v] += [(i, u, t)]
    return adj_edges_in, adj_edges_out

def ports(edge_index, adj_list):
    ports = torch.zeros(edge_index.shape[1], 1)
    ports_dict = {}
    for v, nbs in adj_list.items():
        if len(nbs) < 1: continue
        a = np.array(nbs)
        a = a[a[:, -1].argsort()]
        _, idx = np.unique(a[:,[0]],return_index=True,axis=0)
        nbs_unique = a[np.sort(idx)][:,0]
        for i, u in enumerate(nbs_unique):
            ports_dict[(u,v)] = i
    for i, e in enumerate(edge_index.T):
        ports[i] = ports_dict[tuple(e.numpy())]
    return ports

def time_deltas(data, adj_edges_list):
    time_deltas = torch.zeros(data.edge_index.shape[1], 1)
    if data.timestamps is None:
        return time_deltas
    for v, edges in adj_edges_list.items():
        if len(edges) < 1: continue
        a = np.array(edges)
        a = a[a[:, -1].argsort()]
        a_tds = [0] + [a[i+1,-1] - a[i,-1] for i in range(a.shape[0]-1)]
        tds = np.hstack((a[:,0].reshape(-1,1), np.array(a_tds).reshape(-1,1)))
        for i,td in tds:
            time_deltas[i] = td
    return time_deltas

class GraphData(Data):
    '''This is the homogenous graph object we use for GNN training if reverse MP is not enabled'''
    def __init__(
        self, x: OptTensor = None, edge_index: OptTensor = None, edge_attr: OptTensor = None, y: OptTensor = None, pos: OptTensor = None,
        readout: str = 'edge',
        num_nodes: int = None,
        timestamps: OptTensor = None,
        node_timestamps: OptTensor = None,
        **kwargs
      ):

        super().__init__(x, edge_index, edge_attr, y, pos, **kwargs)
        self.readout = readout
        self.loss_fn = 'ce'
        self.num_nodes = int(self.x.shape[0])
        self.node_timestamps = node_timestamps
        if timestamps is not None:
            self.timestamps = timestamps
        elif edge_attr is not None:
            self.timestamps = edge_attr[:,0].clone()
        else:
            self.timestamps = None

    def add_ports(self):
        '''Adds port numberings to the edge features'''
        reverse_ports = True

        adj_list_in, adj_list_out = to_adj_nodes_with_times(self)
        in_ports = ports(self.edge_index, adj_list_in)
        out_ports = [ports(self.edge_index.flipud(), adj_list_out)] if reverse_ports else []

        self.edge_attr = torch.cat([self.edge_attr, in_ports] + out_ports, dim=1)

        return self

    def add_time_deltas(self):
        '''Adds time deltas (i.e. the time between subsequent transactions) to the edge features'''
        reverse_tds = True

        adj_list_in, adj_list_out = to_adj_edges_with_times(self)
        in_tds = time_deltas(self, adj_list_in)
        out_tds = [time_deltas(self, adj_list_out)] if reverse_tds else []

        self.edge_attr = torch.cat([self.edge_attr, in_tds] + out_tds, dim=1)

        return self

def z_norm(data):
    std = data.std(0).unsqueeze(0)
    std = torch.where(std == 0, torch.tensor(1, dtype=torch.float32).cpu(), std)
    return (data - data.mean(0).unsqueeze(0)) / std

In [73]:
def get_data(args):
    '''Loads the AML transaction data.

    1. The data is loaded from the csv and the necessary features are chosen.
    2. The data is split into training, validation and test data.
    3. PyG Data objects are created with the respective data splits.
    '''

    # Load the formatted data
    formatted_data_file = os.path.join(data_dir, "Custom-Formatted-HI-Small_25.csv")
    df_edges = pd.read_csv(formatted_data_file)

    # Normalize timestamp
    # df_edges['Timestamp'] = df_edges['Timestamp'] - df_edges['Timestamp'].min()

    # Building data object (nodes, edges)
    max_n_id = df_edges.loc[:, ['from_id', 'to_id']].to_numpy().max() + 1
    df_nodes = pd.DataFrame({'NodeID': np.arange(max_n_id), 'Feature': np.ones(max_n_id)})

    timestamps = torch.Tensor(df_edges['timestamp'].to_numpy())
    y = torch.LongTensor(df_edges['is_laundering'].to_numpy())

    print(f"Illicit ratio = {sum(y)} / {len(y)} = {sum(y) / len(y) * 100:.2f}%")
    print(f"Number of nodes (holdings doing transcations) = {df_nodes.shape[0]}")
    print(f"Number of transactions = {df_edges.shape[0]}")

    # edge_features = ['Timestamp', 'Amount Received', 'Received Currency', 'Payment Format']
    edge_features = ['timestamp', 'sent_amount',
       'sent_currency', 'received_amount', 'received_currency', 'payment_type',
       'is_laundering', 'sent_amount_usd', 'received_amount_usd',
       'hour_of_day', 'day_of_week', 'seconds_since_midnight', 'day_sin',
       'day_cos', 'time_of_day_sin', 'time_of_day_cos', 'is_weekend']

    node_features = ['Feature']

    print(f'\nEdge features being used: {edge_features}')
    print(f'Node features being used: {node_features} ("Feature" is a placeholder feature of all 1s)')

    x = torch.tensor(df_nodes.loc[:, node_features].to_numpy()).float()
    edge_index = torch.LongTensor(df_edges.loc[:, ['from_id', 'to_id']].to_numpy().T)
    edge_attr = torch.tensor(df_edges.loc[:, edge_features].to_numpy()).float()

    n_days = int(timestamps.max() / (3600 * 24) + 1)
    n_samples = y.shape[0]
    print(f'\nnumber of days and transactions in the data: {n_days} days, {n_samples} transactions')

    # Data Splitting (temporal aggregation)
    daily_irs, weighted_daily_irs, daily_inds, daily_trans = [], [], [], [] #irs = illicit ratios, inds = indices, trans = transactions
    for day in range(n_days):
        l = day * 24 * 3600
        r = (day + 1) * 24 * 3600
        day_inds = torch.where((timestamps >= l) & (timestamps < r))[0]
        daily_irs.append(y[day_inds].float().mean())
        weighted_daily_irs.append(y[day_inds].float().mean() * day_inds.shape[0] / n_samples)
        daily_inds.append(day_inds)
        daily_trans.append(day_inds.shape[0])

    split_per = [0.6, 0.2, 0.2]
    daily_totals = np.array(daily_trans)
    d_ts = daily_totals
    I = list(range(len(d_ts)))
    split_scores = dict()

    for i,j in itertools.combinations(I, 2):
        if j >= i:
            split_totals = [d_ts[:i].sum(), d_ts[i:j].sum(), d_ts[j:].sum()]
            split_totals_sum = np.sum(split_totals)
            split_props = [v/split_totals_sum for v in split_totals]
            split_error = [abs(v-t)/t for v,t in zip(split_props, split_per)]
            score = max(split_error) #- (split_totals_sum/total) + 1
            split_scores[(i,j)] = score
        else:
            continue

    i, j = min(split_scores, key=split_scores.get)
    # split contains a list for each split (train, validation and test) and each list contains the days that are part of the respective split
    split = [list(range(i)), list(range(i, j)), list(range(j, len(daily_totals)))]
    print(f'\nCalculate split: {split}')

    # Now, we seperate the transactions based on their indices in the timestamp array
    split_inds = {k: [] for k in range(3)}
    for i in range(3):
        for day in split[i]:
            split_inds[i].append(daily_inds[day]) #split_inds contains a list for each split (tr,val,te) which contains the indices of each day seperately

    tr_inds = torch.cat(split_inds[0])
    val_inds = torch.cat(split_inds[1])
    te_inds = torch.cat(split_inds[2])

    print(f"\nTrain indicces shape: {tr_inds.shape}")
    print(f"Validation indicces shape: {val_inds.shape}")
    print(f"Test indicces shape: {te_inds.shape}")

    print(f"\nTotal train samples: {tr_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: "
            f"{y[tr_inds].float().mean() * 100 :.2f}% || Train days: {split[0][:5]}")
    print(f"Total val samples: {val_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: "
        f"{y[val_inds].float().mean() * 100:.2f}% || Val days: {split[1][:5]}")
    print(f"Total test samples: {te_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: "
        f"{y[te_inds].float().mean() * 100:.2f}% || Test days: {split[2][:5]}")

    # Creating the final data objects
    tr_x, val_x, te_x = x, x, x
    e_tr = tr_inds.numpy()
    e_val = np.concatenate([tr_inds, val_inds])

    tr_edge_index,  tr_edge_attr,  tr_y,  tr_edge_times  = edge_index[:,e_tr],  edge_attr[e_tr],  y[e_tr],  timestamps[e_tr]
    val_edge_index, val_edge_attr, val_y, val_edge_times = edge_index[:,e_val], edge_attr[e_val], y[e_val], timestamps[e_val]
    te_edge_index,  te_edge_attr,  te_y,  te_edge_times  = edge_index,          edge_attr,        y,        timestamps

    tr_data = GraphData (x=tr_x,  y=tr_y,  edge_index=tr_edge_index,  edge_attr=tr_edge_attr,  timestamps=tr_edge_times )
    val_data = GraphData(x=val_x, y=val_y, edge_index=val_edge_index, edge_attr=val_edge_attr, timestamps=val_edge_times)
    te_data = GraphData (x=te_x,  y=te_y,  edge_index=te_edge_index,  edge_attr=te_edge_attr,  timestamps=te_edge_times )

    # Adding ports and time-deltas if applicable
    if args.ports:
        print(f"\nStart: adding ports")
        tr_data.add_ports()
        val_data.add_ports()
        te_data.add_ports()
        print(f"Done: adding ports")

    if args.tds:
        print(f"\nStart: adding time-deltas")
        tr_data.add_time_deltas()
        val_data.add_time_deltas()
        te_data.add_time_deltas()
        print(f"Done: adding time-deltas")

    # Normalize data
    tr_data.x = val_data.x = te_data.x = z_norm(tr_data.x)
    if not args.model == 'rgcn':
        tr_data.edge_attr, val_data.edge_attr, te_data.edge_attr = z_norm(tr_data.edge_attr), z_norm(val_data.edge_attr), z_norm(te_data.edge_attr)
    else:
        tr_data.edge_attr[:, :-1], val_data.edge_attr[:, :-1], te_data.edge_attr[:, :-1] = z_norm(tr_data.edge_attr[:, :-1]), z_norm(val_data.edge_attr[:, :-1]), z_norm(te_data.edge_attr[:, :-1])

    print(f'\ntrain data object: {tr_data}')
    print(f'validation data object: {val_data}')
    print(f'test data object: {te_data}')

    return tr_data, val_data, te_data, tr_inds, val_inds, te_inds

## Models

In [74]:
class GINe(torch.nn.Module):
    def __init__(self, num_features, num_gnn_layers, n_classes=2,
                n_hidden=100, edge_updates=False, residual=True,
                edge_dim=None, dropout=0.0, final_dropout=0.5):
        super().__init__()

        self.n_hidden = n_hidden
        self.num_gnn_layers = num_gnn_layers
        self.edge_updates = edge_updates
        self.final_dropout = final_dropout

        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)

        self.convs = nn.ModuleList()
        self.emlps = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for _ in range(self.num_gnn_layers):
            conv = GINEConv(nn.Sequential(
                nn.Linear(self.n_hidden, self.n_hidden),
                nn.ReLU(),
                nn.Linear(self.n_hidden, self.n_hidden)
                ), edge_dim=self.n_hidden)

            if self.edge_updates: self.emlps.append(nn.Sequential(
                nn.Linear(3 * self.n_hidden, self.n_hidden),
                nn.ReLU(),
                nn.Linear(self.n_hidden, self.n_hidden),
            ))
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(n_hidden))

        self.mlp = nn.Sequential(Linear(n_hidden*3, 50),
                                 nn.ReLU(),
                                 nn.Dropout(self.final_dropout),
                                 Linear(50, 25),
                                 nn.ReLU(),
                                 nn.Dropout(self.final_dropout),
                                 Linear(25, n_classes))

    def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
            if self.edge_updates:
                edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2

        x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
        x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
        out = x

        return self.mlp(out)

class GATe(torch.nn.Module):
    def __init__(self, num_features, num_gnn_layers, n_classes=2, n_hidden=100, n_heads=4, edge_updates=False, edge_dim=None, dropout=0.0, final_dropout=0.5):
        super().__init__()
        # GAT specific code
        tmp_out = n_hidden // n_heads
        n_hidden = tmp_out * n_heads

        self.n_hidden = n_hidden
        self.n_heads = n_heads
        self.num_gnn_layers = num_gnn_layers
        self.edge_updates = edge_updates
        self.dropout = dropout
        self.final_dropout = final_dropout

        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)

        self.convs = nn.ModuleList()
        self.emlps = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for _ in range(self.num_gnn_layers):
            conv = GATConv(self.n_hidden, tmp_out, self.n_heads, concat = True, dropout = self.dropout, add_self_loops = True, edge_dim=self.n_hidden)
            if self.edge_updates: self.emlps.append(nn.Sequential(nn.Linear(3 * self.n_hidden, self.n_hidden),nn.ReLU(),nn.Linear(self.n_hidden, self.n_hidden),))
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(n_hidden))

        self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(25, n_classes))

    def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
            if self.edge_updates:
                edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2

        # logging.debug(f"x.shape = {x.shape}, x[edge_index.T].shape = {x[edge_index.T].shape}")
        x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
        # logging.debug(f"x.shape = {x.shape}")
        x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
        # logging.debug(f"x.shape = {x.shape}")
        out = x

        return self.mlp(out)

class PNA(torch.nn.Module):
    def __init__(self, num_features, num_gnn_layers, n_classes=2,
                n_hidden=100, edge_updates=True,
                edge_dim=None, dropout=0.0, final_dropout=0.5, deg=None):
        super().__init__()
        n_hidden = int((n_hidden // 5) * 5)
        self.n_hidden = n_hidden
        self.num_gnn_layers = num_gnn_layers
        self.edge_updates = edge_updates
        self.final_dropout = final_dropout

        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)

        self.convs = nn.ModuleList()
        self.emlps = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for _ in range(self.num_gnn_layers):
            conv = PNAConv(in_channels=n_hidden, out_channels=n_hidden,
                           aggregators=aggregators, scalers=scalers, deg=deg,
                           edge_dim=n_hidden, towers=5, pre_layers=1, post_layers=1,
                           divide_input=False)
            if self.edge_updates: self.emlps.append(nn.Sequential(
                nn.Linear(3 * self.n_hidden, self.n_hidden),
                nn.ReLU(),
                nn.Linear(self.n_hidden, self.n_hidden),
            ))
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(n_hidden))

        self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
                              Linear(25, n_classes))

    def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
            if self.edge_updates:
                edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2

        # logging.debug(f"x.shape = {x.shape}, x[edge_index.T].shape = {x[edge_index.T].shape}")
        x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
        # logging.debug(f"x.shape = {x.shape}")
        x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
        # logging.debug(f"x.shape = {x.shape}")
        out = x
        return self.mlp(out)

class RGCN(nn.Module):
    def __init__(self, num_features, edge_dim, num_relations, num_gnn_layers, n_classes=2,
                n_hidden=100, edge_update=False,
                residual=True,
                dropout=0.0, final_dropout=0.5, n_bases=-1):
        super(RGCN, self).__init__()

        self.num_features = num_features
        self.num_gnn_layers = num_gnn_layers
        self.n_hidden = n_hidden
        self.residual = residual
        self.dropout = dropout
        self.final_dropout = final_dropout
        self.n_classes = n_classes
        self.edge_update = edge_update
        self.num_relations = num_relations
        self.n_bases = n_bases

        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.mlp = nn.ModuleList()

        if self.edge_update:
            self.emlps = nn.ModuleList()
            self.emlps.append(nn.Sequential(
                nn.Linear(3 * self.n_hidden, self.n_hidden),
                nn.ReLU(),
                nn.Linear(self.n_hidden, self.n_hidden),
            ))

        for _ in range(self.num_gnn_layers):
            conv = RGCNConv(self.n_hidden, self.n_hidden, num_relations, num_bases=self.n_bases)
            self.convs.append(conv)
            self.bns.append(nn.BatchNorm1d(self.n_hidden))

            if self.edge_update:
                self.emlps.append(nn.Sequential(
                    nn.Linear(3 * self.n_hidden, self.n_hidden),
                    nn.ReLU(),
                    nn.Linear(self.n_hidden, self.n_hidden),
                ))

        self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout), Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
                              Linear(25, n_classes))

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.reset_parameters()
            elif isinstance(m, RGCNConv):
                m.reset_parameters()
            elif isinstance(m, nn.BatchNorm1d):
                m.reset_parameters()

    def forward(self, x, edge_index, edge_attr):
        edge_type = edge_attr[:, -1].long()
        #edge_attr = edge_attr[:, :-1]
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x =  (x + F.relu(self.bns[i](self.convs[i](x, edge_index, edge_type)))) / 2
            if self.edge_update:
                edge_attr = (edge_attr + F.relu(self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)))) / 2

        x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
        x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
        x = self.mlp(x)
        out = x

        return x

## Training

### Train Utility Functions

In [82]:
class AddEgoIds(BaseTransform):
    r"""Add IDs to the centre nodes of the batch.
    """
    def __init__(self):
        pass

    def __call__(self, data: Union[Data, HeteroData]):
        x = data.x
        device = x.device

        ids = torch.zeros((x.shape[0], 1), device=device)
        nodes = torch.unique(data.edge_label_index.view(-1)).to(device)
        ids[nodes] = 1
        data.x = torch.cat([x, ids], dim=1)

        return data

def extract_param(parameter_name: str, args) -> float:
    """
    Extract the value of the specified parameter for the given model.

    Args:
    - parameter_name (str): Name of the parameter (e.g., "lr").
    - args (argparser): Arguments given to this specific run.

    Returns:
    - float: Value of the specified parameter.
    """
    # file_path = './model_settings.json'
    # with open(file_path, "r") as file:
    #     data = json.load(file)
    data = {
      "gin": {
        "params": {
          "lr": 0.006213266113989207, "n_hidden": 66.00315515631006, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce",
          "w_ce1": 1.0000182882773443, "w_ce2": 6.275014431494497, "norm_method": "z_normalize", "dropout": 0.00983468338330501, "final_dropout": 0.10527690625126304
        },
        "bayes_opt_params": {
          "lr": [0.002, 0.007], "n_hidden": [66.0, 66.01], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.0, 2.001], "loss": [0.0, 0.1],
          "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.001], "dropout": [0, 0.05], "final_dropout": [0, 0.2]
        },
        "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
      },
      "pna": {
        "params": {
          "lr": 0.0006116418195373612, "n_hidden": 20, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1.0003967674742307,
          "w_ce2": 7.077633468006714, "norm_method": "z_normalize", "dropout": 0.08340440094051481, "final_dropout": 0.28812979737686323
        },
        "bayes_opt_params": {
          "lr": [0.0001, 0.001], "n_hidden": [16, 64], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.00, 2.01], "loss": [0.0, 0.1],
          "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.1], "dropout": [0.0, 0.2], "final_dropout": [0.0, 0.4]
        },
        "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
      },
      "gat": {
        "params": {
          "lr": 0.006, "n_hidden": 64, "n_heads": 4, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1, "w_ce2": 6,
          "norm_method": "z_normalize", "dropout": 0.009, "final_dropout": 0.1
        },
        "bayes_opt_params": {
          "lr": [0.01, 0.04], "n_hidden": [4, 24], "n_heads": [1.5, 4.5], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [3, 7],
          "loss": [0, 0.1], "w_ce1": [1, 1.001], "w_ce2": [1, 10], "norm_method": [0, 0.1], "dropout": [0, 0.5], "final_dropout": [0, 0.8]
        },
        "header": "run,tb,lr,n_hidden,n_heads,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
      },
      "mlp": {
        "params": {
          "lr": 0.006213266113989207, "n_hidden": 66.00315515631006, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1.0000182882773443,
          "w_ce2": 9.23, "norm_method": "z_normalize", "dropout": 0.00983468338330501, "final_dropout": 0.10527690625126304
        },
        "bayes_opt_params": {
          "lr": [0.006, 0.0064], "n_hidden": [66.0, 66.01], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.0, 2.001], "loss": [0.0, 0.1],
          "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.001], "dropout": [0, 0.05], "final_dropout": [0, 0.2]
        },
        "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
      },
      "rgcn": {
        "params": {
          "lr": 0.006213266113989207, "n_hidden": 66.00315515631006, "n_mlp_layers": 1, "n_gnn_layers": 2, "loss": "ce", "w_ce1": 1.0000182882773443,
          "w_ce2": 9.23, "norm_method": "z_normalize", "dropout": 0.00983468338330501, "final_dropout": 0.10527690625126304
        },
        "bayes_opt_params": {
          "lr": [0.006, 0.0064], "n_hidden": [66.0, 66.01], "n_mlp_layers": [1, 1.001], "n_gnn_layers": [2.0, 2.001], "loss": [0.0, 0.1],
          "w_ce1": [1, 1.001], "w_ce2": [6, 12], "norm_method": [0, 0.001], "dropout": [0, 0.05], "final_dropout": [0, 0.2]
        },
        "header": "run,tb,lr,n_hidden,n_mlp_layers,n_gnn_layers,loss,w_ce1,w_ce2,norm_method,dropout,final_dropout,epoch,tr_acc,tr_prec,tr_rec,tr_f1,tr_auc,val_acc,val_prec,val_rec,val_f1,val_auc,te_acc,te_prec,te_rec,te_f1,te_auc\n"
      }
    }

    return data.get(args.model, {}).get("params", {}).get(parameter_name, None)

def add_arange_ids(data_list):
    '''
    Add the index as an id to the edge features to find seed edges in training, validation and testing.

    Args:
    - data_list (str): List of tr_data, val_data and te_data.
    '''
    for data in data_list:
        data.edge_attr = torch.cat([torch.arange(data.edge_attr.shape[0]).view(-1, 1), data.edge_attr], dim=1)

def get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args):
    tr_loader =  LinkNeighborLoader(tr_data,
                                    num_neighbors=args.num_neighs,
                                    batch_size=args.batch_size,
                                    shuffle=True,
                                    transform=transform)

    val_loader = LinkNeighborLoader(val_data,
                                    num_neighbors=args.num_neighs,
                                    edge_label_index=val_data.edge_index[:, val_inds],
                                    edge_label=val_data.y[val_inds],
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    transform=transform)

    te_loader =  LinkNeighborLoader(te_data,
                                    num_neighbors=args.num_neighs,
                                    edge_label_index=te_data.edge_index[:, te_inds],
                                    edge_label=te_data.y[te_inds],
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    transform=transform)

    return tr_loader, val_loader, te_loader

def compute_metrics(ground_truth, pred, prob=None):
    """Computes and returns classification metrics."""
    metrics = {
        "accuracy": accuracy_score(ground_truth, pred),
        "f1": f1_score(ground_truth, pred)
    }

    if prob is not None:
        try:
            metrics["roc_auc"] = roc_auc_score(ground_truth, prob)
            metrics["pr_auc"] = average_precision_score(ground_truth, prob)
        except ValueError:
            metrics["roc_auc"] = None
            metrics["pr_auc"] = None
    else:
        metrics["roc_auc"] = None
        metrics["pr_auc"] = None

    return metrics

def save_model(model, optimizer, epoch, args, data_config):
    # Save the model in a dictionary
    torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
                }, f'{data_config["paths"]["model_to_save"]}/checkpoint_{args.unique_name}{"" if not args.finetune else "_finetuned"}.tar')

def load_model(model, device, args, config, data_config):
    checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_{args.unique_name}.tar')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    return model, optimizer

### Training Workflow

In [83]:
@torch.no_grad()
def evaluate_homo(loader, inds, model, data, device, args):
    preds = []
    ground_truths = []
    probs = []

    for batch in tqdm.tqdm(loader, disable=not args.tqdm):
        inds = inds.detach().cpu()
        batch_edge_inds = inds[batch.input_id.detach().cpu()]

        batch_edge_ids = loader.data.edge_attr.detach().cpu()[batch_edge_inds, 0]
        mask = torch.isin(batch.edge_attr[:, 0].detach().cpu(), batch_edge_ids)

        missing = ~torch.isin(batch_edge_ids, batch.edge_attr[:, 0].detach().cpu())

        if missing.sum() != 0 and (args.data == 'Small_J' or args.data == 'Small_Q'):
            missing_ids = batch_edge_ids[missing].int()
            n_ids = batch.n_id
            add_edge_index = data.edge_index[:, missing_ids].detach().clone()
            node_mapping = {value.item(): idx for idx, value in enumerate(n_ids)}
            add_edge_index = torch.tensor([[node_mapping[val.item()] for val in row] for row in add_edge_index])
            add_edge_attr = data.edge_attr[missing_ids, :].detach().clone()
            add_y = data.y[missing_ids].detach().clone()

            batch.edge_index = torch.cat((batch.edge_index, add_edge_index), 1)
            batch.edge_attr = torch.cat((batch.edge_attr, add_edge_attr), 0)
            batch.y = torch.cat((batch.y, add_y), 0)

            mask = torch.cat((mask, torch.ones(add_y.shape[0], dtype=torch.bool)))

        batch.edge_attr = batch.edge_attr[:, 1:]

        batch.to(device)
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        out = out[mask]
        prob = out.softmax(dim=-1)[:, 1]
        pred = out.argmax(dim=-1)

        preds.append(pred)
        probs.append(prob)
        ground_truths.append(batch.y[mask])

    pred = torch.cat(preds).cpu().numpy()
    prob = torch.cat(probs).cpu().numpy()
    ground_truth = torch.cat(ground_truths).cpu().numpy()

    return compute_metrics(ground_truth, pred, prob)

def train_homo(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config):
    best_val_f1 = 0

    for epoch in range(config.epochs):
        total_loss = total_examples = 0
        preds = []
        ground_truths = []
        probs = []

        for batch in tqdm.tqdm(tr_loader, disable=not args.tqdm):
            optimizer.zero_grad()

            inds = tr_inds.detach().cpu()
            batch_edge_inds = inds[batch.input_id.detach().cpu()]

            batch_edge_ids = tr_loader.data.edge_attr.detach().cpu()[batch_edge_inds, 0]
            mask = torch.isin(batch.edge_attr[:, 0].detach().cpu(), batch_edge_ids)

            batch.edge_attr = batch.edge_attr[:, 1:]

            batch.to(device)
            out = model(batch.x, batch.edge_index, batch.edge_attr)
            pred = out[mask]
            ground_truth = batch.y[mask]

            probs.append(pred.softmax(dim=-1)[:, 1])
            preds.append(pred.argmax(dim=-1))
            ground_truths.append(ground_truth)

            loss = loss_fn(pred, ground_truth)
            loss.backward()
            optimizer.step()

            total_loss += float(loss) * pred.numel()
            total_examples += pred.numel()

        pred_np = torch.cat(preds, dim=0).detach().cpu().numpy()
        prob_np = torch.cat(probs, dim=0).detach().cpu().numpy()
        ground_truth_np = torch.cat(ground_truths, dim=0).detach().cpu().numpy()

        train_metrics = compute_metrics(ground_truth_np, pred_np, prob_np)

        print(f'\nEpoch: {epoch}')
        print(f'Train Metrics: {train_metrics}')

        val_metrics = evaluate_homo(val_loader, val_inds, model, val_data, device, args)
        te_metrics = evaluate_homo(te_loader, te_inds, model, te_data, device, args)

        print(f'Validation Metrics: {val_metrics}')
        print(f'Test Metrics: {te_metrics}')

        if epoch == 0:
            print({"best_test_f1": te_metrics['f1']})

        elif val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            print({"best_test_f1": te_metrics["f1"]})

            if args.save_model:
                save_model(model, optimizer, epoch, args, data_config)

    return model


In [84]:
def get_model(sample_batch, config, args):
    n_feats = sample_batch.x.shape[1]
    e_dim = (sample_batch.edge_attr.shape[1] - 1)

    if args.model == "gin":
        model = GINe(
                num_features=n_feats, num_gnn_layers=config.n_gnn_layers, n_classes=2,
                n_hidden=round(config.n_hidden), residual=False, edge_updates=args.emlps, edge_dim=e_dim,
                dropout=config.dropout, final_dropout=config.final_dropout
                )
    elif args.model == "gat":
        model = GATe(
                num_features=n_feats, num_gnn_layers=config.n_gnn_layers, n_classes=2,
                n_hidden=round(config.n_hidden), n_heads=round(config.n_heads),
                edge_updates=args.emlps, edge_dim=e_dim,
                dropout=config.dropout, final_dropout=config.final_dropout
                )
    elif args.model == "pna":
        index = torch.cat((sample_batch['node', 'to', 'node'].edge_index[1], sample_batch['node', 'rev_to', 'node'].edge_index[1]), 0)
        d = degree(index, dtype=torch.long)
        deg = torch.bincount(d, minlength=1)
        model = PNA(
            num_features=n_feats, num_gnn_layers=config.n_gnn_layers, n_classes=2,
            n_hidden=round(config.n_hidden), edge_updates=args.emlps, edge_dim=e_dim,
            dropout=config.dropout, deg=deg, final_dropout=config.final_dropout
            )
    elif config.model == "rgcn":
        model = RGCN(
            num_features=n_feats, edge_dim=e_dim, num_relations=8, num_gnn_layers=round(config.n_gnn_layers),
            n_classes=2, n_hidden=round(config.n_hidden),
            edge_update=args.emlps, dropout=config.dropout, final_dropout=config.final_dropout, n_bases=None #(maybe)
        )

    return model

def train_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config):
    # Set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    config = SimpleNamespace(
        epochs=args.n_epochs,
        batch_size=args.batch_size,
        model=args.model,
        data=args.data,
        num_neighbors=args.num_neighs,

        lr=extract_param("lr", args),
        n_hidden=extract_param("n_hidden", args),
        n_gnn_layers=extract_param("n_gnn_layers", args),

        loss="ce",
        w_ce1=extract_param("w_ce1", args),
        w_ce2=extract_param("w_ce2", args),

        dropout=extract_param("dropout", args),
        final_dropout=extract_param("final_dropout", args),
        n_heads=extract_param("n_heads", args) if args.model == 'gat' else None
    )

    # Set the transform if ego ids should be used
    if args.ego:
        transform = AddEgoIds()
    else:
        transform = None

    # Add the unique ids to later find the seed edges
    add_arange_ids([tr_data, val_data, te_data])

    tr_loader, val_loader, te_loader = get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args)

    # Get the model
    sample_batch = next(iter(tr_loader))
    model = get_model(sample_batch, config, args)

    if args.finetune:
        model, optimizer = load_model(model, device, args, config, data_config)
    else:
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    sample_batch.to(device)
    sample_x = sample_batch.x
    sample_edge_index = sample_batch.edge_index
    sample_batch.edge_attr = sample_batch.edge_attr[:, 1:]
    sample_edge_attr = sample_batch.edge_attr

    print(summary(model, sample_x, sample_edge_index, sample_edge_attr))

    loss_fn = torch.nn.CrossEntropyLoss(weight=torch.FloatTensor([config.w_ce1, config.w_ce2]).to(device))

    model = train_homo(tr_loader, val_loader, te_loader, tr_inds, val_inds, te_inds, model, optimizer, loss_fn, args, config, device, val_data, te_data, data_config)


### Train Initiation

Run these two code chunks everytime we want to test a new model / component

In [85]:
args.tds = False
args.ports = False
args.ego = False

tr_data, val_data, te_data, tr_inds, val_inds, te_inds = get_data(args)

print('\n')
print(tr_data)
print(val_data)
print(te_data)

In [86]:
print(tr_data.edge_attr)

In [87]:
data_config = {
  "paths": {
    "aml_data": "/path/to/aml_data",
    "model_to_load": "/path/to/model_you_want_to_load (e.g for inference or fine-tuning)",
    "model_to_save": "./model"
  }
}

args.model = "gin"
args.emlps = True
args.n_epochs = 10

train_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config)

KeyboardInterrupt: 

## Inference

args.inference has to be set to True and model training checkpoints have to be saved.

In [None]:
import time

script_start = time.time()

def infer_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config):
    # Set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    config = SimpleNamespace(
        epochs=args.n_epochs,
        batch_size=args.batch_size,
        model=args.model,
        data=args.data,
        num_neighbors=args.num_neighs,

        lr=extract_param("lr", args),
        n_hidden=extract_param("n_hidden", args),
        n_gnn_layers=extract_param("n_gnn_layers", args),

        loss="ce",
        w_ce1=extract_param("w_ce1", args),
        w_ce2=extract_param("w_ce2", args),

        dropout=extract_param("dropout", args),
        final_dropout=extract_param("final_dropout", args),
        n_heads=extract_param("n_heads", args) if args.model == 'gat' else None
    )

    # Define a model config dictionary and wandb logging at the same time
    # wandb.init(
    #     mode="disabled" if args.testing else "online",
    #     project="your_proj_name",

    #     config={
    #         "epochs": args.n_epochs,
    #         "batch_size": args.batch_size,
    #         "model": args.model,
    #         "data": args.data,
    #         "num_neighbors": args.num_neighs,
    #         "lr": extract_param("lr", args),
    #         "n_hidden": extract_param("n_hidden", args),
    #         "n_gnn_layers": extract_param("n_gnn_layers", args),
    #         "loss": "ce",
    #         "w_ce1": extract_param("w_ce1", args),
    #         "w_ce2": extract_param("w_ce2", args),
    #         "dropout": extract_param("dropout", args),
    #         "final_dropout": extract_param("final_dropout", args),
    #         "n_heads": extract_param("n_heads", args) if args.model == 'gat' else None
    #     }
    # )

    # config = wandb.config

    # Set the transform if ego ids should be used
    if args.ego:
        transform = AddEgoIds()
    else:
        transform = None

    # Add the unique ids to later find the seed edges
    add_arange_ids([tr_data, val_data, te_data])

    tr_loader, val_loader, te_loader = get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args)

    # Get the model
    sample_batch = next(iter(tr_loader))
    model = get_model(sample_batch, config, args)

    if not args.finetune:
        command = " ".join(sys.argv)
        name = ""
        name = '-'.join(name.split('-')[3:])
        args.unique_name = name

    print("=> loading model checkpoint")
    checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_{args.unique_name}.tar')
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)

    print("=> loaded checkpoint (epoch {})".format(start_epoch))

    te_f1, te_prec, te_rec = evaluate_homo(te_loader, te_inds, model, te_data, device, args, precrec=True)

In [None]:
infer_gnn(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, args, data_config)