In [None]:
# # Check if running on Google Colab
# try:
#     import google.colab
#     IN_COLAB = True
#     print("‚úÖ Running on Google Colab")
# except:
#     IN_COLAB = False
#     print("üìù Running on local Jupyter")

# # Check GPU availability
import tensorflow as tf
print(f"\nTensorFlow version: {tf.__version__}")
# print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

# if tf.config.list_physical_devices('GPU'):
#     print("üöÄ GPU detected! Training will be accelerated.")
# else:
#     print("‚ö†Ô∏è No GPU detected. Training will be slower on CPU.")
#     if IN_COLAB:
#         print("üí° Enable GPU: Runtime > Change runtime type > Hardware accelerator > GPU")

In [None]:
import os, shutil

os.chdir('/content')
if os.path.exists('EstraNet'):
    shutil.rmtree('EstraNet')  # Remove nested mess

!git clone https://github.com/loshithan/EstraNet.git
os.chdir('EstraNet')
print(f"‚úÖ Clean! Directory: {os.getcwd()}")

In [None]:
# import os
import gdown

# # Create data directory
# os.makedirs('data', exist_ok=True)

# # ASCADf dataset configuration
# file_id = "1WNajWT0qFbpqPJiuePS_HeXxsCvUHI5M"
# DATASET_PATH = "data/ASCAD.h5"

# if not os.path.exists(DATASET_PATH):
#     print("üì• Downloading ASCADf dataset from Google Drive...")
#     print("   This may take a few minutes (~1.5 GB)\n")

#     # Download using gdown
#     gdown.download(f"https://drive.google.com/uc?id={file_id}", DATASET_PATH, quiet=False)

#     print("\n‚úÖ Dataset downloaded successfully!")
# else:
#     print("‚úÖ Dataset already exists")

# # Verify dataset
# import h5py
# with h5py.File(DATASET_PATH, 'r') as f:
#     print(f"\nüìä Dataset info:")
#     print(f"  Keys: {list(f.keys())}")
#     if 'Profiling_traces' in f:
#         print(f"  Profiling traces shape: {f['Profiling_traces/traces'].shape}")
#     if 'Attack_traces' in f:
#         print(f"  Attack traces shape: {f['Attack_traces/traces'].shape}")

In [None]:
# Install required dependencies
print("üì¶ Installing dependencies...\n")
# %pip install -q absl-py==2.3.1 numpy==1.24.3 scipy==1.10.1 h5py==3.11.0

# Install gdown for downloading from Google Drive
%pip install -q gdown

# Note: Using TensorFlow version pre-installed in Colab (2.16+ / 2.19+)
# The compatibility fixes in Section 3 work with all TensorFlow 2.13+ versions
print("\n‚úÖ All dependencies installed!")
print(f"Using TensorFlow {tf.__version__} (pre-installed)")

In [None]:
# ============================================================================
# TRAIN UPGRADED "ATTENTION GNN" MODEL IN COLAB
# ============================================================================
# Paste this into a new Colab cell

print("üî∑ Training Upgraded GNN Model (v2: Attention + Deeper)")
print("="*70)

# Configuration
CONFIG = {
    'checkpoint_dir': '/content/drive/MyDrive/EstraNet/checkpoints_gnn_attention_v2',
    'result_path': 'results/gnn_attention',
    'train_steps': 50000,   # Increased for convergence
    'save_steps': 2000,
    'train_batch_size': 256,
    'eval_batch_size': 32,
    'learning_rate': 0.0002, # Slightly higher learning rate for GNN
    'model_type': 'gnn',
    
    # NEW GNN HYPERPARAMETERS
    'n_gcn_layers': 4,       # Deeper (was 2)
    'k_neighbors': 15,       # Wider Context (was 5)
    'graph_pooling': 'attention' # The Secret Sauce!
}

import os
os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs(CONFIG['result_path'], exist_ok=True)

# Build training command
train_cmd = f"""
python train_trans.py \\
    --data_path=data/ASCAD.h5 \\
    --checkpoint_dir={CONFIG['checkpoint_dir']} \\
    --model_type={CONFIG['model_type']} \\
    --dataset=ASCAD \\
    --input_length=700 \\
    
    # --- MODEL ARCHITECTURE ---
    --d_model=128 \\
    --n_gcn_layers={CONFIG['n_gcn_layers']} \\
    --k_neighbors={CONFIG['k_neighbors']} \\
    --graph_pooling={CONFIG['graph_pooling']} \\
    --conv_kernel_size=3 \\
    --n_conv_layer=2 \\
    --pool_size=2 \\
    --dropout=0.1 \\
    
    # --- TRAINING PARAMS ---
    --do_train=True \\
    --learning_rate={CONFIG['learning_rate']} \\
    --clip=0.25 \\
    --min_lr_ratio=0.004 \\
    --warmup_steps=1000 \\
    --train_batch_size={CONFIG['train_batch_size']} \\
    --eval_batch_size={CONFIG['eval_batch_size']} \\
    --train_steps={CONFIG['train_steps']} \\
    --iterations=500 \\
    --save_steps={CONFIG['save_steps']} \\
    --result_path={CONFIG['result_path']}
"""

print("Starting GNN training...")
print(f"Model: GNN-Attention (Est. 350k parameters - Still lightweight!)")
print(f"Features: {CONFIG['n_gcn_layers']} Layers, {CONFIG['k_neighbors']} Neighbors, Attention Pooling")
print(f"Checkpoints: {CONFIG['checkpoint_dir']}")
print(f"Training steps: {CONFIG['train_steps']:,}\n")

!{train_cmd}


In [None]:
# ============================================================================
# TEST GNN ARCHITECTURE (Rank 29 Config)
# ============================================================================
# Paste into a Colab cell to verify the model builds correctly
# with the 'Rank 29' parameters (Pool=2, Input=700).

import os
import shutil

# Fix: Copy GNN scripts from subfolder to root
if os.path.exists('gnn-scripts'):
    print("üìÇ Moving GNN scripts to root...")
    for f in os.listdir('gnn-scripts'):
        if f.endswith('.py'):
            shutil.copy(os.path.join('gnn-scripts', f), '.')
            print(f"   Copied {f}")

print("\nüöÄ Running test_gnn.py...")
!python test_gnn.py

In [None]:
# ============================================================================
# EVALUATE GNN (RANK 29 MODEL)
# ============================================================================
# This script evaluates the CORRECT configuration (Pool=2, Input=700).

print("üîç Evaluating GNN Checkpoints (Correct Config)")
print("="*70)

import tensorflow as tf
import numpy as np
import h5py
import os
import glob
from gnn_estranet import GNNEstraNet
from evaluation_utils import compute_key_rank

# ----------------------------------------------------------------------------
# CONFIGURATION
# ----------------------------------------------------------------------------
# Must match RETRAIN_GNN_CORRECTED.txt
CHECKPOINT_DIR = '/content/drive/MyDrive/EstraNet/checkpoints_gnn'
INPUT_LENGTH = 700
POOL_SIZE = 2
N_TRACES = 2000 # Fast check (use 10000 for full precision)

# ----------------------------------------------------------------------------
# 1. LOAD DATA
# ----------------------------------------------------------------------------
print("\nüì• Loading ASCAD dataset...")
with h5py.File('data/ASCAD.h5', 'r') as f:
    # Load limited traces for speed
    traces = f['Attack_traces']['traces'][:N_TRACES]
    metadata = f['Attack_traces']['metadata'][:N_TRACES]

    # Process inputs (Slice to 700)
    traces = traces[:, :INPUT_LENGTH]

    # FIX: Cast to float32 for TensorFlow model
    # This prevents the "Value passed to parameter 'input' has DataType int8" error
    traces = traces.astype(np.float32)

    # Extract labels/keys
    plaintexts = metadata['plaintext'][:, 2].astype(np.uint8)
    keys = metadata['key'][:, 2].astype(np.uint8)

print(f"‚úÖ Loaded {len(traces)} traces (Length: {traces.shape[1]})")
print(f"   Trace Type: {traces.dtype} (Must be float32)")

# ----------------------------------------------------------------------------
# 2. BUILD MODEL (Correct Rank 29 Config)
# ----------------------------------------------------------------------------
print("\nüèóÔ∏è Building GNN model...")
model = GNNEstraNet(
    n_gcn_layers=2,
    d_model=128,
    k_neighbors=5,
    graph_pooling='mean',
    d_head_softmax=16,
    n_head_softmax=8,
    dropout=0.05,
    n_classes=256,
    conv_kernel_size=3,
    n_conv_layer=2,
    pool_size=POOL_SIZE,  # CRITICAL: Must be 2
    beta_hat_2=150,
    model_normalization='preLC',
    softmax_attn=True,
    output_attn=False
)

# Dummy pass to initialize weights
# Using float32 input explicitly
dummy_input = tf.zeros((1, INPUT_LENGTH), dtype=tf.float32)
model(dummy_input, softmax_attn_smoothing=None, training=False)
print(f"‚úÖ Model Built (Pool Size: {POOL_SIZE})")

# ----------------------------------------------------------------------------
# 3. EVALUATE CHECKPOINTS
# ----------------------------------------------------------------------------
if not os.path.exists(CHECKPOINT_DIR):
    print(f"‚ùå Error: Checkpoint folder not found: {CHECKPOINT_DIR}")
    # Fallback to local
    CHECKPOINT_DIR = 'checkpoints'

print(f"üìÇ Searching for checkpoints in: {CHECKPOINT_DIR}")
ckpt_files = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "*.index")))
if not ckpt_files:
    print("‚ùå No checkpoints found!")
else:
    print(f"‚úÖ Found {len(ckpt_files)} checkpoints.")

for ckpt_path in ckpt_files:
    # Remove extension to get prefix
    prefix = ckpt_path.replace(".index", "")
    fname = os.path.basename(prefix)

    print(f"\nTesting {fname}...")
    try:
        ckpt = tf.train.Checkpoint(model=model)
        ckpt.restore(prefix).expect_partial()
    except Exception as e:
        print(f"‚ö†Ô∏è Failed to load {fname}: {e}")
        continue

    # Inference
    try:
        preds = model.predict(traces, batch_size=256, verbose=0)

        # FIX: Handle tuple return (e.g. if model returns (logits, attn))
        if isinstance(preds, tuple):
            preds = preds[0]

    except Exception as e:
        print(f"‚ùå Prediction failed for {fname}: {e}")
        continue

    # Rank
    # Compute rank evolution (returns array of ranks for 1..N traces)
    try:
        ranks = compute_key_rank(preds, plaintexts, keys)
        final_rank = ranks[-1]

        # Check efficiency (traces to rank 0)
        success_idx = np.where(ranks == 0)[0]
        traces_to_0 = success_idx[0] + 1 if len(success_idx) > 0 else ">" + str(N_TRACES)

        print(f"üèÜ Rank: {final_rank:.2f} | Broken at: {traces_to_0} traces")
    except Exception as e:
        print(f"‚ö†Ô∏è Rank computation failed: {e}")

print("\n‚úÖ Evaluation Complete.")

In [None]:
# ============================================================================
# BACKUP LOCAL CHECKPOINTS TO DRIVE
# ============================================================================
import os
import shutil
import datetime
from google.colab import drive

# 1. Mount Drive
if not os.path.exists('/content/drive'):
    print("üìÇ Mounting Google Drive...")
    drive.mount('/content/drive')

# 2. Configuration
SOURCE_DIR = '/content/checkpoints_gnn'  # Local Colab path
# Check if it exists locally, otherwise check inside repo folder
if not os.path.exists(SOURCE_DIR):
    SOURCE_DIR = '/content/EstraNet/checkpoints_gnn'

# Destination with timestamp to avoid overwriting
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
DEST_DIR = f'/content/drive/MyDrive/EstraNet/checkpoints_gnn_backup_{timestamp}'

# 3. Copy
if os.path.exists(SOURCE_DIR):
    print(f"\nüì¶ Found local checkpoints at: {SOURCE_DIR}")
    print(f"üöÄ Copying to: {DEST_DIR}...")

    try:
        shutil.copytree(SOURCE_DIR, DEST_DIR)
        print(f"‚úÖ Backup successful! Folder: checkpoints_gnn_backup_{timestamp}")
    except Exception as e:
        print(f"‚ùå Backup failed: {e}")
else:
    print(f"\n‚ö†Ô∏è No local '{SOURCE_DIR}' folder found to upload.")
    print(f"   (Your checkpoints might already be in Drive at: /content/drive/MyDrive/EstraNet/checkpoints_gnn)")

In [None]:
import os
import shutil

# Paths
source_path = '/content/drive/MyDrive/EstraNet/checkpoints_gnn'
dest_path = '/content/drive/MyDrive/checkpoints_gnn'

print(f"üîÑ Attempting to move: {source_path} -> {dest_path}")

if os.path.exists(source_path):
    # Check if destination already exists
    if os.path.exists(dest_path):
        print(f"‚ö†Ô∏è Destination folder '{dest_path}' already exists.")
        print("   Renaming source to 'checkpoints_gnn_moved' to avoid overwriting.")
        dest_path = dest_path + "_moved"

    try:
        shutil.move(source_path, dest_path)
        print(f"‚úÖ Successfully moved to: {dest_path}")
    except Exception as e:
        print(f"‚ùå Error moving folder: {e}")
else:
    print(f"‚ùå Source folder not found: {source_path}")
    print("   Please check if the folder path is correct.")

In [None]:
# import os

# # Target directory we just moved
# target_dir = '/content/drive/MyDrive/checkpoints_gnn'

# print("üîç Verifying Google Drive Sync...")

# if os.path.exists(target_dir):
#     print(f"‚úÖ Drive is mounted and folder exists: {target_dir}")

#     # List a few files to confirm access
#     files = os.listdir(target_dir)
#     print(f"   Contains {len(files)} files/folders.")
#     if files:
#         print(f"   Example: {files[0]}")
# else:
#     print(f"‚ùå Folder not found: {target_dir}")
#     print("   Drive might not be mounted correctly or the move failed.")
#     # Optional: Suggest remounting only if strictly needed
#     print("   If this fails, try: drive.mount('/content/drive', force_remount=True)")

In [None]:
import shutil
import os
from google.colab import files

# The folder verified to exist in the previous step
folder_to_zip = '/content/drive/MyDrive/checkpoints_gnn'
zip_name = '/content/gnn_checkpoints_archive2'

if os.path.exists(folder_to_zip):
    print(f"üì¶ Zipping '{folder_to_zip}'... (This may take a moment)")

    # Create zip file (shutil adds .zip extension automatically)
    shutil.make_archive(zip_name, 'zip', folder_to_zip)

    print(f"‚úÖ Zip created: {zip_name}.zip")
    print("‚¨áÔ∏è Starting download to your local machine...")

    # Trigger browser download
    files.download(zip_name + '.zip')
else:
    print(f"‚ùå Error: The folder '{folder_to_zip}' was not found in the Colab runtime.")

In [None]:
# import os
# import glob

# # Locations to check
# locations = [
#     '/content/drive/MyDrive/EstraNet/checkpoints_gnn',
#     '/content/EstraNet/checkpoints',
#     'checkpoints',
# ]

# print("üîç Checking for checkpoints...")
# found_any = False
# for loc in locations:
#     if os.path.exists(loc):
#         files = glob.glob(os.path.join(loc, "*.index"))
#         if files:
#             print(f"‚úÖ Found {len(files)} checkpoints in: {loc}")
#             print(f"   Example: {files[0]}")
#             found_any = True
#             # Update the variable for the next cell
#             actual_checkpoint_dir = loc
#         else:
#             print(f"‚ùå Folder exists but empty: {loc}")
#     else:
#         print(f"‚ùå Folder not found: {loc}")

# if not found_any:
#     print("\n‚ö†Ô∏è No checkpoints found in expected locations.")