# ESTRaNet — Checkpoint → INT8 TFLite → MCU Deployment

This notebook:
1. Loads your pre-trained ESTRaNet checkpoint
2. Converts to INT8 TFLite with Post-Training Quantization
3. Evaluates key rank on test data
4. Prepares the model for MCU deployment

> **No training** — uses existing checkpoint: `trans_long-5`

## 1 — Setup & Mount Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.chdir('/content/drive/MyDrive')  # or wherever your repo is
print('Working directory:', os.getcwd())

MessageError: User cancelled dfs_ephemeral authorization

In [None]:
!pip install -q absl-py h5py

import tensorflow as tf
import numpy as np
import sys

print('TensorFlow:', tf.__version__)
print('GPUs:', tf.config.list_physical_devices('GPU'))

## 2 — Configuration

Update these paths to match your Drive structure.

In [None]:
# ── Paths (EDIT THESE) ────────────────────────────────────────────────────────
REPO_PATH      = '/content/drive/MyDrive/estranet'  # your ESTRaNet code repo
CHECKPOINT_DIR = '/content/drive/MyDrive/checkpoints_transformer_desync0'
CHECKPOINT_IDX = 5  # trans_long-5

# Dataset for evaluation & calibration
DATA_PATH      = '/content/drive/MyDrive/ASCAD.h5'
DATASET_TYPE   = 'ASCAD'  # or 'CHES20'

# Model config (must match your trained checkpoint)
CONFIG = dict(
    input_length      = 700,
    n_layer           = 6,
    d_model           = 128,
    d_head            = 32,
    n_head            = 4,
    d_inner           = 256,
    n_head_softmax    = 4,
    d_head_softmax    = 32,
    dropout           = 0.1,  # will be set to 0.0 at inference
    n_conv_layer      = 1,
    pool_size         = 2,
    d_kernel_map      = 512,
    beta_hat_2        = 150,
    model_normalization = 'preLC',
    head_initialization = 'forward',
)

sys.path.insert(0, REPO_PATH)
print(f'Checkpoint: {CHECKPOINT_DIR}/trans_long-{CHECKPOINT_IDX}')
print(f'Dataset:    {DATA_PATH}')

## 3 — Load Dataset

We need test data for two purposes:
1. **Calibration dataset** for INT8 quantization (~200 traces)
2. **Full test set** for key rank evaluation

In [None]:
import data_utils

# Load test split
test_data = data_utils.Dataset(
    data_path    = DATA_PATH,
    split        = 'test',
    input_length = CONFIG['input_length'],
    data_desync  = 0  # use 0 for fixed evaluation, or match your training desync
)

print(f'Test traces: {test_data.num_samples}')
print(f'Plaintexts: {test_data.plaintexts.shape}')
print(f'Keys:       {test_data.keys.shape}')

## 4 — Restore Pre-Trained Model

Creates the model architecture and loads your checkpoint weights.

In [None]:
from transformer import Transformer

n_classes = 256 if DATASET_TYPE == 'ASCAD' else 4

model = Transformer(
    n_layer             = CONFIG['n_layer'],
    d_model             = CONFIG['d_model'],
    d_head              = CONFIG['d_head'],
    n_head              = CONFIG['n_head'],
    d_inner             = CONFIG['d_inner'],
    d_head_softmax      = CONFIG['d_head_softmax'],
    n_head_softmax      = CONFIG['n_head_softmax'],
    dropout             = 0.0,  # inference mode
    n_classes           = n_classes,
    conv_kernel_size    = 3,
    n_conv_layer        = CONFIG['n_conv_layer'],
    pool_size           = CONFIG['pool_size'],
    d_kernel_map        = CONFIG['d_kernel_map'],
    beta_hat_2          = CONFIG['beta_hat_2'],
    model_normalization = CONFIG['model_normalization'],
    head_initialization = CONFIG['head_initialization'],
    softmax_attn        = True,
    output_attn         = False,
)

# Build model with dummy input
_ = model(tf.zeros([1, CONFIG['input_length']]))

# Restore checkpoint
checkpoint = tf.train.Checkpoint(model=model)
chk_path = os.path.join(CHECKPOINT_DIR, f'trans_long-{CHECKPOINT_IDX}')

print(f'Restoring: {chk_path}')
status = checkpoint.read(chk_path)
status.expect_partial()  # optimizer weights may not be present, that's OK

print('✓ Checkpoint restored successfully')
print(f'  Total parameters: {model.count_params():,}')

## 5 — Evaluate FP32 Model (Baseline)

Get key rank on the test set before quantization.

In [None]:
import evaluation_utils

# Run inference on full test set
test_dataset = test_data.GetTFRecords(batch_size=32, training=False)
predictions = model.predict(test_dataset, verbose=1)

# predictions is a list [logits, attention_weights] but we only need logits
test_scores = predictions[0] if isinstance(predictions, list) else predictions

print(f'Predictions shape: {test_scores.shape}')

# Compute key rank (run 100 times with random noise for robustness)
print('\nComputing key rank (100 iterations)...')
key_rank_list = []
for _ in range(100):
    key_ranks = evaluation_utils.compute_key_rank(
        test_scores,
        test_data.plaintexts,
        test_data.keys
    )
    key_rank_list.append(key_ranks)

key_ranks_fp32 = np.stack(key_rank_list, axis=0)
mean_ranks_fp32 = np.mean(key_ranks_fp32, axis=0)

print('\n─── FP32 Baseline Key Rank ───')
print(f'  Min rank:  {mean_ranks_fp32.min():.2f}')
print(f'  Rank @ 10:  {mean_ranks_fp32[9]:.2f}')
print(f'  Rank @ 100: {mean_ranks_fp32[99]:.2f}')
print(f'  Rank @ 500: {mean_ranks_fp32[499]:.2f}')

## 6 — Export to SavedModel

Required intermediate step for TFLite conversion.

In [None]:
saved_model_path = os.path.join(CHECKPOINT_DIR, 'saved_model')

# Define serving signature with fixed input shape
@tf.function(input_signature=[tf.TensorSpec([None, CONFIG['input_length']], tf.float32)])
def serving_fn(inputs):
    # Return only logits (index 0), not attention weights
    return model(inputs, training=False)[0]

# Save
tf.saved_model.save(
    model,
    saved_model_path,
    signatures={'serving_default': serving_fn}
)

print(f'✓ SavedModel exported to: {saved_model_path}')

## 7 — INT8 Post-Training Quantization

Converts the FP32 model to INT8 using a calibration dataset.

In [None]:
# ── Prepare calibration dataset ───────────────────────────────────────────────
NUM_CALIB_SAMPLES = 200

# Get calibration traces from test set
calib_dataset = test_data.GetTFRecords(batch_size=NUM_CALIB_SAMPLES, training=False)
calib_traces, _ = next(iter(calib_dataset))
calib_traces = calib_traces.numpy()

print(f'Calibration dataset: {calib_traces.shape}')

def representative_dataset_gen():
    """Generator for calibration data."""
    for i in range(calib_traces.shape[0]):
        # TFLite expects batch dimension
        yield [calib_traces[i:i+1]]

# ── Convert to TFLite INT8 ────────────────────────────────────────────────────
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_path)

# Enable INT8 quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset_gen

# Force INT8 for all operations
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
    tf.lite.OpsSet.SELECT_TF_OPS,  # fallback for ops like einsum
]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

# Allow custom ops if needed (LayerCentering, etc.)
converter.allow_custom_ops = True

print('\nConverting to INT8 TFLite (this may take 1-2 minutes)...')
tflite_model = converter.convert()

# Save the quantized model
tflite_path = os.path.join(CHECKPOINT_DIR, 'estranet_int8.tflite')
with open(tflite_path, 'wb') as f:
    f.write(tflite_model)

size_fp32 = model.count_params() * 4 / (1024**2)  # MB
size_int8 = len(tflite_model) / (1024**2)  # MB
compression_ratio = size_fp32 / size_int8

print(f'\n✓ INT8 TFLite model saved to: {tflite_path}')
print(f'  FP32 model size: {size_fp32:.2f} MB')
print(f'  INT8 model size: {size_int8:.2f} MB')
print(f'  Compression:     {compression_ratio:.1f}x')

## 8 — Evaluate INT8 Model

Run inference with the quantized model and check accuracy degradation.

In [None]:
# ── Load TFLite interpreter ───────────────────────────────────────────────────
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()

input_details  = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]

print('TFLite model loaded')
print(f"  Input:  {input_details['shape']} {input_details['dtype']}")
print(f"  Output: {output_details['shape']} {output_details['dtype']}")

# ── Get quantization parameters ───────────────────────────────────────────────
input_scale, input_zero_point = input_details['quantization']
output_scale, output_zero_point = output_details['quantization']

print(f'\nQuantization params:')
print(f'  Input:  scale={input_scale:.6f}, zero_point={input_zero_point}')
print(f'  Output: scale={output_scale:.6f}, zero_point={output_zero_point}')

# ── Run inference on test set ──────────────────────────────────────────────────
def quantize_input(x):
    """Convert FP32 input to INT8."""
    return (x / input_scale + input_zero_point).astype(np.int8)

def dequantize_output(x):
    """Convert INT8 output back to FP32."""
    return (x.astype(np.float32) - output_zero_point) * output_scale

print('\nRunning INT8 inference on test set...')
test_scores_int8 = []

test_dataset_eval = test_data.GetTFRecords(batch_size=1, training=False)
for i, (trace_batch, _) in enumerate(test_dataset_eval.take(test_data.num_samples)):
    if i % 1000 == 0:
        print(f'  {i}/{test_data.num_samples}')
    
    # Quantize input
    trace = trace_batch.numpy()[0]
    trace_int8 = quantize_input(trace)
    
    # Run inference
    interpreter.set_tensor(input_details['index'], trace_int8[None, :])
    interpreter.invoke()
    
    # Get and dequantize output
    output_int8 = interpreter.get_tensor(output_details['index'])
    output_fp32 = dequantize_output(output_int8)
    
    test_scores_int8.append(output_fp32[0])

test_scores_int8 = np.array(test_scores_int8)
print(f'INT8 predictions shape: {test_scores_int8.shape}')

In [None]:
# ── Compute key rank for INT8 model ───────────────────────────────────────────
print('Computing INT8 key rank (100 iterations)...')
key_rank_list_int8 = []
for _ in range(100):
    key_ranks = evaluation_utils.compute_key_rank(
        test_scores_int8,
        test_data.plaintexts,
        test_data.keys
    )
    key_rank_list_int8.append(key_ranks)

key_ranks_int8 = np.stack(key_rank_list_int8, axis=0)
mean_ranks_int8 = np.mean(key_ranks_int8, axis=0)

print('\n─── INT8 Key Rank ───')
print(f'  Min rank:   {mean_ranks_int8.min():.2f}')
print(f'  Rank @ 10:  {mean_ranks_int8[9]:.2f}')
print(f'  Rank @ 100: {mean_ranks_int8[99]:.2f}')
print(f'  Rank @ 500: {mean_ranks_int8[499]:.2f}')

print('\n─── Accuracy Comparison ───')
print(f'  FP32 rank @ 100: {mean_ranks_fp32[99]:.2f}')
print(f'  INT8 rank @ 100: {mean_ranks_int8[99]:.2f}')
print(f'  Degradation:     {mean_ranks_int8[99] - mean_ranks_fp32[99]:.2f}')

## 9 — MCU Deployment Preparation

Convert the `.tflite` file to a C array for embedding in MCU firmware.

In [None]:
# ── Generate C header file ────────────────────────────────────────────────────
c_array_path = os.path.join(CHECKPOINT_DIR, 'estranet_model_data.cc')

!xxd -i {tflite_path} > {c_array_path}

print(f'✓ C array written to: {c_array_path}')
print('\nTo use in your MCU project:')
print('  1. Copy estranet_model_data.cc to your firmware project')
print('  2. Include TFLite Micro runtime (github.com/tensorflow/tflite-micro)')
print('  3. Link with CMSIS-NN kernels for Cortex-M acceleration')
print('\nExample usage:')
print('''
  #include "estranet_model_data.cc"
  
  // Load model
  tflite::MicroInterpreter interpreter(
      tflite::GetModel(g_model),
      ops_resolver,
      tensor_arena,
      kTensorArenaSize
  );
  
  // Run inference
  TfLiteTensor* input = interpreter.input(0);
  // ... copy your trace to input->data.int8 ...
  interpreter.Invoke();
  TfLiteTensor* output = interpreter.output(0);
''')

## 10 — Summary

Review final metrics and next steps.

In [None]:
print('═'*70)
print('                    QUANTIZATION SUMMARY')
print('═'*70)
print(f'\nCheckpoint:  {CHECKPOINT_DIR}/trans_long-{CHECKPOINT_IDX}')
print(f'Dataset:     {DATA_PATH}')
print(f'Test traces: {test_data.num_samples}')
print(f'\n┌─ Model Size')
print(f'│  FP32:  {size_fp32:.2f} MB')
print(f'│  INT8:  {size_int8:.2f} MB')
print(f'│  Ratio: {compression_ratio:.1f}x compression')
print(f'│')
print(f'┌─ Key Rank @ 100 traces')
print(f'│  FP32:  {mean_ranks_fp32[99]:.2f}')
print(f'│  INT8:  {mean_ranks_int8[99]:.2f}')
print(f'│  Loss:  {mean_ranks_int8[99] - mean_ranks_fp32[99]:.2f}')
print(f'│')
print(f'┌─ Output Files')
print(f'│  TFLite: {tflite_path}')
print(f'│  C code: {c_array_path}')
print('└─')
print('\n✓ Model ready for MCU deployment')
print('='*70)