# EstraNet Training on ASCADf Dataset

This notebook sets up and trains the EstraNet model on the ASCADf (ASCAD with fixed key) dataset.

**Paper**: [EstraNet: An Efficient Shift-Invariant Transformer Network for Side-Channel Analysis](https://tches.iacr.org/index.php/TCHES/article/view/11255)

---

## üìã Setup Checklist
- ‚úÖ Install dependencies
- ‚úÖ Download ASCADf dataset from Google Drive
- ‚úÖ Apply TensorFlow 2.13+ compatibility fixes
- ‚úÖ Configure training parameters
- ‚úÖ Train the model
- ‚úÖ Evaluate results

## 1Ô∏è‚É£ Environment Setup

First, let's check if we're running on Google Colab and set up GPU if available.

In [1]:
# Check if running on Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("‚úÖ Running on Google Colab")
except:
    IN_COLAB = False
    print("üìù Running on local Jupyter")

# Check GPU availability
import tensorflow as tf
print(f"\nTensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

if tf.config.list_physical_devices('GPU'):
    print("üöÄ GPU detected! Training will be accelerated.")
else:
    print("‚ö†Ô∏è No GPU detected. Training will be slower on CPU.")
    if IN_COLAB:
        print("üí° Enable GPU: Runtime > Change runtime type > Hardware accelerator > GPU")

‚úÖ Running on Google Colab

TensorFlow version: 2.19.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
üöÄ GPU detected! Training will be accelerated.


## 2Ô∏è‚É£ Clone Repository and Install Dependencies

If running on Colab, we need to clone the repository first.

In [2]:
import os, shutil

os.chdir('/content')
if os.path.exists('EstraNet'):
    shutil.rmtree('EstraNet')  # Remove nested mess
    
!git clone https://github.com/loshithan/EstraNet.git
os.chdir('EstraNet')
print(f"‚úÖ Clean! Directory: {os.getcwd()}")

Cloning into 'EstraNet'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (31/31), done.[K
remote: Total 61 (delta 30), reused 61 (delta 30), pack-reused 0 (from 0)[K
Receiving objects: 100% (61/61), 41.46 KiB | 758.00 KiB/s, done.
Resolving deltas: 100% (30/30), done.
‚úÖ Clean! Directory: /content/EstraNet


In [3]:
# Install required dependencies
print("üì¶ Installing dependencies...\n")
# %pip install -q absl-py==2.3.1 numpy==1.24.3 scipy==1.10.1 h5py==3.11.0

# Install gdown for downloading from Google Drive
%pip install -q gdown

# Note: Using TensorFlow version pre-installed in Colab (2.16+ / 2.19+)
# The compatibility fixes in Section 3 work with all TensorFlow 2.13+ versions
print("\n‚úÖ All dependencies installed!")
print(f"Using TensorFlow {tf.__version__} (pre-installed)")

üì¶ Installing dependencies...


‚úÖ All dependencies installed!
Using TensorFlow 2.19.0 (pre-installed)


## 3Ô∏è‚É£ Apply TensorFlow 2.13+ Compatibility Fixes

‚ö†Ô∏è **IMPORTANT: You MUST run this cell before training!**

The original code has compatibility issues with TensorFlow 2.13+. This cell fixes them automatically.

In [4]:
print("üîß Applying ALL TensorFlow 2.13+ compatibility fixes...\n")

# FIX 1: transformer.py
print("üìù Fixing transformer.py...")
with open('transformer.py', 'r', encoding='utf-8') as f:
    content = f.read()

content = content.replace(
    'def call(self, inp, softmax_attn_smoothing=1, training=False):',
    'def call(self, inputs, softmax_attn_smoothing=1, training=False):'
)
content = content.replace(
    'inp = tf.expand_dims(inp, axis=-1)',
    'inp = tf.expand_dims(inputs, axis=-1)',
    1
)
content = content.replace(
    'pos_ft, pos_ft_slopes = self.pos_feature(slen, bsz)',
    'pos_ft, pos_ft_slopes = self.pos_feature(slen=slen, bsz=bsz)'
)
content = content.replace(
    'from tensorflow.keras.layers.experimental import SyncBatchNormalization',
    'from tensorflow.keras.layers import BatchNormalization as SyncBatchNormalization'
)
content = content.replace('if l is 0 else', 'if l == 0 else')

# FIX DIVISION BY ZERO
content = content.replace(
    'normalized_slopes = (1. / float(slen-1)) * self.slopes',
    'normalized_slopes = (1. / max(float(slen-1), 1.0)) * self.slopes'
)

with open('transformer.py', 'w') as f:
    f.write(content)
print("  ‚úÖ transformer.py fixed!")

# FIX 2: train_trans.py
print("\nüìù Fixing train_trans.py...")
with open('train_trans.py', 'r') as f:
    content = f.read()

content = content.replace('.reset_states()', '.reset_state()')
content = content.replace(
    'logits = model(inps, softmax_attn_smoothing, training=True)[0]',
    'logits = model(inputs=inps, softmax_attn_smoothing=softmax_attn_smoothing, training=True)[0]'
)
content = content.replace(
    'logits = model(inps)[0]',
    'logits = model(inputs=inps)[0]'
)

with open('train_trans.py', 'w') as f:
    f.write(content)
print("  ‚úÖ train_trans.py fixed!")

# FIX 3: fast_attention.py
print("\nüìù Fixing fast_attention.py...")
import re
with open('fast_attention.py', 'r') as f:
    content = f.read()
pattern = r'self\.add_weight\(\s*"([^"]+)"\s*,'
content = re.sub(pattern, r'self.add_weight(name="\1",', content)
with open('fast_attention.py', 'w') as f:
    f.write(content)
print("  ‚úÖ fast_attention.py fixed!")

print("\nüöÄ ALL FIXES APPLIED! Run training now.")

üîß Applying ALL TensorFlow 2.13+ compatibility fixes...

üìù Fixing transformer.py...
  ‚úÖ transformer.py fixed!

üìù Fixing train_trans.py...
  ‚úÖ train_trans.py fixed!

üìù Fixing fast_attention.py...
  ‚úÖ fast_attention.py fixed!

üöÄ ALL FIXES APPLIED! Run training now.


## 4Ô∏è‚É£ Download ASCADf Dataset

Download the ASCAD dataset with fixed key from Google Drive.

In [5]:
import os
import gdown

# Create data directory
os.makedirs('data', exist_ok=True)

# ASCADf dataset configuration
file_id = "1WNajWT0qFbpqPJiuePS_HeXxsCvUHI5M"
DATASET_PATH = "data/ASCAD.h5"

if not os.path.exists(DATASET_PATH):
    print("üì• Downloading ASCADf dataset from Google Drive...")
    print("   This may take a few minutes (~1.5 GB)\n")
    
    # Download using gdown
    gdown.download(f"https://drive.google.com/uc?id={file_id}", DATASET_PATH, quiet=False)
    
    print("\n‚úÖ Dataset downloaded successfully!")
else:
    print("‚úÖ Dataset already exists")

# Verify dataset
import h5py
with h5py.File(DATASET_PATH, 'r') as f:
    print(f"\nüìä Dataset info:")
    print(f"  Keys: {list(f.keys())}")
    if 'Profiling_traces' in f:
        print(f"  Profiling traces shape: {f['Profiling_traces/traces'].shape}")
    if 'Attack_traces' in f:
        print(f"  Attack traces shape: {f['Attack_traces/traces'].shape}")

üì• Downloading ASCADf dataset from Google Drive...
   This may take a few minutes (~1.5 GB)



Downloading...
From: https://drive.google.com/uc?id=1WNajWT0qFbpqPJiuePS_HeXxsCvUHI5M
To: /content/EstraNet/data/ASCAD.h5
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46.6M/46.6M [00:00<00:00, 157MB/s] 


‚úÖ Dataset downloaded successfully!

üìä Dataset info:
  Keys: ['Attack_traces', 'Profiling_traces']
  Profiling traces shape: (50000, 700)
  Attack traces shape: (10000, 700)





## 5Ô∏è‚É£ Configure Training Parameters

Set up the training configuration. You can modify these parameters as needed.

In [6]:
# Training Configuration
config = {
    # Data config
    'data_path': DATASET_PATH,
    'dataset': 'ASCAD',
    'input_length': 10000,  # or 40000 for full traces
    'data_desync': 200,     # 400 for input_length=40000
    
    # Training config
    'train_batch_size': 16,
    'eval_batch_size': 16,
    'train_steps': 4000000,
    'warmup_steps': 1000000,
    'iterations': 20000,
    'save_steps': 40000,
    
    # Optimization config
    'learning_rate': 2.5e-4,
    'clip': 0.25,
    'min_lr_ratio': 0.004,
    
    # Model architecture
    'n_layer': 2,
    'd_model': 128,
    'd_head': 32,
    'n_head': 8,
    'd_inner': 256,
    'n_head_softmax': 8,
    'd_head_softmax': 16,
    'dropout': 0.05,
    'conv_kernel_size': 3,
    'n_conv_layer': 2,
    'pool_size': 20,
    'd_kernel_map': 512,
    'beta_hat_2': 150,
    'model_normalization': 'preLC',
    'head_initialization': 'forward',
    'softmax_attn': True,
    
    # Checkpoint config
    'checkpoint_dir': './',
    'result_path': 'results',
    'warm_start': False,
    'use_tpu': False,
    'max_eval_batch': 100,
}

print("‚öôÔ∏è Training Configuration:")
print(f"  Dataset: {config['dataset']}")
print(f"  Input length: {config['input_length']}")
print(f"  Batch size: {config['train_batch_size']}")
print(f"  Training steps: {config['train_steps']:,}")
print(f"  Model layers: {config['n_layer']}")
print(f"  Model dimension: {config['d_model']}")
print(f"  Attention heads: {config['n_head']}")

‚öôÔ∏è Training Configuration:
  Dataset: ASCAD
  Input length: 10000
  Batch size: 16
  Training steps: 4,000,000
  Model layers: 2
  Model dimension: 128
  Attention heads: 8


## 6Ô∏è‚É£ Train the Model

Now let's train the EstraNet model. This will take several hours depending on your hardware.

In [7]:
# Build the command arguments
args = [
    f'--use_tpu={config["use_tpu"]}',
    f'--data_path={config["data_path"]}',
    f'--dataset={config["dataset"]}',
    f'--checkpoint_dir={config["checkpoint_dir"]}',
    f'--warm_start={config["warm_start"]}',
    f'--result_path={config["result_path"]}',
    f'--learning_rate={config["learning_rate"]}',
    f'--clip={config["clip"]}',
    f'--min_lr_ratio={config["min_lr_ratio"]}',
    f'--warmup_steps={config["warmup_steps"]}',
    f'--input_length={config["input_length"]}',
    f'--data_desync={config["data_desync"]}',
    f'--train_batch_size={config["train_batch_size"]}',
    f'--eval_batch_size={config["eval_batch_size"]}',
    f'--train_steps={config["train_steps"]}',
    f'--iterations={config["iterations"]}',
    f'--save_steps={config["save_steps"]}',
    f'--n_layer={config["n_layer"]}',
    f'--d_model={config["d_model"]}',
    f'--d_head={config["d_head"]}',
    f'--n_head={config["n_head"]}',
    f'--d_inner={config["d_inner"]}',
    f'--n_head_softmax={config["n_head_softmax"]}',
    f'--d_head_softmax={config["d_head_softmax"]}',
    f'--dropout={config["dropout"]}',
    f'--conv_kernel_size={config["conv_kernel_size"]}',
    f'--n_conv_layer={config["n_conv_layer"]}',
    f'--pool_size={config["pool_size"]}',
    f'--d_kernel_map={config["d_kernel_map"]}',
    f'--beta_hat_2={config["beta_hat_2"]}',
    f'--model_normalization={config["model_normalization"]}',
    f'--head_initialization={config["head_initialization"]}',
    f'--softmax_attn={config["softmax_attn"]}',
    f'--max_eval_batch={config["max_eval_batch"]}',
    '--do_train=True'
]

args_str = ' '.join(args)
print("üöÄ Starting training...\n")
print(f"Arguments: {args_str[:100]}...\n")

# Execute training - output will be displayed automatically
!python train_trans.py {args_str}

üöÄ Starting training...

Arguments: --use_tpu=False --data_path=data/ASCAD.h5 --dataset=ASCAD --checkpoint_dir=./ --warm_start=False --r...

2026-02-08 16:03:29.150432: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770566609.169106     718 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770566609.174557     718 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770566609.190607     718 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770566609.190639     718 computation_placer.cc:177] computation placer already registered. Please check li

## 7Ô∏è‚É£ Evaluate the Model (Run After Training Completes)

‚ö†Ô∏è **Run this cell only after training completes!**

This cell evaluates the trained model on the test set. You can run it multiple times to check different checkpoints.

In [8]:
import os
import glob

# Check if checkpoints exist
checkpoint_files = glob.glob('*.index')
if not checkpoint_files:
    print("‚ö†Ô∏è No checkpoint files found!")
    print("   Make sure training has created at least one checkpoint.")
    print("   Checkpoints are saved every", config['save_steps'], "steps.")
else:
    print(f"‚úÖ Found {len(checkpoint_files)} checkpoint(s)")
    print(f"   Latest checkpoint: {sorted(checkpoint_files)[-1]}\n")
    
    # Build evaluation arguments
    eval_args = [
        f'--use_tpu={config["use_tpu"]}',
        f'--data_path={config["data_path"]}',
        f'--dataset={config["dataset"]}',
        f'--checkpoint_dir={config["checkpoint_dir"]}',
        '--checkpoint_idx=0',
        f'--warm_start={config["warm_start"]}',
        f'--result_path={config["result_path"]}',
        f'--learning_rate={config["learning_rate"]}',
        f'--clip={config["clip"]}',
        f'--min_lr_ratio={config["min_lr_ratio"]}',
        f'--warmup_steps={config["warmup_steps"]}',
        f'--input_length={config["input_length"]}',
        f'--train_batch_size={config["train_batch_size"]}',
        f'--eval_batch_size={config["eval_batch_size"]}',
        f'--train_steps={config["train_steps"]}',
        f'--iterations={config["iterations"]}',
        f'--save_steps={config["save_steps"]}',
        f'--n_layer={config["n_layer"]}',
        f'--d_model={config["d_model"]}',
        f'--d_head={config["d_head"]}',
        f'--n_head={config["n_head"]}',
        f'--d_inner={config["d_inner"]}',
        f'--n_head_softmax={config["n_head_softmax"]}',
        f'--d_head_softmax={config["d_head_softmax"]}',
        f'--dropout={config["dropout"]}',
        f'--conv_kernel_size={config["conv_kernel_size"]}',
        f'--n_conv_layer={config["n_conv_layer"]}',
        f'--pool_size={config["pool_size"]}',
        f'--d_kernel_map={config["d_kernel_map"]}',
        f'--beta_hat_2={config["beta_hat_2"]}',
        f'--model_normalization={config["model_normalization"]}',
        f'--head_initialization={config["head_initialization"]}',
        f'--softmax_attn={config["softmax_attn"]}',
        f'--max_eval_batch={config["max_eval_batch"]}',
        '--output_attn=False',
        '--do_train=False'
    ]
    
    eval_args_str = ' '.join(eval_args)
    print("üìä Starting evaluation...\n")
    
    # Execute evaluation
    !python train_trans.py {eval_args_str}

‚úÖ Found 1 checkpoint(s)
   Latest checkpoint: trans_long-1.index

üìä Starting evaluation...

2026-02-08 16:04:07.501399: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770566647.521867     915 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770566647.527381     915 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770566647.540752     915 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770566647.540779     915 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more t

## 8Ô∏è‚É£ View Results

Check the results directory for evaluation metrics and guessing entropy.

In [9]:
import os

# List files in results directory
if os.path.exists('results'):
    print("üìÅ Results directory contents:")
    for file in os.listdir('results'):
        filepath = os.path.join('results', file)
        size = os.path.getsize(filepath)
        print(f"  - {file} ({size:,} bytes)")
else:
    print("‚ö†Ô∏è Results directory not found")

# List checkpoint files
print("\nüíæ Checkpoint files:")
checkpoint_files = [f for f in os.listdir('.') if 'checkpoint' in f or f.endswith('.index') or f.endswith('.data-00000-of-00001')]
if checkpoint_files:
    for file in sorted(checkpoint_files)[:10]:  # Show first 10
        print(f"  - {file}")
else:
    print("  No checkpoints found yet")

‚ö†Ô∏è Results directory not found

üíæ Checkpoint files:
  - checkpoint
  - trans_long-1.data-00000-of-00001
  - trans_long-1.index


## üìù Notes

### Training Tips:
- **GPU Acceleration**: Make sure GPU is enabled in Colab (Runtime > Change runtime type > GPU)
- **Training Time**: Full training with 4M steps will take many hours. Consider reducing `train_steps` for testing.
- **Checkpoints**: Models are saved every 40,000 steps. You can resume training from checkpoints.
- **Memory**: If you run out of memory, try reducing `train_batch_size` or `input_length`.

### Quick Test Run:
For a quick test, modify the config:
```python
config['train_steps'] = 100000  # Reduced from 4M
config['warmup_steps'] = 10000  # Reduced from 1M
config['save_steps'] = 10000    # Save more frequently
```

### Evaluation:
- **When to run**: Only run the evaluation cell (Section 7) after training has created at least one checkpoint
- **Checkpoints**: The evaluation will use the latest checkpoint automatically
- **Re-run**: You can re-run the evaluation cell anytime to check the latest checkpoint

### Compatibility Fixes Applied:
This notebook automatically fixes TensorFlow 2.13+ / Keras 3 compatibility issues:
- ‚úÖ Replaced SyncBatchNormalization with BatchNormalization (doesn't exist in TF 2.13+)
- ‚úÖ Fixed integer comparison syntax (changed 'is' to '==')
- ‚úÖ Fixed add_weight() method calls in fast_attention.py
- ‚úÖ Fixed reset_states() method calls in train_trans.py (renamed to reset_state())
- ‚úÖ Fixed model call signature for Keras 3 (keyword arguments required)
- ‚úÖ Fixed PositionalFeature call signature for Keras 3 (ALL arguments as keywords)

**Note**: These fixes work with TensorFlow 2.13, 2.19, and other Keras 3-based versions.

### References:
- **Paper**: [IACR TCHES 2024](https://tches.iacr.org/index.php/TCHES/article/view/11255)
- **GitHub**: [suvadeep-iitb/EstraNet](https://github.com/suvadeep-iitb/EstraNet)
- **ASCAD Dataset**: [ANSSI-FR/ASCAD](https://github.com/ANSSI-FR/ASCAD)