In [None]:
import imp
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

In [None]:
from rf_pool import modules, models, layers, ops
from rf_pool.utils import functions, datasets

In [None]:
imp.reload(datasets)
imp.reload(functions)
imp.reload(ops)
imp.reload(layers)
imp.reload(modules)
imp.reload(models)

In [None]:
# get MNIST training data
transform = transforms.Compose([
                                transforms.Pad(12),
                                transforms.RandomRotation((90,90)),
                                transforms.RandomVerticalFlip(1.),
                                transforms.ToTensor()
                                ])
trainset = torchvision.datasets.EMNIST(root='../rf_pool/data', split='byclass', train=True, download=True, 
                                       transform=transform)
testset = torchvision.datasets.EMNIST(root='../rf_pool/data', split='byclass', train=False, download=True,
                                     transform=transform)

In [None]:
# for shifting fractured label sets
def custom_mapping(labels):
    init = np.sort(labels)
    out = np.arange(0,len(init))
    return {i:o for i,o in zip(init, out)}

In [None]:
# get crowded MNIST training data
target_labels = list(np.arange(10))+list(np.arange(36,62))
crowd_train = datasets.CrowdedDataset(trainset, 0, 100000, target_labels=target_labels,
                                      transform=transforms.ToTensor(), label_map=custom_mapping(target_labels),
                                      spacing=0, background_size=52)

In [None]:
crowd_test = datasets.CrowdedDataset(testset, 0, 25000, target_labels=target_labels,
                                     transform=transforms.ToTensor(), label_map=custom_mapping(target_labels),
                                     spacing=0, background_size=52)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(crowd_train, batch_size=10,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(crowd_test, batch_size=10,
                                         shuffle=True, num_workers=2)

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,16,5), 
                                              activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(16,32,5), 
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,32,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
branch2 = rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,10,3))
branch3 = rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,26,3))
model.append('3', rf_pool.modules.Branch(branches=[branch2, branch3], cat_output=True))
model.append('4', rf_pool.modules.FeedForward(input_shape=(-1,36)))

In [None]:
# print output_shapes with branch output concatenated
model.output_shapes((1,1,52,52))

In [None]:
# set loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
# train model, monitor weights and lattice
loss_history = model.train(1, crowded_loader, loss_fn, optim, monitor=100,
                           label_params=label_params,
                           show_weights={'layer_id': '0', 'cmap': 'gray'})

In [None]:
model.get_accuracy(testloader)

In [None]:
# crowded digit dataset
target_labels = list(np.arange(10))+list(np.arange(36,62))
crowd_test = datasets.CrowdedDataset(testset, 0, 25000, target_labels=target_labels,
                                  transform=transforms.ToTensor(), label_map=custom_mapping(target_labels),
                                  spacing=0, background_size=100)

In [None]:
# create trainloader
crowd_testloader = torch.utils.data.DataLoader(crowd_test, batch_size=10,
                                               shuffle=True, num_workers=2)

In [None]:
data, label = iter(crowd_testloader).next()
plt.imshow(data[0,0], cmap='gray')
plt.show()
label[0]

In [None]:
model.layers.pop('4')

In [None]:
model.output_shapes(data.shape)

In [None]:
model.get_accuracy(crowd_testloader, crop=(slice(3,4), slice(3,4)))

In [None]:
# rf layer
img_shape = torch.as_tensor((44,44))
mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.23, 0., n_rf=None, n_rings=10, offset=[0.,-20.])
rf_layer = rf_pool.layers.RF_Pool(mu=mu, sigma=sigma, img_shape=img_shape, 
                                  lattice_fn=rf_pool.utils.lattice.mask_kernel_lattice,
                                  pool_type='max', kernel_size=2)
rf_layer.show_lattice()
print(mu.shape)
n_kernels = mu.shape[0]

In [None]:
model.layers['1'].forward_layer.add_module('pool', rf_layer)