In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os

PROJECT_ROOT=os.path.join(os.path.dirname(os.path.abspath(os.pardir)))
sys.path.append(PROJECT_ROOT)

import torch
import torchvision
from torchvision.transforms import transforms 
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from PIL import Image
from datetime import date


from object_detection.mnist_model import MNIST
from object_detection import mnist_inference
from object_detection import mnist_evaluation
from object_detection.fgsm_attack import fgsm_attack
from object_detection.transform import Invertor


In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CONFIG = {"batch_size": 200}

In [None]:
transform = transforms.Compose([Invertor(),transforms.Compose([transforms.ToTensor()])])
# загружаем тестовую выборку
test_data = torchvision.datasets.MNIST(
    "mnist_content", train=False, transform=transform, download=True
)

test_dataloader=torch.utils.data.DataLoader(
    dataset=test_data, 
    batch_size=1,
    shuffle=False
)

In [None]:
# Загружаем обученную модель
model = MNIST()
model = model.to(DEVICE)

MODEL_FPATH = os.path.join(PROJECT_ROOT, "checkpoints", "mnist_checkpoints", "best.pth")
model.load_state_dict(torch.load(MODEL_FPATH)["model_state"])

# Инициализируем инференс 
infer = mnist_inference.Inference(model, device= DEVICE)

# установим модель в режим оценки. В данном случае это относится к слоям отсева
model.eval()


In [None]:
def data_fgsm_attack(model, device, test_loader, epsilon ):
    output_data=[]
    # Loop over all examples in test set
    for data, target in tqdm(test_loader):
        # Send the data and label to the device
        data, target = data.to(device), target.to(device)
        
        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data.reshape(-1, 28 * 28))
        init_pred = output.max(1, keepdim=True)[1]# get the index of the max log-probability
        
        # If the initial prediction is wrong, dont bother attacking, just move on
        if init_pred.item() != target.item():
            continue

        # Calculate the loss
        loss = F.nll_loss(output, target)

        # Zero all existing gradients
        model.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect datagrad
        data_grad = data.grad.data
        
        # Call FGSM Attack
        perturbed_data = fgsm_attack(data.reshape(-1, 28 * 28), epsilon, data_grad.reshape(-1, 28 * 28))

        # Transform tensor to type PIL.Image.Image 
        perturbed_data = np.reshape(perturbed_data.detach().numpy(),(28,28))
        perturbed_data = Image.fromarray(perturbed_data)

    
        output_data.append([perturbed_data,target.item()])

    return output_data


In [None]:
_SAVE= True
epsilon=[0.05, 0.1]
for eps in epsilon:
    # Getting data with fgsm attack
    print('\n Creating perturbed data ')
    perturbed_data=data_fgsm_attack(model,DEVICE,test_dataloader,eps)
    
    mnist_evaluator= mnist_evaluation.MnistEvaluator(infer,perturbed_data)
    # Getting predictions 
    print('Getting predictions')
    predictions = mnist_evaluator.evaluate()
    # Getting metrics
    metrics =mnist_evaluator.classification_report()
    
    print(f'Classification report with fgsm attack eps={eps}')
    print(metrics)
    
    if _SAVE:
        predictions.to_csv(os.path.join(PROJECT_ROOT,'results','mnist_fgsm_attacks',f'pred_eps_{eps}_{date.today()}.csv'),index_label='id')
        metrics.to_csv(os.path.join(PROJECT_ROOT,'results','mnist_fgsm_attacks',f'classification_report_eps_{eps}_{date.today()}.csv'),index_label='label')

