# Fine‑tune Mamba‑GNN with FocalLoss (warm‑start)

This notebook fine‑tunes an existing Cross‑Entropy (CE) checkpoint with FocalLoss.
- Warm‑start from your best CE checkpoint
- Quick (10‑trial) GE for fast feedback and full (100‑trial) GE for final validation
- Plots, checkpoint monitor and export utilities

Run cells in order. Cells that actually start training/evaluation are disabled by default — change the RUN_* flags to True when ready.

In [None]:
# 1) Imports & environment
import os
import sys
import shutil
import subprocess
from pathlib import Path
import importlib.util
import torch
import numpy as np
import matplotlib.pyplot as plt

# Paths
ROOT = Path.cwd()
CHECKPOINTS = ROOT / 'checkpoints'
RESULTS = ROOT / 'results'
LOGS = ROOT / 'logs'
for p in (CHECKPOINTS, RESULTS, LOGS):
    p.mkdir(exist_ok=True)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Make sure local scripts are importable
sys.path.append(str(ROOT / 'mamba-gnn-scripts'))
sys.path.append(str(ROOT))


In [None]:
# 2) Fine‑tune configuration (edit as needed)
SOURCE_CKPT = 'checkpoints/mamba_gnn_estranet/mamba_gnn-50000.pth'  # your best CE checkpoint
TARGET_DIR  = 'checkpoints/mamba_gnn_finetune_from_ce'
EXTRA_STEPS = 25000
LEARNING_RATE = 1e-4
FOCAL_GAMMA = 2.5
FOCAL_ALPHA = 1.0
TRAIN_BATCH_SIZE = 256
EVAL_BATCH_SIZE = 32
SAVE_STEPS = 5000
EVAL_STEPS = 250
DROPOUT = 0.3
WEIGHT_DECAY = 0.01

# Safety flags (cells are inert unless you flip these)
RUN_PREPARE = False   # copy checkpoint + infer arch
RUN_FINETUNE = False  # actually launch the fine‑tune job
RUN_QUICK_GE = False  # 10‑trial GE check after first fine‑tuned ckpt
RUN_FULL_GE = False   # 100‑trial GE for final validation

print('Config ready — change flags to run steps')

In [None]:
# 3) Prepare target folder and infer architecture (dry run)
import importlib.util

spec = importlib.util.spec_from_file_location('finetune_wrapper', Path('mamba-gnn-scripts') / 'finetune_mamba_focal.py')
fw = importlib.util.module_from_spec(spec)
spec.loader.exec_module(fw)

src = Path(SOURCE_CKPT)
tgt = Path(TARGET_DIR)

print('Source exists:', src.exists())
if RUN_PREPARE:
    ck = fw.copy_ckpt_to_target(src, tgt)
    print(f'Copied to {tgt}/checkpoint_latest.pth')
    src_step = int(ck.get('global_step', 0) or 0)
    print('Source global_step:', src_step)
    try:
        d_model, mamba_layers, gnn_layers = fw.infer_arch_from_state(ck.get('model_state_dict') or ck)
        print('Inferred arch ->', d_model, mamba_layers, gnn_layers)
    except Exception as e:
        print('Arch inference failed — using defaults (d_model=64,mamba=2,gnn=2)')
else:
    print('RUN_PREPARE is False — set True to copy checkpoint and infer architecture')


In [None]:
# 4) Launch fine‑tune (warm_start -> continues from source global_step)
import subprocess

if RUN_FINETUNE:
    # build command using wrapper helper to ensure consistent args
    total_steps = None  # let wrapper compute from source + EXTRA_STEPS
    cmd = [
        sys.executable, 'mamba-gnn-scripts/finetune_mamba_focal.py',
        f'--source_ckpt={SOURCE_CKPT}',
        f'--target_dir={TARGET_DIR}',
        f'--extra_steps={EXTRA_STEPS}',
        f'--learning_rate={LEARNING_RATE}',
        f'--focal_gamma={FOCAL_GAMMA}',
        f'--focal_alpha={FOCAL_ALPHA}',
        f'--train_batch_size={TRAIN_BATCH_SIZE}',
        f'--eval_batch_size={EVAL_BATCH_SIZE}',
        f'--save_steps={SAVE_STEPS}',
        f'--eval_steps={EVAL_STEPS}',
        f'--dropout={DROPOUT}',
        f'--weight_decay={WEIGHT_DECAY}'
    ]

    print('Launching fine‑tune command:')
    print(' '.join(map(str, cmd)))
    # run as subprocess (streams to notebook stdout)
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    try:
        for line in proc.stdout:
            print(line, end='')
    except KeyboardInterrupt:
        proc.terminate()
        print('\nCancelled by user')
    print('\nFine‑tune process exited with code', proc.returncode)
else:
    print('RUN_FINETUNE is False — set True to run the fine‑tune job from this cell')


In [None]:
# 5) Checkpoint monitor (one-shot)
from pathlib import Path

def latest_ckpt_info(ckpt_dir):
    p = Path(ckpt_dir)
    if not p.exists():
        return None
    ckpts = sorted(p.glob('*.pth'), key=lambda x: x.stat().st_mtime)
    if not ckpts:
        return None
    latest = ckpts[-1]
    size_mb = latest.stat().st_size / (1024*1024)
    return latest.name, size_mb, latest.stat().st_mtime

info = latest_ckpt_info(TARGET_DIR)
if info:
    print('Latest checkpoint:', info)
else:
    print('No checkpoints found in', TARGET_DIR)

print('\nTip: re-run this cell while training to refresh the latest checkpoint info')

In [None]:
# 6) Quick GE (10‑trial) or Full GE (100‑trial) evaluator
# This imports training utilities and runs GE evaluation on the latest checkpoint in TARGET_DIR
import importlib.util

spec = importlib.util.spec_from_file_location('train_mod', Path('mamba-gnn-scripts') / 'train_mamba_gnn.py')
train_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(train_mod)

from train_mod import OptimizedMambaGNN, ASCADDataset, load_ascad_data, evaluate_model_ge

def run_ge(ckpt_path, num_trials=10, max_traces=10000):
    print('Loading checkpoint:', ckpt_path)
    ck = torch.load(str(ckpt_path), map_location='cpu')
    msd = ck.get('model_state_dict') or ck

    # infer d_model (classifier weight or pos_encoding)
    if 'classifier.0.weight' in msd:
        d_model = msd['classifier.0.weight'].shape[1]
    elif 'pos_encoding' in msd:
        d_model = msd['pos_encoding'].shape[2]
    else:
        d_model = 64

    # simple layer inference (best‑effort)
    mamba_layers = len({k.split('.')[1] for k in msd.keys() if k.startswith('mamba_blocks.')}) or 2
    gnn_layers = len({k.split('.')[1] for k in msd.keys() if k.startswith('gnn_layers.')}) or 2

    print('Instantiating model ->', d_model, mamba_layers, gnn_layers)
    model = OptimizedMambaGNN(trace_length=700, d_model=d_model, mamba_layers=mamba_layers, gnn_layers=gnn_layers, num_classes=256, k_neighbors=8, dropout=0.3)
    model.load_state_dict(msd)
    model.to(device)
    model.eval()

    # load attack data
    _, _, X_attack, y_attack, m_attack = load_ascad_data('data/ASCAD.h5', target_byte=2)
    attack_dataset = ASCADDataset(X_attack, y_attack)
    attack_loader = torch.utils.data.DataLoader(attack_dataset, batch_size=32, shuffle=False, num_workers=0)

    mean_ranks, std_ranks = evaluate_model_ge(model, attack_loader, m_attack, 2, device, num_trials=num_trials, max_traces=max_traces)
    return mean_ranks, std_ranks

# Run quick/full GE when flags set
from pathlib import Path
ckpt = Path(TARGET_DIR) / 'checkpoint_latest.pth'
if not ckpt.exists():
    # fallback: pick latest saved .pth
    ckpts = sorted(Path(TARGET_DIR).glob('mamba_gnn-*.pth'), key=lambda p: p.stat().st_mtime)
    ckpt = ckpts[-1] if ckpts else None

if ckpt is None:
    print('No checkpoint found to evaluate. Run fine‑tune first or ensure TARGET_DIR is correct.')
else:
    if RUN_QUICK_GE:
        print('\nRunning QUICK GE (10 trials)')
        mean10, std10 = run_ge(ckpt, num_trials=10)
        out = Path('results') / f'finetune_quick_{ckpt.stem}.txt'
        with open(out,'w') as f:
            f.write('\t'.join(map(str,mean10)) + '\n')
            f.write('\t'.join(map(str,std10)) + '\n')
        print('Saved quick GE ->', out)

    if RUN_FULL_GE:
        print('\nRunning FULL GE (100 trials) — this will take time')
        mean100, std100 = run_ge(ckpt, num_trials=100)
        out = Path('results') / f'finetune_full_{ckpt.stem}.txt'
        with open(out,'w') as f:
            f.write('\t'.join(map(str,mean100)) + '\n')
            f.write('\t'.join(map(str,std100)) + '\n')
        print('Saved full GE ->', out)


In [None]:
# 7) Plot GE results (if present)
from pathlib import Path

# prefer full GE result if available
full_files = sorted(Path('results').glob('finetune_full_*.txt'))
quick_files = sorted(Path('results').glob('finetune_quick_*.txt'))

res_file = full_files[-1] if full_files else (quick_files[-1] if quick_files else None)

if res_file:
    with open(res_file,'r') as f:
        lines = f.readlines()
        mean = np.array([float(x) for x in lines[0].strip().split('\t')])
        std  = np.array([float(x) for x in lines[1].strip().split('\t')])
    traces = np.arange(1, len(mean)+1)
    plt.figure(figsize=(12,5))
    plt.plot(traces, mean, label='Mean GE')
    plt.fill_between(traces, mean-std, mean+std, alpha=0.25)
    plt.axhline(0, color='red', linestyle='--', label='Recovered')
    plt.xlabel('Traces'); plt.ylabel('Key rank (GE)')
    plt.title(f'Guessing Entropy — {res_file.name}')
    plt.legend(); plt.grid(True)
    out = Path('results') / f'ge_plot_{res_file.stem}.png'
    plt.savefig(out, dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved plot ->', out)
else:
    print('No GE result files found in results/. Run quick/full GE first.')


In [None]:
# 8) Export checkpoints (zip)
import shutil
from pathlib import Path

zip_path = Path.cwd() / 'mamba_gnn_finetune_from_ce_checkpoints'
if Path(TARGET_DIR).exists():
    shutil.make_archive(str(zip_path), 'zip', root_dir=str(Path(TARGET_DIR)))
    print('Created:', str(zip_path)+'.zip')
else:
    print('Target checkpoint folder not found:', TARGET_DIR)

print('Use VS Code or Colab file downloader to fetch the zip file')

# 9) Notes & recommended workflow

# Recommended quick flow:
# 1) Set RUN_PREPARE=True and run the "Prepare" cell to copy the CE checkpoint
# 2) Set RUN_FINETUNE=True and run the "Launch fine‑tune" cell to start warm‑start training (use GPU)
# 3) After first save_steps complete, set RUN_QUICK_GE=True and run the "Quick GE" cell
# 4) If quick GE is promising, set RUN_FULL_GE=True and run the full GE cell

print('Notebook ready — edit config flags above and run cells in order')