In [1]:
!uv sync

[2mResolved [1m64 packages[0m [2min 5ms[0m[0m
[2mUninstalled [1m25 packages[0m [2min 21.38s[0m[0m
 [31m-[39m [1masttokens[0m[2m==3.0.1[0m
 [31m-[39m [1mcomm[0m[2m==0.2.3[0m
 [31m-[39m [1mdebugpy[0m[2m==1.8.17[0m
 [31m-[39m [1mdecorator[0m[2m==5.2.1[0m
 [31m-[39m [1mexecuting[0m[2m==2.2.1[0m
 [31m-[39m [1mipykernel[0m[2m==7.1.0[0m
 [31m-[39m [1mipython[0m[2m==9.7.0[0m
 [31m-[39m [1mipython-pygments-lexers[0m[2m==1.1.1[0m
 [31m-[39m [1mjedi[0m[2m==0.19.2[0m
 [31m-[39m [1mjupyter-client[0m[2m==8.6.3[0m
 [31m-[39m [1mjupyter-core[0m[2m==5.9.1[0m
 [31m-[39m [1mmatplotlib-inline[0m[2m==0.2.1[0m
 [31m-[39m [1mnest-asyncio[0m[2m==1.6.0[0m
 [31m-[39m [1mparso[0m[2m==0.8.5[0m
 [31m-[39m [1mpexpect[0m[2m==4.9.0[0m
 [31m-[39m [1mplatformdirs[0m[2m==4.5.0[0m
 [31m-[39m [1mprompt-toolkit[0m[2m==3.0.52[0m
 [31m-[39m [1mptyprocess[0m[2m==0.7.0[0m
 [31m-[39m [1mpure-eval[0m[2m==0

## CUDA Testing

In [2]:
# save as test_cuda.py and run: python3 test_cuda.py

import platform

print("=== Environment ===")
print("Platform:", platform.platform())
print("Python:", platform.python_version())

try:
    import torch
except ImportError as e:
    print("\nPyTorch is not installed or not in this Python environment.")
    raise SystemExit(e)

print("\n=== PyTorch / CUDA Info ===")
print("torch.__version__:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)

cuda_available = torch.cuda.is_available()
print("torch.cuda.is_available():", cuda_available)

if not cuda_available:
    print("\nCUDA is NOT available to PyTorch in this environment.")
else:
    # Number of devices
    device_count = torch.cuda.device_count()
    print("torch.cuda.device_count():", device_count)

    for i in range(device_count):
        print(f"  device {i}: {torch.cuda.get_device_name(i)}")

    # Simple tensor test on GPU
    try:
        x = torch.rand(3, 3, device="cuda")
        y = torch.rand(3, 3, device="cuda")
        z = x @ y
        print("\nSuccessfully ran a matrix multiply on CUDA.")
        print("z.device:", z.device)
    except Exception as e:
        print("\nERROR: Allocation or compute on CUDA failed:")
        print(e)

print("\n=== Test Complete ===")

=== Environment ===
Platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39
Python: 3.12.3

=== PyTorch / CUDA Info ===
torch.__version__: 2.8.0+cu128
torch.version.cuda: 12.8
torch.cuda.is_available(): True
torch.cuda.device_count(): 1
  device 0: NVIDIA GeForce RTX 4070 Laptop GPU

Successfully ran a matrix multiply on CUDA.
z.device: cuda:0

=== Test Complete ===


In [None]:
import torch


print(torch.__version__)
print(torch.version.cuda)
print(torch.backends.cudnn.version())

2.9.1+cu128
12.8
91002


In [1]:
!uv pip install pyg-lib -f https://data.pyg.org/whl/torch-2.8.0+cu128.html

[2mAudited [1m1 package[0m [2min 189ms[0m[0m


## Diagnostic: Check `pyg-lib` availability
Run the next cell in this notebook kernel to verify whether `pyg_lib` (or `torch_sparse`) is importable here and whether `LinkNeighborLoader` works. If it fails, compare the printed `sys.executable` to your terminal environment.

In [2]:
import sys, importlib, traceback
print('--- Kernel executable ---')
print(sys.executable)
print('\n--- sys.path (first 6 entries) ---')
for p in sys.path[:6]:
    print(' ', p)

def try_find(name):
    try:
        spec = importlib.util.find_spec(name)
        print(f"find_spec('{name}') ->", spec)
        if spec is not None:
            origin = getattr(spec, 'origin', None)
            loader = getattr(spec, 'loader', None)
            print(f"  origin={origin} loader={loader}")
    except Exception as e:
        print(f"Error while find_spec('{name}'):", e)

try_find('pyg_lib')
try_find('torch_sparse')
try_find('torch_geometric')

print('\n--- Import attempts ---')
try:
    import torch
    print('torch.__version__:', torch.__version__)
    print('torch.version.cuda:', torch.version.cuda)
    print('torch.cuda.is_available():', torch.cuda.is_available())
except Exception:
    print('Failed to import torch:\n', traceback.format_exc())

try:
    import pyg_lib
    print('\nImported pyg_lib; version / repr ->', getattr(pyg_lib, '__version__', repr(pyg_lib)))
except Exception:
    print('\npyg_lib import failed:\n', traceback.format_exc())

try:
    import torch_sparse
    print('\nImported torch_sparse; version / repr ->', getattr(torch_sparse, '__version__', repr(torch_sparse)))
except Exception:
    print('\ntorch_sparse import failed:\n', traceback.format_exc())

print('\n--- Try a minimal LinkNeighborLoader run (if torch_geometric available) ---')
try:
    from torch_geometric.data import Data
    from torch_geometric.loader import LinkNeighborLoader
    import torch

    edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)
    num_nodes = 3
    x = torch.randn((num_nodes, 4))
    edge_attr = torch.randn((edge_index.size(1), 3))
    edge_label_index = edge_index
    edge_label = torch.tensor([0, 1], dtype=torch.long)
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    print('Constructed tiny Data:', data)

    try:
        loader = LinkNeighborLoader(
            data,
            num_neighbors=[2, 2],
            batch_size=1,
            edge_label_index=edge_label_index,
            edge_label=edge_label,
            shuffle=False,
            neg_sampling_ratio=0.0,
        )
        print('LinkNeighborLoader created successfully. Iterating one batch...')
        for b_idx, batch in enumerate(loader):
            print('  Batch keys:', list(batch.keys()))
            if hasattr(batch, 'edge_label'):
                print('  batch.edge_label ->', batch.edge_label)
            else:
                print('  batch has no edge_label attribute')
            break
    except Exception:
        print('LinkNeighborLoader construction/iteration failed:\n', traceback.format_exc())

except Exception:
    print('torch_geometric import or LinkNeighborLoader not available:\n', traceback.format_exc())

print('\n--- Diagnostic complete ---')

--- Kernel executable ---
/mnt/d/SFSU/CSC871/csc871-anti-money-laundering-ibm-gnn/.venv/bin/python

--- sys.path (first 6 entries) ---
  /usr/lib/python312.zip
  /usr/lib/python3.12
  /usr/lib/python3.12/lib-dynload
  
  /mnt/d/SFSU/CSC871/csc871-anti-money-laundering-ibm-gnn/.venv/lib/python3.12/site-packages
find_spec('pyg_lib') -> ModuleSpec(name='pyg_lib', loader=<_frozen_importlib_external.SourceFileLoader object at 0x7623ef33bc20>, origin='/mnt/d/SFSU/CSC871/csc871-anti-money-laundering-ibm-gnn/.venv/lib/python3.12/site-packages/pyg_lib/__init__.py', submodule_search_locations=['/mnt/d/SFSU/CSC871/csc871-anti-money-laundering-ibm-gnn/.venv/lib/python3.12/site-packages/pyg_lib'])
  origin=/mnt/d/SFSU/CSC871/csc871-anti-money-laundering-ibm-gnn/.venv/lib/python3.12/site-packages/pyg_lib/__init__.py loader=<_frozen_importlib_external.SourceFileLoader object at 0x7623ef33bc20>
find_spec('torch_sparse') -> None
find_spec('torch_geometric') -> ModuleSpec(name='torch_geometric', loader=

  from .autonotebook import tqdm as notebook_tqdm


Constructed tiny Data: Data(x=[3, 4], edge_index=[2, 2], edge_attr=[2, 3])
LinkNeighborLoader created successfully. Iterating one batch...
  Batch keys: ['num_sampled_edges', 'x', 'n_id', 'num_sampled_nodes', 'edge_label', 'edge_attr', 'edge_label_index', 'edge_index', 'input_id', 'e_id']
  batch.edge_label -> tensor([0])

--- Diagnostic complete ---


## Dataset Loading

In [3]:
import pandas as pd
from pathlib import Path

# Path to the small transactions CSV (relative to this notebook).
DATA_PATH = Path("dataset") / "HI-Small_Trans.csv"

# Load into a DataFrame
small_trans = pd.read_csv(DATA_PATH)

# Quick summary and preview
print(f"Loaded {len(small_trans)} rows; columns: {list(small_trans.columns)}")
small_trans.head()

Loaded 5078345 rows; columns: ['Timestamp', 'From Bank', 'Account', 'To Bank', 'Account.1', 'Amount Received', 'Receiving Currency', 'Amount Paid', 'Payment Currency', 'Payment Format', 'Is Laundering']


Unnamed: 0,Timestamp,From Bank,Account,To Bank,Account.1,Amount Received,Receiving Currency,Amount Paid,Payment Currency,Payment Format,Is Laundering
0,2022/09/01 00:20,10,8000EBD30,10,8000EBD30,3697.34,US Dollar,3697.34,US Dollar,Reinvestment,0
1,2022/09/01 00:20,3208,8000F4580,1,8000F5340,0.01,US Dollar,0.01,US Dollar,Cheque,0
2,2022/09/01 00:00,3209,8000F4670,3209,8000F4670,14675.57,US Dollar,14675.57,US Dollar,Reinvestment,0
3,2022/09/01 00:02,12,8000F5030,12,8000F5030,2806.97,US Dollar,2806.97,US Dollar,Reinvestment,0
4,2022/09/01 00:06,10,8000F5200,10,8000F5200,36682.97,US Dollar,36682.97,US Dollar,Reinvestment,0


In [4]:
# Use the full transaction table for modeling and keep its imbalance statistics
import numpy as np

LABEL_COL = 'Is Laundering'
RANDOM_SEED = 17

working_trans = small_trans.copy().reset_index(drop=True)

pos_count = int(working_trans[LABEL_COL].sum())
neg_count = len(working_trans) - pos_count
fraud_ratio = pos_count / max(len(working_trans), 1)
print(f'Full dataset loaded: {len(working_trans)} edges')
print(f'Fraud edges: {pos_count}, Non-fraud edges: {neg_count}, base rate {fraud_ratio:.6f}')

Full dataset loaded: 5078345 edges
Fraud edges: 5177, Non-fraud edges: 5073168, base rate 0.001019


# Imbalance Visualization Overview (IBM AML Dataset)

This section adds visual summaries of the strong class imbalance (fraud vs non‑fraud) and related distributions. Run in order after the dataset has been loaded into `small_trans`.



In [5]:
# Text-based imbalance summary for IBM AML dataset
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.stats import pointbiserialr

assert 'working_trans' in globals(), "Run the balanced sampling cell to create 'working_trans'."
df = working_trans.copy()
label_col = 'Is Laundering'
if label_col not in df.columns:
    raise KeyError(f"Expected column '{label_col}' in the dataset.")

print('=' * 72)
print('OVERALL LABEL DISTRIBUTION')
print('=' * 72)
label_counts = df[label_col].value_counts().sort_index()
total = label_counts.sum()
for label, count in label_counts.items():
    pct = 100.0 * count / total
    label_name = 'Fraud (1)' if label == 1 else 'Non-fraud (0)'
    print(f"{label_name:<15}: {count:>7} ({pct:6.3f}%)")
fraud_ratio = label_counts.get(1, 0) / max(total, 1)
print(f"Fraud ratio overall: {fraud_ratio:.5f}")

print('\n' + '=' * 72)
print('NUMERIC AMOUNT SUMMARY PER CLASS')
print('=' * 72)
amount_cols = [c for c in df.columns if any(k in c.lower() for k in ('amount', 'amt', 'value'))]
if amount_cols:
    for col in amount_cols:
        series = pd.to_numeric(df[col], errors='coerce')
        summary = df.groupby(label_col)[col].agg(['count', 'mean', 'median', 'std', 'min', 'max'])
        print(f"Column: {col}")
        print(summary.fillna(0).round(4).to_string())
        print('-' * 40)
else:
    print('No amount-like columns detected for summary.')

print('\n' + '=' * 72)
print('TEMPORAL FRAUD RATES (first 10 windows)')
print('=' * 72)
time_col = next((c for c in df.columns if any(k in c.lower() for k in ('time', 'date', 'timestamp'))), None)
if time_col:
    ts = pd.to_datetime(df[time_col], errors='coerce')
    temp_df = pd.DataFrame({'ts': ts, 'label': df[label_col]}).dropna(subset=['ts'])
    span_days = (temp_df['ts'].max() - temp_df['ts'].min()).days
    freq = 'D' if span_days >= 2 else 'H'
    counts = temp_df.set_index('ts').groupby('label').resample(freq).size().unstack(0).fillna(0)
    counts.columns = [f'label_{c}' for c in counts.columns]
    counts['total'] = counts.sum(axis=1)
    counts['fraud_rate'] = counts.get('label_1', 0) / counts['total'].replace(0, np.nan)
    print(f"Using frequency: {freq}")
    preview = counts[['label_0', 'label_1', 'total', 'fraud_rate']].head(10).fillna(0)
    print(preview.round({'fraud_rate': 4}).to_string())
else:
    print('No timestamp/date column detected for temporal summary.')

print('\n' + '=' * 72)
print('ACCOUNT PARTICIPATION SNAPSHOT (top 10)')
print('=' * 72)
sender_col = next((c for c in df.columns if any(k in c.lower() for k in ('sender', 'originator', 'from', 'account'))), None)
receiver_col = next((c for c in df.columns if any(k in c.lower() for k in ('receiver', 'beneficiary', 'to', 'account.1', 'account_1'))), None)
if sender_col and receiver_col:
    part_df = df[[sender_col, receiver_col, label_col]].copy()
    top_senders = part_df.groupby(sender_col).size().sort_values(ascending=False).head(10)
    top_receivers = part_df.groupby(receiver_col).size().sort_values(ascending=False).head(10)
    fraud_senders = part_df.groupby(sender_col)[label_col].sum().sort_values(ascending=False).head(10)
    fraud_receivers = part_df.groupby(receiver_col)[label_col].sum().sort_values(ascending=False).head(10)
    print(f"Top senders by volume ({sender_col}):\n{top_senders.to_string()}\n")
    print(f"Top receivers by volume ({receiver_col}):\n{top_receivers.to_string()}\n")
    print(f"Top senders by fraud count:\n{fraud_senders.to_string()}\n")
    print(f"Top receivers by fraud count:\n{fraud_receivers.to_string()}\n")
else:
    print('Could not identify sender/receiver columns for participation snapshot.')

print('\n' + '=' * 72)
print('NUMERIC FEATURE CORRELATIONS WITH FRAUD (top 15 abs(r))')
print('=' * 72)
num_cols = [c for c in df.select_dtypes(include=[np.number]).columns if c != label_col]
if num_cols:
    corrs = []
    y = df[label_col].values
    for c in num_cols:
        x = pd.to_numeric(df[c], errors='coerce').fillna(0).values
        try:
            r, p = pointbiserialr(y, x)
        except Exception:
            r, p = np.nan, np.nan
        corrs.append({'feature': c, 'r': r, 'p_value': p})
    corr_df = pd.DataFrame(corrs)
    corr_df['abs_r'] = corr_df['r'].abs()
    corr_df = corr_df.sort_values(by='abs_r', ascending=False).head(15)
    print(corr_df[['feature', 'r', 'p_value']].round(5).to_string(index=False))
else:
    print('No numeric columns (besides label) available for correlation analysis.')

OVERALL LABEL DISTRIBUTION
Non-fraud (0)  : 5073168 (99.898%)
Fraud (1)      :    5177 ( 0.102%)
Fraud ratio overall: 0.00102

NUMERIC AMOUNT SUMMARY PER CLASS
Column: Amount Received
                 count          mean   median           std     min           max
Is Laundering                                                                    
0              5073168  5.957962e+06  1407.51  1.036563e+09  0.0000  1.046302e+12
1                 5177  3.613531e+07  8667.21  1.527919e+09  0.0032  8.485314e+10
----------------------------------------
Column: Amount Paid
                 count          mean   median           std     min           max
Is Laundering                                                                    
0              5073168  4.477000e+06  1410.99  8.688463e+08  0.0000  1.046302e+12
1                 5177  3.613531e+07  8667.21  1.527919e+09  0.0032  8.485314e+10
----------------------------------------

TEMPORAL FRAUD RATES (first 10 windows)
Column: Amount Re

  counts = temp_df.set_index('ts').groupby('label').resample(freq).size().unstack(0).fillna(0)


Using frequency: D
            label_0  label_1    total  fraud_rate
ts                                               
2022-09-01  1114599      322  1114921      0.0003
2022-09-02   754041      408   754449      0.0005
2022-09-03   206991      391   207382      0.0019
2022-09-04   207023      407   207430      0.0020
2022-09-05   482179      471   482650      0.0010
2022-09-06   481558      531   482089      0.0011
2022-09-07   482254      497   482751      0.0010
2022-09-08   482234      539   482773      0.0011
2022-09-09   653953      514   654467      0.0008
2022-09-10   207883      442   208325      0.0021

ACCOUNT PARTICIPATION SNAPSHOT (top 10)
Top senders by volume (From Bank):
From Bank
70     449859
10      81629
12      79754
1       62211
15      52511
220     52417
20      41008
3       38413
7       31086
211     30451

Top receivers by volume (To Bank):
To Bank
10     42547
12     41872
15     38721
220    30625
1      30115
3      25627
7      23029
20     22048
28     

## Dataset Pre-processing

In [6]:
import numpy as np

# convert hex account numbers to int
hex_to_int = np.vectorize(lambda x: int(x, 16))

# create adjacency lists to represent the graph
source = hex_to_int(working_trans['Account'])
target = hex_to_int(working_trans['Account.1'])

In [7]:
from torch_geometric.data import Data
import torch

# Map account IDs to a compact 0..N-1 index space to avoid huge sparse IDs
# Concatenate unique accounts from source/target and factorize
all_accounts = np.concatenate([source, target])
unique_accounts, inverse_idx = np.unique(all_accounts, return_inverse=True)
num_nodes = unique_accounts.shape[0]
# Rebuild source/target as compact indices
source_idx = inverse_idx[:source.shape[0]]
target_idx = inverse_idx[source.shape[0]:]

# Build edge_index
edge_index = torch.tensor(np.vstack([source_idx, target_idx]), dtype=torch.long)

# Create Data object
data = Data(edge_index=edge_index, num_nodes=num_nodes)
print('num_nodes:', num_nodes, 'num_edges:', edge_index.size(1))
print(data)

num_nodes: 515080 num_edges: 5078345
Data(edge_index=[2, 5078345], num_nodes=515080)


In [8]:
import numpy as np
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data
import torch

# extract individual edge features
time = pd.to_datetime(working_trans['Timestamp']).astype('int64') / 1e9
amount_paid = working_trans['Amount Paid'].to_numpy()
amount_received = working_trans['Amount Received'].to_numpy()

# combine edge features into single tensor (standardised numeric block)
numeric_features = np.column_stack([time, amount_paid, amount_received])
scaler = StandardScaler()
numeric_scaled = scaler.fit_transform(numeric_features)
edge_features = torch.from_numpy(numeric_scaled).float()

# create edge labels
fraud_label = torch.tensor(working_trans['Is Laundering'].to_numpy(), dtype=torch.long)

# attach features and labels to PyG Data
data.edge_attr = edge_features
data.edge_label = fraud_label
print(data)


Data(edge_index=[2, 5078345], num_nodes=515080, edge_attr=[5078345, 3], edge_label=[5078345])


In [9]:
# Hyperparameters


# Increase edge_batch_size if you have ample memory and want fewer edge chunks per epoch.
edge_batch_size = 1024
# Toggle GPU usage; set to False to keep everything on CPU even if CUDA is visible.
use_gpu = torch.cuda.is_available()
# Ratio of sampled negatives to each positive edge during fallback training.
neg_pos_ratio = 6.0
# Scale factor applied to the empirical class imbalance when computing pos_weight.
pos_weight_scale = 0.05
# Optional manual override for pos_weight (set to a float to force a value).
pos_weight_override = 1.0
# Number of epochs to train for.
epochs = 20
# Hidden dimension for the GNN and edge classifier.
num_hid = 64
# Smaller learning rate to keep updates stable on imbalanced data.
learn_rate = 3e-4
# Weight decay for the optimizer.
decay = 1e-4
# Gradient clipping threshold (set <=0 to disable).
grad_clip = 1.0
# False positive rate target used when calibrating the decision threshold on validation data.
fpr_target = 0.02
# Minimum epochs before enabling regular recalibration so the loss can settle.
calibrate_warmup = 6
# How often (in epochs) to re-fit the validation ROC and refresh the threshold after warmup.
calibrate_every = 2
# Blend factor applied when updating the threshold (0=no change, 1=replace).
threshold_blend = 0.5
# Hard floor on the decision threshold to avoid runaway false positives.
threshold_floor = 0.55
# Hard ceiling on the decision threshold for numerical safety.
threshold_ceiling = 0.995
# Maximum allowed validation positive fraction before skipping a threshold update.
max_val_pos_frac = 0.25
# Maximum allowed validation FPR before skipping a threshold update.
max_val_fpr = 0.15
# Maximum allowed FPR on a raw-distribution sample when updating thresholds.
max_full_sample_fpr = 0.06
# Number of raw transactions sampled for full-distribution FPR checks.
full_val_sample_size = 50000

In [10]:
# Build a raw-distribution sample for calibration guardrails
import numpy as np
from torch_geometric.data import Data as PyGData

# Use the same RNG seed to keep sampling deterministic across runs
raw_base_df = working_trans
raw_sample_size = min(int(full_val_sample_size), len(raw_base_df))
if raw_sample_size <= 0:
    raise ValueError('Dataset is empty; cannot create raw-sample guard for calibration.')
raw_sample_idx = np.random.default_rng(RANDOM_SEED).choice(len(raw_base_df), size=raw_sample_size, replace=False)
raw_sample_df = raw_base_df.iloc[raw_sample_idx].reset_index(drop=True)

raw_source = hex_to_int(raw_sample_df['Account'])
raw_target = hex_to_int(raw_sample_df['Account.1'])

raw_all_accounts = np.concatenate([raw_source, raw_target])
raw_unique_accounts, raw_inverse_idx = np.unique(raw_all_accounts, return_inverse=True)
raw_num_nodes = raw_unique_accounts.shape[0]
raw_source_idx = raw_inverse_idx[:raw_source.shape[0]]
raw_target_idx = raw_inverse_idx[raw_source.shape[0]:]

raw_edge_index = torch.tensor(np.vstack([raw_source_idx, raw_target_idx]), dtype=torch.long)
raw_data = PyGData(edge_index=raw_edge_index, num_nodes=raw_num_nodes)

raw_deg = torch.zeros((raw_num_nodes, 1), dtype=torch.float)
raw_deg.scatter_add_(0, raw_edge_index[0].view(-1, 1), torch.ones((raw_edge_index.size(1), 1)))
raw_data.x = raw_deg

raw_time = pd.to_datetime(raw_sample_df['Timestamp']).astype('int64') / 1e9
raw_amount_paid = raw_sample_df['Amount Paid'].to_numpy()
raw_amount_received = raw_sample_df['Amount Received'].to_numpy()
raw_numeric = np.column_stack([raw_time, raw_amount_paid, raw_amount_received])
raw_numeric_scaled = scaler.transform(raw_numeric)
raw_edge_attr = torch.from_numpy(raw_numeric_scaled).float()
raw_edge_label = torch.tensor(raw_sample_df[LABEL_COL].to_numpy(), dtype=torch.long)

raw_data.edge_attr = raw_edge_attr
raw_data.edge_label = raw_edge_label

print(f'Raw-sample guard set: {raw_sample_size} edges, fraud ratio {raw_edge_label.float().mean().item():.4f}')

Raw-sample guard set: 50000 edges, fraud ratio 0.0012


In [11]:
# Stratified 60/20/20 split to keep the raw class imbalance in each subset
import numpy as np
from sklearn.model_selection import train_test_split
import torch

num_edges = data.edge_index.size(1)
all_indices = np.arange(num_edges)
labels_np = data.edge_label.cpu().numpy()

train_idx, temp_idx = train_test_split(
    all_indices,
    test_size=0.4,
    stratify=labels_np,
    random_state=RANDOM_SEED,
    shuffle=True,
    )

val_idx, test_idx = train_test_split(
    temp_idx,
    test_size=0.5,
    stratify=labels_np[temp_idx],
    random_state=RANDOM_SEED,
    shuffle=True,
    )

train_mask = torch.zeros(num_edges, dtype=torch.bool)
val_mask = torch.zeros(num_edges, dtype=torch.bool)
test_mask = torch.zeros(num_edges, dtype=torch.bool)

train_mask[torch.from_numpy(train_idx)] = True
val_mask[torch.from_numpy(val_idx)] = True
test_mask[torch.from_numpy(test_idx)] = True

data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

print('Masks set:', train_mask.sum().item(), val_mask.sum().item(), test_mask.sum().item())
print('Train positive ratio:', data.edge_label[train_mask].float().mean().item())
print('Val positive ratio:', data.edge_label[val_mask].float().mean().item())
print('Test positive ratio:', data.edge_label[test_mask].float().mean().item())

Masks set: 3047007 1015669 1015669
Train positive ratio: 0.001019360963255167
Val positive ratio: 0.0010190327884629369
Test positive ratio: 0.0010200173128396273


In [17]:
# PyG GNN model and edge classification training (batched)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv
from torch_geometric.data import Data
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Ensure data object exists with edge_index, edge_attr, edge_label, and masks
assert data is not None, 'PyG Data not constructed yet'
num_nodes = data.num_nodes
num_edges = data.edge_index.size(1)
if getattr(data, 'edge_attr', None) is None:
    raise RuntimeError('Edge features missing. Run Cell 12 (edge feature construction) before this cell.')
edge_feat_dim = data.edge_attr.size(1)

# Create simple node features if none exist (e.g., degree or identity)
if getattr(data, 'x', None) is None:
    deg = torch.zeros((num_nodes, 1), dtype=torch.float)
    deg.scatter_add_(0, data.edge_index[0].view(-1,1), torch.ones((num_edges,1)))
    data.x = deg  # use degree as a simple node feature

class GNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, edge_dim):
        super().__init__()
        self.conv1 = TransformerConv(in_channels, hidden_channels, edge_dim=edge_dim)
        self.conv2 = TransformerConv(hidden_channels, hidden_channels, edge_dim=edge_dim)
    def forward(self, x, edge_index, edge_attr):
        x = self.conv1(x, edge_index, edge_attr=edge_attr)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr=edge_attr)
        return x

class EdgeClassifier(nn.Module):
    def __init__(self, node_hidden, edge_feat_dim, hidden=num_hid):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(node_hidden*2 + edge_feat_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
    def forward(self, x, edge_index, edge_attr):
        u, v = edge_index
        h = torch.cat([x[u], x[v], edge_attr], dim=1)
        return self.mlp(h).squeeze(-1)
    def score_pairs(self, x, u, v, edge_attr):
        h = torch.cat([x[u], x[v], edge_attr], dim=1)
        return self.mlp(h).squeeze(-1)


In [None]:
import importlib, math, torch, sys, subprocess, os
from torch_geometric.loader import LinkNeighborLoader

# Detect if neighbor sampling backend (pyg_lib or torch_sparse) is available
backend_ok = bool(importlib.util.find_spec("pyg_lib") or importlib.util.find_spec("torch_sparse"))
fallback_splits = {}
if not backend_ok:
    print("Neighbor sampling backend missing: install 'pyg-lib' (preferred) or 'torch-sparse'.")
    torch_ver = torch.__version__.split('+')[0]
    cuda_ver = torch.version.cuda
    if cuda_ver is None:
        cuda_tag = 'cpu'
    else:
        cuda_tag = 'cu' + cuda_ver.replace('.', '')
    index_url = f'https://data.pyg.org/whl/torch-{torch_ver}+{cuda_tag}.html'
    print('Suggested install command (run in a terminal):')
    print(f"{sys.executable} -m pip install pyg-lib torch-sparse -f {index_url}")
    print('Fallback to full-graph edge training will be used until a backend is installed.')
    print('Note: pyg_lib currently does not support some newer PyTorch versions; fallback will be used if install fails.')

requested_gpu = use_gpu and torch.cuda.is_available()
device = torch.device('cuda') if requested_gpu else torch.device('cpu')
if device.type == 'cuda':
    dev_index = device.index if device.index is not None else torch.cuda.current_device()
    print(f'Using GPU device: {torch.cuda.get_device_name(dev_index)} (index {dev_index}).')
else:
    if use_gpu and not torch.cuda.is_available():
        print('CUDA requested but not available; falling back to CPU.')
    else:
        print('Using CPU for training (set use_gpu=True and ensure CUDA availability to use GPU).')

gnn = GNN(in_channels=data.x.size(1), hidden_channels=num_hid, edge_dim=edge_feat_dim).to(device)
clf = EdgeClassifier(node_hidden=num_hid, edge_feat_dim=edge_feat_dim, hidden=num_hid*2).to(device)

params = list(gnn.parameters()) + list(clf.parameters())
optimizer = torch.optim.Adam(params, lr=learn_rate, weight_decay=decay)

# Build edge_label_index, edge_label, and edge_label_attr tensors for each split
train_edge_label_index = data.edge_index[:, data.train_mask]
train_edge_label = data.edge_label[data.train_mask]
train_edge_attr = data.edge_attr[data.train_mask]

val_edge_label_index = data.edge_index[:, data.val_mask]
val_edge_label = data.edge_label[data.val_mask]
val_edge_attr = data.edge_attr[data.val_mask]

test_edge_label_index = data.edge_index[:, data.test_mask]
test_edge_label = data.edge_label[data.test_mask]
test_edge_attr = data.edge_attr[data.test_mask]

split_edge_label_attrs = {
    'train': train_edge_attr,
    'val': val_edge_attr,
    'test': test_edge_attr,
}

def _make_edge_label_attr_transform(edge_attr_tensor):
    def _transform(batch):
        input_id = getattr(batch, 'input_id', None)
        if input_id is None:
            batch.edge_label_attr = None
            return batch
        idx = input_id.to(torch.long)
        if idx.device.type != 'cpu':
            idx = idx.cpu()
        batch.edge_label_attr = edge_attr_tensor[idx]
        return batch
    return _transform

train_pos = int(train_edge_label.sum().item())
train_total = int(train_edge_label.numel())
train_neg = max(train_total - train_pos, 0)
base_pos_weight = (train_neg / max(train_pos, 1)) if train_pos > 0 else 1.0
if pos_weight_override is not None:
    pos_weight_value = float(pos_weight_override)
else:
    pos_weight_value = max(base_pos_weight * pos_weight_scale, 1.0)
pos_weight_tensor = torch.tensor(pos_weight_value, dtype=torch.float, device=device)
print(f'Train edges: {train_total} | positives: {train_pos} ({train_pos / max(train_total,1):.6f}) | pos_weight {pos_weight_value:.2f} | neg/pos ratio target {neg_pos_ratio:.1f}:1')

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)

# Safer defaults for batch size & neighbors to reduce per-batch time/memory
batch_size = edge_batch_size  # configurable via hyperparameter cell
num_neighbors = [10, 5]  # fewer neighbors for smaller subgraphs

fallback_mode = not backend_ok
if backend_ok:
    print(f'Backend OK: {backend_ok} | batch_size: {batch_size} | num_neighbors: {num_neighbors}')
    train_loader = LinkNeighborLoader(
        data,
        num_neighbors=num_neighbors,
        batch_size=batch_size,
        edge_label_index=train_edge_label_index,
        edge_label=train_edge_label,
        shuffle=True,
        neg_sampling_ratio=0.0,
        transform_sampler_output=_make_edge_label_attr_transform(train_edge_attr)
    )
    val_loader = LinkNeighborLoader(
        data,
        num_neighbors=num_neighbors,
        batch_size=batch_size,
        edge_label_index=val_edge_label_index,
        edge_label=val_edge_label,
        shuffle=False,
        neg_sampling_ratio=0.0,
        transform_sampler_output=_make_edge_label_attr_transform(val_edge_attr)
    )
    test_loader = LinkNeighborLoader(
        data,
        num_neighbors=num_neighbors,
        batch_size=batch_size,
        edge_label_index=test_edge_label_index,
        edge_label=test_edge_label,
        shuffle=False,
        neg_sampling_ratio=0.0,
        transform_sampler_output=_make_edge_label_attr_transform(test_edge_attr)
    )
else:
    def _split_edges(mask):
        return {
            'edge_label_index': data.edge_index[:, mask],
            'edge_label': data.edge_label[mask],
            'edge_attr': data.edge_attr[mask]
        }
    fallback_splits = {
        'train': _split_edges(data.train_mask),
        'val': _split_edges(data.val_mask),
        'test': _split_edges(data.test_mask)
    }
    train_loader = fallback_splits['train']
    val_loader = fallback_splits['val']
    test_loader = fallback_splits['test']
    train_count = fallback_splits['train']['edge_label'].numel()
    val_count = fallback_splits['val']['edge_label'].numel()
    test_count = fallback_splits['test']['edge_label'].numel()
    chunk_size = max(int(edge_batch_size), 1)
    print(f'Fallback mode active on {device.type.upper()} device: full-graph embeddings with edge chunks of {chunk_size} (train edges {train_count}, val {val_count}, test {test_count}).')
    print('edge_batch_size controls chunking in this mode; install pyg-lib or torch-sparse to enable true neighbor sampling.')

Using GPU device: NVIDIA GeForce RTX 4070 Laptop GPU (index 0).
Train edges: 3047007 | positives: 3106 (0.001019) | pos_weight 1.00 | neg/pos ratio target 6.0:1
Backend OK: True | batch_size: 1024 | num_neighbors: [10, 5]


In [None]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve
from tqdm import tqdm
import time
import math
import gc
from contextlib import nullcontext


def _gather_label_edge_attr(batch, source_attr=None):
    if hasattr(batch, 'edge_label_attr') and batch.edge_label_attr is not None:
        target_device = batch.edge_label_index.device
        return batch.edge_label_attr.to(target_device)
    if source_attr is not None and hasattr(batch, 'input_id') and batch.input_id is not None:
        idx = batch.input_id.to(torch.long)
        if idx.device.type != 'cpu':
            idx_cpu = idx.cpu()
        else:
            idx_cpu = idx
        gathered = source_attr[idx_cpu]
        return gathered.to(batch.edge_label_index.device)
    # Fallback path: reconstruct positions when LinkNeighborLoader did not supply edge_label_attr
    e_u = batch.edge_index[0].tolist()
    e_v = batch.edge_index[1].tolist()
    pos_map = {(eu, ev): i for i, (eu, ev) in enumerate(zip(e_u, e_v))}
    lu = batch.edge_label_index[0].tolist()
    lv = batch.edge_label_index[1].tolist()
    idx = [pos_map[(u, v)] for u, v in zip(lu, lv)]
    return batch.edge_attr[idx].to(batch.edge_label_index.device)


def _iter_fallback_chunks(split, chunk_size, index_subset=None):
    edge_index = split['edge_label_index']
    edge_attr = split['edge_attr']
    edge_label = split['edge_label']
    if index_subset is None:
        indices = torch.arange(edge_label.size(0))
    else:
        indices = index_subset
    total = indices.numel()
    for start in range(0, total, chunk_size):
        sel = indices[start:min(start + chunk_size, total)]
        yield edge_index[:, sel], edge_attr[sel], edge_label[sel]


def _sample_balanced_indices(labels, ratio):
    pos_idx = torch.nonzero(labels == 1, as_tuple=False).view(-1)
    neg_idx = torch.nonzero(labels == 0, as_tuple=False).view(-1)
    if pos_idx.numel() == 0:
        return torch.zeros(0, dtype=torch.long)
    neg_needed = int(math.ceil(pos_idx.numel() * ratio))
    if neg_idx.numel() == 0:
        combined = pos_idx
    else:
        neg_needed = min(max(neg_needed, pos_idx.numel()), neg_idx.numel())
        perm = torch.randperm(neg_idx.numel())
        sampled_neg = neg_idx[perm[:neg_needed]]
        combined = torch.cat([pos_idx, sampled_neg])
    shuffle = torch.randperm(combined.numel())
    return combined[shuffle]


def _select_threshold(y_true, probs, target_fpr=0.02):
    if y_true.size == 0:
        return 0.5
    fpr, tpr, thresholds = roc_curve(y_true, probs)
    if np.isnan(thresholds).all():
        return 0.5
    # Remove infinities for stability
    finite_mask = np.isfinite(thresholds)
    fpr, tpr, thresholds = fpr[finite_mask], tpr[finite_mask], thresholds[finite_mask]
    if thresholds.size == 0:
        return 0.5
    if target_fpr is not None:
        ok = np.where(fpr <= target_fpr)[0]
        if ok.size > 0:
            idx = ok[np.argmax(tpr[ok])]
        else:
            idx = np.argmin(fpr)
    else:
        youden = tpr - fpr
        idx = np.argmax(youden)
    thr = thresholds[idx]
    if np.isnan(thr):
        thr = 0.5
    return float(np.clip(thr, 1e-6, 1 - 1e-6))


def _summarise_predictions(y_true, probs, threshold):
    preds = (probs >= threshold).astype(np.int64)
    if preds.size == 0:
        return {
            'precision': 0.0,
            'recall': 0.0,
            'f1': 0.0,
            'fpr': 0.0,
            'acc': 0.0,
            'pos_frac': 0.0,
            'threshold': threshold,
            'preds': preds,
            'probs': probs,
            'labels': y_true
        }
    acc = (preds == y_true).mean()
    pr, rc, f1, _ = precision_recall_fscore_support(y_true, preds, average='binary', zero_division=0)
    cm = confusion_matrix(y_true, preds, labels=[0, 1])
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        fpr = fp / max(tn + fp, 1)
    else:
        fpr = 0.0
    pos_frac = preds.mean()
    return {
        'precision': float(pr),
        'recall': float(rc),
        'f1': float(f1),
        'fpr': float(fpr),
        'acc': float(acc),
        'pos_frac': float(pos_frac),
        'threshold': float(threshold),
        'preds': preds,
        'probs': probs,
        'labels': y_true
    }


def _refine_threshold_with_fpr_limit(probs, labels, max_fpr, fallback_threshold):
    if probs.size == 0:
        return float('nan')
    fpr, tpr, thresholds = roc_curve(labels, probs)
    finite_mask = np.isfinite(thresholds)
    fpr, thresholds = fpr[finite_mask], thresholds[finite_mask]
    if thresholds.size == 0:
        return float('nan')
    valid_idx = np.where(fpr <= max_fpr)[0]
    if valid_idx.size == 0:
        tightened = max(fallback_threshold, np.max(thresholds))
        tightened = float(np.clip(tightened, 1e-6, 0.999999))
        return tightened
    candidate = thresholds[valid_idx].max()
    if not np.isfinite(candidate):
        return float('nan')
    return float(np.clip(candidate, 1e-6, 1 - 1e-6))


def _evaluate_raw_sample(threshold):
    if 'raw_data' not in globals():
        return None
    gnn.eval(); clf.eval()
    with torch.no_grad():
        raw_batch = raw_data.clone().to(device)
        raw_logits = clf.score_pairs(
            gnn(raw_batch.x, raw_batch.edge_index, raw_batch.edge_attr),
            raw_batch.edge_index[0],
            raw_batch.edge_index[1],
            raw_batch.edge_attr
        )
        probs = torch.sigmoid(raw_logits).cpu().numpy().astype(np.float32)
        labels = raw_batch.edge_label.cpu().numpy().astype(np.int64)
    return _summarise_predictions(labels, probs, threshold)


def train_one_epoch(use_amp=True, log_every=200, neg_ratio=3.0):
    gnn.train(); clf.train()
    amp_enabled = use_amp and (device.type == 'cuda') and not fallback_mode
    scaler = torch.amp.GradScaler('cuda') if amp_enabled else None
    clip_enabled = (grad_clip is not None) and (grad_clip > 0)
    if fallback_mode:
        optimizer.zero_grad()
        x = gnn(data.x.to(device), data.edge_index.to(device), data.edge_attr.to(device))
        split = fallback_splits['train']
        labels_cpu = split['edge_label'].cpu()
        selected_indices = _sample_balanced_indices(labels_cpu, neg_ratio)
        if selected_indices.numel() == 0:
            print('No positive edges found in training split; cannot update model.')
            optimizer.zero_grad(set_to_none=True)
            return float('nan')
        total_edges = selected_indices.numel()
        chunk_size = max(int(edge_batch_size), 1)
        loss_terms = []
        for edge_idx_chunk, edge_attr_chunk, label_chunk in _iter_fallback_chunks(split, chunk_size, selected_indices):
            u = edge_idx_chunk[0].to(device)
            v = edge_idx_chunk[1].to(device)
            edge_attr = torch.nan_to_num(edge_attr_chunk, nan=0.0, posinf=0.0, neginf=0.0).to(device)
            edge_label = label_chunk.to(device).float()
            logits = clf.score_pairs(x, u, v, edge_attr)
            loss = criterion(logits, edge_label)
            if not torch.isfinite(loss):
                print('Non-finite loss encountered in fallback chunk; try lowering pos_weight or learning rate.')
                return float('nan')
            loss_terms.append(loss * edge_label.numel())
        if not loss_terms:
            print('Balanced sampling produced no batches; skipping epoch.')
            optimizer.zero_grad(set_to_none=True)
            return 0.0
        loss_total = torch.stack(loss_terms).sum() / max(total_edges, 1)
        loss_total.backward()
        if clip_enabled:
            torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
        optimizer.step()
        return float(loss_total.detach().cpu())
    # Neighbor-sampling path (backend available)
    total_loss = 0.0
    total_count = 0
    t0 = time.time()
    for i, batch in enumerate(tqdm(train_loader, desc='train_batches'), 1):
        optimizer.zero_grad()
        batch = batch.to(device)
        context = torch.amp.autocast('cuda') if scaler is not None else nullcontext()
        with context:
            x = gnn(batch.x, batch.edge_index, batch.edge_attr)
            label_edge_attr = _gather_label_edge_attr(batch, split_edge_label_attrs['train'])
            logits = clf.score_pairs(x, batch.edge_label_index[0], batch.edge_label_index[1], label_edge_attr)
            labels_float = batch.edge_label.float()
            loss = criterion(logits, labels_float)
        if not torch.isfinite(loss):
            # print(f'  batch {i} produced non-finite loss; skipping update.')
            optimizer.zero_grad(set_to_none=True)
            continue
        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            if clip_enabled:
                torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if clip_enabled:
                torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
            optimizer.step()
        batch_size_local = batch.edge_label.numel()
        total_loss += loss.item() * batch_size_local
        total_count += batch_size_local
        if i % max(log_every, 1) == 0:
            print(f'  batch {i} | batch_loss {loss.item():.4f} | elapsed {time.time()-t0:.1f}s')
    return total_loss / max(total_count, 1)


def evaluate_split(split_name, threshold=None, calibrate=False):
    gnn.eval(); clf.eval()
    chunk_size = max(int(edge_batch_size), 1)
    labels_list = []
    probs_list = []
    if fallback_mode:
        split = fallback_splits[split_name]
        with torch.no_grad():
            x = gnn(data.x.to(device), data.edge_index.to(device), data.edge_attr.to(device))
            for edge_idx_chunk, edge_attr_chunk, label_chunk in _iter_fallback_chunks(split, chunk_size):
                u = edge_idx_chunk[0].to(device)
                v = edge_idx_chunk[1].to(device)
                edge_attr = torch.nan_to_num(edge_attr_chunk, nan=0.0, posinf=0.0, neginf=0.0).to(device)
                logits = clf.score_pairs(x, u, v, edge_attr)
                probs = torch.sigmoid(logits).detach().cpu()
                probs_list.append(probs)
                labels_list.append(label_chunk.detach().cpu())
    else:
        loader_map = {
            'train': (train_loader, split_edge_label_attrs['train']),
            'val': (val_loader, split_edge_label_attrs['val']),
            'test': (test_loader, split_edge_label_attrs['test']),
        }
        loader, source_attr = loader_map[split_name]
        with torch.no_grad():
            for batch in tqdm(loader, desc=f'{split_name}_batches'):
                batch = batch.to(device)
                x = gnn(batch.x, batch.edge_index, batch.edge_attr)
                label_edge_attr = _gather_label_edge_attr(batch, source_attr)
                logits = clf.score_pairs(x, batch.edge_label_index[0], batch.edge_label_index[1], label_edge_attr)
                probs = torch.sigmoid(logits).detach().cpu()
                probs_list.append(probs)
                labels_list.append(batch.edge_label.detach().cpu())
    if not labels_list:
        empty = np.array([])
        return {
            'acc': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1': 0.0,
            'fpr': 0.0,
            'pos_frac': 0.0,
            'threshold': threshold if threshold is not None else 0.5,
            'preds': empty,
            'probs': empty,
            'labels': empty
        }
    labels = torch.cat(labels_list).numpy().astype(np.int64)
    probs = torch.cat(probs_list).numpy().astype(np.float32)
    if calibrate or threshold is None:
        chosen_threshold = _select_threshold(labels, probs, target_fpr=fpr_target)
    else:
        chosen_threshold = threshold
    summary = _summarise_predictions(labels, probs, chosen_threshold)
    return summary


log_every = 200 if not fallback_mode else 0
calibrated_threshold = 0.5
for epoch in range(1, epochs+1):
    epoch_t0 = time.time()
    avg_loss = train_one_epoch(use_amp=(device.type == 'cuda'), log_every=log_every, neg_ratio=neg_pos_ratio)
    epoch_time = time.time() - epoch_t0
    if not math.isfinite(avg_loss):
        print(f'Epoch {epoch:02d} skipped due to non-finite loss.')
        continue
    warmup_ready = epoch >= calibrate_warmup
    should_calibrate = (epoch == 1) or (warmup_ready and ((epoch - calibrate_warmup) % max(calibrate_every, 1) == 0))
    train_metrics = evaluate_split('train', threshold=calibrated_threshold)
    val_metrics = evaluate_split('val', threshold=calibrated_threshold, calibrate=should_calibrate)
    if should_calibrate:
        new_threshold = val_metrics['threshold']
        allow_update = np.isfinite(new_threshold)
        if allow_update and val_metrics['pos_frac'] > max_val_pos_frac:
            print(f"  skip threshold update: val_pos_frac {val_metrics['pos_frac']:.3f} exceeds {max_val_pos_frac:.3f}")
            allow_update = False
        if allow_update and val_metrics['fpr'] > max_val_fpr:
            print(f"  skip threshold update: val_fpr {val_metrics['fpr']:.3f} exceeds {max_val_fpr:.3f}")
            allow_update = False
        raw_guard_metrics = None
        if allow_update and 'raw_data' in globals():
            baseline_guard = _evaluate_raw_sample(new_threshold)
            if baseline_guard is not None and baseline_guard['probs'].size > 0:
                print(
                    f"  raw-sample baseline -> precision {baseline_guard['precision']:.3f}, ",
                    f"recall {baseline_guard['recall']:.3f}, fpr {baseline_guard['fpr']:.4f} at threshold {new_threshold:.4f}"
                )
                refined_threshold = _refine_threshold_with_fpr_limit(
                    baseline_guard['probs'],
                    baseline_guard['labels'],
                    max_full_sample_fpr,
                    new_threshold
                )
                if not np.isfinite(refined_threshold):
                    print(
                        f"  skip threshold update: raw_sample_fpr {baseline_guard['fpr']:.3f} exceeds {max_full_sample_fpr:.3f} and no tighter threshold meets the limit"
                    )
                    allow_update = False
                else:
                    refined_threshold = float(np.clip(refined_threshold, threshold_floor, threshold_ceiling))
                    if abs(refined_threshold - new_threshold) < 1e-6:
                        raw_guard_metrics = baseline_guard
                    else:
                        raw_guard_metrics = _summarise_predictions(
                            baseline_guard['labels'],
                            baseline_guard['probs'],
                            refined_threshold
                        )
                    if raw_guard_metrics['fpr'] > max_full_sample_fpr:
                        print(
                            f"  raw-sample FPR {raw_guard_metrics['fpr']:.3f} still above {max_full_sample_fpr:.3f}; applying best available threshold {refined_threshold:.4f} and flagging for retraining."
                        )
                    elif refined_threshold > new_threshold + 1e-6:
                        print(
                            f"  raw-sample threshold raised from {new_threshold:.4f} to {refined_threshold:.4f} to respect FPR <= {max_full_sample_fpr:.3f}"
                        )
                    new_threshold = refined_threshold
            elif baseline_guard is None or baseline_guard['probs'].size == 0:
                print('  raw-sample guard skipped: sample empty or unavailable')
        if allow_update:
            if epoch == 1:
                calibrated_threshold = float(np.clip(new_threshold, threshold_floor, threshold_ceiling))
            else:
                blended = ((1.0 - threshold_blend) * calibrated_threshold) + (threshold_blend * new_threshold)
                calibrated_threshold = float(np.clip(blended, threshold_floor, threshold_ceiling))
            if raw_guard_metrics is not None:
                print(
                    f"  raw-sample metrics -> precision {raw_guard_metrics['precision']:.3f}, ",
                    f"recall {raw_guard_metrics['recall']:.3f}, fpr {raw_guard_metrics['fpr']:.4f}"
                )
    print(
        f"Epoch {epoch:02d} | loss {avg_loss:.4f} | time {epoch_time:.1f}s | ",
        f"train_acc {train_metrics['acc']:.3f} | val_acc {val_metrics['acc']:.3f} | ",
        f"val_precision {val_metrics['precision']:.3f} | val_recall {val_metrics['recall']:.3f} | ",
        f"val_f1 {val_metrics['f1']:.3f} | val_fpr {val_metrics['fpr']:.4f} | val_thresh {calibrated_threshold:.4f} | ",
        f"val_pos_frac {val_metrics['pos_frac']:.5f}"
    )
    if device.type == 'cuda':
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
    gc.collect()

train_batches:   0%|          | 0/2976 [00:00<?, ?it/s]

train_batches:   0%|          | 0/2976 [00:00<?, ?it/s]



KeyError: (500, 499)

In [None]:
!uv pip install matplotlib matplotlib-inline

[2K[2mResolved [1m13 packages[0m [2min 157ms[0m[0m                                        [0m
         If the cache and target directories are on different filesystems, hardlinking may not be supported.
[2K[2mInstalled [1m2 packages[0m [2min 136ms[0m[0m                               [0m
 [32m+[39m [1mmatplotlib-inline[0m[2m==0.2.1[0m
 [32m+[39m [1mtraitlets[0m[2m==5.14.3[0m


In [None]:
import gc

test_metrics = evaluate_split('test', threshold=calibrated_threshold)
labels = test_metrics['labels']
preds = test_metrics['preds']
print('Test accuracy:', round(test_metrics['acc'], 6))
print('Test precision:', round(test_metrics['precision'], 6))
print('Test recall:', round(test_metrics['recall'], 6))
print('Test F1:', round(test_metrics['f1'], 6))
print('False positive rate:', round(test_metrics['fpr'], 6))
print('Predicted positive fraction:', round(test_metrics['pos_frac'], 6))
print('Decision threshold:', round(test_metrics['threshold'], 6))

if labels.size > 0 and preds.size > 0:
    cm = confusion_matrix(labels, preds, labels=[0, 1])
    disp = ConfusionMatrixDisplay(cm, display_labels=['Non-fraud', 'Fraud'])
    disp.plot(values_format='d')
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        print(f'TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}')
    else:
        print(f'Confusion matrix shape unexpected: {cm.shape}')
else:
    print('Not enough classes in test to compute CM/recall/FPR.')

if device.type == 'cuda':
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
gc.collect()

Test accuracy: 0.840266
Test precision: 0.001324
Test recall: 0.206564
Test F1: 0.002631
False positive rate: 0.159087
Predicted positive fraction: 0.159136
Decision threshold: 0.873357


ValueError: Key backend: 'module://matplotlib_inline.backend_inline' is not a valid value for backend; supported values are ['gtk3agg', 'gtk3cairo', 'gtk4agg', 'gtk4cairo', 'macosx', 'nbagg', 'notebook', 'qtagg', 'qtcairo', 'qt5agg', 'qt5cairo', 'tkagg', 'tkcairo', 'webagg', 'wx', 'wxagg', 'wxcairo', 'agg', 'cairo', 'pdf', 'pgf', 'ps', 'svg', 'template']

In [None]:
# Evaluate on the full dataset using the trained model (reuse the in-memory graph)
import numpy as np
import torch

assert 'data' in globals(), "Construct the PyG graph before running this cell."
assert 'gnn' in globals() and 'clf' in globals(), "Train the model before running a full-dataset evaluation."

full_batch = data.clone().to(device)

with torch.no_grad():
    x_full = gnn(full_batch.x, full_batch.edge_index, full_batch.edge_attr)
    logits_full = clf.score_pairs(
        x_full,
        full_batch.edge_index[0],
        full_batch.edge_index[1],
        full_batch.edge_attr
    )
    probs_full = torch.sigmoid(logits_full).cpu().numpy().astype(np.float32)

labels_full = full_batch.edge_label.cpu().numpy().astype(np.int64)
full_metrics = _summarise_predictions(labels_full, probs_full, calibrated_threshold)

print('\n=== Full-dataset evaluation ===')
print('Accuracy:', round(full_metrics['acc'], 6))
print('Precision:', round(full_metrics['precision'], 6))
print('Recall:', round(full_metrics['recall'], 6))
print('F1:', round(full_metrics['f1'], 6))
print('False positive rate:', round(full_metrics['fpr'], 6))
print('Predicted positive fraction:', round(full_metrics['pos_frac'], 6))
print('Decision threshold:', round(full_metrics['threshold'], 6))

if labels_full.size > 0 and full_metrics['preds'].size > 0:
    cm_full = confusion_matrix(labels_full, full_metrics['preds'], labels=[0, 1])
    disp_full = ConfusionMatrixDisplay(cm_full, display_labels=['Non-fraud', 'Fraud'])
    disp_full.plot(values_format='d')
    if cm_full.shape == (2, 2):
        tn, fp, fn, tp = cm_full.ravel()
        print(f'TN: {tn}, FP: {fp}, FN: {fn}, TP: {tp}')
    else:
        print(f'Confusion matrix shape unexpected: {cm_full.shape}')
else:
    print('Not enough classes to compute confusion matrix metrics on full dataset.')

if device.type == 'cuda':
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
gc.collect()


=== Full-dataset evaluation ===
Accuracy: 0.840537
Precision: 0.0013
Recall: 0.202627
F1: 0.002584
False positive rate: 0.158812
Predicted positive fraction: 0.158856
Decision threshold: 0.873357


ValueError: Key backend: 'module://matplotlib_inline.backend_inline' is not a valid value for backend; supported values are ['gtk3agg', 'gtk3cairo', 'gtk4agg', 'gtk4cairo', 'macosx', 'nbagg', 'notebook', 'qtagg', 'qtcairo', 'qt5agg', 'qt5cairo', 'tkagg', 'tkcairo', 'webagg', 'wx', 'wxagg', 'wxcairo', 'agg', 'cairo', 'pdf', 'pgf', 'ps', 'svg', 'template']