In [None]:
from __future__ import print_function, division
import sys
import imp
import numpy as np
import scipy as sp

import mat_neuron._model as mat
from dstrf import strf, mle

# plotting packages
%matplotlib inline
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package
sns.set_style("whitegrid")

In [None]:
# model parameters: (ω, α1, α2, τ1, τ2, tref)
matparams = np.asarray([7, 100, 2, 10, 200, 2], dtype='d')
model_dt = 0.5

matparams_i = [0,1,2]
matparams_n = len(matparams_i)
matparams_f = matparams[matparams_i]

In [None]:
# convolution kernel
from dstrf import filters
stim_dt = 10.0
ntau = 60
upsample = int(stim_dt / model_dt)
ntbas = 8
kscale = 2.0

# raised-cosine basis functions
kcosbas = strf.cosbasis(ntau, ntbas)
ntbas = kcosbas.shape[1]

k1, kt = filters.gammadiff(ntau * stim_dt / 32, ntau * stim_dt / 16, 5, ntau * stim_dt, stim_dt)
k1 = k1[::-1] * kscale
k1c = strf.to_basis(k1, kcosbas)

plt.plot(-kt[::-1], k1)

In [None]:
def filter_stimulus(S, k1):
    return np.correlate(S, k1, mode="full")[:S.size]


def predict_spikes(V, params, dt, upsample):
    omega, a1, a2, t1, t2, tref = params
    return mat.predict_poisson(V - omega, (a1, a2), (t1, t2), tref, 
                               dt, upsample)


In [None]:
# data parameters
duration = 100000
n_bins = int(duration / model_dt)
n_frames = n_bins // upsample
n_assim = 1
n_test = 5

# generate data to fit
np.random.seed(1)
mat.random_seed(1)
data = []
stim = np.random.randn(n_frames)
stim[:100] = 0
        
V = filter_stimulus(stim, k1)
for i in range(n_assim + n_test):
    spikes = predict_spikes(V, matparams, model_dt, upsample)
    H = mat.adaptation(spikes, matparams[3:5], model_dt)
    z = np.nonzero(spikes)[0]
    d = {"H": H,
         "duration": duration,
         "spike_t": z, 
         "spike_v": spikes,
         "loglike": mat.log_likelihood_poisson(V - matparams[0], H, spikes, matparams[1:3], model_dt, upsample)
        }
    data.append(d)

# split into assimilation and test sets
assim_data = data[:n_assim]
test_data = data[n_assim:]

In [None]:
ax1 = plt.subplot(211)
ax2 = plt.subplot(212)
for i, d in enumerate(data):
    ax2.vlines(d["spike_t"], i, i + 0.5)
ax1.plot(V)
ax1.set_xlim(0, 8000 / upsample)
ax2.set_xlim(0, 8000)
print("log likelihood: {}".format(np.sum([d["loglike"] for d in data])))
print("spikes: {}; rate: {} / dt".format(np.mean([d["spike_t"].size for d in data]), 
                                         np.mean([d["spike_t"].size / d["duration"] for d in data])))

In [None]:
# initial guess of parameters using cross-validated ML
ntbas = 8
kcosbas = strf.cosbasis(ntau, ntbas)
spike_v = np.stack([d["spike_v"] for d in assim_data], axis=1)
spike_h = np.stack([d["H"] for d in assim_data], axis=2)
mlest = mle.mat(stim, kcosbas, spike_v, spike_h, stim_dt, model_dt, nlin="softplus")
%time w0 = mlest.estimate(reg_alpha=1.0)

In [None]:
print(w0[:3])
rf_sta = strf.from_basis(mlest.sta(), kcosbas)
rf_ml = strf.from_basis(w0[3:], kcosbas)
plt.plot(k1)
plt.plot(rf_sta)
plt.plot(rf_ml)

In [None]:
# estimate parameters using emcee
from neurofit import priors, costs, utils, startpos
import emcee

# assimilation parameters
if sys.platform == 'darwin':
    nthreads = 1
else:
    nthreads = 8
nwalkers = 500
nsteps = 500

mat_prior = priors.joint_independent(
                [ priors.uniform( 0,  20),
                  priors.uniform(-50,  200),
                  priors.uniform(-5,   10),
                ])

# lasso prior on RF parameters
rf_lambda = 1.0

def matbounds(t1, t2, tr):
    aa1 = -(1 - np.exp(-tr/t2))/(1 - np.exp(-tr/t1))
    aa2 = -(np.exp(tr/t2) - 1)/(np.exp(tr/t1) - 1)
    def f(mparams):
        return (mparams[2] > aa1 * mparams[1]) and (mparams[2] > aa2 * mparams[1])
    return f

matboundprior = matbounds(*matparams[3:6])

def lnprior(theta):
    mparams = theta[:3]
    rfparams = theta[3:]
    if not matboundprior(mparams):
        return -np.inf
    rf_prior = -np.sum(np.abs(rfparams)) * rf_lambda
    ll = mat_prior(mparams) + rf_prior
    if not np.isfinite(ll):
        return -np.inf
    else:
        return ll


def loglike_poisson(V, H, spike_t, alpha, dt):
    mu = V - np.dot(H, alpha)
    return mu[spike_t].sum() - np.exp(mu).sum() * dt 


def loglike_sigmoid(V, H, spike_t, alpha, dt):
    mu = V - np.dot(H, alpha)
    lmb = (1 + np.tanh(mu / 2)) / 2
    return np.log(lmb[spike_t]).sum() - lmb.sum() * dt 


def loglike_softplus(V, H, spike_t, alpha, dt):
    mu = V - np.dot(H, alpha)
    lmb = np.log1p(np.exp(mu))
    return np.log(lmb[spike_t]).sum() - lmb.sum() * dt 
    
    
def lnlike(theta):
    mparams = theta[:3]
    rfparams = theta[3:]
    Vi = mlest.V_interp(theta).squeeze() - mparams[0]
    lp = 0
    for d in assim_data:
        lp += loglike_softplus(Vi, d["H"], d["spike_t"], mparams[1:3], model_dt)
    return lp   

def lnpost_dyn(theta):
    """Posterior probability for dynamical parameters"""
    return lnprior(theta) - mlest.loglike(theta) #+ lnlike(theta) 
    

In [None]:
# theoretically this is as good as it can get
theta_true = np.concatenate([matparams_f, k1c])
print("lnpost of p_true: {}".format(lnpost_dyn(theta_true)))
# initial state is a gaussian ball around the ML estimate
p0 = startpos.normal_independent(nwalkers, w0, np.abs(w0) * 2)
theta_0 = np.median(p0, 0)
print("lnpost of p0 median: {}".format(lnpost_dyn(theta_0)))
%timeit lnpost_dyn(theta_true)

In [None]:
sampler = emcee.EnsembleSampler(nwalkers, theta_true.size, lnpost_dyn, threads=nthreads)
tracker = utils.convergence_tracker(nsteps, 25)

for pos, prob, _ in tracker(sampler.sample(p0, iterations=nsteps, storechain=False)): 
    continue

In [None]:
print("lnpost of p median: {}".format(np.median(prob)))
print("average acceptance fraction: {}".format(sampler.acceptance_fraction.mean()))
theta = np.median(pos, 0)
mparams = theta[:matparams_n]
rfparams = theta[matparams_n:]
plt.plot(k1)
plt.plot(strf.from_basis(rfparams, kcosbas))
print(matparams_f)
print(w0[:3])
print(theta[:matparams_n])

In [None]:
from corner import corner
sns.set_style("whitegrid")

mpos = pos[:,:matparams_n]
matlabs = ['a1','a2','w']
c = corner(mpos,
           #range=[sp for sp in startparams],
       bins=50, smooth=2,smooth1d=0,
       labels=matlabs,
       truths=matparams_f)

In [None]:
# see how well predictions line up
d = assim_data[0]
Vpred = mlest.V(theta)

plt.plot(V[:400])
plt.plot(Vpred[:400])

In [None]:
# posterior predictive distribution
for j, d in enumerate(test_data):
    plt.vlines(d["spike_t"], j, j + 0.5, 'r')

mparamp = matparams.copy()
samples = np.random.permutation(nwalkers)[:n_test]
for i, idx in enumerate(samples):
    sample = pos[idx]
    V = mlest.V(sample)
    mparamp[matparams_i] = sample[:matparams_n]
    S = predict_spikes(V, mparamp, model_dt, upsample)
    spk_t = S.nonzero()[0]
    plt.vlines(spk_t, i + j + 1, i + j + 1.5)

plt.xlim(0, 10000)