In [1]:
from models.ResnetPatchAE import PatchAutoEncoder
# from models.ConvPatchAE import PatchAutoEncoder
import torchvision.transforms as T
from PIL import Image
import torch.nn as nn
import numpy as np
import torchvision
import torch
import math

print(torch.__version__)
print(torchvision.__version__)
print(Image.__version__)

1.10.1
0.11.2
8.4.0


In [2]:
class PatchCIFAR100(torchvision.datasets.CIFAR100):
    """Overrides torchvision CIFAR100 to return patches and class targets
    """
    def __init__(self, transforms=None, grid_size=4, **kwds):
        super().__init__(**kwds)
        self.transforms = transforms
        self.grid_size = grid_size

    def __getitem__(self, index):
        image, class_targets = self.data[index], self.targets[index]

        if len(image.shape) == 2:
            image = gray2rgb(image)

        if self.transforms is not None:
            image = self.transforms(image)  

        shape = np.array(image.shape)
        patch_rw, patch_cl = shape[1]//self.grid_size, shape[2]//self.grid_size

        scale = T.Compose([T.Resize((patch_rw*self.grid_size, patch_cl*self.grid_size))])
        padding = torch.nn.ZeroPad2d((patch_cl, patch_cl, patch_rw, patch_rw))
        img = padding(scale(image))
        patches = img.data.unfold(0, 3, 3).unfold(1, patch_rw, patch_rw).unfold(2, patch_cl, patch_cl)

        neighbours = torch.zeros(self.grid_size*self.grid_size, 8, shape[0], patch_rw, patch_cl)
        target = torch.zeros(self.grid_size*self.grid_size, shape[0], patch_rw, patch_cl)

        k = 0

        for i in range(1, self.grid_size+1):
            for j in range(1, self.grid_size+1):

                neighbours[k, 0, :, :, :] = patches[0, i-1, j-1, :, :, :]
                neighbours[k, 1, :, :, :] = patches[0, i-1, j, :, :, :]
                neighbours[k, 2, :, :, :] = patches[0, i-1, j+1, :, :, :]
                neighbours[k, 3, :, :, :] = patches[0, i, j-1, :, :, :]
                target[k, :, :, :] = patches[0, i, j, :, :, :]
                neighbours[k, 4, :, :, :] = patches[0, i, j+1, :, :, :]
                neighbours[k, 5, :, :, :] = patches[0, i+1, j-1, :, :, :]
                neighbours[k, 6, :, :, :] = patches[0, i+1, j, :, :, :]
                neighbours[k, 7, :, :, :] = patches[0, i+1, j+1, :, :, :]

                k += 1           

        return target, class_targets

In [3]:
transforms = T.Compose([
                        T.ToPILImage(),
                        T.ToTensor(),
                      ])
batch_size = 128
grid_size=4

train_set_lineval = PatchCIFAR100(transforms=transforms, 
                     grid_size=grid_size,
                     root='data', train=True, 
                     )
test_set_lineval = PatchCIFAR100(transforms=transforms, 
                     grid_size=grid_size,
                     root='data', train=False, 
                     )

train_loader_lineval = torch.utils.data.DataLoader(train_set_lineval, batch_size=batch_size, shuffle=True)
test_loader_lineval = torch.utils.data.DataLoader(test_set_lineval, batch_size=batch_size, shuffle=False)

In [4]:
backbone_lineval = PatchAutoEncoder(in_channels=3, out_channels=64, flatten=True)
cuda = True
epoch_num = 120

ckp = torch.load('ckpts/cifar_100/resnet_ae_l2_0001_3_augmneted/checkpoint_' + str(epoch_num) + '.ckp', 'cuda' if cuda else None)
backbone_lineval.load_state_dict(ckp['state_dict'])

Channels:  [16, 32, 64]
Flatten:  True


<All keys matched successfully>

In [5]:
linear_layer = torch.nn.Sequential(torch.nn.Linear(64*16, 100))

def backbone_output(model, data):
    aggregate = torch.zeros((data.shape[0], 16, 64))
    for i in range(data.shape[1]):
        output = model(data[:, i])
        aggregate[:, i] = output
        
    return aggregate.reshape((data.shape[0], 16*64))

In [6]:
optimizer = torch.optim.Adam(linear_layer.parameters())                               
CE = torch.nn.CrossEntropyLoss()
linear_layer.train()
backbone_lineval.encoder.tower.eval()

print('Linear evaluation')
for epoch in range(5):
    accuracy_list = list()
    step = 0
    
    for i, (data, target) in enumerate(train_loader_lineval):
        optimizer.zero_grad()
        output = backbone_output(backbone_lineval.encoder.tower, data).detach()        
        output = linear_layer(output)
        loss = CE(output, target)
        loss.backward()
        optimizer.step()
        
        # Estimate the accuracy
        prediction = output.argmax(-1)
        correct = prediction.eq(target.view_as(prediction)).sum()
        accuracy = (100.0 * correct / len(target))
        accuracy_list.append(accuracy.item())
        
    print('Epoch [{}] loss: {:.5f}; accuracy: {:.2f}%' \
            .format(epoch+1, loss.item(), sum(accuracy_list)/len(accuracy_list)))  

Linear evaluation
Epoch [1] loss: 3.40884; accuracy: 18.32%
Epoch [2] loss: 2.88396; accuracy: 30.58%
Epoch [3] loss: 2.51456; accuracy: 35.09%
Epoch [4] loss: 2.45039; accuracy: 37.98%
Epoch [5] loss: 2.50128; accuracy: 40.05%


In [7]:
accuracy_list = list()
for i, (data, target) in enumerate(test_loader_lineval):
    output = backbone_output(backbone_lineval.encoder.tower, data).detach()
    output = linear_layer(output)
     # Estimate the accuracy
    prediction = output.argmax(-1)
    correct = prediction.eq(target.view_as(prediction)).sum()
    accuracy = (100.0 * correct / len(target))
    accuracy_list.append(accuracy.item())

print('Test accuracy: {:.2f}%'.format(sum(accuracy_list)/len(accuracy_list)))

Test accuracy: 32.91%


| Model      | Test Accuracy | Epoch|
| ---------- | ----------- | ----------- |
| ResnetPatchAEInner      | 10.18 %  |  100  |
| ResnetPatchAE   |      27.30%   | 100|
| ResnetPatchAE   |      30.46%   | 10|
| ResnetPatchAEAugmented   |      34.08%   | 10|
| ResnetPatchAEAugmented   |      33.18%   | 30|
| ResnetAEAugmented   |      16.15%   | 10|
| ResnetAEAugmented   |      17.81%   | 100|
| Conv4AEAugmented   |      24.55%   | 10|