In [1]:
import os

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch import nn
from torch.optim import Adam

from pathlib import Path
import LocalLearning_copy as LocalLearning
from tqdm.notebook import tqdm
import numpy as np
from matplotlib import pyplot as plt
from sklearn.linear_model import LogisticRegression

ModuleNotFoundError: No module named 'sklearn'

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
# Hyper parameters:
BATCH_SIZE = 1000
NUMBER_OF_EPOCHS = 1000
LEARNING_RATE = 1e-4

# loss function
ce_loss = torch.nn.CrossEntropyLoss()

In [None]:
cifar10Train= LocalLearning.LpUnitCIFAR10(
            root="../data/CIFAR10",
            train=True,
            transform=ToTensor(),
            p=3.0,
        )

cifar10Test= LocalLearning.LpUnitCIFAR10(
            root="../data/CIFAR10",
            train=False,
            transform=ToTensor(),
            p=3.0,
        )

TestLoader = LocalLearning.DeviceDataLoader(
            cifar10Test,
            device=device,
            batch_size=BATCH_SIZE,
            num_workers=4,
            shuffle=True,
        )

TrainLoader = LocalLearning.DeviceDataLoader(
            cifar10Train,
            device=device,
            batch_size=BATCH_SIZE,
            num_workers=4,
            shuffle=True,
        )

In [None]:
def load_trained_model_bp(idx):
    trained_model_bp_path = Path("../data/models/KHModelCIFAR10_ensemble/bp")
    file_names_trained_bp = os.listdir(trained_model_bp_path)
    file_names_trained_bp = [fn for fn in file_names_trained_bp if os.path.isfile(trained_model_bp_path / Path(fn))]
    
    trained_model_bp = Path(file_names_trained_bp[idx])
    
    with torch.no_grad():
        trained_state_bp = torch.load(trained_model_bp_path/trained_model_bp)
        model_ps_bp = trained_state_bp["fkhl3-state"]
        model_bp = LocalLearning.KHModel_bp(model_ps_bp)
        model_bp.eval()
        model_bp.load_state_dict(trained_state_bp["model_state_dict"])
        model_bp.to(device)
        
    return model_bp

def load_trained_model_ll(idx):
        
    trained_model_ll_path = Path("../data/models/KHModelCIFAR10_ensemble/ll")
    file_names_trained_ll = os.listdir(trained_model_ll_path)
    file_names_trained_ll = [fn for fn in file_names_trained_ll if os.path.isfile(trained_model_ll_path / Path(fn))]
    
    trained_model_ll = Path(file_names_trained_ll[idx])
    
    with torch.no_grad():
        trained_state_ll = torch.load(trained_model_ll_path/trained_model_ll)
        model_ps_ll = trained_state_ll["fkhl3-state"]
        model_ll = LocalLearning.KHModel(model_ps_ll)
        model_ll.eval()
        model_ll.load_state_dict(trained_state_ll["model_state_dict"])
        model_ll.to(device)
        
    return model_ll

In [None]:
"""
fkhl3-path": str(llmodels_path / model_file),
"fkhl3-state": ll_trained_state,
"model_state_dict": khmodel.state_dict(),
"loss_history": loss_history,
"accuracy_history": accuracy_history
""" 
print()

In [None]:
def acc_total(
    test: DataLoader,
    model: LocalLearning.KHModel, 
    thres,
    crit
    ):
    
    freq_correct = 0
    model.eval()
    total = 0
    for batch_no, (features, labels) in enumerate(test):
        preds = model(features)
        pred = torch.argmax(preds, dim=-1)
        
        if crit == "correct_thres":
            softmax_correct = (preds[torch.arange(1000),pred])
            thres_idx = (softmax_correct >= thres)
            correct_idx = (torch.abs(pred - labels) == 0)
            filtr_idx = thres_idx & correct_idx 
            new_preds = pred[filtr_idx]
            new_labels = labels[filtr_idx]
            total += len(new_labels)
            
        elif crit == "thres":
            softmax_correct = (preds[torch.arange(1000),pred])
            thres_idx = (softmax_correct >= thres)
            new_preds = pred[thres_idx]
            new_labels = labels[thres_idx]
            total += len(new_labels)
            
        else: 
            correct_idx = (torch.abs(pred - labels) == 0)
            new_preds = pred[correct_idx]
            new_labels = labels[correct_idx]
            total += len(new_labels)
        
        freq_correct += (torch.abs(new_preds - new_labels) == 0).sum()
        
    correct = (freq_correct / total).item()
    
    return correct, total

In [None]:
def acc_total_n_models(n,modeltype,dataloader,thres,crit=None):
    list_acc = []
    list_n = []
    
    if modeltype == "bp": 
        for i in tqdm(range(n)):
            model = load_trained_model_bp(i)
            correct, total = acc_total(dataloader, model, thres, crit)
            list_acc.append(correct)
            list_n.append(total)
        
    if modeltype == "ll": 
        for i in tqdm(range(n)):
            model = load_trained_model_ll(i)
            correct, total = acc_total(dataloader, model, thres, crit)
            list_acc.append(correct)
            list_n.append(total)
        
    return list_acc, list_n 

In [None]:
n = 100
dataloader = TestLoader
threshold = 0.8

In [None]:
def print_info(n, model, dataloader, threshold, crit):
    acc, tot = acc_total(n, model, dataloader, threshold, crit)
    if crit == "correct_thres":
        print(f"Criterium = Correct and above 0.8")
        
    elif crit == "thres":
        print(f"Criterium = Above 0.8")

    else: 
        print(f"Criterium = Correct")
        
    print(f"Mean correct for {n} {model} models with Softmax >= {threshold} on training data : {np.mean(acc)*100:.2f} %")
    print(f"Mean number of pictures = {np.mean(tot)}")

In [None]:
for model in ["bp", "ll"]:
    for crit in ["correct_thres", "correct", "thres"]:
        pass

Choosing correct as the only critirium and the TestLoader as the data set 

In [None]:
def data_critirium(
    dataloader,
    model, 
    crit,
    thres = None
    ):
    
    # Returns the data that meets the critirium given (crit)
    
    freq_correct = 0
    model.eval()
    total = 0
    
    data = torch.zeros((0,32,32,3)).to(device)
    lab_data = torch.zeros((0)).to(device)
    
    for batch_no, (features, labels) in enumerate(dataloader):
        preds = model(features)
        pred = torch.argmax(preds, dim=-1)
        
        if crit == "correct":
            filtr_idx = (torch.abs(pred - labels) == 0)
            new_features = features[filtr_idx]
            data = torch.cat((data, new_features),dim=0)
            new_labels = labels[filtr_idx]
            lab_data = torch.cat((lab_data, new_labels),dim=0)
        
        elif crit == "correct_thres":
            softmax_correct = (preds[torch.arange(1000),pred])
            thres_idx = (softmax_correct >= thres)
            correct_idx = (torch.abs(pred - labels) == 0)
            filtr_idx = thres_idx & correct_idx 
            new_features = features[filtr_idx]
            data = torch.cat((data, new_features),dim=0)
            new_labels = labels[filtr_idx]
            lab_data = torch.cat((lab_data, new_labels),dim=0)
            
        else:
            softmax_correct = (preds[torch.arange(1000),pred])
            filtr_idx = (softmax_correct >= thres)
            new_features = features[filtr_idx]
            data = torch.cat((data, new_features),dim=0)
            new_labels = labels[filtr_idx]
            lab_data = torch.cat((lab_data, new_labels),dim=0)
    
    return data, lab_data

In [None]:
model7 = load_trained_model_ll(7)
data7, lab7 = data_critirium(TestLoader, model7, "correct_thres", thres = 0.8)
model8 = load_trained_model_bp(8)
data8, lab8 = data_critirium(TestLoader, model8, "correct_thres", thres = 0.8)

In [None]:
def test_attack(
    feats,
    labs, 
    model, 
    attack,
    loss_fn, 
    optimizer, 
    eps,
    std=None,
    ):
    
    freq_correct = 0
    total = 0 

    for i in range(3):
        features = feats[i*1000:(i+1)*1000]
        labels = (labs[i*1000:(i+1)*1000])
        labels = labels.type(torch.LongTensor).to(device)
        features.requires_grad = True
        preds = model(features)
        loss = loss_fn(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        perturbed = attack(features,eps)

        preds_perturbed = torch.argmax(model(perturbed), dim=-1)
        freq_correct += (torch.abs(preds_perturbed - labels) == 0).sum()
        total += len(preds)

    correct = (freq_correct/total).item()
    
    return correct


def FGSM(features, epsilon):
    perturbed_image = features + epsilon*features.grad.data.sign()
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

In [None]:
AdamOpt_ll = Adam(model7.parameters(), lr=LEARNING_RATE)
AdamOpt_bp = Adam(model8.parameters(), lr=LEARNING_RATE)
epslist = [0.0001*i for i in range(100)]

listll = []
listbp = []

for eps in epslist:
    correct = test_attack(data7, lab7, model7, FGSM, ce_loss, AdamOpt_ll, eps)
    listll.append(correct*100)
    
print()
for eps in epslist:
    correct = test_attack(data8, lab8, model8, FGSM, ce_loss, AdamOpt_bp, eps)
    listbp.append(correct*100)

In [None]:
plt.plot(epslist, listll)
plt.show()
plt.plot(epslist, listbp)
plt.show()