# Setup environment

In [None]:
# Essentials
import os
import numpy as np
import importlib
import json
from datetime import datetime

os.environ["GIT_PYTHON_GIT_EXECUTABLE"] = "/usr/bin/git"

import git
git.refresh("/usr/bin/git")

from foolbox import PyTorchModel

# PyTorch
import torch
from torch.utils.data import DataLoader

# Utils
import utils
importlib.reload(utils)

from utils import get_files, save_image, make_dirs, get_model, select_gpu, get_data, get_class_weigths
from utils import CustomTransforms, My_data, FocalLoss

# OnePixelAttack
import OnePixelAttack
importlib.reload(OnePixelAttack)

# TriangleAttack
import TriangleAttack
importlib.reload(TriangleAttack)

import ProjectedGradientDescent
importlib.reload(ProjectedGradientDescent)

## Setup Cuda

In [None]:
# Set a higher max split size to avoid memory problems
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [None]:
device = torch.device(f"cuda:{select_gpu()}" if torch.cuda.is_available() else "cpu")

print(device)

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary(device=None, abbreviated=False))
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

## Setup Transformers

In [None]:
custom_transforms = CustomTransforms()
resize_transform = custom_transforms.get_transform("resize_tensor")
test_transform = custom_transforms.get_transform("test")

## Get the data

In [None]:
# Load train and test files that are used for the model.
train_dict, test_dict = get_data(os.path.join(os.getcwd(), "BreaKHis_v1"))

In [None]:
test_files = "./dataset/test/original/**/**/*.png"

In [None]:
org_norm = My_data(get_files(test_files), transforms=test_transform)
org_dataloader = DataLoader(org_norm)

## Setup network

In [None]:
model_name = "swin"
model = get_model(device, model_name)
model.eval()

## Setup performance metrics

In [None]:
filepaths = {
    "Network": ["SWIN", "RESNET", "RETRAINED SWIN", "RETRAINED RESNET"],
    "Attack": ["Regular", "OnePixel", "Triangle", "PGD"]
}
network = 0

# Evaluate performance

## Evaluate model performance

In [None]:
plot_data = {
    str(i): [] for i in range(org_norm.__len__())
}

In [None]:
with torch.no_grad():
    for i, (image, label) in enumerate(org_dataloader):
        path = org_norm.__getpath__(i)
        true_label = [index for index, value in enumerate(label[0].tolist()) if value == 1]
        confs = model(image.to(device))
        pred_label = torch.argmax(confs, dim=1)
    
        plot_data[str(i)] = [path, true_label, pred_label.item(), confs.cpu().numpy()[0].tolist()]

In [None]:
name = "{} - {} - {}".format(filepaths["Network"][network], filepaths["Attack"][0], datetime.now().strftime("%Y-%m-%d %H%M"))

with open("{}.txt".format(name), "w") as output:
    output.write(json.dumps(plot_data))

## Perform and evaluate Advanced OnePixel Attack

In [None]:
with torch.no_grad():
        op_advs = OnePixelAttack.attack(1, model, device, org_dataloader, pixel_count=1, maxiter=50, popsize=15)

In [None]:
plot_data = {
    str(i): [] for i in range(org_norm.__len__())
}

In [None]:
make_dirs("test/one_pixel_attack/"+model_name)
with torch.no_grad():
    for i, (image, label) in enumerate(org_dataloader):
        path = org_norm.__getpath__(i)
        true_label = [index for index, value in enumerate(label[0].tolist()) if value == 1]
        confs = model(op_advs[i].to(device))
        pred_label = torch.argmax(confs, dim=1)
        save_image(op_advs[i][0], "./dataset/test/one_pixel_attack/"+model_name+"/" + path)
        plot_data[str(i)] = [path, true_label, pred_label.item(), confs.cpu().numpy()[0].tolist()]

In [None]:
name = "{} - {} - {}".format(filepaths["Network"][network], filepaths["Attack"][1], datetime.now().strftime("%Y-%m-%d %H%M"))

with open("{}.txt".format(name), "w") as output:
    output.write(json.dumps(plot_data))

## Perform and evaluate Triangle attack   

In [None]:
#model = PyTorchModel(get_model(device, model_name).eval(), bounds=(0,1), device=device)
pt_model = PyTorchModel(model.eval(), bounds=(0,1), device=device)
# run on a10 gpu as it has highest memory, is processes about 100 images in 20 mins for the SWIN
print("Attack !")

with torch.no_grad():
    ta_model = TriangleAttack.TA(pt_model, input_device=device)
    my_advs, q_list, my_intermediates, max_length = ta_model.attack(org_dataloader)
    print('TA Attack Done')

In [None]:
plot_data = {
    str(i): [] for i in range(len(my_advs))
}

make_dirs("test/triangle_attack/"+model_name)
for i in range(len(my_advs)):
    path = org_norm.__getpath__(i)
    true_label = [index for index, value in enumerate(org_norm.__getitem__(i)[1].tolist()) if value == 1]
    confs = model(my_advs[i].unsqueeze(0))
    pred_label = torch.argmax(confs, dim=1)
    
    save_image(my_advs[i], "./dataset/test/triangle_attack/"+model_name+"/" + path)
    plot_data[str(i)] = [path, true_label, pred_label.item(), confs.detach().cpu().numpy()[0].tolist()]

In [None]:
name = "{} - {} - {}".format(filepaths["Network"][network], filepaths["Attack"][2], datetime.now().strftime("%Y-%m-%d %H%M"))

with open("{}.txt".format(name), "w") as output:
    output.write(json.dumps(plot_data))

## Perform and evaluate Project Gradient Descent attack

In [None]:
plot_data = {
    str(i): [] for i in range(org_norm.__len__())
}
# Avoid using with torch.no_grad()
pgd_advs = ProjectedGradientDescent.pgd_attack(org_dataloader, model, device, get_class_weigths(train_dict).to(device))

In [None]:
make_dirs("test/pgd_attack/"+model_name)
with torch.no_grad():
    for i, (image, label) in enumerate(org_dataloader):
        path = org_norm.__getpath__(i)
        true_label = [index for index, value in enumerate(label[0].tolist()) if value == 1]
        confs = model(pgd_advs[i].to(device))
        pred_label = torch.argmax(confs, dim=1)
        save_image(pgd_advs[i][0], "./dataset/test/pgd_attack/"+model_name+"/" + path)
        plot_data[str(i)] = [path, true_label, pred_label.item(), confs.cpu().numpy()[0].tolist()]

In [None]:
name = "{} - {} - {}".format(filepaths["Network"][network], filepaths["Attack"][3], datetime.now().strftime("%Y-%m-%d %H%M"))

with open("{}.txt".format(name), "w") as output:
    output.write(json.dumps(plot_data))