In [1]:
!pip install -q torch_geometric
!pip install -q torch_scatter torch_sparse torch_cluster torch_spline_conv \
  -f https://data.pyg.org/whl/torch-2.3.0+cu121.html

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m96.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m67.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m949.6/949.6 kB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
"""
ConFraudCLR (Contrastive Fraud Contrastive Learning Representation)
Two-stage pipeline:
  1) Self-supervised contrastive pretraining (NT-Xent)
  2) Supervised fine-tuning for fraud classification

Requirements: torch, numpy, pandas, scikit-learn, imblearn, torch_geometric
Adjust FILE_PATH and hyperparams as needed.
"""

import os
import hashlib
import random
import time
import pickle
import copy

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from imblearn.over_sampling import BorderlineSMOTE

# PyG imports
try:
    from torch_geometric.data import HeteroData
    from torch_geometric.nn import HeteroConv, SAGEConv
except Exception as e:
    raise ImportError(
        "torch_geometric not available. Install PyG for your torch/cuda version. "
        "See https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html") from e

from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                             roc_auc_score, precision_recall_curve, auc as sklearn_auc,
                             confusion_matrix, classification_report)

  import torch_geometric.typing
  import torch_geometric.typing
  import torch_geometric.typing
  import torch_geometric.typing


In [15]:
# ---------------------------
# Hyperparams and device
# ---------------------------
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Contrastive pretraining hyperparams
PRETRAIN_EPOCHS = 5
PRETRAIN_BATCH_SIZE = 2048
CONTRASTIVE_LR = 3e-4
TEMPERATURE = 0.2
EDGE_DROP_P = 0.2          # edge dropout probability for augmentation
FEAT_MASK_P = 0.15         # fraction of tx features to mask
GAUSSIAN_NOISE_STD = 0.01  # noise added to node features for augmentation

# Fine-tuning hyperparams
FINETUNE_EPOCHS = 5
FINETUNE_BATCH_SIZE = 2048
FINETUNE_LR = 1e-3
FINETUNE_UNFREEZE_ENCODER = True  # whether to fine-tune encoder weights or only classifier

# Shared model hyperparams
EMBED_DIM = 64
NUM_GNN_LAYERS = 1
PROJ_DIM = 64
LR = 1e-3

# Training stability
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [16]:
# ---------------------------
# Load & preprocess dataset
# ---------------------------
FILE_PATH = '/content/drive/MyDrive/credit card /creditcard.csv'  # change if needed
df = pd.read_csv(FILE_PATH)

# numeric columns only
df_numeric = df.select_dtypes(include=['number']).copy()
if 'Time' not in df_numeric.columns:
    df_numeric['Time'] = np.arange(len(df_numeric))

df_numeric = df_numeric.sort_values('Time').reset_index(drop=True)

y = df_numeric['Class'].astype(int).values
X = df_numeric.drop('Class', axis=1).values.astype(np.float32)
feature_names = df_numeric.drop('Class', axis=1).columns.tolist()

scaler = MinMaxScaler()
X_scaled = scaler.fit_transform(X)

# splits (stratified)
X_train, X_temp, y_train, y_temp, idx_train, idx_temp = train_test_split(
    X_scaled, y, np.arange(len(y)), test_size=0.2, random_state=SEED, shuffle=True, stratify=y
)
X_valid, X_test, y_valid, y_test, idx_valid, idx_test = train_test_split(
    X_temp, y_temp, idx_temp, test_size=0.5, random_state=SEED, shuffle=True, stratify=y_temp
)

# optional oversampling on train
sm = BorderlineSMOTE(random_state=SEED)
try:
    X_train_over, y_train_over = sm.fit_resample(X_train, y_train)
    # note: for contrastive pretraining we still use all transactions; for supervised finetune we use resampled train if desired
except Exception:
    X_train_over, y_train_over = X_train, y_train

X_all = X_scaled
y_all = y


In [17]:
# ---------------------------
# Synthetic IDs (same approach)
# ---------------------------
n_cards = 2000
n_merchants = 500
n_devices = 1500

def deterministic_hash_mod(val, mod):
    s = str(val).encode('utf-8')
    h = int(hashlib.md5(s).hexdigest()[:8], 16)
    return int(h % mod)

card_ids = np.array([deterministic_hash_mod(t, n_cards) for t in df_numeric['Time'].values])
amounts = df_numeric['Amount'].values
amount_bins = pd.qcut(amounts, q=min(50, len(amounts)//100+1), duplicates='drop').astype(str)
merchant_ids = np.array([deterministic_hash_mod(b, n_merchants) for b in amount_bins])
if all(c in df_numeric.columns for c in ['V1','V2','V3','V4']):
    device_basis = (df_numeric['V1'].astype(str) + '_' + df_numeric['V2'].astype(str) + '_' +
                    df_numeric['V3'].astype(str) + '_' + df_numeric['V4'].astype(str)).values
else:
    device_basis = df_numeric['Time'].astype(str).values
device_ids = np.array([deterministic_hash_mod(b, n_devices) for b in device_basis])

assert len(card_ids) == len(merchant_ids) == len(device_ids) == len(X_all)

In [18]:
# ---------------------------
# Build hetero graph (base)
# ---------------------------
def build_base_hetero(X_all, card_ids, merchant_ids, device_ids, n_cards, n_merchants, n_devices, y_all):
    data = HeteroData()
    data['card'].num_nodes = n_cards
    data['merchant'].num_nodes = n_merchants
    data['device'].num_nodes = n_devices

    edge_index_card_merchant = np.vstack([card_ids, merchant_ids]).astype(np.int64)
    edge_index_card_device = np.vstack([card_ids, device_ids]).astype(np.int64)

    data['card', 'to_merchant', 'merchant'].edge_index = torch.tensor(edge_index_card_merchant, dtype=torch.long)
    data['card', 'to_device', 'device'].edge_index = torch.tensor(edge_index_card_device, dtype=torch.long)

    edge_attr_tx = torch.tensor(X_all, dtype=torch.float32)
    edge_label = torch.tensor(y_all, dtype=torch.long)
    # attach per-edge attributes (same attr attached to both relations for simplicity)
    data['card', 'to_merchant', 'merchant'].edge_attr = edge_attr_tx.clone()
    data['card', 'to_merchant', 'merchant'].edge_label = edge_label.clone()
    data['card', 'to_device', 'device'].edge_attr = edge_attr_tx.clone()
    data['card', 'to_device', 'device'].edge_label = edge_label.clone()

    # init node features by aggregating transaction means
    data['card'].x = torch.zeros((n_cards, X_all.shape[1]), dtype=torch.float32)
    data['merchant'].x = torch.zeros((n_merchants, X_all.shape[1]), dtype=torch.float32)
    data['device'].x = torch.zeros((n_devices, X_all.shape[1]), dtype=torch.float32)

    card_sums = np.zeros((n_cards, X_all.shape[1]), dtype=np.float32)
    card_counts = np.zeros((n_cards,), dtype=np.int32)
    for i, c in enumerate(card_ids):
        card_sums[c] += X_all[i]; card_counts[c] += 1
    nonzero = card_counts > 0
    data['card'].x[nonzero] = torch.tensor(card_sums[nonzero] / card_counts[nonzero, None], dtype=torch.float32)

    merch_sums = np.zeros((n_merchants, X_all.shape[1]), dtype=np.float32)
    merch_counts = np.zeros((n_merchants,), dtype=np.int32)
    for i, m in enumerate(merchant_ids):
        merch_sums[m] += X_all[i]; merch_counts[m] += 1
    nonzero = merch_counts > 0
    data['merchant'].x[nonzero] = torch.tensor(merch_sums[nonzero] / merch_counts[nonzero, None], dtype=torch.float32)

    dev_sums = np.zeros((n_devices, X_all.shape[1]), dtype=np.float32)
    dev_counts = np.zeros((n_devices,), dtype=np.int32)
    for i, d in enumerate(device_ids):
        dev_sums[d] += X_all[i]; dev_counts[d] += 1
    nonzero = dev_counts > 0
    data['device'].x[nonzero] = torch.tensor(dev_sums[nonzero] / dev_counts[nonzero, None], dtype=torch.float32)

    return data

data_base = build_base_hetero(X_all, card_ids, merchant_ids, device_ids, n_cards, n_merchants, n_devices, y_all)

In [19]:
# ---------------------------
# Model definitions (encoder + projection + classifier)
# ---------------------------
class HeteroEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels=EMBED_DIM, num_layers=NUM_GNN_LAYERS, out_channels=EMBED_DIM):
        super().__init__()
        self.forward_rels = [
            ('card', 'to_merchant', 'merchant'),
            ('card', 'to_device', 'device'),
        ]
        convs = []
        for layer in range(num_layers):
            in_c = in_channels if layer == 0 else hidden_channels
            conv_dict = {}
            for (src, rel, dst) in self.forward_rels:
                conv_dict[(src, rel, dst)] = SAGEConv(in_c, hidden_channels)
                rev_rel = f"rev_{rel}"
                conv_dict[(dst, rev_rel, src)] = SAGEConv(in_c, hidden_channels)
            convs.append(HeteroConv(conv_dict, aggr='mean'))
        self.convs = nn.ModuleList(convs)
        self.node_mlps = nn.ModuleDict({
            'card': nn.Linear(hidden_channels, out_channels),
            'merchant': nn.Linear(hidden_channels, out_channels),
            'device': nn.Linear(hidden_channels, out_channels),
        })

    def forward(self, x_dict, edge_index_dict, edge_attr_dict=None):
        h = {k: v for k, v in x_dict.items()}
        local_edge_index = {}
        for key, eidx in edge_index_dict.items():
            if not isinstance(eidx, torch.Tensor):
                eidx = torch.tensor(eidx, dtype=torch.long)
            if eidx.dim() != 2 or eidx.size(0) != 2:
                raise ValueError(f"edge_index for key {key} must be shape (2,E), got {tuple(eidx.shape)}")
            local_edge_index[key] = eidx
        for (src, rel, dst) in self.forward_rels:
            fkey = (src, rel, dst)
            if fkey not in local_edge_index:
                raise KeyError(f"Expected forward edge_index for {fkey} but it's missing.")
            eidx = local_edge_index[fkey]
            rev_key = (dst, f"rev_{rel}", src)
            if rev_key not in local_edge_index:
                local_edge_index[rev_key] = torch.stack([eidx[1], eidx[0]], dim=0)
        for conv in self.convs:
            h = conv(h, local_edge_index)
            for ntype, feat in h.items():
                h[ntype] = F.gelu(feat)
        out = {}
        for ntype in ['card', 'merchant', 'device']:
            if ntype not in h:
                out[ntype] = self.node_mlps[ntype](x_dict[ntype].to(next(iter(h.values())).device))
            else:
                out[ntype] = self.node_mlps[ntype](h[ntype])
        return out

class GraphFENEncoderWrapper(nn.Module):
    """
    Wraps HeteroEncoder and produces concatenated transaction embedding
    (card_emb || merchant_emb || device_emb || optionally tx_feat)
    """
    def __init__(self, node_feature_dim, embed_dim=EMBED_DIM, proj_dim=PROJ_DIM):
        super().__init__()
        self.encoder = HeteroEncoder(in_channels=node_feature_dim, hidden_channels=embed_dim,
                                     num_layers=NUM_GNN_LAYERS, out_channels=embed_dim)
        # projection head for contrastive learning
        self.proj = nn.Sequential(
            nn.Linear(embed_dim * 3 + node_feature_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, proj_dim)
        )

    def forward_node_emb(self, data_device):
        x_dict = {ntype: data_device[ntype].x.to(DEVICE) for ntype in ['card','merchant','device']}
        edge_index_dict = {
            ('card','to_merchant','merchant'): data_device['card','to_merchant','merchant'].edge_index.to(DEVICE),
            ('card','to_device','device'): data_device['card','to_device','device'].edge_index.to(DEVICE),
        }
        node_emb = self.encoder(x_dict, edge_index_dict)
        return node_emb

    def txn_embed_and_proj(self, node_emb, card_idx_batch, merchant_idx_batch, device_idx_batch, tx_feat_batch):
        # node_emb: dict of node embeddings
        card_e = node_emb['card'][card_idx_batch]
        merch_e = node_emb['merchant'][merchant_idx_batch]
        dev_e = node_emb['device'][device_idx_batch]
        cat = torch.cat([card_e, merch_e, dev_e], dim=1)
        # concatenate tx features so projection sees transaction-level info
        full = torch.cat([cat, tx_feat_batch], dim=1)
        z = self.proj(full)  # projection
        z = F.normalize(z, dim=1)
        return z, full  # return both projection and full embedding (for classification later)

class ClassifierHead(nn.Module):
    def __init__(self, embed_plus_tx_dim, embed_dim=EMBED_DIM, num_classes=2):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(embed_plus_tx_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(embed_dim, num_classes)
        )
    def forward(self, emb):
        return self.classifier(emb)

# instantiate
node_feat_dim = X_all.shape[1]
encoder_wrapper = GraphFENEncoderWrapper(node_feature_dim=node_feat_dim, embed_dim=EMBED_DIM, proj_dim=PROJ_DIM).to(DEVICE)
classifier_head = ClassifierHead(embed_plus_tx_dim=(EMBED_DIM*3 + node_feat_dim), embed_dim=EMBED_DIM, num_classes=2).to(DEVICE)


In [20]:
# ---------------------------
# Utility: batching & augmentations
# ---------------------------
all_tx_count = X_all.shape[0]
tx_indices = np.arange(all_tx_count)

def make_edge_batch(idx_batch, tx_feat_source=None, labels_source=None):
    """
    Create tensors for the given batch indices.

    Parameters
    ----------
    idx_batch : array-like of global transaction indices (e.g. [100, 234, 512, ...])
    tx_feat_source : None or numpy/torch array
        - If None: use global X_all and index by idx_batch.
        - If provided and its first dim == len(idx_batch): treat it as per-batch features and use directly.
        - If provided and its first dim == len(X_all): index by idx_batch.
    labels_source : None or numpy/torch array
        - Same semantics as tx_feat_source; defaults to global y_all.

    Returns
    -------
    card_idx, merchant_idx, device_idx, tx_feat_tensor, labels_tensor
    """
    batch_len = len(idx_batch)

    # node indices (always from global mapping)
    card_idx = torch.tensor(card_ids[idx_batch], dtype=torch.long, device=DEVICE)
    merchant_idx = torch.tensor(merchant_ids[idx_batch], dtype=torch.long, device=DEVICE)
    device_idx = torch.tensor(device_ids[idx_batch], dtype=torch.long, device=DEVICE)

    # tx features
    if tx_feat_source is None:
        tx_feats_np = X_all  # global
    else:
        tx_feat_np = np.asarray(tx_feat_source)
        if tx_feat_np.shape[0] == batch_len:
            tx_feats_np = tx_feat_np  # per-batch already aligned
        elif tx_feat_np.shape[0] == X_all.shape[0]:
            tx_feats_np = tx_feat_np  # global, we'll index below
        else:
            raise ValueError(f"tx_feat_source has incompatible size {tx_feat_np.shape[0]} for batch_len {batch_len}")

    # Choose between indexing (global) or taking as-is (per-batch)
    if isinstance(tx_feat_source, (np.ndarray, list)) and np.asarray(tx_feat_source).shape[0] == batch_len:
        tx_feat = torch.tensor(np.asarray(tx_feat_source), dtype=torch.float32, device=DEVICE)
    else:
        # either tx_feat_source was None (so we want X_all indexed) or it's global array
        tx_feat = torch.tensor(np.asarray(tx_feats_np)[idx_batch], dtype=torch.float32, device=DEVICE)

    # labels
    if labels_source is None:
        labels_np = y_all
    else:
        labels_src_np = np.asarray(labels_source)
        if labels_src_np.shape[0] == batch_len:
            labels_np = labels_src_np
        elif labels_src_np.shape[0] == y_all.shape[0]:
            labels_np = labels_src_np
        else:
            raise ValueError(f"labels_source has incompatible size {labels_src_np.shape[0]} for batch_len {batch_len}")

    if isinstance(labels_source, (np.ndarray, list)) and np.asarray(labels_source).shape[0] == batch_len:
        labels = torch.tensor(np.asarray(labels_source), dtype=torch.long, device=DEVICE)
    else:
        labels = torch.tensor(np.asarray(labels_np)[idx_batch], dtype=torch.long, device=DEVICE)

    return card_idx, merchant_idx, device_idx, tx_feat, labels
def augment_tx_features(tx_feats_np, mask_p=FEAT_MASK_P, noise_std=GAUSSIAN_NOISE_STD):
    """
    Create an augmented view of transaction features by:
      - random feature masking (set to zero) with prob mask_p per feature
      - adding gaussian noise
    tx_feats_np: numpy array (batch, feat_dim) or torch tensor; returns numpy array
    """
    if isinstance(tx_feats_np, torch.Tensor):
        t = tx_feats_np.cpu().numpy()
    else:
        t = tx_feats_np.copy()
    mask = np.random.rand(*t.shape) > mask_p
    t_masked = t * mask
    noise = np.random.randn(*t.shape).astype(np.float32) * noise_std
    t_masked = t_masked + noise
    return t_masked.astype(np.float32)

def augment_graph_data(data: HeteroData, edge_drop_p=EDGE_DROP_P, node_noise_std=GAUSSIAN_NOISE_STD):
    """
    Returns a *shallow copy* of input HeteroData with:
      - random edge dropout applied to forward relations,
      - small gaussian noise added to node features.
    Note: we keep shapes consistent; dropout is per-edge.
    """
    new = HeteroData()
    # copy num_nodes
    new['card'].num_nodes = data['card'].num_nodes
    new['merchant'].num_nodes = data['merchant'].num_nodes
    new['device'].num_nodes = data['device'].num_nodes

    # edge dropout: keep subset of edges uniformly
    for (src, rel, dst) in [('card','to_merchant','merchant'), ('card','to_device','device')]:
        eidx = data[src, rel, dst].edge_index.cpu().numpy()
        E = eidx.shape[1]
        if E == 0:
            new[src, rel, dst].edge_index = torch.tensor(eidx, dtype=torch.long)
        else:
            keep_mask = np.random.rand(E) > edge_drop_p
            if keep_mask.sum() == 0:
                # keep at least one edge to avoid empty relations
                keep_mask[np.random.randint(0, E)] = True
            eidx_kept = eidx[:, keep_mask]
            new[src, rel, dst].edge_index = torch.tensor(eidx_kept, dtype=torch.long)

    # node features: add small gaussian noise
    for ntype in ['card','merchant','device']:
        x = data[ntype].x.cpu().numpy()
        noise = np.random.randn(*x.shape).astype(np.float32) * node_noise_std
        new[ntype].x = torch.tensor((x + noise).astype(np.float32), dtype=torch.float32)

    return new


In [21]:
# ---------------------------
# Contrastive loss (NT-Xent)
# ---------------------------
def nt_xent_loss(z_a, z_b, temperature=TEMPERATURE):
    """
    z_a, z_b: (N, D) tensor, both normalized
    Computes NT-Xent loss across pairs (i,i).
    """
    N = z_a.size(0)
    z = torch.cat([z_a, z_b], dim=0)  # (2N, D)
    sim = torch.matmul(z, z.t())  # (2N,2N)
    # divide by temperature
    sim = sim / temperature
    # mask to remove similarity with itself
    diag_mask = torch.eye(2*N, dtype=torch.bool, device=z.device)
    sim_masked = sim.masked_fill(diag_mask, -9e15)

    # positive pairs: i <-> i+N
    positives = torch.cat([torch.diag(sim, N), torch.diag(sim, -N)])  # (2N,)
    positives = positives.unsqueeze(1)

    # denominator logsumexp across other examples
    logsumexp = torch.logsumexp(sim_masked, dim=1, keepdim=True)  # (2N,1)
    loss = - positives + logsumexp
    loss = loss.mean()
    return loss


In [22]:
# ---------------------------
# Pretraining loop (contrastive)
# ---------------------------
contrastive_optimizer = torch.optim.Adam(list(encoder_wrapper.encoder.parameters()) + list(encoder_wrapper.proj.parameters()), lr=CONTRASTIVE_LR)
contrastive_history = {'pretrain_loss': []}
print("Starting contrastive pretraining...")

for epoch in range(PRETRAIN_EPOCHS):
    t0 = time.time()
    encoder_wrapper.train()
    running_loss = 0.0
    iters = 0
    perm = np.random.permutation(tx_indices)
    pbar = tqdm(range(0, len(perm), PRETRAIN_BATCH_SIZE), desc=f"Pretrain Epoch {epoch+1}/{PRETRAIN_EPOCHS}")
    for start in pbar:
        iters += 1
        batch_idx = perm[start:start+PRETRAIN_BATCH_SIZE]
        # two augmented views of the global graph (we augment globally for encoder computation)
        data_a = augment_graph_data(data_base, edge_drop_p=EDGE_DROP_P, node_noise_std=GAUSSIAN_NOISE_STD)
        data_b = augment_graph_data(data_base, edge_drop_p=EDGE_DROP_P, node_noise_std=GAUSSIAN_NOISE_STD)

        # compute node embeddings for both views (single forward per view)
        data_a_dev = HeteroData()
        data_a_dev['card'].x = data_a['card'].x.to(DEVICE)
        data_a_dev['merchant'].x = data_a['merchant'].x.to(DEVICE)
        data_a_dev['device'].x = data_a['device'].x.to(DEVICE)
        data_a_dev['card','to_merchant','merchant'].edge_index = data_a['card','to_merchant','merchant'].edge_index.to(DEVICE)
        data_a_dev['card','to_device','device'].edge_index = data_a['card','to_device','device'].edge_index.to(DEVICE)

        data_b_dev = HeteroData()
        data_b_dev['card'].x = data_b['card'].x.to(DEVICE)
        data_b_dev['merchant'].x = data_b['merchant'].x.to(DEVICE)
        data_b_dev['device'].x = data_b['device'].x.to(DEVICE)
        data_b_dev['card','to_merchant','merchant'].edge_index = data_b['card','to_merchant','merchant'].edge_index.to(DEVICE)
        data_b_dev['card','to_device','device'].edge_index = data_b['card','to_device','device'].edge_index.to(DEVICE)

        with torch.no_grad():
            # (optional) you could precompute but we want gradients through encoder, so do not torch.no_grad() here
            pass

        # compute node embeddings (with grad)
        node_emb_a = encoder_wrapper.forward_node_emb(data_a_dev)
        node_emb_b = encoder_wrapper.forward_node_emb(data_b_dev)

        # transaction-level augmentations for tx features (two views)
        tx_feat_view_a = augment_tx_features(X_all[batch_idx], mask_p=FEAT_MASK_P, noise_std=GAUSSIAN_NOISE_STD)
        tx_feat_view_b = augment_tx_features(X_all[batch_idx], mask_p=FEAT_MASK_P, noise_std=GAUSSIAN_NOISE_STD)

        card_idx_b, mer_idx_b, dev_idx_b, tx_feat_a_t, _ = make_edge_batch(batch_idx, tx_feat_view_a)
        _, _, _, tx_feat_b_t, _ = make_edge_batch(batch_idx, tx_feat_view_b)

        encoder_wrapper.train()
        contrastive_optimizer.zero_grad()
        # obtain projections
        z_a, _ = encoder_wrapper.txn_embed_and_proj(node_emb_a, card_idx_b, mer_idx_b, dev_idx_b, tx_feat_a_t)
        z_b, _ = encoder_wrapper.txn_embed_and_proj(node_emb_b, card_idx_b, mer_idx_b, dev_idx_b, tx_feat_b_t)
        loss = nt_xent_loss(z_a, z_b, temperature=TEMPERATURE)

        loss.backward()
        contrastive_optimizer.step()

        running_loss += loss.item()
        pbar.set_postfix({'pretrain_loss': running_loss / iters})

    avg_pretrain_loss = running_loss / max(1, iters)
    contrastive_history['pretrain_loss'].append(avg_pretrain_loss)
    epoch_time = time.time() - t0
    print(f"\nPretrain Epoch {epoch+1}/{PRETRAIN_EPOCHS} — time: {epoch_time:.1f}s — loss: {avg_pretrain_loss:.6f}")

# Save pretrained encoder
torch.save({'encoder_state_dict': encoder_wrapper.encoder.state_dict(),
            'proj_state_dict': encoder_wrapper.proj.state_dict()}, "confraudclr_pretrained.pth")
print("Saved pretrained encoder to confraudclr_pretrained.pth")


Starting contrastive pretraining...


Pretrain Epoch 1/5:   0%|          | 0/140 [00:00<?, ?it/s]


Pretrain Epoch 1/5 — time: 167.1s — loss: 8.078935


Pretrain Epoch 2/5:   0%|          | 0/140 [00:00<?, ?it/s]


Pretrain Epoch 2/5 — time: 154.9s — loss: 7.955515


Pretrain Epoch 3/5:   0%|          | 0/140 [00:00<?, ?it/s]


Pretrain Epoch 3/5 — time: 153.7s — loss: 7.895561


Pretrain Epoch 4/5:   0%|          | 0/140 [00:00<?, ?it/s]


Pretrain Epoch 4/5 — time: 159.6s — loss: 7.789713


Pretrain Epoch 5/5:   0%|          | 0/140 [00:00<?, ?it/s]


Pretrain Epoch 5/5 — time: 164.9s — loss: 7.245100
Saved pretrained encoder to confraudclr_pretrained.pth


In [23]:
# ---------------------------
# Fine-tuning (supervised) using pretrained encoder
# ---------------------------
# create optimizer for classifier + optionally encoder
if FINETUNE_UNFREEZE_ENCODER:
    finetune_params = list(encoder_wrapper.parameters()) + list(classifier_head.parameters())
else:
    # freeze encoder parameters
    for p in encoder_wrapper.parameters():
        p.requires_grad = False
    finetune_params = list(classifier_head.parameters())

finetune_optimizer = torch.optim.Adam(finetune_params, lr=FINETUNE_LR)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

history = {'train_loss': [], 'val_acc': [], 'val_prec': [], 'val_rec': [], 'val_f1': [], 'val_roc': [], 'val_pr_auc': [], 'val_loss': []}
best_val_f1 = 0.0

def batched_eval_supervised(encoder_wrapper, classifier_head, data, idx_list, batch_size=2048, device=DEVICE, compute_loss=True):
    encoder_wrapper.eval()
    classifier_head.eval()
    all_preds, all_probs, all_labels = [], [], []
    total_loss = 0.0
    total_samples = 0

    # prepare data on device
    data_device = HeteroData()
    data_device['card'].x = data['card'].x.to(device)
    data_device['merchant'].x = data['merchant'].x.to(device)
    data_device['device'].x = data['device'].x.to(device)
    data_device['card','to_merchant','merchant'].edge_index = data['card','to_merchant','merchant'].edge_index.to(device)
    data_device['card','to_device','device'].edge_index = data['card','to_device','device'].edge_index.to(device)

    with torch.no_grad():
        node_emb = encoder_wrapper.forward_node_emb(data_device)

    n = len(idx_list)
    for start in range(0, n, batch_size):
        batch_idx = idx_list[start:start+batch_size]
        c_idx_b, m_idx_b, d_idx_b, tx_feat_b, labels_b = make_edge_batch(batch_idx, X_all)
        with torch.no_grad():
            # produce full embedding (concatenated cat + tx_feat)
            card_e = node_emb['card'][c_idx_b]
            merch_e = node_emb['merchant'][m_idx_b]
            dev_e = node_emb['device'][d_idx_b]
            cat = torch.cat([card_e, merch_e, dev_e], dim=1)
            full = torch.cat([cat, tx_feat_b], dim=1)
            logits = classifier_head(full)
            probs = F.softmax(logits, dim=1).cpu().numpy()
            preds = probs.argmax(axis=1)
            if compute_loss:
                loss = criterion(logits, labels_b)
                total_loss += loss.item() * labels_b.size(0)
                total_samples += labels_b.size(0)
        all_preds.append(preds); all_probs.append(probs); all_labels.append(labels_b.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)

    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, zero_division=0)
    rec = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)

    if len(np.unique(all_labels)) == 2:
        try:
            roc = roc_auc_score(all_labels, all_probs[:,1])
            p, r, _ = precision_recall_curve(all_labels, all_probs[:,1])
            pr_auc = sklearn_auc(r, p)
        except Exception:
            roc = float('nan'); pr_auc = float('nan')
    else:
        roc = float('nan'); pr_auc = float('nan')

    loss_avg = (total_loss / total_samples) if compute_loss and total_samples > 0 else float('nan')
    cm = confusion_matrix(all_labels, all_preds)
    return {'acc': acc, 'prec': prec, 'rec': rec, 'f1': f1, 'roc': roc, 'pr_auc': pr_auc, 'cm': cm, 'loss': loss_avg}

print("Starting supervised fine-tuning...")

for epoch in range(FINETUNE_EPOCHS):
    t0 = time.time()
    encoder_wrapper.train()
    classifier_head.train()

    # compute node embeddings once per epoch (we want to reuse embeddings; if unfreezing encoder we'd recompute per batch in strict fine-tune)
    # If FINETUNE_UNFREEZE_ENCODER==True we will recompute node_emb per mini-batch with gradients.
    perm = np.random.permutation(idx_train if 'idx_train' in globals() else np.arange(len(X_train)))
    running_loss = 0.0
    iters = 0
    pbar = tqdm(range(0, len(perm), FINETUNE_BATCH_SIZE), desc=f"Finetune Epoch {epoch+1}/{FINETUNE_EPOCHS}")
    for start in pbar:
        iters += 1
        batch_idx = perm[start:start+FINETUNE_BATCH_SIZE]

        # prepare node embeddings:
        if FINETUNE_UNFREEZE_ENCODER:
            # recompute node embeddings on the fly with potential gradient flow
            data_device = HeteroData()
            data_device['card'].x = data_base['card'].x.to(DEVICE)
            data_device['merchant'].x = data_base['merchant'].x.to(DEVICE)
            data_device['device'].x = data_base['device'].x.to(DEVICE)
            data_device['card','to_merchant','merchant'].edge_index = data_base['card','to_merchant','merchant'].edge_index.to(DEVICE)
            data_device['card','to_device','device'].edge_index = data_base['card','to_device','device'].edge_index.to(DEVICE)
            node_emb = encoder_wrapper.forward_node_emb(data_device)
        else:
            with torch.no_grad():
                data_device = HeteroData()
                data_device['card'].x = data_base['card'].x.to(DEVICE)
                data_device['merchant'].x = data_base['merchant'].x.to(DEVICE)
                data_device['device'].x = data_base['device'].x.to(DEVICE)
                data_device['card','to_merchant','merchant'].edge_index = data_base['card','to_merchant','merchant'].edge_index.to(DEVICE)
                data_device['card','to_device','device'].edge_index = data_base['card','to_device','device'].edge_index.to(DEVICE)
                node_emb = encoder_wrapper.forward_node_emb(data_device)

        card_idx_b, merch_idx_b, dev_idx_b, tx_feat_b, labels_b = make_edge_batch(batch_idx, X_all)

        finetune_optimizer.zero_grad()
        with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
            # make full embedding
            card_e = node_emb['card'][card_idx_b]
            merch_e = node_emb['merchant'][merch_idx_b]
            dev_e = node_emb['device'][dev_idx_b]
            cat = torch.cat([card_e, merch_e, dev_e], dim=1)
            full = torch.cat([cat, tx_feat_b], dim=1)
            logits = classifier_head(full)
            loss = criterion(logits, labels_b)

        scaler.scale(loss).backward()
        scaler.step(finetune_optimizer)
        scaler.update()

        running_loss += loss.item()
        pbar.set_postfix({'train_loss': running_loss / iters})

    avg_train_loss = running_loss / max(1, iters)
    history['train_loss'].append(avg_train_loss)

    # validation
    val_metrics = batched_eval_supervised(encoder_wrapper, classifier_head, data_base, idx_valid, batch_size=FINETUNE_BATCH_SIZE, device=DEVICE, compute_loss=True)
    history['val_acc'].append(val_metrics['acc'])
    history['val_prec'].append(val_metrics['prec'])
    history['val_rec'].append(val_metrics['rec'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_roc'].append(val_metrics['roc'])
    history['val_pr_auc'].append(val_metrics['pr_auc'])
    history['val_loss'].append(val_metrics['loss'])

    epoch_time = time.time() - t0
    print(f"\nFinetune Epoch {epoch+1}/{FINETUNE_EPOCHS} — time: {epoch_time:.1f}s — train_loss: {avg_train_loss:.6f}")
    print(f" Val -> acc: {val_metrics['acc']*100:.4f}%  prec: {val_metrics['prec']*100:.4f}%  rec: {val_metrics['rec']*100:.4f}%  f1: {val_metrics['f1']*100:.4f}%  roc: {val_metrics['roc']:.6f}  pr_auc: {val_metrics['pr_auc']:.6f}  val_loss: {val_metrics['loss']:.6f}")

    if val_metrics['f1'] > best_val_f1:
        best_val_f1 = val_metrics['f1']
        torch.save({
            'encoder_state_dict': encoder_wrapper.encoder.state_dict(),
            'proj_state_dict': encoder_wrapper.proj.state_dict(),
            'classifier_state_dict': classifier_head.state_dict()
        }, "confraudclr_best_finetuned.pth")
        print("Saved best fine-tuned model.")

# ---------------------------
# Final test evaluation
# ---------------------------
test_metrics = batched_eval_supervised(encoder_wrapper, classifier_head, data_base, idx_test, batch_size=FINETUNE_BATCH_SIZE, device=DEVICE, compute_loss=True)
print("\n===== Final Test results =====")
print(f"Test loss: {test_metrics['loss']:.6f}")
print(f"Acc: {test_metrics['acc']*100:.4f}%")
print(f"Precision: {test_metrics['prec']*100:.4f}%")
print(f"Recall: {test_metrics['rec']*100:.4f}%")
print(f"F1: {test_metrics['f1']*100:.4f}%")
print(f"ROC AUC: {test_metrics['roc']:.6f}")
print(f"PR AUC: {test_metrics['pr_auc']:.6f}")
print("Confusion matrix:\n", test_metrics['cm'])

preds_all = []
labels_all = []
def collect_preds_supervised(encoder_wrapper, classifier_head, data, idx_list, batch_size=2048, device=DEVICE):
    encoder_wrapper.eval(); classifier_head.eval()
    all_preds, all_labels = [], []
    data_device = HeteroData()
    data_device['card'].x = data['card'].x.to(device)
    data_device['merchant'].x = data['merchant'].x.to(device)
    data_device['device'].x = data['device'].x.to(device)
    data_device['card','to_merchant','merchant'].edge_index = data['card','to_merchant','merchant'].edge_index.to(device)
    data_device['card','to_device','device'].edge_index = data['card','to_device','device'].edge_index.to(device)
    with torch.no_grad():
        node_emb = encoder_wrapper.forward_node_emb(data_device)
    n = len(idx_list)
    for start in range(0, n, batch_size):
        batch_idx = idx_list[start:start+batch_size]
        c_idx_b, m_idx_b, d_idx_b, tx_feat_b, labels_b = make_edge_batch(batch_idx, X_all)
        with torch.no_grad():
            card_e = node_emb['card'][c_idx_b]
            merch_e = node_emb['merchant'][m_idx_b]
            dev_e = node_emb['device'][d_idx_b]
            cat = torch.cat([card_e, merch_e, dev_e], dim=1)
            full = torch.cat([cat, tx_feat_b], dim=1)
            logits = classifier_head(full)
            probs = F.softmax(logits, dim=1).cpu().numpy()
            preds = probs.argmax(axis=1)
        all_preds.append(preds); all_labels.append(labels_b.cpu().numpy())
    return np.concatenate(all_preds), np.concatenate(all_labels)

preds_all, labels_all = collect_preds_supervised(encoder_wrapper, classifier_head, data_base, idx_test, batch_size=FINETUNE_BATCH_SIZE, device=DEVICE)
print("\nClassification report (test):\n", classification_report(labels_all, preds_all, digits=4, zero_division=0))

# ---------------------------
# Save final artifacts
# ---------------------------
torch.save({
    'encoder_state_dict': encoder_wrapper.encoder.state_dict(),
    'proj_state_dict': encoder_wrapper.proj.state_dict(),
    'classifier_state_dict': classifier_head.state_dict()
}, "confraudclr_final.pth")

with open("confraudclr_artifacts.pkl", "wb") as f:
    pickle.dump({
        'scaler': scaler,
        'card_ids': card_ids,
        'merchant_ids': merchant_ids,
        'device_ids': device_ids,
        'feature_names': feature_names,
        'contrastive_history': contrastive_history,
        'finetune_history': history
    }, f)

print("\nSaved model to confraudclr_final.pth and artifacts to confraudclr_artifacts.pkl")

Starting supervised fine-tuning...


  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


Finetune Epoch 1/5:   0%|          | 0/112 [00:00<?, ?it/s]


Finetune Epoch 1/5 — time: 34.4s — train_loss: 0.041832
 Val -> acc: 99.8280%  prec: 0.0000%  rec: 0.0000%  f1: 0.0000%  roc: 0.866223  pr_auc: 0.355214  val_loss: 0.012648


Finetune Epoch 2/5:   0%|          | 0/112 [00:00<?, ?it/s]


Finetune Epoch 2/5 — time: 34.0s — train_loss: 0.013363
 Val -> acc: 99.8280%  prec: 0.0000%  rec: 0.0000%  f1: 0.0000%  roc: 0.881621  pr_auc: 0.407424  val_loss: 0.012552


Finetune Epoch 3/5:   0%|          | 0/112 [00:00<?, ?it/s]


Finetune Epoch 3/5 — time: 32.8s — train_loss: 0.013151
 Val -> acc: 99.8280%  prec: 0.0000%  rec: 0.0000%  f1: 0.0000%  roc: 0.897457  pr_auc: 0.478955  val_loss: 0.012494


Finetune Epoch 4/5:   0%|          | 0/112 [00:00<?, ?it/s]


Finetune Epoch 4/5 — time: 34.5s — train_loss: 0.012949
 Val -> acc: 99.8280%  prec: 0.0000%  rec: 0.0000%  f1: 0.0000%  roc: 0.911462  pr_auc: 0.547916  val_loss: 0.012469


Finetune Epoch 5/5:   0%|          | 0/112 [00:00<?, ?it/s]


Finetune Epoch 5/5 — time: 33.9s — train_loss: 0.012927
 Val -> acc: 99.8280%  prec: 0.0000%  rec: 0.0000%  f1: 0.0000%  roc: 0.924836  pr_auc: 0.605780  val_loss: 0.012384

===== Final Test results =====
Test loss: 0.012375
Acc: 99.8280%
Precision: 0.0000%
Recall: 0.0000%
F1: 0.0000%
ROC AUC: 0.939633
PR AUC: 0.553029
Confusion matrix:
 [[28432     0]
 [   49     0]]

Classification report (test):
               precision    recall  f1-score   support

           0     0.9983    1.0000    0.9991     28432
           1     0.0000    0.0000    0.0000        49

    accuracy                         0.9983     28481
   macro avg     0.4991    0.5000    0.4996     28481
weighted avg     0.9966    0.9983    0.9974     28481


Saved model to confraudclr_final.pth and artifacts to confraudclr_artifacts.pkl
