In [None]:
import numpy as np
import matplotlib
matplotlib.use('nbAgg')
from matplotlib import pyplot as plt
from tqdm import tqdm

from datasets.aes_rd import AES_RD
from leakage_detectors.non_learning import get_trace_means, get_sum_of_differences, get_signal_to_noise_ratio

In [None]:
dataset = AES_RD()
print(dataset)

In [None]:
fig, axes = plt.subplots(3, 5, figsize=(20, 12))
for idx, ax in enumerate(axes.flatten()):
    trace, _ = dataset[idx]
    trace = trace.squeeze()
    ax.plot(trace)
plt.show()

In [None]:
mean, var = np.zeros_like(trace), np.zeros_like(trace)
for idx, (trace, _) in enumerate(tqdm(dataset)):
    mean = (1/(idx+1))*trace.squeeze() + (idx/(idx+1))*mean
for idx, (trace, _) in enumerate(tqdm(dataset)):
    var = (1/(idx+1))*(trace.squeeze() - mean)**2 + (idx/(idx+1))*var
stdev = np.sqrt(var)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].plot(mean)
axes[1].plot(stdev)
axes[0].set_xlabel('Time')
axes[0].set_ylabel('Mean')
axes[1].set_xlabel('Time')
axes[1].set_ylabel('Std. deviation')
plt.show()

In [None]:
trace_means = get_trace_means(dataset)
sod_mask = get_sum_of_differences(dataset, trace_means=trace_means)
snr_mask = get_signal_to_noise_ratio(dataset, trace_means=trace_means)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].plot(sod_mask.squeeze())
axes[1].plot(snr_mask.squeeze())
plt.show()

In [None]:
target_variables = ['subbytes']
target_bytes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
fig, axes = plt.subplots(len(target_bytes), len(target_variables), figsize=(4*len(target_variables), 4*len(target_bytes)))
if len(target_bytes) == 1:
    axes = axes[np.newaxis, ...]
if len(target_variables) == 1:
    axes = axes[..., np.newaxis]
progress_bar = tqdm(total=len(target_bytes)*len(target_variables))

for tb_idx, target_byte in enumerate(target_bytes):
    axes_r = axes[tb_idx, :]
    for tv_idx, target_variable in enumerate(target_variables):
        dataset.select_target(variables=target_variable, bytes=target_byte)
        ax = axes_r[tv_idx]
        snr_mask = get_signal_to_noise_ratio(dataset)
        ax.plot(snr_mask)
        ax.set_title(f'{target_variable}(byte={target_byte})')
        progress_bar.update(1)
plt.show()