# Multi-Dataset AML Training Pipeline

This notebook implements the complete multi-dataset training pipeline with memory management for large datasets.


In [None]:
# Install required packages
%pip install torch torch-geometric networkx pandas numpy scikit-learn imbalanced-learn tqdm matplotlib seaborn plotly --quiet


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Import required libraries
import sys
import os
import pandas as pd
import numpy as np
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader
import pickle
import gc
import time
from tqdm import tqdm
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, confusion_matrix, roc_auc_score
import warnings
warnings.filterwarnings('ignore')

print("🚀 Multi-Dataset AML Training Pipeline")
print("=" * 60)


In [None]:
# Add current directory to path for imports
sys.path.append('/content/drive/MyDrive/LaunDetection')

# Import our custom modules
from multi_dataset_preprocessing import MultiDatasetPreprocessor
from multi_dataset_training import MultiDatasetTrainer

print("✅ All modules imported successfully")


In [None]:
# Step 1: Multi-Dataset Preprocessing with Memory Management
print("🔄 STEP 1: MULTI-DATASET PREPROCESSING")
print("=" * 50)

preprocessor = MultiDatasetPreprocessor()
processed_data = preprocessor.run_full_preprocessing()

if processed_data:
    print("✅ Multi-dataset preprocessing completed successfully!")
else:
    print("❌ Multi-dataset preprocessing failed!")
    raise Exception("Preprocessing failed")


In [None]:
# Step 2: Multi-Dataset Training
print("🚀 STEP 2: MULTI-DATASET TRAINING")
print("=" * 50)

trainer = MultiDatasetTrainer()

# Load processed datasets
datasets = trainer.load_processed_datasets()

if not datasets:
    print("❌ No processed datasets found!")
    raise Exception("No processed datasets found")

print(f"✅ Loaded {len(datasets)} processed datasets")


In [None]:
# Create combined dataset
print("🔄 Creating combined multi-dataset...")
combined_data = trainer.create_combined_dataset(datasets)

# Convert to PyTorch format
print("🔄 Converting to PyTorch Geometric format...")
data = trainer.create_pytorch_data(combined_data)

print(f"✅ Combined dataset ready: {data.num_nodes:,} nodes, {data.num_edges:,} edges")


In [None]:
# Train the multi-dataset model
print("🚀 Starting Multi-Dataset Training...")
model, best_f1 = trainer.train_multi_dataset_model(data, epochs=100, learning_rate=0.001)

print(f"✅ Training completed! Best F1: {best_f1:.4f}")


In [None]:
# Evaluate the trained model
print("📊 Evaluating Multi-Dataset Model...")
metrics, aml_metrics, cm = trainer.evaluate_multi_dataset_model(data)

print("\n🎉 MULTI-DATASET TRAINING COMPLETE!")
print("=" * 50)
print(f"✅ Overall F1: {metrics['f1_weighted']:.4f}")
print(f"✅ AML F1: {aml_metrics['aml_f1']:.4f}")
print(f"✅ ROC-AUC: {metrics['roc_auc']:.4f}")

if aml_metrics['aml_f1'] > 0.5:
    print("🎉 EXCELLENT! AML detection significantly improved!")
elif aml_metrics['aml_f1'] > 0.3:
    print("✅ GOOD! AML detection improved!")
else:
    print("⚠️ AML detection needs further improvement")
