## CNN attack via ResNet using Skywater non-linearized data

The goal of this notebook is to correctly preprocess the given data as tensors that can be used to train ResNet101.

In [1]:
# dataloader.py
# Necessary imports
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import re

### Creating dataloaders
Used files from Kareem's GitHub repo. I may have made mistakes in sampling the data, so feel free to change anything that has been configured incorrectly.

In [2]:
# dataloader.py
# Returns list of files with given format 
def get_files(directory, format, digital_index=0):

    format = re.compile(format)
    files = os.listdir(directory)

    #file_dict = {}
    file_list = [] # fname, fpath, label

    for fname in files:
        if match := format.match(fname):
            fpath = os.path.join(directory, fname)

            dvalue = int(match.groups()[digital_index])
            
            file_list.append((fname, fpath, dvalue))

            #if dvalue in file_dict:
            #    file_dict[dvalue].append(fpath)
            #else:
            #    file_dict[dvalue] = [fpath]

    return file_list #file_dict, file_path

# Creates dataset with given traces
class TraceDataset(Dataset):
    cached_traces = {}
    trace_list    = []

    def __init__(self, file_list, cache=True):
        self.file_list = file_list
        self.cache     = cache

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, index):
        fname, fpath, label = self.file_list[index]
        label = self.process_label(label)

        if self.cache and fname in self.cached_traces:
            return self.cached_traces[fname], label
        else:
            return self.load_trace(fname, fpath), label

    def get_info(self, index):
        return self.file_list[index]

    def load_trace(self, fname, fpath):
        with open(fpath, 'r') as file:
            header = file.readline()
            #time_arr = []
            valu_arr = []
            # Fixed error of float32 incorrectly translating values
            for line in file.readlines():
                time, value = line.strip().split()
                try:
                    match = re.search(r"(?<=e-)\d+", value)
                    if match:
                        if value[0] == "-":
                            strip_val = value[0:11]
                        else:
                            strip_val = value[0:10]
                    float_val = np.float64(strip_val)
                    rounded_val = round(float_val, 7)
                    valu_arr.append(np.float32(rounded_val))

                except ValueError as e:
                    print(f"Error parsing value '{value}': {e}")

        trace = np.array(valu_arr, dtype=np.float32)

        if self.cache: 
            self.cached_traces[fname] = trace
            self.trace_list.append(trace)

        return trace
    
    def process_label(self, label): return label

    def cache_all(self):
        assert self.cache == True

        print("Caching all traces")
        for fname, fpath, label in self.file_list:
            self.load_trace(fname, fpath)
        print("DONE Caching all traces")

class TraceDatasetBW(TraceDataset):
    def __init__(self, file_list, bit_select, cache=True):
        self.bit_mask = 1 << bit_select
        super().__init__(file_list, cache=cache)

    def process_label(self, label):
        return 1 if label & self.bit_mask else 0

class TraceDatasetBuilder:
    def __init__(self, adc_bitwidth=8, cache=True):
        self.file_list        = []
        self.cache = cache
        self.adc_bits = adc_bitwidth

        self.dataset = None
        self.dataloader = None
        self.datasets = []
        self.dataloaders = []

    def add_files(self, directory, format, label_group):
        ''' Builds list of powertrace files
        Inputs:
            directory   : folder to search for files
            format      : regular expression to match filenames
            label_index : group index for digital output label corresponding to trace
        Outputs:
            list        : [(file_name, file_path, label) ... ]
        '''
        format = re.compile(format)
        fnames = os.listdir(directory)

        for fname in fnames:
            if match := format.match(fname):
                fpath = os.path.join(directory, fname)
                dvalue = int(match.groups()[label_group])

                self.file_list.append((fname, fpath, dvalue))

    def build(self):
        self.dataset = TraceDataset(self.file_list, cache=self.cache)
        for b in range(self.adc_bits):
            self.datasets.append(TraceDatasetBW(self.file_list, b, cache=self.cache))

        if self.cache:
            self.dataset.cache_all()

    def build_dataloaders(self, **kwargs): # batch_size=256, shuffle=True
        self.dataloader = DataLoader(self.dataset, **kwargs)
        self.dataloaders = [DataLoader(dataset, **kwargs) for dataset in self.datasets]

In [3]:
# Create dataloaders
pwd = os.getcwd()
# print(pwd)
# proj_dir = os.path.dirname(os.path.dirname(pwd))
# print(proj_dir)
# data_dir = os.path.join(pwd, 'analog', 'outfiles')

builder  = TraceDatasetBuilder(adc_bitwidth=8, cache=True)
builder.add_files(os.path.join(pwd, 'sky_Dec_18_2151'), "sky_d(\\d+)_.*\\.txt", 0)
builder.build()
builder.build_dataloaders(batch_size=256, shuffle=True)

FileNotFoundError: [Errno 2] No such file or directory: '/Users/calya/Desktop/PowerTraces/CNN_PSA_ResNet/sky_Dec_18_2151'

### Setup for ResNet
Section where the initial imports and variables for ResNet is set up.

In [4]:
# ResNet18 code
# Necessary imports
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from torchvision.models import resnet18, ResNet18_Weights
import numpy as np
import datetime


In [None]:
# Create dataset

### Optimizing CUDA version
Section where we desginate the most updated CUDA version for current GPU.

Implementation in progress, but is a lot of unnecessary work; will do if CPU training speeds are unformidable.

## Training CNNs - ResNet18
Using pretrained ResNet18, we train each CNN until all of them reaches an accuracy of 1.0.

I'm not sure if aiming for an accuracy of 1.0 is beneficial, as it is just overfitting the model to the training data. A more realistic value may be 0.99, but will set it to 1.0 for the current simulated environment.

The training function automatically reduces the learning rate used in Adam based on target accuracy. Currently testing different values and decrease rates. 

### Hyperparameters

1) def_lr: Default learning rate. Change this value to test different learning rate values.
2) divide_lr: Rate of how much the learning rate should be reduced if weight decay is enabled. Increase this value to decrease the learning rate more.
3) enable_weight_decay: If set to 'True', enables (1) updating the learning rate according to the resulting accuracy values and (2) oscillation detection (updating the learning rate if there are no improvements over multiple training epochs).
4) freeze_layers: If set to 'True', freezes all layers of ResNet with no additional training to those layers.

In [None]:
# Main hyperparameters
def_lr = 1e-4
divide_lr = 2
enable_weight_decay = False
freeze_layers = False

### Helper Functions

In [5]:
''' 
Function that initializes the CNN and its core components.
Inputs:
    None
Returns:
    1) cnn: model used(ResNet18)
    2) criterion: loss function, current default=CrossEntropyLoss
    3) learning_rate: learning rate, current default=def_lr
    4) optimizer: optimizer, current default=Adam
'''
def cnn_init():
    # Model: ResNet18, pretrained=true, using ResNet18_Weights.DEFAULT for up-to-date values
    cnn = resnet18(weights=ResNet18_Weights.DEFAULT)
    if freeze_layers:
        # Freeze all layers
        for param in cnn.parameters():
            param.requires_grad = False
    cnn.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    # Resulting output is either 0 or 1
    cnn.fc = nn.Linear(cnn.fc.in_features, 2)
    # Loss function: not specified in paper, using Cross Entropy Loss
    criterion = nn.CrossEntropyLoss()

    # Learning rate: default set to def_lr, adjust accordingly
    # Optimizer: not specified in paper, using Adam
    learning_rate = def_lr
    optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)
    return cnn, criterion, learning_rate, optimizer

''' 
Function that initializes parameters used in training.
Inputs:
    None
Returns:
    TARGET ACCURACY PARAMETERS
    1) target_acc: array with target accuracies which reached updates the learning rate
    2) target_acc_index: value used to track progress of target accuracy
    OSCILLATION CHECKING PARAMETERS
    3) is_osc: boolean value that is set to true when oscillation is detected. Automatically set to False
    if new .pth file is loaded.
    4) osc_count: value that counts number of similar accuracy outputs to check oscillation
    TRAINING PARAMETERS
    5) num_epochs: number of maximum epochs per training
    6) max_grad_norm: gradient clipping threshold
    WEIGHT DECAY PARAMETER
    7) weight_decay: set to hyperparameter value 'enable_weight_decay'.
    If True, enables weight decay and oscallation update. If False, disables it.
'''
def param_init():
    # target accuracy parameters
    target_acc = [0.90, 0.95, 0.99, 0.995, 1.0]
    target_acc_index = 0
    
    # oscillation checking parameters
    is_osc = False
    osc_count = 0

    # training parameters
    num_epochs = 1000
    max_grad_norm = 1.0

    # weight decay parameter
    weight_decay = enable_weight_decay
    
    return target_acc, target_acc_index, is_osc, osc_count, num_epochs, max_grad_norm, weight_decay

''' 
Function that loads variables from saved .pth file. Checks if the file is valid.
This function simply receives the checkpoint file and checks if it is valid.
Inputs:
    1) checkpoint_path: string of .pth file name 
    2) device: selected device to run training
Returns:
    1) pth_variables: dictionary that contains received .pth variables.
    List of variables are saved in 'load_values' array.
'''
def load_pth_file(checkpoint_path, device, load_values):
    pth_variables = {}
    # Load .pth file
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        for values in load_values:
            pth_variables[values] = checkpoint.get(values, 0)
        return pth_variables
    except FileNotFoundError:
        print(f"No checkpoint found at {checkpoint_path}. Starting from scratch.")
        raise
    except KeyError as e:
        missing_keys = {
            'cnn_state_dict': "CNN state dictionary",
            'optimizer_state_dict': "optimizer state dictionary",
            'epoch': "epoch",
            'reached_acc': "recorded accuracy of current epoch",
            'osc_count': "oscillation count"
        }
        key = e.args[0]
        if key in missing_keys:
            if key in ['cnn_state_dict', 'optimizer_state_dict', 'epoch', 'reached_acc']:
                print(f"Critical error: Missing \"{key}\" that stores \"{missing_keys[key]}\". Checkpoint file is invalid.")
                raise
            else:
                print(f"Warning: Missing \"{key}\" that stores \"{missing_keys[key]}\". Default values will be used.")
                if key == 'osc_count':
                    osc_count = 0

        print(f"Unknown parameter \"{key}\" in checkpoint. Recheck file and try again.")
        raise

''' 
Function that updates learning rate accordingly from current reached target accuracy.
This function allows changing values in the target_acc array without additional components.
Inputs:
    1) target_acc: array with target accuracies which reached updates the learning rate
    2) target_acc_index: index value of previous target accuracy value
    3) learning_rate: learning rate from CNN
    4) reached_acc: previously reached accuracy
    5) force_update: boolean that specifies if the function is called intentionally.
    If set to True, print results
Returns:
    1) learning_rate: new updated learning rate
    2) target_acc_index: updated index value of new target accuracy value
'''
def update_learning_rate(target_acc, target_acc_index, learning_rate, reached_acc, force_update):
    temp = learning_rate
    updated = False
    while target_acc_index < len(target_acc) - 1 and target_acc[target_acc_index] < reached_acc:
        updated = True
        target_acc_index += 1
        learning_rate /= divide_lr * (target_acc_index)
    if force_update:
        print(f"\tUpdated target accuracy: {target_acc[target_acc_index]}")
        print(f"\tUpdated learning rate: {learning_rate}")
    elif updated:
        print(f"Reached target accuracy {target_acc[target_acc_index-1]}")
        print(f"New target accuracy {target_acc[target_acc_index]}")
        print(f"Updating learning rate FROM: {temp}, TO: {learning_rate}")
    return learning_rate, target_acc_index

In [None]:
cnns = []
dataloaders = builder.dataloaders 
timestamp   = datetime.datetime.now().strftime('%Y%m%d_%H%M')

print(f"Installed CUDA version: {torch.version.cuda}\n")

# Create CNN per dataset
for i in range(7,-1,-1):
    print(f"Starting training for \"cnn_{i}\"...")

    # initialize CNN
    cnn, criterion, learning_rate, optimizer = cnn_init()
    # Save default learning rate
    default_lr = learning_rate
    # initialize parameters
    target_acc, target_acc_index, is_osc, osc_count, num_epochs, max_grad_norm, weight_decay = param_init()

    if weight_decay:
        print("\tWeight decay enabled. Learning rate update and oscillation detection enabled.")
    else:
        print("\tWeight decay disabled. Learning rate update and oscillation detection disabled.")

    # Append CNN to cnns array
    cnns.append(cnn)

    # Create checkpoint to save progress
    checkpoint_path = f"resnet101_checkpoint_{i}.pth"
    
    # Set device to cuda if available
    if torch.cuda.is_available():
        print("\tGPU found, running training on GPU...")
        device = torch.device("cuda")
        cnn = cnn.to(device)
    else:
        print("\tNo GPU found, running training on CPU...")
        print("\tRecheck CUDA version and if your GPU supports it.")
        device = torch.device("cpu")

    # Attempt to load saved .pth file
    # start_epoch = LOCAL variable that specifies starting epoch number
    # prev_acc = LOCAL variable that specifies accuracy achieved by loaded .pth file
    # osc_count = variable used to check oscillation, same name as dictionary key in .pth file
    load_values = ['cnn_state_dict', 'optimizer_state_dict', 'epoch', 'reached_acc', 'osc_count']
    try:
        pth_vars = load_pth_file(checkpoint_path, device, load_values)
    except FileNotFoundError:
        start_epoch = 0
        prev_acc = 0
        osc_count = 0
    else:
        # Manually create variables per value
        cnn.load_state_dict(pth_vars['cnn_state_dict'])
        optimizer.load_state_dict(pth_vars['optimizer_state_dict'])
        start_epoch = pth_vars['epoch']
        prev_acc = pth_vars['reached_acc']
        osc_count = pth_vars['osc_count']
    
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}.")
        print(f"\tPrevious reached accuracy: {prev_acc}.")

    # Skip training if target accuracy is reached
    if prev_acc == 1:
        print(f"\tSkipping training: accuracy of 1 already achieved.\n")
        continue
    
    if weight_decay:
        # Adjust learning rate
        learning_rate, target_acc_index = update_learning_rate(target_acc, target_acc_index, learning_rate, prev_acc, force_update=True)
        optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)

    # Start training
    for e in range(start_epoch, num_epochs):
        correct = 0
        cnn.train()

        for inputs, labels in dataloaders[i]:
            inputs, labels = inputs.to(device), labels.to(device)
            # Add dimensions for channels and width
            inputs = inputs.unsqueeze(1).unsqueeze(-1)
            optimizer.zero_grad()
            output = cnn(inputs)
            # Check for NaN in outputs
            if torch.isnan(output).any():
                print("NaN detected in cnn outputs.")
                break

            loss = criterion(output, labels)
            # Check for NaN in loss
            if torch.isnan(loss):
                print("NaN detected in loss. Stopping training.")
                break
            # print(f"Loss: {loss.item()}")
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(cnn.parameters(), max_grad_norm)
            optimizer.step()
            _, predicted = torch.max(output, 1)
            correct += (predicted == labels).sum()
        
        accuracy = correct / 256

        # Change rate of update for printing accuracy accordingly
        if (e + 1) % 1 == 0:
            print(f'TRAINING: cnn[{i}], Epoch {e+1}, Loss: {loss.item()}')
            print(f'TRAINING: cnn[{i}], Epoch {e+1}, Accuracy: {accuracy}')

        if weight_decay:
            # Check oscillation and update learning rate
            # 1) is_osc == False: not in reset phase
            # 1-a) check difference between accuracy and prev_acc, update osc_count
            # 1-b) IF updated osc_count > 7, means oscillating, set learning rate to default value
            # 1-c) ELSE not oscillating, check if target learning rate is reached
            # 2) is_osc == True: in reset phase
            # 2-a) check if osc_count == 0
            # 2-b) IF osc_count == 0, reset phase is complete, re-update learning rate
            if not is_osc:
                if abs(accuracy - prev_acc) < 0.001:
                    osc_count += 1
                else:
                    osc_count = 0
                if osc_count > 7:
                    print(f"\tOSCILLATION DETECTED. Resetting learning rate to {default_lr}...")
                    is_osc = True
                    learning_rate = default_lr
                    osc_count = 3
                else:
                    # Adjust learning rate
                    learning_rate, target_acc_index = update_learning_rate(target_acc, target_acc_index, learning_rate, accuracy, force_update=False)
                    optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)
            else:
                if osc_count == 0:
                    # Adjust learning rate
                    print("\tLearning rate reset completed, adjusting learning rate...")
                    learning_rate, target_acc_index = update_learning_rate(target_acc, 0, learning_rate, accuracy, force_update=True)
                    optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)
                    is_osc = False
                else:
                    osc_count -= 1

        # Save checkpoint
        torch.save({
            'cnn_state_dict': cnn.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': e + 1,
            'reached_acc': accuracy,
            'osc_count': osc_count
        }, checkpoint_path)
        print(f"Checkpoint saved for epoch {e + 1}")

        prev_acc = accuracy

        if accuracy == 1:
            print(f"Reached accuracy of 1. Stopping training for \"cnn_{i}\".\n")
            break

### Testing the CNNs

Using the resulting CNNs saved in the .pth files, we test each CNNs using the provided power traces.

Currently the testing data is a subset of the training data. In the future, if we can generate more data, we will be able to use separate datasets.



In [None]:
# Run evaluation using testing data
# CURRENTLY USING TRAINING DATA TO TEST DATA: use separate dataset in the future
for i in range(7,-1,-1):
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        cnn.load_state_dict(checkpoint['cnn_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
        reached_target_acc = checkpoint['reached_acc']
        print(f"Previous reached accuracy: {reached_target_acc}")
        if reached_target_acc != 1:
            print(f"cnn_{i}\" did not reach accuracy of 1, skipping evaluation.\n\n")
            continue
    except FileNotFoundError:
        print(f"No checkpoint found for \"cnn_{i}\". Starting from scratch.")

    cnns[i].eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloaders[i]:
            inputs, labels = inputs.to(device), labels.to(device)
            # Add dimensions for channels and width
            inputs = inputs.unsqueeze(1).unsqueeze(-1)
            outputs = cnns[i](inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Accuracy: {100 * correct / total:.2f}%")