# 1. Introduction:
---

Money laundering is a multi-billion dollar issue. Detection of laundering is very difficult. Most automated algorithms have a high false positive rate: legitimate transactions incorrectly flagged as laundering. The converse is also a major problem -- false negatives, i.e. undetected laundering transactions. Naturally, criminals work hard to cover their tracks.

Access to real financial transaction data is highly restricted, for both proprietary and privacy reasons. Even when access is possible, it is problematic to provide a correct tag (laundering or legitimate) to each transaction, as noted above. 

In this project we are using a synthetic transaction dataset from IBM that avoids these problems (ALTMAN et al. 2023).


**To check the paper that originated this synthetic dataset, [click here!](https://arxiv.org/abs/2306.16424)**

The data provided here is based on a virtual world inhabited by individuals, companies, and banks. Individuals interact with other individuals and companies. Likewise, companies interact with other companies and with individuals. These interactions can take many forms, e.g. purchase of consumer goods and services, purchase orders for industrial supplies, payment of salaries, repayment of loans, and more. These financial transactions are generally conducted via banks, i.e. the payer and receiver both have accounts, with accounts taking multiple forms from checking to credit cards to bitcoin.

Some (small) fraction of the individuals and companies in the generator model engage in criminal behavior -- such as smuggling, illegal gambling, extortion, and more. Criminals obtain funds from these illicit activities, and then try to hide the source of these illicit funds via a series of financial transactions. Such financial transactions to hide illicit funds constitute laundering. Thus, the data available here is labelled and can be used for training and testing AML (Anti Money Laundering) models and for other purposes.

The data generator that created the data here not only models illicit activity, but also tracks funds derived from illicit activity through arbitrarily many transactions -- thus creating the ability to label laundering transactions many steps removed from their illicit source. With this foundation, it is straightforward for the generator to label individual transactions as laundering or legitimate.

Note that this IBM generator models the entire money laundering cycle:

*   **Placement**: Sources like smuggling of illicit funds.
*   **Layering**: Mixing the illicit funds into the financial system.
*   **Integration**: Spending the illicit funds.


As another capability possible only with synthetic data, note that a real bank or other institution typically has access to only a portion of the transactions involved in laundering: the transactions involving that bank. Transactions happening at other banks or between other banks are not seen. Thus, models built on real transactions from one institution can have only a limited view of the world.

By contrast these synthetic transactions contain an entire financial ecosystem. Thus it may be possible to create laundering detection models that undertand the broad sweep of transactions across institutions, but apply those models to make inferences only about transactions at a particular bank.

## 1.1. Importing Libraries
---

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pathlib
import zipfile


from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import RobustScaler, OrdinalEncoder
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import GridSearchCV

from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, roc_auc_score, roc_curve
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier

from sklearn.metrics import classification_report, confusion_matrix
from imblearn.over_sampling import SMOTE
from collections import Counter

import warnings
warnings.filterwarnings("ignore")

## 1.2. Verify if Data is Present
---

In [None]:
pathlib.Path("data").mkdir(parents=True, exist_ok=True)
PATH = str(pathlib.Path.cwd())
file_path = pathlib.Path("data/HI-Large_Trans.csv")

if not file_path.is_file():
    with zipfile.ZipFile("./data.zip", 'r') as zf:
        zf.extractall("./data/")

# 2. Exploratory Data Analisys (EDA)
---

## 2.1. Reading the HI-Small_Trans file

In [3]:
import pandas as pd

full_df = pd.read_csv("./data/HI-Small_Trans.csv")

full_df.shape

(5078345, 11)

### 2.1.1. Sampling a Portion of the Original DataFrame
---

In [None]:
df = full_df.sample(n=500000, random_state=42)

df.shape

In [None]:
df.head()

### 2.1.2. About the Features
---

In [None]:
df.info()

## 2.2. Basic Statistic in the Numerical Features
---

In [None]:
df.select_dtypes(exclude='object').describe()

In [20]:
# def feature_values_changer(col, zero, one):
#     for i in range(col.shape[0]):
#         if col.values[i] == zero:
#             col.values[i] = 0
#         elif col.values[i] == one:
#             col.values[i] = 1
#         else:
#             col.values[i] = 2
    
#     return col

Reading the HI-Large_Trans.csv, 1000000 rows each time, isolating only 'Is Laundering' == 1 

In [None]:
# dfs = []
# count = 1
# for df in pd.read_csv('./data/HI-Large_Trans.csv', chunksize=1000000):
#     df = df[df['Is Laundering'] == 1]
    
#     del df['Timestamp']
#     dfs.append(df)
    
#     if count % 10 == 0:
#         print(f"{(count / 180)*100:.2f}% complete")
#     count += 1

In [None]:
# df_full_1 = pd.concat(dfs)
# del dfs

# ones_count = df_full_1.shape[0]
# print("Number of rows with 'Is Laundering' == 1:", ones_count)

Reading the HI-Large_Trans.csv, 1000000 rows each time, isolating only 'Is Laundering' == 0, until it becames 1:1 ratio with 'Is Laundering' == 1

In [37]:
# dfs = []
# current = 0
# for df in pd.read_csv('./data/HI-Large_Trans.csv', chunksize=15000):
#     df = df[df['Is Laundering'] == 0]
#     current += df.shape[0]
    
#     del df['Timestamp']
#     dfs.append(df)

#     if current >= ones_count:
#         break

In [None]:
# df_full_0 = pd.concat(dfs)
# df_full = pd.concat([df_full_0, df_full_1])
# del dfs

# zeros_count = df_full_0.shape[0]
# print("Number of rows with 'Is Laundering' == 0:", zeros_count)

In [None]:
# df_full.head()

In [None]:
# df_full.info()

In [None]:
# print("Unique values for feature:")
# {feature:len(df_full[feature].unique()) for feature in df_full.columns}

In [None]:
# df_full.describe()

In [None]:
import pandas as pd

full_df = pd.read_csv("./data/HI-Small_Trans.csv")

full_df.shape

In [None]:
full_df.head()

There are two columns representing paid and received amount of each transcation, wondering if it is necessary to split the amount into two columns when they shared the same value, unless there are transcation fee/transcation between different currency. Let's find out 

In [None]:
print('Amount Received equals to Amount Paid:')
print(full_df['Amount Received'].equals(full_df['Amount Paid']))
print('Receiving Currency equals to Payment Currency:')
print(full_df['Receiving Currency'].equals(full_df['Payment Currency']))

In [None]:
not_equal1 = full_df.loc[~(full_df['Amount Received'] == full_df['Amount Paid'])]
not_equal2 = full_df.loc[~(full_df['Receiving Currency'] == full_df['Payment Currency'])]
print("Transactions with different amount received and paid")
display(not_equal1.head())
print('---------------------------------------------------------------------------')
print("Transactions with differente currency received and paid")
display(not_equal2.head())

Checking if the values of `Receiving Currency` and `Payment Currency` match

In [None]:
print(sorted(full_df['Receiving Currency'].unique()))
print(sorted(full_df['Payment Currency'].unique()))

In the data preprocessing, we perform below transformation:  
1. Transform the Timestamp with min max normalization.  
2. Create unique ID for each account by adding bank code with account number.  
3. Create receiving_df with the information of receiving accounts, received amount and currency
4. Create paying_df with the information of payer accounts, paid amount and currency
5. Create a list of currency used among all transactions
6. Label the 'Payment Format', 'Payment Currency', 'Receiving Currency' by classes with sklearn OrdinalEncoder

# New Approach, now using GNN
---

## Preprocess Step:
---

In [1]:
import pandas as pd
from datetime import datetime

def get_dict_val(name, collection):
    if name in collection:
        val = collection[name]
    else:
        val = len(collection)
        collection[name] = val
    return val

def format_timestamp(timestamp):
    firstTs = -1
    timestamps = []
    for i in timestamp:
        dt_ts = datetime.strptime(i, '%Y/%m/%d %H:%M')
        ts = dt_ts.timestamp()
        if firstTs == -1:
            day = dt_ts.day
            month = dt_ts.month
            year = dt_ts.year
            startTime = datetime(year, month, day)
            firstTs = startTime.timestamp() - 10
        ts = ts - firstTs
        timestamps.append(ts)

    return timestamps


df_edges = pd.read_csv("./data/HI-Small_Trans.csv")

currency = dict()
payment_format = dict()
fromAccIdStr = dict()
toAccIdStr = dict()

df_edges["Timestamp"] = format_timestamp(df_edges["Timestamp"])
df_edges["Received Currency"] = df_edges['Receiving Currency'].apply(lambda x: get_dict_val(x, currency))
df_edges["Sent Currency"] = df_edges['Payment Currency'].apply(lambda x: get_dict_val(x, currency))
df_edges["Payment Format"] = df_edges['Payment Format'].apply(lambda x: get_dict_val(x, payment_format))
df_edges["temp"] = df_edges["From Bank"].astype(str) + df_edges["Account"].astype(str)
df_edges["from_id"] = df_edges["temp"].apply(lambda x: get_dict_val(x, fromAccIdStr))
df_edges["temp"] = df_edges["To Bank"].astype(str) + df_edges["Account.1"].astype(str)
df_edges["to_id"] = df_edges["temp"].apply(lambda x: get_dict_val(x, toAccIdStr))

df_edges.reset_index(drop=True, inplace=True)
df_edges["EdgeID"] = df_edges.index

df_edges.rename(columns={"Amount Paid":"Amount Sent"}, inplace=True)

df_edges.drop(columns=["temp", "From Bank", "Account",
                       "To Bank", "Account.1", "Receiving Currency",
                       "Payment Currency"], inplace=True)

df_edges = df_edges.reindex(columns=["EdgeID","from_id","to_id","Timestamp",
                                     "Amount Sent","Sent Currency","Amount Received",
                                     "Received Currency","Payment Format","Is Laundering"])

df_edges["Timestamp"] = df_edges["Timestamp"] - df_edges["Timestamp"].min()
df_edges = df_edges.sort_values(by="Timestamp")

In [99]:
import torch
import numpy as np

max_n_id = df_edges.loc[:, ['from_id', 'to_id']].to_numpy().max() + 1
df_nodes = pd.DataFrame({'NodeID': np.arange(max_n_id), 'Feature': np.ones(max_n_id)})
timestamps = torch.Tensor(df_edges['Timestamp'].to_numpy())
y = torch.LongTensor(df_edges['Is Laundering'].to_numpy())

edge_features = ['Timestamp', 'Amount Received', 'Received Currency', 'Payment Format']
node_features = ['Feature']

X = torch.tensor(df_nodes.loc[:, node_features].to_numpy()).float()
edge_index = torch.LongTensor(df_edges.loc[:, ['from_id', 'to_id']].to_numpy().T)
edge_attr = torch.tensor(df_edges.loc[:, edge_features].to_numpy()).float()

n_days = int(timestamps.max() / (3600 * 24) + 1)
n_samples = y.shape[0]

In [165]:
import itertools

#data splitting
daily_irs, weighted_daily_irs, daily_inds, daily_trans = [], [], [], [] #irs = illicit

for day in range(n_days):
        l = day * 24 * 3600
        r = (day + 1) * 24 * 3600
        day_inds = torch.where((timestamps >= l) & (timestamps < r))[0]
        daily_irs.append(y[day_inds].float().mean())
        weighted_daily_irs.append(y[day_inds].float().mean() * day_inds.shape[0] / n_samples)
        daily_inds.append(day_inds)
        daily_trans.append(day_inds.shape[0])

# Recommended split_percentages for train, validation and test. 
split_per = [0.6, 0.2, 0.2]
daily_totals = np.array(daily_trans)
d_ts = daily_totals
I = list(range(len(d_ts)))
split_scores = dict()

# Iterates over all days combination ranges and stores the score at split_scores
for i,j in itertools.combinations(I, 2):
    if j >= i:
        split_totals = [d_ts[:i].sum(), d_ts[i:j].sum(), d_ts[j:].sum()]
        split_totals_sum = np.sum(split_totals)
        split_props = [v/split_totals_sum for v in split_totals] # proportion of each split compared to the total transactions
        split_error = [abs(v-t)/t for v,t in zip(split_props, split_per)] 
        score = max(split_error) #- (split_totals_sum/total) + 1
        split_scores[(i,j)] = score
    else:
        continue
i,j = min(split_scores, key=split_scores.get) # get the best i,j from split_scores

# split contains a list for each split (train, validation and test) and each list contains the days that are part of the respective split
split = [list(range(i)), list(range(i, j)), list(range(j, len(daily_totals)))]

# seperate the transactions based on their indices in the timestamp array
split_inds = {k: [] for k in range(3)}
for i in range(3):
    for day in split[i]:
        split_inds[i].append(daily_inds[day]) #split_inds contains a list for each split (tr,val,te) which contains the indices of each day seperately

tr_inds = torch.cat(split_inds[0])
val_inds = torch.cat(split_inds[1])
te_inds = torch.cat(split_inds[2])

tr_x, val_x, te_x = X, X, X # sets the placeholder (ones) to the variables

In [None]:
print(f"Total train samples: {tr_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: "
        f"{y[tr_inds].float().mean() * 100 :.2f}% || Train days: {split[0][:5]}")
print(f"Total val samples: {val_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: "
    f"{y[val_inds].float().mean() * 100:.2f}% || Val days: {split[1][:5]}")
print(f"Total test samples: {te_inds.shape[0] / y.shape[0] * 100 :.2f}% || IR: "
    f"{y[te_inds].float().mean() * 100:.2f}% || Test days: {split[2][:5]}")

# IR stants for Illicit Ratio!

In [186]:
e_tr = tr_inds.numpy() # Edge train array
e_val = np.concatenate([tr_inds, val_inds]) # Edge validation (train + val) array

# Train
tr_edge_index, tr_edge_attr, tr_y, tr_edge_times = edge_index[:,e_tr],  edge_attr[e_tr],  y[e_tr],  timestamps[e_tr]

# Validation (tr + val)
val_edge_index, val_edge_attr, val_y, val_edge_times = edge_index[:,e_val], edge_attr[e_val], y[e_val], timestamps[e_val]

# Test (tr + val + te)
te_edge_index, te_edge_attr, te_y, te_edge_times = edge_index, edge_attr, y, timestamps

In [187]:
import data_util

tr_data = data_util.GraphData(x=tr_x, y=tr_y, edge_index=tr_edge_index, edge_attr=tr_edge_attr, timestamps=tr_edge_times)
val_data = data_util.GraphData(x=val_x, y=val_y, edge_index=val_edge_index, edge_attr=val_edge_attr, timestamps=val_edge_times)
te_data = data_util.GraphData(x=te_x, y=te_y, edge_index=te_edge_index, edge_attr=te_edge_attr, timestamps=te_edge_times)

In [5]:
import json

with open('args.json', 'r') as config_file:
        data_config = json.load(config_file)

## Models:

In [7]:
import torch.nn as nn
from torch_geometric.nn import GINEConv, BatchNorm, Linear, GATConv, PNAConv, RGCNConv
import torch.nn.functional as F
import torch
import logging

class GINe(torch.nn.Module):
    def __init__(self, num_features, num_gnn_layers, n_classes=2, 
                n_hidden=100, edge_updates=False, residual=True, 
                edge_dim=None, dropout=0.0, final_dropout=0.5):
        super().__init__()
        self.n_hidden = n_hidden
        self.num_gnn_layers = num_gnn_layers
        self.edge_updates = edge_updates
        self.final_dropout = final_dropout

        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)

        self.convs = nn.ModuleList()
        self.emlps = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for _ in range(self.num_gnn_layers):
            conv = GINEConv(nn.Sequential(
                nn.Linear(self.n_hidden, self.n_hidden), 
                nn.ReLU(), 
                nn.Linear(self.n_hidden, self.n_hidden)
                ), edge_dim=self.n_hidden)
            if self.edge_updates: self.emlps.append(nn.Sequential(
                nn.Linear(3 * self.n_hidden, self.n_hidden),
                nn.ReLU(),
                nn.Linear(self.n_hidden, self.n_hidden),
            ))
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(n_hidden))

        self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
                              Linear(25, n_classes))

    def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
            if self.edge_updates: 
                edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2

        x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
        x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
        out = x
        
        return self.mlp(out)
    
class GATe(torch.nn.Module):
    def __init__(self, num_features, num_gnn_layers, n_classes=2, n_hidden=100, n_heads=4, edge_updates=False, edge_dim=None, dropout=0.0, final_dropout=0.5):
        super().__init__()
        # GAT specific code
        tmp_out = n_hidden // n_heads
        n_hidden = tmp_out * n_heads

        self.n_hidden = n_hidden
        self.n_heads = n_heads
        self.num_gnn_layers = num_gnn_layers
        self.edge_updates = edge_updates
        self.dropout = dropout
        self.final_dropout = final_dropout
        
        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)
        
        self.convs = nn.ModuleList()
        self.emlps = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        for _ in range(self.num_gnn_layers):
            conv = GATConv(self.n_hidden, tmp_out, self.n_heads, concat = True, dropout = self.dropout, add_self_loops = True, edge_dim=self.n_hidden)
            if self.edge_updates: self.emlps.append(nn.Sequential(nn.Linear(3 * self.n_hidden, self.n_hidden),nn.ReLU(),nn.Linear(self.n_hidden, self.n_hidden),))
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(n_hidden))
                
        self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(25, n_classes))
            
    def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index
        
        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)
        
        for i in range(self.num_gnn_layers):
            x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
            if self.edge_updates:
                edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2
                    
        logging.debug(f"x.shape = {x.shape}, x[edge_index.T].shape = {x[edge_index.T].shape}")
        x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
        logging.debug(f"x.shape = {x.shape}")
        x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
        logging.debug(f"x.shape = {x.shape}")
        out = x

        return self.mlp(out)
    
class PNA(torch.nn.Module):
    def __init__(self, num_features, num_gnn_layers, n_classes=2, 
                n_hidden=100, edge_updates=True,
                edge_dim=None, dropout=0.0, final_dropout=0.5, deg=None):
        super().__init__()
        n_hidden = int((n_hidden // 5) * 5)
        self.n_hidden = n_hidden
        self.num_gnn_layers = num_gnn_layers
        self.edge_updates = edge_updates
        self.final_dropout = final_dropout

        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)

        self.convs = nn.ModuleList()
        self.emlps = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        for _ in range(self.num_gnn_layers):
            conv = PNAConv(in_channels=n_hidden, out_channels=n_hidden,
                           aggregators=aggregators, scalers=scalers, deg=deg,
                           edge_dim=n_hidden, towers=5, pre_layers=1, post_layers=1,
                           divide_input=False)
            if self.edge_updates: self.emlps.append(nn.Sequential(
                nn.Linear(3 * self.n_hidden, self.n_hidden),
                nn.ReLU(),
                nn.Linear(self.n_hidden, self.n_hidden),
            ))
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(n_hidden))

        self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout),Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
                              Linear(25, n_classes))

    def forward(self, x, edge_index, edge_attr):
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x = (x + F.relu(self.batch_norms[i](self.convs[i](x, edge_index, edge_attr)))) / 2
            if self.edge_updates: 
                edge_attr = edge_attr + self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)) / 2

        logging.debug(f"x.shape = {x.shape}, x[edge_index.T].shape = {x[edge_index.T].shape}")
        x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
        logging.debug(f"x.shape = {x.shape}")
        x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
        logging.debug(f"x.shape = {x.shape}")
        out = x
        return self.mlp(out)
    
class RGCN(nn.Module):
    def __init__(self, num_features, edge_dim, num_relations, num_gnn_layers, n_classes=2, 
                n_hidden=100, edge_update=False,
                residual=True,
                dropout=0.0, final_dropout=0.5, n_bases=-1):
        super(RGCN, self).__init__()

        self.num_features = num_features
        self.num_gnn_layers = num_gnn_layers
        self.n_hidden = n_hidden
        self.residual = residual
        self.dropout = dropout
        self.final_dropout = final_dropout
        self.n_classes = n_classes
        self.edge_update = edge_update
        self.num_relations = num_relations
        self.n_bases = n_bases

        self.node_emb = nn.Linear(num_features, n_hidden)
        self.edge_emb = nn.Linear(edge_dim, n_hidden)

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.mlp = nn.ModuleList()

        if self.edge_update:
            self.emlps = nn.ModuleList()
            self.emlps.append(nn.Sequential(
                nn.Linear(3 * self.n_hidden, self.n_hidden),
                nn.ReLU(),
                nn.Linear(self.n_hidden, self.n_hidden),
            ))
        
        for _ in range(self.num_gnn_layers):
            conv = RGCNConv(self.n_hidden, self.n_hidden, num_relations, num_bases=self.n_bases)
            self.convs.append(conv)
            self.bns.append(nn.BatchNorm1d(self.n_hidden))

            if self.edge_update:
                self.emlps.append(nn.Sequential(
                    nn.Linear(3 * self.n_hidden, self.n_hidden),
                    nn.ReLU(),
                    nn.Linear(self.n_hidden, self.n_hidden),
                ))

        self.mlp = nn.Sequential(Linear(n_hidden*3, 50), nn.ReLU(), nn.Dropout(self.final_dropout), Linear(50, 25), nn.ReLU(), nn.Dropout(self.final_dropout),
                              Linear(25, n_classes))

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.reset_parameters()
            elif isinstance(m, RGCNConv):
                m.reset_parameters()
            elif isinstance(m, nn.BatchNorm1d):
                m.reset_parameters()

    def forward(self, x, edge_index, edge_attr):
        edge_type = edge_attr[:, -1].long()
        #edge_attr = edge_attr[:, :-1]
        src, dst = edge_index

        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)

        for i in range(self.num_gnn_layers):
            x =  (x + F.relu(self.bns[i](self.convs[i](x, edge_index, edge_type)))) / 2
            if self.edge_update:
                edge_attr = (edge_attr + F.relu(self.emlps[i](torch.cat([x[src], x[dst], edge_attr], dim=-1)))) / 2
        
        x = x[edge_index.T].reshape(-1, 2 * self.n_hidden).relu()
        x = torch.cat((x, edge_attr.view(-1, edge_attr.shape[1])), 1)
        x = self.mlp(x)
        out = x

        return x

## Train_utils.py

In [17]:
import torch
import tqdm
from torch_geometric.transforms import BaseTransform
from typing import Union
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import LinkNeighborLoader
from sklearn.metrics import f1_score
import json

class AddEgoIds(BaseTransform):
    r"""Add IDs to the centre nodes of the batch.
    """
    def __init__(self):
        pass

    def __call__(self, data: Union[Data, HeteroData]):
        x = data.x if not isinstance(data, HeteroData) else data['node'].x
        device = x.device
        ids = torch.zeros((x.shape[0], 1), device=device)
        if not isinstance(data, HeteroData):
            nodes = torch.unique(data.edge_label_index.view(-1)).to(device)
        else:
            nodes = torch.unique(data['node', 'to', 'node'].edge_label_index.view(-1)).to(device)
        ids[nodes] = 1
        if not isinstance(data, HeteroData):
            data.x = torch.cat([x, ids], dim=1)
        else: 
            data['node'].x = torch.cat([x, ids], dim=1)
        
        return data

def extract_param(parameter_name: str, args) -> float:
    """
    Extract the value of the specified parameter for the given model.
    
    Args:
    - parameter_name (str): Name of the parameter (e.g., "lr").
    - args (argparser): Arguments given to this specific run.
    
    Returns:
    - float: Value of the specified parameter.
    """
    file_path = './model_settings.json'
    with open(file_path, "r") as file:
        data = json.load(file)

    return data.get(args["model"], {}).get("params", {}).get(parameter_name, None)

def add_arange_ids(data_list):
    '''
    Add the index as an id to the edge features to find seed edges in training, validation and testing.

    Args:
    - data_list (str): List of tr_data, val_data and te_data.
    '''
    for data in data_list:
        if isinstance(data, HeteroData):
            data['node', 'to', 'node'].edge_attr = torch.cat([torch.arange(data['node', 'to', 'node'].edge_attr.shape[0]).view(-1, 1), data['node', 'to', 'node'].edge_attr], dim=1)
            offset = data['node', 'to', 'node'].edge_attr.shape[0]
            data['node', 'rev_to', 'node'].edge_attr = torch.cat([torch.arange(offset, data['node', 'rev_to', 'node'].edge_attr.shape[0] + offset).view(-1, 1), data['node', 'rev_to', 'node'].edge_attr], dim=1)
        else:
            data.edge_attr = torch.cat([torch.arange(data.edge_attr.shape[0]).view(-1, 1), data.edge_attr], dim=1)

def get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, transform, args):
    if isinstance(tr_data, HeteroData):
        tr_edge_label_index = tr_data['node', 'to', 'node'].edge_index
        tr_edge_label = tr_data['node', 'to', 'node'].y


        tr_loader =  LinkNeighborLoader(tr_data, num_neighbors=args["num_neighs"], 
                                    edge_label_index=(('node', 'to', 'node'), tr_edge_label_index), 
                                    edge_label=tr_edge_label, batch_size=args["batch_size"], shuffle=True, transform=transform)
        
        val_edge_label_index = val_data['node', 'to', 'node'].edge_index[:,val_inds]
        val_edge_label = val_data['node', 'to', 'node'].y[val_inds]


        val_loader =  LinkNeighborLoader(val_data, num_neighbors=args["num_neighs"], 
                                    edge_label_index=(('node', 'to', 'node'), val_edge_label_index), 
                                    edge_label=val_edge_label, batch_size=args["batch_size"], shuffle=False, transform=transform)
        
        te_edge_label_index = te_data['node', 'to', 'node'].edge_index[:,te_inds]
        te_edge_label = te_data['node', 'to', 'node'].y[te_inds]


        te_loader =  LinkNeighborLoader(te_data, num_neighbors=args["num_neighs"], 
                                    edge_label_index=(('node', 'to', 'node'), te_edge_label_index), 
                                    edge_label=te_edge_label, batch_size=args["batch_size"], shuffle=False, transform=transform)
    else:
        tr_loader =  LinkNeighborLoader(tr_data, num_neighbors=args["num_neighs"], batch_size=args["batch_size"], shuffle=True, transform=transform)
        val_loader = LinkNeighborLoader(val_data,num_neighbors=args["num_neighs"], edge_label_index=val_data.edge_index[:, val_inds],
                                        edge_label=val_data.y[val_inds], batch_size=args["batch_size"], shuffle=False, transform=transform)
        te_loader =  LinkNeighborLoader(te_data,num_neighbors=args["num_neighs"], edge_label_index=te_data.edge_index[:, te_inds],
                                edge_label=te_data.y[te_inds], batch_size=args["batch_size"], shuffle=False, transform=transform)
        
    return tr_loader, val_loader, te_loader

@torch.no_grad()
def evaluate_homo(loader, inds, model, data, device, args):
    '''Evaluates the model performane for homogenous graph data.'''
    preds = []
    ground_truths = []
    for batch in tqdm.tqdm(loader, disable=not args["tqdm"]):
        #select the seed edges from which the batch was created
        inds = inds.detach().cpu()
        batch_edge_inds = inds[batch.input_id.detach().cpu()]
        batch_edge_ids = loader.data.edge_attr.detach().cpu()[batch_edge_inds, 0]
        mask = torch.isin(batch.edge_attr[:, 0].detach().cpu(), batch_edge_ids)

        #add the seed edges that have not been sampled to the batch
        missing = ~torch.isin(batch_edge_ids, batch.edge_attr[:, 0].detach().cpu())

        if missing.sum() != 0 and (args["data"] == 'Small_J' or args["data"] == 'Small_Q'):
            missing_ids = batch_edge_ids[missing].int()
            n_ids = batch.n_id
            add_edge_index = data.edge_index[:, missing_ids].detach().clone()
            node_mapping = {value.item(): idx for idx, value in enumerate(n_ids)}
            add_edge_index = torch.tensor([[node_mapping[val.item()] for val in row] for row in add_edge_index])
            add_edge_attr = data.edge_attr[missing_ids, :].detach().clone()
            add_y = data.y[missing_ids].detach().clone()
        
            batch.edge_index = torch.cat((batch.edge_index, add_edge_index), 1)
            batch.edge_attr = torch.cat((batch.edge_attr, add_edge_attr), 0)
            batch.y = torch.cat((batch.y, add_y), 0)

            mask = torch.cat((mask, torch.ones(add_y.shape[0], dtype=torch.bool)))

        #remove the unique edge id from the edge features, as it's no longer needed
        batch.edge_attr = batch.edge_attr[:, 1:]
        
        with torch.no_grad():
            batch.to(device)
            out = model(batch.x, batch.edge_index, batch.edge_attr)
            out = out[mask]
            pred = out.argmax(dim=-1)
            preds.append(pred)
            ground_truths.append(batch.y[mask])
    pred = torch.cat(preds, dim=0).cpu().numpy()
    ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
    f1 = f1_score(ground_truth, pred)

    return f1

@torch.no_grad()
def evaluate_hetero(loader, inds, model, data, device, args):
    '''Evaluates the model performane for heterogenous graph data.'''
    preds = []
    ground_truths = []
    for batch in tqdm.tqdm(loader, disable=not args["tqdm"]):
        #select the seed edges from which the batch was created
        inds = inds.detach().cpu()
        batch_edge_inds = inds[batch['node', 'to', 'node'].input_id.detach().cpu()]
        batch_edge_ids = loader.data['node', 'to', 'node'].edge_attr.detach().cpu()[batch_edge_inds, 0]
        mask = torch.isin(batch['node', 'to', 'node'].edge_attr[:, 0].detach().cpu(), batch_edge_ids)

        #add the seed edges that have not been sampled to the batch
        missing = ~torch.isin(batch_edge_ids, batch['node', 'to', 'node'].edge_attr[:, 0].detach().cpu())

        if missing.sum() != 0 and (args["data"] == 'Small_J' or args["data"] == 'Small_Q'):
            missing_ids = batch_edge_ids[missing].int()
            n_ids = batch['node'].n_id
            add_edge_index = data['node', 'to', 'node'].edge_index[:, missing_ids].detach().clone()
            node_mapping = {value.item(): idx for idx, value in enumerate(n_ids)}
            add_edge_index = torch.tensor([[node_mapping[val.item()] for val in row] for row in add_edge_index])
            add_edge_attr = data['node', 'to', 'node'].edge_attr[missing_ids, :].detach().clone()
            add_y = data['node', 'to', 'node'].y[missing_ids].detach().clone()
        
            batch['node', 'to', 'node'].edge_index = torch.cat((batch['node', 'to', 'node'].edge_index, add_edge_index), 1)
            batch['node', 'to', 'node'].edge_attr = torch.cat((batch['node', 'to', 'node'].edge_attr, add_edge_attr), 0)
            batch['node', 'to', 'node'].y = torch.cat((batch['node', 'to', 'node'].y, add_y), 0)

            mask = torch.cat((mask, torch.ones(add_y.shape[0], dtype=torch.bool)))

        #remove the unique edge id from the edge features, as it's no longer needed
        batch['node', 'to', 'node'].edge_attr = batch['node', 'to', 'node'].edge_attr[:, 1:]
        batch['node', 'rev_to', 'node'].edge_attr = batch['node', 'rev_to', 'node'].edge_attr[:, 1:]
        
        with torch.no_grad():
            batch.to(device)
            out = model(batch.x_dict, batch.edge_index_dict, batch.edge_attr_dict)
            out = out[('node', 'to', 'node')]
            out = out[mask]
            pred = out.argmax(dim=-1)
            preds.append(pred)
            ground_truths.append(batch['node', 'to', 'node'].y[mask])
    pred = torch.cat(preds, dim=0).cpu().numpy()
    ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
    f1 = f1_score(ground_truth, pred)

    return f1

def save_model(model, optimizer, epoch, args, data_config):
    # Save the model in a dictionary
    torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
                }, f'{data_config["paths"]["model_to_save"]}/checkpoint_{args["unique_name"]}{"" if not args["finetune"] else "_finetuned"}.tar')
    
def load_model(model, device, args, config, data_config):
    checkpoint = torch.load(f'{data_config["paths"]["model_to_load"]}/checkpoint_{args["unique_name"]}.tar')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    return model, optimizer