# Practice GNN

## Notebook configuration

In [1]:
import random
import networkx as nx
import pandas as pd
import numpy as np
import ipywidgets as widgets
import os
import sys
import matplotlib.pyplot as plt
import warnings
from tabulate import tabulate
from tqdm import trange
from IPython import get_ipython
from IPython.display import display
from time import monotonic
from pprint import pprint
from google.colab import drive
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss, Sequential, Linear, ReLU
!pip install torch==2.5.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install torch-scatter torch-sparse pyg-lib torch-geometric \
  -f https://data.pyg.org/whl/torch-2.5.1+cu118.html
from torch_geometric.nn import GINEConv
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, LinkNeighborLoader
!pip install torchmetrics
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score, BinaryAveragePrecision
from sklearn.model_selection import train_test_split

warnings.filterwarnings('ignore')

content_base = "/content/drive"
drive.mount(content_base)

# Project data
data_dir = os.path.join(content_base, "My Drive/Capstone/data")
data_file = os.path.join(data_dir, "subset_transactions2.csv")

Looking in indexes: https://download.pytorch.org/whl/cu118
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu118.html
Collecting torchmetrics
  Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.14.2-py3-none-any.whl.metadata (5.6 kB)
Downloading torchmetrics-1.7.0-py3-none-any.whl (960 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m960.9/960.9 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.14.2-py3-none-any.whl (28 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.14.2 torchmetrics-1.7.0
Mounted at /content/drive


In [2]:
# Google Colaboratory executes in an environment with a file system
# that has a Linux topography, but where the user should work under
# the `/content` directory
COLAB_ROOT = "/content"

REPO_URL = "https://github.com/engie4800/dsi-capstone-spring-2025-TD-anti-money-laundering.git"
REPO_ROOT = os.path.join(COLAB_ROOT, REPO_URL.split("/")[-1].split(".")[0])
REPO_BRANCH = "sophie"

# Clones the repository at `/content/dsi-capstone-spring-2025-TD-anti-money-laundering`
if not os.path.exists(REPO_ROOT):
  os.chdir(COLAB_ROOT)
  !git clone {REPO_URL}

# Pulls the latest code from the provided branch and adds the
# analysis pipeline source code to the Python system path
os.chdir(REPO_ROOT)
!git pull
!git checkout {REPO_BRANCH}
sys.path.append(os.path.join(REPO_ROOT, "Code/src"))
os.chdir(COLAB_ROOT)

Cloning into 'dsi-capstone-spring-2025-TD-anti-money-laundering'...
remote: Enumerating objects: 643, done.[K
remote: Counting objects: 100% (114/114), done.[K
remote: Compressing objects: 100% (77/77), done.[K
remote: Total 643 (delta 56), reused 44 (delta 35), pack-reused 529 (from 2)[K
Receiving objects: 100% (643/643), 26.53 MiB | 54.01 MiB/s, done.
Resolving deltas: 100% (320/320), done.
Already up to date.
Branch 'sophie' set up to track remote branch 'sophie' from 'origin'.
Switched to a new branch 'sophie'


In [3]:
from helpers import add_cell_timer
from pipeline import ModelPipeline
add_cell_timer()

## Data preprocessing

Run initial full-dataset preprocessing

In [11]:
pl = ModelPipeline(data_file)
pl.run_preprocessing()


⏱️ Execution time: 22.27s


In [12]:
node_features = [
    # TODO
    # A list of tuples with this structure >>>
    # (column to include, treatment/method, column rename)

    ('from_bank', 'first', None),
]

pl.extract_nodes(node_features, add_graph_features=False)


⏱️ Execution time: 0.02s


In [13]:
pl.df = pl.df.sort_values(by='timestamp_int')
pl.df = pl.df.reset_index(drop=True)
pl.df['edge_id'] = pl.df.index
X_cols = ['edge_id','from_bank', 'to_bank', 'received_amount', 'received_currency',
       'sent_amount', 'sent_currency', 'payment_type', 'from_account_idx',
       'to_account_idx', 'sent_amount_usd', 'received_amount_usd',
       'hour_of_day', 'day_of_week', 'seconds_since_midnight', 'timestamp_int',
       'timestamp_scaled', 'day_sin', 'day_cos', 'time_of_day_sin',
       'time_of_day_cos', 'is_weekend']
y_col = 'is_laundering'
X_train, X_val, X_test, y_train, y_val, y_test = pl.split_train_test_val(X_cols, y_col, test_size=0.15, val_size=0.15, split_type='temporal_agg')


⏱️ Execution time: 1.16s


In [14]:
numerical_feats = ['sent_amount_usd', 'received_amount_usd', 'timestamp_scaled']
X_train, X_test, X_val = pl.numerical_scaling(numerical_feats)


⏱️ Execution time: 0.08s


In [15]:
edge_features = ['edge_id','received_amount', 'received_currency','sent_amount',
                 'sent_currency', 'payment_type', 'sent_amount_usd',
                 'hour_of_day', 'day_of_week', 'seconds_since_midnight',
                 'timestamp_scaled']
node_features = ['from_bank'] #,'degree_centrality_sent_amount','pagerank_sent_amount', 'degree_centrality_received_amount', 'pagerank_received_amount']
train_data, val_data, test_data = pl.generate_tensors(edge_features,node_features)


⏱️ Execution time: 0.25s


## GNNs

Cannot use GCN or GAT!

* The Graph Convolutional Network (GCN), implemented with `GCNConv`, only aggregates features from neighboring nodes and does not use edge attributes in its message passing.
* Graph Attention Networks (GAT), implemented with `GATConv`, allows edge attention weights, which can indirectly incorporate edge attributes. Problem: If all nodes have the same feature vector (e.g., initialized to 1), then the computed attention scores will be the same for all edges. We'd need to modify GAT to use edge features meaningfully in the attention computation.

`GINeConv`
* Directly includes edge attributes in message passing using an MLP-based edge transformation.

`EdgeConv`
* dynamically computes edge embeddings and updates node features based on edges

**We'll be using GINeConv moving forward.**

In [16]:
# If on GPU, do as below
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda

⏱️ Execution time: 0.0s


### Model

In [17]:
class EdgeGINE(nn.Module):
    def __init__(self, n_node_feats, n_edge_feats, n_hidden=64):
        super(EdgeGINE, self).__init__()

        self.n_hidden = n_hidden
        self.n_node_feats = n_node_feats
        self.n_edge_feats = n_edge_feats

        # Linear to embed node and edges
        self.node_emb = nn.Linear(self.n_node_feats, self.n_hidden) # [num_nodes, n_hidden]
        self.edge_emb = nn.Linear(self.n_edge_feats, self.n_hidden) # [num_edges, n_hidden]

        # MLP that processes edge features, passed into GINEConv
        nn_edge = Sequential(Linear(self.n_hidden, self.n_hidden), ReLU(), Linear(self.n_hidden, self.n_hidden))

        # Two GINEConv layers using nn_edge when it needs to process edge attributes
        self.gine1 = GINEConv(nn_edge, edge_dim=self.n_hidden, train_eps=True)
        self.gine2 = GINEConv(nn_edge, edge_dim=self.n_hidden, train_eps=True)

        # Edge updates MLPs
        self.emlp1 = Sequential(
                nn.Linear(3 * self.n_hidden, self.n_hidden),
                nn.ReLU(),
                nn.Linear(self.n_hidden, self.n_hidden),
            )
        self.emlp2 = Sequential(
                nn.Linear(3 * self.n_hidden, self.n_hidden),
                nn.ReLU(),
                nn.Linear(self.n_hidden, self.n_hidden),
            )

        # MLP for edge classification
        self.mlp = nn.Sequential(
            nn.Linear(3 * self.n_hidden, 128), # src, dest, edge
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x, edge_index, edge_attr):
        """
        x: Node features (or placeholder embeddings if None)
        edge_index: Edge list (2, n_edges)
        edge_attr: Edge features (n_edges, self.n_edge_feats)
        """
        src, dest = edge_index
        if x is None:  # If no node features, use trainable embeddings
            x = torch.ones((edge_index.max().item() + 1, 1), device=device)

        # Create some initial embeddings for nodes and edges
        x = self.node_emb(x) # MLP
        edge_attr = self.edge_emb(edge_attr) # MLP
        x, edge_attr, edge_index = x.to(device), edge_attr.to(device), edge_index.to(device)

        # Pass nodes and edges through GINE layer1
        x = x + F.relu(self.gine1(x, edge_index, edge_attr))

        # Update edges with MLP1
        edge_attr = edge_attr + self.emlp1(torch.cat([x[src], x[dest], edge_attr], dim=-1)) / 2

        # Pass nodes and edges through GINE layer2
        x = F.relu(self.gine1(x, edge_index, edge_attr))

        # Update edges with MLP2
        edge_attr = edge_attr + self.emlp2(torch.cat([x[src], x[dest], edge_attr], dim=-1)) / 2

        # Get output for classification
        src_embed, dest_embed = x[src], x[dest]
        edge_inputs = torch.cat([src_embed, dest_embed, edge_attr], dim=1)
        edge_logits = self.mlp(edge_inputs).squeeze(1)

        return edge_logits



⏱️ Execution time: 0.0s


### Create data loaders with `LinkNeighborLoader`
Goal: Create data loaders - split into batches using `LinkNeighborLoader`, incorporating masking in the loading process & batching

**LinkNeighborLoader:**

**num_neighbors:** how many neighbors are sampled per node -- only sampling a subgraph around each edge in a batch. it is size [x,y] because we have 2 layers (sample x nodes in layer 1 and y nodes in layer 2).
- Let’s say your batch contains 100 edges, and each edge touches two nodes (source and destination). Then LinkNeighborLoader will:
  - Identify all unique nodes from those 100 edges
  - For each of those nodes:
      - Sample 10 neighbors (for layer 1)
      - Then, for each of those neighbors, sample another 10 neighbors (for layer 2)
  - Build a mini subgraph for this batch using only those sampled nodes and edges
- Imagine you're doing link prediction for a social network:
  - batch_size = 1024 means you're analyzing 1024 friend requests at a time
  - num_neighbors = [10, 10] means for each person in the request, you look at:
    - Their 10 direct friends
    - And 10 friends-of-friends per direct friend

In [49]:
# Move data to GPU if using
tr_data = train_data.to(device)
val_data = val_data.to(device)
te_data = test_data.to(device)

batch_size=8192
num_neighbors=[100,100]

t1 = int(len(pl.df) * 0.7)
t2 = int(len(pl.df) * 0.85)

# Indices of *labels* we will evaluate for each loader
tr_inds = torch.tensor(np.arange(0, t1), device=device)
val_inds = torch.tensor(np.arange(t1, t2), device=device)
te_inds = torch.tensor(np.arange(t2, len(pl.df)), device=device)

# Create data loaders, and restrict evaluation to correct edges for val and test
tr_loader = LinkNeighborLoader(data=tr_data, dataset=[tr_data], edge_label_index=tr_data.edge_index, edge_label=tr_data.y,
                               num_neighbors=num_neighbors, batch_size=batch_size, shuffle=True)
val_loader = LinkNeighborLoader(data=val_data,dataset=[val_data], num_neighbors=num_neighbors, edge_label_index=val_data.edge_index[:, val_inds],
                                edge_label=val_data.y[val_inds], batch_size=batch_size, shuffle=False)
te_loader =  LinkNeighborLoader(data=te_data,dataset=[te_data], num_neighbors=num_neighbors, edge_label_index=te_data.edge_index[:, te_inds],
                        edge_label=te_data.y[te_inds], batch_size=batch_size, shuffle=False)


⏱️ Execution time: 0.05s


In [48]:
# Initialize model & optimizer
num_edge_features = len(edge_features)-1 # num edge feats - edge_id
num_node_features = 1
model = EdgeGINE(num_node_features, num_edge_features).to(device)
optimizer = Adam(model.parameters(), lr=0.01)
criterion = BCEWithLogitsLoss(pos_weight=torch.tensor([10.0], device=device))

print(sum(p.numel() for p in model.parameters() if p.requires_grad))
model

83523


EdgeGINE(
  (node_emb): Linear(in_features=1, out_features=64, bias=True)
  (edge_emb): Linear(in_features=10, out_features=64, bias=True)
  (gine1): GINEConv(nn=Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  ))
  (gine2): GINEConv(nn=Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  ))
  (emlp1): Sequential(
    (0): Linear(in_features=192, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  )
  (emlp2): Sequential(
    (0): Linear(in_features=192, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
  )
  (mlp): Sequential(
    (0): Linear(in_features=192, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (


⏱️ Execution time: 0.01s


### Train model

In [52]:
def train(model, optimizer, loss_fn, tr_loader, val_loader, threshold=0.5, epochs=20):

    # Metrics
    acc_fn = BinaryAccuracy(threshold=threshold).to(device)
    prec_fn = BinaryPrecision(threshold=threshold).to(device)
    rec_fn = BinaryRecall(threshold=threshold).to(device)
    f1_fn = BinaryF1Score(threshold=threshold).to(device)
    pr_auc_fn = BinaryAveragePrecision().to(device)

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_preds, train_targets, train_probs = [], [], []

        for batch in tqdm(tr_loader, desc=f"Epoch {epoch+1} Training"):
            batch = batch.to(device)
            optimizer.zero_grad()

            # Get indices of seed edges for this batch
            batch_input_ids = batch.input_id.detach().cpu()  # maps to tr_inds
            batch_edge_inds = tr_inds[batch_input_ids]       # global edge indices
            batch_edge_ids = tr_data.edge_attr[batch_edge_inds, 0].cpu()  # edge ID column

            # Find which edges in this batch are seed edges
            edge_ids_in_batch = batch.edge_attr[:, 0].detach().cpu()
            mask = torch.isin(edge_ids_in_batch, batch_edge_ids).to(device)

            # Remove edge_id from features before forward pass
            batch.edge_attr = batch.edge_attr[:, 1:]

            # Forward pass of model
            logits = model(batch.x, batch.edge_index, batch.edge_attr)
            logits = logits[mask] # Restrict to seed edges
            target = batch.y[mask] # Restrict to seed edges
            batch_probs = torch.sigmoid(logits)
            batch_preds = (batch_probs > threshold).long()

            # Calculate batch loss & backpropagate
            loss = loss_fn(logits, target.float())
            loss.backward()
            optimizer.step()

            # Running results
            train_loss += loss.item() * logits.size(0)
            train_preds.append(batch_preds)
            train_targets.append(target)
            train_probs.append(batch_probs)

        # Concatenate all training results
        train_preds = torch.cat(train_preds)
        train_targets = torch.cat(train_targets)
        train_probs = torch.cat(train_probs)
        train_loss /= len(train_targets)

        # Compute training metrics
        train_acc = acc_fn(train_preds, train_targets)
        train_prec = prec_fn(train_preds, train_targets)
        train_rec = rec_fn(train_preds, train_targets)
        train_f1 = f1_fn(train_preds, train_targets)
        train_pr_auc = pr_auc_fn(train_probs, train_targets)

        # === Validation ===
        model.eval()
        val_loss = 0
        val_preds, val_targets, val_probs = [], [], []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
                batch = batch.to(device)

                # Get indices of seed edges for this batch
                batch_input_ids = batch.input_id.detach().cpu()
                batch_edge_inds = val_inds[batch_input_ids]
                batch_edge_ids = val_data.edge_attr[batch_edge_inds, 0].cpu()

                # Find which edges in this batch are seed edges
                edge_ids_in_batch = batch.edge_attr[:, 0].detach().cpu()
                mask = torch.isin(edge_ids_in_batch, batch_edge_ids).to(device)

                # Remove edge_id from features before forward pass
                batch.edge_attr = batch.edge_attr[:, 1:]

                # Forward pass of model
                logits = model(batch.x, batch.edge_index, batch.edge_attr)
                logits = logits[mask]
                target = batch.y[mask]
                batch_probs = torch.sigmoid(logits)
                batch_preds = (batch_probs > threshold).long()

                # Calculate loss
                loss = criterion(logits, target.float())

                # Running results
                val_loss += loss.item() * logits.size(0)
                val_preds.append(batch_preds)
                val_targets.append(target)
                val_probs.append(batch_probs)

        val_preds = torch.cat(val_preds)
        val_targets = torch.cat(val_targets)
        val_probs = torch.cat(val_probs)
        val_loss /= len(val_targets)

        val_acc = acc_fn(val_preds, val_targets)
        val_prec = prec_fn(val_preds, val_targets)
        val_rec = rec_fn(val_preds, val_targets)
        val_f1 = f1_fn(val_preds, val_targets)
        val_pr_auc = pr_auc_fn(val_probs, val_targets)

        # Print every epoch
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
        print(f"Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f}")
        print(f"Train PR-AUC: {train_pr_auc:.4f} | Val PR-AUC: {val_pr_auc:.4f}")
        print(f"Train Prec: {train_prec:.4f} | Val Prec: {val_prec:.4f}")
        print(f"Train Rec: {train_rec:.4f} | Val Rec: {val_rec:.4f}")
        print("-" * 80)


⏱️ Execution time: 0.0s


In [None]:
# Run Training
train(model, optimizer, criterion, tr_loader, val_loader, threshold=0.5, epochs=50)

Epoch 1 Training: 100%|██████████| 107/107 [00:20<00:00,  5.33it/s]
Epoch 1 Validation: 100%|██████████| 23/23 [00:01<00:00, 11.85it/s]


Epoch 1/50
Train Loss: 56177.4435 | Val Loss: 48.6337
Train Acc: 0.9510 | Val Acc: 0.9929
Train F1: 0.0063 | Val F1: 0.0074
Train PR-AUC: 0.0033 | Val PR-AUC: 0.0058
Train Prec: 0.0034 | Val Prec: 0.0081
Train Rec: 0.0474 | Val Rec: 0.0069
--------------------------------------------------------------------------------


Epoch 2 Training: 100%|██████████| 107/107 [00:20<00:00,  5.34it/s]
Epoch 2 Validation: 100%|██████████| 23/23 [00:01<00:00, 11.86it/s]


Epoch 2/50
Train Loss: 39.1981 | Val Loss: 0.3005
Train Acc: 0.9950 | Val Acc: 0.9959
Train F1: 0.0050 | Val F1: 0.0000
Train PR-AUC: 0.0042 | Val PR-AUC: 0.0040
Train Prec: 0.0075 | Val Prec: 0.0000
Train Rec: 0.0038 | Val Rec: 0.0000
--------------------------------------------------------------------------------


Epoch 3 Training: 100%|██████████| 107/107 [00:20<00:00,  5.31it/s]
Epoch 3 Validation: 100%|██████████| 23/23 [00:01<00:00, 11.83it/s]


Epoch 3/50
Train Loss: 0.3008 | Val Loss: 0.2005
Train Acc: 0.9966 | Val Acc: 0.9960
Train F1: 0.0007 | Val F1: 0.0000
Train PR-AUC: 0.0034 | Val PR-AUC: 0.0040
Train Prec: 0.0175 | Val Prec: 0.0000
Train Rec: 0.0003 | Val Rec: 0.0000
--------------------------------------------------------------------------------


Epoch 4 Training: 100%|██████████| 107/107 [00:20<00:00,  5.32it/s]
Epoch 4 Validation: 100%|██████████| 23/23 [00:01<00:00, 11.86it/s]


Epoch 4/50
Train Loss: 0.2036 | Val Loss: 0.1790
Train Acc: 0.9966 | Val Acc: 0.9960
Train F1: 0.0007 | Val F1: 0.0000
Train PR-AUC: 0.0034 | Val PR-AUC: 0.0040
Train Prec: 0.0233 | Val Prec: 0.0000
Train Rec: 0.0003 | Val Rec: 0.0000
--------------------------------------------------------------------------------


Epoch 5 Training: 100%|██████████| 107/107 [00:20<00:00,  5.31it/s]
Epoch 5 Validation: 100%|██████████| 23/23 [00:01<00:00, 11.67it/s]


Epoch 5/50
Train Loss: 0.1651 | Val Loss: 0.1698
Train Acc: 0.9967 | Val Acc: 0.9961
Train F1: 0.0000 | Val F1: 0.0000
Train PR-AUC: 0.0034 | Val PR-AUC: 0.0040
Train Prec: 0.0000 | Val Prec: 0.0000
Train Rec: 0.0000 | Val Rec: 0.0000
--------------------------------------------------------------------------------


Epoch 6 Training: 100%|██████████| 107/107 [00:20<00:00,  5.31it/s]
Epoch 6 Validation: 100%|██████████| 23/23 [00:02<00:00, 11.33it/s]


Epoch 6/50
Train Loss: 0.1523 | Val Loss: 0.4439
Train Acc: 0.9967 | Val Acc: 0.9961
Train F1: 0.0000 | Val F1: 0.0000
Train PR-AUC: 0.0034 | Val PR-AUC: 0.0040
Train Prec: 0.0000 | Val Prec: 0.0000
Train Rec: 0.0000 | Val Rec: 0.0000
--------------------------------------------------------------------------------


Epoch 7 Training: 100%|██████████| 107/107 [00:20<00:00,  5.30it/s]
Epoch 7 Validation: 100%|██████████| 23/23 [00:01<00:00, 11.79it/s]


Epoch 7/50
Train Loss: 0.1804 | Val Loss: 0.1701
Train Acc: 0.9966 | Val Acc: 0.9959
Train F1: 0.0000 | Val F1: 0.0000
Train PR-AUC: 0.0034 | Val PR-AUC: 0.0039
Train Prec: 0.0000 | Val Prec: 0.0000
Train Rec: 0.0000 | Val Rec: 0.0000
--------------------------------------------------------------------------------


Epoch 8 Training: 100%|██████████| 107/107 [00:20<00:00,  5.30it/s]
Epoch 8 Validation: 100%|██████████| 23/23 [00:01<00:00, 11.76it/s]


Epoch 8/50
Train Loss: 0.1485 | Val Loss: 0.1686
Train Acc: 0.9967 | Val Acc: 0.9960
Train F1: 0.0000 | Val F1: 0.0000
Train PR-AUC: 0.0033 | Val PR-AUC: 0.0039
Train Prec: 0.0000 | Val Prec: 0.0000
Train Rec: 0.0000 | Val Rec: 0.0000
--------------------------------------------------------------------------------


Epoch 9 Training: 100%|██████████| 107/107 [00:20<00:00,  5.28it/s]
Epoch 9 Validation: 100%|██████████| 23/23 [00:01<00:00, 11.88it/s]


Epoch 9/50
Train Loss: 0.1467 | Val Loss: 0.1680
Train Acc: 0.9967 | Val Acc: 0.9960
Train F1: 0.0000 | Val F1: 0.0000
Train PR-AUC: 0.0034 | Val PR-AUC: 0.0039
Train Prec: 0.0000 | Val Prec: 0.0000
Train Rec: 0.0000 | Val Rec: 0.0000
--------------------------------------------------------------------------------


Epoch 10 Training: 100%|██████████| 107/107 [00:20<00:00,  5.26it/s]
Epoch 10 Validation: 100%|██████████| 23/23 [00:01<00:00, 11.83it/s]


Epoch 10/50
Train Loss: 0.1464 | Val Loss: 0.1675
Train Acc: 0.9967 | Val Acc: 0.9960
Train F1: 0.0000 | Val F1: 0.0000
Train PR-AUC: 0.0035 | Val PR-AUC: 0.0039
Train Prec: 0.0000 | Val Prec: 0.0000
Train Rec: 0.0000 | Val Rec: 0.0000
--------------------------------------------------------------------------------


Epoch 11 Training:  70%|███████   | 75/107 [00:14<00:06,  5.32it/s]

# Alternative loss functions

In [None]:
# # Use focal loss to focus on rare positives
# class FocalLoss(torch.nn.Module):
#     def __init__(self, gamma=2.0, alpha=0.25):
#         super().__init__()
#         self.gamma = gamma
#         self.alpha = alpha

#     def forward(self, logits, targets):
#         bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
#         pt = torch.exp(-bce_loss)  # Probabilities of correct classification
#         focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
#         return focal_loss.mean()

# class HybridLoss(torch.nn.Module):
#     """Hybrid Loss that balances BCE (for accuracy) and Focal Loss (for recall)"""
#     def __init__(self, alpha=0.25, gamma=2.0, focal_weight=0.5):
#         super().__init__()
#         self.bce = torch.nn.BCEWithLogitsLoss()
#         self.alpha = alpha
#         self.gamma = gamma
#         self.focal_weight = focal_weight  # Weighting factor between BCE and Focal Loss

#     def forward(self, logits, targets):
#         # BCE Loss
#         bce_loss = self.bce(logits, targets.float())

#         # Focal Loss
#         probs = torch.sigmoid(logits)
#         bce_loss_per_sample = F.binary_cross_entropy_with_logits(logits, targets.float(), reduction="none")
#         focal_loss = self.alpha * (1 - torch.exp(-bce_loss_per_sample)) ** self.gamma * bce_loss_per_sample
#         focal_loss = focal_loss.mean()

#         # Combine BCE and Focal Loss
#         total_loss = (1 - self.focal_weight) * bce_loss + self.focal_weight * focal_loss
#         return total_loss

# criterion = FocalLoss(gamma=2.0, alpha=0.25)
#criterion = HybridLoss(focal_weight=0.3)  # Adjust weight (0.3-0.6 works well)