In [None]:
import pickle 

import imp
from IPython.display import clear_output, display
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms

from context import rf_pool

In [None]:
from rf_pool import models, modules, pool, ops
from rf_pool.utils import lattice, functions, visualize, datasets, stimuli

**Load MNIST Data**

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, 
                                       transform=transform)
testset = torchvision.datasets.MNIST(root='../data', train=False, download=True,
                                     transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=1,
                                         shuffle=True, num_workers=2)

In [None]:
# load base set crowded i
base_set = pickle.load(open('crowding_experiment/MNIST_CrowdedDataset.pkl', 'rb'))

**Load Model**

In [None]:
# initialize model
model = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers of model
model.append('0', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(1,32,5),
                                              activation=torch.nn.ReLU(), 
                                              pool=torch.nn.MaxPool2d(2)))
model.append('1', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(32,64,5),
                                              activation=torch.nn.ReLU(),
                                              pool=torch.nn.MaxPool2d(2)))
model.append('2', rf_pool.modules.FeedForward(hidden=torch.nn.Conv2d(64,10,4)))
model.append('3', rf_pool.modules.FeedForward(input_shape=(-1,10)))

In [None]:
# load previous model and results
(_, extras) = model.load_model('crowding_experiment/MNIST_rate_0.2_10k_3deg.pkl')

In [None]:
# remove reshape layer 
model.layers.pop('3')

In [None]:
# Get Peak Feature Value
peak = 0.
for i, (data, label) in enumerate(testloader):
    tmp_peak = torch.max(model.apply_layers(data, ['0','1']).detach()).item()
    if tmp_peak > peak:
        peak = tmp_peak
    clear_output(wait=True)
    display('progress %a' % ((i+1) / len(testloader)))

In [None]:
peak = 10.207415580749512

In [None]:
# update RF layer with attention at target location
img_shape = torch.Size((53,53))
offset = [0., -30.]
RF_rate = 0.2
gap = 0.
n_rings = 10
std = 1.
mu, sigma = rf_pool.utils.lattice.init_foveated_lattice(img_shape, RF_rate, gap, n_rings=n_rings, std=std,
                                                        offset=offset)
rf_layer = rf_pool.pool.RF_Pool(mu=mu, sigma=sigma, img_shape=img_shape, 
                                lattice_fn=rf_pool.utils.lattice.mask_kernel_lattice,
                                pool_fn='max_pool', kernel_size=2, retain_shape=True)
rf_layer.show_lattice()
print(rf_layer.mu.shape)
n_kernels = rf_layer.mu.shape[0]

layer_id = '1'
model.layers[layer_id].forward_layer.add_module('pool', rf_layer)
visualize.heatmap(model, '1');

In [None]:
# Get Sigmas for Density/Size Tests
sigma_test = np.unique(sigma)[-4:-1]

**Set functions for experiments**

In [None]:
def get_crowd_params(crowd_type):
    if crowd_type in ['outer','inner']:
        n_flankers = 1
    else:
        n_flankers = 2
    if crowd_type == 'inner':
        axis = np.pi
    elif crowd_type == 'tangential':
        axis = np.pi / 2.
    else:
        axis = 0.
    return n_flankers, axis

def create_crowd_set(dataset, n_images, img_size, n_flankers, axis, spacing, base_set=None,
                     label_map=None, no_target=False, transform=transforms.ToTensor()):
    if base_set is None:
        crowd_set = datasets.CrowdedDataset(dataset, n_flankers, n_images, no_target=no_target,
                                            load_previous=False, label_map=label_map,
                                            transform=transform,
                                            spacing=20*spacing, background_size=img_size, axis=axis)
    else:
        crowd_set = datasets.CrowdedDataset(dataset, n_flankers, n_images,
                                            base_set.recorded_target_indices,
                                            base_set.recorded_flanker_indices,
                                            no_target=no_target,
                                            load_previous=True, label_map=label_map,
                                            transform=transform,
                                            spacing=20*spacing, background_size=img_size, axis=axis)
    return crowd_set

In [None]:
def apply_attention_field(model, layer_id, mu, sigma, loc, extent):
    # update rfs with spatial extent
    img_shape = model.layers[layer_id].forward_layer.pool.get(['img_shape'])[0]
    attn_field = torch.zeros(img_shape)
    attn_field[loc[0],loc[1]] = 1./extent
    new_mu, new_sigma = lattice.update_mu_sigma(mu, sigma, attn_field)
    model.layers[layer_id].forward_layer.pool.set(mu=new_mu, sigma=new_sigma)
    return model

In [None]:
def get_SNR_accuracy(target_loader, crowd_loader, peak, batch_size=1, model=None, RF_mask=None, 
                     extent=None, lattice_fn=None, lattice_kwargs=None):
    # parse lattice_kwargs
    if lattice_kwargs is not None:
        lattice_kwargs.setdefault('rotate', 0.)
        rotate = lattice_kwargs.pop('rotate')
        if type(rotate) is not type(lambda : 0.):
            rotate_fn = lambda : rotate
        else:
            rotate_fn = rotate
    else:
        rotate_fn = None
    # init SNR, acc, counters
    SNR = torch.zeros(1)
    correct = torch.zeros(1)
    SNR_is = []
    acc_is = []
    cnt = 0.
    # set mask_i if given
    if RF_mask is not None:
        mask_i = RF_mask.clone()
    # get SNR, accuracy for each image
    for i, ((target, labels), (crowd, _)) in enumerate(zip(target_loader, crowd_loader)):
        # reset RFs
        if lattice_fn is not None and lattice_kwargs is not None:
            if rotate_fn is not None:
                mu, sigma = lattice_fn(**lattice_kwargs, rotate=rotate_fn())
            else:
                mu, sigma = lattice_fn(**lattice_kwargs)
            model.layers['1'].forward_layer.pool.set(mu=mu, sigma=sigma)
        # get mask
        if RF_mask is None:
            mask_i = model.rf_index(target, '1', thr=0.1).float()
        # attention
        if extent:
            model = apply_attention_field(model, layer_id, mu, sigma, [26,26], extent)
        # get target signal and noise (crowd)
        with torch.no_grad():
            target_output = model.rf_output(target, '1', retain_shape=True)
            signal = torch.max(target_output.flatten(-2), -1)[0] + torch.rand(target_output.shape[:-2])
            crowd_output = model.rf_output(crowd, '1', retain_shape=True)
            noise = torch.max(crowd_output.flatten(-2), -1)[0] + torch.rand(crowd_output.shape[:-2])
            # mask crowd_output and pass forward to get accuracy
            masked_output = torch.max(torch.mul(crowd_output, mask_i.reshape(batch_size, 1, -1, 1, 1)), 2)[0]
            output = torch.max(model.apply_layers(masked_output, ['2']).flatten(-2), -1)[0]
            correct_i = torch.sum(torch.max(output, -1)[1] == labels).item()
            correct += correct_i
        # peak snr numerator and denominator
        MSE = torch.mean(torch.pow(signal - noise, 2), dim=[1])
        # compute SNR
        SNR_i = 10. * torch.log10(torch.div(torch.pow(peak, 2), MSE))
        SNR_i = torch.mul(SNR_i, mask_i)
        SNR = SNR + torch.sum(SNR_i)
        SNR_is.append(torch.div(torch.sum(SNR_i), torch.sum(mask_i)).item())
        acc_is.append(correct_i)
        cnt += torch.sum(mask_i).item()
    # get pct correct, SNR
    pct_correct = (correct / len(target_loader)).item()
    avg_SNR = torch.div(SNR, cnt).item()
    return (avg_SNR, pct_correct, SNR_is, acc_is)

In [None]:
def get_heatmaps(target_loader, crowd_loader, peak, batch_size=1, model=None, RF_mask=None, 
                 extent=None, lattice_fn=None, lattice_kwargs=None):
    # parse lattice_kwargs
    if lattice_kwargs is not None:
        lattice_kwargs.setdefault('rotate', 0.)
        rotate = lattice_kwargs.pop('rotate')
        if type(rotate) is not type(lambda : 0.):
            rotate_fn = lambda : rotate
        else:
            rotate_fn = rotate
    else:
        rotate_fn = None
    # init SNR, mask
    n_kernels = model.layers['1'].forward_layer.pool.mu.shape[0]
    SNR = torch.zeros(1, n_kernels)
    mask = torch.zeros(1, n_kernels)
    # set mask_i if given
    if RF_mask is not None:
        mask_i = RF_mask.clone()
    # get SNR heatmap for each image
    for i, ((target, labels), (crowd, _)) in enumerate(zip(target_loader, crowd_loader)):
        # update RFs
        if lattice_fn is not None and lattice_kwargs is not None:
            if rotate_fn is not None:
                mu, sigma = lattice_fn(**lattice_kwargs, rotate=rotate_fn())
            else:
                mu, sigma = lattice_fn(**lattice_kwargs)
            model.layers['1'].forward_layer.pool.set(mu=mu, sigma=sigma)
        # get mask_i, add to mask
        if RF_mask is None:
            mask_i = model.rf_index(target, '1', thr=0.1).float()
        mask = mask + mask_i
        # attention
        if extent:
            model = apply_attention_field(model, layer_id, mu, sigma, [26,26], extent)
        # get target signal and noise (crowd)
        with torch.no_grad():
            target_output = model.rf_output(target, '1', retain_shape=True)
            signal = torch.max(target_output.flatten(-2), -1)[0] + torch.rand(target_output.shape[:-2])
            crowd_output = model.rf_output(crowd, '1', retain_shape=True)
            noise = torch.max(crowd_output.flatten(-2), -1)[0] + torch.rand(crowd_output.shape[:-2])
        # get MSE between target and crowding stimuli
        MSE = torch.mean(torch.pow(signal - noise, 2), dim=[1])
        # compute SNR
        SNR_i = 10. * torch.log10(torch.div(torch.pow(peak, 2), MSE))
        SNR_i = torch.mul(SNR_i, mask_i)
        SNR = SNR + SNR_i
    # get heatmap
    SNR_hm = torch.div(SNR, mask)
    return SNR_hm

**Test RF size and density**

In [None]:
# set label mapping
label_map = {}
label_map.update([(n,n) for n in range(10)])

In [None]:
# get crowding stimulus configuration based on key
n_test = 100
spacing = 1.
batch_size = 1
n_flankers, axis = get_crowd_params('radial')
# get target, and crowd data
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing,
                             base_set, label_map)
crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=batch_size,
                                                   shuffle=False, num_workers=2)

In [None]:
exp_type = 'Density'
img_shape = torch.Size((53,53))
center = (torch.tensor(img_shape) - 1.) / 2.

if exp_type.lower() == 'size':
    rng = np.arange(2., 23.)
else:
    rng = np.arange(2.25, 2.75, 0.25)
RF_SNR = {}
for n in range(3):
    # RF size
    if exp_type.lower() == 'size':
        mu = rf_pool.utils.lattice.init_uniform_lattice(center, 1, 0.)[0]
        model.layers['1'].forward_layer.pool.set(mu=mu)
        param_space = [torch.tensor([x]) for x in rng]
        param_name = ['model', 'layers', '1', 'forward_layer', 'pool', 'sigma']
    else: # RF density
        lattice_kwargs = {'center': center, 'n_kernel_side': (2,2), 'spacing': 0.,
                          'sigma_init': sigma_test[n], 'rotate': lambda : np.pi * np.random.rand()}
        sigma_init = sigma_test[n]
        param_space = [s*sigma_test[n] for s in rng]
        param_name = ['lattice_kwargs', 'spacing']
    # Get SNR and accuracy
    RF_mask = model.rf_index(target_set[0][0].unsqueeze(0), '1', thr=0.1).float()
    SNR = functions.param_search(get_SNR_accuracy, [target_loader, crowd_loader, torch.tensor(peak)], 
                                 {'model': model, 'RF_mask': RF_mask,
                                  'lattice_fn': lattice.init_uniform_lattice,
                                  'lattice_kwargs': lattice_kwargs},
                                 param_name, param_space, verbose=False, show_cost=False)
    if exp_type.lower() == 'size':
        RF_SNR.update({'cost_%d' % n: SNR, 'sigma_%d' % n: [s for s in rng]})
    else:
        RF_SNR.update({'cost_%d' % n: SNR, 'sigma_%d' % n: [s * sigma_test[n] for s in rng]})

In [None]:
with open('crowding_experiment/PSNR_%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'wb') as f:
    pickle.dump(RF_SNR, f)

**Test Attention**

In [None]:
exp_type = 'Attention'
RF_SNR = {}
# get crowding stimulus configuration based on key
n_test = 10
batch_size = 1
spacing = 1.
# get target data
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map)
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
for key in ['radial','tangential']: #'outer','inner','radial',
    n_flankers, axis = get_crowd_params(key)
    crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing,
                                 base_set, label_map)
    crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=batch_size,
                                                   shuffle=False, num_workers=2)
    param_space = np.arange(7., 29.5, 2.5)
    param_name = ['extent']
    lattice_fn = lattice.init_foveated_lattice
    lattice_kwargs = {'img_shape': img_shape, 'scale': RF_rate, 'spacing': gap, 
                      'std': std, 'n_rings': n_rings, 'offset': offset, 'rotate': lambda : np.pi * np.random.rand()}
    # Get SNR and accuracy
    SNR = functions.param_search(get_SNR_accuracy, [target_loader, crowd_loader, torch.tensor(peak)], 
                                 {'batch_size': 1, 'model': model, 'extent': None,
                                  'lattice_fn': lattice_fn, 'lattice_kwargs': lattice_kwargs},
                                 param_name, param_space, verbose=False, show_cost=False)
    RF_SNR.update({key + '_attn': SNR, 'extent': param_space})

In [None]:
with open('crowding_experiment/PSNR_%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'wb') as f:
    pickle.dump(RF_SNR, f)

**Test Spacing**

In [None]:
def rotate_fn(max_angle, base_angle=0., seed=0):
    gen = np.random.RandomState(seed=seed)
    return lambda : (2. * gen.rand() - 1.) * max_angle + base_angle

def jitter_fn(max_w, max_h, seed=0):
    gen = np.random.RandomState(seed=seed)
    w_fn = lambda : np.int((2. * gen.rand() - 1.) * max_w)
    h_fn = lambda : np.int((2. * gen.rand() - 1.) * max_h)
    return lambda x: torch.roll(x, [w_fn(),h_fn()], dims=(-2,-1))

In [None]:
# get angle between RFs
n_RF = np.floor(np.pi / RF_rate)
angles = 2. * np.pi * np.linspace(0., 1., np.int(n_RF))[:-1]
rot_angle = angles[1]

In [None]:
exp_type = 'Spacing'
RF_SNR = {}
n_test = 100
batch_size = 1
# set lattice kwargs
img_shape = torch.Size((53,53))
RF_rate = 0.2
gap = 0.
n_rings = 10
std = 1.
# get target data
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map,
                              transform=transforms.Compose([transforms.ToTensor(), jitter_fn(5, 5, seed=0)]))
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
for key in ['outer','inner','radial','tangential']:
    n_flankers, axis = get_crowd_params(key)
    SNR = []
    for spacing in np.arange(1., 2.25, 0.25):
        clear_output(wait=True)
        display('%s: %a' % (key, spacing))
        crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing, base_set, label_map,
                                     transform=transforms.Compose([transforms.ToTensor(), jitter_fn(5, 5, seed=0)]))
        crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=batch_size,
                                                   shuffle=False, num_workers=2)
        lattice_fn = lattice.init_foveated_lattice
        lattice_kwargs = {'img_shape': img_shape, 'scale': RF_rate, 'spacing': gap, 
                          'std': std, 'n_rings': n_rings, 'offset': offset, 
                          'rotate': rotate_fn(rot_angle / 2., base_angle=0, seed=0)}
        # Get SNR and accuracy
        SNR.append(get_SNR_accuracy(target_loader, crowd_loader, torch.tensor(peak), 
                                    **{'batch_size': batch_size, 'model': model, 'extent': None, 
                                       'lattice_fn': lattice_fn,
                                       'lattice_kwargs': lattice_kwargs}))
    RF_SNR.update({key + '_space': SNR, 'spacing': np.arange(1., 2.25, 0.25)})

In [None]:
with open('crowding_experiment/PSNR_%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'wb') as f:
    pickle.dump(RF_SNR, f)

**Get Heatmaps**

In [None]:
exp_type = 'Heatmaps'
RF_SNR = {}
n_test = 100
batch_size = 1
# set lattice kwargs
img_shape = torch.Size((53,53))
RF_rate = 0.2
gap = 0.
n_rings = 10
std = 1.
# get target data
target_set = create_crowd_set(testset, n_test, 118, 0, 0, 0, base_set=base_set, label_map=label_map,
                              transform=transforms.Compose([transforms.ToTensor(), jitter_fn(5, 5, seed=0)]))
target_loader = torch.utils.data.DataLoader(target_set, batch_size=batch_size,
                                            shuffle=False, num_workers=2)
# get SNR heatmaps
for key in ['outer','inner','radial','tangential']:
    n_flankers, axis = get_crowd_params(key)
    SNR = []
    for extent in [None, 7., 27.]:
        for spacing in [1., 2.]:
            # if extent and spacing > 1, skip
            if extent is not None and spacing > 1.:
                continue
            # monitor progress
            clear_output(wait=True)
            display('%s: spacing %a, extent %a' % (key, spacing, extent))
            # get crowding set
            crowd_set = create_crowd_set(testset, n_test, 118, n_flankers, axis, spacing, base_set, label_map,
                                         transform=transforms.Compose([transforms.ToTensor(), jitter_fn(5, 5, seed=0)]))
            crowd_loader = torch.utils.data.DataLoader(crowd_set, batch_size=batch_size,
                                                       shuffle=False, num_workers=2)
            # set lattice kwargs 
            lattice_fn = lattice.init_foveated_lattice
            lattice_kwargs = {'img_shape': img_shape, 'scale': RF_rate, 'spacing': gap, 
                              'std': std, 'n_rings': n_rings, 'offset': offset, 
                              'rotate': 0.}
            # get heatmap
            SNR.append(get_heatmaps(target_loader, crowd_loader, torch.tensor(peak), 
                                    **{'batch_size': batch_size, 'model': model, 
                                       'extent': extent,
                                       'lattice_fn': lattice_fn,
                                       'lattice_kwargs': lattice_kwargs}))
    RF_SNR.update({key + '_hm': SNR, 'spacing': [1., 2.], 'extent': [7., 27.]})

In [None]:
with open('crowding_experiment/PSNR_%s_%dk.pkl' % (exp_type.lower(), n_test/1000), 'wb') as f:
    pickle.dump(RF_SNR, f)