## CNN attack via ResNet using Skywater non-linearized data

The goal of this notebook is to correctly preprocess the given data as tensors that can be used to train ResNet101.

In [1]:
# dataloader.py
# Necessary imports
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import re

### Creating dataloaders
Used files from Kareem's GitHub repo. I may have made mistakes in sampling the data, so feel free to change anything that has been configured incorrectly.

In [2]:
# dataloader.py
# Returns list of files with given format 
def get_files(directory, format, digital_index=0):

    format = re.compile(format)
    files = os.listdir(directory)

    #file_dict = {}
    file_list = [] # fname, fpath, label

    for fname in files:
        if match := format.match(fname):
            fpath = os.path.join(directory, fname)

            dvalue = int(match.groups()[digital_index])
            
            file_list.append((fname, fpath, dvalue))

            #if dvalue in file_dict:
            #    file_dict[dvalue].append(fpath)
            #else:
            #    file_dict[dvalue] = [fpath]

    return file_list #file_dict, file_path

# Creates dataset with given traces
class TraceDataset(Dataset):
    cached_traces = {}
    trace_list    = []

    def __init__(self, file_list, cache=True):
        self.file_list = file_list
        self.cache     = cache

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        fname, fpath, label = self.file_list[index]
        label = self.process_label(label)

        if self.cache and fname in self.cached_traces:
            return self.cached_traces[fname], label
        else:
            return self.load_trace(fname, fpath), label

    def get_info(self, index):
        return self.file_list[index]

    def load_trace(self, fname, fpath):
        with open(fpath, 'r') as file:
            header = file.readline()
            #time_arr = []
            valu_arr = []

            for line in file.readlines():
                time, value = line.strip().split()
                #time_arr.append(np.float32(time))
                valu_arr.append(np.float32(value))

        trace = np.array(valu_arr, dtype=np.float32)

        if self.cache: 
            self.cached_traces[fname] = trace
            self.trace_list.append(trace)

        return trace
    
    def process_label(self, label): return label

    def cache_all(self):
        assert self.cache == True

        print("Caching all traces")
        for fname, fpath, label in self.file_list:
            self.load_trace(fname, fpath)
        print("DONE Caching all traces")

class TraceDatasetBW(TraceDataset):
    def __init__(self, file_list, bit_select, cache=True):
        self.bit_mask = 1 << bit_select
        super().__init__(file_list, cache=cache)

    def process_label(self, label):
        return 1 if label & self.bit_mask else 0

class TraceDatasetBuilder:
    def __init__(self, adc_bitwidth=8, cache=True):
        self.file_list        = []
        self.cache = cache
        self.adc_bits = adc_bitwidth

        self.dataset = None
        self.dataloader = None
        self.datasets = []
        self.dataloaders = []

    def add_files(self, directory, format, label_group):
        ''' Builds list of powertrace files
        Inputs:
            directory   : folder to search for files
            format      : regular expression to match filenames
            label_index : group index for digital output label corresponding to trace
        Outputs:
            list        : [(file_name, file_path, label) ... ]
        '''
        format = re.compile(format)
        fnames = os.listdir(directory)

        for fname in fnames:
            if match := format.match(fname):
                fpath = os.path.join(directory, fname)
                dvalue = int(match.groups()[label_group])

                self.file_list.append((fname, fpath, dvalue))

    def build(self):
        self.dataset = TraceDataset(self.file_list, cache=self.cache)
        for b in range(self.adc_bits):
            self.datasets.append(TraceDatasetBW(self.file_list, b, cache=self.cache))

        if self.cache:
            self.dataset.cache_all()

    def build_dataloaders(self, **kwargs): # batch_size=256, shuffle=True
        self.dataloader = DataLoader(self.dataset, **kwargs)
        self.dataloaders = [DataLoader(dataset, **kwargs) for dataset in self.datasets]

In [3]:
# Create dataloaders
pwd = os.getcwd()
# print(pwd)
# proj_dir = os.path.dirname(os.path.dirname(pwd))
# print(proj_dir)
# data_dir = os.path.join(pwd, 'analog', 'outfiles')

builder  = TraceDatasetBuilder(adc_bitwidth=8, cache=True)
builder.add_files(os.path.join(pwd, 'sky_Dec_18_2151'), "sky_d(\\d+)_.*\\.txt", 0)
builder.build()
builder.build_dataloaders(batch_size=256, shuffle=True)

Caching all traces
DONE Caching all traces


### Setup for ResNet
Section where the initial imports and variables for ResNet is set up.

In [4]:
# ResNet101 code
# Necessary imports
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from torchvision.models import resnet101, ResNet101_Weights
import numpy as np
import datetime


### Optimizing CUDA version
Section where we desginate the most updated CUDA version for current GPU.

Implementation in progress, but is a lot of unnecessary work; will do if CPU training speeds are unformidable.

## Training CNNs - ResNet101
Using pretrained ResNet101, we train each CNN until all of them reaches an accuracy of 1.

I'm not sure if aiming for an accuracy of 1 is beneficial, as it is just overfitting the model to the training data. A more realistic value may be 0.99, but will set it to 1 for the current simulated environment.

The training function automatically reduces the learning rate used in Adam based on target accuracy. Currently testing different values and decrease rates. 

In [None]:
cnns = []
dataloaders = builder.dataloaders 
timestamp   = datetime.datetime.now().strftime('%Y%m%d_%H%M')

'''
plt.ion()
figs, axs = plt.subplots(2)
axs[0].set_title("Loss")
axs[1].set_title("Accuracy")
'''

print(f"Installed CUDA version: {torch.version.cuda}")

# Create CNN per dataset
for i in range(7,-1,-1):
    print(f"Starting training for \"cnn_{i}\"...")
    # Model: ResNet101, pretrained=true, using ResNet101_Weights.DEFAULT for up-to-date values
    cnn = resnet101(weights=ResNet101_Weights.DEFAULT)
    cnn.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    cnn.fc = nn.Linear(cnn.fc.in_features, 2)
    # Loss function: not specified in paper, used Cross Entropy Loss
    criterion = nn.CrossEntropyLoss()

    # Optimizer: not specified in paper, used Adam
    # Target accuracies to update the learning rate
    # Use different values if needed
    target_acc = [0.90, 0.95, 0.99, 0.995, 1.0]
    target_acc_index = 0
    learning_rate = 1e-4
    optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)

    # Parameters for training CNN
    num_epochs = 1000
    max_grad_norm = 1.0  # Gradient clipping threshold
    learning_rate = 1e-4  # Reduced learning rate for stability, inital value = 1e-4
    '''
    # Paramters for plot, may erase if not used
    loss_arr = []  # Array used to store loss values over epoches
    acc_arr  = []  # Array used to store accuracy values over epoches
    loss_g = None
    acc_g  = None
    '''

    # Append CNN to cnns array
    cnns.append(cnn)

    # Create checkpoint to save progress
    checkpoint_path = f"resnet101_checkpoint_{i}.pth"
    start_epoch = 0
    
    if torch.cuda.is_available():
        print("GPU found, running training on GPU...")
        device = torch.device("cuda")
        cnn = cnn.to(device)
    else:
        print("No GPU found, running training on CPU...")
        print("Recheck CUDA version and if your GPU supports it.")
        device = torch.device("cpu")

    # Try to load .pth file
    # NEED TO ADD FUNCTIONALITY TO CHECK INTEGRITY OF .pth FILE
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        cnn.load_state_dict(checkpoint['cnn_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
        reached_target_acc = checkpoint['reached_acc']
        print(f"Previous reached accuracy: {reached_target_acc}")
        if reached_target_acc == 1:
            print(f"Skipping training for \"cnn_{i}\", already reached accuracy of 1.\n\n")
            continue
        while target_acc_index < len(target_acc) - 1 and target_acc[target_acc_index] < reached_target_acc:
            target_acc_index += 1
            learning_rate /= 2
        print(f"Updated target accuracy: {target_acc[target_acc_index]}")
        print(f"Updated learning rate: {learning_rate}")
    except FileNotFoundError:
        print(f"No checkpoint found for \"cnn_{i}\". Starting from scratch.")

    # Start training
    for epoch in range(start_epoch, num_epochs):
        correct = 0
        cnn.train()

        for inputs, labels in dataloaders[i]:
            inputs, labels = inputs.to(device), labels.to(device)
            # Add dimensions for channels and width
            inputs = inputs.unsqueeze(1).unsqueeze(-1)
            optimizer.zero_grad()
            output = cnn(inputs)
            # Check for NaN in outputs
            if torch.isnan(output).any():
                print("NaN detected in cnn outputs.")
                break

            loss = criterion(output, labels)
            # Check for NaN in loss
            if torch.isnan(loss):
                print("NaN detected in loss. Stopping training.")
                break
            # print(f"Loss: {loss.item()}")
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(cnn.parameters(), max_grad_norm)
            optimizer.step()
            _, predicted = torch.max(output, 1)
            correct += (predicted == labels).sum()
        
        accuracy = correct / 256
        '''
        acc_arr.append(accuracy)
        loss_arr.append(loss.item())
        '''
        # Change rate of update for printing accuracy accordingly
        if (epoch + 1) % 1 == 0:
            print(f'TRAINING: cnn[{i}], Epoch {epoch+1}, Loss: {loss.item()}')
            print(f'TRAINING: cnn[{i}], Epoch {epoch+1}, Accuracy: {accuracy}')
        '''
        if epoch % 50 == 0: 
            if loss_g: loss_g.remove()
            if acc_g:  acc_g.remove()
            loss_g = axs[0].plot(loss_arr, color='lightgray', linestyle='dotted')[0]
            loss_a = axs[1].plot(acc_arr,  color='lightgray', linestyle='dotted')[0]
            plt.pause(0.01)
        '''

        # Save checkpoint
        torch.save({
            'epoch': epoch + 1,
            'cnn_state_dict': cnn.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'reached_acc': accuracy
        }, checkpoint_path)
        print(f"Checkpoint saved for epoch {epoch + 1}")

        # Check oscillation
        

        # If training reached accuracy of 1, stop training
        if accuracy == 1:
            print(f"Reached accuracy of 1. Stopping training for \"cnn_{i}\".\n")
            break
        # Update learning rate if accuracy reaches target value
        # Reducing stepsize accordingly so the optimizer does not overshoot
        elif target_acc[target_acc_index] < accuracy and target_acc_index != len(target_acc) - 1:
            updated = False
            temp = learning_rate
            while target_acc_index <= len(target_acc) - 1 and target_acc[target_acc_index] < accuracy:
                target_acc_index += 1
                learning_rate = learning_rate / 2
                updated = True
            if updated:
                print(f"Reached target accuracy {target_acc[target_acc_index-1]}")
                print(f"New target accuracy {target_acc[target_acc_index]}")
                print(f"Updating learning rate FROM: {temp}, TO: {learning_rate}")

'''
    label = f'cnn[{i}]'
    axs[0].plot(loss_arr, label=label)
    axs[1].plot(acc_arr,  label=label)
    axs[0].legend()
    axs[1].legend()
    plt.pause(0.01)

plt.pause(60*10)
'''

Installed CUDA version: None
Starting training for "cnn_7"...
No GPU found, running training on CPU...
Recheck CUDA version and if your GPU supports it.
Checkpoint loaded. Resuming from epoch 9
Previous reached accuracy: 1.0
Skipping training for "cnn_7", already reached accuracy of 1.


Starting training for "cnn_6"...
No GPU found, running training on CPU...
Recheck CUDA version and if your GPU supports it.
Checkpoint loaded. Resuming from epoch 18
Previous reached accuracy: 1.0
Skipping training for "cnn_6", already reached accuracy of 1.


Starting training for "cnn_5"...
No GPU found, running training on CPU...
Recheck CUDA version and if your GPU supports it.
Checkpoint loaded. Resuming from epoch 26
Previous reached accuracy: 1.0
Skipping training for "cnn_5", already reached accuracy of 1.


Starting training for "cnn_4"...
No GPU found, running training on CPU...
Recheck CUDA version and if your GPU supports it.
Checkpoint loaded. Resuming from epoch 51
Previous reached accura

### Testing the CNNs

Using the resulting CNNs saved in the .pth files, we test each CNNs using the provided power traces.

Currently the testing data is a subset of the training data. In the future, if we can generate more data, we will be able to use separate datasets.



In [None]:
# Run evaluation using testing data
# CURRENTLY USING TRAINING DATA TO TEST DATA: use separate dataset in the future
for i in range(7,-1,-1):
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        cnn.load_state_dict(checkpoint['cnn_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
        reached_target_acc = checkpoint['reached_acc']
        print(f"Previous reached accuracy: {reached_target_acc}")
        if reached_target_acc != 1:
            print(f"cnn_{i}\" did not reach accuracy of 1, skipping evaluation.\n\n")
            continue
    except FileNotFoundError:
        print(f"No checkpoint found for \"cnn_{i}\". Starting from scratch.")

    cnns[i].eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloaders[i]:
            inputs, labels = inputs.to(device), labels.to(device)
            # Add dimensions for channels and width
            inputs = inputs.unsqueeze(1).unsqueeze(-1)
            outputs = cnns[i](inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Accuracy: {100 * correct / total:.2f}%")