## GLMAT: 2D kernel, song stimuli, ML+MC estimation

This notebook demonstrates the full assimilation technique using song stimuli. The song waveform is processed to a 2D spectrogram, then convolved with a 2D STRF to produce the "voltage" of the GLMAT model. The adaptation "current" is calculated by convolving the spike trains with two exponential kernels. The goal of the assimilation is to estimate the parameters of the RF and the adaptation kernels. The parameter count of the RF is minimized by using a low-rank approximation (i.e., an outer product of two vectors) and by projecting time into a basis set of raised cosine filters that are spaced exponentially.

The approach is to use elastic-net penalized maximum-likelihood estimation to get a first guess at the parameters. This model uses the time-compression basis set but does not factorize the RF. The next step involves factorizing the RF to further reduce the parameters and then using MCMC to sample the posterior distribution of the parameters.

In [None]:
from __future__ import print_function, division
import os
import sys
import imp
import numpy as np

import mat_neuron._model as mat
from dstrf import strf, mle, io, performance

# plotting packages
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package
sns.set_style("whitegrid")
mpl.rcParams['image.origin'] = 'lower'
mpl.rcParams['image.aspect'] = 'auto'
mpl.rcParams['image.cmap'] = 'jet'

The MAT model is governed by a small number of parameters: the spike threshold (omega), the amplitudes of the adaptation kernels (alpha_1, alpha_2), the time constants of the adaptation kernels (tau_1, tau_2), and the absolute refactory period. In addition, a function must be chosen for spike generation. The 'softplus' function, log(1 + exp(mu)), is a good choice because it doesn't saturate as readily when mu is large. Because there can only be one spike per bin, saturation causes the estimated parameters to be less than the true parameters.

In [None]:
# model parameters: (ω, α1, α2, τ1, τ2, tref)
matparams = np.asarray([7, 100, 2, 10, 200, 2], dtype='d')
model_dt = 0.5
nlin = "exp"

matparams_i = [0,1,2]
matparams_n = len(matparams_i)
matparams_f = matparams[matparams_i]

In [None]:
# STRF: keep this very simple for proof of principle
stim_dt = 3.0
upsample = int(stim_dt / model_dt)
spec_window = 4.0
spec_compress = 10
f_min = 1.0
f_max = 8.0
nfreq = 20
ntau  = 40
# these parameters influence how the STRF dimensional reduction will work
ntbas = 10       # number of temporal basis functions
krank = 1        # rank of bilinear approximation
kcosbas = strf.cosbasis(ntau, ntbas)
ntbas = kcosbas.shape[1]

Here we load some data from a real neural recording from the CRCNS dataset. We're going to replace the actual neural response with a simulation based on the dstrf model. In the original experiment, stimuli were presented individually in a pseudorandom order. To simplify the model, we concatenate the stimuli, setting padding between the stimuli sufficient to capture any offset responses. Note that the spike responses are convolved with the adaptation kernels before merging stimuli so that we don't inadvertently carry over spike history from trials that are not truly contiguous.

In [None]:
# data parameters
pad_before = 0    # how much to pad stimulus before onset
pad_after = ntau * stim_dt # how much to pad after offset
n_trials = 10     # set the number of trials to use (needs to be the same for all stimuli)
p_test = 0.2      # proportion of trials to use for test

# load stimuli and responses
root = os.path.join(os.environ["HOME"], "data", "crcns")
cell = "oo1920_6_B"
stim_type = "conspecific"
data = io.load_crcns(cell, stim_type, root, spec_window, stim_dt, f_min=f_min, f_max=f_max, f_count=nfreq, 
                     compress=spec_compress, gammatone=True)
io.pad_stimuli(data, pad_before, pad_after, fill_value=0.0)
io.preprocess_spikes(data, model_dt, matparams[3:5])

n_test = int(p_test * len(data))
# split into assimilation and test sets and merge stimuli
assim_data = io.merge_data(data[:-n_test])
test_data = io.merge_data(data[-n_test:])

In [None]:
stim = assim_data["stim"]
plt.subplot(311).imshow(assim_data["stim"], extent=(0, assim_data["duration"], f_min, f_max))
t_spike = np.linspace(0, assim_data["duration"], assim_data["spike_v"].shape[0])
plt.subplot(312).plot(t_spike, assim_data["spike_v"].sum(1))
ax = plt.subplot(313)
for i, spk in enumerate(assim_data["spike_t"]):
    ax.vlines(spk * model_dt, i, i + 0.5)
    ax.plot(t_spike, assim_data["spike_h"][:, :, i])
for ax in plt.gcf().axes:
    ax.set_xlim(1750, 2200)

## Estimate parameters

The reg_alpha and reg_lambda parameters set the L1 and L2 penalties for the initial ML estimation. Note that we supply the nonlinearity function to the constructor too, as this determines how the log-likelihood is calculated.

In [None]:
# initial guess of parameters using penalized ML. Note that we provide the cosine basis set to the constructor of
# mle.estimator, which causes the design matrix to be in the cosine basis set
mlest = mle.estimator(assim_data["stim"], kcosbas, assim_data["spike_v"], assim_data["spike_h"],
                      assim_data["stim_dt"], assim_data["spike_dt"], nlin="exp")
# also construct an estimator for the test object so we can score results
mltest = mle.estimator(test_data["stim"], kcosbas, test_data["spike_v"], test_data["spike_h"],
                       test_data["stim_dt"], test_data["spike_dt"], nlin="exp")
%time w0 = mlest.estimate(reg_lambda=1e1, reg_alpha=1e1)
# TODO: cross-validation to find best regularization parameters

In [None]:
print(w0[:3])
# the RF is in the cosine basis set, which we need to convert back in order to visualize
rf_sta = strf.as_matrix(mlest.sta(), kcosbas)
rf_mle = strf.as_matrix(w0[3:], kcosbas)
k0f, k0t = strf.factorize(strf.as_matrix(w0[3:], ntbas), krank)
rf_mlb = strf.from_basis(np.dot(k0f, k0t), kcosbas)
plt.subplot(221).imshow(rf_sta, extent=(0, ntau, f_min, f_max))
plt.title("STA")
plt.subplot(222).imshow(rf_mle, extent=(0, ntau, f_min, f_max))
plt.title("MAP estimate")
plt.subplot(224).imshow(rf_mlb, extent=(0, ntau, f_min, f_max))
plt.title("MAP (rank {})".format(krank))

In [None]:
t_stim = np.linspace(0, assim_data["duration"], assim_data["stim"].shape[1])
plt.subplot(311).imshow(assim_data["stim"], extent=(0, assim_data["duration"], f_min, f_max))
V = mlest.V(w0)
plt.subplot(312).plot(t_stim, V)
ax = plt.subplot(313)
for j, spk in enumerate(assim_data["spike_t"]):
    ax.vlines(spk * model_dt, j, j + 0.5, 'r')
pred = np.zeros_like(assim_data["spike_v"])
for i in range(j):
    pred[:, i] = mlest.predict(w0, matparams[3:], V)
    spk_t = pred[:, i].nonzero()[0]
    ax.vlines(spk_t * model_dt, i + j + 1, i + j + 1.5)
pred_psth = spikes.psth(pred, upsample, 1)
test_psth = spikes.psth(assim_data["spike_v"], upsample, 1)
ax.plot(t_stim, test_psth, t_stim, pred_psth)
for ax in plt.gcf().axes:
    ax.set_xlim(0, 4000)

print("loglike: {}".format(-mlest.loglike(w0)))
print("CC: {}".format(np.corrcoef(test_psth, pred_psth)[0, 1]))
print("spike count: data = {}, pred = {}".format(assim_data["spike_v"].sum(), pred.sum()))

In [None]:
imp.reload(performance)
t_stim = np.linspace(0, test_data["duration"], test_data["stim"].shape[1])
plt.subplot(311).imshow(test_data["stim"], extent=(0, test_data["duration"], f_min, f_max))
V = mltest.V(w0)
plt.subplot(312).plot(t_stim, V)
ax = plt.subplot(313)
for j, spk in enumerate(test_data["spike_t"]):
    ax.vlines(spk * model_dt, j, j + 0.5, 'r')
pred = np.zeros_like(test_data["spike_v"])
for i in range(j):
    pred[:, i] = mltest.predict(w0, matparams[3:], V)
    spk_t = pred[:, i].nonzero()[0]
    ax.vlines(spk_t * model_dt, i + j + 1, i + j + 1.5)
pred_psth = spikes.psth(pred, upsample, 1)
test_psth = spikes.psth(test_data["spike_v"], upsample, 1)
ax.plot(t_stim, test_psth, t_stim, pred_psth)
for ax in plt.gcf().axes:
    ax.set_xlim(0, 4000)

eo = performance.corrcoef(test_data["spike_v"][::2], test_data["spike_v"][1::2], upsample, 1)
print("loglike: {}".format(-mltest.loglike(w0)))
print("CC: {}/{}".format(np.corrcoef(test_psth, pred_psth)[0, 1], eo))
print("spike count: data = {}, pred = {}".format(test_data["spike_v"].sum(), pred.sum()))

We'll use the ML estimate to seed the MCMC sampler. We're going to reduce the size of the parameter space by factorizing the RF (i.e., a bilinear approximation). Note that we try to use the mlest object as much as possible to do the calculations rather than reimplement things; however, there can be some significant performance enhancements from an optimized implementation.

In [None]:
# estimate parameters using emcee
from neurofit import priors, costs, utils, startpos

# the MAT parameters are just bounded between reasonable limits. These may need to be expanded when using real data.
mat_prior = priors.joint_independent(
                [ priors.uniform( 0,  20),
                  priors.uniform(-50,  200),
                  priors.uniform(-5,   10),
                ])

# lasso prior on RF parameters. This is the same as the reg_alpha parameter in the penalized MLE
rf_alpha = 1.
    
def lnpost(theta):
    """Posterior probability for dynamical parameters"""
    mparams = theta[:3]
    rfparams = theta[3:]
    rf_prior = -np.sum(np.abs(rfparams)) * rf_alpha
    ll = mat_prior(mparams) + rf_prior
    if not np.isfinite(ll):
        return -np.inf
    # reassemble strf from low-rank factors
    k = strf.defactorize(rfparams, nfreq, krank).flatten()
    w = np.r_[mparams, k]
    return ll - mlest.loglike(w)


In [None]:
# get low-rank approx of the ML STRF
w0_bl = np.r_[w0[:3], k0f.flatten(), k0t.flatten()]
print("lnpost of ML estimate: {}".format(lnpost(w0_bl)))
#%timeit lnpost(w0_bl)
%time for i in range(1000): lnpost(w0_bl)

This code starts the MCMC sampler. We initialize the walkers (chains) in a gaussian around the ML estimate, with standard deviation 2x the absolute value of the best guess. The model converges fairly quickly, but then we let it sample for a while.

In [None]:
import emcee
# assimilation parameters
if sys.platform == 'darwin':
    nthreads = 1
else:
    nthreads = 8
nwalkers = 500
nsteps = 3000

# initialize walkers
pos = p0 = startpos.normal_independent(nwalkers, w0_bl, np.abs(w0_bl) * 2)
# initialize the sampler
sampler = emcee.EnsembleSampler(nwalkers, w0_bl.size, lnpost, threads=nthreads)

In [None]:
# start the sampler
tracker = utils.convergence_tracker(nsteps, 100)
for pos, prob, like in tracker(sampler.sample(pos, iterations=nsteps, storechain=True)): 
    continue

In [None]:
print("lnpost of p median: {}".format(np.median(prob)))
print("average acceptance fraction: {}".format(sampler.acceptance_fraction.mean()))
try:
    print("autocorrelation time: {}".format(sampler.acor))
except:
    pass    
w1 = np.median(pos, 0)
rfparams = w1[3:]
rf_map = strf.from_basis(strf.defactorize(rfparams, nfreq, krank), kcosbas)
print(w0[:matparams_n])
print(w1[:matparams_n])
plt.subplot(222)
plt.imshow(rf_mle, extent=(0, ntau, f_min, f_max))
plt.subplot(223)
plt.imshow(rf_mlb, extent=(0, ntau, f_min, f_max))
plt.subplot(224)
plt.imshow(rf_map, extent=(0, ntau, f_min, f_max))

In [None]:
from corner import corner
sns.set_style("whitegrid")

mpos = pos[:,:matparams_n]
matlabs = ['w','a1','a2',]
c = corner(mpos,
       bins=50, smooth=2,smooth1d=0,
       labels=matlabs,
       truths=matparams[:3])

In [None]:
# posterior predictive distribution
stim = test_data["stim"]
t_stim = np.linspace(0, test_data["duration"], stim.shape[1])
plt.subplot(311).imshow(stim, extent=(0, test_data["duration"], f_min, f_max))
vax = plt.subplot(312)
spax = plt.subplot(313)
for j, spk in enumerate(test_data["spikes"]):
    spax.vlines(spk, j, j + 0.5, 'r')
mparams = matparams.copy()
samples = np.random.permutation(nwalkers)[:j]
for i, idx in enumerate(samples):
    rfparams = pos[idx, matparams_n:]
    k = strf.defactorize(rfparams, nfreq, krank)
    rf = strf.from_basis(k, kcosbas)
    V = strf.convolve(stim, rf)
    vax.plot(t_stim, V)
    mparams[matparams_i] = pos[idx, :matparams_n]
    S = predict_spikes(V, mparams, model_dt, upsample)
    spk_t = S.nonzero()[0]
    spax.vlines(spk_t * model_dt, i + j + 1, i + j + 1.5)

for ax in plt.gcf().axes:
    ax.set_xlim(0, 1000)