In [None]:
import pickle 

import imp
from IPython.display import clear_output, display
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

In [None]:
from rf_pool import models, modules, layers, ops
from rf_pool.utils import lattice, functions, visualize, datasets, stimuli

**Load MNIST Data**

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, 
                                       transform=transform)
testset = torchvision.datasets.MNIST(root='../data', train=False, download=True,
                                     transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=True, num_workers=2)

In [None]:
# load base set crowded i
base_set = pickle.load(open('crowding_experiment/MNIST_CrowdedDataset.pkl', 'rb'))

**Load Model**

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,32,5),
                                              activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(64,10,4)))
model.append('3', rf_pool.modules.FeedForward(input_shape=(-1,10)))

In [None]:
# load previous model and results
(_, extras) = model.load_model('crowding_experiment/attention_3deg.pkl')
interference_crop = extras.get('interference_crop')
similarity_rfs = extras.get('similarity_rfs')

In [None]:
# remove reshape layer 
model.layers.pop('3')

In [None]:
# update RF layer with attention at target location
img_shape = torch.Size((53,53))
mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.23, 0., n_rf=None, n_rings=11, 
                                                        offset=[0.,-30.], rotate_rings=False)
rf_layer = rf_pool.layers.RF_Pool(mu=mu, sigma=sigma, img_shape=img_shape, 
                                  lattice_fn=rf_pool.utils.lattice.mask_kernel_lattice,
                                  pool_type=None, kernel_size=2, thr=np.exp(-1.))
rf_layer.show_lattice()
print(rf_layer.mu.shape)
n_kernels = rf_layer.mu.shape[0]

layer_id = '1'
model.layers[layer_id].forward_layer.add_module('pool', rf_layer)

**Set functions for experiments**

In [None]:
def get_crowd_params(crowd_type):
    if crowd_type in ['outer','inner']:
        n_flankers = 1
    else:
        n_flankers = 2
    if crowd_type == 'inner':
        axis = np.pi
    else:
        axis = 0.
    return n_flankers, axis

def create_crowd_set(dataset, n_images, img_size, n_flankers, axis, spacing, base_set=None,
                     label_map=None, no_target=False):
    if base_set is None:
        crowd_set = datasets.CrowdedDataset(dataset, n_flankers, n_images, no_target=no_target,
                                            load_previous=False, label_map=label_map,
                                            transform=transforms.ToTensor(),
                                            spacing=20*spacing, background_size=img_size, axis=axis)
    else:
        crowd_set = datasets.CrowdedDataset(dataset, n_flankers, n_images,
                                            base_set.recorded_target_indices,
                                            base_set.recorded_flanker_indices,
                                            no_target=no_target,
                                            load_previous=True, label_map=label_map,
                                            transform=transforms.ToTensor(),
                                            spacing=20*spacing, background_size=img_size, axis=axis)
    return crowd_set

In [None]:
def apply_attention_field(model, layer_id, mu, sigma, loc, extent):
    # update rfs with spatial extent
    img_shape = model.layers[layer_id].forward_layer.pool.get(['img_shape'])[0]
    attn_field = torch.zeros(img_shape)
    attn_field[loc[0],loc[1]] = 1./extent
    new_mu, new_sigma = lattice.update_mu_sigma(mu, sigma, attn_field)
    model.layers[layer_id].forward_layer.pool.set(mu=new_mu, sigma=new_sigma)
    return model

In [None]:
def cropped_interference(model, layer_id, target, target_flanker, crop):
    layer_ids = model.pre_layer_ids(layer_id) + [layer_id]
    with torch.no_grad():
        target_crop = model.apply_layers(target, layer_ids)[:,:,crop[0],crop[1]].flatten(1)
        crowd_crop = model.apply_layers(target_flanker, layer_ids)[:,:,crop[0],crop[1]].flatten(1)
    return 1. - functions.pairwise_cosine_similarity(target_crop, crowd_crop)

def rf_cosine_similarity(model, layer_id, target, flanker, target_flanker):
    with torch.no_grad():
        target_rfs = model.rf_output(target, '1', retain_shape=True)[0]
        target_rfs = torch.transpose(target_rfs, 1, 2).flatten(2)
        target_rfs = torch.mul(target_rfs, model.rf_index(target, '1', thr=0.1).unsqueeze(-1).float())
        flanker_rfs = model.rf_output(flanker, '1', retain_shape=True)[0]
        flanker_rfs = torch.transpose(flanker_rfs, 1, 2).flatten(2)
        flanker_rfs = torch.mul(flanker_rfs, model.rf_index(flanker, '1', thr=0.1).unsqueeze(-1).float())
        crowd_rfs = model.rf_output(target_flanker, '1', retain_shape=True)[0]
        crowd_rfs = torch.transpose(crowd_rfs, 1, 2).flatten(2)
        crowd_rfs = torch.mul(crowd_rfs, model.rf_index(target_flanker, '1', thr=0.1).unsqueeze(-1).float())
    target_rfs_sim = functions.pairwise_cosine_similarity(target_rfs, crowd_rfs, axis=-1)
    flanker_rfs_sim = functions.pairwise_cosine_similarity(flanker_rfs, crowd_rfs, axis=-1)
    return target_rfs_sim, flanker_rfs_sim

In [None]:
def mask_similarity(x, y):
    mask = (torch.isnan(x) * torch.isnan(y)) + (torch.ge(x, 0.9999) * torch.ge(y, 0.9999))
    return 1. - mask.float()

In [None]:
cmap_gray = cm.get_cmap("gray")
font = {'family' : 'arial',
        'size'   : 10}
matplotlib.rc('font', **font)

def plot_interference(x, y, std, label, color, filename=None):
    fig, ax = plt.subplots(1,1, figsize=(5,4))
    for var, s, name, c in zip(y, std, label, color):
        ax.errorbar(x, var, yerr = s, color=c, lw=3, fmt='o-', label=name, ms=10)
    ax.yaxis.grid(which="major", color=cmap_gray(.8), linestyle='--', linewidth=1)
    ax.set_xlabel("Attentional Field Extent (DVA)")
    ax.set_ylabel("Mean Interference")
    plt.legend()
    plt.show()
    if filename:
        fig.savefig(filename, dpi=600)

**Get cropped interference**

In [None]:
# set batch size, image size, and test size
batch_size = 1
img_size = 118
n_test = 1000
# set spacing
spacing = 1.
# set label mapping
label_map = {}
label_map.update([(n,n) for n in range(10)])
# get crowded MNIST training data for targets
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                           shuffle=False, num_workers=2)

In [None]:
# init crowd interference dictionary
interference_crop = {'extent': np.arange(10, 22.5, 2.5),
             'outer': [], 
             'inner': [], 
             'radial': []}

In [None]:
# for each type, get interference
for key in ['outer','inner','radial']:
    # get crowded MNIST training data
    n_flankers, axis = get_crowd_params(key)
    crowd_set = create_crowd_set(testset, n_test, img_size, n_flankers, axis, spacing, base_set, label_map)
    for e, extent in enumerate(list(interference_crop.get('extent'))):
        # apply attention
        mu, sigma = lattice.init_foveated_lattice(img_shape, 0.23, 0., n_rf=None, n_rings=11, 
                                                  offset=[0.,-30.], rotate_rings=False)
        model = apply_attention_field(model, '1', mu, sigma, [26,26], extent)
        # get interference for each 
        target_crop_int = []
        for i, ((target, _), (target_flanker, _)) in enumerate(zip(target_set, crowd_set)):
            # get cropped interference
            t_int = cropped_interference(model, layer_id, target.unsqueeze(0),
                                         target_flanker.unsqueeze(0),
                                         crop=(slice(11,15), slice(11,15)))
            if not torch.isnan(t_int):
                target_crop_int.append(t_int.item())
            # monitor
            clear_output(wait=True)
            display('%s %s' % (key, extent))
            display('%0.3f' % (i / len(target_loader)))
            display('interference: %f' % torch.mean(torch.tensor(target_crop_int)))
        # update and save
        interference_crop.update({key: list(interference_crop.get(key)) + [target_crop_int]})

**Get RF Cosine similarity for (Target, Target+Flanker) and (Flanker, Target+Flanker)**

In [None]:
# set batch size, image size, and test size
batch_size = 1
img_size = 118
n_test = 100
# set spacing
spacing = 1.
# set label mapping
label_map = {}
label_map.update([(n,n) for n in range(10)])
# get crowded MNIST training data for targets
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                           shuffle=False, num_workers=2)

In [None]:
# init crowd interference dictionary
similarity_rfs = {'extent': np.arange(10, 22.5, 2.5),
                 'outer': [], 'outer_flanker': [], 
                 'inner': [], 'inner_flanker': [], 
                 'radial': [], 'radial_flanker': []}

In [None]:
# for each type, get interference
for key in ['outer','inner','radial']:
    n_flankers, axis = get_crowd_params(key)
    # get flanker/crowded MNIST test data
    flanker_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing, 
                                   base_set, label_map, no_target=True)
    crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing,
                                 base_set, label_map)
    for e, extent in enumerate(list(similarity_rfs.get('extent'))):
        # update rfs with spatial extent
        mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.23, 0., n_rf=None, n_rings=11, 
                                                                offset=[0.,-30.], rotate_rings=False)
        model = apply_attention_field(model, layer_id, mu, sigma, [26,26], extent)
        # get interference for each 
        target_rfs_sim = []
        flanker_rfs_sim = []
        for i, ((target, _), (flanker, _), (target_flanker, _)) in enumerate(zip(target_set, flanker_set, crowd_set)):
            # get cosine similarity 
            t_sim, f_sim = rf_cosine_similarity(model, layer_id, target.unsqueeze(0), 
                                                flanker.unsqueeze(0), target_flanker.unsqueeze(0))
            target_rfs_sim.append(t_sim)
            flanker_rfs_sim.append(f_sim)
            # monitor
            clear_output(wait=True)
            display('%s %s' % (key, extent))
            display('%0.3f' % (i / len(target_loader)))
        # update similarity rfs dict 
        similarity_rfs.update({key: list(similarity_rfs.get(key)) + [torch.cat(target_rfs_sim)],
                               key + '_flanker': list(similarity_rfs.get(key + '_flanker'))
                               + [torch.cat(flanker_rfs_sim)]})

**Save model and results**

In [None]:
model.save_model('crowding_experiment/attention_3deg.pkl', {'interference_crop': interference_crop, 
                                                            'similarity_rfs': similarity_rfs})

**Plot cropped interference and RF heatmaps**

In [None]:
# plot interference with extent
x = interference_crop.get('extent') / 10.
y = [[np.mean(interference_crop.get(key)[i]) for i in range(5)] for key in ['outer','inner','radial']]
std = [[np.std(interference_crop.get(key)[i]) for i in range(5)] for key in ['outer','inner','radial']]
filename = 'crowding_experiment/attention_experiments.png'
plot_interference(x, y, std, ['outer','inner','both'], ['blue','green','orange'], filename)

In [None]:
# create heatmaps of attentional differences
for key in ['outer','inner','radial']:
    for i in [0,-1]:
        # update rfs with spatial extent (at center of RF layer)
        mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, 0.23, 0., n_rf=None, n_rings=11, 
                                                                offset=[0.,-30.], rotate_rings=False)
        model = apply_attention_field(model, layer_id, mu, sigma, [26,26], similarity_rfs.get('extent')[i])
        # get target, flanker similarities
        res = str(similarity_rfs.get('extent')[i])
        x = similarity_rfs.get(key)[i].clone()
        y = similarity_rfs.get(key + '_flanker')[i].clone()
        # get mask
        mask = mask_similarity(x, y)
        mask[:, torch.sum(mask, 0) < 5] = 0
        # remove nans from similarities
        x[torch.isnan(x)] = 0.
        y[torch.isnan(y)] = 0.
        # get mean difference between target and flanker similarities
        scores = torch.div(torch.sum(torch.mul(x - y, mask), 0), torch.sum(mask, 0))
        # get example crowded images
        n_flankers, axis = get_crowd_params(key)
        crowd_set = create_crowd_set(testset, 1, img_size, n_flankers, axis, spacing, base_set=base_set, 
                                     label_map=label_map)
        visualize.heatmap(model, layer_id, scores, -1., 1., outline_rfs=True, input=crowd_set[0][0][0],
                          filename='crowding_experiment/' + key + '_' + res + '_attention_heatmap.png')