# Loading Modules

In [104]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from LogGabor import LogGabor
from utils import view_data
from typing import List, Tuple

# Loading the Data

In [105]:
#args.offset_max = 40 #like in the paper

In [106]:
image_size = 256

transform =  transforms.Compose([
            transforms.Resize((int(image_size), int(image_size))),
            #transforms.AutoAugment(), # https://pytorch.org/vision/master/transforms.html#torchvision.transforms.AutoAugment
            transforms.ToTensor(),      # Convert the image to pyTorch Tensor data type.
        ])

In [107]:
image_path = "../data/animal/"

image_dataset = { 'train' : datasets.ImageFolder(
                            image_path+'train', 
                            transform=transform
                        ),
                  'test' : datasets.ImageFolder(
                            image_path+'test', 
                            transform=transform
                        )
                }

In [108]:
dataset_size = {'train' : len(image_dataset['train']),
                'test' : len(image_dataset['test'])}

dataset_size['train'], dataset_size['test']

(2000, 1200)

In [109]:
batch_size = 50
num_workers = 1

dataloader = { 'train' : torch.utils.data.DataLoader(
                            image_dataset['train'], batch_size=batch_size,
                            shuffle=True, 
                            num_workers=num_workers,
                        ),
               'test' : torch.utils.data.DataLoader(
                            image_dataset['test'], batch_size=batch_size,
                            shuffle=True, 
                            num_workers=num_workers,
                        )
             }

- The original format

# Creating an Attention Transformer model with log-polar entry (POLO-STN)

In [None]:
class Polo_AttentionTransNet(nn.Module):
    
    def __init__(self):
        super(Polo_AttentionTransNet, self).__init__()

        ##  The what pathway
        self.conv1 = nn.Conv2d(3, 20, 5)
        self.conv2 = nn.Conv2d(20, 50, 5)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 2)


        self.downscale = nn.Parameter(torch.tensor([[.2, 0], [0, .2]], dtype=torch.float), requires_grad=False)

    def stn(self: object, x:torch.Tensor) -> Tuple[torch.Tensor]:
        
        x = torch.tensor(x, dtype=torch.float)
        mu = torch.tensor([0, 0],dtype=torch.float)
        mu = mu.unsqueeze(0).repeat(x.size()[0], 1)   
        std = np.exp(-3/2)
        sigma = torch.tensor([std, std],dtype=torch.float)
        sigma = sigma.unsqueeze(0).repeat(x.size()[0], 1)   
              
        self.q = torch.distributions.Normal(mu, sigma)
        z = self.q.rsample()
        print(z[0])
        
        theta = torch.cat((self.downscale.unsqueeze(0).repeat(x.size(0), 1, 1), 
                           z.unsqueeze(2)),
                            dim=2)
        
        grid_size = torch.Size([x.size()[0], x.size()[1], 28, 28])
        grid = F.affine_grid(theta, grid_size)
        x = F.grid_sample(x, grid)
       
        return x, theta, z

    def forward(self, x):
        # transform the input
        x, theta, z = self.stn(x)                                   

        # print(x.shape)
        # Perform the usual forward pass
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)

        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x, theta, z

In [None]:
def train(epoch, loader):
    model.train()
    for batch_idx, (data, target) in enumerate(loader):

        data, target = data.to(device, dtype=torch.double), target.to(device)

        optimizer.zero_grad()
        output, theta, z = model(data)
        
        loss = loss_func(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, n_epochs, batch_idx * len(data),
                len(dataloader['train'].dataset),
                100. * batch_idx / len(dataloader['train']), loss.item()))


def test(loader):
    with torch.no_grad():
        model.eval()
        test_loss = 0
        correct = 0
        for data, target in loader:
            data, target = data.to(device, dtype=torch.double), target.to(device)

            optimizer.zero_grad()
            output, theta, z = model(data)

            # sum up batch loss
            #test_loss += F.nll_loss(output, target, size_average=False).item()
            test_loss += loss_func(output, target).item()
            # get the index of the max log-probability
            #pred = output.max(1, keepdim=True)[1]
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(dataloader['test'].dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.
              format(test_loss, correct, len(dataloader['test'].dataset),
                     100. * correct / len(dataloader['test'].dataset)))
        return correct / len(dataloader['test'].dataset)

# Training 

In [None]:
lr = 1e-4
log_interval = 1000

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = torch.load("../models/low_comp_polo_stn.pt")
model = Polo_AttentionTransNet().to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=lr)
#optimizer = optim.SGD(model.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss()
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9, last_epoch=-1) #, verbose=True)

In [None]:
acc = []

In [None]:
n_epochs = 100
for epoch in range(n_epochs):
    train(epoch, dataloader['train'])
    curr_acc = test(dataloader['test'])
    acc.append(curr_acc)
    #scheduler.step()

tensor([0.1019, 0.2776])
tensor([-0.0268, -0.3799])
tensor([-0.1220,  0.1738])
tensor([0.0372, 0.1412])
tensor([0.0736, 0.1394])
tensor([-0.0076,  0.2421])
tensor([0.0826, 0.2255])
tensor([-0.2906, -0.1683])
tensor([0.1116, 0.0828])
tensor([-0.3784, -0.0180])
tensor([-0.0310,  0.0845])
tensor([0.0591, 0.0487])
tensor([-0.0926,  0.0868])
tensor([-0.0143,  0.3228])
tensor([-0.0347, -0.2162])
tensor([ 0.2330, -0.3177])
tensor([-0.0664,  0.0149])
tensor([-0.0299,  0.1847])
tensor([-0.2768, -0.3199])
tensor([0.1915, 0.0041])
tensor([ 0.4675, -0.0697])
tensor([-0.1199, -0.1278])
tensor([-0.1662, -0.2571])
tensor([-0.3072,  0.1417])
tensor([ 0.3294, -0.1055])
tensor([-0.0003, -0.0719])
tensor([0.3138, 0.3982])
tensor([ 0.1455, -0.1690])
tensor([-0.1028,  0.2212])
tensor([0.1042, 0.2828])
tensor([-0.0520,  0.0492])
tensor([ 0.0715, -0.0718])
tensor([-0.4023,  0.1579])
tensor([-0.2272, -0.0321])
tensor([-0.2465,  0.0348])
tensor([ 0.1591, -0.0505])
tensor([ 0.1451, -0.1796])
tensor([-0.1259, -0

tensor([-0.1028,  0.1047])
tensor([-0.5436,  0.0224])
tensor([-0.0150, -0.0274])
tensor([0.0546, 0.4162])
tensor([-0.1689, -0.0807])
tensor([-0.1219,  0.0825])
tensor([-0.2460, -0.2334])
tensor([-0.1049,  0.0311])
tensor([-0.2416, -0.1245])
tensor([-0.0455,  0.3503])
tensor([-0.3364,  0.1479])
tensor([-0.3502, -0.1620])
tensor([0.2855, 0.1835])
tensor([0.1900, 0.1576])
tensor([ 0.3597, -0.0682])
tensor([0.5920, 0.2266])
tensor([-0.1606, -0.1046])
tensor([-0.0165,  0.3467])
tensor([-0.1046, -0.5305])
tensor([-0.6196,  0.1370])
tensor([-0.2635, -0.0236])
tensor([-0.1739, -0.1579])
tensor([-0.4512, -0.2809])
tensor([0.3019, 0.1637])
tensor([0.0138, 0.2126])
tensor([ 0.1390, -0.2071])
tensor([0.1081, 0.2030])
tensor([0.0072, 0.4914])
tensor([ 0.2359, -0.0129])

Test set: Average loss: 0.0133, Accuracy: 715/1200 (60%)

tensor([0.1346, 0.1498])
tensor([ 0.4373, -0.4670])
tensor([0.2092, 0.2250])
tensor([0.0407, 0.0118])
tensor([-0.5206, -0.0470])
tensor([0.0543, 0.3602])
tensor([-0.1907,  0.

In [None]:
plt.plot(acc)

In [None]:
torch.save(model, "stn_lenet_baseline.pt")

In [None]:
data, label = next(iter(dataloader['test']))

In [None]:
output = model.stn(data)

In [None]:
num=1
with torch.no_grad():
    cat, theta, z = model(data)
cat = torch.argmax(F.softmax(cat),1)

In [None]:
#theta = torch.cat((model.downscale.unsqueeze(0).repeat(
#            theta.size(0), 1, 1), theta.unsqueeze(2)),
#                          dim=2)
        
        #theta = theta.view(-1, 2, 3)
for num in range(50):
    plt.figure(figsize=(5,10))
    plt.subplot(1,2,1)
    img = data[num:num+1,...]   
    th = theta[num:num+1,...]
    #theta[0,:,2] = torch.FloatTensor((0,.5))
    #print(th)
    grid_size = torch.Size([1, 3, 28, 28])
    grid = F.affine_grid(th, grid_size)
    img_grid = F.grid_sample(img, grid)
    plt.imshow(data[num,...].permute(1,2,0).detach().numpy())
    plt.title(str(num)+', '+str(label[num]))
    plt.subplot(1,2,2)
    plt.imshow(img_grid[0,:].permute(1,2,0).detach().numpy())
    plt.plot(14,14,'r+')
    plt.title(cat[num])

In [None]:
#theta = torch.cat((model.downscale.unsqueeze(0).repeat(
#            theta.size(0), 1, 1), theta.unsqueeze(2)),
#                          dim=2)
mem_z  = []
mem_cat  = []

        #theta = theta.view(-1, 2, 3)
for num in [43,]*50: #range(50):
    with torch.no_grad():
        #img = data[0][num,...].unsqueeze(0)
        #polo_img = {'in': data[1]['in'][num,...].unsqueeze(0),
        #             'out': data[1]['out'][num,...].unsqueeze(0)}
        cat, theta, z = model(data)
    cat = torch.argmax(F.softmax(cat),1)
    plt.figure(figsize=(5,10))
    plt.subplot(1,2,1)
    img = data[num:num+1,...]   
    th = theta[num:num+1,...]
    mem_z.append(z)
    mem_cat.append(cat)
    #theta[0,:,2] = torch.FloatTensor((0,.5))
    #print(th)
    grid_size = torch.Size([1, 3, 28, 28])
    grid = F.affine_grid(th, grid_size)
    img_grid = F.grid_sample(img, grid)
    plt.imshow(data[num,...].permute(1,2,0).detach().numpy())
    plt.title(label[num])
    plt.subplot(1,2,2)
    plt.imshow(img_grid[0,:].permute(1,2,0).detach().numpy())
    plt.plot(14,14,'r+')
    plt.title(cat[num])


In [None]:
for num in range(batch_size):
    plt.figure()
    plt.imshow(data[num,...].permute(1,2,0).detach().numpy())
    for i, z in enumerate(mem_z):
        if mem_cat[i][num] == 1:
            plt.plot(127.5 + z[num][0]*128, 127.5 + z[num][1]*128,'r+')
        else:
            plt.plot(127.5 + z[num][0]*128, 127.5 + z[num][1]*128,'b+')