In [1]:
%load_ext autoreload
%autoreload 2

import math
import torch
from torch import nn
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from pathlib import Path
from context import LocalLearning
from tqdm.notebook import tqdm
import numpy as np
from matplotlib import pyplot as plt
import os
import copy
from scipy import stats

import pickle

plt.style.use(['seaborn-paper', "./A1PosterPortrait.mplstyle"])

In [2]:
# hyper parameter
BATCH_SIZE = 1000
LEARNING_RATE = 1e-4

In [3]:
# define paths
ll_model_path = Path("../data/models/CIFAR10_PowerLaw")
model_path = ll_model_path
figure_path = Path("../data/figures/NORAConf23Poster")

# define file names
khmodel_file = Path("KHModel.pty")
khmodel_scheduled_file = Path("KHModel_scheduled.pty")
bpmodel_file = Path("BPModel.pty")
bpmodel_scheduled_file = Path("BPModel_scheduled.pty")

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [5]:
# define colormap for all the poster plots
cmap = {
    "kh": "#762a83",
    "hybrid": "#f8a953",
    "bp": "#1b7837",#"#106151",
}

# Load Models

In [6]:
khmodel_state = torch.load(model_path / khmodel_file)
khmodel = LocalLearning.KHModel(khmodel_state["fkhl3-state"])
khmodel.load_state_dict(khmodel_state["model_state_dict"])
khmodel.to(device)
khmodel.eval()

KHModel(
  (local_learning): FKHL3(
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (relu_h): ReLU()
  (dense): Linear(in_features=2000, out_features=10, bias=True)
  (softMax): Softmax(dim=-1)
)

In [7]:
bpmodel_state = torch.load(model_path / bpmodel_file)
bpmodel = LocalLearning.SHLP(params=bpmodel_state["params"])
bpmodel.load_state_dict(bpmodel_state["model_state_dict"])
bpmodel.to(device)
bpmodel.eval()

SHLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (ReLU): ReLU()
  (dense): Linear(in_features=2000, out_features=10, bias=True)
)

# Perform Mia's Attacks

In [8]:
def cross_entropy_loss(outputs, targets):
    # computationally stable
    log_probs = nn.functional.log_softmax(outputs, dim=-1)
    loss = (-log_probs.gather(1, targets[..., None])).sum() / len(outputs)
    
    # Compute negative log likelihood loss
    #loss = nn.functional.nll_loss(log_probs, targets)
    
    return loss

ce_loss = cross_entropy_loss

In [9]:
cifar10Test= LocalLearning.LpUnitCIFAR10(
            root="../data/CIFAR10",
            train=False,
            transform=ToTensor(),
            p=3.0,
        )

TestLoader = LocalLearning.DeviceDataLoader(
            cifar10Test,
            device=device,
            batch_size=BATCH_SIZE,
            num_workers=4,
            shuffle=True,
        )

Files already downloaded and verified


In [10]:
def data_critirium(
    dataloader,
    model, 
    crit,
    thres = None
    ):
    
    # Returns the data and the corresponding labels that meets the critirium given (crit)
    
    freq_correct = 0
    model.eval()
    total = 0
    
    data = torch.zeros((0,32,32,3)).to(device)
    lab_data = torch.zeros((0)).to(device)
    
    for batch_no, (features, labels) in enumerate(dataloader):
        preds = model(features)
        pred = torch.argmax(preds, dim=-1)
        
        if crit == "correct":
            filtr_idx = (torch.abs(pred - labels) == 0)
            new_features = features[filtr_idx]
            data = torch.cat((data, new_features),dim=0)
            new_labels = labels[filtr_idx]
            lab_data = torch.cat((lab_data, new_labels),dim=0)
        
        elif crit == "correct_thres":
            softmax_correct = (preds[torch.arange(1000),pred])
            thres_idx = (softmax_correct >= thres)
            correct_idx = (torch.abs(pred - labels) == 0)
            filtr_idx = thres_idx & correct_idx 
            new_features = features[filtr_idx]
            data = torch.cat((data, new_features),dim=0)
            new_labels = labels[filtr_idx]
            lab_data = torch.cat((lab_data, new_labels),dim=0)
            
        elif crit == "thres":
            softmax_correct = (preds[torch.arange(1000),pred])
            filtr_idx = (softmax_correct >= thres)
            new_features = features[filtr_idx]
            data = torch.cat((data, new_features),dim=0)
            new_labels = labels[filtr_idx]
            lab_data = torch.cat((lab_data, new_labels),dim=0)
            
        else: 
            raise ValueError("Not a valid criterium")
    
    return data, lab_data

In [11]:
max_num_steps = 10000 
step_size = 0.0001 
eps_start = 0

eps_list = [eps_start + n * step_size for n in range(max_num_steps + 1)]

In [12]:
def crit_eps(criterium, model, attack, print_accuracy=False):
    
    features, labels = data_critirium(TestLoader, model, criterium)
    labels = labels.type(torch.LongTensor).to(device)
    features.requires_grad = True
    perturbed_image = features 
    
    siz = len(labels)

    crit_eps_per_image = torch.ones(siz).to(device).fill_(math.nan)   
    crit_dist_per_image = torch.ones(siz).to(device).fill_(math.nan)

    freq_correct = 0
    total = 0 

    loss_fn = ce_loss 
    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

    correct = []
    b_norm = 0.05
    
    noise = torch.randn(features.shape).to(device)
    
    images = []
    accuracy_dict = {}
    accuracy_dict_actual = {}
    
    preds = model(features)
    loss = loss_fn(preds, labels)
    optimizer.zero_grad()
    loss.backward()

    with tqdm(total=len(eps_list)) as pbar:
        for i, eps in enumerate(eps_list):
        
            with torch.no_grad():

                if attack == "WN":
                    adv_image = features + eps*noise
                    perturbed_image = torch.clamp(adv_image, min = 0, max = 1)

                elif attack == "PGD":

                    adv_image = perturbed_image + eps*features.grad.data.sign()
                    clamp = torch.clamp(adv_image - features, min = -b_norm, max = b_norm)
                    perturbed_image = torch.clamp(features + clamp, min = 0, max = 1)

                elif attack == "FGSM":

                    perturbed_image = features + eps*features.grad.data.sign()
                    perturbed_image = torch.clamp(perturbed_image, 0, 1)

                preds_perturbed = torch.argmax(model(perturbed_image), dim=-1)
                alike = (preds_perturbed == labels)
                freq_correct = (torch.abs(preds_perturbed - labels) == 0).sum()
                total = len(labels)
                
                accuracy = (freq_correct/total).item()
                correct.append(accuracy)
                
                mask = (alike == False) & (crit_eps_per_image.isnan())

                x = features.view(features.size(0),-1)
                y = perturbed_image.view(perturbed_image.size(0),-1)

                dist = torch.norm(x - y, dim=1).detach()

                crit_dist_per_image[mask] = dist[mask]

                perturbation = torch.abs(features - perturbed_image).detach()
                avg_perturbation = torch.sum(perturbation.view(perturbation.size(0), -1), dim=1) / (perturbation.size(1) * perturbation.size(2) * perturbation.size(3))
                crit_eps_per_image[mask] = avg_perturbation[mask]
                
                accuracy_dict_actual[eps] = accuracy*100
                
                pbar.update(1)
                
                if (i % 2 == 0) and i <= 50:
                    info = [perturbed_image[:10], labels[:10], preds_perturbed[:10], features, eps]
                    images.append(info)

    unique_crit_eps = torch.unique(crit_eps_per_image)
    for unique_eps in unique_crit_eps:
        if not math.isnan(unique_eps.item()):
            not_misclassified = (torch.sum(crit_eps_per_image > unique_eps).item()*100)/siz
            accuracy_dict[unique_eps.item()] = not_misclassified
    
    if print_accuracy == True:
        print(f"{correct[-1]*100:.2f}% is still correctly classified")
        print(f"{100*(torch.sum(crit_eps_per_image.isnan()).item())/siz}% have been correctly classified at every step")
        
    crit_eps_per_image = np.array(crit_eps_per_image.cpu())
    crit_dist_per_image = np.array(crit_dist_per_image.cpu())
    correct = np.array(correct)
        
    return crit_eps_per_image, crit_dist_per_image, correct, images, accuracy_dict, accuracy_dict_actual

In [17]:
criteps_bp_wn, critdist_bp_wn, correct_bp_wn, images_bp_wn, accuracy_dict_bp_wn, accuracy_act_bp_wn = crit_eps(
    "correct", 
    bpmodel, 
    "WN", 
    print_accuracy=True,
)
criteps_ll_wn, critdist_ll_wn, correct_ll_wn, images_ll_wn, accuracy_dict_ll_wn, accuracy_act_ll_wn = crit_eps(
    "correct", 
    khmodel, 
    "WN", 
    print_accuracy=True,
)

  0%|          | 0/10001 [00:00<?, ?it/s]

10.99% is still correctly classified
7.122093023255814% have been correctly classified at every step


  0%|          | 0/10001 [00:00<?, ?it/s]

10.80% is still correctly classified
7.442748091603053% have been correctly classified at every step


In [22]:
khmodel_adversarial_data_WN = {
    "accuracy dict": accuracy_dict_ll_wn, 
     "actual accuracy dict": accuracy_act_ll_wn,
    "critical epsilon": criteps_ll_wn, 
    "critical distance": critdist_ll_wn, 
}
bpmodel_adversarial_data_WN = {
    "accuracy dict": accuracy_dict_bp_wn, 
     "actual accuracy dict": accuracy_act_bp_wn,
    "critical epsilon": criteps_bp_wn, 
    "critical distance": critdist_bp_wn, 
}

with open(model_path / Path("khmodel_attack_statistics_WhiteNoise.pkl"), 'wb') as handle:
    pickle.dump(khmodel_adversarial_data_WN, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
with open(model_path / Path("bpmodel_attack_statistics_WhiteNoise.pkl"), 'wb') as handle:
    pickle.dump(bpmodel_adversarial_data_WN, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [13]:
criteps_bp_fgsm, critdist_bp_fgsm, correct_bp_fgsm, images_bp_fgsm, accuracy_dict_bp_fgsm, accuracy_act_bp_fgsm = crit_eps(
    "correct", 
    bpmodel, 
    "FGSM", 
    print_accuracy=True,
)
criteps_ll_fgsm, critdist_ll_fgsm, correct_ll_fgsm, images_ll_fgsm, accuracy_dict_ll_fgsm, accuracy_act_ll_fgsm = crit_eps(
    "correct", 
    khmodel, 
    "FGSM", 
    print_accuracy=True,
)

  0%|          | 0/10001 [00:00<?, ?it/s]

0.00% is still correctly classified
0.0% have been correctly classified at every step


  0%|          | 0/10001 [00:00<?, ?it/s]

0.42% is still correctly classified
0.0% have been correctly classified at every step


In [15]:
khmodel_adversarial_data_FGSM = {
    "accuracy dict": accuracy_dict_ll_fgsm, 
     "actual accuracy dict": accuracy_act_ll_fgsm,
    "critical epsilon": criteps_ll_fgsm, 
    "critical distance": critdist_ll_fgsm, 
}
bpmodel_adversarial_data_FGSM = {
    "accuracy dict": accuracy_dict_bp_fgsm, 
     "actual accuracy dict": accuracy_act_bp_fgsm,
    "critical epsilon": criteps_bp_fgsm, 
    "critical distance": critdist_bp_fgsm, 
}

with open(model_path / Path("khmodel_attack_statistics_FGSM.pkl"), 'wb') as handle:
    pickle.dump(khmodel_adversarial_data_FGSM, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
with open(model_path / Path("bpmodel_attack_statistics_FGSM.pkl"), 'wb') as handle:
    pickle.dump(bpmodel_adversarial_data_FGSM, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [19]:
criteps_bp_pgd, critdist_bp_pgd, correct_bp_pgd, images_bp_pgd, accuracy_dict_bp_pgd, accuracy_act_bp_pgd = crit_eps(
        "correct", 
        bpmodel, 
        "PGD", 
        print_accuracy=True,
)
criteps_ll_pgd, critdist_ll_pgd, correct_ll_pgd, images_ll_pgd, accuracy_dict_ll_pgd, accuracy_act_ll_pgd = crit_eps(
    "correct", 
    khmodel, 
    "PGD", 
    print_accuracy=True,
)

  0%|          | 0/10001 [00:00<?, ?it/s]

0.00% is still correctly classified
0.0% have been correctly classified at every step


  0%|          | 0/10001 [00:00<?, ?it/s]

0.57% is still correctly classified
0.4961832061068702% have been correctly classified at every step


In [23]:
khmodel_adversarial_data_PGD = {
    "accuracy dict": accuracy_dict_ll_pgd, 
     "actual accuracy dict": accuracy_act_ll_pgd,
    "critical epsilon": criteps_ll_pgd, 
    "critical distance": critdist_ll_pgd, 
}
bpmodel_adversarial_data_PGD = {
    "accuracy dict": accuracy_dict_bp_pgd, 
     "actual accuracy dict": accuracy_act_bp_pgd,
    "critical epsilon": criteps_bp_pgd, 
    "critical distance": critdist_bp_pgd, 
}

with open(model_path / Path("khmodel_attack_statistics_PGD.pkl"), 'wb') as handle:
    pickle.dump(khmodel_adversarial_data_PGD, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
with open(model_path / Path("bpmodel_attack_statistics_PGD.pkl"), 'wb') as handle:
    pickle.dump(bpmodel_adversarial_data_PGD, handle, protocol=pickle.HIGHEST_PROTOCOL)

1.0