In [None]:
ignore_warnings = True

# ------------------------------------------------------------------------------
# Custom installations
!pip install torchmetrics
!pip install torch==2.5.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install torch-scatter torch-sparse pyg-lib torch-geometric \
  -f https://data.pyg.org/whl/torch-2.5.1+cu118.html

# ------------------------------------------------------------------------------
# Setup imports
import logging
import os
import shutil
import sys
import warnings
from rich.logging import RichHandler
from google.colab import drive

# ------------------------------------------------------------------------------
# Logging

# Need to override the root logger, at least in colab
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
if root_logger.hasHandlers():
    root_logger.handlers.clear()

# Add `rich` log handling
rich_handler = RichHandler(rich_tracebacks=True, markup=True)
root_logger.addHandler(rich_handler)

# ------------------------------------------------------------------------------
# Colab setup

# Google Colaboratory executes in an environment with a file system
# that has a Linux topography, but where the user should work under
# the `/content` directory
COLAB_ROOT = "/content"

REPO_URL = "https://github.com/engie4800/dsi-capstone-spring-2025-TD-anti-money-laundering.git"
REPO_ROOT = os.path.join(COLAB_ROOT, REPO_URL.split("/")[-1].split(".")[0])
REPO_BRANCH = "main"


# Clones the repository at `/content/dsi-capstone-spring-2025-TD-anti-money-laundering`
# no matter what
if os.path.exists(REPO_ROOT):
    shutil.rmtree(REPO_ROOT)
os.chdir(COLAB_ROOT)
!git clone {REPO_URL}

# Pulls the latest code from the provided branch and adds the
# analysis pipeline source code to the Python system path
os.chdir(REPO_ROOT)
!git fetch origin {REPO_BRANCH}
!git checkout {REPO_BRANCH}
sys.path.append(os.path.join(REPO_ROOT, "Code/src"))
os.chdir(COLAB_ROOT)

# Get ready to load the data
CONTENT_BASE = os.path.join(COLAB_ROOT, "drive")
DATA_DIR = os.path.join(CONTENT_BASE, "My Drive/capstone/data")
DATASET = "HI-Tiny_Trans.csv"  # tiny is the 25% dataset Sophie created
DATASET_PATH = os.path.join(DATA_DIR, DATASET)

# Mount the drive
drive.mount(CONTENT_BASE)
files = os.listdir(DATA_DIR)
logging.info("\nData files available:")
logging.info(files)

# ------------------------------------------------------------------------------
# Project imports
#
# Need to remove these modules before importing them, to try to force them to
# be updated if we've pushed changes through git and are re-running the
# notebook

for m in ["helpers", "pipeline", "plotting"]:
    if m in sys.modules:
        del sys.modules[m]

import helpers
import pipeline
import plotting

helpers.add_cell_timer()

# ------------------------------------------------------------------------------
# Warnings
if ignore_warnings:
    warnings.filterwarnings("ignore")

In [None]:
# PREPROCESSING

# ------------------------------------------------------------------------------
# Pipeline setup
p = pipeline.GNNModelPipeline(dataset_path=DATASET_PATH)

# Initial, pre-train-test-split preprocessing
p.rename_columns()
p.drop_duplicates()
p.check_for_null()
p.extract_currency_features()
p.extract_time_features()
p.create_unique_ids()
p.extract_additional_time_features()
p.cyclical_encoding()
p.apply_label_encoding()
p.apply_one_hot_encoding()

# Train-test split and...
p.split_train_test_val(
    split_type="temporal_agg",
    test_size=0.2,
    val_size=0.2,
)
p.compute_split_specific_node_features()
p.scale_node_data_frames()
p.split_train_test_val_graph()
p.get_data_loaders()

# ------------------------------------------------------------------------------
# Command landfill (commandfill)
# p.df.head()
# plotting.plot_column_imbalances(p)
# p.df.columns

In [None]:
# TRAINING

# ------------------------------------------------------------------------------
# Training setup
p.initialize_training()
p.train(threshold=0.5, epochs=50)