# inference for Gabor-GLM with SNPE

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

- we fit a mixture-density network with convolutional layers to directly obtain posterior estimates from spike-triggered averages (STAs)
- two-stage fitting procedure: 
    - a first round identifies the rough region in parameter space by fitting a Gaussian posterior approximation
    - a second round identifies the exact posterior shape within that region by fitting an 8-component mixture of Gaussians. 
    
- this notebook imports a custom CDELFI which does custom init of components for second round

In [None]:
%%capture
%matplotlib inline

use_gpu = True
if use_gpu:
    import os    
    os.environ['THEANO_FLAGS'] = "device=cuda0"

import theano
theano.config.floatX='float32'

import matplotlib.pyplot as plt
import numpy as np
import lasagne.nonlinearities as lnl
import dill as pickle

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
from support_files.CDELFI import CDELFI
import delfi.utils.io as io
from delfi.utils.viz import plot_pdf

from utils import get_maprf_prior_01, setup_sim, setup_sampler, \
get_data_o, quick_plot, contour_draws
from model.gabor_rf import maprf as model
from model.gabor_stats import maprfStats

seed = 42

In [None]:

def plot_hist_marginals(data, weights=None, lims=None, gt=None, upper=False, rasterized=False):
    """
    Plots marginal histograms and pairwise scatter plots of a dataset.
    """

    data = np.asarray(data)
    n_bins = int(np.sqrt(data.shape[0]))

    if data.ndim == 1:

        fig, ax = plt.subplots(1, 1)
        ax.hist(data, weights=weights, bins=n_bins, normed=True, rasterized=rasterized)
        ax.set_ylim([0.0, ax.get_ylim()[1]])
        ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
        if lims is not None: ax.set_xlim(lims)
        if gt is not None: ax.vlines(gt, 0, ax.get_ylim()[1], color='r')

    else:

        n_dim = data.shape[1]
        fig = plt.figure()

        if weights is None:
            col = 'k'
            vmin, vmax = None, None
        else:
            col = weights
            vmin, vmax = 0., np.max(weights)

        if lims is not None:
            lims = np.asarray(lims)
            lims = np.tile(lims, [n_dim, 1]) if lims.ndim == 1 else lims

        for i in range(n_dim):
            for j in range(i, n_dim) if upper else range(i + 1):

                ax = fig.add_subplot(n_dim, n_dim, i * n_dim + j + 1)

                if i == j:
                    ax.hist(data[:, i], weights=weights, bins=n_bins, normed=True, rasterized=rasterized)
                    ax.set_ylim([0.0, ax.get_ylim()[1]])
                    ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
                    if i < n_dim - 1 and not upper: ax.tick_params(axis='x', which='both', labelbottom=False)
                    if lims is not None: ax.set_xlim(lims[i])
                    if gt is not None: ax.vlines(gt[i], 0, ax.get_ylim()[1], color='r')

                else:
                    ax.scatter(data[:, j], data[:, i], c=col, s=3, marker='o', vmin=vmin, vmax=vmax, cmap='binary', edgecolors='none', rasterized=rasterized)
                    if i < n_dim - 1: ax.tick_params(axis='x', which='both', labelbottom=False)
                    if j > 0: ax.tick_params(axis='y', which='both', labelleft=False)
                    if j == n_dim - 1: ax.tick_params(axis='y', which='both', labelright=True)
                    if lims is not None:
                        ax.set_xlim(lims[j])
                        ax.set_ylim(lims[i])
                    if gt is not None: ax.scatter(gt[j], gt[i], c='r', s=20, marker='o', edgecolors='none')

    return fig
    

In [None]:
# observation, models
reload_obs_stats = False

if reload_obs_stats:
    gtd = np.load('results/SNPE/toycell_6/ground_truth_data.npy', allow_pickle=True)[()]
    obs_stats = gtd['obs_stats']
    sim_info = np.load('results/sim_info.npy', allow_pickle=True)[()]
    d, params_ls = sim_info['d'], sim_info['params_ls']
    p = get_maprf_prior_01(params_ls)
    import delfi.generator as dg
    g = dg.Default(model=None, prior=p[0], summary=None)
else:
    # result dirs
    !mkdir -p results/
    !mkdir -p results/SNPE/
    !mkdir -p results/SNPE/toycell_6/

    # training data and true parameters, data, statistics
    idx_cell = 6 # load toy cell number 6 (cosine-shaped RF with 1Hz firing rate)
    filename = 'results/toy_cells/toy_cell_' + str(idx_cell) + '.npy'

    g, prior, d = setup_sim(seed, path='')
    obs_stats, pars_true = get_data_o(filename, g, seed)
    rf = g.model.params_to_rf(pars_true)[0]

    # plot ground-truth receptive field
    plt.imshow(rf, interpolation='None')
    plt.show()
    obs_stats, obs_stats[0,-1] # summary statistics: (STA , spike count (over 5 minutes simulation) )

    np.save('results/SNPE/toycell_6/ground_truth_data',
            {'obs_stats' : obs_stats, 'pars_true' : pars_true, 'rf' : rf})

    # visualize RFs defined by prior-drawn parameters theta
    contour_draws(g.prior, g, obs_stats, d=d)
    print(obs_stats)

In [None]:
# network architecture: 9 layer network [5x conv, 3x fully conn., 1x MoG] 

filter_sizes=[3,3,3,3,2]   # 5 conv ReLU layers
n_filters=(16,16,32,32,32) # 16 to 32 filters
pool_sizes=[1,2,2,2,2]     # pooling layers
n_hiddens=[50,50]          # 2 fully connected layers per MAF
actfun=lnl.rectify         # using ReLU's for fully connected layers


# N = 10k for first round

n_train = 10000
n_rounds = 2

# number of Gaussian components for final-round posterior estimate

# feature for CNN architectures: passing a value directly to the hidden layers (bypassing the conv layers).
# In this case, we pass the number of spikes (single number) directly, which allows to normalize the STAs 
# and hence help out the conv layers. Without that extra input, we couldn't recover the RF gain anymore!
n_inputs_hidden = 1

# some learning-schedule parameters
lr_decay = 0.999 # learning-rate decay over epochs
epochs = 500     # number of epochs
minibatch=100    # minibatch-size for stochastic gradient descent

svi=False        # whether to regularize the network weight. Large N should make this do very little anyways
reg_lambda=0.0   # regularization strength (not used if svi=False)

pilot_samples=1000 # z-scoring only applies to extra inputs (here: firing rate) directly fed to fully connected layers

prior_norm = False  # normalizes prior scales to mean zero and unit variances. 
                   # Helpful if parameter have vastly different scales.
init_norm = False  # normalizes network intitialization. Not yet support for conv- and ReLU- layers

rank = None   # no rank constraint on covariance matrices of posterior


n_mades = 5
act_fun = 'tanh'
mode = 'random'

rng = np.random
rng.seed(seed)

batch_norm= False
val_frac = 0.02

assert (n_train * val_frac) % minibatch == 0 # cannot deal with incomplete minibatches right now....


In [None]:
obs_stats[0,-1]

In [None]:

inf = infer.APT(
    generator=g, 
    obs=obs_stats, 
    prior_norm=prior_norm,                                # PRIOR NORMALIZATION OFF
    pilot_samples=pilot_samples,
    seed=seed, 
    svi=False,
    
    n_hiddens=n_hiddens, 
    n_filters=n_filters, 

    density='maf',
    n_mades=n_mades, 
    maf_actfun=act_fun,
    maf_mode=mode, 
    batch_norm=batch_norm,
    
    
    n_inputs = d*d,
    input_shape = (1,d,d), 
    n_bypass=1,
    filter_sizes=filter_sizes, 
    pool_sizes=pool_sizes, 
    actfun=actfun,
    verbose=True)

inf.network.aps[1].dtype

In [None]:
# print parameter numbers per layer (just weights, not biases)
def get_shape(i):
    return inf.network.aps[i].get_value().shape
print([get_shape(i) for i in range(1,17,2)])
print([np.prod(get_shape(i)) for i in range(1,17,2)])

In [None]:
#run SNPE-C
print('fitting model with SNPC-C')
log, trn_data, posteriors = inf.run(

                    n_train=n_train,
                    epochs=epochs, 
                    proposal='atomic',
    
                    n_atoms = minibatch - 1, 
                    moo='resample', 
    
                    n_rounds=n_rounds,
                    train_on_all=False,
                    minibatch=minibatch,
                    val_frac=val_frac,
                    silent_fail=False, 
                    verbose=True, 
                    print_each_epoch=True)


In [None]:
for r in range(len(posteriors)):
    
    posterior = posteriors[r]

    post_draws = posterior.gen(1000)

    plot_prior = dd.TransformedNormal(m=g.prior.m, S = g.prior.S,
                                flags=[0,0,2,1,2,1,1,2,2],
                                lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

    post_draws_trans = plot_prior._f(post_draws)


    fig = plot_hist_marginals(data=post_draws_trans, weights=None, 
                              lims=[[-1.5,1.5], [-1.1,1.1], [0,np.pi], [0, 2.5], [0,2*np.pi], [0,2], [0,4], [-1,1], [-1,1]], 
                              gt=None, upper=True, rasterized=False)
    fig.set_figwidth(16)
    fig.set_figheight(16)
    fig.show()
    
    

# plot posteriors in original space (back-transformed)
fitting Gaussians on log-transformed (frequency, ratio, width) and logit-transformed (phase, angle , location) parameters gives log- resp. logit-Normal marginals on original parameters. The 9-dimenensional joint distribution of all parameters can be transformed analytically.  

In [None]:
plot_prior = dd.TransformedNormal(m=g.prior.m, S = g.prior.S,
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

plot_post = dd.mixture.TransformedGaussianMixture.MoTG(
                            ms= [posterior.xs[i].m for i in range(posterior.n_components)], 
                            Ss =[posterior.xs[i].S for i in range(posterior.n_components)],
                            a = posterior.a, 
                            flags=[0,0,2,1,2,1,1,2,2],
                            lower=[0,0,0,0,0,0,0,-1,-1], upper=[0,0,np.pi,0,2*np.pi,0,0,1,1]) 

lims = np.array([[-2, -1.5, .001, 0,       .001, 0, 0, -.999, -.999], 
                 [ 2,  1.5, .999*np.pi, 3, 1.999*np.pi, 3, 3, .999,   .999]]).T

fig, _ = plot_pdf(plot_post, pdf2=plot_prior, lims=lims, gt=plot_post._f(pars_true.reshape(1,-1)).reshape(-1), 
                  figsize=(16,16), resolution=100,
                  labels_params=['bias', 'gain', 'logit phase', 'log freq', 'logit angle', 'log ratio', 'log width', 'xo', 'yo'])

## show contours of posterior draws

In [None]:
lvls=[0.5, 0.5]
p = posterior
n_draws = 10 
plt.figure(figsize=(6,6))
plt.imshow(obs_stats[0,:-1].reshape(d,d), interpolation='None', cmap='gray')
for i in range(n_draws):
    rfm = g.model.params_to_rf(p.gen().reshape(-1))[0]
    plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()])
    #print(rfm.min(), rfm.max())
    #plt.hold(True)
plt.title('RF posterior draws')

rfm = g.model.params_to_rf(pars_true.reshape(-1))[0]
plt.contour(rfm, levels=[lvls[0]*rfm.min(), lvls[1]*rfm.max()], colors='r')
plt.show()


# store final results

In [None]:
round_ = 2
filename1 = 'results/SNPE/toycell_6/maprf_100k_prior01_run_1_round' + str(round_) + '_param9_nosvi_CDELFI.pkl'
filename2 = 'results/SNPE/toycell_6/maprf_100k_prior01_run_1_round' + str(round_) + '_param9_nosvi_CDELFI_res.pkl'
filename4 = 'results/SNPE/toycell_6/maprf_100k_prior01_run_1_round' + str(round_) + '_param9_nosvi_CDELFI_net_only.pkl'

io.save_pkl((log2, trn_data2, posteriors2),filename1)
net = inf.network
data = {'network.spec_dict' : net.spec_dict, 
        'network.params_dict' : net.params_dict }
io.save_pkl(data, filename4)

In [None]:
# key results for figure 3 in paper
np.save('results/SNPE/toycell_6/maprf_100k_prior01_run_1_round' + str(round_) + '_param9_nosvi_CDELFI_posterior',
        {'posterior' : posteriors2[-1],
         'proposal' : inf.generator.proposal, 
         'prior' : inf.generator.prior})

In [None]:
round_=2
p=np.load('results/SNPE/toycell_6/maprf_100k_prior01_run_1_round' + str(round_) + '_param9_nosvi_CDELFI_posterior.npy', allow_pickle=True)[()]