Cell 1: Imports

In [1]:
import pandas as pd
import numpy as np
import networkx as nx
import kuzu
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import precision_score, recall_score, f1_score, roc_curve, precision_recall_curve, auc
from sklearn.model_selection import KFold, StratifiedKFold
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import shutil
import scipy.sparse as sp

Cell 2: Synthetic Data

In [2]:
def generate_synthetic_data(num_entities=5000, num_transactions=25000):
    np.random.seed(42)
    # Step 1: Benign data
    countries = np.random.choice(['US', 'EU', 'ASIA', 'HIGH_RISK'], num_entities, p=[0.4, 0.3, 0.2, 0.1])
    entities = pd.DataFrame({
        'entity_id': [f'E{i:04d}' for i in range(num_entities)],
        'profile_type': np.random.choice(['individual_low', 'individual_high', 'business_small', 'business_large'], num_entities, p=[0.4, 0.1, 0.3, 0.2]),
        'country': countries,
        'agent_id': np.random.randint(1, 100, num_entities),
        'kyc_risk_score': np.random.uniform(0, 0.3, num_entities),
        'dormancy_period': np.random.randint(0, 365, num_entities)
    })
    G = nx.watts_strogatz_graph(num_entities, k=10, p=0.3, seed=42)  # Watts-Strogatz for cycles
    edges = [(f'E{u:04d}', f'E{v:04d}') for u, v in G.edges()]
    transactions = pd.DataFrame({
        'sender_id': [edges[i % len(edges)][0] for i in range(num_transactions)],
        'receiver_id': [edges[i % len(edges)][1] for i in range(num_transactions)],
        'amount': np.random.exponential(50000, num_transactions).clip(100, 1500000),
        'timestamp': np.random.randint(0, 30, num_transactions),
        'ml_flag': np.zeros(num_transactions, dtype=int),
        'typology': ['benign'] * num_transactions
    })
    transactions['is_cross_border'] = transactions.apply(lambda row: entities.loc[entities['entity_id'] == row['sender_id'], 'country'].values[0] != entities.loc[entities['entity_id'] == row['receiver_id'], 'country'].values[0], axis=1)
    transactions['high_risk_jurisdiction'] = transactions.apply(lambda row: 'HIGH_RISK' in [entities.loc[entities['entity_id'] == row['sender_id'], 'country'].values[0], entities.loc[entities['entity_id'] == row['receiver_id'], 'country'].values[0]], axis=1)
    transactions['flagged_receiver'] = np.random.choice([True, False], num_transactions, p=[0.02, 0.98])

    # Step 2: Inject typologies (2% suspicious entities)
    num_suspicious = int(0.02 * num_entities)
    suspicious_ids = np.random.choice(entities['entity_id'], num_suspicious, replace=False)
    for sid in suspicious_ids:
        entities.loc[entities['entity_id'] == sid, 'kyc_risk_score'] = np.random.uniform(0.6, 0.9)
        entities.loc[entities['entity_id'] == sid, 'dormancy_period'] = np.random.randint(200, 365)
        ts = np.random.randint(0, 3)
        
        # Typology 1: Smurfing
        large_amount = np.random.uniform(200000, 800000)
        num_smurfs = np.random.randint(10, 20)
        small_amounts = np.full(num_smurfs, large_amount / num_smurfs)
        receivers = np.random.choice(entities['entity_id'], num_smurfs, replace=False)
        smurf_tx = pd.DataFrame({
            'sender_id': [sid] * num_smurfs,
            'receiver_id': receivers,
            'amount': small_amounts,
            'timestamp': [ts] * num_smurfs,
            'ml_flag': 1,
            'is_cross_border': np.random.choice([True, False], num_smurfs, p=[0.7, 0.3]),
            'high_risk_jurisdiction': np.random.choice([True, False], num_smurfs, p=[0.5, 0.5]),
            'flagged_receiver': np.random.choice([True, False], num_smurfs, p=[0.3, 0.7]),
            'typology': ['smurfing'] * num_smurfs
        })
        transactions = pd.concat([transactions, smurf_tx], ignore_index=True)
        
        # Typology 2: Money Mules
        num_mules = np.random.randint(3, 6)
        mules = np.random.choice(entities['entity_id'], num_mules, replace=False)
        mule_amount = np.random.uniform(5000, 50000)
        for mule in mules:
            final = np.random.choice(entities[entities['country'] == 'HIGH_RISK']['entity_id'], 1)[0] if np.random.rand() > 0.5 else sid
            mule_tx = pd.DataFrame({
                'sender_id': [sid, mule],
                'receiver_id': [mule, final],
                'amount': [mule_amount] * 2,
                'timestamp': [ts, ts + 1],
                'ml_flag': 1,
                'is_cross_border': [False, True],
                'high_risk_jurisdiction': [False, True],
                'flagged_receiver': [True, False],
                'typology': ['money_mule'] * 2
            })
            transactions = pd.concat([transactions, mule_tx], ignore_index=True)
        
        # Typology 3: CLS
        layers = np.random.randint(4, 7)
        chain = [sid] + list(np.random.choice(entities['entity_id'], layers - 1, replace=False)) + [sid if np.random.rand() > 0.5 else np.random.choice(entities['entity_id'])]
        layer_amount = np.random.uniform(10000, 100000)
        for j in range(len(chain) - 1):
            cls_tx = pd.DataFrame({
                'sender_id': [chain[j]],
                'receiver_id': [chain[j+1]],
                'amount': [layer_amount],
                'timestamp': [ts + j],
                'ml_flag': 1,
                'is_cross_border': np.random.choice([True, False], 1, p=[0.8, 0.2]),
                'high_risk_jurisdiction': np.random.choice([True, False], 1, p=[0.6, 0.4]),
                'flagged_receiver': np.random.choice([True, False], 1, p=[0.4, 0.6]),
                'typology': ['cls']
            })
            transactions = pd.concat([transactions, cls_tx], ignore_index=True)

    # Step 3: Refine realism
    noise_idx = np.random.choice(transactions[transactions['typology'] == 'benign'].index, int(0.1 * num_transactions))
    transactions.loc[noise_idx, 'amount'] *= np.random.uniform(0.8, 1.2)
    transactions.loc[noise_idx, 'is_cross_border'] = np.random.choice([True, False], len(noise_idx), p=[0.4, 0.6])
    high_risk_noise = transactions[transactions['high_risk_jurisdiction']]
    if not high_risk_noise.empty:
        high_risk_noise_idx = np.random.choice(high_risk_noise.index, int(0.2 * len(high_risk_noise)))
        transactions.loc[high_risk_noise_idx, 'flagged_receiver'] = True

    entities.to_csv('entities.csv', index=False)
    transactions.to_csv('transactions.csv', index=False)
    print(f"Generated {num_entities} entities and {len(transactions)} transactions (after typologies)")
    print(transactions['typology'].value_counts())
    print(transactions['amount'].describe())
    print(transactions[['is_cross_border', 'ml_flag', 'flagged_receiver', 'high_risk_jurisdiction']].value_counts())
    print(entities[['kyc_risk_score', 'dormancy_period']].describe())
    print(f"Proportion of ml_flag = 1: {transactions['ml_flag'].mean():.4f}")
    return entities, transactions

entities, transactions = generate_synthetic_data()

Generated 5000 entities and 27714 transactions (after typologies)
typology
benign        25000
smurfing       1424
money_mule      798
cls             492
Name: count, dtype: int64
count     27714.000000
mean      47739.861370
std       47348.665181
min          81.681115
25%       15102.939172
50%       33895.386598
75%       64331.302178
max      524429.154457
Name: amount, dtype: float64
is_cross_border  ml_flag  flagged_receiver  high_risk_jurisdiction
True             0        False             False                     12443
False            0        False             False                      7625
True             0        False             True                       3260
                          True              True                        820
                 1        False             True                        713
False            1        True              False                       475
True             1        False             False                       474
       

Cell 3:

In [7]:
import pandas as pd
import numpy as np
import networkx as nx
import scipy.sparse as sp
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_curve

# Load data
entities = pd.read_csv('entities.csv')
transactions = pd.read_csv('transactions.csv')

# Validate input data
if entities.empty or transactions.empty:
    raise ValueError("Entities or transactions DataFrame is empty.")
print(f"Loaded {len(entities)} entities and {len(transactions)} transactions")

# Check for invalid or missing IDs in transactions
invalid_transactions = transactions[transactions['sender_id'].isna() | transactions['receiver_id'].isna()]
if not invalid_transactions.empty:
    print(f"Warning: {len(invalid_transactions)} transactions with missing sender_id or receiver_id")
    transactions = transactions.dropna(subset=['sender_id', 'receiver_id'])
    print(f"After dropping invalid transactions, {len(transactions)} transactions remain")

# Check for duplicate transactions
transactions['edge'] = transactions['sender_id'] + '->' + transactions['receiver_id']
duplicate_edges = transactions[transactions.duplicated(subset=['edge'], keep=False)]
if not duplicate_edges.empty:
    print(f"Warning: {len(duplicate_edges)} duplicate edges found; keeping first occurrence")
    transactions = transactions.drop_duplicates(subset=['edge'], keep='first')
    print(f"After dropping duplicates, {len(transactions)} transactions remain")

# Ensure all transaction nodes are in entities
all_entity_ids = set(entities['entity_id']).union(transactions['sender_id'], transactions['receiver_id'])
missing_entities = all_entity_ids - set(entities['entity_id'])
if missing_entities:
    print(f"Adding {len(missing_entities)} missing entities to entities DataFrame")
    missing_df = pd.DataFrame({
        'entity_id': list(missing_entities),
        'profile_type': 'UNKNOWN',
        'country': 'UNKNOWN',
        'agent_id': -1,
        'kyc_risk_score': 0,
        'dormancy_period': 0
    })
    entities = pd.concat([entities, missing_df], ignore_index=True)
    entities.to_csv('entities.csv', index=False)

# Build directed graph
G = nx.DiGraph()
for _, row in entities.iterrows():
    G.add_node(row['entity_id'], **row.to_dict())
for _, row in transactions.iterrows():
    G.add_edge(row['sender_id'], row['receiver_id'], **row.to_dict())

# Verify graph
num_nodes = G.number_of_nodes()
num_edges = G.number_of_edges()
print(f"Graph built with {num_nodes} nodes and {num_edges} edges")
if num_nodes != len(entities):
    raise ValueError(f"Graph has {num_nodes} nodes, but {len(entities)} entities expected")
if num_edges != len(transactions):
    print(f"Error: Graph has {num_edges} edges, but {len(transactions)} transactions expected")
    # Debug mismatched transactions
    graph_edges = set((u, v) for u, v in G.edges())
    transaction_edges = set((row['sender_id'], row['receiver_id']) for _, row in transactions.iterrows())
    missing_edges = transaction_edges - graph_edges
    print(f"Missing edges: {missing_edges}")
    raise ValueError(f"Graph edge mismatch; {len(missing_edges)} edges not added")

# Feature extraction
degree_dict = dict(G.out_degree())
degree_df = pd.DataFrame({'entity_id': list(degree_dict.keys()), 'degree': list(degree_dict.values())}).fillna(0)
in_degree_dict = dict(G.in_degree())
in_degree_df = pd.DataFrame({'entity_id': list(in_degree_dict.keys()), 'in_degree': list(in_degree_dict.values())}).fillna(0)
small_tx = {}
for node in G.nodes():
    small_tx[node] = sum(1 for _, _, data in G.out_edges(node, data=True) if data.get('amount', 0) < 1000)
small_tx_df = pd.DataFrame({'entity_id': list(small_tx.keys()), 'small_tx_count': list(small_tx.values())}).fillna(0)
undirected_G = G.to_undirected()
clustering_dict = nx.clustering(undirected_G)
cluster_df = pd.DataFrame({'entity_id': list(clustering_dict.keys()), 'clustering_coeff': list(clustering_dict.values())}).fillna(0)
tx_count = {node: G.out_degree(node) for node in G.nodes()}
tx_count_df = pd.DataFrame({'entity_id': list(tx_count.keys()), 'tx_count': list(tx_count.values())}).fillna(0)
mean_tx_count = tx_count_df['tx_count'].mean()
tx_count_df['tx_freq_variance'] = np.sqrt((tx_count_df['tx_count'] - mean_tx_count) ** 2).fillna(0)
tx_freq_df = tx_count_df[['entity_id', 'tx_freq_variance']]
amount_data = {}
for node in G.nodes():
    amounts = [data.get('amount', 0) for _, _, data in G.out_edges(node, data=True)]
    if len(amounts) > 0:
        avg = np.mean(amounts)
        var = np.var(amounts, ddof=1) if len(amounts) > 1 else 0
        skew = (((np.array(amounts) - avg) ** 3).mean()) / (np.std(amounts, ddof=1) ** 3 + 1e-7) if len(amounts) > 1 else 0
        high_ratio = sum(a > 100000 for a in amounts) / len(amounts)
        cycle_score = np.log1p(sum(a**2 for a in amounts) / len(amounts)) / np.log1p(max(amounts)) if max(amounts) > 0 else 0
    else:
        avg = var = skew = high_ratio = cycle_score = 0
    amount_data[node] = {'avg_amount': avg, 'amount_variance': var, 'amount_skewness': skew, 'high_value_ratio': high_ratio, 'cycle_score': cycle_score}
var_skew_df = pd.DataFrame.from_dict(amount_data, orient='index').reset_index().rename(columns={'index': 'entity_id'}).fillna(0)
time_data = {}
for node in G.nodes():
    timestamps = [data.get('timestamp', 0) for _, _, data in G.out_edges(node, data=True)]
    if len(timestamps) > 0:
        min_t, max_t = min(timestamps), max(timestamps)
        velocity = len(timestamps) / (max_t - min_t + 1e-7)
        burst = np.var(timestamps, ddof=1) * 10 if len(timestamps) > 1 else 0
        conc = 1 if (max_t - min_t < 3) else 0
    else:
        velocity = burst = conc = 0
    time_data[node] = {'tx_velocity': velocity, 'burstiness': burst, 'temporal_concentration': conc}
time_agg_df = pd.DataFrame.from_dict(time_data, orient='index').reset_index().rename(columns={'index': 'entity_id'}).fillna(0)
features = {}
for node in G.nodes():
    node_data = G.nodes[node]
    out_edges = list(G.out_edges(node, data=True))
    in_edges = list(G.in_edges(node, data=True))
    cross_border_amount = np.log1p(sum(data.get('amount', 0) for _, _, data in out_edges if data.get('is_cross_border', False)))
    tx_frequency = len(out_edges)
    directionality_ratio = len(in_edges) / (len(in_edges) + tx_frequency + 1e-7)
    features[node] = {
        'kyc_risk_score': node_data.get('kyc_risk_score', 0),
        'dormancy_period': node_data.get('dormancy_period', 0),
        'cross_border_amount': cross_border_amount,
        'in_degree': len(in_edges),
        'tx_frequency': tx_frequency,
        'directionality_ratio': directionality_ratio
    }
features_df = pd.DataFrame.from_dict(features, orient='index').reset_index().rename(columns={'index': 'entity_id'}).fillna(0)
round_trip = {}
for node in G.nodes():
    count = 0
    for neighbor in G.neighbors(node):
        for next_neighbor in G.neighbors(neighbor):
            if next_neighbor == node and neighbor != node:
                count += 1
    round_trip[node] = count
round_trip_df = pd.DataFrame({'entity_id': list(round_trip.keys()), 'round_trip_count': list(round_trip.values())}).fillna(0)

# Merge features
all_entities = pd.DataFrame({'entity_id': entities['entity_id']})
data_df = all_entities.merge(degree_df, on='entity_id', how='left')\
                     .merge(small_tx_df, on='entity_id', how='left')\
                     .merge(var_skew_df, on='entity_id', how='left')\
                     .merge(cluster_df, on='entity_id', how='left')\
                     .merge(tx_freq_df, on='entity_id', how='left')\
                     .merge(time_agg_df, on='entity_id', how='left')\
                     .merge(features_df, on='entity_id', how='left')\
                     .merge(round_trip_df, on='entity_id', how='left')\
                     .merge(in_degree_df, on='entity_id', how='left')\
                     .merge(entities[['entity_id', 'country', 'profile_type', 'agent_id']], on='entity_id', how='left')

# One-hot encode 'country'
data_df = pd.get_dummies(data_df, columns=['country'], prefix='country')

# Fill missing values
data_df.fillna({
    'degree': 0, 'small_tx_count': 0, 'avg_amount': 0, 'amount_variance': 0,
    'amount_skewness': 0, 'high_value_ratio': 0, 'cycle_score': 0, 'clustering_coeff': 0,
    'tx_freq_variance': 0, 'tx_velocity': 0, 'burstiness': 0,
    'temporal_concentration': 0, 'kyc_risk_score': 0, 'dormancy_period': 0,
    'cross_border_amount': 0, 'in_degree': 0, 'tx_frequency': 0,
    'directionality_ratio': 0, 'round_trip_count': 0,
    'profile_type': 'UNKNOWN', 'agent_id': -1
}, inplace=True)

# Compute ml_flag_score for anomaly labeling
ml_flag_dict = {}
for node in entities['entity_id']:
    ml_flags = [data.get('ml_flag', 0) for u, v, data in G.out_edges(node, data=True)]
    ml_flag_dict[node] = np.mean(ml_flags) if ml_flags else 0
ml_flag_df = pd.DataFrame({'entity_id': list(ml_flag_dict.keys()), 'ml_flag_score': list(ml_flag_dict.values())}).fillna(0)
data_df = data_df.merge(ml_flag_df, on='entity_id', how='left').fillna({'ml_flag_score': 0})

# Debug merge
if 'ml_flag_score' not in data_df.columns:
    raise ValueError("ml_flag_score column missing after merge")
print(f"ml_flag_score stats: mean={data_df['ml_flag_score'].mean():.4f}, non-zero={data_df['ml_flag_score'].gt(0).sum()}")

# Define anomaly labels
data_df['is_anomaly'] = (data_df['ml_flag_score'] > 0).astype(int)
if data_df['is_anomaly'].sum() > 1:
    fpr, tpr, thresholds = roc_curve(data_df['is_anomaly'], data_df['ml_flag_score'])
    anomaly_threshold = thresholds[np.argmax(tpr - fpr)]
    data_df['is_anomaly'] = (data_df['ml_flag_score'] > anomaly_threshold).astype(int)
else:
    print("Warning: Too few positive samples; using 95th percentile threshold")
    anomaly_threshold = np.percentile(data_df['ml_flag_score'], 95)
    data_df['is_anomaly'] = (data_df['ml_flag_score'] > anomaly_threshold).astype(int)

# Generate adjacency matrix
adj = nx.adjacency_matrix(G, nodelist=data_df['entity_id'].tolist(), weight='amount')
adj = sp.csr_matrix(adj)

# Validate features
print("Merged data_df shape:", data_df.shape)
print("NaN counts:\n", data_df.isna().sum())
print("Data types:\n", data_df.dtypes)
print("Feature stats:\n", data_df.describe())
numeric_cols = [col for col in data_df.columns if col not in ['entity_id', 'is_anomaly', 'profile_type', 'agent_id']]
if not np.all(np.isfinite(data_df[numeric_cols].select_dtypes(include=[np.number]))):
    print("Warning: Non-finite values detected; replacing with 0")
    data_df[numeric_cols] = data_df[numeric_cols].fillna(0)

# Check for leakage with robust correlation
non_onehot_cols = [col for col in numeric_cols if not col.startswith('country_') and col != 'ml_flag_score']
if non_onehot_cols:
    correlations = data_df[non_onehot_cols + ['ml_flag_score']].corr(method='spearman')['ml_flag_score'].drop('ml_flag_score', errors='ignore').fillna(0)
    print("Feature correlations with ml_flag_score:\n", correlations)
else:
    print("Warning: No non-onehot numeric columns for correlation analysis")

# Save for debugging
data_df.to_csv('merged_features.csv', index=False)
sp.save_npz('adjacency_matrix.npz', adj)

Loaded 5000 entities and 27714 transactions
After dropping duplicates, 27709 transactions remain
Graph built with 5000 nodes and 27709 edges
ml_flag_score stats: mean=0.0425, non-zero=818
Merged data_df shape: (5000, 29)
NaN counts:
 entity_id                 0
degree                    0
small_tx_count            0
avg_amount                0
amount_variance           0
amount_skewness           0
high_value_ratio          0
cycle_score               0
clustering_coeff          0
tx_freq_variance          0
tx_velocity               0
burstiness                0
temporal_concentration    0
kyc_risk_score            0
dormancy_period           0
cross_border_amount       0
in_degree_x               0
tx_frequency              0
directionality_ratio      0
round_trip_count          0
in_degree_y               0
profile_type              0
agent_id                  0
country_ASIA              0
country_EU                0
country_HIGH_RISK         0
country_US                0
ml_flag_sc

Cell 4: Neural Networks

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import precision_score, recall_score, f1_score, roc_curve, precision_recall_curve, auc
import numpy as np
import pandas as pd
import scipy.sparse as sp

# Prepare data for GNN
numeric_cols = [col for col in data_df.columns if col not in ['entity_id', 'is_anomaly', 'ml_flag_score', 'profile_type', 'agent_id']]
X = data_df[numeric_cols].values
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_tensor = torch.FloatTensor(X_scaled)
y_tensor = torch.FloatTensor(data_df['is_anomaly'].values).unsqueeze(1)

# Create adjacency matrix with proper edge handling
G = nx.DiGraph()
for _, row in transactions.iterrows():
    G.add_edge(row['sender_id'], row['receiver_id'], amount=row['amount'])
if G.number_of_edges() != len(transactions):
    print(f"Warning: Graph edges ({G.number_of_edges()}) do not match transactions ({len(transactions)})")
adj = nx.adjacency_matrix(G, nodelist=data_df['entity_id'].tolist(), weight='amount')
adj = sp.load_npz('adjacency_matrix.npz')
adj_tensor = torch.sparse_csr_tensor(
    torch.LongTensor(adj.indptr),
    torch.LongTensor(adj.indices),
    torch.FloatTensor(adj.data),
    size=adj.shape
).to_dense()

# Define Enhanced GCN (GraphSAGE-inspired)
class EnhancedGCN(nn.Module):
    def __init__(self, in_features, hidden_size1, hidden_size2, out_features=1):
        super(EnhancedGCN, self).__init__()
        self.conv1 = nn.Linear(in_features * 2, hidden_size1)  # Aggregate self + neighbors
        self.conv2 = nn.Linear(hidden_size1 * 2, hidden_size2)
        self.conv3 = nn.Linear(hidden_size2, out_features)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x, adj):
        # First layer: aggregate neighbors
        neighbor_agg = torch.mm(adj, x) / (torch.sum(adj, dim=1, keepdim=True) + 1e-7)
        x = torch.cat([x, neighbor_agg], dim=1)
        x = F.relu(self.conv1(x))
        x = self.dropout(x)
        # Second layer
        neighbor_agg = torch.mm(adj, x) / (torch.sum(adj, dim=1, keepdim=True) + 1e-7)
        x = torch.cat([x, neighbor_agg], dim=1)
        x = F.relu(self.conv2(x))
        x = self.dropout(x)
        # Output layer
        x = self.conv3(x)
        return torch.sigmoid(x)

# Initialize model
input_dim = X.shape[1]
model = EnhancedGCN(input_dim, 64, 16)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = nn.BCELoss()

# K-fold cross-validation with stratification
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
metrics = {'precision': [], 'recall': [], 'f1': [], 'auprc': [], 'auc': []}

for fold, (train_idx, test_idx) in enumerate(skf.split(X, data_df['is_anomaly'])):
    print(f"Training fold {fold + 1}/5")
    X_train, X_test = X_tensor[train_idx], X_tensor[test_idx]
    y_train, y_test = y_tensor[train_idx], y_tensor[test_idx]
    adj_train = adj_tensor[train_idx][:, train_idx]
    adj_test = adj_tensor[test_idx][:, test_idx]

    # Define institution mapping based on training data
    train_df = data_df.iloc[train_idx].copy()
    institution_cols = [col for col in train_df.columns if col.startswith('country_')]
    institution_data = train_df[institution_cols].idxmax(axis=1).map(lambda x: x.replace('country_', '')).to_dict()
    train_df['institution'] = train_df['entity_id'].map({k: v for k, v in institution_data.items() if v in ['US', 'EU', 'ASIA', 'HIGH_RISK']}).fillna('UNKNOWN')

    # Local models per institution with weighted aggregation
    local_models = {}
    inst_weights = {'US': 0.3, 'EU': 0.3, 'ASIA': 0.2, 'HIGH_RISK': 0.2, 'UNKNOWN': 0.0}
    for inst in train_df['institution'].unique():
        inst_idx = train_df[train_df['institution'] == inst].index
        if len(inst_idx) < 2:  # Skip small institutions
            continue
        relative_idx = np.searchsorted(train_idx, np.array(inst_idx))
        relative_idx = relative_idx[relative_idx < len(train_idx)]
        if len(relative_idx) < 2:
            continue
        local_X = X_train[relative_idx]
        local_y = y_train[relative_idx]
        local_adj = adj_train[relative_idx][:, relative_idx]
        local_model = EnhancedGCN(input_dim, 64, 16)
        local_opt = optim.Adam(local_model.parameters(), lr=0.001)
        for epoch in range(200):  # Increased epochs
            local_model.train()
            outputs = local_model(local_X, local_adj)
            loss = criterion(outputs, local_y)
            local_opt.zero_grad()
            loss.backward()
            local_opt.step()
        local_models[inst] = (local_model, inst_weights.get(inst, 0.0))

    # Aggregate weights with institution weighting
    for param_name in model.state_dict().keys():
        global_weight = torch.zeros_like(model.state_dict()[param_name])
        total_weight = 0.0
        for inst, (local_model, weight) in local_models.items():
            global_weight += weight * local_model.state_dict()[param_name]
            total_weight += weight
        if total_weight > 0:
            global_weight /= total_weight
        model.state_dict()[param_name].copy_(global_weight)

    # Train global model
    model.train()
    for epoch in range(300):  # Increased epochs
        outputs = model(X_train, adj_train)
        loss = criterion(outputs, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch % 20 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    # Evaluate
    model.eval()
    with torch.no_grad():
        test_outputs = model(X_test, adj_test)
        preds = test_outputs.numpy().flatten()
        y_true = y_test.numpy().flatten()

    # Dynamic threshold using Youden’s J
    if y_true.sum() > 1:
        fpr, tpr, thresholds = roc_curve(y_true, preds)
        optimal_threshold = thresholds[np.argmax(tpr - fpr)]
    else:
        print(f"Warning: Too few positive samples in fold {fold + 1}; using 95th percentile")
        optimal_threshold = np.percentile(preds, 95)
    pred_labels = (preds > optimal_threshold).astype(int)

    # Normalize anomaly scores
    preds = (preds - preds.min()) / (preds.max() - preds.min() + 1e-7)

    # Metrics
    precision = precision_score(y_true, pred_labels, zero_division=0)
    recall = recall_score(y_true, pred_labels, zero_division=0)
    f1 = f1_score(y_true, pred_labels, zero_division=0)
    precision_vals, recall_vals, _ = precision_recall_curve(y_true, preds)
    auprc = auc(recall_vals, precision_vals)
    fpr, tpr, _ = roc_curve(y_true, preds)
    roc_auc = auc(fpr, tpr)

    metrics['precision'].append(precision)
    metrics['recall'].append(recall)
    metrics['f1'].append(f1)
    metrics['auprc'].append(auprc)
    metrics['auc'].append(roc_auc)

    print(f"Fold {fold + 1} Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, AUPRC: {auprc:.4f}, AUC: {roc_auc:.4f}")
    print(f"Validation Anomaly Scores (first 10): {preds[:10]}")

# Summary
print("\nCross-validation results:")
for metric, values in metrics.items():
    print(f"{metric.capitalize()}: {np.mean(values):.4f} (±{np.std(values):.4f})")

# Save predictions
data_df['anomaly_score'] = np.zeros(len(data_df))
data_df.loc[test_idx, 'anomaly_score'] = preds
data_df['predicted_anomaly'] = (data_df['anomaly_score'] > optimal_threshold).astype(int)
data_df.to_csv('predictions.csv', index=False)

Training fold 1/5
Epoch 0, Loss: 0.6931
Epoch 20, Loss: 0.6865
Epoch 40, Loss: 0.6799
Epoch 60, Loss: 0.6735
Epoch 80, Loss: 0.6672
Epoch 100, Loss: 0.6611
Epoch 120, Loss: 0.6551
Epoch 140, Loss: 0.6493
Epoch 160, Loss: 0.6435
Epoch 180, Loss: 0.6380
Epoch 200, Loss: 0.6325
Epoch 220, Loss: 0.6272
Epoch 240, Loss: 0.6220
Epoch 260, Loss: 0.6169
Epoch 280, Loss: 0.6120
Fold 1 Precision: 0.0000, Recall: 0.0000, F1: 0.0000, AUPRC: 0.5786, AUC: 0.5011
Validation Anomaly Scores (first 10): [0.22959778 0.22959778 0.22959778 0.22959778 0.22959778 0.22959778
 0.22959778 0.22959778 0.22959778 0.22959778]
Training fold 2/5
Epoch 0, Loss: 0.6931
Epoch 20, Loss: 0.6863
Epoch 40, Loss: 0.6792
Epoch 60, Loss: 0.6722
Epoch 80, Loss: 0.6655
Epoch 100, Loss: 0.6591
Epoch 120, Loss: 0.6528
Epoch 140, Loss: 0.6468
Epoch 160, Loss: 0.6409
Epoch 180, Loss: 0.6353
Epoch 200, Loss: 0.6297
Epoch 220, Loss: 0.6244
Epoch 240, Loss: 0.6192
Epoch 260, Loss: 0.6142
Epoch 280, Loss: 0.6093
Fold 2 Precision: 0.0000