# Does ReDO learn classification ?

To run this notebook you have to have following libraries installed (tested with specified verisinons):

 * numpy == 1.17.4

 * scipy == 1.3.3

 * PIL == 4.3.0

 * sklearn == 0.21.3

 * pytorch == 1.3.1

The directories should be organized as follows: 
 
1) In the same directory with this notebook there should be downloaded repositoriy of [ReDO](https://github.com/mickaelChen/ReDO.git). 
 
2) Download and extract: Dataset, Segmentations, Image labels and data splits from http://www.robots.ox.ac.uk/~vgg/data/flowers/102/. The obtained jpg folder, segmin folder and setid.mat file should be placed in the folder 'data/flowers'.

3) Download and unzip weigth [dataset_nets_state.tar.gz](https://drive.google.com/drive/folders/1hUb2iOTJAbWw1NotWGAsEt4ASomhOwbh) into 'weights' folder

 We provide script "prepare_working_directory.sh" which does everything **except step 3**, which need to be done manually.


In [0]:
import os
import random

import numpy as np
from scipy import io
from PIL import Image
from sklearn import neighbors
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from torch.nn.utils import spectral_norm
from torchvision.utils import save_image

import torchvision
from  torchvision import datasets, transforms


import ReDO.models as models
import ReDO.datasets as datasets

device = torch.device('cuda:0')
DATAPATH = 'data/flowers'
WEIGHTPATH = 'weights'

Here we redefine classes of some Neural Networks provided by ReDO and the dataset class

In [0]:
# We redefine class of segmentator 
# to make it return embeddings from the middle part of the network
class _netEncM(nn.Module):
    def __init__(self, sizex=128, nIn=3, nMasks=2, nRes=5, nf=128, temperature=1):
        super(_netEncM, self).__init__()
        self.nMasks = nMasks
        sizex = sizex // 4 
        self.cnn = nn.Sequential(*([models._downConv(nIn, nf)] +
                                   [models._resBloc(nf=nf) for i in range(nRes)]))
        self.psp = nn.ModuleList([nn.Sequential(nn.AvgPool2d(sizex),
                                                nn.Conv2d(nf,1,1),
                                                nn.Upsample(size=sizex, mode='bilinear')),
                                  nn.Sequential(nn.AvgPool2d(sizex//2, sizex//2),
                                                nn.Conv2d(nf,1,1),
                                                nn.Upsample(size=sizex, mode='bilinear')),
                                  nn.Sequential(nn.AvgPool2d(sizex//3, sizex//3),
                                                nn.Conv2d(nf,1,1),
                                                nn.Upsample(size=sizex, mode='bilinear')),
                                  nn.Sequential(nn.AvgPool2d(sizex//6, sizex//6),
                                                nn.Conv2d(nf,1,1),
                                                nn.Upsample(size=sizex, mode='bilinear'))])
        self.out = models._upConv(1 if nMasks == 2 else nMasks, nf+4)
        self.temperature = temperature
    def forward(self, x):
        f = self.cnn(x)
        # m = self.out(torch.cat([f] + [pnet(f) for pnet in self.psp], 1))
        # if self.nMasks == 2:
        #     m = torch.sigmoid(m / self.temperature)
        #     m = torch.cat((m, (1-m)), 1)
        # else:
        #     m = F.softmax(m / self.temperature, dim=1)
        return f

# We redefine class of the Discriminator to make it
# return embeddings from the last FC layerinstead of real/fake image prediction
class _resDiscriminator128(nn.Module):
    def __init__(self, nIn=3, nf=64, selfAtt=False):
        super(_resDiscriminator128, self).__init__()
        self.blocs = []
        self.sc = []
        # first bloc
        self.bloc0 = nn.Sequential(spectral_norm(nn.Conv2d(nIn, nf, 3, 1, 1, bias=True)),
                                   nn.ReLU(),
                                   spectral_norm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)),
                                   nn.AvgPool2d(2),)
        self.sc0 = nn.Sequential(nn.AvgPool2d(2),
                                 spectral_norm(nn.Conv2d(nIn, nf, 1, bias=True)),)
        if selfAtt:
            self.selfAtt = models.SelfAttention(nf)
        else:
            self.selfAtt = nn.Sequential()
        # Down blocs
        for i in range(4):
            nfPrev = nf
            nf = nf*2
            self.blocs.append(nn.Sequential(nn.ReLU(),
                                            spectral_norm(nn.Conv2d(nfPrev, nf, 3, 1, 1, bias=True)),
                                            nn.ReLU(),
                                            spectral_norm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)),
                                            nn.AvgPool2d(2),))
            self.sc.append(nn.Sequential(nn.AvgPool2d(2),
                                         spectral_norm(nn.Conv2d(nfPrev, nf, 1, bias=True)),))
        # Last Bloc
        self.blocs.append(nn.Sequential(nn.ReLU(),
                                        spectral_norm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)),
                                        nn.ReLU(),
                                        spectral_norm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))))
        self.sc.append(nn.Sequential())
        self.dense = nn.Linear(nf, 1)
        self.blocs = nn.ModuleList(self.blocs)
        self.sc = nn.ModuleList(self.sc)
    def forward(self, x):
        x = self.selfAtt(self.bloc0(x) + self.sc0(x))
        for k in range(len(self.blocs)):
            x = self.blocs[k](x) + self.sc[k](x)
        x = x.sum(3).sum(2)
        return x


# We redefine class of Dataset to be able to store labels of images
class FlowersDataset(torch.utils.data.Dataset):
    def __init__(self, dataPath, sets='train', transform=transforms.ToTensor()):
        super(FlowersDataset, self).__init__()
        self.files =  io.loadmat(os.path.join(dataPath, "setid.mat"))
        self.labels = io.loadmat(os.path.join(dataPath, "imagelabels.mat")).get('labels')[0]
        # if sets == 'train':
        #    self.files = self.files.get('tstid')[0]
        # elif sets == 'val':
        #    self.files = self.files.get('valid')[0]
        # else:
        #    self.files = self.files.get('trnid')[0]
        self.transform = transform
        self.datapath = dataPath
    def __len__(self):
        return len(self.files.get('tstid')[0]) + \
               len(self.files.get('valid')[0]) + \
               len(self.files.get('trnid')[0])
    def __getitem__(self, idx):
        imgname = "image_%05d.jpg" % (idx + 1)
        segname = "segmim_%05d.jpg" % (idx + 1)
        label   = self.labels[idx]
        img = self.transform(Image.open(os.path.join(self.datapath, "jpg", imgname)))
        seg = np.array(Image.open(os.path.join(self.datapath, "segmim", segname)))
        seg = 1 - ((seg[:,:,0:1] == 0) + (seg[:,:,1:2] == 0) + (seg[:,:,2:3] == 254))
        seg = (seg * 255).astype('uint8').repeat(3,axis=2)
        seg = self.transform(Image.fromarray(seg))[:1]
        return img * 2 - 1, seg, label

Load weights

In [0]:
states = torch.load(WEIGHTPATH, map_location={'cuda:0' : 'cuda:0'})
opt = states['options']
if "netEncM" in states:
    netEncM = _netEncM(sizex=opt.sizex, nIn=opt.nx, nMasks=opt.nMasks, nRes=opt.nResM, nf=opt.nfM, temperature=opt.temperature).to(device)
    netEncM.load_state_dict(states["netEncM"])
    netEncM.eval()
if "netGenX" in states:
    netGenX = models._netGenX(sizex=opt.sizex, nOut=opt.nx, nc=opt.nz, nf=opt.nfX, nMasks=opt.nMasks, selfAtt=opt.useSelfAttG).to(device)
    netGenX.load_state_dict(states["netGenX"])
    netGenX.eval()
if "netRecZ" in states:
    netRecZ = models._netRecZ(sizex=opt.sizex, nIn=opt.nx, nc=opt.nz, nf=opt.nfZ, nMasks=opt.nMasks).to(device)
    netRecZ.load_state_dict(states["netRecZ"])
    netRecZ.eval()
if "netDX" in states:
    netDX = _resDiscriminator128(nIn=opt.nx, nf=opt.nfD, selfAtt=opt.useSelfAttD).to(device)
    netDX.load_state_dict(states["netDX"])
    netDX.eval()

Load dataset

In [0]:
dataset = FlowersDataset(DATAPATH, "train",
            torchvision.transforms.Compose([torchvision.transforms.Resize(opt.sizex, Image.NEAREST),
                                            torchvision.transforms.CenterCrop(opt.sizex),
                                            torchvision.transforms.ToTensor(),
            ]))

Create a function to automatize evaluation. The function takes callback function which takes data sample and computes embeddings. Then it trains KNN on the specified number of batches and compute accuracy on the rest of the dataset.

In [0]:
def evaluate_embedding(get_embedding, batch_size=100, batches_for_train=2):
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    # Compute embeddings
    train_X, train_y = get_embedding(loader, n_batches=batches_for_train)
    test_X, test_y = get_embedding(loader, n_batches=len(loader) - batches_for_train)
    print('Sizes are: ', train_X.size(), test_X.size())

    # Train KNN
    clf = neighbors.KNeighborsClassifier(1, weights='distance')
    clf.fit(train_X, train_y)

    # Compute accuracy
    pred_y = clf.predict(test_X)
    accuracy = accuracy_score(test_y, pred_y)
    print('Accuracy: ', accuracy)
    return accuracy

# Expirements

## Raw Images KNN 

In [0]:
def get_data_raw(loader, n_batches=1):
    labels = []
    embeddings = []
    with torch.no_grad():
      for i in range(n_batches):
        xData, mData, batch_labels = next(iter(loader))
        xData = xData.to(device)
        mData = mData.to(device)
        batch_embeddings = netEncM(xData).reshape(100, -1)

        labels.append(batch_labels)
        embeddings.append(batch_embeddings)

    return torch.cat(embeddings, dim=0).cpu() , torch.cat(labels, dim=0).cpu()

In [0]:
evaluate_embedding(get_data_raw)

## Raw images resized to 3x16x16

In [0]:
def get_data_raw_resized(loader, n_batches=1):
    labels = []
    embeddings = []
    with torch.no_grad():
        for i in range(n_batches):
            xData, mData, batch_labels = next(iter(loader))
            xData = xData.to(device)
            mData = mData.to(device)
            batch_embeddings = F.max_pool2d(xData, 8).reshape(100, -1)

            labels.append(batch_labels)
            embeddings.append(batch_embeddings)

    return torch.cat(embeddings, dim=0).cpu() , torch.cat(labels, dim=0).cpu()

In [0]:
evaluate_embedding(get_data_raw_resized)

## KNN on Segmentation embeddings

In [0]:
def get_data_from_segm(loader, n_batches=1):
    labels = []
    embeddings = []
    with torch.no_grad():
        for i in range(n_batches):
            xData, mData, batch_labels = next(iter(loader))
            xData = xData.to(device)
            mData = mData.to(device)
            batch_embeddings = netEncM(xData).reshape(100, -1)

            labels.append(batch_labels)
            embeddings.append(batch_embeddings)


    return torch.cat(embeddings, dim=0).cpu() , torch.cat(labels, dim=0).cpu()

In [0]:
evaluate_embedding(get_data_from_segm)

## KNN on Discriminator Embeddings

In [0]:
def get_data_from_descr(loader, n_batches=1):
    labels = []
    embeddings = []
    with torch.no_grad():
        for i in range(n_batches):
            xData, mData, batch_labels = next(iter(loader))
            xData = xData.to(device)
            mData = mData.to(device)
            batch_embeddings = netDX(xData)

            labels.append(batch_labels)
            embeddings.append(batch_embeddings) #xData.reshape(100, -1))


    return torch.cat(embeddings, dim=0).cpu() , torch.cat(labels, dim=0).cpu()

In [0]:
evaluate_embedding(get_data_from_descr)

0.229625

### The below code is redundant and we didn't use it in our report.

In [0]:
class AutoEncoder(nn.Module):
    def __init__(self, code_size):
        super().__init__()
        self.code_size = code_size
        # Encoder
        self.enc_cnn_1 = nn.Conv2d(3, 10, kernel_size=5)
        self.enc_cnn_2 = nn.Conv2d(10, 20, kernel_size=5)
        self.enc_cnn_3 = nn.Conv2d(20, 30, kernel_size=5)
        self.enc_linear_1 = nn.Linear(30 * 12 * 12, 2048)
        self.enc_linear_2 = nn.Linear(2048, self.code_size)
        # Decoder
        self.dec_linear_1 = nn.Linear(self.code_size, 2048)
        self.dec_linear_2 = nn.Linear(2048, 3* IMAGE_WIDTH * IMAGE_HEIGHT)
    def forward(self, images):
        code = self.encode(images)
        out = self.decode(code)
        return out, code
    def encode(self, images):
        code = self.enc_cnn_1(images)
        code = F.selu(F.max_pool2d(code, 2))
        code = self.enc_cnn_2(code)
        code = F.selu(F.max_pool2d(code, 2))
        code = self.enc_cnn_3(code)
        code = F.selu(F.max_pool2d(code, 2))
        code = code.view([images.size(0), -1])
        code = F.selu(self.enc_linear_1(code))
        code = self.enc_linear_2(code)
        return code
    def decode(self, code):
        out = F.selu(self.dec_linear_1(code))
        out = torch.sigmoid(self.dec_linear_2(out))
        out = out.view([code.size(0), 3, IMAGE_WIDTH, IMAGE_HEIGHT])
        return out

IMAGE_WIDTH = IMAGE_HEIGHT = 128

In [0]:
# Hyperparameters
code_size = 1024
num_epochs = 5
batch_size = 128
lr = 0.01
optimizer_cls = optim.Adam
# Load data
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Instantiate model
autoencoder = AutoEncoder(code_size).to(device)
loss_fn = nn.MSELoss()
optimizer = optimizer_cls(autoencoder.parameters(), lr=lr)
# Training loop
for epoch in range(num_epochs):
    print("Epoch %d" % epoch)
    for i, (images, _, _) in enumerate(loader):    # Ignore image labels
        out, code = autoencoder(Variable(images).to(device))
        optimizer.zero_grad()
        loss = loss_fn(out, images.to(device))
        loss.backward()
        optimizer.step()
        print("Loss = %.3f" % loss.data)
  
    # update LR
    lr /= 10
    for g in optimizer.param_groups:
        g['lr'] = lr

In [0]:
def get_data_from_autoenc(loader, n_batches=1):
    labels = []
    embeddings = []
    with torch.no_grad():
        for i in range(n_batches):
            xData, mData, batch_labels = next(iter(loader))
            xData = xData.to(device)
            mData = mData.to(device)
            batch_embeddings = autoencoder.encode(xData)

            labels.append(batch_labels)
            embeddings.append(batch_embeddings) 
  
  return torch.cat(embeddings, dim=0).cpu() , torch.cat(labels, dim=0).cpu()

In [41]:
evaluate_embedding(get_data_from_autoenc, batch_size=128)

Sizes are:  torch.Size([256, 1024]) torch.Size([7936, 1024])
Accuracy:  0.06149193548387097


0.06149193548387097