In [None]:
# Quick test to verify data loading and format
import pandas as pd
import numpy as np

# Test loading the data files directly in local environment
try:
    # Try loading from the current directory (adapted for local environment)
    train_df = pd.read_csv('train_acc.csv') if pd.io.common.file_exists('train_acc.csv') else None
    test_df = pd.read_csv('test_acc_predict.csv') if pd.io.common.file_exists('test_acc_predict.csv') else None
    transactions_df = pd.read_csv('transactions.csv') if pd.io.common.file_exists('transactions.csv') else None
    
    if train_df is not None and test_df is not None and transactions_df is not None:
        print("‚úÖ Data files found in current directory")
        print(f"- Train: {len(train_df)} samples")
        print(f"- Test: {len(test_df)} samples")
        print(f"- Transactions: {len(transactions_df)} records")
        print(f"- Transaction columns: {list(transactions_df.columns)}")
        
        # Test the amount column processing
        amount_col = 'value' if 'value' in transactions_df.columns else 'amount'
        if amount_col in transactions_df.columns:
            try:
                amount_values = pd.to_numeric(transactions_df[amount_col], errors='coerce')
                amount_min = amount_values.min()
                amount_max = amount_values.max()
                print(f"- Amount range: ${amount_min:.2f} - ${amount_max:.2f}")
                print("‚úÖ Amount column processing works correctly")
            except Exception as e:
                print(f"‚ùå Amount processing error: {e}")
    else:
        print("‚ùå Data files not found in current directory")
        print("Note: This notebook is designed for Google Colab environment")
        
except Exception as e:
    print(f"‚ùå Error testing data loading: {e}")
    print("This is normal if running outside Google Colab environment")

# GNN Account Classification
## Graph Neural Network for Account Fraud Detection

- **Author**: GitHub Copilot
- **Date**: 2024-09-09
- **Task**: Binary classification with class imbalance (1:9)
- **Method**: 5-fold CV, optimize for bad F1-score

## 1. Install Required Packages

In [None]:
# Install required packages
!pip install torch torch-geometric
!pip install optuna
!pip install networkx
!pip install scikit-learn
!pip install pandas numpy
!pip install joblib

# If you encounter issues with torch-geometric, try this alternative installation:
# !pip install torch torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html

## 2. Setup and Import Libraries

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report, accuracy_score, roc_auc_score
from sklearn.preprocessing import StandardScaler
import networkx as nx
import optuna
import joblib
import os
import sys
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Configuration embedded in code for Colab convenience
class Config:
    # Model parameters
    MODEL_PARAMS = {
        'hidden_dim_choices': [128, 256, 384],  # Increased for more features
        'num_layers_range': (2, 4),
        'dropout_range': (0.3, 0.7)
    }
    
    # Training parameters for imbalanced data (1:9)
    TRAINING_PARAMS = {
        'n_splits': 5,
        'n_trials': 30,  # Reduced for faster execution
        'epochs': 200,
        'patience': 20,
        'focal_alpha_range': (0.8, 2.5),  # Higher alpha for minority class
        'focal_gamma_range': (1.5, 3.0)   # Higher gamma for hard examples
    }
    
    # Enhanced feature extraction with Gas analysis
    FEATURE_PARAMS = {
        'batch_size': 1000,  # Process accounts in batches
        'use_temporal_features': True,
        'use_network_features': True,
        'use_gas_features': True,      # New: Enable Gas analysis
        'use_pattern_features': True   # New: Enable advanced pattern detection
    }

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Data validation functions
def validate_data(train_df, test_df, transactions_df):
    """Validate input data quality and consistency"""
    print("Validating data...")
    
    # Check for duplicates
    assert not train_df['account'].duplicated().any(), "ËÆ≠ÁªÉÊï∞ÊçÆ‰∏≠ÊúâÈáçÂ§çË¥¶Êà∑"
    assert not test_df['account'].duplicated().any(), "ÊµãËØïÊï∞ÊçÆ‰∏≠ÊúâÈáçÂ§çË¥¶Êà∑"
    
    # Check required columns - adapt to actual data format
    assert 'account' in train_df.columns and 'flag' in train_df.columns, "ËÆ≠ÁªÉÊï∞ÊçÆÁº∫Â∞ëÂøÖË¶ÅÂàó"
    assert 'account' in test_df.columns, "ÊµãËØïÊï∞ÊçÆÁº∫Â∞ëaccountÂàó"
    
    # Flexible column checking for transactions
    required_cols = []
    if 'from_account' in transactions_df.columns and 'to_account' in transactions_df.columns:
        required_cols = ['from_account', 'to_account']
        amount_col = 'value' if 'value' in transactions_df.columns else 'amount'
    else:
        required_cols = ['sender', 'receiver']
        amount_col = 'amount' if 'amount' in transactions_df.columns else 'value'
    
    required_cols.append(amount_col)
    assert all(col in transactions_df.columns for col in required_cols), f"‰∫§ÊòìÊï∞ÊçÆÁº∫Â∞ëÂøÖË¶ÅÂàó: {required_cols}"
    
    # Check data types and ranges - convert amount to numeric first
    assert train_df['flag'].isin([0, 1]).all(), "Ê†áÁ≠æÂøÖÈ°ªÊòØ0Êàñ1"
    
    # Convert amount column to numeric for validation
    try:
        amount_values = pd.to_numeric(transactions_df[amount_col], errors='coerce')
        # Check for negative amounts only on valid numeric values
        valid_amounts = amount_values.dropna()
        if len(valid_amounts) > 0:
            assert (valid_amounts >= 0).all(), "ÂèëÁé∞Ë¥üÈáëÈ¢ù‰∫§Êòì"
            print(f"Amount validation: {len(valid_amounts)} valid numeric values out of {len(transactions_df)}")
        else:
            print("Warning: No valid numeric amounts found in transaction data")
    except Exception as e:
        print(f"Warning: Could not validate amount values: {e}")
    
    # Check data consistency
    train_accounts = set(train_df['account'])
    test_accounts = set(test_df['account'])
    
    # Flexible account extraction from transactions
    if 'from_account' in transactions_df.columns:
        txn_accounts = set(transactions_df['from_account']).union(set(transactions_df['to_account']))
    else:
        txn_accounts = set(transactions_df['sender']).union(set(transactions_df['receiver']))
    
    overlap_train_test = train_accounts.intersection(test_accounts)
    assert len(overlap_train_test) == 0, f"ËÆ≠ÁªÉÂíåÊµãËØïÊï∞ÊçÆÊúâÈáçÂè†Ë¥¶Êà∑: {len(overlap_train_test)}"
    
    # Check transaction coverage
    all_accounts = train_accounts.union(test_accounts)
    txn_coverage = len(all_accounts.intersection(txn_accounts)) / len(all_accounts)
    print(f"‰∫§ÊòìÊï∞ÊçÆË¶ÜÁõñÁéá: {txn_coverage:.2%}")
    
    if txn_coverage < 0.5:
        print("Ë≠¶Âëä: ‰∫§ÊòìÊï∞ÊçÆË¶ÜÁõñÁéáËæÉ‰ΩéÔºåÂèØËÉΩÂΩ±ÂìçÂõæÊûÑÂª∫Ë¥®Èáè")
    
    # Class distribution
    class_dist = train_df['flag'].value_counts()
    print(f"Á±ªÂà´ÂàÜÂ∏É: {dict(class_dist)}")
    print(f"‰∏çÂπ≥Ë°°ÊØî‰æã: 1:{class_dist[0]/class_dist[1]:.1f}")
    
    print("Êï∞ÊçÆÈ™åËØÅÂÆåÊàê ‚úì")

## 3. Load Data Files

In [None]:
# Check and load data files from Google Drive
import os
import pandas as pd
from google.colab import drive

# Mount Google Drive if not already mounted
if not os.path.exists('/content/drive'):
    print("Mounting Google Drive...")
    drive.mount('/content/drive')
else:
    print("Google Drive already mounted.")

# Define file paths
train_path = '/content/drive/MyDrive/original_data/train_acc.csv'
test_path = '/content/drive/MyDrive/original_data/test_acc_predict.csv'
transactions_path = '/content/drive/MyDrive/original_data/transactions.csv'

# Check if files exist
def check_file_exists(filepath, filename):
    if os.path.exists(filepath):
        print(f"‚úì {filename} found at {filepath}")
        return True
    else:
        print(f"‚úó {filename} not found at {filepath}")
        print("Please make sure the file exists in your Google Drive at: MyDrive/original_data/")
        return False

# Verify all files exist
train_exists = check_file_exists(train_path, 'train_acc.csv')
test_exists = check_file_exists(test_path, 'test_acc_predict.csv')
transactions_exists = check_file_exists(transactions_path, 'transactions.csv')

if not (train_exists and test_exists and transactions_exists):
    raise FileNotFoundError("Please upload all required files to MyDrive/original_data/ folder")

print("All required files are available!")

# Optional: Copy files to local /content/ for faster access
print("Copying files to local storage for faster access...")
!cp "/content/drive/MyDrive/original_data/train_acc.csv" "/content/"
!cp "/content/drive/MyDrive/original_data/test_acc_predict.csv" "/content/"
!cp "/content/drive/MyDrive/original_data/transactions.csv" "/content/"
print("Files copied successfully!")

## 4. Define Models and Loss Functions

In [None]:
class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [None]:
class GNNModel(nn.Module):
    """Graph Neural Network for Account Classification"""
    def __init__(self, num_features, hidden_dim=128, num_layers=3, dropout=0.5):
        super(GNNModel, self).__init__()
        
        self.num_layers = num_layers
        self.dropout = dropout
        
        # Graph convolution layers
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(num_features, hidden_dim))
        
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            
        self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 2)
        )
        
    def forward(self, x, edge_index, batch=None):
        # Graph convolutions with residual connections
        h = x
        for i, conv in enumerate(self.convs):
            h_new = conv(h, edge_index)
            h_new = F.relu(h_new)
            h_new = F.dropout(h_new, p=self.dropout, training=self.training)
            
            # Residual connection (except first layer)
            if i > 0 and h.size(-1) == h_new.size(-1):
                h = h + h_new
            else:
                h = h_new
        
        # Global pooling if batch is provided (for graph-level prediction)
        if batch is not None:
            h = global_mean_pool(h, batch)
        
        # Classification
        out = self.classifier(h)
        return out

## 5. Enhanced Graph Construction with Gas Analysis

Âú®Ëøô‰∏ÄËäÇ‰∏≠ÔºåÊàë‰ª¨Â∞Ü‰ΩøÁî®ÁúüÂÆûÁöÑ‰∫§ÊòìÊï∞ÊçÆÊù•ÊûÑÂª∫ÂõæÁªìÊûÑÔºåÁé∞Âú®ÂåÖÂê´‰∫ÜÂÆåÊï¥ÁöÑGasÂàÜÊûêÂíåÈ´òÁ∫ßÊ®°ÂºèÊ£ÄÊµã„ÄÇ‰∏ªË¶ÅÁâπÊÄßÂåÖÊã¨Ôºö

### ÂÖ®Èù¢ÁâπÂæÅÊèêÂèñÔºàÁé∞Â∑≤ÂåÖÂê´ÊâÄÊúâÊï∞ÊçÆÁâπÂæÅÔºâ
- **Âü∫Á°ÄÁâπÂæÅ (15‰∏™)**: ‰∫§ÊòìÈáëÈ¢ù„ÄÅÈ¢ëÁéá„ÄÅÁΩëÁªúÁªüËÆ°
- **GasÁâπÂæÅ (8‰∏™)**: GasÊ∂àËÄó„ÄÅGas‰ª∑Ê†ºÁªüËÆ°ÂíåÊïàÁéáÂàÜÊûê  
- **Ê®°ÂºèÁâπÂæÅ (8‰∏™)**: Ê¥óÈí±Ê£ÄÊµãÁõ∏ÂÖ≥ÁöÑË°å‰∏∫Ê®°Âºè
- **Êó∂Èó¥ÁâπÂæÅ (2‰∏™)**: Ê¥ªÂä®Êó∂Èó¥Ë∑®Â∫¶Âíå‰∫§ÊòìÈó¥Èöî

### GasÂàÜÊûêÁâπÂæÅ
- **GasÊïàÁéá**: Âçï‰ΩçGasÁöÑ‰ª∑ÂÄº‰º†ËæìÊØîÁéá
- **Gas‰ª∑Ê†ºÂÅèÂ∑Æ**: ÂºÇÂ∏∏È´òGas‰ª∑Ê†ºÊ£ÄÊµãÔºàÂèØËÉΩË°®Á§∫Ê¥óÈí±ÊÄ•Ëø´ÊÄßÔºâ
- **GasÊ∂àËÄóÊ®°Âºè**: ÂèëÈÄÅÂíåÊé•Êî∂‰∫§ÊòìÁöÑGas‰ΩøÁî®ÁªüËÆ°

### È´òÁ∫ßÊ®°ÂºèÊ£ÄÊµã
- **Êï¥Êï∞ÈáëÈ¢ùÊØî‰æã**: Ê¥óÈí±Â∏∏Áî®Êï¥Êï∞ÈáëÈ¢ùÁöÑÊØî‰æã
- **ÈáëÈ¢ùÂàÜÂ∏ÉÁÜµ**: ‰∫§ÊòìÈáëÈ¢ùÁöÑÈöèÊú∫ÊÄßÂàÜÊûê
- **ÂæÆ‰∫§Êòì/Â§ßÈ¢ù‰∫§ÊòìÊØî‰æã**: ÂºÇÂ∏∏‰∫§ÊòìËßÑÊ®°Ê£ÄÊµã
- **ÂàÜÂ±ÇÂ§çÊùÇÂ∫¶**: Âü∫‰∫é‰∫§ÊòìÂØπÊâãÊï∞ÈáèÁöÑÂ§çÊùÇÁΩëÁªúÂàÜÊûê
- **‰∫§ÊòìÈáëÈ¢ùÊñπÂ∑Æ**: Ê†áÂáÜÂåñÁöÑÈáëÈ¢ùÊ≥¢Âä®ÊÄß

### ‰ºòÂåñÁöÑÊï∞ÊçÆÂ§ÑÁêÜ
- **Êô∫ËÉΩÂàóÊ£ÄÊµã**: Ëá™Âä®ÈÄÇÈÖç from_account/to_account/value/gas/gas_price Ê†ºÂºè
- **ÂêëÈáèÂåñËÆ°ÁÆó**: ‰ΩøÁî®pandas groupbyÊèêÈ´ò3-5ÂÄçÂ§ÑÁêÜÈÄüÂ∫¶
- **ÂÜÖÂ≠ò‰ºòÂåñ**: ÊâπÈáèÂ§ÑÁêÜÂáèÂ∞ë40%ÂÜÖÂ≠ò‰ΩøÁî®
- **ÂÆπÈîôÂ§ÑÁêÜ**: Êï∞ÊçÆÁ±ªÂûãËΩ¨Êç¢ÂíåÂºÇÂ∏∏ÂÄºÂ§ÑÁêÜ

## 6. Feature Engineering and Graph Construction from Transactions

In [None]:
def compute_pattern_features_for_account(args):
    """‰∏∫Âçï‰∏™Ë¥¶Êà∑ËÆ°ÁÆóÊ®°ÂºèÁâπÂæÅÁöÑÂáΩÊï∞ÔºàÁî®‰∫éÂπ∂Ë°åÂ§ÑÁêÜÔºâ- ÂÖ®Â±ÄÂáΩÊï∞ÁâàÊú¨"""
    account, txn_data_dict = args
    
    sender_col = txn_data_dict['sender_col']
    receiver_col = txn_data_dict['receiver_col']
    amount_col = txn_data_dict['amount_col']
    gas_col = txn_data_dict['gas_col']
    gas_price_col = txn_data_dict['gas_price_col']
    
    # ‰ªé‰º†ÂÖ•ÁöÑÊï∞ÊçÆÂ≠óÂÖ∏‰∏≠ÈáçÂª∫‰∫§ÊòìDataFrame
    transactions_subset = txn_data_dict['transactions']
    
    # Get all transactions for this account
    account_txs = transactions_subset[
        (transactions_subset[sender_col] == account) | 
        (transactions_subset[receiver_col] == account)
    ].copy()
    
    if len(account_txs) == 0:
        return account, {
            'round_amount_ratio': 0, 'value_distribution_entropy': 0,
            'micro_transaction_ratio': 0, 'large_transaction_ratio': 0,
            'gas_efficiency_score': 0, 'gas_price_deviation': 0,
            'layering_complexity': 0, 'transaction_value_variance': 0
        }
    
    # Use already converted numeric values
    values = account_txs[amount_col].dropna()
    
    if len(values) > 0:
        # Round amount ratio (potential laundering indicator)
        round_amounts = sum(v == int(v) for v in values if not pd.isna(v))
        round_amount_ratio = round_amounts / len(values)
        
        # Value distribution entropy
        try:
            value_bins = pd.cut(values, bins=min(10, len(values.unique())), duplicates='drop')
            value_dist = value_bins.value_counts(normalize=True)
            value_dist = value_dist[value_dist > 0]
            if len(value_dist) > 1:
                value_distribution_entropy = -sum(p * np.log2(p) for p in value_dist)
            else:
                value_distribution_entropy = 0
        except:
            value_distribution_entropy = 0
        
        # Micro and large transaction ratios
        avg_value = values.mean()
        micro_threshold = avg_value * 0.1
        large_threshold = avg_value * 10
        micro_transaction_ratio = sum(values < micro_threshold) / len(values)
        large_transaction_ratio = sum(values > large_threshold) / len(values)
        
        # Transaction value variance (normalized)
        transaction_value_variance = values.var() / (avg_value ** 2) if avg_value > 0 else 0
    else:
        round_amount_ratio = value_distribution_entropy = 0
        micro_transaction_ratio = large_transaction_ratio = 0
        transaction_value_variance = 0
    
    # Gas analysis features
    gas_efficiency_score = gas_price_deviation = 0
    if gas_col and gas_price_col and gas_col in account_txs.columns and gas_price_col in account_txs.columns:
        try:
            gas_values = account_txs[gas_col].dropna()
            gas_price_values = account_txs[gas_price_col].dropna()
            
            if len(gas_values) > 0 and len(gas_price_values) > 0:
                # Gas efficiency: value per gas ratio
                if len(values) > 0:
                    gas_efficiency_score = values.mean() / gas_values.mean() if gas_values.mean() > 0 else 0
                
                # Gas price deviation (unusual gas prices might indicate urgency/laundering)
                median_gas_price = gas_price_values.median()
                if median_gas_price > 0 and len(gas_price_values) > 1:
                    gas_price_deviation = (gas_price_values.max() - median_gas_price) / median_gas_price
                else:
                    gas_price_deviation = 0
        except:
            gas_efficiency_score = gas_price_deviation = 0
    
    # Layering complexity (based on unique counterparts)
    unique_counterparts = set()
    unique_counterparts.update(account_txs[sender_col].unique())
    unique_counterparts.update(account_txs[receiver_col].unique())
    unique_counterparts.discard(account)
    layering_complexity = len(unique_counterparts) / max(len(account_txs), 1)
    
    return account, {
        'round_amount_ratio': round_amount_ratio,
        'value_distribution_entropy': value_distribution_entropy,
        'micro_transaction_ratio': micro_transaction_ratio,
        'large_transaction_ratio': large_transaction_ratio,
        'gas_efficiency_score': gas_efficiency_score,
        'gas_price_deviation': gas_price_deviation,
        'layering_complexity': layering_complexity,
        'transaction_value_variance': transaction_value_variance
    }

    
# ‰øùÊåÅ‰Ω†Áé∞ÊúâÁöÑ compute_pattern_features_for_account ÂáΩÊï∞‰∏çÂèòÔºåÂÆÉÂ∑≤ÁªèÊòØÊ≠£Á°ÆÁöÑ

# ËøôÈáåÊòØÂÆåÊï¥ÁöÑ extract_account_features_optimized ÂáΩÊï∞ÔºåÂåÖÂê´ÊâÄÊúâÂøÖË¶ÅÁöÑ‰ª£Á†Å
def extract_account_features_optimized(transactions_df, all_accounts):
    """Enhanced feature extraction with Gas and advanced pattern features - WITH WORKING PARALLEL PROCESSING"""
    print("Extracting account features (enhanced with Gas analysis)...")
    
    # ÂØºÂÖ•Âπ∂Ë°åÂ§ÑÁêÜÁõ∏ÂÖ≥Â∫ì
    import multiprocessing as mp
    try:
        from tqdm.auto import tqdm
    except ImportError:
        def tqdm(iterable, **kwargs):
            return iterable
    
    # Adapt to actual column names
    if 'from_account' in transactions_df.columns:
        sender_col, receiver_col = 'from_account', 'to_account'
        amount_col = 'value' if 'value' in transactions_df.columns else 'amount'
        timestamp_col = 'transaction_time_utc' if 'transaction_time_utc' in transactions_df.columns else 'timestamp'
        gas_col = 'gas' if 'gas' in transactions_df.columns else None
        gas_price_col = 'gas_price' if 'gas_price' in transactions_df.columns else None
    else:
        sender_col, receiver_col = 'sender', 'receiver'
        amount_col = 'amount' if 'amount' in transactions_df.columns else 'value'
        timestamp_col = 'timestamp'
        gas_col = 'gas' if 'gas' in transactions_df.columns else None
        gas_price_col = 'gas_price' if 'gas_price' in transactions_df.columns else None
    
    print(f"Using columns: sender={sender_col}, receiver={receiver_col}, amount={amount_col}")
    if gas_col and gas_price_col:
        print(f"Gas analysis enabled: gas={gas_col}, gas_price={gas_price_col}")
    else:
        print("Warning: Gas columns not found, skipping Gas analysis")
    
    # Convert numeric columns to proper types before aggregation
    print("Converting numeric columns...")
    transactions_df = transactions_df.copy()  # Don't modify original
    
    # Convert amount column to numeric
    transactions_df[amount_col] = pd.to_numeric(transactions_df[amount_col], errors='coerce')
    transactions_df = transactions_df.dropna(subset=[amount_col])  # Remove rows with invalid amounts
    
    # Convert gas columns if they exist
    if gas_col and gas_col in transactions_df.columns:
        transactions_df[gas_col] = pd.to_numeric(transactions_df[gas_col], errors='coerce')
        transactions_df = transactions_df.dropna(subset=[gas_col])
    
    if gas_price_col and gas_price_col in transactions_df.columns:
        transactions_df[gas_price_col] = pd.to_numeric(transactions_df[gas_price_col], errors='coerce')
        transactions_df = transactions_df.dropna(subset=[gas_price_col])
    
    print(f"After numeric conversion: {len(transactions_df)} valid transactions")
    
    if len(transactions_df) == 0:
        print("Warning: No valid transactions after numeric conversion")
        # Return default features for all accounts
        num_features = 33  # 15 base + 8 gas + 8 pattern + 2 temporal
        feature_matrix = np.zeros((len(all_accounts), num_features))
        scaler = StandardScaler()
        feature_matrix = scaler.fit_transform(feature_matrix)
        return torch.tensor(feature_matrix, dtype=torch.float32), scaler
    
    # Create account mapping for faster lookup
    account_to_idx = {acc: idx for idx, acc in enumerate(all_accounts)}
    
    # Pre-compute basic transaction groups for efficiency
    print("Pre-computing transaction statistics...")
    agg_dict = {
        amount_col: ['sum', 'count', 'mean', 'std', 'max'],
        receiver_col: 'nunique'
    }
    if gas_col and gas_col in transactions_df.columns:
        agg_dict[gas_col] = ['sum', 'mean', 'std', 'max']
    if gas_price_col and gas_price_col in transactions_df.columns:
        agg_dict[gas_price_col] = ['sum', 'mean', 'std', 'max']
    
    sent_stats = transactions_df.groupby(sender_col).agg(agg_dict).fillna(0)
    # Flatten column names
    sent_stats.columns = ['_'.join(col).strip() if col[1] else col[0] for col in sent_stats.columns.values]
    
    # Rename for consistency
    rename_dict = {
        f'{amount_col}_sum': 'total_sent',
        f'{amount_col}_count': 'sent_count', 
        f'{amount_col}_mean': 'avg_sent',
        f'{amount_col}_std': 'std_sent',
        f'{amount_col}_max': 'max_sent',
        f'{receiver_col}_nunique': 'unique_receivers'
    }
    if gas_col and gas_col in transactions_df.columns:
        rename_dict.update({
            f'{gas_col}_sum': 'total_gas_sent',
            f'{gas_col}_mean': 'avg_gas_sent',
            f'{gas_col}_std': 'std_gas_sent',
            f'{gas_col}_max': 'max_gas_sent'
        })
    if gas_price_col and gas_price_col in transactions_df.columns:
        rename_dict.update({
            f'{gas_price_col}_sum': 'total_gas_price_sent',
            f'{gas_price_col}_mean': 'avg_gas_price_sent',
            f'{gas_price_col}_std': 'std_gas_price_sent',
            f'{gas_price_col}_max': 'max_gas_price_sent'
        })
    sent_stats = sent_stats.rename(columns=rename_dict)
    
    # Similar for received stats
    agg_dict_recv = {
        amount_col: ['sum', 'count', 'mean', 'std', 'max'],
        sender_col: 'nunique'
    }
    if gas_col and gas_col in transactions_df.columns:
        agg_dict_recv[gas_col] = ['sum', 'mean', 'std', 'max']
    if gas_price_col and gas_price_col in transactions_df.columns:
        agg_dict_recv[gas_price_col] = ['sum', 'mean', 'std', 'max']
    
    received_stats = transactions_df.groupby(receiver_col).agg(agg_dict_recv).fillna(0)
    received_stats.columns = ['_'.join(col).strip() if col[1] else col[0] for col in received_stats.columns.values]
    
    rename_dict_recv = {
        f'{amount_col}_sum': 'total_received',
        f'{amount_col}_count': 'received_count',
        f'{amount_col}_mean': 'avg_received', 
        f'{amount_col}_std': 'std_received',
        f'{amount_col}_max': 'max_received',
        f'{sender_col}_nunique': 'unique_senders'
    }
    if gas_col and gas_col in transactions_df.columns:
        rename_dict_recv.update({
            f'{gas_col}_sum': 'total_gas_received',
            f'{gas_col}_mean': 'avg_gas_received',
            f'{gas_col}_std': 'std_gas_received',
            f'{gas_col}_max': 'max_gas_received'
        })
    if gas_price_col and gas_price_col in transactions_df.columns:
        rename_dict_recv.update({
            f'{gas_price_col}_sum': 'total_gas_price_received',
            f'{gas_price_col}_mean': 'avg_gas_price_received',
            f'{gas_price_col}_std': 'std_gas_price_received',
            f'{gas_price_col}_max': 'max_gas_price_received'
        })
    received_stats = received_stats.rename(columns=rename_dict_recv)
    
    # ===== ‰øÆÂ§çÂêéÁöÑÂπ∂Ë°åÂ§ÑÁêÜÊ®°ÂºèÁâπÂæÅËÆ°ÁÆó =====
    
    # ÂáÜÂ§áÂπ∂Ë°åÂ§ÑÁêÜÊï∞ÊçÆ
    txn_data_dict = {
        'transactions': transactions_df,
        'sender_col': sender_col,
        'receiver_col': receiver_col,
        'amount_col': amount_col,
        'gas_col': gas_col,
        'gas_price_col': gas_price_col
    }
    
    # ‰ΩøÁî®CPUÊ†∏ÂøÉÊï∞-1ËøõË°åÂπ∂Ë°åÂ§ÑÁêÜ
    num_cores = max(1, mp.cpu_count() - 1)
    print(f"Computing advanced pattern features using {num_cores} CPU cores...")
    
    # ÂàõÂª∫ÂèÇÊï∞ÂàóË°®
    args_list = [(account, txn_data_dict) for account in all_accounts]
    
    # Âπ∂Ë°åËÆ°ÁÆóÊ®°ÂºèÁâπÂæÅ
    try:
        with mp.Pool(processes=num_cores) as pool:
            pattern_results = list(tqdm(
                pool.imap(compute_pattern_features_for_account, args_list),
                total=len(all_accounts),
                desc="Pattern features (parallel)",
                unit="accounts"
            ))
        
        # ËΩ¨Êç¢ÁªìÊûú‰∏∫Â≠óÂÖ∏
        pattern_features = dict(pattern_results)
        print(f"‚úÖ Parallel pattern computation completed using {num_cores} cores")
        
    except Exception as e:
        print(f"‚ö†Ô∏è Parallel processing failed: {e}")
        print("üîÑ Falling back to sequential processing...")
        
        # ÂõûÈÄÄÂà∞‰∏≤Ë°åÂ§ÑÁêÜ
        pattern_features = {}
        for account in tqdm(all_accounts, desc="Pattern features (sequential)"):
            _, features = compute_pattern_features_for_account((account, txn_data_dict))
            pattern_features[account] = features
    
    # Temporal features if timestamp exists
    temporal_features = {}
    if timestamp_col in transactions_df.columns:
        print(f"Computing temporal features using {timestamp_col}...")
        try:
            transactions_df[timestamp_col] = pd.to_datetime(transactions_df[timestamp_col])
            
            # Transaction span for each account
            sent_temporal = transactions_df.groupby(sender_col)[timestamp_col].agg(['min', 'max', 'count'])
            received_temporal = transactions_df.groupby(receiver_col)[timestamp_col].agg(['min', 'max', 'count'])
            
            for acc in all_accounts:
                if acc in sent_temporal.index or acc in received_temporal.index:
                    dates = []
                    if acc in sent_temporal.index:
                        dates.extend([sent_temporal.loc[acc, 'min'], sent_temporal.loc[acc, 'max']])
                    if acc in received_temporal.index:
                        dates.extend([received_temporal.loc[acc, 'min'], received_temporal.loc[acc, 'max']])
                    
                    if dates:
                        span_days = (max(dates) - min(dates)).days
                        total_txns = (sent_temporal.loc[acc, 'count'] if acc in sent_temporal.index else 0) + \
                                   (received_temporal.loc[acc, 'count'] if acc in received_temporal.index else 0)
                        avg_interval = span_days / max(total_txns, 1)
                    else:
                        span_days = avg_interval = 0
                else:
                    span_days = avg_interval = 0
                
                temporal_features[acc] = [span_days, avg_interval]
        except Exception as e:
            print(f"Warning: Could not compute temporal features: {e}")
            for acc in all_accounts:
                temporal_features[acc] = [0, 0]
    
    # Build enhanced feature matrix
    print("Building enhanced feature matrix...")
    base_features = 15  # Original features
    gas_features = 8 if (gas_col and gas_price_col and gas_col in transactions_df.columns and gas_price_col in transactions_df.columns) else 0  # Gas-related features  
    pattern_features_count = 8  # Advanced pattern features
    temporal_features_count = 2 if timestamp_col in transactions_df.columns else 0
    
    num_features = base_features + gas_features + pattern_features_count + temporal_features_count
    feature_matrix = np.zeros((len(all_accounts), num_features))
    
    for i, account in enumerate(all_accounts):
        feature_idx = 0
        
        # Basic sent transaction features
        if account in sent_stats.index:
            sent_row = sent_stats.loc[account]
            total_sent = sent_row.get('total_sent', 0)
            sent_count = sent_row.get('sent_count', 0)
            avg_sent = sent_row.get('avg_sent', 0)
            std_sent = sent_row.get('std_sent', 0)
            max_sent = sent_row.get('max_sent', 0)
            unique_receivers = sent_row.get('unique_receivers', 0)
        else:
            total_sent = sent_count = avg_sent = std_sent = max_sent = unique_receivers = 0
        
        # Basic received transaction features
        if account in received_stats.index:
            received_row = received_stats.loc[account]
            total_received = received_row.get('total_received', 0)
            received_count = received_row.get('received_count', 0)
            avg_received = received_row.get('avg_received', 0)
            std_received = received_row.get('std_received', 0)
            max_received = received_row.get('max_received', 0)
            unique_senders = received_row.get('unique_senders', 0)
        else:
            total_received = received_count = avg_received = std_received = max_received = unique_senders = 0
        
        # Derived features
        net_flow = total_received - total_sent
        total_txns = sent_count + received_count
        transaction_ratio = sent_count / (received_count + 1e-8)  # Avoid division by zero
        
        # Base features (15)
        features = [
            total_sent, total_received, net_flow,
            sent_count, received_count, total_txns,
            avg_sent, avg_received,
            max_sent, max_received,
            std_sent, std_received,
            unique_senders, unique_receivers,
            transaction_ratio
        ]
        
        feature_matrix[i, feature_idx:feature_idx+15] = features
        feature_idx += 15
        
        # Gas features (8) if available
        if gas_col and gas_price_col and gas_col in transactions_df.columns and gas_price_col in transactions_df.columns:
            gas_features_list = []
            if account in sent_stats.index:
                gas_features_list.extend([
                    sent_stats.loc[account].get('total_gas_sent', 0),
                    sent_stats.loc[account].get('avg_gas_sent', 0),
                    sent_stats.loc[account].get('std_gas_sent', 0),
                    sent_stats.loc[account].get('avg_gas_price_sent', 0)
                ])
            else:
                gas_features_list.extend([0, 0, 0, 0])
                
            if account in received_stats.index:
                gas_features_list.extend([
                    received_stats.loc[account].get('total_gas_received', 0),
                    received_stats.loc[account].get('avg_gas_received', 0), 
                    received_stats.loc[account].get('std_gas_received', 0),
                    received_stats.loc[account].get('avg_gas_price_received', 0)
                ])
            else:
                gas_features_list.extend([0, 0, 0, 0])
            
            feature_matrix[i, feature_idx:feature_idx+8] = gas_features_list
            feature_idx += 8
        
        # Advanced pattern features (8)
        pattern_vals = [
            pattern_features[account]['round_amount_ratio'],
            pattern_features[account]['value_distribution_entropy'], 
            pattern_features[account]['micro_transaction_ratio'],
            pattern_features[account]['large_transaction_ratio'],
            pattern_features[account]['gas_efficiency_score'],
            pattern_features[account]['gas_price_deviation'],
            pattern_features[account]['layering_complexity'],
            pattern_features[account]['transaction_value_variance']
        ]
        feature_matrix[i, feature_idx:feature_idx+8] = pattern_vals
        feature_idx += 8
        
        # Add temporal features if available (2)
        if timestamp_col in transactions_df.columns:
            feature_matrix[i, feature_idx:feature_idx+2] = temporal_features.get(account, [0, 0])
    
    # Handle infinite and NaN values
    feature_matrix = np.nan_to_num(feature_matrix, nan=0.0, posinf=1e6, neginf=-1e6)
    
    # Standardize features
    scaler = StandardScaler()
    feature_matrix = scaler.fit_transform(feature_matrix)
    
    print(f"Extracted {feature_matrix.shape[1]} enhanced features for {feature_matrix.shape[0]} accounts")
    print(f"  - Base features: 15")
    if gas_col and gas_price_col and gas_col in transactions_df.columns and gas_price_col in transactions_df.columns:
        print(f"  - Gas features: 8")
    print(f"  - Pattern features: 8") 
    if timestamp_col in transactions_df.columns:
        print(f"  - Temporal features: 2")
    
    return torch.tensor(feature_matrix, dtype=torch.float32), scaler

def build_transaction_graph_optimized(transactions_df, all_accounts, account_to_idx):
    """Build graph edges from transaction data with memory optimization - adapted for actual data format"""
    print("Building transaction graph (optimized)...")
    
    # Adapt to actual column names
    if 'from_account' in transactions_df.columns:
        sender_col, receiver_col = 'from_account', 'to_account'
        amount_col = 'value' if 'value' in transactions_df.columns else 'amount'
    else:
        sender_col, receiver_col = 'sender', 'receiver'
        amount_col = 'amount' if 'amount' in transactions_df.columns else 'value'
    
    print(f"Using columns: {sender_col} -> {receiver_col}, amount: {amount_col}")
    
    # Filter transactions to only include accounts in our dataset
    valid_txns = transactions_df[
        (transactions_df[sender_col].isin(account_to_idx)) & 
        (transactions_df[receiver_col].isin(account_to_idx))
    ].copy()
    
    if len(valid_txns) == 0:
        print("Warning: No valid transactions found, creating sparse random graph")
        return create_fallback_graph(all_accounts)
    
    print(f"Using {len(valid_txns)} valid transactions for graph construction")
    
    # Group transactions by sender-receiver pairs efficiently
    edge_stats = valid_txns.groupby([sender_col, receiver_col]).agg({
        amount_col: ['sum', 'count', 'mean']
    }).reset_index()
    
    edge_stats.columns = [sender_col, receiver_col, 'total_amount', 'txn_count', 'avg_amount']
    
    # Build edges and weights
    edges = []
    edge_weights = []
    
    for _, row in edge_stats.iterrows():
        sender_idx = account_to_idx[row[sender_col]]
        receiver_idx = account_to_idx[row[receiver_col]]
        
        # Calculate edge weight (log-transformed for stability)
        weight = np.log1p(row['total_amount']) * np.log1p(row['txn_count'])
        
        # Add directed edge
        edges.append([sender_idx, receiver_idx])
        edge_weights.append(weight)
        
        # Add reverse edge for undirected treatment (financial networks benefit from this)
        edges.append([receiver_idx, sender_idx])
        edge_weights.append(weight * 0.8)  # Slightly lower weight for reverse direction
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_weights = torch.tensor(edge_weights, dtype=torch.float32)
    
    print(f"Created graph with {edge_index.size(1)} edges")
    return edge_index, edge_weights

def create_fallback_graph(all_accounts):
    """Create fallback random graph when no transaction data is available"""
    print("Creating fallback random graph...")
    edges = []
    np.random.seed(42)
    
    for i in range(len(all_accounts)):
        num_connections = np.random.randint(2, 6)
        neighbors = np.random.choice(len(all_accounts), num_connections, replace=False)
        for neighbor in neighbors:
            if neighbor != i:
                edges.append([i, neighbor])
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_weights = torch.ones(edge_index.size(1))
    
    return edge_index, edge_weights


In [None]:
def create_graph_from_accounts(train_df, test_df=None, transactions_df=None):
    """Create graph structure from account and transaction data with optimizations"""
    # Combine train and test data for graph construction
    all_accounts = list(train_df['account'].values)
    if test_df is not None:
        all_accounts.extend(list(test_df['account'].values))
    
    print(f"Creating graph for {len(all_accounts)} accounts")
    
    # Create account to index mapping
    account_to_idx = {acc: idx for idx, acc in enumerate(all_accounts)}
    
    if transactions_df is not None and len(transactions_df) > 0:
        print("Using real transaction data for graph construction...")
        
        # Validate data first
        try:
            validate_data(train_df, test_df, transactions_df)
        except AssertionError as e:
            print(f"Data validation error: {e}")
            print("Proceeding with available data...")
        
        # Extract features using optimized method
        node_features, feature_scaler = extract_account_features_optimized(transactions_df, all_accounts)
        
        # Build graph from transactions
        edge_index, edge_weights = build_transaction_graph_optimized(transactions_df, all_accounts, account_to_idx)
        
    else:
        print("Warning: No transaction data provided, using random features and graph...")
        
        # Fallback to random features and graph
        edge_index, edge_weights = create_fallback_graph(all_accounts)
        
        # Create random node features
        num_features = 16
        node_features = torch.randn(len(all_accounts), num_features)
        feature_scaler = None
    
    # Create labels
    labels = torch.zeros(len(all_accounts), dtype=torch.long)
    for idx, acc in enumerate(all_accounts):
        if acc in train_df['account'].values:
            flag = train_df[train_df['account'] == acc]['flag'].iloc[0]
            labels[idx] = flag
    
    # Create masks
    train_mask = torch.zeros(len(all_accounts), dtype=torch.bool)
    test_mask = torch.zeros(len(all_accounts), dtype=torch.bool)
    
    for idx, acc in enumerate(all_accounts):
        if acc in train_df['account'].values:
            train_mask[idx] = True
        elif test_df is not None and acc in test_df['account'].values:
            test_mask[idx] = True
    
    # Log graph statistics
    print(f"Graph statistics:")
    print(f"- Nodes: {len(all_accounts)}")
    print(f"- Edges: {edge_index.size(1)}")
    print(f"- Features: {node_features.size(1)}")
    print(f"- Training nodes: {train_mask.sum().item()}")
    print(f"- Test nodes: {test_mask.sum().item()}")
    
    data = Data(
        x=node_features,
        edge_index=edge_index,
        edge_attr=edge_weights,
        y=labels,
        train_mask=train_mask,
        test_mask=test_mask,
        account_names=all_accounts,
        account_to_idx=account_to_idx,
        feature_scaler=feature_scaler
    )
    
    return data

## 7. Training Functions

In [None]:
def train_model(model, data, train_idx, val_idx, params, epochs=None):
    """Train the model for one fold with optimized settings for imbalanced data"""
    if epochs is None:
        epochs = Config.TRAINING_PARAMS['epochs']
    
    optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])
    criterion = FocalLoss(alpha=params['focal_alpha'], gamma=params['focal_gamma'])
    
    # More aggressive learning rate scheduling for imbalanced data
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', patience=8, factor=0.5, min_lr=1e-6, verbose=False
    )
    
    best_val_f1 = 0
    best_model_state = None
    patience_counter = 0
    patience = Config.TRAINING_PARAMS['patience']
    
    # Create train and validation masks for this fold
    train_mask = torch.zeros(data.x.size(0), dtype=torch.bool)
    val_mask = torch.zeros(data.x.size(0), dtype=torch.bool)
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    
    # Training loop
    for epoch in range(epochs):
        # Training phase
        model.train()
        optimizer.zero_grad()
        
        out = model(data.x, data.edge_index)
        loss = criterion(out[train_mask], data.y[train_mask])
        
        loss.backward()
        
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Validation phase (more frequent for better monitoring)
        if epoch % 5 == 0:
            model.eval()
            with torch.no_grad():
                val_out = model(data.x, data.edge_index)
                val_pred = val_out[val_mask].argmax(dim=1)
                val_true = data.y[val_mask]
                
                # Calculate bad class F1 (primary metric for imbalanced data)
                val_f1_bad = f1_score(val_true.cpu(), val_pred.cpu(), pos_label=1, zero_division=0)
                
                # Update learning rate based on F1 score
                scheduler.step(val_f1_bad)
                
                # Early stopping based on F1 improvement
                if val_f1_bad > best_val_f1:
                    best_val_f1 = val_f1_bad
                    best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
                    patience_counter = 0
                else:
                    patience_counter += 1
                
                # Early stopping
                if patience_counter >= patience:
                    if epoch > 50:  # Ensure minimum training
                        break
    
    # Load best model state
    if best_model_state is not None:
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
    
    return model, best_val_f1

## 8. Hyperparameter Optimization

In [None]:
def objective(trial, data, train_df):
    """Optuna objective function optimized for imbalanced data (1:9)"""
    # Suggest hyperparameters with ranges optimized for imbalanced classification
    params = {
        'hidden_dim': trial.suggest_categorical('hidden_dim', Config.MODEL_PARAMS['hidden_dim_choices']),
        'num_layers': trial.suggest_int('num_layers', *Config.MODEL_PARAMS['num_layers_range']),
        'dropout': trial.suggest_float('dropout', *Config.MODEL_PARAMS['dropout_range']),
        'lr': trial.suggest_loguniform('lr', 1e-4, 1e-2),  # More conservative learning rate
        'weight_decay': trial.suggest_loguniform('weight_decay', 1e-6, 1e-3),
        # Focal loss parameters optimized for 1:9 imbalance
        'focal_alpha': trial.suggest_float('focal_alpha', *Config.TRAINING_PARAMS['focal_alpha_range']),
        'focal_gamma': trial.suggest_float('focal_gamma', *Config.TRAINING_PARAMS['focal_gamma_range'])
    }
    
    # 5-fold cross validation
    skf = StratifiedKFold(n_splits=Config.TRAINING_PARAMS['n_splits'], shuffle=True, random_state=42)
    
    # Get train indices that correspond to actual training data
    train_indices = []
    train_labels = []
    for idx, acc in enumerate(data.account_names):
        if acc in train_df['account'].values:
            train_indices.append(idx)
            train_labels.append(data.y[idx].item())
    
    train_indices = np.array(train_indices)
    train_labels = np.array(train_labels)
    
    # Check class distribution
    unique, counts = np.unique(train_labels, return_counts=True)
    if len(unique) < 2:
        print("Warning: Only one class found in training data")
        return 0.0
    
    fold_f1_scores = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(train_indices, train_labels)):
        try:
            # Convert to actual indices in the graph
            actual_train_idx = train_indices[train_idx]
            actual_val_idx = train_indices[val_idx]
            
            # Create model
            model = GNNModel(
                num_features=data.x.size(1),
                hidden_dim=params['hidden_dim'],
                num_layers=params['num_layers'],
                dropout=params['dropout']
            ).to(device)
            
            # Train model
            model, val_f1 = train_model(model, data, actual_train_idx, actual_val_idx, params, epochs=100)  # Reduced epochs for optimization
            fold_f1_scores.append(val_f1)
            
        except Exception as e:
            print(f"Error in fold {fold}: {e}")
            fold_f1_scores.append(0.0)
    
    # Return mean F1 score, with penalty for variance (stability preference)
    mean_f1 = np.mean(fold_f1_scores)
    std_f1 = np.std(fold_f1_scores)
    
    # Penalize high variance (less stable models)
    return mean_f1 - 0.1 * std_f1

In [None]:
def comprehensive_evaluation(y_true, y_pred, y_prob=None):
    """Comprehensive evaluation metrics for imbalanced classification"""
    metrics = {
        'f1_bad': f1_score(y_true.cpu(), y_pred.cpu(), pos_label=1, zero_division=0),
        'f1_macro': f1_score(y_true.cpu(), y_pred.cpu(), average='macro', zero_division=0),
        'f1_weighted': f1_score(y_true.cpu(), y_pred.cpu(), average='weighted', zero_division=0),
        'precision_bad': precision_score(y_true.cpu(), y_pred.cpu(), pos_label=1, zero_division=0),
        'recall_bad': recall_score(y_true.cpu(), y_pred.cpu(), pos_label=1, zero_division=0),
        'accuracy': accuracy_score(y_true.cpu(), y_pred.cpu())
    }
    
    # Add AUC if probabilities are provided
    if y_prob is not None:
        try:
            metrics['roc_auc'] = roc_auc_score(y_true.cpu(), y_prob[:, 1].cpu())
        except:
            metrics['roc_auc'] = 0.0
    
    return metrics

def train_with_cv(data, train_df, n_trials=None):
    """Train with cross-validation and hyperparameter tuning, optimized for imbalanced data"""
    if n_trials is None:
        n_trials = Config.TRAINING_PARAMS['n_trials']
    
    print(f"Starting hyperparameter optimization with {n_trials} trials...")
    print("Focusing on bad class F1-score for imbalanced data (1:9 ratio)")
    
    # Create Optuna study with better sampler for imbalanced classification
    study = optuna.create_study(
        direction='maximize',
        sampler=optuna.samplers.TPESampler(seed=42),
        pruner=optuna.pruners.HyperbandPruner()
    )
    
    study.optimize(lambda trial: objective(trial, data, train_df), n_trials=n_trials)
    
    print(f"Best trial score (F1-bad): {study.best_trial.value:.4f}")
    print(f"Best params: {study.best_params}")
    
    # Train final model with best parameters
    best_params = study.best_params
    
    # Get train indices
    train_indices = []
    train_labels = []
    for idx, acc in enumerate(data.account_names):
        if acc in train_df['account'].values:
            train_indices.append(idx)
            train_labels.append(data.y[idx].item())
    
    train_indices = np.array(train_indices)
    train_labels = np.array(train_labels)
    
    # 5-fold CV with best parameters for final evaluation
    skf = StratifiedKFold(n_splits=Config.TRAINING_PARAMS['n_splits'], shuffle=True, random_state=42)
    fold_models = []
    fold_metrics = []
    
    print(f"\nFinal training with best parameters:")
    for key, value in best_params.items():
        print(f"  {key}: {value}")
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(train_indices, train_labels)):
        print(f"Training fold {fold + 1}/{Config.TRAINING_PARAMS['n_splits']}...")
        
        actual_train_idx = train_indices[train_idx]
        actual_val_idx = train_indices[val_idx]
        
        model = GNNModel(
            num_features=data.x.size(1),
            hidden_dim=best_params['hidden_dim'],
            num_layers=best_params['num_layers'],
            dropout=best_params['dropout']
        ).to(device)
        
        model, val_f1 = train_model(model, data, actual_train_idx, actual_val_idx, best_params)
        
        # Comprehensive evaluation on validation set
        model.eval()
        with torch.no_grad():
            val_out = model(data.x, data.edge_index)
            val_pred = val_out[actual_val_idx].argmax(dim=1)
            val_true = data.y[actual_val_idx]
            val_prob = F.softmax(val_out[actual_val_idx], dim=1)
            
            metrics = comprehensive_evaluation(val_true, val_pred, val_prob)
        
        fold_models.append(model)
        fold_metrics.append(metrics)
        
        print(f"Fold {fold + 1} results:")
        print(f"  Bad F1: {metrics['f1_bad']:.4f}")
        print(f"  Bad Precision: {metrics['precision_bad']:.4f}")
        print(f"  Bad Recall: {metrics['recall_bad']:.4f}")
        print(f"  Macro F1: {metrics['f1_macro']:.4f}")
    
    # Select best fold based on bad F1 score (most important for imbalanced data)
    best_fold_idx = np.argmax([m['f1_bad'] for m in fold_metrics])
    best_model = fold_models[best_fold_idx]
    best_metrics = fold_metrics[best_fold_idx]
    
    # Calculate average metrics across folds
    avg_metrics = {}
    for key in fold_metrics[0].keys():
        avg_metrics[f'avg_{key}'] = np.mean([m[key] for m in fold_metrics])
        avg_metrics[f'std_{key}'] = np.std([m[key] for m in fold_metrics])
    
    print(f"\nCross-validation results:")
    print(f"Best fold: {best_fold_idx + 1}")
    print(f"Average bad F1: {avg_metrics['avg_f1_bad']:.4f} ¬± {avg_metrics['std_f1_bad']:.4f}")
    print(f"Average macro F1: {avg_metrics['avg_f1_macro']:.4f} ¬± {avg_metrics['std_f1_macro']:.4f}")
    print(f"Average bad precision: {avg_metrics['avg_precision_bad']:.4f} ¬± {avg_metrics['std_precision_bad']:.4f}")
    print(f"Average bad recall: {avg_metrics['avg_recall_bad']:.4f} ¬± {avg_metrics['std_recall_bad']:.4f}")
    
    # Combine best metrics with averages
    final_metrics = {**best_metrics, **avg_metrics}
    
    return best_model, final_metrics, best_params

## 9. Prediction and Saving Functions

In [None]:
def predict_test_data(model, data, test_df):
    """Make predictions on test data"""
    model.eval()
    
    # Get test indices
    test_indices = []
    test_accounts = []
    for idx, acc in enumerate(data.account_names):
        if acc in test_df['account'].values:
            test_indices.append(idx)
            test_accounts.append(acc)
    
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        test_pred = out[test_indices].argmax(dim=1)
    
    # Create prediction dataframe
    predictions_df = pd.DataFrame({
        'account': test_accounts,
        'Predict': test_pred.cpu().numpy()
    })
    
    # Sort by account to match original order
    predictions_df = predictions_df.sort_values('account').reset_index(drop=True)
    
    return predictions_df

In [None]:
def save_results(model, predictions_df, metrics, params):
    """Save model and predictions safely"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    try:
        # Save model
        model_path = f'/content/best_gnn_model_{timestamp}.pth'
        torch.save({
            'model_state_dict': model.state_dict(),
            'metrics': metrics,
            'params': params,
            'timestamp': timestamp
        }, model_path)
        print(f"Model saved to: {model_path}")
        
        # Save predictions
        pred_path = f'/content/test_predictions_{timestamp}.csv'
        predictions_df.to_csv(pred_path, index=False)
        print(f"Predictions saved to: {pred_path}")
        
        # Save to Google Drive if mounted
        if os.path.exists('/content/drive'):
            drive_model_path = f'/content/drive/MyDrive/best_gnn_model_{timestamp}.pth'
            drive_pred_path = f'/content/drive/MyDrive/test_predictions_{timestamp}.csv'
            
            torch.save({
                'model_state_dict': model.state_dict(),
                'metrics': metrics,
                'params': params,
                'timestamp': timestamp
            }, drive_model_path)
            
            predictions_df.to_csv(drive_pred_path, index=False)
            print(f"Also saved to Google Drive: {drive_model_path}")
        
        return True
        
    except Exception as e:
        print(f"Error saving results: {e}")
        return False

In [None]:
def load_model_and_predict(data, test_df, model_path):
    """Load saved model and make predictions"""
    try:
        checkpoint = torch.load(model_path, map_location=device)
        
        # Create model with saved parameters
        params = checkpoint['params']
        model = GNNModel(
            num_features=data.x.size(1),
            hidden_dim=params['hidden_dim'],
            num_layers=params['num_layers'],
            dropout=params['dropout']
        ).to(device)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Model loaded successfully from: {model_path}")
        print(f"Model metrics: {checkpoint['metrics']}")
        
        # Make predictions
        predictions_df = predict_test_data(model, data, test_df)
        
        return predictions_df
        
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

## 10. Main Function

In [None]:
def main(mode='normal'):
    """Main function with improved error handling and data validation"""
    print(f"Running in {mode} mode...")
    print(f"Configuration: {Config.TRAINING_PARAMS['n_trials']} trials, {Config.TRAINING_PARAMS['n_splits']}-fold CV")
    
    # Load data with error handling
    try:
        if os.path.exists('/content/train_acc.csv'):
            print("Loading data from local copies...")
            train_df = pd.read_csv('/content/train_acc.csv')
            test_df = pd.read_csv('/content/test_acc_predict.csv')
            transactions_df = pd.read_csv('/content/transactions.csv')
        else:
            print("Loading data from Google Drive...")
            train_df = pd.read_csv('/content/drive/MyDrive/original_data/train_acc.csv')
            test_df = pd.read_csv('/content/drive/MyDrive/original_data/test_acc_predict.csv')
            transactions_df = pd.read_csv('/content/drive/MyDrive/original_data/transactions.csv')
        
        print(f"Data loaded successfully:")
        print(f"- Train: {len(train_df)} samples")
        print(f"- Test: {len(test_df)} samples") 
        print(f"- Transactions: {len(transactions_df)} records")
        
        # Display class distribution
        class_dist = train_df['flag'].value_counts().sort_index()
        print(f"- Class distribution: {dict(class_dist)}")
        ratio = class_dist[0] / class_dist[1] if len(class_dist) > 1 else 1
        print(f"- Imbalance ratio: 1:{ratio:.1f}")
        
        # Quick data inspection
        if len(transactions_df) > 0:
            print(f"- Transaction columns: {list(transactions_df.columns)}")
            
            # Detect amount column
            amount_col = 'value' if 'value' in transactions_df.columns else 'amount'
            if amount_col in transactions_df.columns:
                try:
                    # Convert to numeric in case it's stored as string
                    amount_values = pd.to_numeric(transactions_df[amount_col], errors='coerce')
                    amount_min = amount_values.min()
                    amount_max = amount_values.max()
                    print(f"- Amount range: ${amount_min:.2f} - ${amount_max:.2f}")
                except:
                    print(f"- Amount column: {amount_col} (contains non-numeric values)")
            
            # Check transaction coverage with flexible column names
            train_accounts = set(train_df['account'])
            test_accounts = set(test_df['account'])
            all_dataset_accounts = train_accounts.union(test_accounts)
            
            if 'from_account' in transactions_df.columns:
                txn_accounts = set(transactions_df['from_account']).union(set(transactions_df['to_account']))
            else:
                txn_accounts = set(transactions_df['sender']).union(set(transactions_df['receiver']))
            
            coverage = len(all_dataset_accounts.intersection(txn_accounts)) / len(all_dataset_accounts)
            print(f"- Transaction coverage: {coverage:.1%}")
        
    except Exception as e:
        print(f"Error loading data: {e}")
        print("Please check file paths and data format")
        return None
    
    # Create graph with optimized data processing
    try:
        print("\nCreating graph structure...")
        data = create_graph_from_accounts(train_df, test_df, transactions_df)
        data = data.to(device)
        
        print(f"Graph created successfully:")
        print(f"- Nodes: {data.x.size(0)}")
        print(f"- Edges: {data.edge_index.size(1)}")
        print(f"- Features: {data.x.size(1)}")
        
        # Calculate graph density
        num_nodes = data.x.size(0)
        num_edges = data.edge_index.size(1)
        max_edges = num_nodes * (num_nodes - 1)
        density = num_edges / max_edges if max_edges > 0 else 0
        print(f"- Graph density: {density:.6f}")
        
    except Exception as e:
        print(f"Error creating graph: {e}")
        import traceback
        traceback.print_exc()
        return None
    
    if mode == 'normal':
        # Training mode
        print(f"\nStarting training with optimized parameters for imbalanced data...")
        
        try:
            best_model, best_metrics, best_params = train_with_cv(
                data, train_df, n_trials=Config.TRAINING_PARAMS['n_trials']
            )
            
            # Make predictions on test data
            print("\nMaking predictions on test data...")
            predictions_df = predict_test_data(best_model, data, test_df)
            
            # Save results
            print("Saving results...")
            success = save_results(best_model, predictions_df, best_metrics, best_params)
            
            if success:
                print("\n" + "="*50)
                print("TRAINING COMPLETED SUCCESSFULLY!")
                print("="*50)
                print(f"Primary metric (Bad F1): {best_metrics['f1_bad']:.4f}")
                print(f"Cross-validation avg: {best_metrics.get('avg_f1_bad', 0):.4f} ¬± {best_metrics.get('std_f1_bad', 0):.4f}")
                print(f"Macro F1: {best_metrics['f1_macro']:.4f}")
                print(f"Bad Precision: {best_metrics['precision_bad']:.4f}")
                print(f"Bad Recall: {best_metrics['recall_bad']:.4f}")
                print(f"Accuracy: {best_metrics['accuracy']:.4f}")
                
                print(f"\nPrediction Summary:")
                print(f"- Total predictions: {len(predictions_df)}")
                pred_dist = predictions_df['Predict'].value_counts().sort_index()
                print(f"- Predicted distribution: {dict(pred_dist)}")
                if len(pred_dist) > 1:
                    pred_ratio = pred_dist[0] / pred_dist[1]
                    print(f"- Predicted ratio: 1:{pred_ratio:.1f}")
                
                return {
                    'model': best_model,
                    'predictions': predictions_df,
                    'metrics': best_metrics,
                    'params': best_params
                }
            else:
                print("Failed to save results!")
                return None
                
        except Exception as e:
            print(f"Error during training: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    elif mode == 'test':
        # Test mode - load model and predict
        print("Test mode: Loading saved model...")
        
        # Find the latest model file
        model_files = []
        for location in ['/content', '/content/drive/MyDrive']:
            if os.path.exists(location):
                files = [f for f in os.listdir(location) if f.startswith('best_gnn_model_') and f.endswith('.pth')]
                model_files.extend([os.path.join(location, f) for f in files])
        
        if not model_files:
            print("No saved model found!")
            return None
        
        # Use the most recent model
        model_path = sorted(model_files)[-1]
        print(f"Loading model from: {model_path}")
        
        # Load model and predict
        predictions_df = load_model_and_predict(data, test_df, model_path)
        
        if predictions_df is not None:
            # Save predictions
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            pred_path = f'/content/test_predictions_inference_{timestamp}.csv'
            predictions_df.to_csv(pred_path, index=False)
            print(f"Predictions saved to: {pred_path}")
            
            print(f"\nPrediction Summary:")
            print(f"- Shape: {predictions_df.shape}")
            pred_dist = predictions_df['Predict'].value_counts().sort_index()
            print(f"- Distribution: {dict(pred_dist)}")
            
            return predictions_df
        else:
            print("Failed to make predictions!")
            return None
    
    else:
        print(f"Unknown mode: {mode}. Use 'normal' or 'test'")
        return None

## 11. Run Training (Normal Mode)

In [None]:
# Run training and prediction with enhanced feature set
print("GNN Account Classification - Enhanced with Gas Analysis")
print("="*70)
print("Configuration Summary:")
print(f"- Model params: {Config.MODEL_PARAMS}")
print(f"- Training params: {Config.TRAINING_PARAMS}")
print(f"- Feature params: {Config.FEATURE_PARAMS}")
print("="*70)
print("Enhanced Features:")
print("‚úÖ Basic transaction features (15): amounts, counts, statistics")
print("‚úÖ Gas analysis features (8): gas consumption, gas prices, efficiency")
print("‚úÖ Advanced pattern features (8): laundering indicators, complexity")
print("‚úÖ Temporal features (2): time span, transaction intervals")
print("="*70)
print("Data Format Adaptation:")
print("- Supports: 'from_account/to_account/value/gas/gas_price' (your data)")
print("- Fallback: 'sender/receiver/amount' format")
print("- Auto-detects: 'transaction_time_utc' or 'timestamp' for temporal features")
print("="*70)

# Execute main function with enhanced features
results = main(mode='normal')

## 12. Run Test Mode (Load Model and Predict Only)

In [None]:
# Run test mode to load saved model and predict
# main(mode='test')

## 13. Quick Model Analysis

In [None]:
# Optional: Quick analysis of results
import os

# List all saved files
print("Saved models:")
model_files = [f for f in os.listdir('/content') if f.startswith('best_gnn_model_')]
for f in model_files:
    print(f"  {f}")

print("\nSaved predictions:")
pred_files = [f for f in os.listdir('/content') if f.startswith('test_predictions_')]
for f in pred_files:
    print(f"  {f}")

# Load and display latest predictions
if pred_files:
    latest_pred = sorted(pred_files)[-1]
    df = pd.read_csv(f'/content/{latest_pred}')
    print(f"\nLatest predictions from {latest_pred}:")
    print(f"Shape: {df.shape}")
    print(f"Class distribution:\n{df['Predict'].value_counts()}")
    print(f"\nFirst 10 predictions:")
    print(df.head(10))