## Setup

In [None]:
# Base Python libraries
import os
from math import log, log1p
from typing import Callable
from collections import defaultdict

# Third-party Python libraries
from tqdm.auto import tqdm
import numpy as np
from scipy.stats import spearmanr
import torch
from torch import nn, optim
from torch.utils.data import random_split, DataLoader

# Our own libraries
from common import *
from datasets.ascadv1 import ASCADv1, ascadv1_download # A second-order leaking dataset we'll use for this example
from utils.metrics import get_rank

In [None]:
ASCADv1_ROOT = os.path.join(RESOURCE_DIR, 'ascadv1-fixed') # Dataset will be auto-downloaded here if not already present. Feel free to change this directory.

training_steps = 10000
minibatch_size = 256
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
gamma_budget = 0.5 # Budget hyperparameter, equal to the value of \gamma when \eta is uniform. Called \overline{\gamma} in the paper.

In [None]:
ascadv1_download(ASCADv1_ROOT) # make sure the dataset has been downloaded and extracted
profiling_dataset = ASCADv1(root=ASCADv1_ROOT, train=True, variable_keys=False) # train dataset (called 'profiling' dataset by side-channel community)
attack_dataset = ASCADv1(root=ASCADv1_ROOT, train=False, variable_keys=False) # test dataset (called 'attack' dataset by side-channel community)
input_features = profiling_dataset.timesteps_per_trace # called T in the paper
output_classes = profiling_dataset.class_count # cardinality of \mathsf{Y} in the paper
train_dataset, val_dataset = random_split(profiling_dataset, lengths=(40000, 10000))
train_dataloader = DataLoader(train_dataset, batch_size=minibatch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2*minibatch_size)
attack_dataloader = DataLoader(attack_dataset, batch_size=2*minibatch_size)

## Computing the oracle signal to noise ratio

The basic approach when carrying out a side-channel attack on AES-128 is to target an internal variable called the first SubBytes output. This variable is defined as $S := \operatorname{Sbox}(k \oplus w)$ where $k$ denotes one byte of the cryptographic key, $w$ denotes the corresponding byte of the plaintext, $\oplus$ denotes the bitwise exclusive-or operation, and $\operatorname{Sbox}$ denotes an invertible nonlinear operation which is the same for all AES implementations. Note that the plaintext is not a secret, so if the attacker can determine $S$, they can then compute the corresponding byte of the key as $k = \operatorname{Sbox}^{-1}(S) \oplus w.$

Since the SubBytes is a major target of attackers, the designers of AES algorithms have developed countermeasures to make it more difficult to attack. In the implementations from which the ASCADv1 datasets were collected, the SubBytes variable is protected by a countermeasure called Boolean masking. Before every encryption, random bytes called *masks* are sampled uniformly at random. Note that for random bit $b$ with arbitrary distribution, if we generate a random mask bit as $r \sim \mathcal{U}\{0, 1\}$, the variable $b \oplus r$ is now statistically independent of $b$. The same is true if $b$ and $r$ are bitstrings (e.g. bytes) rather than single bits. Boolean masking exploits this fact by modifying the AES algorithm so that the SubBytes variable $S$ is never directly operated on and therefore never directly influences power consumption or EM radiation. The algorithm operates only on variables $S \oplus r$ for various mask variables $r$. Thus, attackers must determine ordered pairs $(r, S \oplus r)$ to determine $S$, where $r$ and $S \oplus r$ are usually leaked at temporally-distant points in time. While deep learning methods have proven capable of overcoming countermeasures of this nature, they make it significantly harder to attack using older parametric statistics-based techniques. Refer to algorithm 1 of Benadjila et al. (2020) for details.

Boolean-masked implementations often leak due to many such pairs of internal AES variables. Identification of these variables generally requires significant domain knowledge and careful analysis of the AES implementation, as well as knowledge of the internally-generated random mask variables of the implementation (on top of the key and plaintext, which are commonly assumed to be known to attackers during the 'profiling' phase of their attack). In Benadjila et al. (2020), the creators of the ASCADv1 datasets identified 2 pairs of random variables which leak. 2 additional pairs were subsequently discovered by Egger et al. (2022). In the side-channel community, leakage localization for non-masked datasets is usually done using simple first-order parametric statistical methods such as computing the signal to noise ratio between individual power measurements and the SubBytes variable. This does not work for second-order datasets such as ASCADv1 because each individual power measurement is by design nearly-independent of the SubBytes variable. However, it is possible to individually target each of the leaky internal AES variables with such techniques, then average the 'leakiness' assessments to 

In [None]:
targets = [
    # This is the standard target when carrying out AES attacks. While this variable *does* leak from the ASCADv1 datasets,
    #  because they are Boolean masked we cannot identify the leaking points with the standard first-order parametric statistical techniques.
    'subbytes',

    # These are the pairs of leaking variables identified by Egger et al. (2022). It is possible to use the standard first-order methods to
    #  identify leaking points for these individual variables, then average the results to get a list of leaky points for SubBytes.
    'r_in', 'plaintext__key__r_in',
    'r', 'subbytes__r',
    'r_out', 'subbytes__r_out',
    's_prev__subbytes__r_out', 'security_load'
]


In [None]:
classifiers = nn.Sequential( # called \Phi_\theta in the paper
    nn.Linear(2*input_features, 500), # factor of 2 because we are feeding it both the power trace and the occlusion mask
    nn.ReLU(),
    *sum([[nn.Linear(500, 500), nn.ReLU()] for _ in range(3)], start=[]),
    nn.Linear(500, output_classes)
)
eta_tau = 0.01*torch.randn(input_features) # the erasure probability logits before reparameterizing for the budget constraint, called \tilde{\bm{\eta}} in the paper
eta_tau.requires_grad_(True)
theta_optimizer = optim.Adam(classifiers.parameters(), lr=1e-4)
etat_optimizer = optim.Adam([eta_tau], lr=1e-3)
print(classifiers)

In [None]:
def get_erasure_prob_logits(eta_tau: torch.Tensor, gamma_budget: float) -> torch.Tensor: # equation 6 in paper
    gamma_tau = eta_tau - torch.logsumexp(eta_tau.squeeze(0), dim=0) + log(input_features) + log(gamma_budget) - log1p(-gamma_budget)
    return gamma_tau

def sample_from_concrete_distribution(batch_size: int, prob_logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
    # for numerical stability it's important to keep these log-scale instead of e.g. passing probabilities and taking log
    log_probs = nn.functional.logsigmoid(prob_logits).unsqueeze(0).repeat(batch_size, 1, 1)
    log_1mprobs = nn.functional.logsigmoid(-prob_logits).unsqueeze(0).repeat(batch_size, 1, 1)
    u = torch.rand_like(log_probs).clamp_(min=1e-6, max=1-1e-6) # clamping to avoid taking log of 0
    concrete_sample = nn.functional.sigmoid((log_probs - log_1mprobs + u.log() - (1-u).log())/temperature)
    return concrete_sample

def get_masked_logits(classifiers: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: torch.Tensor, mask: torch.Tensor):
    masked_input = (1-mask)*x + mask*torch.randn_like(x)
    logits = classifiers(torch.cat([masked_input, 1-mask], dim=-1))
    return logits

In [None]:
training_curves = defaultdict(list)
step_idx = 0
progress_bar = tqdm(total=training_steps)
while step_idx < training_steps:
    for x, y in train_dataloader:
        batch_size, *_ = x.shape
        gamma_tau = get_erasure_prob_logits(eta_tau, gamma_budget)
        mask = sample_from_concrete_distribution(batch_size, gamma_tau)
        classifier_logits = get_masked_logits(classifiers, x, mask)
        theta_loss = nn.functional.cross_entropy(classifier_logits, y)
        eta_tau_loss = -theta_loss
        theta_optimizer.zero_grad()
        etat_optimizer.zero_grad()
        theta_loss.backward(retain_graph=True, inputs=list(classifiers.parameters()))
        eta_tau_loss.backward(inputs=list(eta_tau))
        theta_optimizer.step()
        etat_optimizer.step()
        with torch.no_grad():
            training_curves['train_loss'].append(theta_loss.item())
            training_curves['classifiers_rank'].append(get_rank(classifier_logits, y).mean().item()) # Lower === more-accurate.
        step_idx += 1
        progress_bar.update(1)
        if step_idx >= training_steps:
            break
    with torch.no_grad():
        val_loss, val_rank = [], []
        for x, y in val_dataloader():
            batch_size, *_ = x.shape
            gamma_tau = get_erasure_prob_logits(eta_tau, gamma_budget)
            mask = sample_from_concrete_distribution(batch_size, gamma_tau)
            logits = get_masked_logits(classifiers, x, mask)
            loss = nn.functional.cross_entropy(classifier_logits, y)
            val_loss.append(loss.item())
            val_rank.append(get_rank(logits, y).mean().item())
        training_curves['val_loss'].append(np.mean(val_loss))
        training_curves['val_rank'].append(np.mean(val_rank))