In [None]:
import imp
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

In [None]:
from rf_pool import modules, models, layers, ops
from rf_pool.utils import functions, datasets, stimuli

In [None]:
imp.reload(functions)
imp.reload(stimuli)
imp.reload(datasets)
imp.reload(ops)
imp.reload(layers)
imp.reload(modules)
imp.reload(models)

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.Pad(12), transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../rf_pool/data', train=True, download=True, 
                                       transform=transform)
testset = torchvision.datasets.MNIST(root='../rf_pool/data', train=False, download=True,
                                     transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=10,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=10,
                                         shuffle=True, num_workers=2)

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,16,5),
                                              activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(16,32,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,32,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('3', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,10,3)))
model.append('4', rf_pool.modules.FeedForward(input_shape=(-1,10)))

In [None]:
# print output_shapes
model.output_shapes((10,1,52,52))

In [None]:
# set loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
# train model, monitor weights and lattice
loss_history = model.train(10, trainloader, loss_fn, optim, monitor=100,
                           show_weights={'layer_id': '0', 'cmap': 'gray'})

In [None]:
trained_weights = model.download_weights()

In [None]:
model.get_accuracy(testloader)

Add receptive field layer, test on crowded digits

In [None]:
# get crowded MNIST training data
crowd_set = datasets.CrowdedDataset(testset, 0, 20000,
                                      transform=transforms.ToTensor(),
                                      spacing=20, background_size=100)
crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=10,
                                             shuffle=True, num_workers=2)

In [None]:
data, label = iter(testloader).next()
plt.imshow(data[0,0], cmap='gray')
plt.show()
print(label[0])

In [None]:
model.load_weights(trained_weights)

In [None]:
# remove reshape layer 
model.layers.pop('4')

In [None]:
model.append('4', rf_pool.modules.FeedForward(input_shape=(-1,10)))

In [None]:
label_params = {}.fromkeys(np.arange(10), 'branch_0')
label_params.update({}.fromkeys(np.arange(10,36), 'branch_1'))

In [None]:
model.output_shapes(data.shape)

In [None]:
model.get_accuracy(crowd_loader, crop=(slice(2,4), slice(2,4)))

In [None]:
# rf layer
img_shape = torch.as_tensor((50,50))
mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.23, 0., n_rf=None, n_rings=10, offset=[0.,-20.])
rf_layer = rf_pool.layers.RF_Pool(mu=mu, sigma=sigma, img_shape=img_shape, 
                                  lattice_fn=rf_pool.utils.lattice.mask_kernel_lattice,
                                  pool_type='max', kernel_size=2)
rf_layer.show_lattice()
print(mu.shape)
n_kernels = mu.shape[0]

In [None]:
model.layers['0'].forward_layer.add_module('pool', rf_layer)