In [None]:
import numpy as np
import torch
from tqdm import tqdm
from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist, fashion_mnist
from torchvision.datasets import EMNIST
from time import time
import numpy as np
import robust_onlinehd

In [None]:
dataset = 'mnist'

In [None]:
# loads simple mnist dataset
def load():
    if dataset == 'mnist':
        (x, y), (x_test, y_test) = mnist.load_data()
    elif dataset == 'fashion_mnist':
        (x, y), (x_test, y_test) = fashion_mnist.load_data()
    else:
        temp = EMNIST('./data/EMNIST', split = 'letters', train = True, download = True)
        x = temp.data.unsqueeze(3).numpy().transpose((0,2,1,3))
        y = temp.targets.numpy() - 1

        temp = EMNIST('./data/EMNIST', split = 'letters', train = False, download = True)
        x_test = temp.data.unsqueeze(3).numpy().transpose((0,2,1,3))
        y_test = temp.targets.numpy() - 1 

    # changes data to pytorch's tensors
    x = torch.from_numpy(x).float()   
    y = torch.from_numpy(y).long().squeeze()
    x_test = torch.from_numpy(x_test).float()
    y_test = torch.from_numpy(y_test).long().squeeze()
    
    if len(x.shape) == 3:
        x = x.unsqueeze(3)
        x_test = x_test.unsqueeze(3)

    return x, x_test, y, y_test


print('Loading...')
x, x_test, y, y_test = load()

In [None]:
seeds = ['seed27', 'seed33', 'seed54', 'seed71', 'seed88']

In [None]:
hp = 'hp1'

In [None]:
robust_cache = []
for seed in seeds:
    robust_cache.append(torch.load('%s/%s/full_result/robust_onlinehd_%s.pt' % (hp, seed, dataset)))

In [None]:
origin_cache = []
for seed in seeds:
    origin_cache.append(torch.load('%s/%s/nothing/robust_onlinehd_%s.pt' % (hp, seed, dataset)))

In [None]:
pops = np.concatenate([origin_cache[i]['pops'] for i in range(len(origin_cache))])

In [None]:
imgs = [[] for i in range(len(origin_cache))]
misclassified = [[] for i in range(len(origin_cache))]
correct = [[] for i in range(len(origin_cache))]
for i in range(len(origin_cache)):
    pops = origin_cache[i]['pops']
    success_idx = origin_cache[i]['success_idx']
    indices = origin_cache[i]['indices']
    targets = origin_cache[i]['targets']
    for s in success_idx:
        imgs[i].append(pops[s[1][0]][s[1][1]])
        misclassified[i].append(targets[s[0]].item())
        correct[i].append(y_test[s[0]].item())
imgs = [torch.Tensor(imgs[i]).float() for i in range(len(origin_cache))]
correct = [torch.Tensor(correct[i]).float() for i in range(len(origin_cache))]

In [None]:
models = []
for r in robust_cache:
    models.append(r['model'])

In [None]:
success = 0
for m in range(len(models)):
    success += (correct[m] == models[m](imgs[m]).cpu()).sum().item()

In [None]:
total = np.array([len(c) for c in correct]).sum()

In [None]:
total

In [None]:
print(success / total)

In [None]:
if dataset == 'emnist':
    labels = {
        0 : 'a',
        1 : 'b',
        2 : 'c',
        3 : 'd',
        4 : 'e',
        5 : 'f',
        6 : 'g',
        7 : 'h',
        8 : 'i',
        9 : 'j',
        10 : 'k',
        11 : 'l',
        12 : 'm',
        13 : 'n',
        14 : 'o',
        15 : 'p',
        16 : 'q',
        17 : 'r',
        18 : 's',
        19 : 't',
        20 : 'u',
        21 : 'v',
        22 : 'w',
        23 : 'x',
        24 : 'y',
        25 : 'z'
    }
elif dataset == 'mnist':
    labels = {
        0 : '0',
        1 : '1',
        2 : '2',
        3 : '3',
        4 : '4',
        5 : '5',
        6 : '6',
        7 : '7',
        8 : '8',
        9 : '9'
    }
else:
    labels = {
        0 : 'T-shirt/top',
        1 : 'Trouser',
        2 : 'Pullover',
        3 : 'Dress',
        4 : 'Coat',
        5 : 'Sandal',
        6 : 'Shirt',
        7 : 'Sneaker',
        8 : 'Bag',
        9 : 'Ankle boot'
    }

In [None]:
idx1 = torch.randint(0, len(imgs), (1,))
idx2 = torch.randint(0, len(imgs[idx1]), (1,))
print(idx1.item(), idx2.item())
f, axes = plt.subplots(1, 2)
axes[0].imshow(imgs[idx1][idx2].squeeze(), cmap=plt.gray())
_ = axes[0].set_title('Robust model : %s' % labels[models[0](imgs[idx1][idx2]).item()])
axes[1].imshow(imgs[idx1][idx2].squeeze())
_ = axes[1].set_title('Origin model : %s' % labels[misclassified[idx1][idx2]])

eminst = 125, 183 

fashion mnist = 209, 253, 224


In [None]:
imgs = []
misclassified = []
correct = []
for i in range(len(origin_cache)):
    pops = origin_cache[i]['pops']
    success_idx = origin_cache[i]['success_idx']
    indices = origin_cache[i]['indices']
    targets = origin_cache[i]['targets']
    for s in success_idx:
        imgs.append(pops[s[1][0]][s[1][1]])
        misclassified.append(targets[s[0]].item())
        correct.append(y_test[s[0]].item())
imgs = torch.Tensor(imgs).float()
correct = torch.Tensor(correct).float()

In [None]:
success = 0
for m in range(len(models)):
    success += (correct == models[m](imgs).cpu()).sum().item()

In [None]:
success = success / len(models)

In [None]:
success / total

In [None]:
plt.imshow(imgs[206])

In [None]:
plt.imshow(models[0].quantizing(models[0].local_maximum(imgs[206].unsqueeze(0))).squeeze())

In [None]:
plt.imshow(models[0].quantizing(imgs[206].unsqueeze(0)).squeeze())

In [None]:
f, axes = plt.subplots(1, 4)
f.set_figheight(15)
f.set_figwidth(15)
axes[0].imshow(imgs[150], cmap=plt.gray())
_ = axes[0].set_title('original image')
axes[1].imshow(models[0].local_maximum(imgs[150].unsqueeze(0)).squeeze())
_ = axes[1].set_title('local maximum only')
axes[2].imshow(models[0].quantizing(imgs[150].unsqueeze(0)).squeeze())
_ = axes[2].set_title('quantization only')
axes[3].imshow(models[0].quantizing(models[0].local_maximum(imgs[150].unsqueeze(0))).squeeze())
_ = axes[3].set_title('maxpooling + quantization')