In [None]:
# COLAB
if (True):
    from google.colab import drive
    drive.mount('/content/drive')

import torch as t
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch._dynamo.reset()


t.set_float32_matmul_precision('high')
t.backends.cuda.matmul.allow_tf32 = True
t.backends.cudnn.allow_tf32 = True


import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import pandas as pd

import sklearn as sk

import h5py
import time
import os
import json
import gc


from datetime import datetime

In [None]:
def HDF5_viewer(content, dir, N_object):
    y = None
    x = None
    m = None

    with h5py.File(dir, 'r') as file:
        y = file[content]['labels'][:]
        x = file[content]['traces'][:]
        m_aux = file[content]['metadata'][:]

        y = np.array(y)
        x = np.array(x)

    N = len(m_aux)
    object_dim = len(m_aux[0][0])

    m = np.zeros(shape = (N_object, N, object_dim), dtype = np.int16)

    for i in range(N_object):
        for k in range(N):
            for j in range(object_dim):

                m[i,k,j] = m_aux[k][i][j]


    return x, y, m

def Import_hdf5_data_SC(content, dir, k_pos = None, p_pos = None, byte_a = None):
    y = None
    x = None
    m = None

    with h5py.File(dir, 'r') as file:
        y = file[content]['labels'][:]
        x = file[content]['traces'][:]
        m = file[content]['metadata'][:]

    N_data, N = x.shape
    x = np.array(x, dtype = np.float32).reshape(N_data, 1, N)
    y = np.array(y, dtype = np.int64).reshape(N_data, 1)

    p = np.zeros(shape = (N_data, 1), dtype = np.int64)
    k = np.zeros(shape = (N_data, 1), dtype = np.int64)

    for i in range(N_data):
        p[i, 0] = m[i][p_pos][byte_a]
        k[i, 0] = m[i][k_pos][byte_a]

    y_ext = np.concatenate((y, p, k), axis = 1)

    return [x, y_ext]

def Preprocess_data_SC(data_p, data_a, size, augmentation = None, scaler = None):
    N_elements = data_p[1].shape[1]

    if (augmentation != None): data_p = augmentation(data_p)

    N_data, N_char, N = data_p[0].shape
    data_p[0] = data_p[0].reshape(N_data, N_char*N)


    # Scaler
    if (scaler == 'Standard'):
        from sklearn.preprocessing import StandardScaler
        scaler = StandardScaler()

        data_p[0] = scaler.fit_transform(data_p[0])


    if (scaler == 'MinMax'):
        from sklearn.preprocessing import MinMaxScaler
        scaler = MinMaxScaler()

        data_p[0] = scaler.fit_transform(data_p[0])

    # StratifiedShuffleSplit
    if (size != 1):
        from sklearn.model_selection import StratifiedShuffleSplit
        sss = StratifiedShuffleSplit(
            n_splits = 1,
            train_size = size,
            random_state = 22
        )

        train_ind = sss.split(X = data_p[0], y = data_p[1][:,0]).__next__()[0]

        x_p = data_p[0][train_ind, :]
        y_ext_p = data_p[1][train_ind, :]

        N_data = x_p.shape[0]

    else:
        x_p = data_p[0]
        y_ext_p = data_p[1]

    ind_shuffle = np.random.permutation(N_data)
    ind = np.arange(N_data, dtype = np.int32)

    x_p[ind,:] = x_p[ind_shuffle[ind],:]
    y_ext_p[ind,:] = y_ext_p[ind_shuffle[ind],:]

    x_p = t.from_numpy(x_p).to(t.float32).view(N_data, N_char, N)
    y_ext_p = t.from_numpy(y_ext_p).to(t.int64).view(N_data, N_elements)

    data_p_prepro = [x_p, y_ext_p]

    N_data, N_char, N = data_a[0].shape
    data_a[0] = data_a[0].reshape(N_data, N_char*N)

    x_a = data_a[0]
    y_ext_a = data_a[1]

    ind_shuffle = np.random.permutation(N_data)
    ind = np.arange(N_data, dtype = np.int32)

    x_a[ind,:] = x_a[ind_shuffle[ind],:]
    y_ext_a[ind,:] = y_ext_a[ind_shuffle[ind],:]

    x_a = scaler.transform(x_a)

    x_a = t.from_numpy(x_a).to(t.float32).view(N_data, N_char, N)
    y_ext_a = t.from_numpy(y_ext_a).to(t.int64).view(N_data, N_elements)

    data_a_prepro = [x_a, y_ext_a]

    return data_p_prepro, data_a_prepro

In [None]:
from matplotlib.pyplot import figure, show
from matplotlib.gridspec import GridSpec
from IPython.display import clear_output, display, update_display


class Metrics:
    def __init__(self,
        do_CDA = False,
        do_conv = False,
        do_fit = False,
        do_loss = False,
        do_acc = False,

        x_attack = None,
        y_attack = None,
        p_attack = None,
        k_attack = None,
        x_test = None,
        y_test = None,

        N_attacks = None,
        N_traces = None,
        Delta_traces = None,
        N_population = None,
        Batch = None,
        Epochs = None,

        device = None,
        display_id = None,
    ):
        self.device = device

        self.do_CDA = do_CDA
        self.do_conv = do_conv
        self.do_fit = do_fit
        self.do_loss = do_loss
        self.do_acc = do_acc

        self.x_attack = x_attack.to(device)
        self.y_attack = y_attack.to(device)
        self.p_attack = p_attack.to(device)
        self.k_attack = k_attack.to(device)

        if (self.do_acc) or (self.do_loss):
            self.x_test = x_test.to(device)
            self.y_test = y_test.to(device)

        self.N_attacks = N_attacks
        self.N_traces = N_traces
        self.Delta_traces = Delta_traces
        self.N_population = N_population
        self.Batch = Batch
        self.Epochs = Epochs

        self.t_ref = 0
        self.t_round = 0
        self.epoch_ind = 0

        self.display_id = display_id

        self.fig = plt.figure(figsize = (15, 20))
        self.fig.add_subplot().axis('off')
        display(self.fig, display_id = self.display_id)
        plt.close(self.fig)

        if (self.do_CDA):   self.CDA = t.tensor([], dtype = t.float32, device = device)
        if (self.do_conv):   self.conv = t.tensor([], dtype = t.float32, device = device)
        if (self.do_loss):   self.loss = t.tensor([], dtype = t.float32, device = device)
        if (self.do_acc):   self.acc = t.tensor([], dtype = t.float32, device = device)
        if (self.do_fit):   self.fit = np.array([], dtype = np.float32)

        self.GE = t.tensor([], dtype = t.float32, device = device)
        self.GE_aux = t.zeros(N_attacks, N_traces, dtype = t.float32, device = device)

        self.AES_SBOX = t.tensor([
            0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
            0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
            0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
            0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
            0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
            0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
            0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
            0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
            0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
            0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
            0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
            0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
            0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
            0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
            0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
            0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16
        ], device = self.device)

    def close_fig(self):
        plt.close(self.fig)
        self.fig = None

    def add_to_dataset(x_i, x_ip1):
        if (x_i.shape[0] == 0):
            x_i = x_ip1

        else: x_i = t.cat((
            x_i,
            x_ip1
        ), dim = 0)

        return x_i

    def add_to_dataset_numpy(x_i, x_ip1):
        if (x_i.shape[0] == 0):
            x_i = x_ip1

        else: x_i = np.concatenate((
            x_i,
            x_ip1
        ), axis = 0)

        return x_i

    def measure_GE(self, net):
        with t.no_grad():

            for attack in range(self.N_attacks):

                ind = t.randperm(self.x_attack.shape[0], dtype = t.int32, device = self.device)[:self.N_traces].view(self.N_traces, 1)

                net.eval()
                leak = net(self.x_attack[ind.squeeze(dim = 1),...]).softmax(dim = -1)
                plain = self.p_attack[ind]
                key = self.k_attack[ind]

                trace = t.arange(self.N_traces, dtype = t.int32, device = self.device).view(self.N_traces, 1)
                k = t.arange(256, dtype = t.int32, device = self.device)

                x = leak[trace, self.AES_SBOX[k^plain]].add(1e-40).log()
                A = t.ones(self.N_traces, self.N_traces, device = self.device).tril()
                x = A@x

                GE_nf = x.sort(dim = 1, descending = True).values.eq(x[trace, key]).nonzero(as_tuple = False)

                trace_nf = t.arange(GE_nf.shape[0], dtype = t.int32, device = self.device)
                GE_nf = t.cat((
                    -1*t.ones(1, 2, dtype = t.int32, device = self.device),
                    GE_nf,
                    -1*t.ones(1, 2, dtype = t.int32, device = self.device)
                ), dim = 0)

                self.GE_aux[attack, :]  = GE_nf[GE_nf[trace_nf + 1, 0].ne(GE_nf[trace_nf + 2, 0]).nonzero(as_tuple = True)[0] + 1, 1].float()
                self.GE_aux[attack, :] += GE_nf[GE_nf[trace_nf + 1, 0].ne(GE_nf[trace_nf + 0, 0]).nonzero(as_tuple = True)[0] + 1, 1].float()
                self.GE_aux[attack, :] /= 2

                t.cuda.empty_cache()

            GE_aux = t.cat((
                self.GE_aux.mean(dim = 0, dtype = t.float32).view(self.N_traces, 1),
                self.GE_aux.var(dim = 0).view(self.N_traces, 1)/self.N_attacks
            ), dim = 1).view(1, self.N_traces, 2).detach()

            self.GE = Metrics.add_to_dataset(self.GE, GE_aux)

            t.cuda.empty_cache()

    def measure_conv(self):
        with t.no_grad():
            conv_aux = self.GE_aux[:, -self.Delta_traces:].sum(dim = 1, dtype = t.float32)/self.Delta_traces

            conv_aux = t.cat((
                conv_aux.mean(dtype = t.float32).view(1),
                conv_aux.var().view(1)/self.N_attacks
            )).view(1, 2).detach()

            self.conv = Metrics.add_to_dataset(self.conv, conv_aux)

            t.cuda.empty_cache()

    def measure_CDA(self):
        with t.no_grad():
            CDA_aux_i = self.GE_aux.sum(dim = 1, dtype = t.float32)/(128*self.N_traces/2)

            CDA_aux = t.cat((
                CDA_aux_i.mean(dtype = t.float32).view(1),
                CDA_aux_i.var().view(1)/self.N_attacks
            )).view(1, 2).detach()

            self.CDA = Metrics.add_to_dataset(self.CDA, CDA_aux)

            t.cuda.empty_cache()

    def measure_fit(self):
        from scipy.optimize import curve_fit, Bounds

        def g(x, s, n, gamma, alpha):
            return n + s*np.exp(gamma - gamma*x**alpha)

        limits = Bounds(
            lb = np.array([0, 0, 0, 0], dtype = np.float32),
            ub = np.array([200, 200, 10, 10], dtype = np.float32)
        )

        GE_data = self.GE[-1, :, :].cpu().numpy().astype(np.float32)

        p0 = np.array([1, 128, 0.1, 0.2])
        x_SCIPY = np.arange(self.N_traces, dtype = np.float32) + 1
        y_SCIPY = GE_data[:, 0]
        se2_y_SCIPY = GE_data[:, 1]

        popt, pcov = curve_fit(
            f = g,
            xdata = x_SCIPY,
            ydata = y_SCIPY,
            p0 = p0,
            bounds = limits,
            sigma = np.sqrt(se2_y_SCIPY) + 10-8,
            absolute_sigma = True
        )

        popt[0] = popt[0] + popt[1]
        pcov[0,0] = pcov[0,0] + pcov[1,1] + 2*pcov[0,1]

        fit_aux = np.concatenate((
            popt.reshape(1, 4, 1),
            pcov.diagonal().reshape(1, 4, 1)
        ), axis = 2)

        self.fit = Metrics.add_to_dataset_numpy(self.fit, fit_aux)

    def measure_loss(self, net, loss, loss_train):
        with t.no_grad():
            net.eval()

            loss_aux =  t.zeros(2, self.N_population, dtype = t.float32, device = self.device)

            for i in range(self.N_population):
                ind_test = t.randperm(self.x_test.shape[0], dtype = t.int32, device = self.device)[:self.Batch].view(self.Batch)
                ind_attack = t.randperm(self.x_attack.shape[0], dtype = t.int32, device = self.device)[:self.Batch].view(self.Batch)

                loss_aux[0, i] = loss(net(self.x_test[ind_test, ...]), self.y_test[ind_test])
                loss_aux[1, i] = loss(net(self.x_attack[ind_attack, ...]), self.y_attack[ind_attack])

            t.cuda.empty_cache()

            loss_aux = t.cat((
                loss_train,  # 1, 2
                t.cat((
                    loss_aux.mean(dim = 1, dtype = t.float32).view(2, 1),
                    loss_aux.var(dim = 1).view(2, 1)/self.N_population
                ), dim = 1)
            ), dim = 0).view(1, 3, 2).detach()

            self.loss = Metrics.add_to_dataset(self.loss, loss_aux)

            t.cuda.empty_cache()

    def acc_func(self, net, x, y, Batch):

        with t.no_grad():
            net.eval()

            ind_train = t.randperm(x.shape[0], dtype = t.int32, device = x.device)[:Batch].view(Batch)

            pred_train = net(x[ind_train,...]).argmax(dim = -1)
            A = y[ind_train]
            acc_aux = pred_train.eq(y[ind_train]).float().mean(dtype = t.float32).detach()

            t.cuda.empty_cache()

            return acc_aux

    def measure_acc(self, net, acc_train):
        with t.no_grad():
            net.eval()

            acc_aux =  t.zeros(2, self.N_population, dtype = t.float32, device = self.device)


            for i in range(self.N_population):
                ind_test = t.randperm(self.x_test.shape[0], dtype = t.int32, device = self.device)[:self.Batch].view(self.Batch)
                ind_attack = t.randperm(self.x_attack.shape[0], dtype = t.int32, device = self.device)[:self.Batch].view(self.Batch)

                pred_test = net(self.x_test[ind_test,...]).argmax(dim = -1)
                pred_attack = net(self.x_attack[ind_attack,...]).argmax(dim = -1)

                acc_aux[0,i] = (pred_test.eq(self.y_test[ind_test])).float().mean(dtype = t.float32)
                acc_aux[1,i] = (pred_attack.eq(self.y_attack[ind_test])).float().mean(dtype = t.float32)

            t.cuda.empty_cache()

            acc_aux = t.cat((
                acc_train,  # 1, 2
                t.cat((
                    acc_aux.mean(dim = 1, dtype = t.float32).view(2, 1),
                    acc_aux.var(dim = 1).view(2, 1)/self.N_population
                ), dim = 1)
            ), dim = 0).view(1, 3, 2).detach()

            self.acc = Metrics.add_to_dataset(self.acc, acc_aux)

            t.cuda.empty_cache()

    def measure(self, net, loss = None, loss_train = None, acc_train = None):
        self.measure_GE(net)

        if (self.do_conv):  self.measure_conv()
        if (self.do_fit):   self.measure_fit()
        if (self.do_CDA):   self.measure_CDA()
        if (self.do_loss):  self.measure_loss(net, loss, loss_train)
        if (self.do_acc):   self.measure_acc(net, acc_train)

    def plot(self):
        t_now = time.time() - self.t_ref
        delta_time = t_now - self.t_round
        T = delta_time*(self.Epochs - self.epoch_ind - 1) + t_now
        self.t_round = t_now

        # clear_output(wait=True)

        self.fig = figure(figsize = (15, 20))
        gs = GridSpec(5, 2, height_ratios = [0.03, 0.2425, 0.2425, 0.2425, 0.2425])

        # Time bar
        ax = self.fig.add_subplot(gs[0,:])
        ax.barh(y = 0, width = round(t_now/60, 2), color = 'green')
        ax.set_xlabel('Time (min)')
        ax.set_yticks([])
        ax.set_xlim(0, round(T/60, 2))
        ax.text(
            x = 0.1,
            y = -0.1,
            s = f't_epoch={round(delta_time, 2)} s',
            fontsize = 12
        )

        # Loss graph
        if (self.do_loss):
            ax = self.fig.add_subplot(gs[1,0])
            y = self.loss[:, :, 0].cpu().numpy()
            desv = t.sqrt(self.loss[:, :, 1]).cpu().numpy()

            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                y[:, 0],
                label = 'Train',
                color = 'red'
            )

            ax.fill_between(
                x = np.arange(1, self.epoch_ind + 2),
                y1 = y[:, 0] - desv[:, 0],
                y2 = y[:, 0] + desv[:, 0],
                color = 'red',
                alpha = 0.3
            )

            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                y[:, 1],
                label = 'Test',
                color = 'blue'
            )

            ax.fill_between(
                x = np.arange(1, self.epoch_ind + 2),
                y1 = y[:, 1] - desv[:, 1],
                y2 = y[:, 1] + desv[:, 1],
                color = 'blue',
                alpha = 0.3
            )

            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                y[:, 2],
                label = 'Attack',
                color = 'green'
            )

            ax.fill_between(
                np.arange(1, self.epoch_ind + 2),
                y1 = y[:, 2] - desv[:, 2],
                y2 = y[:, 2] + desv[:, 2],
                color = 'green',
                alpha = 0.3
            )

            ax.set_xlim(1, self.Epochs)
            ax.set_xlabel('#epoch')
            ax.set_ylabel('loss')

            ax.legend()
            ax.grid(True, which="both", linestyle="--", linewidth=0.5)

        # acc graph
        if (self.do_acc):
            ax = self.fig.add_subplot(gs[1,1])
            y = self.acc[:, :, 0].cpu().numpy()
            desv = t.sqrt(self.acc[:, :, 1]).cpu().numpy()

            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                y[:, 0],
                label = 'Train',
                color = 'red'
            )

            ax.fill_between(
                x = np.arange(1, self.epoch_ind + 2),
                y1 = y[:, 0] - desv[:, 0],
                y2 = y[:, 0] + desv[:, 0],
                color = 'red',
                alpha = 0.3
            )

            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                y[:, 1],
                label = 'Test',
                color = 'blue'
            )

            ax.fill_between(
                x = np.arange(1, self.epoch_ind + 2),
                y1 = y[:, 1] - desv[:, 1],
                y2 = y[:, 1] + desv[:, 1],
                color = 'blue',
                alpha = 0.3
            )

            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                y[:, 2],
                label = 'Attack',
                color = 'green'
            )

            ax.fill_between(
                x = np.arange(1, self.epoch_ind + 2),
                y1 = y[:, 2] - desv[:, 2],
                y2 = y[:, 2] + desv[:, 2],
                color = 'green',
                alpha = 0.3
            )

            ax.set_xlim(1, self.Epochs)
            ax.set_xlabel('#epoch')
            ax.set_ylabel('acc')

            ax.legend()
            ax.grid(True, which="both", linestyle="--", linewidth=0.5)

        # CDA graph
        if (self.do_CDA):
            ax = self.fig.add_subplot(gs[2,0])
            y = self.CDA[:, 0].cpu().numpy()
            desv = t.sqrt(self.CDA[:, 1]).cpu().numpy()

            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                y,
                label = 'Attack',
                color = 'green'
            )

            ax.fill_between(
                x = np.arange(1, self.epoch_ind + 2),
                y1 = y - desv,
                y2 = y + desv,
                color = 'green',
                alpha = 0.3
            )

            ax.set_xlim(1, self.Epochs)
            ax.set_xlabel('#epoch')
            ax.set_ylabel('CDA')

            ax.grid(True, which="both", linestyle="--", linewidth=0.5)

        # conv graph
        if (self.do_conv):
            ax = self.fig.add_subplot(gs[2,1])
            y = self.conv[:, 0].cpu().numpy()
            desv = t.sqrt(self.conv[:, 1]).cpu().numpy()

            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                y,
                label = 'Attack',
                color = 'green'
            )

            ax.fill_between(
                x = np.arange(1, self.epoch_ind + 2),
                y1 = y - desv,
                y2 = y + desv,
                color = 'green',
                alpha = 0.3
            )

            ax.set_xlim(1, self.Epochs)
            ax.set_xlabel('#epoch')
            ax.set_ylabel('conv')

            ax.grid(True, which="both", linestyle="--", linewidth=0.5)

        # fit graph
        if (self.do_fit):
            ax = self.fig.add_subplot(gs[3,0])

            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                self.fit[:, 0, 0],
                label = 'Attack',
                color = 'green'
            )

            desv = np.sqrt(self.fit[:, 0, 1])

            # ax.fill_between(
            #     x = np.arange(1, self.epoch_ind + 2),
            #     y1 = self.fit[:, 0, 0] - desv,
            #     y2 = self.fit[:, 0, 0] + desv,
            #     color = 'green',
            #     alpha = 0.3
            # )

            ax.set_xlim(1, self.Epochs)
            ax.set_xlabel('#epoch')
            ax.set_ylabel('s_fit')

            ax.grid(True, which="both", linestyle="--", linewidth=0.5)

            ax = self.fig.add_subplot(gs[3,1])
            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                self.fit[:, 1, 0],
                label = 'Attack',
                color = 'green'
            )

            desv = np.sqrt(self.fit[:, 1, 1])

            # ax.fill_between(
            #     x = np.arange(1, self.epoch_ind + 2),
            #     y1 = self.fit[:, 1, 0] - desv,
            #     y2 = self.fit[:, 1, 0] + desv,
            #     color = 'green',
            #     alpha = 0.3
            # )

            ax.set_xlim(1, self.Epochs)
            ax.set_xlabel('#epoch')
            ax.set_ylabel('n_fit')

            ax.grid(True, which="both", linestyle="--", linewidth=0.5)

            ax = self.fig.add_subplot(gs[4,0])
            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                self.fit[:, 2, 0],
                label = 'Attack',
                color = 'green'
            )

            desv = np.sqrt(self.fit[:, 2, 1])

            # ax.fill_between(
            #     x = np.arange(1, self.epoch_ind + 2),
            #     y1 = self.fit[:, 2, 0] - desv,
            #     y2 = self.fit[:, 2, 0] + desv,
            #     color = 'green',
            #     alpha = 0.3
            # )

            ax.set_xlim(1, self.Epochs)
            ax.set_xlabel('#epoch')
            ax.set_ylabel('gamma_fit')

            ax.grid(True, which="both", linestyle="--", linewidth=0.5)

            ax = self.fig.add_subplot(gs[4,1])
            ax.plot(
                np.arange(1, self.epoch_ind + 2),
                self.fit[:, 3, 0],
                label = 'Attack',
                color = 'green'
            )

            desv = np.sqrt(self.fit[:, 3, 1])

            # ax.fill_between(
            #     x = np.arange(1, self.epoch_ind + 2),
            #     y1 = self.fit[:, 3, 0] - desv,
            #     y2 = self.fit[:, 3, 0] + desv,
            #     color = 'green',
            #     alpha = 0.3
            # )

            ax.set_xlim(1, self.Epochs)
            ax.set_xlabel('#epoch')
            ax.set_ylabel('alpha_fit')

            ax.grid(True, which="both", linestyle="--", linewidth=0.5)

        self.fig.suptitle(f'Model Trainning           Epoch = {self.epoch_ind + 1}/{self.Epochs}', fontsize = 15, y = 0.99)
        self.fig.tight_layout()
        update_display(self.fig, display_id = self.display_id)
        plt.close(self.fig)

        self.epoch_ind += 1

    def save_plot(self, dir, model_name, label):
        self.fig.savefig(f'{dir}/{model_name}_{label}.png')

    def save_data(self, dir, model_name, label):
        with h5py.File(f'{dir}/{model_name}_{label}.h5', 'a') as file:

            GE = file.require_group('GE')

            GE_data = self.GE.cpu().numpy().astype(np.float32)

            data = GE.require_dataset(
                name = 'GE_E',
                shape = (self.Epochs, self.N_traces),
                dtype = np.float32,
                exact = False
            )

            data[:,:] = GE_data[:, :, 0]

            data = GE.require_dataset(
                name = 'GE_se2',
                shape = (self.Epochs, self.N_traces),
                dtype = np.float32,
                exact = False
            )

            data[:,:] = GE_data[:, :, 1]

            if (self.do_conv):
                conv = file.require_group('conv')

                conv_data = self.conv.cpu().numpy().astype(np.float32)

                data = conv.require_dataset(
                    name = 'conv_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = conv_data[:, 0]

                data = conv.require_dataset(
                    name = 'conv_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = conv_data[:, 1]

            if (self.do_CDA):
                CDA = file.require_group('CDA')

                CDA_data = self.CDA.cpu().numpy().astype(np.float32)

                data = CDA.require_dataset(
                    name = 'CDA_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = CDA_data[:, 0]

                data = CDA.require_dataset(
                    name = 'CDA_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = CDA_data[:, 1]

            if (self.do_loss):
                loss = file.require_group('loss')

                loss_data = self.loss.cpu().numpy().astype(np.float32)

                loss_train = loss.require_group('loss_train')

                data = loss_train.require_dataset(
                    name = 'loss_train_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = loss_data[:, 0, 0]

                data = loss_train.require_dataset(
                    name = 'loss_train_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = loss_data[:, 1, 1]

                loss_test = loss.require_group('loss_test')

                data = loss_test.require_dataset(
                    name = 'loss_test_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = loss_data[:, 1, 0]

                data = loss_test.require_dataset(
                    name = 'loss_test_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = loss_data[:, 1, 1]

                loss_attack = loss.require_group('loss_attack')

                data = loss_attack.require_dataset(
                    name = 'loss_attack_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = loss_data[:, 2, 0]

                data = loss_attack.require_dataset(
                    name = 'loss_attack_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = loss_data[:, 2, 1]

            if (self.do_acc):
                acc = file.require_group('acc')

                acc_data = self.acc.cpu().numpy().astype(np.float32)

                acc_train = acc.require_group('acc_train')

                data = acc_train.require_dataset(
                    name = 'acc_train_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = acc_data[:, 0, 0]

                data = acc_train.require_dataset(
                    name = 'loss_train_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = acc_data[:, 0, 1]

                acc_test = acc.require_group('acc_test')

                data = acc_test.require_dataset(
                    name = 'acc_test_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = acc_data[:, 1, 0]

                data = acc_test.require_dataset(
                    name = 'acc_test_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = acc_data[:, 1, 1]

                acc_attack = acc.require_group('acc_attack')

                data = acc_attack.require_dataset(
                    name = 'acc_attack_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = acc_data[:, 2, 0]

                data = acc_attack.require_dataset(
                    name = 'acc_attack_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = acc_data[:, 2, 1]

            if (self.do_fit):
                fit = file.require_group('fit')

                fit_data = self.fit

                s = fit.require_group('s_fit')

                data = s.require_dataset(
                    name = 's_fit_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = fit_data[:, 0, 0]

                data = s.require_dataset(
                    name = 's_fit_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = fit_data[:, 0, 1]

                n = fit.require_group('n_fit')

                data = n.require_dataset(
                    name = 'n_fit_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = fit_data[:, 1, 0]

                data = n.require_dataset(
                    name = 'n_fit_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = fit_data[:, 1, 1]

                gamma = fit.require_group('gamma_fit')

                data = gamma.require_dataset(
                    name = 'gamma_fit_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = fit_data[:, 2, 0]

                data = gamma.require_dataset(
                    name = 'gamma_fit_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = fit_data[:, 2, 1]

                alpha = fit.require_group('alpha_fit')

                data = alpha.require_dataset(
                    name = 'alpha_fit_E',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = fit_data[:, 3, 0]

                data = alpha.require_dataset(
                    name = 'alpha_fit_se2',
                    shape = self.Epochs,
                    dtype = np.float32,
                    exact = False
                )

                data[:] = fit_data[:, 3, 1]



In [None]:
class Trainer:
    from torch.utils.data import Dataset

    class MyDataset(Dataset):
        def __init__(self, keys, data):
            self.keys = keys
            self.data = data

        def __len__(self):  return len(self.keys)

        def __getitem__(self, i):
            key = self.keys[i]
            datum = self.data[i]

            return (key, datum)

    def __init__(self,
        device = t.device('cpu'),

        x_train = None,
        y_train = None,

        x_test = None,
        y_test = None,

        x_attack = None,
        p_attack = None,
        k_attack = None,

        EPOCH = None,
        BATCH_SIZE = None,

        metrics = None,

        net_name = None,
        save_dir = None,
        save_w_path = False,

    ):
        self.device = device

        self.x_train = x_train.to(self.device)
        self.y_train = y_train.to(self.device)

        self.x_test = x_test.to(self.device)
        self.y_test = y_test.to(self.device)

        self.x_attack = x_attack.to(self.device)
        self.p_attack = p_attack.to(self.device)
        self.k_attack = k_attack.to(self.device)

        self.EPOCH = EPOCH
        self.BATCH_SIZE = BATCH_SIZE

        self.metrics = metrics

        self.net_name = net_name
        self.save_dir = save_dir
        self.save_w_path = save_w_path

        from torch.utils.data import DataLoader

        dataset = Trainer.MyDataset(
            keys = self.y_train,
            data = self.x_train
        )

        self.dataloader = DataLoader(
            dataset = dataset,
            batch_size = BATCH_SIZE,
            shuffle = True,
        )

    def train(self, net = None, opt = None, reg = None, sch = None, loss_func = None, trainning_date = None):
        self.metrics.t_ref = time.time()
        self.metrics.t_round = 0

        for epoch in range(self.EPOCH):
            if (self.save_w_path):  model_evo_batch = {}

            if (self.metrics.do_loss):  loss_train_aux = t.tensor([], dtype = t.float32, device = self.device)
            if (self.metrics.do_acc):   acc_train_aux = t.tensor([], dtype = t.float32, device = self.device)

            N_Batches = 0
            for y_batch, x_batch in self.dataloader:
                net.train()
                opt.zero_grad()

                with t.autocast(device_type = self.device.type, dtype = t.float16, enabled = False):
                    loss_value = loss_func(net(x_batch), y_batch)

                # loss_value = loss_func(net(x_batch), y_batch)

                loss_value.backward()
                opt.step()

                if (reg != None):
                    opt.zero_grad()

                    reg_value = opt.param_groups[0]['lr']*reg[0](net, self.device)
                    reg_value.backward()

                    reg[1].step()

                if (sch != None): sch.step()

                if (self.metrics.do_loss):  loss_train_aux = Metrics.add_to_dataset(loss_train_aux, loss_value.detach().view(1))
                if (self.metrics.do_acc):
                    acc_train_aux = Metrics.add_to_dataset(
                        acc_train_aux,
                        self.metrics.acc_func(
                            net, self.x_train, self.y_train, self.metrics.Batch
                        ).view(1)
                    )

                if (self.save_w_path):  model_evo_batch[f'Epoch = {epoch} //--// Batch = {N_Batches}'] = net.state_dict()

                N_Batches += 1

                t.cuda.empty_cache()

            if (self.save_w_path): t.save(model_evo_batch, f'{self.save_dir}/{self.net_name}_{trainning_date} ---- Epoch = {epoch}.pt')

            with t.no_grad():
                net.eval()
                if (self.metrics.do_loss):
                    loss_train = t.cat((
                        loss_train_aux.mean(dtype = t.float32).view(1),
                        loss_train_aux.var().view(1)/N_Batches
                    )).view(1, 2)

                else: loss_train = None

                if (self.metrics.do_acc):
                    acc_train = t.cat((
                        acc_train_aux.mean(dtype = t.float32).view(1),
                        acc_train_aux.var().view(1)/N_Batches
                    )).view(1, 2)

                else: acc_train = None

                self.metrics.measure(net, loss_func, loss_train, acc_train)
                self.metrics.plot()

                t.cuda.empty_cache()

        t.save(net.state_dict() , f'{self.save_dir}/{self.net_name}_{trainning_date}.pt')

        self.metrics.save_plot(self.save_dir, self.net_name, trainning_date)
        self.metrics.save_data(self.save_dir, self.net_name, trainning_date)
        self.metrics.close_fig()

In [None]:
x, y, m = HDF5_viewer(
    content = 'Attack_traces',
    dir = '',
    N_object = 3
)

m

In [None]:
data_p, data_a = Preprocess_data_SC(
    data_p = Import_hdf5_data_SC(
        content = 'Profiling_traces',
        dir = '',
        p_pos = 0,
        k_pos = 1,
        byte_a = 2
    ),
    data_a = Import_hdf5_data_SC(
        content = 'Attack_traces',
        dir = '',
        p_pos = 0,
        k_pos = 1,
        byte_a = 2
    ),
    size = 0.6,
    augmentation = None,
    scaler = 'Standard'
)

x_p = data_p[0]
y_p = data_p[1][:,0]

x_attack = data_a[0]
y_attack = data_a[1][:,0]
p_attack = data_a[1][:,1]
k_attack = data_a[1][:,2]

In [None]:
device = t.device('cuda')

TEST = 0.2
EPOCH = 75
BATCH_SIZE = 256
lr = 0.005

N_trains = 5
for i in range(N_trains):
    if (True):
        gc.collect()
        t.cuda.empty_cache()

    print(f'Trainning #{i}')

    trainning_date = datetime.now().strftime("%Y-%m-%d %H_%M_%S.%f")
    print(f'Date of the the trainning = {trainning_date}')
    print('\n')

    from sklearn.model_selection import StratifiedShuffleSplit
    sss = StratifiedShuffleSplit(
        n_splits = 1,
        test_size = TEST
    )

    train_ind, test_ind = sss.split(X = t.zeros_like(y_p), y = y_p).__next__()

    x_train = x_p[train_ind,...].to(t.float32)
    y_train = y_p[train_ind].to(t.int64)

    x_test = x_p[test_ind,...].to(t.float32)
    y_test = y_p[test_ind].to(t.int64)

    net = Equivariant_Wavelet_Network(device).to(device)
    print(net)

    if (False):
        net = t.compile(
            model = net,
            backend = 'inductor',
            fullgraph = True,
            dynamic = True,
            options = {
                'max-autotune': True,
                'epilogue_fusion': True,
                'shape_padding': True,
                # 'triton.cudagraphs': True,

                # 'optimize_dtypes': True,
                # 'enable_python_fallback': True,
                # 'dynamic_shapes': True,
            }
        )

    loss = nn.CrossEntropyLoss()

    opt = optim.Adam(
        params = net.parameters(),
        lr = lr,
        weight_decay = 0,
        amsgrad = False
    )

    sch = None
    if (True):
        sch = optim.lr_scheduler.OneCycleLR(
            optimizer = opt,
            epochs = EPOCH,
            total_steps = int(x_train.shape[0]/BATCH_SIZE + 1)*EPOCH,
            max_lr = lr,
            pct_start = 0.4,
            div_factor = 10,
            final_div_factor = 10,
            anneal_strategy = 'linear',
            three_phase = True,
            cycle_momentum = False,
        )

    reg = None
    if (True):
        # torch.autograd.set_detect_anomaly(True)

        def Reg_loss(net, device):   return Reg_Wavelet_mean0(net, device) + Reg_Wavelet_2norm1(net, device)

        weight_decay = 0.00001

        reg_opt = optim.SGD(
            params = net.parameters(),
            lr = weight_decay
        )

        reg = [Reg_loss, reg_opt]

    metrics = Metrics(
        do_CDA = True,
        do_conv = True,
        do_fit = True,
        do_loss = True,
        do_acc = True,
        x_attack = x_attack,
        y_attack = y_attack,
        p_attack = p_attack,
        k_attack = k_attack,
        x_test = x_test,
        y_test = y_test,
        N_attacks = 100,
        N_traces = 1000,
        Delta_traces = 200,
        N_population = 100,
        Batch = 50,
        Epochs = EPOCH,
        device = device,
        display_id = f'INFO_{i}'
    )


    trainer = Trainer(
        device = device,
        x_train = x_train,
        y_train = y_train,
        x_test = x_test,
        y_test = y_test,
        x_attack = x_attack,
        p_attack = p_attack,
        k_attack = k_attack,
        EPOCH = EPOCH,
        BATCH_SIZE = BATCH_SIZE,
        metrics = metrics,
        net_name = '',
        save_dir = "",
        save_w_path = False
    )
    if (True):
        gc.collect()
        t.cuda.empty_cache()

    trainer.train(
        net = net,
        opt = opt,
        reg = reg,
        sch = sch,
        loss_func = loss,
        trainning_date = trainning_date
    )

    print('\n\n')
