In [None]:
import pickle 

import imp
from IPython.display import clear_output, display
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import os

from context import rf_pool

In [None]:
from rf_pool import models, modules, pool, ops
from rf_pool.utils import lattice, functions, visualize, datasets, stimuli

In [None]:
from experiment_functions import * 

In [None]:
directories = ['results', 'datasets']
for d in directories:
    if not os.path.exists(d):
        os.mkdir(d)

**Load MNIST Data**

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../../data', train=True, download=True, 
                                       transform=transform)
testset = torchvision.datasets.MNIST(root='../../data', train=False, download=True,
                                     transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=True, num_workers=2)

In [None]:
# load in crowded digits base set
base_set_filename = 'MNIST_CrowdedDataset.pkl'
if os.path.exists('datasets/' + base_set_filename):
    base_set = pickle.load(open('datasets/' + base_set_filename, 'rb'))
else:
    base_set = None
    print('Base Set not found!')
    
# set what labels mmap to what digit
label_map = {}
label_map.update([(n,n) for n in range(10)])

**Load Model**

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,32,5),
                                              activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(64,10,4)))
model.append('3', rf_pool.modules.FeedForward(input_shape=(-1,10)))

In [None]:
# load previous model and results
(_, extras) = model.load_model('models/MNIST_rate_0.2_10k_3deg.pkl')

In [None]:
# remove reshape layer 
model.layers.pop('3')

**Replace max pool layer with rf pool layer**

In [None]:
# create the rf layer
img_shape = torch.Size((53,53))
offset = [0., -30.] # right visual field (3deg)
ref_axis = 0. # set the reference axis for the crowding configurations
RF_rate = 0.2
gap = 0.
n_rings = 10
std = 1.
mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, RF_rate, gap, n_rings=n_rings, std=std,
                                                        offset=offset)
rf_layer = rf_pool.pool.RF_Pool(mu=mu, sigma=sigma, img_shape=img_shape, 
                                lattice_fn=rf_pool.utils.lattice.mask_kernel_lattice,
                                pool_fn='max_pool', kernel_size=2)
n_kernels = rf_layer.mu.shape[0]
# append the rf pool layer to the model
layer_id = '1'
model.layers[layer_id].forward_layer.add_module('pool', rf_layer)
visualize.heatmap(model, '1');

**Test Attention**

In [None]:
# set experiment parameters
exp_type = 'Attention'
n_test = 100
batch_size = 1
# get target data
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                            shuffle=False, num_workers=2)

In [None]:
extents = np.arange(7.,29.5,2.5)
spacing = 1.
crowd_types = ['radial','tangential', 'outer', 'inner']
# run the experiment
RF_ACC = {}
for key in crowd_types: 
    n_flankers, axis = get_crowd_params(key, ref_axis)
    crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing,
                                 base_set, label_map)
    crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=batch_size,
                                               shuffle=False, num_workers=2)
    param_space = extents
    param_name = ['extent']
    lattice_fn = lattice.init_foveated_lattice
    lattice_kwargs = {'img_shape': img_shape, 'scale': RF_rate, 'spacing': gap, 
                      'std': std, 'n_rings': n_rings, 'offset': offset, 'rotate': lambda : np.pi * np.random.rand()}
    # Get accuracy
    ACC = functions.param_search(get_accuracy, [target_loader, crowd_loader], 
                                 {'model':model,'batch_size': 1, 'layer_id':'1', 'extent': None,
                                  'lattice_fn': lattice_fn, 'lattice_kwargs': lattice_kwargs},
                                 param_name, param_space, verbose=False, show_cost=False)
    RF_ACC.update({key + '_attn': ACC, 'extent': param_space})

In [None]:
with open('results/ACC_%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'wb') as f:
    pickle.dump(RF_ACC, f)

**Test Spacing**

In [None]:
# set experiment parameters
exp_type = 'Spacing'
n_test = 100
batch_size = 1
# set lattice kwargs
img_shape = torch.Size((53,53))
RF_rate = 0.2
gap = 0.
n_rings = 10
std = 1.

# get angle between RFs
n_RF = np.floor(np.pi / RF_rate)
angles = 2. * np.pi * np.linspace(0., 1., np.int(n_RF))[:-1]
rot_angle = angles[1]

# get target data
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map,
                              transform=transforms.Compose([transforms.ToTensor(), jitter_fn(5, 5, seed=0)]))
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                            shuffle=False, num_workers=2)

In [None]:
spacings = np.arange(1., 2.25, 0.25)
crowd_types = ['outer','inner','radial','tangential']
# run the experiment
RF_ACC = {}
for key in crowd_types:
    n_flankers, axis = get_crowd_params(key, ref_axis)
    ACC = []
    for spacing in spacings:
        clear_output(wait=True)
        display('%s: %a' % (key, spacing))
        crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing, base_set, label_map,
                                     transform=transforms.Compose([transforms.ToTensor(), jitter_fn(5, 5, seed=0)]))
        crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=batch_size,
                                                   shuffle=False, num_workers=2)
        lattice_fn = lattice.init_foveated_lattice
        lattice_kwargs = {'img_shape': img_shape, 'scale': RF_rate, 'spacing': gap, 
                          'std': std, 'n_rings': n_rings, 'offset': offset, 
                          'rotate': rotate_fn(rot_angle / 2., base_angle=0, seed=0)}
        # Get accuracy
        ACC.append(get_accuracy(target_loader, crowd_loader,
                                **{'layer_id':'1', 'batch_size': batch_size, 'model':model,'extent': None, 
                                   'lattice_fn': lattice_fn,
                                   'lattice_kwargs': lattice_kwargs}))
    RF_ACC.update({key + '_space': ACC, 'spacing': np.arange(1., 2.25, 0.25)})

In [None]:
with open('results/ACC_%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'wb') as f:
    pickle.dump(RF_ACC, f)

**Test Size/Position Attention**

In [None]:
# set experiment parameters
exp_type = 'Attention_Size_Position'
n_test = 100
batch_size = 1
# get target data
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                            shuffle=False, num_workers=2)

In [None]:
extents = np.arange(7.,29.5,2.5)
spacing = 1.
crowd_types = ['radial','tangential', 'outer', 'inner']
# run the experiment
RF_ACC = {}
for key in crowd_types: 
    ACC = []
    # set update_mu, update_sigma to True/False
    for update_mu in [True, False]:
        update_sigma = (not update_mu)
        n_flankers, axis = get_crowd_params(key, ref_axis)
        crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing,
                                     base_set, label_map)
        crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=batch_size,
                                                       shuffle=False, num_workers=2)
        param_space = extents
        param_name = ['extent']
        lattice_fn = lattice.init_foveated_lattice
        lattice_kwargs = {'img_shape': img_shape, 'scale': RF_rate, 'spacing': gap, 
                          'std': std, 'n_rings': n_rings, 'offset': offset,
                          'rotate': lambda : np.pi * np.random.rand()}
        # Get accuracy
        ACC.append(functions.param_search(get_accuracy, [target_loader, crowd_loader], 
                                          {'model': model, 'batch_size': 1, 'layer_id':'1', 'extent': None,
                                           'lattice_fn': lattice_fn, 'lattice_kwargs': lattice_kwargs,
                                           'update_mu': update_mu, 'update_sigma': update_sigma},
                                          param_name, param_space, verbose=False, show_cost=False))
    RF_ACC.update({key + '_attn': ACC, 'extent': param_space})

In [None]:
with open('results/ACC_%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'wb') as f:
    pickle.dump(RF_ACC, f)

**Get Heatmaps**

In [None]:
# set experiment parameters
exp_type = 'Heatmaps'
n_test = 100
batch_size = 1
# set lattice kwargs
img_shape = torch.Size((53,53))
RF_rate = 0.2
gap = 0.
n_rings = 10
std = 1.
# get target data
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map,
                              transform=transforms.Compose([transforms.ToTensor(), jitter_fn(5, 5, seed=0)]))
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
# set rms function
rms = lambda x: torch.sqrt(torch.mean(torch.pow(torch.tensor(x), 2.)))
# get results
results = pickle.load(open('ACC_spacing_10k.pkl', 'rb'))
results.update(pickle.load(open('ACC_attention_10k.pkl', 'rb')))

In [None]:
spacings = [1., 2.]
extents = [None, 7., 27.]
crowd_types = ['inner', 'outer', 'radial', 'tangential']
# run the experiment
RF_heatmaps = {}
# get accuracy heatmaps
for key in crowd_types:
    n_flankers, axis = get_crowd_params(key, ref_axis)
    heatmap = []
    for extent in extents:
        for spacing in spacings:
            # if extent and spacing > 1, skip
            if extent is not None and spacing > 1.:
                continue
            # get crowding set (without target to account for flankers in get_contribution call)
            crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing, base_set, label_map,
                                         no_target=True,
                                         transform=transforms.Compose([transforms.ToTensor(),
                                                                       jitter_fn(5, 5, seed=0)]))
            crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=batch_size,
                                                       shuffle=False, num_workers=2)
            # set lattice kwargs 
            lattice_fn = lattice.init_foveated_lattice
            lattice_kwargs = {'img_shape': img_shape, 'scale': RF_rate, 'spacing': gap, 
                              'std': std, 'n_rings': n_rings, 'offset': offset, 
                              'rotate': rotate_fn(rot_angle / 2., base_angle=0, seed=0)}
            # get heatmap
            if extent is None:
                result_name = key + '_space'
                result_idx = np.where(results.get('spacing') == spacing)[0].item()
            else:
                result_name = key + '_attn'
                result_idx = np.where(results.get('extent') == extent)[0].item()
            # get rf scores
            heatmap.append(get_contribution(target_loader, crowd_loader,
                                            **{'layer_id':'1',
                                               'model':model,
                                               'RF_mask': None, 
                                               'acc': results[result_name][result_idx][-1],
                                               'extent': extent,
                                               'lattice_fn': lattice_fn,
                                               'lattice_kwargs': lattice_kwargs}))
            # monitor progress
            clear_output(wait=True)
            display('%s: spacing %a, extent %a' % (key, spacing, extent))
            # plot heatmap
            visualize.heatmap(model, '1', score_map=heatmap[-1], cmap='Greens', colorbar=True,
                              filename='results/heatmap_%s_%s_%s.png' % (key, spacing, extent));
    # append heatmaps
    RF_heatmaps.update({key + '_heatmap': heatmap, 'spacing': spacings, 'extent': extents})

In [None]:
with open('results/ACC_%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'wb') as f:
    pickle.dump(RF_heatmaps, f)