In [1]:
# Necessary
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split
import torchvision
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torchdiffeq import odeint_adjoint as odeint
# from jupyterthemes import jtplot
from utils import *
# jtplot.style(theme="chesterish")
 # CONSTANT 
device = "cuda"
EPOCHS=1
BATCH_SIZE=32
IMG_SIZE=(32,32,3)

In [2]:
# Load data
DIR = "./data/mnist/"
MNIST = torchvision.datasets.MNIST(DIR,
                                   train=True,
                                   transform=None,
                                   target_transform=None, download=True)
ls = enumerate(MNIST)

#ds_len_, normal_ds_, pertubed_ds_ = preprocess_data(MNIST)


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [3]:
cnn_model = Network()
ode_func = ODEBlock()
ode_model = ODENet(ode_func)


# In[4]:


def model_state_dict_parallel_convert(state_dict, mode):
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    if mode == 'to_single':
        for k, v in state_dict.items():
            name = k.replace("module.","")  # remove 'module.' of DataParallel
            new_state_dict[name] = v
    elif mode == 'to_parallel':
        for k, v in state_dict.items():
            name = 'module.' + k  # add 'module.' of DataParallel
            new_state_dict[name] = v
    elif mode == 'same':
        new_state_dict = state_dict
    else:
        raise Exception('mode = to_single / to_parallel')

    return new_state_dict 
ode_state_dict = torch.load("./model/ode_origin/mnist_origin_origin.pt",map_location=torch.device('cuda'))
ode_state_dict = model_state_dict_parallel_convert(ode_state_dict, mode="to_single")
ode_model.load_state_dict(ode_state_dict)
cnn_state_dict = torch.load("./model/cnn_origin/mnist_origin_origin.pt",map_location=torch.device('cuda'))
cnn_state_dict = model_state_dict_parallel_convert(cnn_state_dict, mode="to_single")
cnn_model.load_state_dict(cnn_state_dict)

ode_model = ode_model.to(device)
cnn_model = cnn_model.to(device)


In [4]:
# def visualize_model(model,data, typ="ode", sigma = 50.0):
#     import random
#     cnt = 0
#     tr = []; fal = []
#     for _, dp in list(enumerate(data)):
#         np_dp = torch.tensor(np.array(dp[0]).reshape((1,1,28,28))).float()
#         np_dp = np_dp + torch.normal(torch.zeros(np_dp.shape),torch.ones(np_dp.shape) * sigma).float()
#         np_dp = np_dp.to(device)
#         preds = model(np_dp)
#         preds = torch.argmax(preds,dim=1).item()
#         if preds == dp[1]:  tr.append((np_dp,dp[1],preds))
#         else: fal.append((np_dp,dp[1],preds))
#         cnt+=1
#         if cnt == 6001: break
#     print(f"With sigma = {sigma}, the number of images going to wrong is: {len(fal)} / 2000 (images)")
#     random.shuffle(tr)
#     random.shuffle(fal)
#     return tr, fal
# def plot(tr, fal, typ="ode"):
#     fig, ax = plt.subplots(2,5,sharex=True,sharey=True)
#     for i in range(5):
#         ax[0][i].imshow(tr[i][0].cpu().detach().numpy().reshape((28,28)),cmap="gray")
#         ax[0][i].set_xlabel(f"truths: {tr[i][1]}\nPred: {tr[i][2]}")
#         ax[1][i].imshow(fal[i][0].cpu().detach().numpy().reshape((28,28)),cmap="gray")
#         ax[1][i].set_xlabel(f"truths: {fal[i][1]}\nPred: {fal[i][2]}")
#     #plt.xlabel(f"{typ} model with sigma {sigma}")
#     plt.show()
# tr1, fal1 = visualize_model(cnn_model, MNIST, typ="cnnnet", sigma=sig)
# tr2, fal2 = visualize_model(ode_model, MNIST, typ="odenet", sigma=sig)

In [5]:
# sig = 100.0

# plot(tr1,fal1, typ="ode")

# plot(tr2,fal2, typ="cnn")

In [6]:
# _ds_len, _ds = preprocess_data(MNIST)

# print(_ds)

In [7]:
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon*sign_data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_image = torch.clamp(perturbed_image, 0, 1)

    return perturbed_image

In [8]:
def test(model, device, test_loader, loss_fn, epsilon):
    correct = 0
    model.eval()
    for data, target in test_loader:

        data, target = data.to(device), target.to(device)
        data.requires_grad = True

        # Forward pass the data through the model
        output = model(data)
        _, init_pred = torch.max(output, 1)
        _, correct_labels = torch.max(target, 1)

        # If the initial prediction is wrong, skip
        if init_pred.item() != correct_labels.item():
            continue

        loss = loss_fn(output.float(), target.float())
        model.zero_grad()

        loss.backward()
        data_grad = data.grad.data
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # Re-classify the perturbed image
        output = model(perturbed_data)
        _, final_pred = torch.max(output, 1)
        if final_pred.item() == correct_labels.item():
            correct += 1

    final_acc = correct/float(len(test_loader))
    print("Epsilon: {}\tTest Accuracy = {} / {} = {}".format(epsilon, correct, len(test_loader), final_acc))

    
    return final_acc

In [9]:
import random

def preprocess_fgsm_data(data, shape = (28,28), device="cpu", numTest=1000):
    X = []
    Y = []
    ds = {}
    cnt = 0
    for data_idx, (x,y) in list(enumerate(data)):
        if random.random()<0.5:
            continue
        X.append(np.array(x).reshape((1,shape[0],shape[0])))
        Y.append(y)
        cnt += 1
        if cnt==numTest:
            break
    y_data = F.one_hot(torch.Tensor(Y).to(torch.int64), num_classes=10)
    y_data = y_data.to(device)
    x_data = torch.Tensor(X)
    x_data = x_data.to(device)
    
    ds.update({"original": TensorDataset(x_data / 255.0, y_data)})
    ds_len = len(Y)
    return ds_len, ds


In [10]:
_ds_len, _ds = preprocess_fgsm_data(MNIST)


In [11]:
import time
val_loader = DataLoader(_ds['original'], batch_size=1)
loss_fn = torch.nn.functional.binary_cross_entropy_with_logits
print('CNN model')
start = time.time()
test(cnn_model, device, val_loader, loss_fn, 0.15)
print('took: ',time.time() - start)

start = time.time()
test(cnn_model, device, val_loader, loss_fn, 0.3)
print('took: ',time.time() - start)

start = time.time()
test(cnn_model, device, val_loader, loss_fn, 0.5)
print('took: ',time.time() - start)

print('---------')
print('ODE model')
start = time.time()
test(ode_model, device, val_loader, loss_fn, 0.15)
print('took: ',time.time() - start)

start = time.time()
test(ode_model, device, val_loader, loss_fn, 0.3)
print('took: ',time.time() - start)

start = time.time()
test(ode_model, device, val_loader, loss_fn, 0.5)
print('took: ',time.time() - start)

CNN model


  input = module(input)


Epsilon: 0.15	Test Accuracy = 57 / 1000 = 0.057
took:  5.2390055656433105
Epsilon: 0.3	Test Accuracy = 5 / 1000 = 0.005
took:  5.890883445739746
Epsilon: 0.5	Test Accuracy = 18 / 1000 = 0.018
took:  5.934651613235474
---------
ODE model
Epsilon: 0.15	Test Accuracy = 312 / 1000 = 0.312
took:  52.79925298690796
Epsilon: 0.3	Test Accuracy = 62 / 1000 = 0.062
took:  53.87685799598694
Epsilon: 0.5	Test Accuracy = 50 / 1000 = 0.05
took:  53.409910440444946
