# SNPE & RF

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

# calibration testing, late-round SNPE-A and multimodality

- this notebook implements **simulation-based calibration** (SBC) for the SNPE-A fits to the mapRF application


- an important insight regarding SBC is that it does not test posteriors $p(\theta|x_o)$, but entire conditional densities $p(\theta |x)$
- for SNPE (A/B), this means we can do SBC very cheaply, as these methods return $p(\theta |x)$ 'amortized' if run from the prior $p(\theta)$


- another important insight is that SNPE (A/B) is also 'locally' amortized when run from proposal priors $\tilde{p}(\theta)$,
i.e. for all $x \sim \tilde{p}(x) = \int p(x|\theta) \tilde{p}(\theta) d\theta$
- for SNPE-A, this means we can cheaply test calibration for later rounds (not just the first round) by comparing the calibration of the **uncorrected** conditional density $\tilde{p}(\theta|x)$ (as directly returned by the MDN) against the proposal $\tilde{p}(\theta)$.
- doesn't work for SNPE-B (which on later rounds is valid only for $x\sim\tilde{p}(x)$, but returns the corrected $p(\theta|x)$)


- below, we show results that suggest that SBC can be used to find previously undetected multimodality in the posteriors: 
- in cases where the conditional density for $x\sim \tilde{p}(x)$ is generally multimodal, but the MDN failed to capture that (see first- and second-round results), we find the distribution of rank-statistic to be multimodal 
- this behavior disappears as soon as the MDN 'gets' that the posterior is in fact multimodal (see round #3 below)

In [None]:
%%capture
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from lfimodels.maprf.utils import get_maprf_prior_01, setup_sim, setup_sampler, get_data_o, quick_plot, contour_draws

from lfimodels.maprf.maprf import maprf as model
from lfimodels.maprf.maprfStats import maprfStats

seed = 42

In [None]:
## training data and true parameters, data, statistics

idx_cell = 3 # load toy cell number i 
filename = './results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'
g, prior, d = setup_sim(seed, path='.')
obs_stats, pars_true = get_data_o(filename, g, seed)
rf = g.model.params_to_rf(pars_true)[0]#[10:31, 10:31]

plt.imshow(rf, interpolation='None')
plt.show()

print('spike count', obs_stats[0,-1])

labels_params=['bias', 'gain', 'phase', 'freq', 'angle', 'ratio', 'width', 'xo', 'yo']

In [None]:
class SBC(object):
    """ SIMULATION-BASED CALIBRATION"""
    def __init__(self, generator, inf, f, dim):
        
        self.generator = generator # delfi generator object
        self.inf = inf             # delfi inference object
        self.f = f                 # test-function (maps x->f(x))
        self.dim = dim             # dimensionality of f(x)
        
    def sample_full(self, N):
        out = self.generator.gen(N) # will sample from generator.proposal unless it's None
        return out[0], out[1]

    def get_conditional(self, x):
        return self.inf.predict_uncorrected(x)
    
    def test(self, N, L):
        data = self.sample_full(N)  
        N = data[0].shape[0]
        
        res  = np.empty((N, self.dim))
        
        for i in range(N):
            f0 = self.f(data[0][i,:]).reshape(1,-1)
            p = self.get_conditional(data[1][i,:])
            
            batch = self.f(p.gen(L))
            assert batch.shape==(L, f0.size)
                
            res[i,:] = np.sum( f0 < batch , axis=0)

        return res
    

In [None]:
# some copy-pasted parameter settings for the used network and training regime

filter_sizes=[3,3,3,3,2]   # 5 conv ReLU layers
n_filters=(16,16,32,32,64) # 16 to 64 filters
pool_sizes=[1,2,2,2,2]     # 
n_hiddens=[100,100,100]     # 3 fully connected layers

n_train=100000
n_components=8
n_rounds=1
n_inputs_hidden = 1

lr_decay = 0.99
epochs=20
minibatch=50

svi=False          
reg_lambda=0.      
pilot_samples=1000 
prior_norm = True   
init_norm = False  
rank = None   

run_id = 7 # seventh iteration of network/settings for eLife application
run_id_first = 6 # seventh iteration branches from sixth only after first round


# first round

- first-round SNPE-A fit: single-component MoG fit to all $x\sim p(x)$ resulting from the (broad) prior
- very hard problem (prior big $\rightarrow$ $p(x)$ broad $\rightarrow$ many different $x$ to map onto their $\theta$


- we compare the calibration of $p(\theta|x)$ (as returned by the MDN on round #1) against the prior $p(\theta)$

In [None]:
round_ = 1

g, _, _ = setup_sim(seed, path='.') # reinit generator seed 
_,_ = get_data_o(filename, g, seed) # for pilot runs
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_'+str(run_id_first)+'_round' + str(round_) + '_param9_nosvi_base_net_only.pkl'
tmp = io.load_pkl(filename4)
inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=tmp['network.spec_dict']['n_components'], rank=rank,
                 n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)
inf.network.params_dict = tmp['network.params_dict']
inf.round = round_
posterior=inf.predict_uncorrected(obs_stats)


In [None]:
plot_prior = dd.TransformedNormal(m=g.prior.m, S = g.prior.S,
                            flags=[0,0,2,1,2,1,1,0,0],
                            lower=[0,0,0,0,0,0,0,0,0], upper=[0,0,np.pi,0,2*np.pi,0,0,0,0]) 

plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)], 
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a, 
                            flags=[0,0,2,1,2,1,1,0,0],
                            lower=[0,0,0,0,0,0,0,0,0], upper=[0,0,np.pi,0,2*np.pi,0,0,0,0]) 


lims = np.array([[-2, -2,    0, 0,       0, 0, 0, -1.5, -1], 
                 [ 2,  2,np.pi, 3, 2*np.pi, 3, 3, 1.5,   1]]).T

fig, _ = plot_pdf(plot_post, pdf2=plot_prior, lims=lims, gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1), 
                  figsize=(16,16), resolution=100, diag_only=True, diag_only_cols=3, diag_only_rows=3,
                  labels_params=labels_params)


In [None]:
def f(x):
    return x

sbc = SBC(generator=g, inf=inf, f=f, dim=9)

N = 10000
L = N//100
res = sbc.test(N, L)

plt.figure(figsize=(16,16))
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.hist(res[:,i], color='r', normed=True, bins=np.linspace(0,L+1,L+2)-.5)
    plt.title(labels_params[i])
plt.show()

# second round

- second-round SNPE-A fit: single-component MoG fit to all $x\sim \tilde{p}(x)$ resulting from the single-component proposal
- easier problem (proposal narrow $\rightarrow$ $\tilde{p}(x)$ narrower $\rightarrow$ fewer different $x$ to map onto their $\theta$


- we compare the calibration of  **uncorrected** $\tilde{p}(\theta|x)$ (returned by the MDN on round #2) against the **proposal prior** $\tilde{p}(\theta)$ returned from round #1

In [None]:
# ok, slightly retarded I didn't specifically save the proposals. First have to load first-round net to get that

round_ = 1
g, _, _ = setup_sim(seed, path='.') # reinit generator seed 
_,_ = get_data_o(filename, g, seed) # for pilot runs
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_'+str(run_id_first)+'_round' + str(round_) + '_param9_nosvi_base_net_only.pkl'
tmp = io.load_pkl(filename4)
inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=tmp['network.spec_dict']['n_components'], rank=rank,
                 n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)
inf.network.params_dict = tmp['network.params_dict']

proposal = inf.predict_uncorrected(obs_stats)


round_ = 2
g, _, _ = setup_sim(seed, path='.') # reinit generator seed 
_,_ = get_data_o(filename, g, seed) # for pilot runs
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_'+str(run_id)+'_round' + str(round_) + '_param9_nosvi_base_net_only.pkl'
tmp = io.load_pkl(filename4)
inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=tmp['network.spec_dict']['n_components'], rank=rank,
                 n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)
inf.network.params_dict = tmp['network.params_dict']
inf.generator.proposal = proposal

posterior=inf.predict_uncorrected(obs_stats)
inf.round = round_

## **uncorrected** posterior $\tilde{p}(\theta | x_o)$ in round 2

- notice the estimated marginal to the gain is 

In [None]:
plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)], 
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a, 
                            flags=[0,0,2,1,2,1,1,0,0],
                            lower=[0,0,0,0,0,0,0,0,0], upper=[0,0,np.pi,0,2*np.pi,0,0,0,0]) 

fig, _ = plot_pdf(plot_post, pdf2=plot_prior, lims=lims, gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1), 
                  figsize=(16,16), resolution=100, diag_only=True, diag_only_cols=3, diag_only_rows=3,
                  labels_params=labels_params)


In [None]:
def f(x):
    return x

sbc = SBC(generator=g, inf=inf, f=f, dim=9)

N = 10000
L = N//100
res = sbc.test(N, L)

plt.figure(figsize=(16,16))
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.hist(res[:,i], color='r', normed=True, bins=np.linspace(0,L+1,L+2)-.5)
    plt.title(labels_params[i])
plt.show()

# third-round


- **last**-round SNPE-A fit: **four**-component MoG fit to all $x\sim \tilde{p}(x)$ resulting from a single-component proposal 



- we compare the calibration of **uncorrected** $\tilde{p}(\theta|x)$ (returned by the MDN on round #3) against the **proposal prior** $\tilde{p}(\theta)$ returned from round #2

In [None]:
# ok, slightly retarded I didn't specifically save the proposals. First have to load second-round net to get that

round_ = 2
g, _, _ = setup_sim(seed, path='.') # reinit generator seed 
_,_ = get_data_o(filename, g, seed) # for pilot runs
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_'+str(run_id)+'_round' + str(round_) + '_param9_nosvi_base_net_only.pkl'
tmp = io.load_pkl(filename4)
inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=tmp['network.spec_dict']['n_components'], rank=rank,
                 n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)
inf.network.params_dict = tmp['network.params_dict']

proposal = inf.predict_uncorrected(obs_stats)


round_ = 3
g, _, _ = setup_sim(seed, path='.') # reinit generator seed 
_,_ = get_data_o(filename, g, seed) # for pilot runs
filename4 = './results/SNPE/maprf_100k_elife_prior01_run_'+str(run_id)+'_round' + str(round_) + '_param9_nosvi_base_net_only.pkl'
tmp = io.load_pkl(filename4)
inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                 pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                 n_components=tmp['network.spec_dict']['n_components'], rank=rank,
                 n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d),
                 filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,verbose=True)
inf.network.params_dict = tmp['network.params_dict']
inf.generator.proposal = proposal

posterior=inf.predict_uncorrected(obs_stats)
inf.round = round_

## **uncorrected** posterior $\tilde{p}(\theta | x_o)$ in round #3

- notice the skewed weighting between the two modes for the gain - they're actually almost equally large after the analytical correction for the proposal, which in this case was closer to the right mode (see above for the uncorrected posterior on round #2)

In [None]:
plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)], 
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a, 
                            flags=[0,0,2,1,2,1,1,0,0],
                            lower=[0,0,0,0,0,0,0,0,0], upper=[0,0,np.pi,0,2*np.pi,0,0,0,0]) 

fig, _ = plot_pdf(plot_post, pdf2=plot_prior, lims=lims, gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1), 
                  figsize=(16,16), resolution=100, diag_only=True, diag_only_cols=3, diag_only_rows=3,
                  labels_params=labels_params)

In [None]:
def f(x):
    return x

sbc = SBC(generator=g, inf=inf, f=f, dim=9)

N = 10000
L = N//100
res = sbc.test(N, L)

plt.figure(figsize=(16,16))
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.hist(res[:,i], color='r', normed=True, bins=np.linspace(0,L+1,L+2)-.5)
    plt.title(labels_params[i])
plt.show()

# calibrating calibration
- sanity check:
- let us set $p(\theta|x) = p(\theta) \ \forall x$ for now, which is trivially self-consistent with $p(\theta)$.
- same for resp. $\tilde{p}(\theta|x) = \tilde{p}(\theta) \ \forall x$ and $\tilde{p}(\theta)$.
- histograms should be beautifully uniform
- if they look too wild, we should use more samples (higher N) for the test

In [None]:
class SBC_naive(object):
    """ SIMULATION-BASED CALIBRATION"""
    def __init__(self, generator, inf, f, dim):
        
        self.generator = generator # delfi generator object
        self.inf = inf             # delfi inference object
        self.f = f                 # test-function (maps x->f(x))
        self.dim = dim             # dimensionality of f(x)
        
    def sample_full(self, N):
        out = self.generator.gen(N) # will sample from generator.proposal unless it's None
        return out[0], out[1]

    def get_conditional(self, x):
        if self.generator.proposal is None:
            return self.generator.prior
        else:
            return self.generator.proposal 
    
    def test(self, N, L):
        data = self.sample_full(N)  
        N = data[0].shape[0]
        
        res  = np.empty((N, self.dim))
        
        for i in range(N):
            f0 = self.f(data[0][i,:]).reshape(1,-1)
            p = self.get_conditional(data[1][i,:])
            
            batch = self.f(p.gen(L))
            assert batch.shape==(L, f0.size)
                
            res[i,:] = np.sum( f0 < batch , axis=0)

        return res
    

In [None]:
def f(x):
    return x

sbc = SBC_naive(generator=g, inf=inf, f=f, dim=9)

N = 10000
L = N//100
res = sbc.test(N, L)

plt.figure(figsize=(16,16))
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.hist(res[:,i], color='r', normed=True, bins=np.linspace(0,L+1,L+2)-.5)
plt.show()