##### Please make sure to download and preprocess the data before running this notebook
##### Use "run.sh" to download the data
##### Use "data_prep_*.ipynb" to preprocess the data

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.datasets as dset
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import scipy.io
import random
from PIL import Image
from torchvision import transforms
import matlab
import matlab.engine
import bayesiancoresets as bc

%load_ext autoreload
%autoreload 2

# Feed data into the model and extract intermediate features

In [None]:
# load the pre-trained model
from model import SiameseNet
model = SiameseNet()
model.load_state_dict(torch.load('ckpt/exp_old/best_model_ckpt.tar')['model_state'])
model.eval()
model.cuda()

In [None]:
def extract_feature(x):
    out = F.relu(F.max_pool2d(model.conv1(x), 2))
    out = F.relu(F.max_pool2d(model.conv2(out), 2))
    out = F.relu(F.max_pool2d(model.conv3(out), 2))
    out = F.relu(model.conv4(out))
    out = out.view(out.shape[0], -1)
    return out

def compare_feature(x1, h2):
    h2 = F.sigmoid(model.fc1(h2))
    h1 = model.sub_forward(x1)
    diff = torch.abs(h1 - h2)
    scores = model.fc2(diff)
    return scores

In [None]:
# load the dataset
import torchvision.datasets as dset
dataset = dset.ImageFolder(root='./data/changed/test')
transform = transforms.ToTensor()

# Extracts the intermediate features from the network

In [None]:
# extract intermediate features
last_y = -1
features = []
with torch.no_grad():
    for x, y in dataset:
        if y == last_y:
            x = x.convert('L')
            x = transform(x).unsqueeze(0).cuda()
            stacked = torch.cat((stacked, x), 0).cuda()
        else:
            if last_y != -1:
                features.append(extract_feature(stacked).cpu().detach())
            last_y = y
            if last_y % 10 == 9:
                print(last_y + 1, "classes done")
            x = x.convert('L')
            stacked = transform(x).unsqueeze(0).cuda()
features.append(extract_feature(stacked).cpu().detach())
torch.save(features, 'features.pth')

In [None]:
# load the pre-computed features
features = torch.load('features.pth')

# Select images
##### There are 4 algorithms available: SP, GIGA, FW, RND
##### Since the selection may cost some time, we do not provide a function that does every thing together
##### Instead, dedicating the number of images to select and the algorithm to use, it will return a set of indeces
##### Apply this set of indeces to the last cell to test this case

In [None]:
# SP
M = 3 # number of points
selected = []
eng = matlab.engine.start_matlab()
M_mat = matlab.double([M])

for i, f in enumerate(features):
    print(i)
    f_mat = matlab.double([f.numpy()[i].tolist() for i in range(f.size(0))])
    s = eng.SP(eng.transpose(f_mat), M_mat)
    selected.append([int(ind) - 1 for ind in s[0]])

eng.quit()

h = torch.Tensor(0)
for i, f in zip(selected, features):
    h = torch.cat((h, f[i].unsqueeze(0)), 0)

In [None]:
# GIGA
M = 3 # number of points
select_alg = bc.GIGA
selected = []

for i, f in enumerate(features):
    alg = select_alg(f.numpy())
    alg.run(M)
    wts = alg.weights()
    new_M = M
    while len(wts.nonzero()[0]) < M:
        new_M +=  M - len(wts.nonzero()[0])
        alg.run(new_M)
        wts = alg.weights()
    selected.append(wts.nonzero()[0].tolist())

h = torch.Tensor(0)
for i, f in zip(selected, features):
    h = torch.cat((h, f[i].unsqueeze(0)), 0)

In [None]:
# FW
M = 3 # number of points
select_alg = bc.FrankWolfe
selected = []

for i, f in enumerate(features):
    alg = select_alg(f.numpy())
    alg.run(M)
    wts = alg.weights()
    new_M = M
    while len(wts.nonzero()[0]) < M:
        new_M +=  M - len(wts.nonzero()[0])
        alg.run(new_M)
        wts = alg.weights()
    selected.append(wts.nonzero()[0].tolist())

h = torch.Tensor(0)
for i, f in zip(selected, features):
    h = torch.cat((h, f[i].unsqueeze(0)), 0)

In [None]:
# RND
M = 3 # number of points
selected = []

for i in range(len(features)):
    selected.append(np.random.randint(20, size=M))

h = torch.Tensor(0)
for i, f in zip(selected, features):
    h = torch.cat((h, f[i].unsqueeze(0)), 0)

# Visualization example. The selected images are marked in red

In [None]:
plt.figure(figsize=(20, 2))
for i, (x1, x2) in enumerate(dataset):
    if i == 20:
        break

    # display top
    ax = plt.subplot(2, 10, i + 1)
    if i in selected[0]:
        x1.paste(Image.new('RGBA', (20, 20), 'red'), (0, 0))
    plt.imshow(x1)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.tight_layout(pad=0, w_pad=-70)
plt.show()

# Test the accuracy

In [None]:
with torch.no_grad():
    correct = 0
    count = 0
    for i, (x, y) in enumerate(dataset):
        if i % 100 == 99:
            print('processed numbers:', i + 1)
            print('current accuracy:', correct / count)
        if i % 20 not in selected[i // 20]:
            count += 1
            x = x.convert('L')
            x = transform(x).unsqueeze(0).cuda()
            x = torch.cat(h.size(1)*[x], 0)
            all_out = []
            for h2 in h:
                h2 = h2.cuda()
                out = F.sigmoid(compare_feature(x, h2)).mean()
                all_out.append(out.item())
            pred = np.argmax(all_out)
            correct += pred == y
acc = correct / count
print('\nfinal accuracy:', acc)