In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
import random

PROJECT_DPATH = os.path.abspath(os.pardir)
DATA_DPATH = os.path.join(PROJECT_DPATH, "data")

# for pip environment
sys.path.append(PROJECT_DPATH)

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
from PIL import Image
from tqdm import tqdm

from mnist_recognition.inference import Inference
from mnist_recognition.evaluation import Evaluator
from mnist_recognition.fgsm_attack import fgsm_attack
from mnist_recognition.models import MlpModel
from mnist_recognition.transforms import Invertor, AlbuAugmentation, Convertor
from mnist_recognition.utils.fs import get_date_string

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True

In [None]:
SAVE_DPATH = os.path.join(PROJECT_DPATH, "results", get_date_string())
os.makedirs(SAVE_DPATH, exist_ok=True)

## Загрузка обученной модели

In [None]:
BATCH_SIZE = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint_dpath = os.path.join(PROJECT_DPATH, "checkpoints")
# model_name = "best_valid_with_augmentations.pth"
model_name = "best_valid.pth"
model_fpath = os.path.join(checkpoint_dpath, model_name)

infer = Inference.from_file(model_fpath, device= DEVICE)

## Оценка на исходной тестовой выборке 

In [None]:
test_data = torchvision.datasets.MNIST(
   DATA_DPATH, train=False, transform=Invertor(), download=True
)

print(f"Тестовая выборка содержит {len(test_data)} изображений")

In [None]:
evaluator= Evaluator(infer, test_data)

In [None]:
predictions = evaluator.evaluate()

In [None]:
metrics = evaluator.classification_report()

print(metrics)

In [None]:
_SAVE = True 

if _SAVE: 
    fpath = os.path.join(SAVE_DPATH, f"source_test_evaluation_{model_name.split('.')[0]}.csv")
    predictions.to_csv(fpath, index_label='id')

    metric_fpath = os.path.join(SAVE_DPATH, f"source_test_classification_report_{model_name.split('.')[0]}.csv")
    metrics.to_csv(metric_fpath, index_label='label')

## Оценка на тестовой выборке c альбументациями

In [None]:
albu = AlbuAugmentation()

test_data_with_augmentation = []
for data in tqdm(test_data, desc="Test Data Processing"):
    img, label = data
    img = np.array(img)
    # adding albumentations
    transformed_img = albu(img)

    # Transform array to type PIL.Image.Image 
    transformed_img = Image.fromarray(transformed_img)

    # Collect data examples
    test_data_with_augmentation.append([transformed_img,label])

print(f"Тестовая выборка содержит {len(test_data_with_augmentation)} изображений")

In [None]:
evaluator= Evaluator(infer, test_data_with_augmentation)

In [None]:
predictions = evaluator.evaluate()

In [None]:
metrics = evaluator.classification_report()

print(metrics)

In [None]:
_SAVE = True 

if _SAVE: 
    fpath = os.path.join(SAVE_DPATH, f"aug_test_evaluation_{model_name.split('.')[0]}.csv")
    predictions.to_csv(fpath, index_label='id')

    metric_fpath = os.path.join(SAVE_DPATH, f"aug_test_classification_report_{model_name.split('.')[0]}.csv")
    metrics.to_csv(metric_fpath, index_label='label')

## Оценка на данных с атаками

In [None]:
def data_fgsm_attack(model, device, test_loader, epsilon ):
    output_data=[]
    origin_data=[]
    # Loop over all examples in test set
    for data, target in tqdm(test_loader):
        # Send the data and label to the device
        data, target = data.to(device), target.to(device)
        
        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data.reshape(-1, 28 * 28))
        init_pred = output.max(1, keepdim=True)[1]# get the index of the max log-probability
        
        # If the initial prediction is wrong, dont bother attacking, just move on
        if init_pred.item() != target.item():
            continue

        # Calculate the loss
        loss = F.nll_loss(output, target)

        # Zero all existing gradients
        model.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect datagrad
        data_grad = data.grad.data
        
        # Call FGSM Attack
        perturbed_data = fgsm_attack(data.reshape(-1, 28 * 28), epsilon, data_grad.reshape(-1, 28 * 28))

        # Transform tensor to type PIL.Image.Image 
        perturbed_data = np.reshape(perturbed_data.cpu().detach().numpy(),(28,28))
        perturbed_data = Image.fromarray(perturbed_data)

        output_data.append([perturbed_data,target.item()])

        # Collect origin data examples for visualisation 
        original_img = np.reshape(data.cpu().detach().numpy(),(28,28))
        original_img = Image.fromarray(original_img)
        origin_data.append([original_img,target.item()])

    return output_data, origin_data

In [None]:
transform = transforms.Compose(
    [
        Invertor(),
        transforms.ToTensor()
    ]
)
# загружаем тестовую выборку
test_data = torchvision.datasets.MNIST(
    DATA_DPATH, train=False, transform=transform, download=True
)

test_dataloader = torch.utils.data.DataLoader(
    dataset=test_data, 
    batch_size=1,
    shuffle=False
)

model = MlpModel()
model = model.to(DEVICE)
model.load_state_dict(torch.load(model_fpath)["model_state"])

eps=0.1 
perturbed_data, origin_data = data_fgsm_attack(model, DEVICE, test_dataloader, eps)

print(f"Данные с атаками содержат {len(perturbed_data)} изображений.")

In [None]:
evaluator = Evaluator(infer, perturbed_data)

In [None]:
predictions = evaluator.evaluate()

In [None]:
metrics = evaluator.classification_report()

print(f'Classification report with fgsm attack eps={eps}')
print(metrics)

In [None]:
_SAVE = True 

if _SAVE: 
    fpath = os.path.join(SAVE_DPATH, f"fsgm_{eps}_test_evaluation_{model_name.split('.')[0]}.csv")
    predictions.to_csv(fpath, index_label='id')

    metric_fpath = os.path.join(SAVE_DPATH, f"fsgm_{eps}_test_classification_report_{model_name.split('.')[0]}.csv")
    metrics.to_csv(metric_fpath, index_label='label')

## Оценка на комбинированной выборке (атаки + альбументации)

In [None]:
transform_aug = transforms.Compose(
    [Invertor(), Convertor(), AlbuAugmentation(), transforms.ToTensor()]
)

test_data_with_augmentation = torchvision.datasets.MNIST(
   DATA_DPATH, train=False, transform= transform_aug, download=True
)

test_dataloader_aug = torch.utils.data.DataLoader(
    dataset=test_data_with_augmentation, 
    batch_size=1,
    shuffle=False
)

eps=0.05
perturbed_data_aug, _ = data_fgsm_attack(model, DEVICE, test_dataloader_aug, eps)

print(f"Комбинированная выборка содержит {len(perturbed_data)} изображений.")

In [None]:
evaluator = Evaluator(infer, perturbed_data_aug)

In [None]:
predictions = evaluator.evaluate()

In [None]:
metrics = evaluator.classification_report()

print(f'Classification report with fgsm attack eps={eps}')
print(metrics)

In [None]:
_SAVE = True 

if _SAVE: 
    fpath = os.path.join(SAVE_DPATH, f"mixed_{eps}_test_evaluation_{model_name.split('.')[0]}.csv")
    predictions.to_csv(fpath, index_label='id')

    metric_fpath = os.path.join(SAVE_DPATH, f"mixed_{eps}_test_classification_report_{model_name.split('.')[0]}.csv")
    metrics.to_csv(metric_fpath, index_label='label')