In [25]:
# Necessary
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchdiffeq import odeint_adjoint as odeint
from utils import *

# CONSTANT 
device = "cuda"
EPOCHS=1
BATCH_SIZE=32
IMG_SIZE=(32,32,3)

In [26]:
# Load data
DIR = "./data/mnist/"
MNIST = torchvision.datasets.MNIST(DIR,
                                   train=True,
                                   transform=None,
                                   target_transform=None, download=True)


#ds_len_, normal_ds_, pertubed_ds_ = preprocess_data(MNIST)


In [27]:
def model_state_dict_parallel_convert(state_dict, mode):
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    if mode == 'to_single':
        for k, v in state_dict.items():
            name = k[7:]  # remove 'module.' of DataParallel
            new_state_dict[name] = v
    elif mode == 'to_parallel':
        for k, v in state_dict.items():
            name = 'module.' + k  # add 'module.' of DataParallel
            new_state_dict[name] = v
    elif mode == 'same':
        new_state_dict = state_dict
    else:
        raise Exception('mode = to_single / to_parallel')

    return new_state_dict 


In [28]:
lr = 1e-3

loss_fn = torch.nn.functional.binary_cross_entropy_with_logits
ode_func = ODEBlock().to(device)
ode_model = ODENet(ode_func, device=device).to(device)
ode_optimizer = torch.optim.Adam(ode_model.parameters(), lr=lr)
cnn_model = Network().to(device)
cnn_optimizer = torch.optim.Adam(cnn_model.parameters(), lr=lr)


In [29]:
cnn_state_dict = torch.load("./model/cnn_origin/mnist_original_origin.pt",map_location=torch.device(device))
cnn_state_dict = model_state_dict_parallel_convert(cnn_state_dict, mode="to_single")
cnn_model.load_state_dict(cnn_state_dict)
ode_state_dict = torch.load("./model/ode_origin/mnist_original_origin.pt",map_location=torch.device(device))
ode_state_dict = model_state_dict_parallel_convert(ode_state_dict, mode="to_single")
ode_model.load_state_dict(ode_state_dict)

<All keys matched successfully>

In [30]:
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    return perturbed_image


In [57]:
def test(model, device, test_loader, loss_fn, epsilon):
    correct = 0
    for data, target in test_loader:

        data, target = data.to(device), target.to(device)
        data.requires_grad = True

        # Forward pass the data through the model
        outputs = model(data)
        _, init_pred = torch.max(outputs, 1)
        _, correct_labels = torch.max(target, 1)

        # If the initial prediction is wrong, skip
        if (init_pred == correct_labels).sum().item() != correct_labels.size():
            continue

        loss = loss_fn(outputs, correct_labels)
        model.zero_grad()

        loss.backward()
        data_grad = data.grad.data
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # Re-classify the perturbed image
        output = model(perturbed_data)
        _, final_pred = torch.max(outputs, 1)
        if final_pred.item() == target.item():
            correct += 1

    final_acc = correct/float(len(test_loader))
    print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))

    
    return final_acc


In [51]:
def preprocess_fgsm_data(data, shape = (28,28), device="cpu"):
    X = []
    Y = []
    ds = {}
    for data_idx, (x,y) in list(enumerate(data)):
        X.append(np.array(x).reshape((1,shape[0],shape[0])))
        Y.append(y)
    y_data = F.one_hot(torch.Tensor(Y).to(torch.int64), num_classes=10)
    y_data = y_data.to(device)
    x_data = torch.Tensor(X)
    x_data = x_data.to(device)
    
    ds.update({"original": TensorDataset(x_data / 255.0, y_data)})
    ds_len = len(Y)
    return ds_len, ds


In [44]:
_ds_len, _ds = preprocess_fgsm_data(MNIST)

In [60]:
# data_dis = 2000
# train_set, val_set, _ = torch.utils.data.random_split(_ds['original'],data_dis)
#print(type(train_set))
# assert isinstance(train_set,torch.utils.data.Dataset)
val_loader = DataLoader(_ds['original'], shuffle=True, batch_size=1)


In [61]:
test(cnn_model, device, val_loader, loss_fn, 0.3)


Epsilon: 0.3	Test Accuracy = 0 / 60000 = 0.0


0.0