## GLMAT: 2D kernel

In [None]:
from __future__ import print_function, division
import os
import numpy as np
import scipy as sp
from numba import jit, guvectorize
import mat_neuron.core as matmodel

# plotting packages
%matplotlib inline
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package
sns.set_style("whitegrid")

In [None]:
# model parameters: (α1, α2, ω, τ1, τ2, tref)
matparams = np.asarray([100, 2, 7, 10, 200, 2], dtype='d')
model_dt = 0.5

# data parameters
duration = 40000
n_samples = int(duration / model_dt)
n_assim = 3
n_test = 0

# assimilation parameters
nthreads = 4
nwalkers = 2000
nsteps = 500
matparams_i = [0,1,2]
matparams_n = len(matparams_i)
matparams_f = matparams[matparams_i]

In [None]:
# STRF: keep this very simple for proof of principle
stim_dt = 10.0
nfreq = 30
ntau  = 30

from scipy.signal import resample
filts = np.load('../../filters.npz')
print(filts.keys())
k1 = resample(filts['bbm'], nfreq, axis=0)[:,:ntau] * 5

def filter_stimulus(S, k1, upsample=None):
    nf, nt = S.shape
    X = np.zeros(nt)
    for i in range(nf):
        X += np.convolve(S[i], k1[i], mode="same")
    if upsample is not None:
        t = np.arange(X.size)
        tu = np.arange(0, X.size, 1 / upsample)
        return np.interp(tu, t, X)
    else:
        return X

sns.heatmap(k1)


In [None]:
# generate some random data to fit
np.random.seed(1)
stim_dt = 10.0
upsample = int(stim_dt / model_dt)
data = []
stim = np.random.randn(nfreq, int(n_samples / (stim_dt / model_dt)))
stim[:,:100] = 0
plt.imshow(stim)

In [None]:
@jit
def predict_spikes(V, params, dt):
    N = V.size
    S = np.zeros(N, dtype='i')
    P = np.zeros(N, dtype='d')
    R = np.random.uniform(size=N)
    omega = params[2]
    A1 = np.exp(-dt / params[3])
    A2 = np.exp(-dt / params[4])
    i_refrac = int(params[5] / dt)
    H1 = H2 = 0
    iref = 0
    for i in range(N):
        Vt = V[i]
        H1 *= A1
        H2 *= A2
        p = np.exp(Vt - H1 - H2 - omega) * dt
        if i > iref and p > R[i]:
            H1 += params[0]
            H2 += params[1]
            iref = i + i_refrac
            S[i] = 1
    return S


@jit
def lci_poisson(V, params, spikes, dt):
    N = V.size
    lp = 0
    omega = params[2]
    A1 = np.exp(-dt / params[3])
    A2 = np.exp(-dt / params[4])
    H1 = H2 = 0
    for i in range(N):
        Vt = V[i]
        H1 *= A1
        H2 *= A2
        mu = Vt - H1 - H2 - omega
        lp += spikes[i] * mu - dt * np.exp(mu)
        if not np.isfinite(lp):
            return np.inf
        if spikes[i]:
            H1 += params[0]
            H2 += params[1]
    return lp
        
V = filter_stimulus(stim, k1, upsample)
for i in range(n_assim + n_test):
    spikes = predict_spikes(V, matparams, model_dt)
    lp = lci_poisson(V, matparams, spikes, model_dt)
    z = np.nonzero(spikes)[0]
    d = {"V": V,
         "duration": duration,
         "spike_t": z, 
         "spike_v": spikes,
         "lci": lp
         }
    data.append(d)

# split into assimilation and test sets
assim_data = data[:n_assim]
test_data = data[n_assim:]

In [None]:
ax1 = plt.subplot(211)
ax2 = plt.subplot(212)
for i, d in enumerate(data):
    ax1.plot(data[0]["V"])
    ax2.vlines(d["spike_t"], i, i + 0.5)
ax1.set_xlim(0, 8000)
ax2.set_xlim(0, 8000)
sum(d["lci"] for d in data)
len(data[0]["spike_t"])

In [None]:
# estimate parameters using emcee
from neurofit import priors, costs, utils, startpos
import emcee

mat_prior = priors.joint_independent(
                [ priors.uniform(-50,  200),
                  priors.uniform(-5,   10),
                  priors.uniform( 0,  20),
                ])
startparams = np.asarray([[-50, 200],
                          [-5, 10],
                          [0, 20],
                         ], dtype='d')

# lasso prior on RF parameters
rf_lambda = 1.0

# this is the local copy of the parameters that we'll update in each step
mparams = matparams.copy()
rfparams = k1.copy().flatten()

def lnpost_dyn(theta):
    """Posterior probability for dynamical parameters"""
    from mat_neuron._model import predict_adaptation
    mparams[matparams_i] = theta[:matparams_n]
    rfparams[:] = theta[matparams_n:]
    rf_prior = -np.sum(np.abs(theta)) * rf_lambda
    ll = mat_prior(theta[:matparams_n]) + rf_prior
    if not np.isfinite(ll):
        return -np.inf
    lp = 0
    V = filter_stimulus(stim, rfparams.reshape(nfreq, ntau), upsample)
    for d in assim_data:
        #H = predict_adaptation(matstate, mparams, d["spike_v"], model_dt)
        lp += lci_poisson(V, mparams, d["spike_v"], model_dt)
    return ll + lp        

In [None]:
# theoretically this is as good as it can get
theta_true = np.concatenate([matparams_f, k1.flatten()])
print("lnpost of p_true: {}".format(lnpost_dyn(theta_true)))
# and this is our initial state
p0 = np.concatenate([startpos.uniform_independent(nwalkers, startparams[:,0], startparams[:,1]),
                     startpos.normal_independent(nwalkers, k1.flatten(), [0.1] * k1.size)],
                   axis=1)
theta_0 = np.median(p0, 0)
print("lnpost of p0 median: {}".format(lnpost_dyn(theta_0)))
%timeit lnpost_dyn(theta_true)
%time for theta_0 in p0: lnpost_dyn(theta_0)

In [None]:
sampler = emcee.EnsembleSampler(nwalkers, theta_true.size, lnpost_dyn, threads=nthreads)
tracker = utils.convergence_tracker(nsteps, 25)

for pos, prob, _ in tracker(sampler.sample(p0, iterations=nsteps, storechain=False)): 
    continue

In [None]:
print("lnpost of p median: {}".format(np.median(prob)))
print("average acceptance fraction: {}".format(sampler.acceptance_fraction.mean()))
theta = np.median(pos, 0)
mparams[matparams_i] = theta[:matparams_n]
rfparams[:] = theta[matparams_n:]
plt.subplot(121)
sns.heatmap(k1)
plt.subplot(122)
sns.heatmap(rfparams.reshape(nfreq, ntau))
print(matparams_f)
print(theta[:matparams_n])

In [None]:
from corner import corner
sns.set_style("whitegrid")

mpos = pos[:,:matparams_n]
matlabs = ['a1','a2','b','w', 'R']
c = corner(mpos,
           range=[sp for sp in startparams],
       bins=50, smooth=2,smooth1d=0,
       labels=matlabs,
       truths=matparams_f)

In [None]:
# see how well predictions line up
d = assim_data[0]
Vpred = filter_stimulus(stim, rfparams, 10, upsample)

#Y, S = matmodel.predict(matstate, mparams, I, model_dt, upsample=upsample, stochastic=True)
plt.plot(d["V"][:4000])
plt.plot(Vpred[:4000])

In [None]:
for i, d in enumerate(data):
    plt.vlines(d["spike_t"], i, i + 0.5, 'r')

for i in range(len(data), len(data) + 10):
    S = predict_spikes(Vpred, mparams, model_dt, upsample=1)
    spk_t = S.nonzero()[0]
    plt.vlines(spk_t, i, i + 0.5)

plt.xlim(0, 10000)