In [10]:
import os

import numpy as np
from tqdm import tqdm

from src.data.loaders.ascad import ASCADData
from src.dlla.berg import make_mlp
from src.dlla.hw import prepare_traces_dl, dlla_known_p
from src.pollution.gaussian_noise import gaussian_noise
from src.tools.cache import cache_np
from src.trace_set.database import Database
from src.trace_set.pollution import Pollution, PollutionType
from src.trace_set.set_hw import TraceSetHW
from src.trace_set.window import get_windows, extract_traces

In [11]:
# Source [EDIT]
DB = Database.aisy

# Limit number of traces [EDIT]
LIMIT_PROF = None
LIMIT_ATT = 1000

# Select targets and noise parameters

RAW_TRACES, WINDOW_JITTER_PARAMS, GAUSS_PARAMS, LIMIT_RAW = [None] * 4

if DB is Database.ascad_none or DB is Database.ascad:
    TARGET_ROUND = 0
    TARGET_BYTE = 0

    WINDOW_JITTER_PARAMS = np.arange(0, 205, 5)
    GAUSS_PARAMS = np.arange(0, 102, 2)

    RAW_TRACES = ASCADData.raw()['traces']
    LIMIT_RAW = -1

if DB is Database.ascad:
    TARGET_BYTE = 2

    WINDOW_JITTER_PARAMS = np.arange(0, 2.05, .05)
    GAUSS_PARAMS = np.arange(0, 5.1, .1)

if DB is Database.aisy:
    TARGET_ROUND = 4
    TARGET_BYTE = 0

    WINDOW_JITTER_PARAMS = np.arange(0, 460, 10)
    GAUSS_PARAMS = np.arange(0, 4100, 100)

    RAW_TRACES = cache_np("aisy_traces")

# Select targets
TRACE_SET = TraceSetHW(DB)
SAMPLE_TRACE = TRACE_SET.profile()[0][0]
WINDOW, WINDOW_CXT = get_windows(RAW_TRACES, SAMPLE_TRACE)

# Isolate context trace for window jitter.
# Gets cached, as this procedure takes some time (depending on disk read speed)
X_CXT = cache_np(f"{DB.name}_x_cxt", extract_traces, RAW_TRACES, WINDOW_CXT)[:LIMIT_RAW]

In [12]:
PROFILING_MASK = np.ones(len(X_CXT), dtype=bool)
PROFILING_MASK[2::3] = 0

In [13]:
X_PROF, Y_PROF = TRACE_SET.profile_states()
X_ATT, Y_ATT = TRACE_SET.attack_states()

X_PROF_CXT = X_CXT[PROFILING_MASK]
X_ATT_CXT = X_CXT[~PROFILING_MASK]

In [14]:
def verify(db: Database, pollution: Pollution):
    """
    Assess leakage from database by
    """
    trace_set = TraceSetHW(db, pollution, (LIMIT_PROF, LIMIT_ATT))
    x9, y9, x9_att, y9_att = prepare_traces_dl(*trace_set.profile(), *trace_set.attack())
    mdl9 = make_mlp(x9, y9, progress=False)
    dlla9_p = dlla_known_p(mdl9, x9_att, y9_att)

    print(f"Pollution {pollution.type} ({pollution.parameter}): p-value ({dlla9_p}).")

In [15]:
def desync(traces: np.ndarray, window: (int, int), sigma: float):
    start, end = window
    num_traces = len(traces)
    num_sample_points = end - start

    permutations = np.round(np.random.normal(scale=sigma, size=num_traces)).astype(int)

    if np.max(np.abs(permutations)) >= num_sample_points:
        raise Exception(f"Window jitter parameter ({sigma}) too high. PoI is not always within the resulting traces.")

    permutations += start

    res = np.ones((num_traces, num_sample_points), dtype=traces.dtype)

    for ix in tqdm(range(num_traces), f"Trace desynchronization, sigma={sigma}"):
        permutation = permutations[ix]
        res[ix] = traces[ix, permutation:permutation + num_sample_points]

    return res

def apply_desync(db, x_prof_cxt, y_prof, x_att_cxt, y_att, window: (int, int), params: list):
    for param in params:

        pollution = Pollution(PollutionType.desync, param)
        out = TraceSetHW(db, pollution, (LIMIT_PROF, LIMIT_ATT))

        if not os.path.exists(out.path):
            xn = desync(x_prof_cxt, window, param)
            xn_att = desync(x_att_cxt, window, param)

            out.create(xn, y_prof, xn_att, y_att)

            verify(db, pollution)

In [16]:
apply_desync(DB, X_PROF_CXT, Y_PROF, X_ATT_CXT, Y_ATT, WINDOW, WINDOW_JITTER_PARAMS)

In [17]:
def apply_gauss(db, params: list):
    for param in params:
        pollution = Pollution(PollutionType.gauss, param)
        default = TraceSetHW(db)
        out = TraceSetHW(db, pollution, (LIMIT_PROF, LIMIT_ATT))

        x_prof, y_prof = default.profile_states()
        x_att, y_att = default.attack_states()

        if not os.path.exists(out.path):
            xn = gaussian_noise(x_prof, param)
            xn_att = gaussian_noise(x_att, param)

            out.create(xn, y_prof, xn_att, y_att)

            verify(db, pollution)

apply_gauss(DB, GAUSS_PARAMS)