# inference for Gabor-GLM with ABC methods

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

- we run a classical likelihood-free inference algorithm (SMC-ABC) on the Gabor-GLM simulator
- like SNPE, SMC-ABC iteratively refines a posterior estimate across multiple rounds
- within each rounds, SMC-ABC runs a rejection-sampling scheme that rejects parameters based on a distance measure $d(x,x_o)$.


- the design of this distance measure $d$ can be tricky and requires good summary statistics.
- here we use a standard approach: squared error on (normalized) summary statistics, $d(x,x_o) = || (x-x_o) \ / \ std ||_2$.
- this standard approach does not work well here, since the summary statistics in this case include a $1641$-dimensional spike-triggered average

In [None]:
%%capture
%matplotlib inline

import theano
theano.config.floatX='float64'

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
#from delfi.utils.viz import plot_pdf

from model.gabor_rf import maprf as model
from model.gabor_stats import maprfStats
from utils import setup_sim, setup_sampler, quick_plot, contour_draws

import sys; sys.path.append('../')
from common import col, svg, plot_pdf, samples_nd

from support_files.run_abc import run_smc

# parameters for this experiment

In [None]:
seed = 42    # seed for generation of xo for selected cell. MCMC currently not seeded !

idx_cell = 6 # load toy cell number i = idx_cell

maxsim = int(1e6)
n_particles= int(1e3)

savefile = 'results/SMC/toycell_' + str(idx_cell) + '/maprf_SMC_prior01_run_1_'+ str(n_particles) + 'particles_param9'

# load cell, generate xo

In [None]:
g, prior, d = setup_sim(seed, path='')

filename = 'results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
params_dict_true = np.load(filename, allow_pickle=True)[()]
params_dict_true['kernel']['t'] = {'value' : 1.}

m = g.model
m.params_dict = params_dict_true.copy()
m.rng = np.random.RandomState(seed=seed)

pars_true, obs = m.read_params_buffer(), m.gen_single()
obs_stats = g.summary.calc([obs])

# plot ground-truth receptive field
rf = g.model.params_to_rf(pars_true)[0]
plt.imshow(np.hstack((obs_stats[0,:-1].reshape(d,d), rf)), interpolation='None')
plt.show()

print('spike count', obs_stats[0,-1])

# define distance function (based on pilot runs)

In [None]:
gts, pilots =g.gen(10000)
stats_mean, stats_std = np.zeros((1,pilots.shape[1])), np.ones((1,pilots.shape[1]))

stats_mean[:,-1] = pilots[:,-1].mean()

# firing rate re-normalized to contribute ~20% of total loss (counts 1/4 as much as the d**2 STA pixels)
stats_std[:,:-1] = pilots[:, :-1].std()
stats_std[:,-1]  = 4/(d**2)*pilots[:,-1].std() #

class normed_summary():

    def calc(self, y):

        x = g.summary.calc(y)

        return (x-stats_mean)/stats_std

obs_statz =  (obs_stats.flatten() - stats_mean) /  stats_std

## simulations from ground-truth parameter
- to better understand distances between summary statistics, we simulate repeatedly from the ground-truth parameters $\theta^*$ that originally generated $x_o$.
- these distances $d(x,x_o)$ for $x \sim p(x|\theta^*)$ will typically be smaller those for $x \sim p(x|\theta)$ for prior-drawn $\theta\sim p(\theta)$, and even for $\theta \sim p(\theta|x_o)$.

In [None]:
y_true = g.model.gen([pars_true for i in range(1000)])
x_true = [g.summary.calc(y_true[i]) for i in range(len(y_true))]
stats_true  = np.vstack(x_true)

# unnormalized L2 distance
- no division by stats_std in calculation of distance
- distance in this case dominated by the spike counts

In [None]:
# compute distances over pilot runs
def calc_dist(stats_1, stats_2):
    """Euclidian distance between summary statistics"""
    return np.sqrt(np.sum( ((stats_1 - stats_2)) ** 2))

# compute distances over pilot runs
dists = np.empty(stats_true.shape[0])
for i in range(stats_true.shape[0]):
    dists[i] = calc_dist( stats_true[i], obs_stats.flatten(0))

print(r'x from ground-truth $\theta^*$')
print('minimal distance: ', np.min(dists))

# show distance histogram (use to pick initial epsilon, e.g. roughly as median distance)
plt.hist(dists, bins=np.linspace(0,500,100), normed=True, label=r'x from ground-truth $\theta^*$')

dists = np.empty(pilots.shape[0])
for i in range(pilots.shape[0]):
    dists[i] = calc_dist( pilots[i], obs_stats.flatten() )
# show distance histogram (use to pick initial epsilon, e.g. roughly as median distance)
plt.hist(dists, bins=np.linspace(0,500,100), normed=True, label=r'x from prior-drawn $\theta$')
plt.legend()
plt.xlabel('distance')
plt.ylabel('rel. frequency')
print(r'x from prior-drawn $\theta$')
print('minimal distance: ', np.min(dists))
plt.show()

# STA-only distances
- note that even for samples $x \sim p(x|\theta^*)$ from ground-truth parameters (blue), minimal distances $d(x,x_o)$ are clearly >>0.
- summary-stats space is so high-dimensional and noise-driven that the closest distances are not achieved by
  simulations $x$ from ground-truth parameters $\theta^*$, but by $x$ from parameters $\theta$ that yield next-to-zero firing rates (with $||x_o||_2$ as an effective lower bound)!

In [None]:
# compute distances over pilot runs
def calc_dist(stats_1, stats_2):
    """Euclidian distance between summary statistics"""
    return np.sqrt(np.sum( ((stats_1[:-1] - stats_2[:-1])) ** 2))

# compute distances over pilot runs
dists = np.empty(stats_true.shape[0])
for i in range(stats_true.shape[0]):
    dists[i] = calc_dist( stats_true[i], obs_stats.flatten())

# show distance histogram (use to pick initial epsilon, e.g. roughly as median distance)
plt.hist(dists, bins=np.linspace(0,10,100), normed=True, label=r'x from ground-truth $\theta^*$')

dists = np.empty(pilots.shape[0])
for i in range(pilots.shape[0]):
    dists[i] = calc_dist( pilots[i], obs_stats.flatten() )
plt.hist(dists, bins=np.linspace(0,10,100), normed=True, label=r'x from prior-drawn $\theta$')

L2o = np.sqrt(np.sum(obs_stats[0,:-1]**2))
plt.plot([L2o, L2o], [0,5], 'g', label=r'$||x_o||_2$')

plt.legend()
plt.xlabel('distance')
plt.ylabel('rel. frequency')

plt.show()

print('L2 norm of observed STA:', L2o)

# normalized L2 distance
- a weighted L2 distance with 80% of loss coming from errors in STAs and 20% from errors in spike count
- note that 20% roughly corresponds to 2 out of 9 model parameters (gain and bias) determining the firing rate
- $x$ generated from ground-truth parameters now have lower distances $d(x,x_o)$ than those $x$ from prior-drawn $\theta$
- shape information (STAs) with 80% still dominates the overall loss

In [None]:
# compute distances over pilot runs

def calc_dist(stats_1, stats_2):
    """Euclidian distance between summary statistics"""
    return np.sqrt(np.sum( ((stats_1 - stats_2)/stats_std) ** 2))

# compute distances over pilot runs
dists = np.empty(stats_true.shape[0])
for i in range(stats_true.shape[0]):
    dists[i] = calc_dist( stats_true[i], obs_stats )

print(np.min(dists))
# show distance histogram (use to pick initial epsilon, e.g. roughly as median distance)
plt.hist(dists, bins=np.linspace(0,150,100), normed=True, label=r'x from ground-truth $\theta^*$')

dists = np.empty(pilots.shape[0])
for i in range(pilots.shape[0]):
    dists[i] = calc_dist( pilots[i], obs_stats )
plt.hist(dists, bins=np.linspace(0,150,100), normed=True, label=r'x from prior-drawn $\theta$')

plt.legend()
plt.show()

# visualize 10 closest summary stats to xo under normalized L2 distance
- due to the noise on each pixel of the high-dimensional STAs, almost no information on the actual RF shape can be extracted from pixel-bases L2 loss ...
- closest $x$ to $x_o$ mostly defined through their spike count (307 spikes in total for $x_o$)

In [None]:
PANEL_A = 'svg/SMC_ABC__loss_10samples.svg'

In [None]:
lvls = [0.2, 0.2]

import matplotlib as mpl
with mpl.rc_context(fname='../.matplotlibrc'):

    plt.figure(figsize=(10, 4.2))
    for i in range(10):

        plt.subplot(2,5,i+1)
        idx = np.argsort(dists)[i]

        x = (pilots[idx,:] - stats_mean) / stats_std
        plt.imshow(x[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
        plt.title(str(int(pilots[idx,-1])) + ' spikes', loc='right')
        rfm = g.model.params_to_rf(gts[idx,:].reshape(-1))[0]
        plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors=[col['SMC']])
        plt.axis('off')

    plt.subplots_adjust( wspace=0.2, hspace=0.1, left=0.1, bottom=0.12)

    plt.savefig(PANEL_A, transparent=True)
    plt.show()

# run ABC

### run SMC

In [None]:
!mkdir -p results/SMC/toycell_6/

seed = 42 # SMC seed
eps_init = dists[np.argmin( (dists-np.median(dists))**2 )]
print(eps_init)

all_ps, all_xs, all_logweights, all_eps, all_nsims = run_smc(model=g.model, prior=g.prior, summary=normed_summary(),
                                                     obs_stats=obs_statz,
                                                     seed=seed, fn=savefile,
                                                     n_particles=n_particles, eps_init=eps_init, maxsim=maxsim)

# load results

In [None]:
res = np.load('results/SMC/toycell_6/maprf_SMC_prior01_run_1_1000particles_param9.npy',
    allow_pickle=True)[()]

for k,v in res.items():
    globals()[k] = v

# check results
- sampled posteriors on 'bias' and 'gain' (restricting spike count and firing rate) are much tighter relative to the priors
- marginals over remaining parameters (defining the shape and location of the RF) are very similar to the prior, i.e. we failed to learn much about $\theta$ from $x_o$.

In [None]:
fig_inches = (10,10)

lims_samples = np.vstack([g.prior.m - 5*np.sqrt(np.diag(g.prior.S)), g.prior.m + 5*np.sqrt(np.diag(g.prior.S))]).T

labels_params = ['bias', 'log gain', 'logit phase', 'log freq', 'logit angle',
                'log ratio', 'log width', 'logit $x_o$', 'logit $y_o$']

with mpl.rc_context(fname='../.matplotlibrc'):
    fig,axes = plot_pdf(g.prior,
                        samples=all_ps[-1].T,
                        levels=(0.9499, 0.95),
                        gt=pars_true.flatten(),
                        figsize=(10,10),
                        lims=lims_samples,
                        labels_params=labels_params,
                        col1=col['SMC'],
                        col2=(0.55,0.,0.),
                        col4=col['GT'],);
    

    for i in range(g.prior.ndim):
        axes[i,i].set_xticks([lims_samples[i, 0], lims_samples[i, 1]])
        axes[i,i].set_yticks([])

    sns.despine(offset=5, left=True)

    PANEL_B = 'svg/supp_smc_abc.svg'
    fig.savefig(PANEL_B, transparent=True)

## store results

In [None]:
np.save(savefile, {'eps_init' : eps_init,
                    'obs_statz' : obs_statz,
                    'obs_stats' : obs_stats,
                    'n_particles' : n_particles,
                    'maxsim' : maxsim,
                    'stats_mean' : stats_mean,
                    'stats_std' : stats_std,
                    'all_ps' : all_ps,
                    'all_logweights' : all_logweights,
                    'all_eps' : all_eps,
                    'all_nsims' : all_nsims,
                    'params_dict_true' : params_dict_true})

## Compose figure

In [None]:
from svgutils.compose import *

# > Inkscape pixel is 1/90 of an inch, other software usually uses 1/72.
# > http://www.inkscapeforum.com/viewtopic.php?f=6&t=5964
svg_scale = 1.25  # set this to 1.25 for Inkscape, 1.0 otherwise

# Panel letters in Helvetica Neue, 12pt, Medium
kwargs_text = {'size': '12pt', 'font': 'Arial', 'weight': '800'}

pxw = 720
pxh = 760

PANEL_A = 'svg/SMC_ABC__loss_10samples.svg'
PANEL_B = 'svg/supp_smc_abc.svg'
PANEL_C = 'fig/fig3_gabor_supp_smc_abc_comps.svg'

f = Figure("20.3cm", "28.3cm",

    Panel(
          SVG(PANEL_A).scale(1.22).move(3, 5),
          Text("a", 0, 13, **kwargs_text),
    ).move(0, 0),

    Panel(
          SVG(PANEL_B).scale(svg_scale).move(15,10),
          Text("b", 0, 13, **kwargs_text),
    ).move(0, 310),

    Panel(
          SVG(PANEL_C).scale(svg_scale).move(15,15),
          Text("c", 0, 13, **kwargs_text),
    ).move(0, 730),
           
    #Grid(10, 10),
)

f.save("fig/fig3_gabor_supp_smc_abc.svg")
svg('fig/fig3_gabor_supp_smc_abc.svg')