In [None]:
import pickle 

import imp
from IPython.display import clear_output, display
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import os
from sklearn import linear_model
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from context import rf_pool

In [None]:
from rf_pool import models, modules, pool, ops
from rf_pool.utils import lattice, functions, visualize, datasets, stimuli

In [None]:
from experiment_functions import * 

In [None]:
directories = ['results', 'datasets']
for d in directories:
    if not os.path.exists(d):
        os.mkdir(d)

**Load MNIST Data**

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../../data', train=True, download=True, 
                                       transform=transform)
testset = torchvision.datasets.MNIST(root='../../data', train=False, download=True,
                                     transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=True, num_workers=2)

In [None]:
# load in crowded digits base set
base_set_filename = 'MNIST_cross_CrowdedDataset.pkl'
if os.path.exists('datasets/' + base_set_filename):
    base_set = pickle.load(open('datasets/' + base_set_filename, 'rb'))
else:
    base_set = None
    print('Base Set not found!')
    
# set what labels mmap to what digit
label_map = {}
label_map.update([(n,n) for n in range(10)])

**Load Model**

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,32,5),
                                              activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(64,10,4)))
model.append('3', rf_pool.modules.FeedForward(input_shape=(-1,10)))

In [None]:
# load previous model and results
(_, extras) = model.load_model('models/MNIST_rate_0.2_10k_3deg.pkl')

In [None]:
# remove reshape layer 
model.layers.pop('3')

**Replace max pool layer with rf pool layer**

In [None]:
# create the rf layer
img_shape = torch.Size((53,53))
offset = [0., -30.] # right visual field (3deg)
ref_axis = 0. # set the reference axis for the crowding configurations
RF_rate = 0.2
gap = 0.
n_rings = 10
std = 1.
mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, RF_rate, gap, n_rings=n_rings, std=std,
                                                        offset=offset)
rf_layer = rf_pool.pool.RF_Pool(mu=mu, sigma=sigma, img_shape=img_shape, 
                                lattice_fn=rf_pool.utils.lattice.mask_kernel_lattice,
                                pool_fn='max_pool', kernel_size=2)
n_kernels = rf_layer.mu.shape[0]
# append the rf pool layer to the model
layer_id = '1'
model.layers[layer_id].forward_layer.add_module('pool', rf_layer)
visualize.heatmap(model, '1');

**Run Experiments**

In [None]:
# set experiment parameters
exp_type = 'accuracy'
n_test = 1000
batch_size = 1

# set lattice kwargs
img_shape = torch.Size((53,53))
RF_rate = 0.2
gap = 0.
n_rings = 10
std = 1.

# get angle between RFs
n_RF = np.floor(np.pi / RF_rate)
angles = 2. * np.pi * np.linspace(0., 1., np.int(n_RF))[:-1]
rot_angle = angles[1]

# set spacings, extents, configuration types
spacings = [1.] #np.linspace(1., 2., 9)
extents = np.arange(7.,29.5,2.5)
crowd_types = ['cross'] #['inner', 'outer', 'radial', 'tangential']

# set function
fn = getattr(experiment_functions, 'get_%s' % exp_type.lower())

In [None]:
# run the experiment
output = {}
for key in crowd_types:
    n_flankers, axis = get_crowd_params(key, ref_axis)
    output_i = {}
    for extent in extents:
        for spacing in spacings:
            # get target data
            target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0,
                                          base_set=base_set, label_map=label_map,
                                          transform=transforms.Compose([transforms.ToTensor(), 
                                                                        jitter_fn(5, 5, seed=0)]))
            target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                                        shuffle=False, num_workers=2)
            # create square data set to get redundancy metric
            s = stimuli.make_crowded_stimuli(torch.ones(20,20), [], 0., (118,118))
            square_set = datasets.Dataset(data=[s], transform=transforms.Compose([transforms.ToTensor(), 
                                          jitter_fn(5, 5, seed=0)]))
            square_loader = torch.utils.data.DataLoader(square_set, batch_size=1, shuffle=False, num_workers=2)
            # get crowding set (without target to account for flankers in get_contribution call)
            crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing, 
                                         base_set=base_set, label_map=label_map,
                                         transform=transforms.Compose([transforms.ToTensor(),
                                                                       jitter_fn(5, 5, seed=0)]))
            crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=batch_size,
                                                       shuffle=False, num_workers=2)
            # set lattice kwargs 
            lattice_fn = lattice.init_foveated_lattice
            lattice_kwargs = {'img_shape': img_shape, 'scale': RF_rate, 'spacing': gap, 
                              'std': std, 'n_rings': n_rings, 'offset': offset, 
                              'rotate': rotate_fn(rot_angle / 2., base_angle=ref_axis, seed=0)}
            # get output_i
            output_i.update({'spacing_%a_extent_%a' % (spacing, extent): 
                             fn(target_loader, crowd_loader, #square_loader,
                                **{'layer_id': layer_id,
                                   'model': model,
                                   'extent': extent,
                                   'lattice_fn': lattice_fn,
                                   'lattice_kwargs': lattice_kwargs})})
            # monitor progress
            clear_output(wait=True)
            display('%s: spacing %a, extent %a' % (key, spacing, extent))
    # update output
    output.update({'%s_%s' % (key, exp_type.lower()): output_i, 'spacing': spacings, 'extent': extents})

In [None]:
with open('results/%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'wb') as f:
    pickle.dump(output, f)

**Get PCA of data**

In [None]:
# set experiment parameters
exp_type = 'accuracy'
n_test = 1000

# set output, redundancy, fidelity
output = pickle.load(open('results/%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'rb'))
redundancy = pickle.load(open('results/redundancy_1k.pkl', 'rb'))
fidelity = pickle.load(open('results/mse_1k.pkl', 'rb'))
key = 'cross'
output_key = '%s_%s' % (key, 'ACC') #exp_type.lower())
fidelity_key = '%s_%s' % (key, 'MSE')
redundancy_key = '%s_%s' % (key, 'redundancy')

# set stat function
stat_fn = lambda x: x[0]

In [None]:
# create data array
X = []
for i, space in enumerate(output.get('spacing')):
    for j, extent in enumerate(output.get('extent')):
        X.append([fidelity.get(fidelity_key).get('spacing_%a_attn_None' % space).sum() / 5., 
                  redundancy.get(redundancy_key).get('spacing_%a_extent_%a' % (1., extent)), 
                  stat_fn(output.get(output_key).get('spacing_%a_extent_%a' % (space, extent)))])
X = np.array(X)

In [None]:
scaler = StandardScaler()
scaler.fit(X)

In [None]:
X = scaler.transform(X)

In [None]:
pca = PCA(n_components=2)

In [None]:
pca.fit(X)

In [None]:
print(pca.explained_variance_ratio_)

In [None]:
print(pca.components_)

In [None]:
s = scaler.inverse_transform([0., 0., 0.])
e = scaler.inverse_transform(pca.components_[0])
data = scaler.inverse_transform(X)

In [None]:
fig = plt.figure()
ax = Axes3D(fig, azim=20.)
ax.scatter(data[:,0], data[:,1], data[:,2])
ax.plot([s[0], e[0]], [s[1], e[1]], [s[2], e[2]], color='red')
plt.xlabel('Fidelity')
plt.ylabel('Redundancy')
ax.set_zlabel('Target Confidence');

**Multiple Linear Regression**

In [None]:
y = []
for i, x in enumerate(X):
    y.append(x[-1])
    X[i][-1] = 1.

In [None]:
clf = linear_model.LinearRegression(fit_intercept=False)
clf.fit(X[:,2:], y)
r2_inter = clf.score(X[:,2:], y)
print('Intercept:', r2_inter)

In [None]:
clf = linear_model.LinearRegression(fit_intercept=False)
clf.fit(X[:,1:], y)
r2_attn = clf.score(X[:,1:], y) - r2_inter
print('Redundancy:', r2_attn)

In [None]:
clf = linear_model.LinearRegression(fit_intercept=False)
clf.fit(X, y)
r2_space = clf.score(X, y) - (r2_attn + r2_inter)
print('Fidelity:', r2_space)