# Mamba-GNN Training Notebook (EstraNet-Aligned Configuration)

**⚠️ IMPORTANT**: This notebook uses the **corrected configuration** that matches EstraNet for fair comparison.

## Key Configuration ✅
- **Loss Function**: Cross-Entropy (not Focal Loss)
- **Learning Rate**: 2.5e-4 (not 2e-3)
- **Batch Size**: 256 train / 32 eval (not 64)
- **Training**: 100k steps (not 50-100 epochs)
- **Optimizer**: Adam (not AdamW)
- **Scheduler**: Cosine Decay (not OneCycleLR)
- **Evaluation**: 100-trial Guessing Entropy (not single trial)

## DO NOT USE `final_best_gnn_mamba_teacher.ipynb` ❌
That notebook has incorrect configuration and cannot be fairly compared with EstraNet!

## Setup Environment

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

# Add paths
sys.path.append(str(Path.cwd() / 'mamba-gnn-scripts'))
sys.path.append(str(Path.cwd()))

# Configuration
DATA_PATH = 'data/ASCAD.h5'
CHECKPOINT_DIR = 'checkpoints/mamba_gnn_estranet'
RESULT_DIR = 'results'

# Create directories
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)

print(f"\n✓ Environment ready")
print(f"  Data path: {DATA_PATH}")
print(f"  Checkpoint dir: {CHECKPOINT_DIR}")
print(f"  Result dir: {RESULT_DIR}")

## Configuration Comparison

### ❌ Old Notebook (final_best_gnn_mamba_teacher.ipynb)
```python
# WRONG CONFIGURATION - DO NOT USE
loss = FocalLoss(gamma=2.5)
optimizer = AdamW(lr=2e-3)
batch_size = 64
epochs = 50-100
scheduler = OneCycleLR(max_lr=2e-3)
evaluation = single_trial_key_rank()
```

### ✅ New Configuration (This Notebook)
```python
# CORRECT CONFIGURATION - Matches EstraNet
loss = CrossEntropyLoss()
optimizer = Adam(lr=2.5e-4)
batch_size = 256 (train), 32 (eval)
train_steps = 100000
scheduler = CosineLRSchedule(max_lr=2.5e-4)
evaluation = compute_ge_key_rank(num_trials=100)
```

## Option 1: Train Using PyTorch Script (Recommended)

This uses the corrected `train_mamba_gnn.py` script with EstraNet-aligned configuration.

In [None]:
# Training configuration (EstraNet-aligned)
config = {
    'data_path': DATA_PATH,
    'checkpoint_dir': CHECKPOINT_DIR,
    'target_byte': 2,
    'train_batch_size': 256,      # ✓ Matches EstraNet
    'eval_batch_size': 32,        # ✓ Matches EstraNet
    'train_steps': 100000,        # ✓ Matches EstraNet
    'learning_rate': 2.5e-4,      # ✓ Matches EstraNet
    'd_model': 128,               # ✓ Matches EstraNet
    'mamba_layers': 4,
    'gnn_layers': 3,
    'k_neighbors': 8,
    'dropout': 0.1,               # ✓ Matches EstraNet
    'iterations': 500,
    'eval_steps': 500,
    'save_steps': 10000,
    'clip': 0.25,                 # ✓ Matches EstraNet
    'warmup_steps': 1000,
}

print("✓ Configuration (EstraNet-aligned):")
for key, value in config.items():
    print(f"  {key:20s}: {value}")

In [None]:
# Build training command
train_cmd = f"""python mamba-gnn-scripts/train_mamba_gnn.py \
    --data_path={config['data_path']} \
    --checkpoint_dir={config['checkpoint_dir']} \
    --target_byte={config['target_byte']} \
    --train_batch_size={config['train_batch_size']} \
    --eval_batch_size={config['eval_batch_size']} \
    --train_steps={config['train_steps']} \
    --learning_rate={config['learning_rate']} \
    --d_model={config['d_model']} \
    --mamba_layers={config['mamba_layers']} \
    --gnn_layers={config['gnn_layers']} \
    --k_neighbors={config['k_neighbors']} \
    --dropout={config['dropout']} \
    --do_train
"""

print("Training command:")
print(train_cmd)
print("\n✓ Run the cell below to start training")

In [None]:
# Execute training
import subprocess
import sys

print("="*80)
print("Starting Mamba-GNN Training (EstraNet-Aligned Configuration)")
print("="*80)

# Run training script
result = subprocess.run(
    train_cmd.split(),
    capture_output=False,
    text=True,
    shell=False
)

if result.returncode == 0:
    print("\n" + "="*80)
    print("✓ Training completed successfully!")
    print("="*80)
else:
    print("\n" + "="*80)
    print("✗ Training failed with error code:", result.returncode)
    print("="*80)

## Option 2: Train Using TensorFlow (for TFLite Conversion)

If you need to deploy on mobile/edge devices, use the TensorFlow version.

In [None]:
# TensorFlow training command
tf_train_cmd = f"""python scripts/train_mamba_gnn_tf.py \
    --data_path={config['data_path']} \
    --checkpoint_dir=checkpoints/mamba_gnn_tf \
    --target_byte={config['target_byte']} \
    --train_batch_size={config['train_batch_size']} \
    --eval_batch_size={config['eval_batch_size']} \
    --train_steps={config['train_steps']} \
    --learning_rate={config['learning_rate']} \
    --d_model={config['d_model']} \
    --do_train
"""

print("TensorFlow training command:")
print(tf_train_cmd)
print("\n✓ This will automatically export to TFLite after training")
print("  Output: checkpoints/mamba_gnn_tf/mamba_gnn.tflite")

In [None]:
# Execute TensorFlow training (uncomment to run)
# import subprocess
# result = subprocess.run(tf_train_cmd.split(), capture_output=False)
# print(f"✓ Training complete. TFLite model saved.")

## Monitor Training Progress

View training loss, learning rate, and gradient norms.

In [None]:
import pickle
from pathlib import Path

# Load training history
loss_path = Path(CHECKPOINT_DIR) / 'loss.pkl'

if loss_path.exists():
    with open(loss_path, 'rb') as f:
        history = pickle.load(f)
    
    steps = sorted(history.keys())
    train_losses = [history[s]['train_loss'] for s in steps if 'train_loss' in history[s]]
    lrs = [history[s]['lr'] for s in steps if 'lr' in history[s]]
    grad_norms = [history[s]['grad_norm'] for s in steps if 'grad_norm' in history[s]]
    
    # Plot training metrics
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Training loss
    axes[0].plot(steps[:len(train_losses)], train_losses, 'b-', linewidth=2)
    axes[0].set_xlabel('Training Steps', fontsize=12)
    axes[0].set_ylabel('Loss (Cross-Entropy)', fontsize=12)
    axes[0].set_title('Training Loss', fontsize=14, fontweight='bold')
    axes[0].grid(True, alpha=0.3)
    
    # Learning rate
    axes[1].plot(steps[:len(lrs)], lrs, 'g-', linewidth=2)
    axes[1].set_xlabel('Training Steps', fontsize=12)
    axes[1].set_ylabel('Learning Rate', fontsize=12)
    axes[1].set_title('Learning Rate Schedule (Cosine Decay)', fontsize=14, fontweight='bold')
    axes[1].grid(True, alpha=0.3)
    
    # Gradient norm
    axes[2].plot(steps[:len(grad_norms)], grad_norms, 'r-', linewidth=2)
    axes[2].set_xlabel('Training Steps', fontsize=12)
    axes[2].set_ylabel('Gradient Norm', fontsize=12)
    axes[2].set_title('Gradient Norm (Clipped at 0.25)', fontsize=14, fontweight='bold')
    axes[2].axhline(y=0.25, color='orange', linestyle='--', label='Clip threshold')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{RESULT_DIR}/training_progress.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\n✓ Training Progress:")
    print(f"  Total steps: {max(steps):,}")
    print(f"  Latest loss: {train_losses[-1]:.4f}")
    print(f"  Latest LR: {lrs[-1]:.6f}")
    print(f"  Latest grad norm: {grad_norms[-1]:.4f}")
    print(f"\n  Plot saved: {RESULT_DIR}/training_progress.png")
else:
    print("⚠ No training history found. Start training first.")

## Evaluate Model (Guessing Entropy)

Evaluate using 100-trial Guessing Entropy methodology (matches EstraNet).

In [None]:
# Evaluation configuration
checkpoint_idx = 100000  # Which checkpoint to evaluate (0 = latest)

eval_cmd = f"""python mamba-gnn-scripts/train_mamba_gnn.py \
    --data_path={config['data_path']} \
    --checkpoint_dir={config['checkpoint_dir']} \
    --target_byte={config['target_byte']} \
    --d_model={config['d_model']} \
    --mamba_layers={config['mamba_layers']} \
    --gnn_layers={config['gnn_layers']} \
    --k_neighbors={config['k_neighbors']} \
    --dropout={config['dropout']} \
    --checkpoint_idx={checkpoint_idx} \
    --result_path={RESULT_DIR}/mamba_gnn_eval
"""

print("Evaluation command:")
print(eval_cmd)
print("\n✓ This will compute 100-trial Guessing Entropy")
print("  (Takes ~10-15 minutes)")

In [None]:
# Execute evaluation
import subprocess

print("="*80)
print("Evaluating Mamba-GNN (100-trial Guessing Entropy)")
print("="*80)

result = subprocess.run(eval_cmd.split(), capture_output=False)

if result.returncode == 0:
    print("\n✓ Evaluation complete!")
else:
    print("\n✗ Evaluation failed")

## Plot Guessing Entropy Curve

In [None]:
# Load and plot evaluation results
result_file = f'{RESULT_DIR}/mamba_gnn_eval.txt'

if Path(result_file).exists():
    with open(result_file, 'r') as f:
        lines = f.readlines()
        mean_ranks = np.array([float(x) for x in lines[0].strip().split('\t')])
        std_ranks = np.array([float(x) for x in lines[1].strip().split('\t')])
    
    traces = np.arange(1, len(mean_ranks) + 1)
    
    # Create plot
    plt.figure(figsize=(14, 7))
    
    # Main line
    plt.plot(traces, mean_ranks, 'b-', linewidth=2.5, label='Mean Key Rank (GE)')
    
    # Confidence interval
    plt.fill_between(traces, 
                     mean_ranks - std_ranks, 
                     mean_ranks + std_ranks,
                     alpha=0.3, color='blue', label='±1 Std Dev')
    
    # Key recovered line
    plt.axhline(y=0, color='r', linestyle='--', linewidth=2, label='Key Recovered (Rank=0)')
    
    # Labels and formatting
    plt.xlabel('Number of Traces', fontsize=14)
    plt.ylabel('Key Rank (Guessing Entropy)', fontsize=14)
    plt.title('Mamba-GNN: Guessing Entropy Evaluation (100 trials)\nEstraNet-Aligned Configuration', 
              fontsize=16, fontweight='bold')
    plt.legend(fontsize=12, loc='upper right')
    plt.grid(True, alpha=0.3)
    
    # Add milestones
    milestones = [100, 500, 1000, 2000, 5000]
    for m in milestones:
        if m < len(mean_ranks):
            plt.axvline(x=m, color='gray', linestyle=':', alpha=0.5)
            plt.text(m, plt.ylim()[1] * 0.95, f'{m}', 
                    ha='center', fontsize=9, color='gray')
    
    plt.tight_layout()
    plt.savefig(f'{RESULT_DIR}/guessing_entropy_curve.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print summary statistics
    print("="*80)
    print("GUESSING ENTROPY RESULTS")
    print("="*80)
    print(f"\nTarget byte: {config['target_byte']}")
    print(f"\nKey Rank (Mean ± Std):")
    print(f"  100 traces:   {mean_ranks[99]:.2f} ± {std_ranks[99]:.2f}")
    print(f"  500 traces:   {mean_ranks[499]:.2f} ± {std_ranks[499]:.2f}")
    print(f"  1000 traces:  {mean_ranks[999]:.2f} ± {std_ranks[999]:.2f}")
    if len(mean_ranks) >= 2000:
        print(f"  2000 traces:  {mean_ranks[1999]:.2f} ± {std_ranks[1999]:.2f}")
    if len(mean_ranks) >= 5000:
        print(f"  5000 traces:  {mean_ranks[4999]:.2f} ± {std_ranks[4999]:.2f}")
    if len(mean_ranks) >= 10000:
        print(f"  10000 traces: {mean_ranks[9999]:.2f} ± {std_ranks[9999]:.2f}")
    
    # Find recovery point
    recovered_idx = np.where(mean_ranks == 0)[0]
    if len(recovered_idx) > 0:
        print(f"\n✓ Key RECOVERED at {recovered_idx[0]+1} traces")
    else:
        print(f"\n✗ Key NOT recovered (best rank: {mean_ranks[-1]:.2f})")
    
    print("\n" + "="*80)
    print(f"Plot saved: {RESULT_DIR}/guessing_entropy_curve.png")
    print("="*80)
else:
    print("⚠ No evaluation results found. Run evaluation first.")

## Compare with EstraNet

Compare Mamba-GNN with EstraNet Transformer/GNN models.

In [None]:
# Compare with EstraNet results
estranet_result = 'results/estranet_transformer_eval.txt'  # Update this path

compare_cmd = f"""python scripts/compare_results.py \
    --mamba_results={result_file} \
    --estranet_results={estranet_result} \
    --output={RESULT_DIR}/model_comparison.png
"""

print("Comparison command:")
print(compare_cmd)
print("\n✓ Make sure you have EstraNet results first:")
print("  python scripts/train_trans.py --model_type=transformer --do_train")

In [None]:
# Execute comparison (uncomment when EstraNet results are ready)
# import subprocess
# result = subprocess.run(compare_cmd.split())
# print("✓ Comparison plot created")

## Load and Use Trained Model

Load checkpoint for inference or further analysis.

In [None]:
import torch
import h5py
from sklearn.preprocessing import StandardScaler

# Import model
sys.path.append('models')
from mamba_gnn_model import OptimizedMambaGNN

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model (EstraNet-aligned configuration)
model = OptimizedMambaGNN(
    trace_length=700,
    d_model=config['d_model'],
    mamba_layers=config['mamba_layers'],
    gnn_layers=config['gnn_layers'],
    num_classes=256,
    k_neighbors=config['k_neighbors'],
    dropout=config['dropout']
).to(device)

# Load checkpoint
ckpt_path = f"{CHECKPOINT_DIR}/mamba_gnn-100000.pth"

if Path(ckpt_path).exists():
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    total_params = sum(p.numel() for p in model.parameters())
    
    print(f"\n✓ Model loaded successfully")
    print(f"  Checkpoint: {ckpt_path}")
    print(f"  Training step: {checkpoint['global_step']:,}")
    print(f"  Total parameters: {total_params:,}")
    print(f"  Model ready for inference")
else:
    print(f"⚠ Checkpoint not found: {ckpt_path}")
    print("  Train the model first or specify different checkpoint_idx")

In [None]:
# Example: Get predictions on attack traces
if Path(ckpt_path).exists() and Path(DATA_PATH).exists():
    print("Running inference on sample traces...")
    
    # Load attack data
    with h5py.File(DATA_PATH, 'r') as f:
        X_attack = f['Attack_traces/traces'][:100]  # First 100 traces
        m_attack = f['Attack_traces/metadata'][:]
    
    # Normalize (using same scaler as training)
    with h5py.File(DATA_PATH, 'r') as f:
        X_train_sample = f['Profiling_traces/traces'][:1000]
    
    scaler = StandardScaler()
    scaler.fit(X_train_sample)
    X_attack_norm = scaler.transform(X_attack)
    
    # Get predictions
    X_tensor = torch.FloatTensor(X_attack_norm).to(device)
    
    with torch.no_grad():
        logits = model(X_tensor)
        probs = torch.softmax(logits, dim=1)
    
    predicted_classes = torch.argmax(probs, dim=1).cpu().numpy()
    
    print(f"\n✓ Inference complete")
    print(f"  Traces processed: {len(X_attack)}")
    print(f"  Prediction shape: {probs.shape}")
    print(f"  Top predicted classes: {predicted_classes[:10]}")
    print(f"  Confidence (first trace): {probs[0].max().item():.4f}")
else:
    print("⚠ Model or data not available")

## Summary & Configuration Verification

**✅ This notebook uses the CORRECT configuration:**

In [None]:
# Verify configuration matches EstraNet
print("="*80)
print("CONFIGURATION VERIFICATION")
print("="*80)

verification = {
    'Loss Function': ('Cross-Entropy', '✓'),
    'Learning Rate': ('2.5e-4', '✓'),
    'Train Batch Size': ('256', '✓'),
    'Eval Batch Size': ('32', '✓'),
    'Training Steps': ('100,000', '✓'),
    'Optimizer': ('Adam', '✓'),
    'LR Schedule': ('Cosine Decay', '✓'),
    'Model Dimension': ('128', '✓'),
    'Dropout': ('0.1', '✓'),
    'Gradient Clipping': ('0.25', '✓'),
    'Evaluation Method': ('100-trial GE', '✓'),
}

print("\nEstraNet Alignment Check:")
for param, (value, status) in verification.items():
    print(f"  {status} {param:25s}: {value}")

print("\n" + "="*80)
print("✓ ALL PARAMETERS MATCH ESTRANET")
print("✓ FAIR COMPARISON GUARANTEED")
print("="*80)

print("\n⚠️ DO NOT USE: final_best_gnn_mamba_teacher.ipynb")
print("   That notebook has incompatible configuration:")
print("   - FocalLoss instead of Cross-Entropy")
print("   - Learning rate 8x too high")
print("   - Batch size 4x too small")
print("   - Single-trial evaluation instead of 100-trial GE")

print("\n✓ USE THIS NOTEBOOK for training and evaluation")

## Next Steps

### 1. Train Model
Run the training cell above to start training with correct configuration.

### 2. Monitor Progress
Check training loss, learning rate, and gradient norms periodically.

### 3. Evaluate Model
After training, run evaluation to get Guessing Entropy curves.

### 4. Compare with EstraNet
Train EstraNet models and compare results using the comparison script.

### 5. Expected Results
- **After 100k steps**: Loss ~4.8-5.0
- **Key recovery**: Within 1000-2000 traces
- **Performance**: Comparable to EstraNet Transformer (~1200 traces)

### Files Created
- **Training script**: `mamba-gnn-scripts/train_mamba_gnn.py`
- **TensorFlow script**: `scripts/train_mamba_gnn_tf.py`
- **PowerShell runner**: `mamba-gnn-scripts/train_mamba_gnn.ps1`
- **Bash runner**: `mamba-gnn-scripts/train_mamba_gnn.sh`
- **Comparison tool**: `scripts/compare_results.py`
- **This notebook**: `train_mamba_gnn_notebook.ipynb`

### Documentation
- **Quick start**: `QUICKSTART.md`
- **TFLite guide**: `TENSORFLOW_TFLITE_GUIDE.md`
- **TF implementation**: `TENSORFLOW_IMPLEMENTATION_SUMMARY.md`
- **Config comparison**: `NOTEBOOK_VS_SCRIPT_COMPARISON.md`

**Ready to train with fair comparison to EstraNet!** 🎯