In [1]:
# Check if running on Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("‚úÖ Running on Google Colab")
except:
    IN_COLAB = False
    print("üìù Running on local Jupyter")

# Check GPU availability
import tensorflow as tf
print(f"\nTensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

if tf.config.list_physical_devices('GPU'):
    print("üöÄ GPU detected! Training will be accelerated.")
else:
    print("‚ö†Ô∏è No GPU detected. Training will be slower on CPU.")
    if IN_COLAB:
        print("üí° Enable GPU: Runtime > Change runtime type > Hardware accelerator > GPU")

‚úÖ Running on Google Colab

TensorFlow version: 2.19.0
GPU Available: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
üöÄ GPU detected! Training will be accelerated.


In [2]:
import os, shutil

os.chdir('/content')
if os.path.exists('EstraNet'):
    shutil.rmtree('EstraNet')  # Remove nested mess

!git clone https://github.com/loshithan/EstraNet.git
os.chdir('EstraNet')
print(f"‚úÖ Clean! Directory: {os.getcwd()}")

Cloning into 'EstraNet'...
remote: Enumerating objects: 140, done.[K
remote: Counting objects: 100% (140/140), done.[K
remote: Compressing objects: 100% (92/92), done.[K
remote: Total 140 (delta 71), reused 117 (delta 48), pack-reused 0 (from 0)[K
Receiving objects: 100% (140/140), 3.13 MiB | 9.44 MiB/s, done.
Resolving deltas: 100% (71/71), done.
‚úÖ Clean! Directory: /content/EstraNet


In [3]:
import os
import gdown

# Create data directory
os.makedirs('data', exist_ok=True)

# ASCADf dataset configuration
file_id = "1WNajWT0qFbpqPJiuePS_HeXxsCvUHI5M"
DATASET_PATH = "data/ASCAD.h5"

if not os.path.exists(DATASET_PATH):
    print("üì• Downloading ASCADf dataset from Google Drive...")
    print("   This may take a few minutes (~1.5 GB)\n")

    # Download using gdown
    gdown.download(f"https://drive.google.com/uc?id={file_id}", DATASET_PATH, quiet=False)

    print("\n‚úÖ Dataset downloaded successfully!")
else:
    print("‚úÖ Dataset already exists")

# Verify dataset
import h5py
with h5py.File(DATASET_PATH, 'r') as f:
    print(f"\nüìä Dataset info:")
    print(f"  Keys: {list(f.keys())}")
    if 'Profiling_traces' in f:
        print(f"  Profiling traces shape: {f['Profiling_traces/traces'].shape}")
    if 'Attack_traces' in f:
        print(f"  Attack traces shape: {f['Attack_traces/traces'].shape}")

üì• Downloading ASCADf dataset from Google Drive...
   This may take a few minutes (~1.5 GB)



Downloading...
From: https://drive.google.com/uc?id=1WNajWT0qFbpqPJiuePS_HeXxsCvUHI5M
To: /content/EstraNet/data/ASCAD.h5
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 46.6M/46.6M [00:00<00:00, 112MB/s]


‚úÖ Dataset downloaded successfully!

üìä Dataset info:
  Keys: ['Attack_traces', 'Profiling_traces']
  Profiling traces shape: (50000, 700)
  Attack traces shape: (10000, 700)





In [4]:
# Install required dependencies
print("üì¶ Installing dependencies...\n")
# %pip install -q absl-py==2.3.1 numpy==1.24.3 scipy==1.10.1 h5py==3.11.0

# Install gdown for downloading from Google Drive
%pip install -q gdown

# Note: Using TensorFlow version pre-installed in Colab (2.16+ / 2.19+)
# The compatibility fixes in Section 3 work with all TensorFlow 2.13+ versions
print("\n‚úÖ All dependencies installed!")
print(f"Using TensorFlow {tf.__version__} (pre-installed)")

üì¶ Installing dependencies...


‚úÖ All dependencies installed!
Using TensorFlow 2.19.0 (pre-installed)


In [5]:
# ============================================================================
# TRAIN GNN MODEL IN COLAB
# ============================================================================
# Paste this into a new Colab cell

print("üî∑ Training GNN Model")
print("="*70)

# Configuration
CONFIG = {
    'checkpoint_dir': '/content/drive/MyDrive/EstraNet/checkpoints_gnn',
    'result_path': 'results/gnn',
    'train_steps': 5000,
    'save_steps': 200,
    'train_batch_size': 256,
    'eval_batch_size': 32,
    'learning_rate': 0.00025,
    'model_type': 'gnn',  # KEY: Use GNN
}

import os
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs(CONFIG['result_path'], exist_ok=True)

# Build training command
train_cmd = f"""
python train_trans.py \\
    --data_path=data/ASCAD.h5 \\
    --checkpoint_dir={CONFIG['checkpoint_dir']} \\
    --model_type={CONFIG['model_type']} \\
    --dataset=ASCAD \\
    --input_length=700 \\
    --eval_batch_size={CONFIG['eval_batch_size']} \\
    --n_layer=2 \\
    --d_model=128 \\
    --d_inner=256 \\
    --n_head_softmax=8 \\
    --d_head_softmax=16 \\
    --dropout=0.05 \\
    --conv_kernel_size=3 \\
    --n_conv_layer=2 \\
    --pool_size=2 \\
    --beta_hat_2=150 \\
    --model_normalization=preLC \\
    --softmax_attn=True \\
    --do_train=True \\
    --learning_rate={CONFIG['learning_rate']} \\
    --clip=0.25 \\
    --min_lr_ratio=0.004 \\
    --warmup_steps=0 \\
    --train_batch_size={CONFIG['train_batch_size']} \\
    --train_steps={CONFIG['train_steps']} \\
    --iterations=500 \\
    --save_steps={CONFIG['save_steps']} \\
    --result_path={CONFIG['result_path']}
"""

print("Starting GNN training...")
print(f"Model: GNN (211,876 parameters - 51% less than Transformer)")
print(f"Checkpoints: {CONFIG['checkpoint_dir']}")
print(f"Training steps: {CONFIG['train_steps']:,}\n")

!{train_cmd}


üî∑ Training GNN Model
Starting GNN training...
Model: GNN (211,876 parameters - 51% less than Transformer)
Checkpoints: /content/drive/MyDrive/EstraNet/checkpoints_gnn
Training steps: 5,000

2026-02-13 04:13:22.652503: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770956002.673303    7557 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770956002.680023    7557 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770956002.695943    7557 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770956002.695972    7557 computation_placer.cc:177] comp

In [12]:
# ============================================================================
# TEST GNN ARCHITECTURE (Rank 29 Config)
# ============================================================================
# Paste into a Colab cell to verify the model builds correctly
# with the 'Rank 29' parameters (Pool=2, Input=700).

!python test_gnn.py


2026-02-13 04:25:06.751573: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770956706.772463   11915 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770956706.779156   11915 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770956706.795279   11915 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770956706.795306   11915 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770956706.795309   11915 computation_placer.cc:177] computation placer alr

In [20]:
# ============================================================================
# EVALUATE GNN (RANK 29 MODEL)
# ============================================================================
# This script evaluates the CORRECT configuration (Pool=2, Input=700).

print("üîç Evaluating GNN Checkpoints (Correct Config)")
print("="*70)

import tensorflow as tf
import numpy as np
import h5py
import os
import glob
from gnn_estranet import GNNEstraNet
from evaluation_utils import compute_key_rank

# ----------------------------------------------------------------------------
# CONFIGURATION
# ----------------------------------------------------------------------------
# Must match RETRAIN_GNN_CORRECTED.txt
CHECKPOINT_DIR = '/content/drive/MyDrive/EstraNet/checkpoints_gnn'
INPUT_LENGTH = 700
POOL_SIZE = 2
N_TRACES = 2000 # Fast check (use 10000 for full precision)

# ----------------------------------------------------------------------------
# 1. LOAD DATA
# ----------------------------------------------------------------------------
print("\nüì• Loading ASCAD dataset...")
with h5py.File('data/ASCAD.h5', 'r') as f:
    # Load limited traces for speed
    traces = f['Attack_traces']['traces'][:N_TRACES]
    metadata = f['Attack_traces']['metadata'][:N_TRACES]

    # Process inputs (Slice to 700)
    traces = traces[:, :INPUT_LENGTH]

    # FIX: Cast to float32 for TensorFlow model
    # This prevents the "Value passed to parameter 'input' has DataType int8" error
    traces = traces.astype(np.float32)

    # Extract labels/keys
    plaintexts = metadata['plaintext'][:, 2].astype(np.uint8)
    keys = metadata['key'][:, 2].astype(np.uint8)

print(f"‚úÖ Loaded {len(traces)} traces (Length: {traces.shape[1]})")
print(f"   Trace Type: {traces.dtype} (Must be float32)")

# ----------------------------------------------------------------------------
# 2. BUILD MODEL (Correct Rank 29 Config)
# ----------------------------------------------------------------------------
print("\nüèóÔ∏è Building GNN model...")
model = GNNEstraNet(
    n_gcn_layers=2,
    d_model=128,
    k_neighbors=5,
    graph_pooling='mean',
    d_head_softmax=16,
    n_head_softmax=8,
    dropout=0.05,
    n_classes=256,
    conv_kernel_size=3,
    n_conv_layer=2,
    pool_size=POOL_SIZE,  # CRITICAL: Must be 2
    beta_hat_2=150,
    model_normalization='preLC',
    softmax_attn=True,
    output_attn=False
)

# Dummy pass to initialize weights
# Using float32 input explicitly
dummy_input = tf.zeros((1, INPUT_LENGTH), dtype=tf.float32)
model(dummy_input, softmax_attn_smoothing=None, training=False)
print(f"‚úÖ Model Built (Pool Size: {POOL_SIZE})")

# ----------------------------------------------------------------------------
# 3. EVALUATE CHECKPOINTS
# ----------------------------------------------------------------------------
if not os.path.exists(CHECKPOINT_DIR):
    print(f"‚ùå Error: Checkpoint folder not found: {CHECKPOINT_DIR}")
    # Fallback to local
    CHECKPOINT_DIR = 'checkpoints'

print(f"üìÇ Searching for checkpoints in: {CHECKPOINT_DIR}")
ckpt_files = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "*.index")))
if not ckpt_files:
    print("‚ùå No checkpoints found!")
else:
    print(f"‚úÖ Found {len(ckpt_files)} checkpoints.")

for ckpt_path in ckpt_files:
    # Remove extension to get prefix
    prefix = ckpt_path.replace(".index", "")
    fname = os.path.basename(prefix)

    print(f"\nTesting {fname}...")
    try:
        ckpt = tf.train.Checkpoint(model=model)
        ckpt.restore(prefix).expect_partial()
    except Exception as e:
        print(f"‚ö†Ô∏è Failed to load {fname}: {e}")
        continue

    # Inference
    try:
        preds = model.predict(traces, batch_size=256, verbose=0)

        # FIX: Handle tuple return (e.g. if model returns (logits, attn))
        if isinstance(preds, tuple):
            preds = preds[0]

    except Exception as e:
        print(f"‚ùå Prediction failed for {fname}: {e}")
        continue

    # Rank
    # Compute rank evolution (returns array of ranks for 1..N traces)
    try:
        ranks = compute_key_rank(preds, plaintexts, keys)
        final_rank = ranks[-1]

        # Check efficiency (traces to rank 0)
        success_idx = np.where(ranks == 0)[0]
        traces_to_0 = success_idx[0] + 1 if len(success_idx) > 0 else ">" + str(N_TRACES)

        print(f"üèÜ Rank: {final_rank:.2f} | Broken at: {traces_to_0} traces")
    except Exception as e:
        print(f"‚ö†Ô∏è Rank computation failed: {e}")

print("\n‚úÖ Evaluation Complete.")

üîç Evaluating GNN Checkpoints (Correct Config)

üì• Loading ASCAD dataset...
‚úÖ Loaded 2000 traces (Length: 700)
   Trace Type: float32 (Must be float32)

üèóÔ∏è Building GNN model...
‚úÖ GNN Graph Construction: 175 nodes (from Input Length 175)
‚úÖ Model Built (Pool Size: 2)
üìÇ Searching for checkpoints in: /content/drive/MyDrive/EstraNet/checkpoints_gnn
‚úÖ Found 11 checkpoints.

Testing gnn_ASCAD-1...
üèÜ Rank: 1.00 | Broken at: 1473 traces

Testing gnn_ASCAD-10...
üèÜ Rank: 203.00 | Broken at: >2000 traces

Testing gnn_ASCAD-11...
üèÜ Rank: 204.00 | Broken at: >2000 traces

Testing gnn_ASCAD-2...
üèÜ Rank: 210.00 | Broken at: >2000 traces

Testing gnn_ASCAD-3...
üèÜ Rank: 206.00 | Broken at: >2000 traces

Testing gnn_ASCAD-4...
üèÜ Rank: 204.00 | Broken at: >2000 traces

Testing gnn_ASCAD-5...
üèÜ Rank: 205.00 | Broken at: >2000 traces

Testing gnn_ASCAD-6...
üèÜ Rank: 205.00 | Broken at: >2000 traces

Testing gnn_ASCAD-7...
üèÜ Rank: 205.00 | Broken at: >2000 trace

In [7]:
import os
import glob

# Locations to check
locations = [
    '/content/drive/MyDrive/EstraNet/checkpoints_gnn',
    '/content/EstraNet/checkpoints',
    'checkpoints',
]

print("üîç Checking for checkpoints...")
found_any = False
for loc in locations:
    if os.path.exists(loc):
        files = glob.glob(os.path.join(loc, "*.index"))
        if files:
            print(f"‚úÖ Found {len(files)} checkpoints in: {loc}")
            print(f"   Example: {files[0]}")
            found_any = True
            # Update the variable for the next cell
            actual_checkpoint_dir = loc
        else:
            print(f"‚ùå Folder exists but empty: {loc}")
    else:
        print(f"‚ùå Folder not found: {loc}")

if not found_any:
    print("\n‚ö†Ô∏è No checkpoints found in expected locations.")

üîç Checking for checkpoints...
‚úÖ Found 11 checkpoints in: /content/drive/MyDrive/EstraNet/checkpoints_gnn
   Example: /content/drive/MyDrive/EstraNet/checkpoints_gnn/gnn_ASCAD-7.index
‚úÖ Found 1 checkpoints in: /content/EstraNet/checkpoints
   Example: /content/EstraNet/checkpoints/trans_long-8.index
‚úÖ Found 1 checkpoints in: checkpoints
   Example: checkpoints/trans_long-8.index
