In [1]:
"""Evaluate a models testing keyrank at N traces and how many traces required to achieve 99% accuracy"""

import data
import sys
import numpy as np
import torch
import json

import training

import keyrank_rs

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)

In [23]:
def get_best_epoch(model_name) -> int:
    with open(f"models/eval/{model_name}.txt", 'r') as f:
        lines = f.readlines()[1:] # skip "N traces" line
        lines = [float(line.strip()) for line in lines]
        best_epoch = np.array(lines).argmin()
    return best_epoch.item()

def metadata_best_epoch(model_name) -> int:
    with open(f"models/{model_name}/metadata.json") as f:
        metadata = json.load(f)
        val_scores = metadata["scores"][1]
        best_epoch = np.array(val_scores).argmin()
    return best_epoch.item()


device = torch.device("cuda")

testing_data = {}


PREDICTION_TARGET = "combo"
TARGET_BYTE_IDX = 1
TRACE_START = 400
TRACE_END = 1500
SEED = 777

_,_,test_loader = data.get_dataloaders(
    200,
    PREDICTION_TARGET,
    TARGET_BYTE_IDX,
    TRACE_START,
    TRACE_END,
    SEED,
)

In [24]:
from models import *

IMPL = "fixslice"
ARCH = "zhang"


LOSS_WEIGHT = 0.5

N_TRACES = 10


testing_data["n_traces"] = N_TRACES

if not PREDICTION_TARGET == "combo":
    model_name = f"{IMPL}-{PREDICTION_TARGET}-byte{TARGET_BYTE_IDX}-{ARCH}-{TRACE_START}_{TRACE_END}-s{SEED}"
    epoch = metadata_best_epoch(model_name)
    model_path = f"models/{model_name}/epoch{epoch}.pt"
    print(model_path)

    model = torch.load(model_path)
    model.eval()
else:
    models = []
    for target in ["sbox","sbox2"]:
        model_name = f"{IMPL}-{target}-byte{TARGET_BYTE_IDX}-{ARCH}-{TRACE_START}_{TRACE_END}-s{SEED}"
        epoch = metadata_best_epoch(model_name)
        model_path = f"models/{model_name}/epoch{epoch}.pt"
        print(model_path)

        model = torch.load(model_path)
        model.eval()

        models.append(model)
#testing_data["best_epoch"] = epoch




with torch.no_grad():
    if PREDICTION_TARGET == "key":
        testing_keyrank = training.mean_keyrank(model, test_loader, N_TRACES)
    elif PREDICTION_TARGET == "sbox":
        testing_keyrank = training.mean_sbox_rank(model, test_loader, N_TRACES)
    elif PREDICTION_TARGET == "sbox2":
        testing_keyrank = training.mean_sbox_rank(model, test_loader, N_TRACES, plaintext=1)
    elif PREDICTION_TARGET in ["2sbox", "2sbox*", "2sbox..."]:
        testing_keyrank = training.mean_2sbox_rank(model, test_loader, N_TRACES)
    elif PREDICTION_TARGET == "combo":
        testing_keyrank = None

testing_data["testing_keyrank"] = testing_keyrank
testing_keyrank

models/fixslice-sbox-byte1-zhang-400_1500-s777/epoch32.pt


  model = torch.load(model_path)


models/fixslice-sbox2-byte1-zhang-400_1500-s777/epoch20.pt


In [25]:
from tqdm import tqdm

def traces_for_99acc(model, test_loader, block=0):
    # maps (N traces, key) to success/failure
    success_matrix = np.zeros((500,500))

    for key_idx, (traces, plaintexts, true_key) in tqdm(enumerate(test_loader), unit='key'):

        # All 500 traces
        traces : torch.Tensor = traces.to(device).squeeze()
        plaintexts : torch.Tensor = plaintexts.to(device)

        # All 500 plaintexts
        plaintexts = plaintexts.long().detach().cpu().numpy().squeeze()
        plaintexts : np.ndarray  = plaintexts[:, block]

        # [500, 256]
        sbox_scores = model(traces)

        numpy_scores = sbox_scores.detach().cpu().numpy()
            
        numpy_keyscores = keyrank_rs.sbox_scores_to_keyscores_parallel(plaintexts, numpy_scores)
        keyscores = torch.Tensor(numpy_keyscores).to(device)

        keyscore_acc = torch.zeros(256)

        for n_traces,keyscore in enumerate(keyscores):

            keyscore_acc += keyscore.cpu().softmax().log()
            rank1_value = keyscore_acc.argmax().item()
            success_matrix[n_traces, key_idx] = 1. if rank1_value == true_key else 0.

    N_traces_successes = success_matrix.sum(axis=1)

    print(N_traces_successes)

    for idx, n_hits in enumerate(N_traces_successes):
        if n_hits >= 495.0: # 99%
            return idx+1

    return -1

In [26]:
def traces_for_99acc_2sbox(model, test_loader):
    # maps (N traces, key) to success/failure
    success_matrix = np.zeros((500,500))

    for key_idx, (traces, plaintexts, true_key) in tqdm(enumerate(test_loader), unit='key'):

        # All 500 traces
        traces : torch.Tensor = traces.to(device).squeeze()
        plaintexts : torch.Tensor = plaintexts.to(device)

        # All 500 plaintexts
        plaintexts = plaintexts.long().detach().cpu().numpy().squeeze()
        plaintexts : np.ndarray  = plaintexts.transpose((1, 0))


        # 2PT model outputs list of 2 tensors
        sbox_scores = model(traces)
        if type(sbox_scores) is list:
            sbox_scores = torch.stack(sbox_scores).detach().cpu().numpy()
        else:
            sbox_scores = sbox_scores.detach().cpu().numpy()


        keyscores_both = []
        for scores, pt in zip(sbox_scores, plaintexts):

            numpy_keyscores = keyrank_rs.sbox_scores_to_keyscores_parallel(pt, scores)
            keyscores_both.append(torch.Tensor(numpy_keyscores).to(device))

        keyscores_both = [torch.Tensor(np_ks).softmax(dim=1).log() for np_ks in keyscores_both]
        keyscores_combo = keyscores_both[0] + keyscores_both[1]

        keyscore_acc = torch.zeros(256)

        for n_traces,keyscore in enumerate(keyscores_combo):

            keyscore_acc += keyscore.cpu()
            rank1_value = keyscore_acc.argmax().item()
            success_matrix[n_traces, key_idx] = 1. if rank1_value == true_key else 0.

    N_traces_successes = success_matrix.sum(axis=1)

    print(N_traces_successes)

    for idx, n_hits in enumerate(N_traces_successes):
        if n_hits >= 495.0: # 99%
            return idx+1

    return -1


In [30]:
from tqdm import tqdm

def traces_for_99acc_combo(sbox1_model, sbox2_model, test_loader):
    # maps (N traces, key) to success/failure
    success_matrix = np.zeros((500,500))

    for key_idx, (traces, plaintexts, true_key) in tqdm(enumerate(test_loader), unit='key'):

        # All 500 traces
        traces : torch.Tensor = traces.to(device).squeeze()
        plaintexts : torch.Tensor = plaintexts.to(device)

        # All 500 plaintexts
        plaintexts = plaintexts.long().detach().cpu().numpy().squeeze()
        plaintexts1 : np.ndarray  = plaintexts[:, 0]
        plaintexts2 : np.ndarray  = plaintexts[:, 1]

        # [500, 256]
        sbox1_scores = sbox1_model(traces)
        sbox2_scores = sbox2_model(traces)

        numpy_scores1 = sbox1_scores.detach().cpu().numpy()
        numpy_scores2 = sbox2_scores.detach().cpu().numpy()
            
        numpy_keyscores1 = keyrank_rs.sbox_scores_to_keyscores_parallel(plaintexts1, numpy_scores1)
        numpy_keyscores2 = keyrank_rs.sbox_scores_to_keyscores_parallel(plaintexts2, numpy_scores2)

        keyscores1 = torch.Tensor(numpy_keyscores1).to(device).softmax(1).log()
        keyscores2 = torch.Tensor(numpy_keyscores2).to(device).softmax(1).log()

        keyscores = keyscores1 + keyscores2

        keyscore_acc = torch.zeros(256)

        for n_traces,keyscore in enumerate(keyscores):

            keyscore_acc += keyscore.cpu()
            rank1_value = keyscore_acc.argmax().item()
            success_matrix[n_traces, key_idx] = 1. if rank1_value == true_key else 0.

    N_traces_successes = success_matrix.sum(axis=1)

    print(N_traces_successes)

    for idx, n_hits in enumerate(N_traces_successes):
        if n_hits >= 495.0: # 99%
            return idx+1

    return -1

In [31]:
if PREDICTION_TARGET == "sbox":
    output = traces_for_99acc(model, test_loader, block=0)
elif PREDICTION_TARGET == "sbox2":
    output = traces_for_99acc(model, test_loader, block=1)
elif PREDICTION_TARGET in ["2sbox","2sbox..."]:
    output = traces_for_99acc_2sbox(model, test_loader)
elif PREDICTION_TARGET == "combo":
    output = traces_for_99acc_combo(models[0], models[1], test_loader)

output

500key [00:49, 10.04key/s]

[ 47. 136. 246. 347. 398. 442. 470. 477. 490. 494. 497. 499. 500. 500.
 499. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500. 500.
 500. 




11

In [29]:
output