# Mamba-GNN — Cross‑Entropy (Phase 1)
# EstraNet‑aligned Cross‑Entropy training, evaluation (100‑trial GE), plotting, and export.

"""
Purpose: Phase‑1 (fair comparison) notebook using Cross‑Entropy loss.
- EstraNet‑aligned defaults
- Training cell (safe-run toggle)
- 100‑trial Guessing Entropy evaluation
- Plots + export utilities
"""

In [None]:
# Section 1 — Clone repository & workspace cleanup
import os, subprocess, shutil
from pathlib import Path

FORCE_RECLONE = False
REPO_URL = 'https://github.com/loshithan/EstraNet.git'
ROOT = Path.cwd()
print(f"Workspace cwd: {ROOT}")

# Safe behavior: only remove/re-clone if FORCE_RECLONE=True
if FORCE_RECLONE:
    if Path('EstraNet').exists():
        print('Removing existing EstraNet folder (FORCE_RECLONE=True)')
        shutil.rmtree('EstraNet')

if not Path('EstraNet').exists():
    print('Cloning EstraNet repository...')
    subprocess.run(['git', 'clone', REPO_URL], check=True)
else:
    print('EstraNet already present — skipping clone')

# If cloned, change into the repo directory
if Path('EstraNet').exists():
    os.chdir('EstraNet')
    print(f"Changed cwd -> {os.getcwd()}")
else:
    print('Warning: EstraNet folder not found in workspace — ensure you are in the project root')

In [None]:
# Section 2 — Setup environment, packages and project paths
import sys
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# Verify core packages and versions (do not auto-install here)
import torch
print(f"python: {sys.version.split()[0]}, torch: {torch.__version__}")

# Ensure project script/model folders are on sys.path
ROOT = Path.cwd()
sys.path.append(str(ROOT / 'mamba-gnn-scripts'))
sys.path.append(str(ROOT))

# Create folders used by the notebook
os.makedirs('checkpoints/mamba_gnn_estranet', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('logs', exist_ok=True)

print('\n✓ Environment folders created')

In [None]:
# Section 3 — Verify GPU availability and device config
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    try:
        print(f"CUDA device name: {torch.cuda.get_device_name(0)}")
    except Exception:
        pass

In [None]:
# Section 4 — Define EstraNet‑aligned training configuration
config = {
    'data_path': 'data/ASCAD.h5',
    'checkpoint_dir': 'checkpoints/mamba_gnn_estranet',
    'target_byte': 2,
    'train_batch_size': 256,
    'eval_batch_size': 32,
    'train_steps': 50000,
    'learning_rate': 2.5e-4,
    'd_model': 64,
    'mamba_layers': 2,
    'gnn_layers': 2,
    'k_neighbors': 8,
    'eval_steps': 250,
    'save_steps': 5000,
    'warmup_steps': 1000,
    'dropout': 0.2,
    'weight_decay': 0.01,
    'label_smoothing': 0.02,
    'augment_noise': 0.0,
    'augment_shift': 0,
    'early_stopping': 30,
    'loss_function': 'cross_entropy',
    # focal params (kept for Phase‑2 but unused here)
    'focal_gamma': 2.5,
    'focal_alpha': 1.0,
}

print('EstraNet-aligned config (Phase 1 - CrossEntropy):')
for k, v in config.items():
    print(f"  {k:20s}: {v}")

In [None]:
# Section 5 — Configuration comparison (wrong vs correct)
print('--- Incorrect (example) ---')
print('loss=FocalLoss, lr=2e-3, batch_size=64, single-trial eval (WRONG)')

print('\n--- Correct (EstraNet‑aligned) ---')
for k in ['loss_function','learning_rate','train_batch_size','eval_batch_size','train_steps','d_model']:
    print(f"{k:20s}: {config.get(k)}")

In [None]:
# Section 6 — Toggle loss function (Phase 1 vs Phase 2)
# Default is Phase 1: Cross-Entropy (safe/fair comparison)
LOSS_FUNCTION = 'cross_entropy'  # change to 'focal_loss' to prepare Phase 2 cells

if LOSS_FUNCTION in ('focal', 'focal_loss'):
    config['loss_function'] = 'focal'
    config['label_smoothing'] = 0.0
    print('Switched to FocalLoss mode (label_smoothing disabled)')
else:
    config['loss_function'] = 'cross_entropy'
    print('Using Cross-Entropy (Phase 1)')

print('loss_function ->', config['loss_function'])

In [None]:
# Section 7 — Build CLI args and checkpoint/result directories
from pathlib import Path

# Ensure checkpoint dir matches phase
PHASE1_CKPT = Path('checkpoints/mamba_gnn_estranet')
PHASE2_CKPT = Path('checkpoints/mamba_gnn_phase2_focal')
PHASE1_CKPT.mkdir(parents=True, exist_ok=True)
PHASE2_CKPT.mkdir(parents=True, exist_ok=True)

config['checkpoint_dir'] = str(PHASE1_CKPT) if config['loss_function']=='cross_entropy' else str(PHASE2_CKPT)

# Build CLI argument string from config (selective)
def build_train_cmd(cfg):
    parts = ['python', 'mamba-gnn-scripts/train_mamba_gnn.py', '--do_train']
    parts.append(f"--loss_type={cfg['loss_function']}")
    for key in ['data_path','checkpoint_dir','train_steps','train_batch_size','eval_batch_size','learning_rate','d_model','mamba_layers','gnn_layers','k_neighbors','dropout','weight_decay','label_smoothing','save_steps','eval_steps']:
        val = cfg.get(key)
        if val is not None:
            parts.append(f"--{key}={val}")
    return ' '.join(parts)

train_cmd_preview = build_train_cmd(config)
print('Training command preview:')
print(train_cmd_preview)


In [None]:
# Section 8 — Run training with live logging and checkpoint monitor
import subprocess
import time
from datetime import datetime
from pathlib import Path

RUN_TRAIN = False  # set True to actually start training from this cell
LOG_DIR = Path('logs')
LOG_DIR.mkdir(exist_ok=True)

if RUN_TRAIN:
    cmd = build_train_cmd(config)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_file = LOG_DIR / f"train_ce_{timestamp}.log"
    print('Starting training — logging to', log_file)

    # unbuffered python to stream logs
    process = subprocess.Popen(cmd.replace('python ', 'python -u '), shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
    with open(log_file, 'w') as f:
        while True:
            line = process.stdout.readline()
            if line:
                print(line, end='')
                f.write(line)
                f.flush()
            if process.poll() is not None:
                break
    print('Training finished with returncode', process.returncode)
else:
    print('RUN_TRAIN is False — set to True to start training from this cell.')

# Quick monitor for newly saved checkpoints (one-shot)
ckpt_dir = Path(config['checkpoint_dir'])
ckpts = sorted(ckpt_dir.glob('*.pth'), key=lambda p: p.stat().st_mtime)
if ckpts:
    latest = ckpts[-1]
    print('Latest checkpoint:', latest.name, ' — size:', latest.stat().st_size/(1024*1024), 'MB')
else:
    print('No checkpoints found yet in', ckpt_dir)

In [None]:
# Section 9 — Patch training script to fix KeyError in `loss_history`
script_path = Path('mamba-gnn-scripts') / 'train_mamba_gnn.py'
pattern = 'loss_history[global_step].update({'
replacement = 'loss_history.setdefault(global_step, {}).update({'

if script_path.exists():
    text = script_path.read_text()
    if pattern in text:
        text = text.replace(pattern, replacement)
        script_path.write_text(text)
        print('Patched train_mamba_gnn.py — replaced loss_history update with setdefault variant')
    else:
        print('Pattern not found — file may already be patched or is different')
else:
    print('train_mamba_gnn.py not found at', script_path)

In [None]:
# Section 10 — Interactive checkpoint directory monitor (run in separate cell/session)
import time
from pathlib import Path

watch_dir = Path(config['checkpoint_dir'])
seen = set()
print(f"Watching checkpoints in: {watch_dir} (Ctrl+C to stop)")
try:
    while False:  # change to True when you actually want continuous monitoring
        files = sorted(watch_dir.glob('*.pth'))
        new = [f for f in files if f.name not in seen]
        for f in new:
            print(f"[{time.strftime('%H:%M:%S')}] New checkpoint: {f.name} ({f.stat().st_size/1e6:.1f} MB)")
            seen.add(f.name)
        time.sleep(5)
except KeyboardInterrupt:
    print('Monitor stopped')

print('Note: set the while-loop condition to True to enable live polling in this cell (not recommended inside long-running notebooks).')

In [None]:
# Section 11 — Optional: TensorFlow training + TFLite export (template)
# Template only — uncomment to run TensorFlow training script (if available)
# tf_cmd = (
#     "python scripts/train_mamba_gnn_tf.py "
#     f"--data_path={config['data_path']} "
#     "--checkpoint_dir=checkpoints/mamba_gnn_tf "
#     f"--train_batch_size={config['train_batch_size']} "
#     f"--eval_batch_size={config['eval_batch_size']} "
#     "--train_steps=50000 --do_train"
# )
# print('TF command (template):')
# print(tf_cmd)


In [None]:
# Section 12 — Plot training metrics (loss, LR, gradient norms)
import pickle
from pathlib import Path
import matplotlib.pyplot as plt

loss_path = Path(config['checkpoint_dir']) / 'loss.pkl'
if loss_path.exists():
    with open(loss_path, 'rb') as f:
        history = pickle.load(f)
    steps = sorted(history.keys())
    train_losses = [history[s].get('train_loss', None) for s in steps]
    lrs = [history[s].get('lr', None) for s in steps]
    grad_norms = [history[s].get('grad_norm', None) for s in steps]

    fig, axes = plt.subplots(1, 3, figsize=(18, 4))
    axes[0].plot(steps, train_losses, label='train_loss'); axes[0].set_title('Train Loss')
    axes[1].plot(steps, lrs, label='lr'); axes[1].set_title('LR')
    axes[2].plot(steps, grad_norms, label='grad_norm'); axes[2].set_title('Grad Norm')
    plt.tight_layout()
    out = Path('results') / 'training_progress.png'
    plt.savefig(out, dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved:', out)
else:
    print('No training history found at', loss_path)

In [None]:
# Section 13 — Evaluate model with 100-trial Guessing Entropy
# By default this will evaluate the latest checkpoint (checkpoint_idx=0)
RUN_EVAL = False  # set True to run evaluation from this cell
checkpoint_idx = 0  # 0 -> latest
result_path = Path('results') / 'cross_entropy_eval.txt'

eval_cmd = (
    f"python mamba-gnn-scripts/train_mamba_gnn.py "
    f"--data_path={config['data_path']} "
    f"--checkpoint_dir={config['checkpoint_dir']} "
    f"--d_model={config['d_model']} "
    f"--mamba_layers={config['mamba_layers']} "
    f"--gnn_layers={config['gnn_layers']} "
    f"--k_neighbors={config['k_neighbors']} "
    f"--dropout={config['dropout']} "
    f"--checkpoint_idx={checkpoint_idx} "
    f"--result_path={result_path}"
)

print('Eval command (preview):')
print(eval_cmd)

if RUN_EVAL:
    import subprocess
    rc = subprocess.run(eval_cmd.split()).returncode
    print('Eval returncode =', rc)
else:
    print('RUN_EVAL is False — set to True to run the 100-trial GE from this cell')

In [None]:
# Section 14 — Plot Guessing Entropy curve and save figure
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

result_file = Path('results') / 'cross_entropy_eval.txt'
if result_file.exists():
    with open(result_file, 'r') as f:
        lines = f.readlines()
        mean_ranks = np.array([float(x) for x in lines[0].strip().split('\t')])
        std_ranks = np.array([float(x) for x in lines[1].strip().split('\t')])

    traces = np.arange(1, len(mean_ranks) + 1)
    plt.figure(figsize=(12,6))
    plt.plot(traces, mean_ranks, label='Mean GE', lw=2)
    plt.fill_between(traces, mean_ranks-std_ranks, mean_ranks+std_ranks, alpha=0.25)
    plt.axhline(0, color='red', linestyle='--', label='Key recovered')
    plt.xlabel('Number of traces')
    plt.ylabel('Key rank (GE)')
    plt.title('Guessing Entropy — Cross‑Entropy (Phase 1)')
    plt.legend()
    out = Path('results') / 'guessing_entropy_cross_entropy.png'
    plt.savefig(out, dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved GE plot to', out)
else:
    print('No GE results found at', result_file, '\nRun evaluation first (Section 13)')

In [None]:
# Section 15 — Compare Mamba-GNN results with EstraNet
estranet_result = Path('results') / 'trans_long-11.txt'  # update if needed
mamba_result = Path('results') / 'cross_entropy_eval.txt'

compare_cmd = f"python scripts/compare_results.py --mamba_results={mamba_result} --estranet_results={estranet_result} --output=results/comparison_ce_vs_estranet.png"
print('Preview compare command:')
print(compare_cmd)

# To run uncomment the next lines
# import subprocess
# subprocess.run(compare_cmd.split())
# print('Comparison plot saved to results/comparison_ce_vs_estranet.png')

In [None]:
# Section 16 — Load model for inference (patch imports if needed)
import os
import sys
from pathlib import Path
import torch

# Patch relative imports in model file if they cause import errors
model_file = Path('models') / 'mamba_gnn_model.py'
if model_file.exists():
    text = model_file.read_text()
    if 'from .mamba_block' in text:
        text = text.replace('from .mamba_block import OptimizedMambaBlock', 'from models.mamba_block import OptimizedMambaBlock')
        text = text.replace('from .gat_layer import EnhancedGAT', 'from models.gat_layer import EnhancedGAT')
        text = text.replace('from .patch_embedding import CNNPatchEmbedding', 'from models.patch_embedding import CNNPatchEmbedding')
        model_file.write_text(text)
        print('Patched relative imports in models/mamba_gnn_model.py')

# Import model class
sys.path.append(str(Path.cwd()))
from models.mamba_gnn_model import OptimizedMambaGNN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = OptimizedMambaGNN(trace_length=700, d_model=config['d_model'], mamba_layers=config['mamba_layers'], gnn_layers=config['gnn_layers'], num_classes=256, k_neighbors=config['k_neighbors'], dropout=config['dropout']).to(device)

# Load latest checkpoint (if available)
ckpt_dir = Path(config['checkpoint_dir'])
ckpts = sorted(ckpt_dir.glob('*.pth'), key=lambda p: p.stat().st_mtime)
if ckpts:
    ckpt = torch.load(ckpts[-1], map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    print('Loaded model from', ckpts[-1].name)
else:
    print('No checkpoint found to load')

In [None]:
# Section 17 — Run inference on sample traces
import h5py
import numpy as np
from sklearn.preprocessing import StandardScaler
import torch

data_path = Path(config['data_path'])
if data_path.exists() and 'model' in globals() and hasattr(model, 'eval'):
    with h5py.File(data_path, 'r') as f:
        X_attack = f['Attack_traces/traces'][:100]
        X_train_sample = f['Profiling_traces/traces'][:1000]

    scaler = StandardScaler()
    scaler.fit(X_train_sample)
    X_attack_norm = scaler.transform(X_attack)

    X_tensor = torch.FloatTensor(X_attack_norm).to(device)
    with torch.no_grad():
        logits = model(X_tensor)
        probs = torch.softmax(logits, dim=1).cpu().numpy()

    preds = np.argmax(probs, axis=1)
    print('Processed', len(X_attack), 'attack traces')
    print('Top predictions (first 10):', preds[:10])
    print('Confidence (first trace):', probs[0].max())
else:
    print('Data or model not available — ensure checkpoint is loaded and ASCAD dataset exists')

In [None]:
# Section 18 — Export / download checkpoints
import shutil
from pathlib import Path
ckpt_folder = Path(config['checkpoint_dir'])
zip_name = Path.cwd() / 'mamba_gnn_ce_checkpoints'
if ckpt_folder.exists():
    shutil.make_archive(str(zip_name), 'zip', root_dir=str(ckpt_folder))
    print('Created:', str(zip_name)+'.zip')
    print('Path:', (str(zip_name)+'.zip'))
else:
    print('Checkpoint folder not found:', ckpt_folder)

# Note: in Colab use google.colab.files.download(zip_path) to download the zip file

In [None]:
# Section 19 — Phase comparison (Phase 1 vs Phase 2)
import numpy as np
from pathlib import Path

def load_eval(path):
    p = Path(path)
    if p.exists():
        with open(p,'r') as f:
            lines = f.readlines()
            mean = np.array([float(x) for x in lines[0].strip().split('\t')])
            std = np.array([float(x) for x in lines[1].strip().split('\t')])
            return mean, std
    return None, None

ce_mean, ce_std = load_eval('results/cross_entropy_eval.txt')
f_mean, f_std = load_eval('results/focal_loss_eval.txt')

if ce_mean is None:
    print('Cross-entropy eval missing — run Section 13')
else:
    print('Loaded Cross-Entropy eval — length', len(ce_mean))

if f_mean is None:
    print('FocalLoss eval missing — run Phase 2 experiments to populate results/focal_loss_eval.txt')

# If both present, plot comparison
if ce_mean is not None and f_mean is not None:
    import matplotlib.pyplot as plt
    L = min(len(ce_mean), len(f_mean))
    traces = np.arange(1, L+1)
    plt.figure(figsize=(12,6))
    plt.plot(traces, ce_mean[:L], label='Cross-Entropy (Phase 1)')
    plt.plot(traces, f_mean[:L], label='FocalLoss (Phase 2)')
    plt.xlabel('Traces')
    plt.ylabel('Key rank (GE)')
    plt.legend()
    plt.grid(True)
    plt.savefig('results/phase_comparison_ce_vs_focal.png', dpi=150)
    plt.show()

    # Recovery points
    def recovery_point(mean):
        idx = np.where(mean==0)[0]
        return idx[0]+1 if len(idx)>0 else None

    print('Phase 1 recovery:', recovery_point(ce_mean))
    print('Phase 2 recovery:', recovery_point(f_mean))

In [None]:
# Section 20 — Utilities: helpers for history / quick checks
from pathlib import Path
import pickle


def load_loss_history(ckpt_dir=config['checkpoint_dir']):
    p = Path(ckpt_dir) / 'loss.pkl'
    if p.exists():
        with open(p,'rb') as f:
            return pickle.load(f)
    return {}


def find_checkpoint_by_step(ckpt_dir=config['checkpoint_dir'], step=None):
    files = sorted(Path(ckpt_dir).glob('*.pth'), key=lambda p: p.stat().st_mtime)
    if not files:
        return None
    if step is None:
        return files[-1]
    for f in files:
        if str(step) in f.name:
            return f
    return None

print('Utilities loaded: load_loss_history, find_checkpoint_by_step')