In [22]:
import torch
import pprint as pp
import torchvision
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
import numpy as np
from operator import add
import matplotlib.pyplot as plt

In [23]:
# hyperparameters

batch_size = 64
num_caps = 10
learning_rate = 0.01
weight_decay = 1e-5
momentum = 0.5
torch.backends.cudnn.enabled = False

In [24]:
# Create data loaders.

transform = torchvision.transforms.Compose([
              torchvision.transforms.ToTensor(),
              torchvision.transforms.Normalize((0.1307,), (0.3081,))
              #torchvision.transforms.RandomRotation(20)
            ])
train_dataloader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=transform),
  batch_size=batch_size, shuffle=True, drop_last=True)

test_dataloader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=transform),
  batch_size=batch_size, shuffle=True, drop_last=True)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])
Shape of y:  torch.Size([64]) torch.int64


In [25]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

Using cuda device


In [32]:
def shift_image(image):
    dx = np.random.randint(-3, 4)
    dy = np.random.randint(-3, 4)
    image = torch.roll(image, dx, -1)
    image = torch.roll(image, dy, -2)
    return image, dx, dy

In [33]:
class Capsule(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # N, 1, 28, 28
            nn.Conv2d(1, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16), 
            nn.ReLU(True),
            nn.MaxPool2d(2), # 14
            
            nn.Conv2d(16, 8, 3, padding=1, bias=False),
            nn.BatchNorm2d(8), 
            nn.ReLU(True),
            nn.MaxPool2d(2), # 7
        )
    
        self.lin = nn.Sequential(
            nn.Dropout(),
            nn.Linear(8, 3),
            nn.Sigmoid()
        )


        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 2, 3, stride=2, padding=1, bias=False, output_padding=1), # B, 8, 14, 14
            nn.BatchNorm2d(2), 
            nn.ReLU(True),

            nn.ConvTranspose2d(2, 1, 3, stride=2, padding=1, bias=False, output_padding=1), # B, 1, 28, 28
               
        )
        
        self.decoder_lin = nn.Sequential(
            nn.Linear(2, 4*49), 
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.mean(-1).mean(-1)
        z = self.lin(x)
        p = z[:,[0]]
        xy = z[:,[1,2]]
        return p, xy
    
    
    def decode(self, xy):
        xy = self.decoder_lin(xy)
        xy = torch.reshape(xy, (64, 4 ,7 , 7))
        return self.decoder(xy)
    
capsule = Capsule().to(device)

In [34]:
class TAE(nn.Module):
    def __init__(self, num_caps=num_caps):
        super().__init__()
        self.caps = nn.ModuleList([
            Capsule().to(device) for _ in range(num_caps) 
        ])
    
    def forward(self, img, dx_dy):
        output_image = 0.0
        for caps in self.caps:
            p_list, xy_list = caps.encode(img)
            xy_shifted = xy_list + dx_dy
            #xy_shifted = torch.stack(xy_shifted)
            img_rec = caps.decode(xy_shifted)
            img_rec = img_rec * p_list[:, :, None, None]
            output_image = output_image + img_rec

        return output_image
    
    

model = TAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model

TAE(
  (caps): ModuleList(
    (0): Capsule(
      (encoder): Sequential(
        (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (lin): Sequential(
        (0): Dropout(p=0.5, inplace=False)
        (1): Linear(in_features=8, out_features=3, bias=True)
        (2): Sigmoid()
      )
      (decoder): Sequential(
        (0): ConvTranspose2d(4, 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
  

In [35]:
shifts = [
            [-2, -2], [-2, -1], [-2, 0], [-2, 1], [-2, 2],
            [-1, -2], [-1, -1], [-1, 0], [-1, 1], [-1, 2],
            [0, -2], [0, -1], [0, 0], [0, 1], [0, 2],
            [1, -2], [1, -1], [1, 0], [1, 1], [1, 2],
            [2, -2], [2, -1], [2, 0], [2, 1], [2, 2]
        ]

In [36]:
def train(train_loader):
    loss = 0
    criterion = nn.MSELoss().to(device)
    
    for batch_features, target in train_loader:

        batch_features, target = batch_features.to(device), target.to(device)
        new_features, dx, dy = shift_image(batch_features)

        optimizer.zero_grad()

        output_image = model(batch_features, torch.Tensor([dx, dy]).to(device))

        # compute training reconstruction loss
        train_loss = criterion(output_image, new_features)
        
        # compute accumulated gradients
        train_loss.backward()
        
        # perform parameter update based on current gradients
        optimizer.step()
        
        # add the mini-batch training loss to epoch loss
        loss = loss + train_loss.item()
    
    # compute the epoch training loss
    loss = loss / len(train_loader)
    
    # display the epoch training loss
    print("epoch : {}/{}, loss = {:.6f}".format(epochs + 1, epochs, loss))
    return output_image, batch_features

def test(test_loader):
    outputs = []
    for batch_features, target in test_loader:
        batch_features, target = batch_features.to(device), target.to(device)
        
        for dx, dy in shifts:
            rolled_image = torch.roll(batch_features, dx, -1)
            rolled_image = torch.roll(rolled_image, dy, -2)
            new_features, dx, dy = shift_image(rolled_image)
            output_image = model(batch_features, torch.Tensor([dx, dy]).to(device))
            outputs.append(output_image)

        return outputs, target[0]

In [None]:
torch.autograd.set_detect_anomaly(True)

epochs = 20
outputs = []  
for epoch in range(epochs):
    print(f"Epoch {epoch+1}\n-------------------------------")
    output, _ = train(train_dataloader)
    outputs.append(output)
shift_grid = test(test_dataloader)

print("Finished")

Epoch 1
-------------------------------
epoch : 21/20, loss = 0.571363
Epoch 2
-------------------------------


In [None]:
# for k in range(0, epochs, 4): 
#     plt.figure(figsize=(9, 2))
#     plt.gray()
#     imgs = outputs[k].detach().cpu().numpy()
#     for i, item in enumerate(imgs):
#         if i >= 9: break
#         plt.subplot(2,9, i+1)
#         plt.imshow(item[0])

#     for i, item in enumerate(recon):
#         if i >= 9: break
#         plt.subplot(2, 9, 9+i+1)
#         plt.imshow(item)



In [None]:
for _ in range(5):
  shift_grid, target = test(test_dataloader)
  print(target)
  plt.figure(figsize=(5, 5))
  plt.gray()
  for i, item in enumerate(shift_grid):
      item = item.detach().cpu().numpy()
      if i >= 25: break
      plt.subplot(5,5, i+1)
      plt.imshow(item[0][0])
