## 1. Setup and Imports

In [None]:
# Standard imports
import os
import sys
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# PyTorch
import torch
import torch.nn as nn

# Add project to path
sys.path.insert(0, str(Path.cwd()))

# Project imports
from token_malice_prediction.data import (
    TokenPreprocessor,
    TransactionGraphBuilder,
    build_graphs_from_processed_data,
    TokenGraphDatasetList,
    create_data_loaders
)
from token_malice_prediction.models import TokenGATClassifier, create_model
from token_malice_prediction.training import Trainer, compute_metrics
from token_malice_prediction.evaluation import compute_classification_metrics
from token_malice_prediction.utils import set_seed, get_device, load_config

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Configuration
set_seed(42)
device = get_device()
print(f"Using device: {device}")

## 2. Data Overview

### 2.1 Load and Preprocess Token Data

In [None]:
# Load configuration
config = load_config('config.yaml')

# Initialize preprocessor
preprocessor = TokenPreprocessor(
    data_dir=config.data.data_dir,
    malice_threshold=config.data.malice_threshold,
    min_transactions=config.data.min_transactions
)

# Process all token CSV files
processed_data = preprocessor.process_directory()

# Display statistics
print("="*60)
print("DATASET STATISTICS")
print("="*60)
print(f"\nTotal tokens processed: {len(processed_data)}")

labels = [label for _, label, _ in processed_data]
print(f"\nClass distribution:")
print(f"  Benign (0):    {labels.count(0)}")
print(f"  Malicious (1): {labels.count(1)}")
print(f"\nMalicious rate: {sum(labels)/len(labels):.2%}")

# Show sample token statistics
print(f"\nSample token transaction counts:")
for df, label, name in processed_data[:5]:
    print(f"  {name[:20]:20s}: {len(df):5d} transactions, label={label}")

### 2.2 Build Transaction Graphs

In [None]:
# Build transaction graphs from processed data
graphs_with_names = build_graphs_from_processed_data(processed_data)
graphs = [g for g, _ in graphs_with_names]

print("="*60)
print("GRAPH STATISTICS")
print("="*60)

# Compute statistics
num_nodes = [g.num_nodes for g in graphs]
num_edges = [g.edge_index.shape[1] for g in graphs]
graph_labels = [g.y.item() for g in graphs]

print(f"\nTotal graphs: {len(graphs)}")
print(f"\nNode statistics:")
print(f"  Min:  {min(num_nodes)}")
print(f"  Max:  {max(num_nodes)}")
print(f"  Mean: {np.mean(num_nodes):.1f}")
print(f"  Std:  {np.std(num_nodes):.1f}")

print(f"\nEdge statistics:")
print(f"  Min:  {min(num_edges)}")
print(f"  Max:  {max(num_edges)}")
print(f"  Mean: {np.mean(num_edges):.1f}")
print(f"  Std:  {np.std(num_edges):.1f}")

# Plot distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(num_nodes, bins=30, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Number of Nodes')
axes[0].set_ylabel('Count')
axes[0].set_title('Node Count Distribution')

axes[1].hist(num_edges, bins=30, edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Number of Edges')
axes[1].set_ylabel('Count')
axes[1].set_title('Edge Count Distribution')

# Node vs Edge scatter
axes[2].scatter(num_nodes, num_edges, c=graph_labels, cmap='coolwarm', alpha=0.6)
axes[2].set_xlabel('Number of Nodes')
axes[2].set_ylabel('Number of Edges')
axes[2].set_title('Nodes vs Edges (colored by label)')

plt.tight_layout()
plt.show()

### 2.3 Create Dataset and Data Loaders

In [None]:
# Create dataset
dataset = TokenGraphDatasetList(graphs)

# Create data loaders with stratified split
train_loader, val_loader, test_loader = create_data_loaders(
    dataset=dataset,
    train_ratio=config.data.train_ratio,
    val_ratio=config.data.val_ratio,
    test_ratio=config.data.test_ratio,
    batch_size=config.data.batch_size,
    random_seed=config.seed
)

# Class weights for imbalanced data
class_weights = dataset.get_class_weights()

print("="*60)
print("DATA LOADERS")
print("="*60)
print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches:   {len(val_loader)}")
print(f"Test batches:  {len(test_loader)}")
print(f"\nClass weights: {class_weights.tolist()}")

## 3. Model Architecture

### 3.1 TokenFormer Architecture Overview

In [None]:
# Create GAT-based model
model = create_model(
    node_dim=TransactionGraphBuilder.get_node_feature_dim(),
    edge_dim=TransactionGraphBuilder.get_edge_feature_dim(),
    hidden_dim=config.model.hidden_dim,
    num_classes=config.model.num_classes,
    num_layers=config.model.num_layers,
    num_heads=config.model.num_heads,
    dropout=config.model.dropout,
    pooling=config.model.pooling,
    use_edge_features=config.model.use_edge_features
)

# Print model summary
print("="*60)
print("GAT CLASSIFIER MODEL ARCHITECTURE")
print("="*60)
print(model)
print("\n" + "="*60)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 4. Training Results

### 4.1 Training History

In [None]:
# Setup training
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.training.learning_rate,
    weight_decay=config.training.weight_decay
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    device=str(device),
    scheduler=scheduler,
    class_weights=class_weights,
    gradient_clip=config.training.gradient_clip
)

# Train the model
print("="*60)
print("TRAINING")
print("="*60)

history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=config.training.num_epochs,
    early_stopping_patience=config.training.early_stopping_patience,
    save_path='outputs/best_model.pt'
)

# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

epochs = range(1, len(history['train_loss']) + 1)

# Loss
axes[0].plot(epochs, history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(epochs, history['val_loss'], label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(epochs, history['train_acc'], label='Train Accuracy', linewidth=2)
axes[1].plot(epochs, history['val_acc'], label='Val Accuracy', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
os.makedirs('outputs', exist_ok=True)
plt.savefig('outputs/training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Evaluation Results

### 5.1 Classification Metrics

In [None]:
# Evaluate on test set
test_results = trainer.evaluate(test_loader, desc='Test Evaluation')

# Compute metrics
metrics = compute_metrics(
    predictions=test_results['predictions'],
    labels=test_results['labels'],
    probabilities=test_results['probabilities']
)

# Display metrics
print("="*60)
print("CLASSIFICATION METRICS")
print("="*60)
print(f"\nAccuracy:    {metrics['accuracy']:.4f}")
print(f"Precision:   {metrics['precision']:.4f}")
print(f"Recall:      {metrics['recall']:.4f}")
print(f"F1 Score:    {metrics['f1']:.4f}")
if 'auc_roc' in metrics:
    print(f"AUC-ROC:     {metrics['auc_roc']:.4f}")
print("\nConfusion Matrix:")
print(f"  TN: {metrics.get('true_negatives', 'N/A'):4}  FP: {metrics.get('false_positives', 'N/A'):4}")
print(f"  FN: {metrics.get('false_negatives', 'N/A'):4}  TP: {metrics.get('true_positives', 'N/A'):4}")

### 5.2 ROC and Precision-Recall Curves

In [None]:
from sklearn.metrics import roc_curve, precision_recall_curve, auc

y_true = np.array(test_results['labels'])
y_prob = np.array(test_results['probabilities'])

# Compute curves
fpr, tpr, _ = roc_curve(y_true, y_prob)
precision, recall, _ = precision_recall_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)
pr_auc = auc(recall, precision)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# ROC Curve
axes[0].plot(fpr, tpr, 'b-', linewidth=2, label=f'GAT Classifier (AUC = {roc_auc:.3f})')
axes[0].plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random Classifier')
axes[0].fill_between(fpr, tpr, alpha=0.2)
axes[0].set_xlabel('False Positive Rate', fontsize=12)
axes[0].set_ylabel('True Positive Rate', fontsize=12)
axes[0].set_title('ROC Curve', fontsize=14)
axes[0].legend(loc='lower right')
axes[0].grid(True, alpha=0.3)
axes[0].set_xlim([0, 1])
axes[0].set_ylim([0, 1])

# Precision-Recall Curve
axes[1].plot(recall, precision, 'g-', linewidth=2, label=f'GAT Classifier (AP = {pr_auc:.3f})')
baseline = y_true.mean()
axes[1].axhline(y=baseline, color='k', linestyle='--', linewidth=1, label=f'Baseline ({baseline:.3f})')
axes[1].fill_between(recall, precision, alpha=0.2, color='green')
axes[1].set_xlabel('Recall', fontsize=12)
axes[1].set_ylabel('Precision', fontsize=12)
axes[1].set_title('Precision-Recall Curve', fontsize=14)
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim([0, 1])
axes[1].set_ylim([0, 1])

plt.tight_layout()
plt.savefig('outputs/roc_pr_curves.png', dpi=150, bbox_inches='tight')
plt.show()

### 5.3 Confusion Matrix Visualization

In [None]:
from sklearn.metrics import confusion_matrix

y_pred = np.array(test_results['predictions'])

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot
plt.figure(figsize=(8, 6))

# Normalize for percentages
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Create annotations
annot = np.array([[f'{cm[i,j]}\n({cm_norm[i,j]:.1%})' for j in range(2)] for i in range(2)])

sns.heatmap(
    cm,
    annot=annot,
    fmt='',
    cmap='Blues',
    xticklabels=['Benign', 'Malicious'],
    yticklabels=['Benign', 'Malicious'],
    square=True,
    linewidths=0.5,
    cbar_kws={'label': 'Count'}
)

plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion Matrix', fontsize=14)
plt.tight_layout()
plt.savefig('outputs/confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Graph Analysis

### 6.1 Graph Feature Analysis by Label

In [None]:
# Analyze graph features by label
graph_stats = []
for g in graphs:
    stats = {
        'label': g.y.item(),
        'num_nodes': g.num_nodes,
        'num_edges': g.edge_index.shape[1],
        'density': g.edge_index.shape[1] / (g.num_nodes * (g.num_nodes - 1) + 1e-8),
        'avg_node_degree': g.edge_index.shape[1] / (g.num_nodes + 1e-8),
    }
    # Node feature statistics
    if g.x is not None:
        stats['avg_in_degree'] = g.x[:, 0].mean().item()
        stats['avg_out_degree'] = g.x[:, 1].mean().item()
        stats['max_total_degree'] = g.x[:, 2].max().item()
    graph_stats.append(stats)

stats_df = pd.DataFrame(graph_stats)

# Compare statistics by label
print("="*60)
print("GRAPH STATISTICS BY LABEL")
print("="*60)
print("\nBenign tokens (label=0):")
print(stats_df[stats_df['label'] == 0].describe())
print("\nMalicious tokens (label=1):")
print(stats_df[stats_df['label'] == 1].describe())

# Plot comparisons
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

features_to_plot = ['num_nodes', 'num_edges', 'density', 'avg_node_degree']
for ax, feat in zip(axes.flatten(), features_to_plot):
    for label, color, name in [(0, 'steelblue', 'Benign'), (1, 'crimson', 'Malicious')]:
        data = stats_df[stats_df['label'] == label][feat]
        ax.hist(data, bins=20, alpha=0.6, color=color, label=name, density=True)
    ax.set_xlabel(feat.replace('_', ' ').title())
    ax.set_ylabel('Density')
    ax.legend()
    ax.set_title(f'{feat} by Label')

plt.tight_layout()
plt.savefig('outputs/graph_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Attention Analysis

In [None]:
# Analyze attention weights from the GAT model
# Get a sample graph for visualization
sample_idx = 0
sample_graph = graphs[sample_idx]

model.eval()
with torch.no_grad():
    sample_graph = sample_graph.to(device)
    
    # Get attention weights from last layer
    try:
        edge_index, attention = model.get_attention_weights(
            x=sample_graph.x,
            edge_index=sample_graph.edge_index,
            edge_attr=sample_graph.edge_attr,
            layer_idx=-1
        )
        
        # Analyze attention distribution
        attention_np = attention.cpu().numpy()
        
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.hist(attention_np.flatten(), bins=50, edgecolor='black', alpha=0.7)
        plt.xlabel('Attention Weight')
        plt.ylabel('Frequency')
        plt.title('Attention Weight Distribution')
        
        plt.subplot(1, 2, 2)
        plt.hist(attention_np.mean(axis=1), bins=30, edgecolor='black', alpha=0.7)
        plt.xlabel('Mean Attention per Edge')
        plt.ylabel('Frequency')
        plt.title('Mean Attention per Edge (across heads)')
        
        plt.tight_layout()
        plt.savefig('outputs/attention_analysis.png', dpi=150, bbox_inches='tight')
        plt.show()
        
        print(f"Attention statistics:")
        print(f"  Min:  {attention_np.min():.4f}")
        print(f"  Max:  {attention_np.max():.4f}")
        print(f"  Mean: {attention_np.mean():.4f}")
        print(f"  Std:  {attention_np.std():.4f}")
    except Exception as e:
        print(f"Could not extract attention weights: {e}")
        print("This is expected if using the simple model variant.")

## 8. Prediction Analysis

In [None]:
# Analyze prediction confidence
y_prob = np.array(test_results['probabilities'])
y_true = np.array(test_results['labels'])
y_pred = np.array(test_results['predictions'])

# Separate correct and incorrect predictions
correct_mask = y_pred == y_true
incorrect_mask = ~correct_mask

plt.figure(figsize=(14, 5))

# Confidence distribution
plt.subplot(1, 2, 1)
plt.hist(y_prob[correct_mask], bins=30, alpha=0.6, label='Correct', color='green', density=True)
plt.hist(y_prob[incorrect_mask], bins=30, alpha=0.6, label='Incorrect', color='red', density=True)
plt.xlabel('Predicted Probability (Malicious)')
plt.ylabel('Density')
plt.title('Prediction Confidence Distribution')
plt.legend()

# Calibration plot
plt.subplot(1, 2, 2)
n_bins = 10
bin_edges = np.linspace(0, 1, n_bins + 1)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2

# Calculate actual positive rate per bin
actual_pos_rate = []
predicted_pos_rate = []
for i in range(n_bins):
    mask = (y_prob >= bin_edges[i]) & (y_prob < bin_edges[i+1])
    if mask.sum() > 0:
        actual_pos_rate.append(y_true[mask].mean())
        predicted_pos_rate.append(y_prob[mask].mean())
    else:
        actual_pos_rate.append(np.nan)
        predicted_pos_rate.append(np.nan)

plt.plot([0, 1], [0, 1], 'k--', label='Perfectly Calibrated')
plt.plot(predicted_pos_rate, actual_pos_rate, 'bo-', label='Model')
plt.xlabel('Mean Predicted Probability')
plt.ylabel('Actual Positive Rate')
plt.title('Calibration Plot')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('outputs/prediction_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nPrediction Statistics:")
print(f"  Correct predictions:   {correct_mask.sum()} ({correct_mask.mean():.1%})")
print(f"  Incorrect predictions: {incorrect_mask.sum()} ({incorrect_mask.mean():.1%})")
print(f"  Mean confidence (correct):   {y_prob[correct_mask].mean():.3f}")
print(f"  Mean confidence (incorrect): {y_prob[incorrect_mask].mean():.3f}" if incorrect_mask.sum() > 0 else "")

## 9. Error Analysis

In [None]:
# Analyze misclassified samples
false_positives = (y_pred == 1) & (y_true == 0)
false_negatives = (y_pred == 0) & (y_true == 1)

print("="*60)
print("ERROR ANALYSIS")
print("="*60)

print(f"\nFalse Positives (benign predicted as malicious): {false_positives.sum()}")
print(f"False Negatives (malicious predicted as benign): {false_negatives.sum()}")

# Get indices of misclassified samples in the test set
# (Note: These are indices in the test_results, not the original dataset)

if false_positives.sum() > 0:
    print(f"\nFalse Positive prediction probabilities:")
    fp_probs = y_prob[false_positives]
    print(f"  Mean: {fp_probs.mean():.3f}")
    print(f"  Std:  {fp_probs.std():.3f}")
    print(f"  Min:  {fp_probs.min():.3f}")
    print(f"  Max:  {fp_probs.max():.3f}")

if false_negatives.sum() > 0:
    print(f"\nFalse Negative prediction probabilities:")
    fn_probs = y_prob[false_negatives]
    print(f"  Mean: {fn_probs.mean():.3f}")
    print(f"  Std:  {fn_probs.std():.3f}")
    print(f"  Min:  {fn_probs.min():.3f}")
    print(f"  Max:  {fn_probs.max():.3f}")

# Plot error distribution
if false_positives.sum() > 0 or false_negatives.sum() > 0:
    plt.figure(figsize=(10, 4))
    
    if false_positives.sum() > 0:
        plt.hist(y_prob[false_positives], bins=20, alpha=0.6, label='False Positives', color='orange')
    if false_negatives.sum() > 0:
        plt.hist(y_prob[false_negatives], bins=20, alpha=0.6, label='False Negatives', color='purple')
    
    plt.axvline(x=0.5, color='k', linestyle='--', label='Decision Threshold')
    plt.xlabel('Predicted Probability (Malicious)')
    plt.ylabel('Count')
    plt.title('Misclassified Samples by Prediction Probability')
    plt.legend()
    plt.tight_layout()
    plt.savefig('outputs/error_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()

## 10. Summary and Conclusions

In [None]:
print("="*70)
print("SUMMARY OF RESULTS")
print("="*70)

print(f"""
1. DATASET
   - Total tokens processed: {len(processed_data)}
   - Malicious tokens: {sum(labels)}
   - Benign tokens: {len(labels) - sum(labels)}
   - Malicious rate: {sum(labels)/len(labels):.1%}

2. GRAPH CONSTRUCTION
   - Total graphs: {len(graphs)}
   - Average nodes per graph: {np.mean(num_nodes):.1f}
   - Average edges per graph: {np.mean(num_edges):.1f}
   - Node features: {TransactionGraphBuilder.get_node_feature_dim()}
   - Edge features: {TransactionGraphBuilder.get_edge_feature_dim()}

3. MODEL ARCHITECTURE
   - GAT-based classifier with GATv2Conv layers
   - Hidden dimension: {config.model.hidden_dim}
   - Number of layers: {config.model.num_layers}
   - Attention heads: {config.model.num_heads}
   - Total parameters: {total_params:,}

4. CLASSIFICATION PERFORMANCE
   - Accuracy:  {metrics['accuracy']:.4f}
   - Precision: {metrics['precision']:.4f}
   - Recall:    {metrics['recall']:.4f}
   - F1 Score:  {metrics['f1']:.4f}
   - AUC-ROC:   {metrics.get('auc_roc', 'N/A')}

5. KEY FINDINGS
   - Transaction graph structure captures malicious patterns
   - Edge features (amount, value, time) improve predictions
   - GAT attention mechanism highlights suspicious transactions
""")

### 10.1 Export Results

In [None]:
# Create outputs directory
os.makedirs('outputs', exist_ok=True)

# Save results to JSON
final_results = {
    'dataset': {
        'total_tokens': len(processed_data),
        'malicious_count': sum(labels),
        'benign_count': len(labels) - sum(labels),
        'malicious_rate': sum(labels) / len(labels)
    },
    'graphs': {
        'total': len(graphs),
        'avg_nodes': float(np.mean(num_nodes)),
        'avg_edges': float(np.mean(num_edges))
    },
    'model': {
        'hidden_dim': config.model.hidden_dim,
        'num_layers': config.model.num_layers,
        'num_heads': config.model.num_heads,
        'total_params': total_params
    },
    'classification_metrics': {k: float(v) if isinstance(v, (int, float, np.floating, np.integer)) else v 
                               for k, v in metrics.items()},
    'training_history': {k: [float(x) for x in v] for k, v in history.items()}
}

with open('outputs/final_results.json', 'w') as f:
    json.dump(final_results, f, indent=2)

print("Results exported to outputs/final_results.json")
print("\nGenerated files:")
for f in sorted(os.listdir('outputs')):
    print(f"  - outputs/{f}")