In [None]:
import imp
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

In [None]:
from rf_pool import modules, models, layers, ops
from rf_pool.utils import functions, datasets, stimuli

In [None]:
imp.reload(functions)
imp.reload(stimuli)
imp.reload(datasets)
imp.reload(ops)
imp.reload(layers)
imp.reload(modules)
imp.reload(models)

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../rf_pool/data', train=True, download=True, 
                                       transform=transform)
testset = torchvision.datasets.MNIST(root='../rf_pool/data', train=False, download=True,
                                     transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=10,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=10,
                                         shuffle=True, num_workers=2)

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,32,5),
                                              activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(64,10,4)))
model.append('3', rf_pool.modules.FeedForward(input_shape=(-1,10)))

In [None]:
# set loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
# train model, monitor weights and lattice
loss_history = model.train(1, trainloader, loss_fn, optim, monitor=100,
                           show_weights={'layer_id': '0', 'cmap': 'gray'})

In [None]:
acc = model.get_accuracy(testloader)

In [None]:
(_, extras) = model.load_model('crowding_experiment/attention_3deg.pkl')
crowd_int = extras.get('crowd_int')

In [None]:
fig = plt.figure()
plt.plot(crowd_int.get('extent'), [1. - np.mean(c) for c in crowd_int.get('outer')])
plt.plot(crowd_int.get('extent'), [1. - np.mean(c) for c in crowd_int.get('inner')])
plt.plot(crowd_int.get('extent'), [1. - np.mean(c) for c in crowd_int.get('radial')])
plt.plot(crowd_int.get('extent'), [1. - np.mean(c) for c in crowd_int.get('tangential')])
plt.legend(['outer', 'inner', 'radial','tangential'])
plt.show()
# fig.savefig('crowding_experiment/test.png', dpi=600)

Add receptive field layer, test on crowded digits

In [None]:
# remove reshape layer 
model.layers.pop('3')

In [None]:
model.output_shapes(data.shape)

In [None]:
model.append('4', rf_pool.modules.FeedForward(input_shape=(-1,10)))

In [None]:
# get crowded MNIST training data
crowd_set = datasets.CrowdedDataset(testset, 1, 1000,
                                      transform=transforms.ToTensor(),
                                      spacing=20, background_size=118, axis=0)
crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=10,
                                             shuffle=True, num_workers=2)

In [None]:
data, label = iter(crowd_loader).next()
plt.imshow(data[0,0], cmap='gray')
plt.show()
print(label[0])

In [None]:
from rf_pool.utils import visualize
imp.reload(visualize)

In [None]:
data = torch.zeros(1, 1, 118, 118)
data[:,:,59-14:59+14, 59-14:59+14] = 1.
rf_idx = rf_pool.utils.visualize.index_rfs(model, '1', data)[0]

In [None]:
heatmap = rf_pool.utils.visualize.rf_heatmap(model, '1')

In [None]:
tmp = torch.rand(53, 1, 1)

In [None]:
plt.imshow(torch.sum(heatmap * tmp, 0))
plt.show()

In [None]:
# rf layer
img_shape = torch.as_tensor((53,53))
mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.23, 0., n_rf=None, n_rings=11, offset=[0.,-30.])
rf_layer = rf_pool.layers.RF_Pool(mu=mu[rf_idx], sigma=sigma[rf_idx], img_shape=img_shape, 
                                  lattice_fn=rf_pool.utils.lattice.mask_kernel_lattice,
                                  pool_type=None, kernel_size=2, thr=np.exp(-1.))
rf_layer.show_lattice()
print(mu.shape)
n_kernels = mu.shape[0]

In [None]:
model.layers['1'].forward_layer.add_module('pool', rf_layer)

In [None]:
model.output_shapes((1,1,118,118))

In [None]:
data, label = iter(crowd_loader).next()
plt.imshow(data[0,0], cmap='gray')
plt.show()
print(label[0])

In [None]:
acc_2d = {'spacing': [0.75, 1., 1.25, 1.5, 1.75],
          'outer': [55.1,64.4,65.3,], 'inner': [52.0,65.5,70.1], 
          'radial': [31.6,58.6,65.2], 'tangential': [41.1,66.4,71.7], 
          'none': [72.5]}

In [None]:
acc_2d_3layer = {'spacing': [0.75,1.,1.25,1.5,1.75], 
                 'outer': [], 'inner': [],
                 'radial': [], 'tangential': [], 
                 'none': []}

In [None]:
for spacing in [1.25, 1.5, 1.75]:
    for key in ['outer','inner']:#,'radial','tangential']:
        if key in ['radial','tangential']:
            n_flankers = 2
        else:
            n_flankers = 1
        if key == 'inner':
            axis = np.pi
        elif key == 'tangential':
            axis = np.pi / 2.
        else:
            axis = 0.
        # get crowded MNIST training data
        testset = torchvision.datasets.MNIST(root='../rf_pool/data', train=False, download=True,
                                             transform=transform)
        crowd_set = datasets.CrowdedDataset(testset, n_flankers, 1000,
                                              transform=transforms.ToTensor(),
                                              spacing=20*spacing, background_size=100, axis=axis)
        crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=2,
                                                   shuffle=True, num_workers=2)
        # get accuracy 
        acc = model.get_accuracy(crowd_loader, crop=(slice(9,10), slice(9,10)), monitor=10)
        acc_2d_3layer.update({key: list(acc_2d_3layer.get(key)) + [acc]})
        print(key, acc_2d_3layer)