# SNPE & RF

running this notebook will requre the following repositories to be locally installed and on the specified branches:
- delfi 
- lfimodels / maprf_elife 
- maprf_mn  / elife      


In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.utils import setup_sim, get_data_o

root_path = '/home/marcel/Dropbox (mackelab)/team/Write/Working_Manuscripts/InferenceNeuralDynamics_2018/materials/fig2/mapRF'

In [None]:
## training data and true parameters, data, statistics

seed = 42
idx_cell = 6 # load toy cell number i 

tmp = np.load(root_path+'/results/SNPE/toycell_6/ground_truth_data.npy')[()]
obs_stats, pars_true, rf = tmp['obs_stats'],  tmp['pars_true'], tmp['rf']

sim_info = np.load(root_path +'/results/sim_info.npy')[()]
d, params_ls = sim_info['d'], sim_info['params_ls']

assert obs_stats[0,-1] == 307 # the cell we want to work with should have this number of spikes

labels_params = ['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width', 'xo', 'yo']
labels_params_select = ['bias', 'gain', 'phase', 'angle']

# load SNPE results

In [None]:
tmp = np.load(root_path + '/results/SNPE/toycell_6/maprf_100k_elife_prior01_run_9_round4_param9_nosvi_CDELFI_posterior.npy')[()]
posterior, proposal, prior = tmp['posterior'], tmp['proposal'], tmp['prior']

plot_prior = dd.TransformedNormal(m=prior.m, S = prior.S,
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)], 
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a, 
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

lims_post = np.array([[-1.5, -1.1, .001,         0,          .001, 0, 0, -.999, -.999], 
                 [ 1.5,  1.1, .999*np.pi, 2.5,   1.999*np.pi, 2, 4., .999,   .999]]).T

# load comparison against MCMC sampler

In [None]:
n_samples=1000000
path = root_path + '/results/MCMC/'
savefile = path + 'toycell_' + str(idx_cell) + '/maprf_MCMC_prior01_run_1_'+ str(n_samples)+'samples_param9_5min.npy'
tmp = np.load(savefile)[()]

T, params_dict_true = tmp['T'], tmp['params_dict_true']

params_ls = ['bias', 'gain', 'phase', 'freq','angle','ratio','width', 'xo', 'yo']
samples = np.hstack([np.atleast_2d(T[key].T).T for key in params_ls])

def symmetrize_sample_modes(samples):

    assert samples.ndim==2 and samples.shape[1] == 9 

    # assumes phase in [0, pi]
    assert np.min(samples[:,2]) >= 0. and np.max(samples[:,2] <= np.pi)
    # assumes angle in [0, 2*pi]
    assert np.min(samples[:,4]) >= 0. and np.max(samples[:,4] <= 2*np.pi)
    # assumes freq, ratio and width > 0
    assert np.all(np.min(samples[:,np.array([3,5,6])], axis=0) >= 0.)

    samples1 = samples.copy()
    idx = np.where( samples[:,4] > np.pi )[0]
    samples1[idx,4] = samples1[idx,4] - np.pi
    idx = np.where( samples[:,4] < np.pi )[0]
    samples1[idx,4] = samples1[idx,4] + np.pi
    #samples1[:,2] = np.pi - samples1[:,2]
    samples_all = np.vstack((samples, samples1))[::2, :]

    samples1 = samples_all.copy()
    samples1[:,1] = - samples1[:,1] 
    samples1[:,2] = np.pi - samples1[:,2] 
    samples_all = np.vstack((samples_all, samples1))[::2, :]

    return samples_all

samples = symmetrize_sample_modes(samples)

pars_raw = np.array([ params_dict_true['glm']['bias'],
                      params_dict_true['kernel']['s']['gain'],
                      params_dict_true['kernel']['s']['phase'],
                      params_dict_true['kernel']['s']['angle'],
                      params_dict_true['kernel']['s']['freq'],
                      params_dict_true['kernel']['s']['ratio'],
                      params_dict_true['kernel']['s']['width'],
                      params_dict_true['kernel']['l']['xo'],
                      params_dict_true['kernel']['l']['yo'] ])

lims_samples = np.array([[-1.5, -1.5, .001, 0,       .001, 0, 0, -0.5, -0.5], 
                 [ 1.5,  1.5, .999*np.pi, 3, 1.999*np.pi, 3, 5, 0.5, 0.5]]).T

## panel for summary statistics

In [None]:
plt.figure(figsize=(6,6))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
plt.title('STA')

# option to add contours of ground-truth RF
add_gt = False
if add_gt:    
    rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')

plt.show()

## panel for parameters

In [None]:
plt.figure(figsize=(6,6))
plt.imshow(rf, interpolation='None')
plt.title('ground-truth receptive field')
plt.show()


## panel for (partial) posterior

In [None]:
idx = np.array([0,1,2,4]) # bias, gain, frequency and angle

plot_post_select = dd.mixture.MoTG(ms=[x.m[idx] for x in plot_post.xs],
                                  Ss=[x.S[idx][:,idx] for x in plot_post.xs],
                                  a=plot_post.a,
                                  flags=plot_post.flags[idx],
                                  lower=plot_post.lower[idx],
                                  upper=plot_post.upper[idx]                                  
                                 )

plot_prior_select = dd.TransformedNormal(m=plot_prior.m[idx], S = plot_prior.S[idx][:,idx],
                            flags=plot_prior.flags[idx],
                            lower=plot_prior.lower[idx], 
                            upper=plot_prior.upper[idx]) 

fig, _ = plot_pdf(plot_post_select,  pdf2=plot_prior_select, 
                  lims=lims_samples[idx], gt=pars_raw.reshape(-1)[idx], 
                  figsize=(9,9), resolution=100, 
                  samples=samples[:,idx],
                  labels_params=labels_params_select);

## panel for posterior samples 

In [None]:

# this snippet of code requires the mapRF repository (to instantiate g.model)
g, prior, d = setup_sim(seed, path=root_path)
filename = root_path + '/results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
obs_stats, pars_true = get_data_o(filename, g, seed)
rf = g.model.params_to_rf(pars_true)[0]

lvls, n_draws=[0.5, 0.5], 10 
plt.figure(figsize=(6,6))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
for i in range(n_draws):
    rfm = g.model.params_to_rf(posterior.gen().reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()])
    plt.hold(True)
plt.title('RF posterior draws')

rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')

plt.show()


# supplementary figure: full posterior

In [None]:
fig, _ = plot_pdf(plot_post, pdf2=plot_prior, lims=lims_post, gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1), 
                  figsize=(16,16), resolution=100,
                  labels_params=labels_params)


In [None]:
fig, _ = plot_pdf(plot_post,  pdf2=plot_prior, lims=lims_samples, gt=pars_raw.reshape(-1), figsize=(16,16), resolution=100, samples=samples,
         labels_params=labels_params);
