In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os

PROJECT_ROOT=os.path.join(os.path.dirname(os.path.abspath(os.pardir)))
sys.path.append(PROJECT_ROOT)

import torch
import torchvision
from torchvision.transforms import transforms 
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from PIL import Image
from datetime import date
from ipywidgets import interact, IntSlider
import matplotlib.pyplot as plt



from object_detection.mnist_model import MNIST
from object_detection import mnist_inference
from object_detection import mnist_evaluation
from object_detection.transform import Invertor
from object_detection.fgsm_attack import fgsm_attack


In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CONFIG = {"batch_size": 200}
SAVE_DPATH = os.path.join(PROJECT_ROOT,'results','mnist_fgsm_attacks')
os.makedirs(SAVE_DPATH, exist_ok=True)


In [None]:
# transform = transforms.Compose([Invertor(),transforms.Compose([transforms.ToTensor()])])
transform = transforms.Compose(
    [
        Invertor(),
        transforms.ToTensor()
    ]
)
# загружаем тестовую выборку
test_data = torchvision.datasets.MNIST(
    "mnist_content", train=False, transform=transform, download=True
)

test_dataloader=torch.utils.data.DataLoader(
    dataset=test_data, 
    batch_size=1,
    shuffle=False
)

In [None]:
# Загружаем обученную модель
model = MNIST()
model = model.to(DEVICE)

MODEL_FPATH = os.path.join(PROJECT_ROOT, "checkpoints", "mnist_checkpoints", "best.pth")
model.load_state_dict(torch.load(MODEL_FPATH)["model_state"])

# Инициализируем инференс 
infer = mnist_inference.Inference(model, device= DEVICE)

# установим модель в режим оценки. В данном случае это относится к слоям отсева
model.eval()


In [None]:
def data_fgsm_attack(model, device, test_loader, epsilon ):
    output_data=[]
    origin_data=[]
    # Loop over all examples in test set
    for data, target in tqdm(test_loader):
        # Send the data and label to the device
        data, target = data.to(device), target.to(device)
        
        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data.reshape(-1, 28 * 28))
        init_pred = output.max(1, keepdim=True)[1]# get the index of the max log-probability
        
        # If the initial prediction is wrong, dont bother attacking, just move on
        if init_pred.item() != target.item():
            continue

        # Calculate the loss
        loss = F.nll_loss(output, target)

        # Zero all existing gradients
        model.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect datagrad
        data_grad = data.grad.data
        
        # Call FGSM Attack
        perturbed_data = fgsm_attack(data.reshape(-1, 28 * 28), epsilon, data_grad.reshape(-1, 28 * 28))

        # Transform tensor to type PIL.Image.Image 
        perturbed_data = np.reshape(perturbed_data.cpu().detach().numpy(),(28,28))
        perturbed_data = Image.fromarray(perturbed_data)

        output_data.append([perturbed_data,target.item()])

        # Collect origin data examples for visualisation 
        original_img = np.reshape(data.cpu().detach().numpy(),(28,28))
        original_img = Image.fromarray(original_img)
        origin_data.append([original_img,target.item()])

    return output_data, origin_data


Рассмотрим влияние атак на точность работы модели. Возьмем несколько epsilon для анализа.

<center>Epsilon = 0.05</center>

In [None]:
_SAVE_PRED= True
_SAVE_METRICS=True
eps=0.05

# Getting data with fgsm attack
print('\n Creating perturbed data ')
perturbed_data,origin_data = data_fgsm_attack(model,DEVICE,test_dataloader,eps)

mnist_evaluator= mnist_evaluation.MnistEvaluator(infer,perturbed_data)
# Getting predictions 
print('Getting predictions')
predictions = mnist_evaluator.evaluate()
# Getting metrics
metrics =mnist_evaluator.classification_report()

print(f'Classification report with fgsm attack eps={eps}')
print(metrics)

if _SAVE_PRED:
    predictions.to_csv(os.path.join(SAVE_DPATH,f'pred_eps_{eps}_{date.today()}.csv'),index_label='id')
if _SAVE_METRICS:
    metrics.to_csv(os.path.join(SAVE_DPATH,f'classification_report_eps_{eps}_{date.today()}.csv'),index_label='label')




Получим предсказания для оригинальныых изображений

In [None]:
mnist_evaluator_origin_data= mnist_evaluation.MnistEvaluator(infer,origin_data)
# Getting predictions 
print('Getting predictions for original img')
predictions_without_attack = mnist_evaluator_origin_data.evaluate()

# Getting metrics
metrics_original_img =mnist_evaluator_origin_data.classification_report()

print(f'Classification report with origianl img')
print(metrics_original_img)



In [None]:
# Сохраним данные для визуализации при eps=0.05
eps_05=eps
predictions_05=predictions
perturbed_data_05=perturbed_data

Посмотрим на оригинальные изображения и изображения с аттакой с epsilon = 0.05

In [None]:
@interact
def show_predictions(index=IntSlider(val=0, min=0, max=len(perturbed_data)-1)):

    or_img,_= origin_data[index]
    or_img = np.array(or_img)
    pred_label_without_attack = predictions_without_attack.iloc[index,1]


    test_img_05,target_05= perturbed_data_05[index]
    test_img_05 = np.array(test_img_05)
    pred_label_05 = predictions_05.iloc[index,1]

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=[10, 10])

    ax[0].imshow(or_img,cmap='gray')
    ax[0].set_title("Original image")
    ax[1].imshow(test_img_05,cmap='gray')
    ax[1].set_title(f"Epsilon : {eps_05}")
    
    print(f"True Label: {target_05}")
    print(f"Predicted Label without attack: {pred_label_without_attack}")
    print(f"Predicted Label with fgsm attack: {pred_label_05}")
    plt.show()

<center>Epsilon = 0.1</center>

In [None]:
_SAVE_PRED= True
_SAVE_METRICS=True
eps=0.1

# Getting data with fgsm attack
print('\n Creating perturbed data ')
perturbed_data,origin_data=data_fgsm_attack(model,DEVICE,test_dataloader,eps)

mnist_evaluator= mnist_evaluation.MnistEvaluator(infer,perturbed_data)
# Getting predictions 
print('Getting predictions')
predictions = mnist_evaluator.evaluate()
# Getting metrics
metrics =mnist_evaluator.classification_report()

print(f'Classification report with fgsm attack eps={eps}')
print(metrics)

if _SAVE_PRED:
    predictions.to_csv(os.path.join(SAVE_DPATH,f'pred_eps_{eps}_{date.today()}.csv'),index_label='id')
if _SAVE_METRICS:
    metrics.to_csv(os.path.join(SAVE_DPATH,f'classification_report_eps_{eps}_{date.today()}.csv'),index_label='label')



Получим предсказания для оригинальных изображений 

In [None]:
mnist_evaluator_origin_data = mnist_evaluation.MnistEvaluator(infer,origin_data)
# Getting predictions 
print('Getting predictions for original img')
predictions_without_attack = mnist_evaluator_origin_data.evaluate()

# Getting metrics
metrics_original_img = mnist_evaluator_origin_data.classification_report()

print(f'Classification report with original images')
print(metrics_original_img)


In [None]:
# Сохраним данные для визуализации 
eps_01=eps
predictions_01=predictions
perturbed_data_01=perturbed_data

Посмотрим на изображение с epsilon = 0.1

In [None]:

@interact
def show_predictions(index=IntSlider(val=0, min=0, max=len(perturbed_data)-1)):

    or_img,_= origin_data[index]
    or_img = np.array(or_img)
    pred_label_without_attack = predictions_without_attack.iloc[index,1]


    test_img_01,target_01= perturbed_data_01[index]
    test_img_01 = np.array(test_img_01)
    pred_label_01 = predictions_01.iloc[index,1]

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=[10, 10])

    ax[0].imshow(or_img,cmap='gray')
    ax[0].set_title("Original image")
    ax[1].imshow(test_img_01,cmap='gray')
    ax[1].set_title(f"Epsilon : {eps_01}")
    
    print(f"True Label: {target_01}")
    print(f"Predicted Label without attack: {pred_label_without_attack}")
    print(f"Predicted Label: {pred_label_01}")
    plt.show()
