 # Loading Modules

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from what_where.main import init, MNIST
from what_where.where import RetinaFill, WhereShift, RetinaBackground, RetinaMask
from utils import view_dataset
from stn import STN_128x128 # importing the network

In [3]:
class FoveatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, input_size=96, pool=False):
        super().__init__()
        out_channels = int(out_channels / 4)

        if pool:
            self.conv1 = nn.Sequential(
                nn.AvgPool2d(2),
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            )
        else:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, dilation=2, padding=2)
        self.border1 = 0

        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.border2 = int((input_size - (input_size / 2)) / 2)

        self.conv3 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.border3 = int((input_size - (input_size / 4)) / 2)

        self.conv4 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=4, padding=0, output_padding=1)
        self.border4 = int((input_size - (input_size / 8)) / 2)

    def crop(self, x, border):
        return x[:, :, border:x.size(2) - border, border:x.size(3) - border]

    def forward(self, x):
        x1 = self.conv1(self.crop(x, self.border1))
        x2 = self.conv2(self.crop(x, self.border2))
        x3 = self.conv3(self.crop(x, self.border3))
        x4 = self.conv4(self.crop(x, self.border4))
        return torch.cat((x1, x2, x3, x4), dim=1)

In [4]:
fov_net = STN_128x128()

In [5]:
fov_net

STN_128x128(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 100, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=84100, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (localization): Sequential(
    (0): Conv2d(1, 16, kernel_size=(7, 7), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU(inplace=True)
    (6): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): ReLU(inplace=True)
    (9): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): ReLU(inplace=True)
  )
  (fc_loc): Sequenti