In [4]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import time
import os
import torch

from datasets.mnist import MNIST

from models.cae_model import CAE
from models.conv_model import CNN

from train import train_ae, train_cnn
%load_ext autoreload
%autoreload 2
from cem import ContrastiveExplanationMethod

# set random seeds for reproducability (although the CEM is fully determininstic)
torch.manual_seed(0)
np.random.seed(0)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
dataset = MNIST(batch_size=64)
# dataset = FashionMNIST()

# Training the classifier

In [6]:
cnn = CNN(device="cuda:0")

train_cnn(cnn, dataset, iterations=7, lr=0.1, save_fn='mnist-cnn', device="cuda:0", load_path="./models/saved_models/mnist-cnn.h5")

  out = nn.functional.softmax(out)


loss after step 0:2.3022148609161377 accuracy: 0.109375
loss after step 100:2.3019509315490723 accuracy: 0.140625
loss after step 200:2.301959991455078 accuracy: 0.0625
loss after step 300:2.30271577835083 accuracy: 0.078125
loss after step 400:2.3030552864074707 accuracy: 0.125
loss after step 500:2.302427053451538 accuracy: 0.078125
loss after step 600:2.3021762371063232 accuracy: 0.109375
loss after step 700:2.302530288696289 accuracy: 0.078125
loss after step 800:2.3017024993896484 accuracy: 0.109375
loss after step 900:2.3015575408935547 accuracy: 0.15625
done with iteration: 0/7
loss after step 0:2.3027946949005127 accuracy: 0.09375
loss after step 100:2.3013687133789062 accuracy: 0.125
loss after step 200:2.3005127906799316 accuracy: 0.1875
loss after step 300:2.3006529808044434 accuracy: 0.15625
loss after step 400:2.301255226135254 accuracy: 0.125
loss after step 500:2.3014659881591797 accuracy: 0.09375
loss after step 600:2.3002796173095703 accuracy: 0.125
loss after step 700

In [None]:
images, _ = dataset.get_batch()

output = cnn(images)

images = images.numpy()
output = output.detach().numpy()


In [None]:
# # evaluate the cnn by uncommenting this cell

# total_acc = 0
# total_batches = 0
# for step, (batch_inputs, batch_targets) in enumerate(dataset.test_loader):
    
#     predictions = cnn(batch_inputs)
#     acc = (predictions.argmax(1).cpu().numpy() == batch_targets.cpu().numpy()).sum()/(predictions.shape[0] )
#     total_batches += 1
#     total_acc += acc
    
# print("acc: {}".format(total_acc / total_batches))

# Training the autoencoder

This section trains the autoencoder which will be used as regularizer for the data space which the perturbations are found in.

In [None]:
# Train or load autoencoder
cae = CAE(device="cpu")

train_ae(cae, dataset, iterations=10, save_fn="mnist-cae", device="cpu", load_path="models/saved_models/mnist-cae-no-rs.h5")

In [None]:
# obtain one batch of test images
images, _ = dataset.get_batch()

images += 0.5

#images_flatten = images.view(images.size(0), -1)
# get sample outputs
output = cae(images)
# prep images for display
images = images.numpy()

# output is resized into a batch of images
# output = output.view(batch_size, 1, 28, 28)
# use detach when it's an output that requires_grad
output = output.detach().numpy()

# plot the first ten input images and then reconstructed images
fig, axes = plt.subplots(nrows=2, ncols=4, sharex=True, sharey=True, figsize=(25,4))

# input images on top row, reconstructions on bottom
for images, row in zip([images, output], axes):
    for img, ax in zip(images, row):
        ax.imshow(np.squeeze(img))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

# Contrastive Explanation Method

In [None]:
# optimal: kappa 30, gamma 1.0, beta 0.1, lr 0.01

kappa = 10
gamma = 1.0
beta = 0.1
lr = 0.01
device = "cpu"

CEM = ContrastiveExplanationMethod(
    cnn,
    cae,
    iterations=1000,
    n_searches=9,
    kappa=kappa,
    gamma=gamma,
    beta=beta,
    learning_rate=lr,
    c_init=10.0,
    device=device,
    verbal=False
)

In [None]:
def save_imgs():
    # save the created images
    dirname = "saved_perturbations/mode-{}-kappa-{}-gamma-{}-beta-{}-lr-{}".format(mode, kappa, gamma, beta, lr)
    os.makedirs(dirname, exist_ok=True)
    
    fname_orig = dirname + "/{}-cb-{}-ca-{}-orig.png".format(int(time.time()), before, after)
    fname_pert = dirname + "/{}-before-{}-after-{}-pert.png".format(int(time.time()), before, after)
    fname_combined = dirname + "/{}-before-{}-after-{}-pn.png".format(int(time.time()), before, after)
    fname_combined_pp = dirname + "/{}-before-{}-after-{}-pp.png".format(int(time.time()), before, after)
    
    plt.imsave(fname_orig, image.squeeze(), cmap="gray")
    plt.imsave(fname_pert, best_delta.view(28,28) - image.squeeze(), cmap="gray")
    plt.imsave(fname_combined, best_delta.view(28,28), cmap="gray")
    plt.imsave(fname_combined_pp, image.squeeze() - best_delta.view(28,28), cmap="gray")

In [None]:
for i in range(10):
    # obtain one sample
    image = dataset.get_sample_by_class(class_label=i, show_image=False).to(device)

    print("IMAGE FROM CLASS: {}".format(i))
    before = np.argmax(cnn(image.squeeze(-1)).detach().cpu()).item()
    
    for mode in ["PP", "PN"]:
        print("mode: {}".format(mode))
        best_delta = CEM.explain(image, mode=mode)
        
        if mode == "PP":
            after = np.argmax(cnn(image.squeeze(-1) - best_delta.view(1,28,28)).detach().cpu()).item()
        else:
            after = np.argmax(cnn(best_delta.view(1,28,28)).detach().cpu()).item()
        
        save_imgs()

In [None]:
# print original image
plt.imshow(image.view(28,28), cmap="gray")
plt.show()

# classification before
before = np.argmax(cnn(image.squeeze(-1)).detach()).item()
print("classification before perturbation: {}".format(before))

if mode == "PP":
    plt.imshow(image.squeeze() - CEM.best_delta.view(28,28), cmap="gray")
    plt.show()
    after = np.argmax(cnn(image.squeeze(-1) - CEM.best_delta.view(1,28,28)).detach()).item()
    print("classification of delta: {}".format(after))
else:
    plt.imshow(CEM.best_delta.view(28,28),  cmap="gray")
    plt.show()
    after = np.argmax(cnn(CEM.best_delta.view(1,28,28)).detach()).item()
    print("classification after perturbation: {}".format(after))

In [None]:
print(float("inf") > 5)

In [None]:
from models.conv_model_copy import CNN as CNN1

image, _ = dataset.get_sample()

cnn1 = CNN1()

cnn1(image)
