# deep identity mapping

learning receptive field parameters from inputs (white-noise videos) and outputs (spike trains) of linear-nonlinear neuron models with parameterized linear filters

In [None]:
%%capture
%matplotlib inline
# notebook currently depends on code found only in feature_maprf-branch of lfi_models !

import delfi.neuralnet as dn
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import matplotlib.pyplot as plt
import numpy as np
import lfimodels.maprf.utils as utils

from lfimodels.maprf.maprf import maprf
from lfimodels.maprf.maprfStats import maprfStats
from delfi.utils.viz import plot_pdf

import lasagne.layers as ll
import theano
import theano.tensor as tt
import collections


In [None]:

seed = 42

## simulation model

d = 41                    # edge length of (quadratic) receptive field
parametrization = 'gabor' # ['full', 'gaussian', 'gabor']

params_ls = {'glm': ('bias',),
             'kernel': {'s' : ('phase', 'vec_f', 'ratio', 'width', ),
                        't' : ('value',)}}

#params_ls = {'glm': ('bias',),
#             'kernel': {'s' : ('vec_A', 'vec_f', 'ratio', 'width', ),
#                        't' : ('value',)}}


filter_shape = np.array((d,d,2))
m = maprf(filter_shape=filter_shape, 
          parametrization=parametrization,
          params_ls = params_ls,
          seed=seed, 
          dt = 0.025, 
          duration=1000 )


## prior over simulation parameters
prior = collections.OrderedDict()
if 'bias' in m.params_ls['glm']:
    prior['bias'] = {'mu' : np.array([0]), 'sigma' : np.array([1.]) }
#if 'vec_A' in m.params_ls['kernel']['s']:
#    prior['vec_A']  = {'mu' : np.zeros(2), 'sigma' : 1.0 * np.ones(2) }
if 'phase' in m.params_ls['kernel']['s']:
    prior['logit_φ']  = {'mu' : np.array([0]), 'sigma' : np.array([1.]) }    
if 'vec_f' in m.params_ls['kernel']['s']:
    prior['vec_f']  = {'mu' : np.zeros(2), 'sigma' : 1.0 * np.ones(2) }
if 'ratio' in m.params_ls['kernel']['s']:
    prior['log_γ']  = {'mu' : np.array([-0.098]), 'sigma' : np.array([0.256])}
if 'width' in m.params_ls['kernel']['s']:
    prior['log_b']  = {'mu' : np.array([ 0.955]), 'sigma' : np.array([0.236])}
if 'xo' in m.params_ls['kernel']['s']:
    prior['xo'] = {'mu' : np.array([0.]), 'sigma' : np.array([5/np.sqrt(.5)])}
if 'yo' in m.params_ls['kernel']['s']:
    prior['yo'] = {'mu' : np.array([0.]), 'sigma' : np.array([5/np.sqrt(.5)])}
L = np.diag(np.concatenate([prior[i]['sigma'] for i in list(prior.keys())]))
if 'value' in m.params_ls['kernel']['t']:
    ax_t = m._gen.axis_t
    Λ =  np.diag(ax_t / 0.075 * np.exp(1 - ax_t / 0.075))
    D = np.eye(ax_t.shape[0]) - np.eye(ax_t.shape[0], k=-1)
    F = np.dot(D, D.T)
    Σ = np.dot(Λ, np.linalg.inv(F).dot(Λ))
    prior['kt'] = {'mu': np.zeros_like(ax_t), 'sigma': np.linalg.inv(D).dot(Λ)}
    L = np.block([[L, np.zeros((L.shape[0], ax_t.size))], 
                  [np.zeros((ax_t.size, L.shape[1])), prior['kt']['sigma']]])
mu  = np.concatenate([prior[i][ 'mu'  ] for i in prior.keys()])
p = dd.Gaussian(m=mu, S=L.T.dot(L), seed=seed)

## data summary staistics

s = maprfStats(n_summary=d*d)

g = dg.Default(model=m, prior=p, summary=s)

## training data and true parameters, data, statistics

params_dict_true = {'glm': {'binsize': m.dt,
                            'bias': -0.5},
                    'kernel': {'s': {'angle': 0.7,
                                     'freq': .3,
                                     'gain': 2,
                                     'phase': np.pi/4,
                                     'ratio': 1.,
                                     'width': 2.5},
#                                    'xo': 0.,
#                                    'yo': 0.},
                               't': {'value': np.array([1.,0.])}}}
m.params_dict = params_dict_true
pars_true = m.read_params_buffer()

obs = m.gen_single()
obs_stats = s.calc([obs])



In [None]:
m.params_dict

In [None]:
m.params_idx

In [None]:
m.params_ls

In [None]:
x = np.pi/4
1/(1+np.exp(-x)), np.log( 1 / (1 - x))

In [None]:
np.pi/4

In [None]:
m.reparam_p2m(pars_true)

In [None]:
n_hiddens=(30,30)
n_filters=(16,16,16)

n_train=1000
epochs=100
minibatch=50
n_rounds=3
n_components=1

inf = infer.SNPE(generator=g, obs=obs_stats, prior_norm=True, pilot_samples=100, seed=seed, 
                 n_components=n_components, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d))
logs = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds)

posterior = inf.predict(obs_stats)
posterior.ndim = posterior.xs[0].ndim

# bunch of example prior draws
plt.figure(figsize=(16,10))
for i in range(15):
    plt.subplot(3,5,i+1)
    plt.imshow(m.params_to_rf(p.gen().reshape(-1))[0], interpolation='None')
plt.title('RF prior draws')
plt.show()

plt.figure(figsize=(16,5))
plt.subplot(1,4,1)
plt.imshow(m.params_to_rf(p.m)[0], interpolation='None')
plt.title('prior mean RF')
plt.subplot(1,4,2)
plt.imshow(obs_stats.reshape(d,d), interpolation='None')
plt.title('data STA')
plt.subplot(1,4,3)
plt.imshow(m.params_to_rf(pars_true)[0], interpolation='None')
plt.title('ground-truth RF')
plt.subplot(1,4,4)
plt.imshow(m.params_to_rf(posterior.calc_mean_and_cov()[0])[0], interpolation='None')
plt.title('posterior mean RF')
plt.show()

# bunch of example posterior draws
plt.figure(figsize=(16,10))
for i in range(15):
    plt.subplot(3,5,i+1)
    plt.imshow(m.params_to_rf(posterior.gen().reshape(-1))[0], interpolation='None')
plt.title('RF posterior draws')
plt.show()

# all pairwise marginals of fitted posterior
plot_pdf(posterior, pdf2=p, lims=[-5,5], gt=pars_true.reshape(-1), figsize=(16,16));


In [None]:
np.diag(posterior.calc_mean_and_cov()[1]) / np.diag(p.S)

In [None]:
plot_pdf(posterior, pdf2=p, lims=[-5,5], gt=pars_true.reshape(-1), figsize=(16,16),
labels_params=['b','logit_phase','f_1','f_2','log ratio','log width']+['kt_'+str(i) for i in range(filter_shape[2])]);

In [None]:
plt.figure(figsize=(6,4))
plt.imshow(obs_stats.reshape(d,d), interpolation='None')
plt.title('x0')
plt.show()

for r in range(n_rounds):
    print('round ', str(r))
    parameters =  logs[1][r][0]
    samples = logs[1][r][1]

    x_dists = np.sum( (samples[:,0,:,:] - obs_stats.reshape(1, d,d))**2, axis=(1,2))
    thresh = np.sort(x_dists)[x_dists.size//20]
    idx_0 = np.where(x_dists < thresh)[0]
    print(len(idx_0))

    plt.hist(x_dists, bins=np.linspace(0, x_dists.max(), 50) )
    plt.plot([thresh, thresh], [0, 50], 'r')
    plt.show()


    plt.figure(figsize=(12,8))
    for i in range(np.min((25, len(idx_0)))):
        plt.subplot(5,5,i+1)
        j = i# np.random.randint(len(idx_0))
        plt.imshow(samples[idx_0[j],0,:,:], interpolation='None')
    plt.show()

    m_,S_ = parameters.mean(axis=0), np.cov(parameters.T)
    #S_[0,0] = 1.
    posterior_sampled = dd.Gaussian(m=m_, S=S_)

    # all pairwise marginals of fitted posterior
    plot_pdf(posterior, lims=[-5,5], figsize=(16,16), samples=parameters[idx_0,:].T);

    print('posterior mean:', posterior.xs[0].m)
    print('sampled mean:', posterior_sampled.m)
    plt.plot(posterior.xs[0].m)
    plt.plot(posterior_sampled.m)
    plt.show()

In [None]:
x = np.random.normal(size=10000)
params = np.empty_like(x)
for i in range(x.size):
    params[i] = np.pi/2. / (1. + np.exp(-x[i]))
for i in range(x.size):
    params[i] =  np.log(params[i] / (np.pi/2. - params[i]))
params.mean()/np.pi, params.std(), params.min()/np.pi, params.max()/np.pi


In [None]:
x.shape, params.shape

In [None]:
i = 7
plt.hist(parameters[idx_0,i])
plt.plot([pars_true[i], pars_true[i]], [0,10], 'r--')
plt.plot([posterior.xs[0].m[i], posterior.xs[0].m[i]], [0,10], 'g')
plt.show()

In [None]:
posterior.calc_mean_and_cov()[0]

In [None]:
p.m

In [None]:
pars_true

In [None]:
np.diag(p.S)

In [None]:
np.diag(posterior.calc_mean_and_cov()[1])

In [None]:
# compare with pairwise prior marginals
plot_pdf(posterior, pdf2=p, lims=[-2,2], gt=pars_true.reshape(-1), figsize=(16,16),
labels_params=['b','A_1','A_2','f_1','f_2','log ratio','log width']+['kt_'+str(i) for i in range(filter_shape[2])]);


In [None]:
import delfi.utils.colormaps as cmaps
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import time

def plot_2pdf(pdf1, lims, pdf2=None, gt=None, contours=False, levels=(0.68, 0.95),
             resolution=500, labels_params=None, ticks=False, diag_only=False,
             diag_only_cols=4, diag_only_rows=4, figsize=(5, 5), fontscale=1,
             partial=False, samples=None, col1='k', col2='b', col3='r'):
    """Plots marginals of a pdf, for each variable and pair of variables.

    Parameters
    ----------
    pdf : object
    lims : array
    contours : bool
    levels : tuple
        For contours
    resolution
    labels_params : array of strings
    ticks: bool
        If True, includes ticks in plots
    diag_only : bool
    diag_only_cols : int
        Number of grid columns if only the diagonal is plotted
    diag_only_rows : int
        Number of grid rows if only the diagonal is plotted
    fontscale: int
    partial: bool
        If True, plots partial posterior with at the most 3 parameters.
        Only available if `diag_only` is False
    samples: array
        If given, samples of a distribution are plotted along `pdf`.
        If given, `pdf` is plotted with default `levels` (0.68, 0.95), if provided `levels` is None.
        If given, `lims` is overwritten and taken to be the respective
        limits of the samples in each dimension.
    col1 : str
        color 1
    col2 : str
        color 2
    """

    pdfs = (pdf1, pdf2)
    colrs = (col2, col3)

    if not (pdf1 is None or pdf2 is None): 
        assert pdf1.ndim==pdf2.ndim


    if samples is not None:
        contours = True
        if levels is None:
            levels = (0.68, 0.95)
        lims_min = np.min(samples, axis=1)
        lims_max = np.max(samples, axis=1)
        lims = np.asarray(lims)
        lims = np.concatenate(
            (lims_min.reshape(-1, 1), lims_max.reshape(-1, 1)), axis=1)
    else:
        lims = np.asarray(lims)
        lims = np.tile(lims, [pdf1.ndim, 1]) if lims.ndim == 1 else lims

    if pdf1.ndim == 1:

        fig, ax = plt.subplots(1, 1, facecolor='white', figsize=figsize)

        if samples is not None:
            ax.hist(samples[i, :], bins=100, normed=True,
                    color=col1,
                    edgecolor=col1)

        xx = np.linspace(lims[0, 0], lims[0, 1], resolution)

        for pdf, col in zip(pdfs, col):
            if not pdf is None:
                pp = pdf.eval(xx[:, np.newaxis], log=False)
                ax.plot(xx, pp, color=col)
        ax.set_xlim(lims[0])
        ax.set_ylim([0, ax.get_ylim()[1]])
        if gt is not None:
            ax.vlines(gt, 0, ax.get_ylim()[1], color='r')

        if ticks:
            ax.get_yaxis().set_tick_params(which='both', direction='out')
            ax.get_xaxis().set_tick_params(which='both', direction='out')
            ax.set_xticks(np.linspace(lims[0, 0], lims[0, 1], 2))
            ax.set_yticks(np.linspace(min(pp), max(pp), 2))
            ax.xaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
            ax.yaxis.set_major_formatter(mpl.ticker.FormatStrFormatter('%.1f'))
        else:
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])

    else:

        if not diag_only:
            if partial:
                rows = min(3, pdf1.ndim)
                cols = min(3, pdf1.ndim)
            else:
                rows = pdf1.ndim
                cols = pdf1.ndim
        else:
            cols = diag_only_cols
            rows = diag_only_rows
            r = 0
            c = -1

        fig, ax = plt.subplots(rows, cols, facecolor='white', figsize=figsize)
        ax = ax.reshape(rows, cols)

        for i in range(rows):
            for j in range(cols):

                if i == j:
                    if samples is not None:
                        ax[i, j].hist(samples[i, :], bins=100, normed=True,
                                      color=col1,
                                      edgecolor=col1)
                    xx = np.linspace(lims[i, 0], lims[i, 1], resolution)

                    for pdf, col in zip(pdfs, colrs):
                        if not pdf is None:
                            pp = pdf.eval(xx, ii=[i], log=False)

                            if diag_only:
                                c += 1
                            else:
                                r = i
                                c = j

                            ax[r, c].plot(xx, pp, color=col)
                    ax[r, c].set_xlim(lims[i])
                    ax[r, c].set_ylim([0, ax[r, c].get_ylim()[1]])

                    if gt is not None:
                        ax[r, c].vlines(
                            gt[i], 0, ax[r, c].get_ylim()[1], color='r')

                    if ticks:
                        ax[r, c].get_yaxis().set_tick_params(
                            which='both', direction='out', labelsize=fontscale * 20)
                        ax[r, c].get_xaxis().set_tick_params(
                            which='both', direction='out', labelsize=fontscale * 20)
#                         ax[r, c].locator_params(nbins=3)
                        ax[r, c].set_xticks(np.linspace(
                            lims[i, 0], lims[j, 1], 2))
                        ax[r, c].set_yticks(np.linspace(min(pp), max(pp), 2))
                        ax[r, c].xaxis.set_major_formatter(
                            mpl.ticker.FormatStrFormatter('%.1f'))
                        ax[r, c].yaxis.set_major_formatter(
                            mpl.ticker.FormatStrFormatter('%.1f'))
                    else:
                        ax[r, c].get_xaxis().set_ticks([])
                        ax[r, c].get_yaxis().set_ticks([])

                    if labels_params is not None:
                        ax[r, c].set_xlabel(
                            labels_params[i], fontsize=fontscale * 15)
                    else:
                        ax[r, c].set_xlabel([])

                    x0, x1 = ax[r, c].get_xlim()
                    y0, y1 = ax[r, c].get_ylim()
                    ax[r, c].set_aspect((x1 - x0) / (y1 - y0))

                    if partial and i == rows - 1:
                        ax[i, j].text(x1 + (x1 - x0) / 6., (y0 + y1) /
                                      2., '...', fontsize=fontscale * 25)
                        plt.text(x1 + (x1 - x0) / 8.4, y0 - (y1 - y0) /
                                 6., '...', fontsize=fontscale * 25, rotation=-45)

                else:
                    if diag_only:
                        continue

                    if i < j:
                        pdf = pdfs[0]
                    else:
                        pdf = pdfs[1]

                    if pdf is None:
                        ax[i, j].get_yaxis().set_visible(False)
                        ax[i, j].get_xaxis().set_visible(False)
                        ax[i, j].set_axis_off()
                        continue
                        

                    if samples is not None:
                        H, xedges, yedges = np.histogram2d(
                            samples[i, :], samples[j, :], bins=30, normed=True)
                        ax[i, j].imshow(np.flipud(H), origin='lower', extent=[
                                        xedges[0], xedges[-1], yedges[0], yedges[-1]])

                    xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
                    yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
                    X, Y = np.meshgrid(xx, yy)
                    xy = np.concatenate(
                        [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)
                    pp = pdf.eval(xy, ii=[i, j], log=False)
                    pp = pp.reshape(list(X.shape))
                    if contours:
                        ax[i, j].contour(X, Y, probs2contours(
                            np.flipud(pp.T), levels), levels, colors=('w', 'y'))
                    else:
                        ax[i, j].imshow(pp.T, origin='lower', cmap=cmaps.parula,
                                        extent=[lims[i, 0], lims[i, 1], lims[j, 0], lims[j, 1]],
                                        aspect='auto', interpolation='none')
                    ax[i, j].set_xlim(lims[i])
                    ax[i, j].set_ylim(lims[j])

                    if gt is not None:
                        ax[i, j].plot(gt[j], gt[i], 'r.', ms=10,
                                      markeredgewidth=0.0)

                    ax[i, j].get_xaxis().set_ticks([])
                    ax[i, j].get_yaxis().set_ticks([])
                    ax[i, j].set_axis_off()

                    x0, x1 = ax[i, j].get_xlim()
                    y0, y1 = ax[i, j].get_ylim()
                    ax[i, j].set_aspect((x1 - x0) / (y1 - y0))

                    if partial and j == cols - 1:
                        ax[i, j].text(x1 + (x1 - x0) / 6., (y0 + y1) /
                                      2., '...', fontsize=fontscale * 25)

                if diag_only and c == cols - 1:
                    c = -1
                    r += 1

    return fig, ax

def probs2contours(probs, levels):
    """Takes an array of probabilities and produces an array of contours at specified percentile levels

    Parameters
    ----------
    probs : array
        Probability array. doesn't have to sum to 1, but it is assumed it contains all the mass
    levels : list
        Percentile levels, have to be in [0.0, 1.0]

    Return
    ------
    Array of same shape as probs with percentile labels
    """
    # make sure all contour levels are in [0.0, 1.0]
    levels = np.asarray(levels)
    assert np.all(levels <= 1.0) and np.all(levels >= 0.0)

    # flatten probability array
    shape = probs.shape
    probs = probs.flatten()

    # sort probabilities in descending order
    idx_sort = probs.argsort()[::-1]
    idx_unsort = idx_sort.argsort()
    probs = probs[idx_sort]

    # cumulative probabilities
    cum_probs = probs.cumsum()
    cum_probs /= cum_probs[-1]

    # create contours at levels
    contours = np.ones_like(cum_probs)
    levels = np.sort(levels)[::-1]
    for level in levels:
        contours[cum_probs <= level] = level

    # make sure contours have the order and the shape of the original
    # probability array
    contours = np.reshape(contours[idx_unsort], shape)

    return contours

In [None]:

plot_2pdf(p, pdf2=posterior, lims=[-2,2], figsize=(12,12), samples=parameters.T);

In [None]:
n_hiddens=(30,30)
n_filters=(16,16,16)

n_train=5000
epochs=200
minibatch=50
n_rounds=2
n_components=4


g2 = dg.Default(model=m, prior=posterior, summary=s)


inf = infer.SNPE(generator=g2, obs=obs_stats, prior_norm=False, pilot_samples=None, seed=seed, 
                 n_components=n_components, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,d,d))
logs = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds)

posterior2 = inf.predict(obs_stats)
posterior2.ndim = posterior.xs[0].ndim

# bunch of example prior draws
plt.figure(figsize=(16,10))
for i in range(15):
    plt.subplot(3,5,i+1)
    plt.imshow(m.params_to_rf(p.gen().reshape(-1))[0], interpolation='None')
plt.title('RF prior draws')
plt.show()

plt.figure(figsize=(16,5))
plt.subplot(1,4,1)
plt.imshow(m.params_to_rf(p.m)[0], interpolation='None')
plt.title('prior mean RF')
plt.subplot(1,4,2)
plt.imshow(obs_stats.reshape(d,d), interpolation='None')
plt.title('data STA')
plt.subplot(1,4,3)
plt.imshow(m.params_to_rf(pars_true)[0], interpolation='None')
plt.title('ground-truth RF')
plt.subplot(1,4,4)
plt.imshow(m.params_to_rf(posterior2.calc_mean_and_cov()[0])[0], interpolation='None')
plt.title('posterior mean RF')
plt.show()

# bunch of example posterior draws
plt.figure(figsize=(16,10))
for i in range(15):
    plt.subplot(3,5,i+1)
    plt.imshow(m.params_to_rf(posterior2.gen().reshape(-1))[0], interpolation='None')
plt.title('RF posterior draws')
plt.show()

# all pairwise marginals of fitted posterior
plot_pdf(posterior2, lims=[-5,5], gt=pars_true.reshape(-1), figsize=(16,16),
labels_params=['b','A_1','A_2','f_1','f_2','log ratio','log width']+['kt_'+str(i) for i in range(filter_shape[2])]);


In [None]:
plt.figure(figsize=(6,4))
plt.imshow(obs_stats.reshape(d,d), interpolation='None')
plt.title('x0')
plt.show()

for r in range(n_rounds):
    print('round ', str(r))
    parameters =  logs[1][r][0]
    samples = logs[1][r][1]

    x_dists = np.sum( (samples[:,0,:,:] - obs_stats.reshape(1, d,d))**2, axis=(1,2))
    thresh = np.sort(x_dists)[x_dists.size//20]
    idx_0 = np.where(x_dists < thresh)[0]
    print(len(idx_0))

    plt.hist(x_dists, bins=np.linspace(0, x_dists.max(), 50) )
    plt.plot([thresh, thresh], [0, 50], 'r')
    plt.show()


    plt.figure(figsize=(12,8))
    for i in range(np.min((25, len(idx_0)))):
        plt.subplot(5,5,i+1)
        j = i# np.random.randint(len(idx_0))
        plt.imshow(samples[idx_0[j],0,:,:], interpolation='None')
    plt.show()

    m_,S_ = parameters.mean(axis=0), np.cov(parameters.T)
    #S_[0,0] = 1.
    posterior_sampled = dd.Gaussian(m=m_, S=S_)

    # all pairwise marginals of fitted posterior
    plot_pdf(posterior, lims=[-5,5], figsize=(16,16), samples=parameters[idx_0,:].T);

    print('posterior mean:', posterior.xs[0].m)
    print('sampled mean:', posterior_sampled.m)
    plt.plot(posterior.xs[0].m)
    plt.plot(posterior_sampled.m)
    plt.show()

# compare with maprf sampling

In [None]:
T

In [None]:
import numpy as np
import numpy.random as nr
import maprf.config as config
import maprf.rfs.v1 as V1
import maprf.invlink as invlink
import maprf.glm as glm 
from maprf.utils import *
from maprf.data import SymbolicData
import time
import maprf.filters as filters
import maprf.kernels as kernels
# from maprf.sampling.slice import EllipticalSliceSampler as ESS

import theano.printing as printing
import theano.tensor as tt

import theano
from theano import In

import pickle
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

from os import path
from maprf.inference import *

def pyprint(var, filename):
    printing.pydotprint(var, format='pdf', outfile=filename, high_contrast=False, with_ids=True)

cfg = config.load(path.join('config.yaml'))

# The forward part of the model
rf = V1.SimpleLinear()
emt = glm.Poisson()
# inputs and outputs
data = [theano.shared(empty(3), 'frames'),
        theano.shared(empty(1, dtype='int64'))]
frames, spikes = data

# fill the grids
rf.grids['s'][0].set_value(m._gen.grid_x)
rf.grids['s'][1].set_value(m._gen.grid_y)
rf.grids['t'][0].set_value(m._gen.axis_t)

import numpy.linalg as linalg
# build prior for temporal kernel
ax_t = rf.grids['t'][0].get_value()
s = ax_t / 0.075
n = ax_t.shape[0]
Λ =  np.diag(s * np.exp(1 - s))
D = np.eye(n) - np.eye(n, k=-1)
F = np.dot(D, D.T)
Σ = np.dot(Λ, linalg.inv(F).dot(Λ))

# inference model
inference = Inference(rf, emt)
inference.priors = cfg['priors']
inference.priors['kernel']['t'] = {'mu': np.zeros_like(ax_t), 'sigma': linalg.cholesky(Σ)}


inference.add_sampler(GaborSampler())
inference.add_sampler(KernelSampler())

plt.imshow(Σ, interpolation='None')
plt.show()

print('inputs: ', inference.inputs)
print('priors: ', inference.priors)

inference.build(data)
inference.compile()


inference.loglik['xo'] = 0
inference.loglik['yo'] = 0
#inference.loglik['kt'] = np.array([ 0.02047043,  0.51640702,  0.61731474,  0.01362172, -0.37586342,
#       -0.3750627 , -0.23319645, -0.1131711 , -0.04670192, -0.01714343,
#       -0.00575457])
inference.loglik['kt'] = np.array([0.5, 0.])
inference.loglik['vec_A'] = np.zeros(2)  # np.array([2.0, 0.0])
inference.loglik['vec_f'] = np.zeros(2)  # 0.3 * np.array([np.cos(0.7), np.sin(0.7)])
inference.loglik['log_γ'] = 0.0
inference.loglik['log_b'] = np.log(2.5)

frames.set_value(obs['I'].reshape(-1,d,d))
spikes.set_value(obs['data'])
plt.plot(spikes.get_value())

print(np.sum(obs['data']))

In [None]:

x_dists = np.sum( (trn_data[1][:,0,:,:] - obs_stats.reshape(1, d,d))**2, axis=(1,2))
thresh = 5.
idx_0 = np.where(x_dists < thresh)[0]
print(len(idx_0))

plt.hist(x_dists, bins=np.linspace(0, 100, 50) )
plt.show()

plt.figure(figsize=(6,4))
plt.imshow(obs_stats.reshape(d,d), interpolation='None')
plt.show()

plt.figure(figsize=(12,8))
for i in range(25):
    plt.subplot(5,5,i+1)
    j = i# np.random.randint(len(idx_0))
    plt.imshow(trn_data[1][idx_0[j],0,:,:], interpolation='None')
plt.show()

# all pairwise marginals of fitted posterior
plot_pdf(posterior.xs[0], lims=[-5,5], figsize=(16,16), samples=trn_data[0][idx_0,:].T);



S=np.cov(samples.T)
S[0,0] = 1.
posterior_sampled = dd.Gaussian(m=samples.mean(axis=0), S=S)

# all pairwise marginals of fitted posterior
plot_pdf(posterior_sampled, lims=[-5,5], figsize=(16,16), samples=trn_data[0][idx_0,:].T);

print('posterior mean:', posterior.xs[0].m)
print('sampled mean:', posterior_sampled.m)
plt.plot(posterior.xs[0].m)
plt.plot(posterior_sampled.m)
plt.show()

In [None]:
import datetime

T, L = inference.sample(5000)
T = {k.name: t for k, t in T.items()}

x = T['xo']
y = T['yo']

plt.figure(figsize=(15, 4))
plt.subplot(121)
plt.plot(x[500:])

plt.subplot(122)
plt.hist(x[500:], alpha=0.5, normed=True)
plt.show()


plt.figure(figsize=(15, 4))
plt.subplot(121)
plt.plot(x[500:], y[500:], '.k', alpha=0.1)
plt.show()


In [None]:
try: 
    np.savez('posterior_samples', {'T' : T})
except:
    pass
T = np.load('posterior_samples.npz')['arr_0'].tolist()['T']

T['b'] = np.zeros((T['vec_A'].shape[0], 1))

samples = np.hstack([np.atleast_2d(T[key].T).T for key in ['b','vec_A','vec_f','log_γ','log_b', 'kt']])

# compare with pairwise priors
plot_pdf(posterior.xs[0], lims=[-5,5], gt=pars_true.reshape(-1), figsize=(16,16), samples=samples.T,
labels_params=['b','A_1','A_2','f_1','f_2','log ratio','log width']+['kt_'+str(i) for i in range(filter_shape[2])]);


In [None]:

S=np.cov(samples.T)
S[0,0] = 1.
posterior_sampled = dd.Gaussian(m=samples.mean(axis=0), S=S)

# all pairwise marginals of fitted posterior
plot_pdf(posterior_sampled, lims=[-5,5], figsize=(16,16), samples=samples.T);

print('posterior mean:', posterior.xs[0].m)
print('sampled mean:', posterior_sampled.m)
plt.plot(posterior.xs[0].m)
plt.plot(posterior_sampled.m)
plt.show()

In [None]:
plot_pdf(posterior_sampled, lims=[-5,5], figsize=(16,16), samples=samples.T);


In [None]:
import delfi.distribution as dd

def save_mog(mog, filename=None):
    """Save mixture of Gaussians (avoiding pickle)
    
    Saves the key arrays of a Mixture of Gaussians. 
    
    Parameters
    ----------
    mog : (Mixture of) Gaussian object
        mixture 
    filename : string
        desired save file location. If None, does not
        save and returns dictionary with arrays instead.
    
    """        
    assert isinstance(posterior, (dd.MoG, dd.Gaussian))
    
    if isinstance(mog, dd.MoG):
        
        save_dict = {'a'  : mog.a,
                     'ms' : [x.m for x in mog.xs],
                     'Ss' : [x.S for x in mog.xs],
                     'seed' : mog.seed}
        
    elif isinstance(mog, dd.Gaussian):
        
        save_dict = {'a'  : np.ones(1),
                     'ms' : [mog.m], 
                     'Ss' : [mog.S],
                     'seed' : mog.seed}
        
            
    if not filename is None: 
        np.save(filename, save_dict)
    else:
        return save_dict
    
def load_mog(filename):
    """Load mixture of Gaussians (avoiding pickle)
    
    Loads key arrays of a Mixture of Gaussians and returns the 
    corresponding object. 
    
    Parameters
    ----------
    filename : string
        save file location (with or without file extension)
    
    """
    if not filename[-4:]=='.npy':
        filename += '.npy'
    
    sd = np.load(filename)[()]
    
    mog = dd.MoG(a=sd['a'], ms=sd['ms'], Ss=sd['Ss'], seed=sd['seed'])
    
    return mog

def save_res(p, posterior, network, filename)

    prior = save_mog(p)
    posterior = save_mog(posterior)
    
    net_pars = network.params_dict
    net_spec = network.spec_dict
    
    np.savez(filename, prior, posterior, net_pars, net_spec)
    
    
def load_res(filename):
    
    load_file = np.load(filename)

    ld = load_file['p']
    p = dd.MoG(a=ld['a'], ms=ld['ms'], Ss=ld['Ss'], seed=ld['seed'])

    ld = load_file['posterior']
    posterior = dd.MoG(a=ld['a'], ms=ld['ms'], Ss=ld['Ss'], seed=ld['seed'])
    
    ns = load_file['net_spec']
    network = NeuralNetwork(n_inputs=ns['n_inputs'],
                            n_outputs=ns['n_outputs'],
                            n_components=ns['n_components'],
                            n_filters=ns['n_filters'],
                            n_hiddens=ns['n_hiddens'],
                            n_rnn=ns['n_rnn'],
                            seed=ns['seed']
                            svi=ns['svi'])
    network.params_dict = load_file['net_pars']
    
