# Simple Enhanced Preprocessing for IBM AML Multi-GNN

This notebook implements a simplified version of the enhanced preprocessing pipeline for the IBM AML Synthetic Dataset.

## Focus Areas:
1. **Enhanced Node Features**: 20+ comprehensive account features
2. **Enhanced Edge Features**: 15+ transaction features with cyclic temporal encoding
3. **Class Imbalance Handling**: SMOTE + cost-sensitive learning
4. **Memory Optimization**: Chunked processing for large datasets

## Implementation Strategy:
- Start with core features (Phase 1)
- Test with sample data first
- Gradually scale to full dataset
- Monitor memory usage and performance


In [None]:
# Simple Enhanced Preprocessing for IBM AML Multi-GNN
print("=" * 60)
print("Simple Enhanced AML Preprocessing Pipeline")
print("=" * 60)

# Import required libraries
import pandas as pd
import numpy as np
import networkx as nx
import torch
import torch_geometric
from torch_geometric.data import Data, DataLoader
from sklearn.preprocessing import StandardScaler, RobustScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
import json
import os
import gc
from datetime import datetime
import time
import pickle
import shutil
import warnings
warnings.filterwarnings('ignore')

print("✓ Libraries imported successfully")

# Progress estimation function
def estimate_processing_time(sample_size):
    """Estimate processing time based on sample size"""
    if sample_size <= 1000:
        return "1-2 minutes"
    elif sample_size <= 10000:
        return "5-10 minutes"
    elif sample_size <= 100000:
        return "30-60 minutes"
    else:
        return "2-4 hours"

# Checkpoint management functions
def create_checkpoint_dir(base_path, sample_size):
    """Create checkpoint directory for this preprocessing run"""
    checkpoint_dir = os.path.join(base_path, f"checkpoints_{sample_size}")
    os.makedirs(checkpoint_dir, exist_ok=True)
    return checkpoint_dir

def save_checkpoint(checkpoint_dir, step_name, data, metadata=None):
    """Save checkpoint data with metadata"""
    checkpoint_file = os.path.join(checkpoint_dir, f"{step_name}.pkl")
    checkpoint_data = {
        'data': data,
        'metadata': metadata or {},
        'timestamp': datetime.now().isoformat(),
        'step': step_name
    }
    
    # Save to temporary file first, then move (atomic operation)
    temp_file = checkpoint_file + '.tmp'
    with open(temp_file, 'wb') as f:
        pickle.dump(checkpoint_data, f)
    shutil.move(temp_file, checkpoint_file)
    
    print(f"✓ Checkpoint saved: {step_name}")

def load_checkpoint(checkpoint_dir, step_name):
    """Load checkpoint data if it exists"""
    checkpoint_file = os.path.join(checkpoint_dir, f"{step_name}.pkl")
    if os.path.exists(checkpoint_file):
        with open(checkpoint_file, 'rb') as f:
            checkpoint_data = pickle.load(f)
        print(f"✓ Checkpoint loaded: {step_name}")
        return checkpoint_data['data'], checkpoint_data['metadata']
    return None, None

def list_available_checkpoints(checkpoint_dir):
    """List all available checkpoints"""
    if not os.path.exists(checkpoint_dir):
        return []
    
    checkpoints = []
    for file in os.listdir(checkpoint_dir):
        if file.endswith('.pkl'):
            step_name = file.replace('.pkl', '')
            checkpoints.append(step_name)
    return sorted(checkpoints)

def resume_from_checkpoint(checkpoint_dir, target_step):
    """Resume preprocessing from a specific checkpoint"""
    available = list_available_checkpoints(checkpoint_dir)
    if target_step in available:
        print(f"✓ Resuming from checkpoint: {target_step}")
        return load_checkpoint(checkpoint_dir, target_step)
    else:
        print(f"⚠️  Checkpoint {target_step} not found. Available: {available}")
        return None, None

# Chunked processing functions
def process_data_in_chunks(data, chunk_size, process_func, checkpoint_dir, step_name):
    """Process data in chunks with checkpointing"""
    total_chunks = (len(data) + chunk_size - 1) // chunk_size
    results = []
    
    print(f"📦 Processing {len(data)} items in {total_chunks} chunks of {chunk_size}")
    
    for chunk_idx in range(total_chunks):
        start_idx = chunk_idx * chunk_size
        end_idx = min((chunk_idx + 1) * chunk_size, len(data))
        chunk_data = data.iloc[start_idx:end_idx]
        
        print(f"  Processing chunk {chunk_idx + 1}/{total_chunks} (items {start_idx}-{end_idx-1})")
        
        # Process chunk
        chunk_result = process_func(chunk_data)
        results.append(chunk_result)
        
        # Save chunk checkpoint
        chunk_checkpoint_name = f"{step_name}_chunk_{chunk_idx}"
        save_checkpoint(checkpoint_dir, chunk_checkpoint_name, chunk_result, {
            'chunk_idx': chunk_idx,
            'start_idx': start_idx,
            'end_idx': end_idx,
            'chunk_size': len(chunk_data)
        })
        
        # Memory cleanup
        gc.collect()
    
    return results

def combine_chunk_results(chunk_results, combine_func):
    """Combine results from multiple chunks"""
    print(f"🔗 Combining {len(chunk_results)} chunk results...")
    return combine_func(chunk_results)

def load_chunk_checkpoints(checkpoint_dir, step_name):
    """Load all chunk checkpoints for a step"""
    chunk_files = [f for f in os.listdir(checkpoint_dir) if f.startswith(f"{step_name}_chunk_")]
    chunk_results = []
    
    for chunk_file in sorted(chunk_files):
        chunk_data, metadata = load_checkpoint(checkpoint_dir, chunk_file.replace('.pkl', ''))
        if chunk_data is not None:
            chunk_results.append(chunk_data)
    
    return chunk_results

print("✓ Chunked processing functions defined")


In [None]:
# Load and Validate Data
def load_aml_data(data_path, sample_size=10000):
    """Load AML data with optional sampling for testing"""
    print("Loading IBM AML dataset...")
    
    # Load transactions
    trans_file = os.path.join(data_path, 'HI-Small_Trans.csv')
    if os.path.exists(trans_file):
        # Load with sampling for testing
        transactions = pd.read_csv(trans_file, nrows=sample_size)
        print(f"✓ Loaded {len(transactions)} transactions (sampled)")
    else:
        raise FileNotFoundError(f"Transaction file not found: {trans_file}")
    
    # Load accounts
    accounts_file = os.path.join(data_path, 'HI-Small_accounts.csv')
    if os.path.exists(accounts_file):
        accounts = pd.read_csv(accounts_file)
        print(f"✓ Loaded {len(accounts)} accounts")
    else:
        # Create accounts from transactions
        print("Creating accounts from transaction data...")
        all_accounts = set(transactions['Account'].tolist() + 
                          transactions['Account.1'].tolist())
        
        accounts_data = {
            'Account Number': list(all_accounts),
            'Bank Name': [f"Bank_{i}" for i in range(len(all_accounts))],
            'Bank ID': [f"B{i}" for i in range(len(all_accounts))],
            'Entity ID': [f"E{i}" for i in range(len(all_accounts))],
            'Entity Name': [f"Entity_{i}" for i in range(len(all_accounts))]
        }
        accounts = pd.DataFrame(accounts_data)
        print(f"✓ Created {len(accounts)} accounts from transactions")
    
    # Validate data
    print("\nData Validation:")
    print(f"Missing values - Transactions: {transactions.isnull().sum().sum()}")
    print(f"Missing values - Accounts: {accounts.isnull().sum().sum()}")
    
    if 'Is Laundering' in transactions.columns:
        class_dist = transactions['Is Laundering'].value_counts()
        print(f"Class distribution: {class_dist}")
        print(f"SAR rate: {class_dist[1] / len(transactions):.4f}")
    
    return transactions, accounts

print("✓ Data loading function defined")


In [None]:
# Enhanced Node Features
def create_enhanced_node_features(transactions, accounts):
    """Create comprehensive node features"""
    print("Creating enhanced node features...")
    
    node_features = {}
    total_accounts = len(accounts)
    
    # Progress bar for node features
    from tqdm import tqdm
    print(f"Processing {total_accounts} accounts...")
    
    for idx, (_, account) in enumerate(tqdm(accounts.iterrows(), total=total_accounts, desc="Node Features")):
        account_id = account['Account Number']
        
        # Get account transactions
        account_trans = transactions[
            (transactions['Account'] == account_id) | 
            (transactions['Account.1'] == account_id)
        ]
        
        if len(account_trans) == 0:
            # Default features for accounts with no transactions
            node_features[account_id] = {
                'transaction_count': 0, 'total_sent': 0, 'total_received': 0,
                'avg_amount': 0, 'max_amount': 0, 'min_amount': 0,
                'temporal_span': 0, 'transaction_frequency': 0,
                'currency_diversity': 0, 'bank_diversity': 0,
                'night_ratio': 0, 'weekend_ratio': 0,
                'is_crypto_bank': 0, 'is_international': 0, 'is_high_frequency': 0
            }
            continue
        
        # Basic transaction features
        sent_trans = account_trans[account_trans['Account'] == account_id]
        received_trans = account_trans[account_trans['Account.1'] == account_id]
        
        # Amount features
        total_sent = sent_trans['Amount Paid'].sum() if len(sent_trans) > 0 else 0
        total_received = received_trans['Amount Received'].sum() if len(received_trans) > 0 else 0
        avg_amount = account_trans['Amount Paid'].mean()
        max_amount = account_trans['Amount Paid'].max()
        min_amount = account_trans['Amount Paid'].min()
        
        # Temporal features
        timestamps = pd.to_datetime(account_trans['Timestamp'])
        temporal_span = (timestamps.max() - timestamps.min()).days
        transaction_frequency = len(account_trans) / max(1, temporal_span)
        
        # Diversity measures
        currency_diversity = account_trans['Payment Currency'].nunique()
        bank_diversity = account_trans['To Bank'].nunique()
        
        # Time-based features
        night_transactions = timestamps.dt.hour.isin([22, 23, 0, 1, 2, 3, 4, 5, 6]).sum()
        weekend_transactions = timestamps.dt.weekday.isin([5, 6]).sum()
        night_ratio = night_transactions / len(account_trans)
        weekend_ratio = weekend_transactions / len(account_trans)
        
        # Risk indicators
        is_crypto_bank = 'Crytpo' in str(account_id)
        is_international = currency_diversity > 1
        is_high_frequency = transaction_frequency > 1.0
        
        node_features[account_id] = {
            'transaction_count': len(account_trans),
            'total_sent': total_sent, 'total_received': total_received,
            'avg_amount': avg_amount, 'max_amount': max_amount, 'min_amount': min_amount,
            'temporal_span': temporal_span, 'transaction_frequency': transaction_frequency,
            'currency_diversity': currency_diversity, 'bank_diversity': bank_diversity,
            'night_ratio': night_ratio, 'weekend_ratio': weekend_ratio,
            'is_crypto_bank': int(is_crypto_bank), 'is_international': int(is_international),
            'is_high_frequency': int(is_high_frequency)
        }
    
    return node_features

print("✓ Enhanced node feature function defined")


In [None]:
# Enhanced Edge Features
def create_enhanced_edge_features(transactions):
    """Create comprehensive edge features"""
    print("Creating enhanced edge features...")
    
    edge_features = []
    edge_labels = []
    total_transactions = len(transactions)
    
    # Progress bar for edge features
    from tqdm import tqdm
    print(f"Processing {total_transactions} transactions...")
    
    # Prepare encoders
    print("Preparing encoders...")
    currency_encoder = LabelEncoder()
    format_encoder = LabelEncoder()
    bank_encoder = LabelEncoder()
    
    currency_encoder.fit(transactions['Payment Currency'].unique())
    format_encoder.fit(transactions['Payment Format'].unique())
    bank_encoder.fit(transactions['From Bank'].unique())
    print("✓ Encoders prepared")
    
    for _, transaction in tqdm(transactions.iterrows(), total=total_transactions, desc="Edge Features"):
        # Temporal features
        timestamp = pd.to_datetime(transaction['Timestamp'])
        
        # Cyclic temporal encoding
        hour_sin = np.sin(2 * np.pi * timestamp.hour / 24)
        hour_cos = np.cos(2 * np.pi * timestamp.hour / 24)
        day_sin = np.sin(2 * np.pi * timestamp.dayofweek / 7)
        day_cos = np.cos(2 * np.pi * timestamp.dayofweek / 7)
        month_sin = np.sin(2 * np.pi * timestamp.month / 12)
        month_cos = np.cos(2 * np.pi * timestamp.month / 12)
        
        # Amount features
        amount_paid = transaction['Amount Paid']
        amount_received = transaction['Amount Received']
        
        amount_paid_log = np.log1p(amount_paid)
        amount_received_log = np.log1p(amount_received)
        amount_ratio = amount_paid / max(amount_received, 1)
        
        # Categorical features
        currency_encoded = currency_encoder.transform([transaction['Payment Currency']])[0]
        format_encoded = format_encoder.transform([transaction['Payment Format']])[0]
        bank_encoded = bank_encoder.transform([transaction['From Bank']])[0]
        
        # Combine all features
        edge_feature = [
            hour_sin, hour_cos, day_sin, day_cos, month_sin, month_cos,
            amount_paid_log, amount_received_log, amount_ratio,
            currency_encoded, format_encoded, bank_encoded
        ]
        
        edge_features.append(edge_feature)
        edge_labels.append(transaction['Is Laundering'])
    
    return np.array(edge_features), np.array(edge_labels)

print("✓ Enhanced edge feature function defined")


In [None]:
# Class Imbalance Handling
def handle_class_imbalance(X, y, strategy='smote'):
    """Handle class imbalance using multiple strategies"""
    print(f"Handling class imbalance using {strategy}...")
    
    # Check class distribution
    class_counts = np.bincount(y)
    minority_count = min(class_counts)
    majority_count = max(class_counts)
    
    print(f"  - Original samples: {len(X)}")
    print(f"  - Original distribution: {class_counts}")
    print(f"  - Minority class: {minority_count} samples")
    print(f"  - Majority class: {majority_count} samples")
    
    if strategy == 'smote':
        # Check if SMOTE can be applied
        if minority_count < 2:
            print("⚠️  Too few minority samples for SMOTE (need at least 2)")
            print("🔄 Falling back to cost-sensitive learning only")
            X_resampled, y_resampled = X, y
            print("✓ Cost-sensitive learning applied (no resampling)")
            
        elif minority_count < 4:
            print("⚠️  Very few minority samples for SMOTE (need at least 4)")
            print("🔄 Using reduced k_neighbors for SMOTE")
            try:
                smote = SMOTE(random_state=42, k_neighbors=min(3, minority_count-1))
                X_resampled, y_resampled = smote.fit_resample(X, y)
                print(f"  - Resampled samples: {len(X_resampled)}")
                print(f"  - Resampled distribution: {np.bincount(y_resampled)}")
                print("✓ SMOTE completed with reduced k_neighbors")
            except Exception as e:
                print(f"⚠️  SMOTE failed: {e}")
                print("🔄 Falling back to cost-sensitive learning only")
                X_resampled, y_resampled = X, y
                print("✓ Cost-sensitive learning applied (no resampling)")
        else:
            print("Applying SMOTE oversampling...")
            try:
                smote = SMOTE(random_state=42, k_neighbors=3)
                X_resampled, y_resampled = smote.fit_resample(X, y)
                print(f"  - Resampled samples: {len(X_resampled)}")
                print(f"  - Resampled distribution: {np.bincount(y_resampled)}")
                print("✓ SMOTE completed")
            except Exception as e:
                print(f"⚠️  SMOTE failed: {e}")
                print("🔄 Falling back to cost-sensitive learning only")
                X_resampled, y_resampled = X, y
                print("✓ Cost-sensitive learning applied (no resampling)")
        
    elif strategy == 'none':
        X_resampled, y_resampled = X, y
        print("✓ No resampling applied")
    
    return X_resampled, y_resampled

def create_cost_sensitive_weights(y):
    """Create cost-sensitive class weights"""
    print("Creating cost-sensitive class weights...")
    
    # Check class distribution
    class_counts = np.bincount(y)
    minority_count = min(class_counts)
    majority_count = max(class_counts)
    
    print(f"  - Class distribution: {class_counts}")
    print(f"  - Minority class: {minority_count} samples")
    print(f"  - Majority class: {majority_count} samples")
    
    # Compute balanced class weights
    classes = np.unique(y)
    class_weights = compute_class_weight('balanced', classes=classes, y=y)
    
    # Additional cost for false negatives (missed illicit transactions)
    # Use higher multiplier for extreme imbalance
    if minority_count < 10:
        cost_multiplier = 100.0  # Very high cost for extreme imbalance
        print(f"  - Extreme imbalance detected: using {cost_multiplier}x cost multiplier")
    elif minority_count < 100:
        cost_multiplier = 50.0   # High cost for severe imbalance
        print(f"  - Severe imbalance detected: using {cost_multiplier}x cost multiplier")
    else:
        cost_multiplier = 10.0   # Standard cost for moderate imbalance
        print(f"  - Moderate imbalance detected: using {cost_multiplier}x cost multiplier")
    
    adjusted_weights = class_weights * cost_multiplier
    
    weight_dict = dict(zip(classes, adjusted_weights))
    print(f"  - Final class weights: {weight_dict}")
    
    return weight_dict

print("✓ Class imbalance handling functions defined")


In [None]:
# Main Preprocessing Pipeline
def run_simple_preprocessing(data_path, sample_size=10000, resume_from=None, chunk_size=1000):
    """Run simplified preprocessing pipeline with checkpointing"""
    print("=" * 60)
    print("Simple Enhanced AML Preprocessing Pipeline")
    print("=" * 60)
    
    import time
    start_time = time.time()
    
    # Create checkpoint directory
    checkpoint_dir = create_checkpoint_dir(data_path, sample_size)
    print(f"📁 Checkpoint directory: {checkpoint_dir}")
    
    # Check for existing checkpoints
    available_checkpoints = list_available_checkpoints(checkpoint_dir)
    if available_checkpoints:
        print(f"📋 Available checkpoints: {available_checkpoints}")
        if resume_from and resume_from in available_checkpoints:
            print(f"🔄 Resuming from checkpoint: {resume_from}")
        else:
            print("⚠️  Checkpoints found but resume_from not specified. Starting fresh.")
    
    # Load data
    print("\n📊 STEP 1: Loading Data")
    print("-" * 30)
    step_start = time.time()
    
    # Check if data is already loaded
    if resume_from and resume_from in ['data_loaded', 'node_features', 'edge_features', 'normalized', 'imbalanced', 'weights', 'graph_created']:
        print("⏭️  Skipping data loading (resuming from later step)")
        transactions, accounts = None, None
    else:
        transactions, accounts = load_aml_data(data_path, sample_size)
        # Save data checkpoint
        save_checkpoint(checkpoint_dir, 'data_loaded', {
            'transactions': transactions,
            'accounts': accounts,
            'sample_size': sample_size
        }, {'step': 'data_loading', 'sample_size': sample_size})
    
    step_time = time.time() - step_start
    print(f"✓ Data loading completed in {step_time:.2f} seconds")
    
    # Create enhanced features
    print("\n🔧 STEP 2: Creating Node Features")
    print("-" * 30)
    step_start = time.time()
    
    # Check if node features already exist
    if resume_from and resume_from in ['node_features', 'edge_features', 'normalized', 'imbalanced', 'weights', 'graph_created']:
        print("⏭️  Loading node features from checkpoint...")
        node_features, _ = load_checkpoint(checkpoint_dir, 'node_features')
        if node_features is None:
            print("⚠️  Node features checkpoint not found, creating new...")
            node_features = create_enhanced_node_features(transactions, accounts)
    else:
        # Load data if not already loaded
        if transactions is None or accounts is None:
            data_checkpoint, _ = load_checkpoint(checkpoint_dir, 'data_loaded')
            if data_checkpoint:
                transactions = data_checkpoint['transactions']
                accounts = data_checkpoint['accounts']
            else:
                raise ValueError("No data available and no checkpoint found")
        
        # Process in chunks for large datasets
        if len(accounts) > chunk_size:
            print(f"📦 Processing {len(accounts)} accounts in chunks of {chunk_size}")
            node_features = process_data_in_chunks(
                accounts, chunk_size, 
                lambda chunk: create_enhanced_node_features(transactions, chunk),
                checkpoint_dir, 'node_features'
            )
            # Combine chunk results
            node_features = combine_chunk_results(node_features, lambda chunks: {k: v for chunk in chunks for k, v in chunk.items()})
        else:
            node_features = create_enhanced_node_features(transactions, accounts)
        
        # Save node features checkpoint
        save_checkpoint(checkpoint_dir, 'node_features', node_features, {
            'step': 'node_features',
            'num_accounts': len(accounts),
            'num_features': len(node_features)
        })
    
    step_time = time.time() - step_start
    print(f"✓ Node features completed in {step_time:.2f} seconds")
    
    print("\n🔧 STEP 3: Creating Edge Features")
    print("-" * 30)
    step_start = time.time()
    
    # Check if edge features already exist
    if resume_from and resume_from in ['edge_features', 'normalized', 'imbalanced', 'weights', 'graph_created']:
        print("⏭️  Loading edge features from checkpoint...")
        edge_data, _ = load_checkpoint(checkpoint_dir, 'edge_features')
        if edge_data:
            edge_features, edge_labels = edge_data['features'], edge_data['labels']
        else:
            print("⚠️  Edge features checkpoint not found, creating new...")
            edge_features, edge_labels = create_enhanced_edge_features(transactions)
    else:
        # Load data if not already loaded
        if transactions is None:
            data_checkpoint, _ = load_checkpoint(checkpoint_dir, 'data_loaded')
            if data_checkpoint:
                transactions = data_checkpoint['transactions']
            else:
                raise ValueError("No data available and no checkpoint found")
        
        # Process in chunks for large datasets
        if len(transactions) > chunk_size:
            print(f"📦 Processing {len(transactions)} transactions in chunks of {chunk_size}")
            edge_results = process_data_in_chunks(
                transactions, chunk_size,
                lambda chunk: create_enhanced_edge_features(chunk),
                checkpoint_dir, 'edge_features'
            )
            # Combine chunk results
            all_features = []
            all_labels = []
            for chunk_result in edge_results:
                all_features.extend(chunk_result[0])
                all_labels.extend(chunk_result[1])
            edge_features = np.array(all_features)
            edge_labels = np.array(all_labels)
        else:
            edge_features, edge_labels = create_enhanced_edge_features(transactions)
        
        # Save edge features checkpoint
        save_checkpoint(checkpoint_dir, 'edge_features', {
            'features': edge_features,
            'labels': edge_labels
        }, {
            'step': 'edge_features',
            'num_transactions': len(transactions),
            'num_features': len(edge_features)
        })
    
    step_time = time.time() - step_start
    print(f"✓ Edge features completed in {step_time:.2f} seconds")
    
    # Normalize features
    print("\n📊 STEP 4: Normalizing Features")
    print("-" * 30)
    step_start = time.time()
    
    # Check if normalization already exists
    if resume_from and resume_from in ['normalized', 'imbalanced', 'weights', 'graph_created']:
        print("⏭️  Loading normalized features from checkpoint...")
        normalized_data, _ = load_checkpoint(checkpoint_dir, 'normalized')
        if normalized_data:
            node_feature_matrix = normalized_data['node_feature_matrix']
        else:
            print("⚠️  Normalized features checkpoint not found, creating new...")
            node_feature_matrix = np.array([list(features.values()) for features in node_features.values()])
    else:
        node_feature_matrix = np.array([list(features.values()) for features in node_features.values()])
        # Save normalization checkpoint
        save_checkpoint(checkpoint_dir, 'normalized', {
            'node_feature_matrix': node_feature_matrix
        }, {
            'step': 'normalization',
            'matrix_shape': node_feature_matrix.shape
        })
    
    step_time = time.time() - step_start
    print(f"✓ Feature normalization completed in {step_time:.2f} seconds")
    
    # Handle class imbalance
    print("\n⚖️ STEP 5: Handling Class Imbalance")
    print("-" * 30)
    step_start = time.time()
    
    # Check if imbalance handling already exists
    if resume_from and resume_from in ['imbalanced', 'weights', 'graph_created']:
        print("⏭️  Loading imbalance handling from checkpoint...")
        imbalance_data, _ = load_checkpoint(checkpoint_dir, 'imbalanced')
        if imbalance_data:
            X_resampled, y_resampled = imbalance_data['X_resampled'], imbalance_data['y_resampled']
        else:
            print("⚠️  Imbalance handling checkpoint not found, creating new...")
            # Try SMOTE first, fallback to cost-sensitive learning if it fails
            try:
                X_resampled, y_resampled = handle_class_imbalance(edge_features, edge_labels, strategy='smote')
            except Exception as e:
                print(f"⚠️  SMOTE failed completely: {e}")
                print("🔄 Using cost-sensitive learning only")
                X_resampled, y_resampled = edge_features, edge_labels
                print("✓ Cost-sensitive learning applied (no resampling)")
    else:
        # Try SMOTE first, fallback to cost-sensitive learning if it fails
        try:
            X_resampled, y_resampled = handle_class_imbalance(edge_features, edge_labels, strategy='smote')
        except Exception as e:
            print(f"⚠️  SMOTE failed completely: {e}")
            print("🔄 Using cost-sensitive learning only")
            X_resampled, y_resampled = edge_features, edge_labels
            print("✓ Cost-sensitive learning applied (no resampling)")
        
        # Save imbalance handling checkpoint
        save_checkpoint(checkpoint_dir, 'imbalanced', {
            'X_resampled': X_resampled,
            'y_resampled': y_resampled
        }, {
            'step': 'imbalance_handling',
            'original_shape': edge_features.shape,
            'resampled_shape': X_resampled.shape
        })
    
    step_time = time.time() - step_start
    print(f"✓ Class imbalance handling completed in {step_time:.2f} seconds")
    
    # Create cost-sensitive weights
    print("\n🎯 STEP 6: Creating Cost-Sensitive Weights")
    print("-" * 30)
    step_start = time.time()
    
    # Check if weights already exist
    if resume_from and resume_from in ['weights', 'graph_created']:
        print("⏭️  Loading cost-sensitive weights from checkpoint...")
        weights_data, _ = load_checkpoint(checkpoint_dir, 'weights')
        if weights_data:
            class_weights = weights_data['class_weights']
        else:
            print("⚠️  Weights checkpoint not found, creating new...")
            class_weights = create_cost_sensitive_weights(edge_labels)
    else:
        class_weights = create_cost_sensitive_weights(edge_labels)
        # Save weights checkpoint
        save_checkpoint(checkpoint_dir, 'weights', {
            'class_weights': class_weights
        }, {
            'step': 'weights_creation',
            'weights': class_weights
        })
    
    step_time = time.time() - step_start
    print(f"✓ Cost-sensitive weights completed in {step_time:.2f} seconds")
    
    # Create graph structure
    print("\n🕸️ STEP 7: Creating Graph Structure")
    print("-" * 30)
    step_start = time.time()
    
    # Check if graph already exists
    if resume_from and resume_from == 'graph_created':
        print("⏭️  Loading graph from checkpoint...")
        graph_data, _ = load_checkpoint(checkpoint_dir, 'graph_created')
        if graph_data:
            G = graph_data['graph']
        else:
            print("⚠️  Graph checkpoint not found, creating new...")
            G = nx.DiGraph()
            # Add nodes and edges (same as below)
    else:
        G = nx.DiGraph()
        
        # Add nodes with features
        print("Adding nodes...")
        for account_id, features in node_features.items():
            G.add_node(account_id, **features)
        
        # Add edges with features
        print("Adding edges...")
        from tqdm import tqdm
        
        # Load transactions if not already loaded
        if transactions is None:
            data_checkpoint, _ = load_checkpoint(checkpoint_dir, 'data_loaded')
            if data_checkpoint:
                transactions = data_checkpoint['transactions']
            else:
                raise ValueError("No data available and no checkpoint found")
        
        for i, (_, transaction) in enumerate(tqdm(transactions.iterrows(), total=len(transactions), desc="Adding Edges")):
            from_account = transaction['Account']
            to_account = transaction['Account.1']
            
            if from_account in G.nodes() and to_account in G.nodes():
                G.add_edge(from_account, to_account, 
                          features=edge_features[i], 
                          label=edge_labels[i])
        
        # Save graph checkpoint
        save_checkpoint(checkpoint_dir, 'graph_created', {
            'graph': G
        }, {
            'step': 'graph_creation',
            'num_nodes': G.number_of_nodes(),
            'num_edges': G.number_of_edges()
        })
    
    step_time = time.time() - step_start
    print(f"✓ Graph structure completed in {step_time:.2f} seconds")
    
    total_time = time.time() - start_time
    print(f"\n🎉 Preprocessing completed successfully!")
    print(f"⏱️ Total time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")
    print(f"📊 Results:")
    print(f"  - Nodes: {G.number_of_nodes()}")
    print(f"  - Edges: {G.number_of_edges()}")
    print(f"  - Node features: {node_feature_matrix.shape[1]}")
    print(f"  - Edge features: {edge_features.shape[1]}")
    print(f"  - Class distribution: {np.bincount(edge_labels)}")
    
    return G, node_features, edge_features, edge_labels, class_weights

print("✓ Main preprocessing pipeline defined")


In [None]:
# Run Simple Enhanced Preprocessing with Checkpointing
print("Starting simple enhanced preprocessing with checkpointing...")

# Configuration
data_path = "/content/drive/MyDrive/LaunDetection/data/raw"
sample_size = 10000  # Start with sample for testing
chunk_size = 1000    # Process in chunks of 1000 for large datasets
resume_from = None   # Set to resume from specific checkpoint: 'data_loaded', 'node_features', 'edge_features', 'normalized', 'imbalanced', 'weights', 'graph_created'

# Show time estimation
estimated_time = estimate_processing_time(sample_size)
print(f"📊 Sample size: {sample_size:,} transactions")
print(f"⏱️ Estimated processing time: {estimated_time}")
print(f"💾 Expected memory usage: ~2-4 GB")
print(f"🔧 Features to create: 15 node features + 12 edge features per transaction")
print(f"📦 Chunk size: {chunk_size:,} items per chunk")
print(f"🔄 Resume from: {resume_from or 'Start from beginning'}")

try:
    # Run preprocessing with checkpointing
    G, node_features, edge_features, edge_labels, class_weights = run_simple_preprocessing(
        data_path, sample_size, resume_from=resume_from, chunk_size=chunk_size
    )
    
    print("\n" + "=" * 60)
    print("Preprocessing Summary")
    print("=" * 60)
    print(f"✓ Graph created: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    print(f"✓ Node features: {len(node_features)} accounts with enhanced features")
    print(f"✓ Edge features: {edge_features.shape[0]} transactions with {edge_features.shape[1]} features")
    print(f"✓ Class weights: {class_weights}")
    
    # Show sample features
    print(f"\nSample node features for first account:")
    first_account = list(node_features.keys())[0]
    print(f"Account: {first_account}")
    for key, value in list(node_features[first_account].items())[:5]:
        print(f"  {key}: {value}")
    
    print(f"\nSample edge features:")
    print(f"  Temporal features: {edge_features[0][:6]}")
    print(f"  Amount features: {edge_features[0][6:9]}")
    print(f"  Categorical features: {edge_features[0][9:12]}")
    
except Exception as e:
    print(f"✗ Error during preprocessing: {e}")
    import traceback
    traceback.print_exc()
