# EstraNet Training on Google Colab

This notebook trains the EstraNet Transformer model for side-channel analysis on ASCAD dataset.

**Steps:**
1. Setup environment and clone repository
2. Download/upload ASCAD dataset
3. Train the model
4. Evaluate results

## 1. Setup Environment

In [None]:
# Check GPU availability
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
print("GPU Available:", tf.config.list_physical_devices('GPU'))

# If no GPU, you can enable it via: Runtime â†’ Change runtime type â†’ GPU

In [None]:
# Mount Google Drive (to save checkpoints)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone your repository
!git clone https://github.com/YOUR_USERNAME/EstraNet.git
%cd EstraNet

In [None]:
# Install dependencies
!pip install -r requirements.txt

## 2. Download ASCAD Dataset

Choose one option below:

In [None]:
import os
import gdown

# Create data directory
os.makedirs('data', exist_ok=True)

# ASCADf dataset configuration
file_id = "1WNajWT0qFbpqPJiuePS_HeXxsCvUHI5M"
DATASET_PATH = "data/ASCAD.h5"

if not os.path.exists(DATASET_PATH):
    print("ðŸ“¥ Downloading ASCADf dataset from Google Drive...")
    print("   This may take a few minutes (~1.5 GB)\n")
    
    # Download using gdown
    gdown.download(f"https://drive.google.com/uc?id={file_id}", DATASET_PATH, quiet=False)
    
    print("\nâœ… Dataset downloaded successfully!")
else:
    print("âœ… Dataset already exists")

# Verify dataset
import h5py
with h5py.File(DATASET_PATH, 'r') as f:
    print(f"\nðŸ“Š Dataset info:")
    print(f"  Keys: {list(f.keys())}")
    if 'Profiling_traces' in f:
        print(f"  Profiling traces shape: {f['Profiling_traces/traces'].shape}")
    if 'Attack_traces' in f:
        print(f"  Attack traces shape: {f['Attack_traces/traces'].shape}")

In [None]:
# Verify dataset
import h5py
import os

if os.path.exists('data/ASCAD.h5'):
    with h5py.File('data/ASCAD.h5', 'r') as f:
        print("âœ“ ASCAD.h5 found!")
        print("  Available keys:", list(f.keys()))
        print("  Profiling traces shape:", f['Profiling_traces']['traces'].shape)
        print("  Attack traces shape:", f['Attack_traces']['traces'].shape)
else:
    print("âœ— ASCAD.h5 not found. Please upload the dataset.")

## 3. Configure Training

Edit these settings as needed:

In [None]:
# Training configuration
CONFIG = {
    # Paths
    'data_path': 'data/ASCAD.h5',
    'checkpoint_dir': '/content/drive/MyDrive/estranet_checkpoints',  # Save to Google Drive
    'result_path': 'results',
    
    # Training settings (adjusted for Colab)
    'train_steps': 50000,        # Reduced from 4M for faster training
    'warmup_steps': 5000,        # Reduced from 1M
    'save_steps': 5000,          # Save every 5k steps
    'iterations': 1000,          # Log every 1k steps
    'train_batch_size': 32,      # Adjust based on GPU memory
    'eval_batch_size': 32,
    
    # Data settings
    'input_length': 10000,       # Use 10000 for faster training
    'data_desync': 200,          # Data augmentation
    
    # Model architecture
    'n_layer': 2,
    'd_model': 128,
    'd_head': 32,
    'n_head': 8,
    'd_inner': 256,
    'n_head_softmax': 8,
    'd_head_softmax': 16,
    'dropout': 0.05,
    'n_conv_layer': 2,
    'pool_size': 20,
    
    # Optimization
    'learning_rate': 2.5e-4,
    'clip': 0.25,
    'min_lr_ratio': 0.004,
}

# Create checkpoint directory
!mkdir -p {CONFIG['checkpoint_dir']}
!mkdir -p results

print("Configuration ready!")
print(f"Checkpoints will be saved to: {CONFIG['checkpoint_dir']}")

## 4. Train the Model

In [None]:
# Build the training command
train_cmd = f"""
python train_trans.py \
    --data_path={CONFIG['data_path']} \
    --checkpoint_dir={CONFIG['checkpoint_dir']} \
    --dataset=ASCAD \
    --input_length={CONFIG['input_length']} \
    --data_desync={CONFIG['data_desync']} \
    --train_batch_size={CONFIG['train_batch_size']} \
    --eval_batch_size={CONFIG['eval_batch_size']} \
    --train_steps={CONFIG['train_steps']} \
    --warmup_steps={CONFIG['warmup_steps']} \
    --iterations={CONFIG['iterations']} \
    --save_steps={CONFIG['save_steps']} \
    --n_layer={CONFIG['n_layer']} \
    --d_model={CONFIG['d_model']} \
    --d_head={CONFIG['d_head']} \
    --n_head={CONFIG['n_head']} \
    --d_inner={CONFIG['d_inner']} \
    --n_head_softmax={CONFIG['n_head_softmax']} \
    --d_head_softmax={CONFIG['d_head_softmax']} \
    --dropout={CONFIG['dropout']} \
    --conv_kernel_size=3 \
    --n_conv_layer={CONFIG['n_conv_layer']} \
    --pool_size={CONFIG['pool_size']} \
    --d_kernel_map=512 \
    --beta_hat_2=150 \
    --model_normalization=preLC \
    --head_initialization=forward \
    --softmax_attn=True \
    --learning_rate={CONFIG['learning_rate']} \
    --clip={CONFIG['clip']} \
    --min_lr_ratio={CONFIG['min_lr_ratio']} \
    --max_eval_batch=100 \
    --do_train=True
"""

print("Starting training...")
print("="*60)
!{train_cmd}

## 5. Monitor Training (Optional)

Run this in a separate cell while training is running:

In [None]:
# View training progress
import pickle
import matplotlib.pyplot as plt

loss_file = f"{CONFIG['checkpoint_dir']}/loss.pkl"

try:
    with open(loss_file, 'rb') as f:
        loss_dict = pickle.load(f)
    
    steps = sorted(loss_dict.keys())
    train_losses = [loss_dict[s]['train_loss'] for s in steps]
    test_losses = [loss_dict[s]['test_loss'] for s in steps]
    
    plt.figure(figsize=(10, 5))
    plt.plot(steps, train_losses, label='Train Loss')
    plt.plot(steps, test_losses, label='Test Loss')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training Progress')
    plt.grid(True)
    plt.show()
    
    print(f"Latest step: {steps[-1]}")
    print(f"Train loss: {train_losses[-1]:.4f}")
    print(f"Test loss: {test_losses[-1]:.4f}")
except FileNotFoundError:
    print("Loss file not found yet. Training hasn't started saving checkpoints.")

## 6. Evaluate the Model

In [None]:
# Run evaluation to get key rank
eval_cmd = f"""
python train_trans.py \
    --data_path={CONFIG['data_path']} \
    --checkpoint_dir={CONFIG['checkpoint_dir']} \
    --dataset=ASCAD \
    --input_length={CONFIG['input_length']} \
    --eval_batch_size={CONFIG['eval_batch_size']} \
    --n_layer={CONFIG['n_layer']} \
    --d_model={CONFIG['d_model']} \
    --d_head={CONFIG['d_head']} \
    --n_head={CONFIG['n_head']} \
    --d_inner={CONFIG['d_inner']} \
    --n_head_softmax={CONFIG['n_head_softmax']} \
    --d_head_softmax={CONFIG['d_head_softmax']} \
    --dropout={CONFIG['dropout']} \
    --n_conv_layer={CONFIG['n_conv_layer']} \
    --pool_size={CONFIG['pool_size']} \
    --model_normalization=preLC \
    --result_path={CONFIG['result_path']}/eval_results \
    --do_train=False
"""

print("Starting evaluation...")
!{eval_cmd}

In [None]:
# Plot key rank results
import numpy as np
import matplotlib.pyplot as plt

results_file = f"{CONFIG['result_path']}/eval_results.txt"

try:
    with open(results_file, 'r') as f:
        lines = f.readlines()
    
    # Last line contains mean ranks
    mean_ranks = np.array([float(x) for x in lines[-1].strip().split()])
    
    plt.figure(figsize=(12, 6))
    plt.plot(mean_ranks)
    plt.xlabel('Number of Traces')
    plt.ylabel('Key Rank')
    plt.title('Key Recovery Performance')
    plt.grid(True)
    plt.yscale('log')
    plt.show()
    
    # Find how many traces needed for rank 0
    rank_0_idx = np.where(mean_ranks == 0)[0]
    if len(rank_0_idx) > 0:
        print(f"âœ“ Key recovered with {rank_0_idx[0]} traces!")
    else:
        print(f"Key not fully recovered. Best rank: {int(mean_ranks[-1])}")
        
except FileNotFoundError:
    print("Results file not found. Run evaluation first.")

## Tips

1. **Training Time**: With the default config (50k steps), training takes ~2-4 hours on Colab GPU
2. **Checkpoints**: Saved to Google Drive, so they persist across sessions
3. **Resume Training**: Set `--warm_start=True` to resume from last checkpoint
4. **Experiment**: Try different `input_length` (700, 10000, 40000) and model sizes
5. **Memory Issues**: Reduce `train_batch_size` if you get OOM errors