## GLMAT: 2D kernel, song stimuli, biophysical model, ML+MC estimation

This notebook tests the full estimation method on data generated by a biophysical dSTRF model. The goal ois to estimate the parameters of the RF and the adaptation kernels. The parameter count of the RF is minimized by using a low-rank approximation (i.e., an outer product of two vectors) and by projecting time into a basis set of raised cosine filters that are spaced exponentially.

In [None]:
from __future__ import print_function, division
import os
import sys
import imp
import numpy as np

from dstrf import strf, mle, io, spikes, performance

# plotting packages
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package

sns.set_style("whitegrid")
mpl.rcParams['image.origin'] = 'lower'
mpl.rcParams['image.aspect'] = 'auto'
mpl.rcParams['image.cmap'] = 'jet'

The MAT model is governed by a small number of parameters: the spike threshold (omega), the amplitudes of the adaptation kernels (alpha_1, alpha_2), the time constants of the adaptation kernels (tau_1, tau_2), and the absolute refactory period. In addition, a function must be chosen for the probability of spike generation. The 'softplus' function, log(1 + exp(mu)), is a good choice because it doesn't saturate as readily when mu is large. Because there can only be one spike per bin, saturation causes the estimated parameters to be less than the true parameters.

In [None]:
# model time step
model_dt = 0.5

# spectrogram parameters
stim_dt = 3.0
upsample = int(stim_dt / model_dt)
spec_window = 2.5
spec_compress = 10
f_min = 0.0
f_max = 8.0

# strf parameters
nfreq = 20
ntau  = 30
# these parameters influence how the STRF dimensional reduction will work
ntbas = 10       # number of temporal basis functions
krank = 2        # rank of bilinear approximation
kcosbas = strf.cosbasis(ntau, ntbas)
ntbas = kcosbas.shape[1]

# time constants for the spike history kernels
htau = np.asarray([10, 200], dtype='d')
# nonlinearity used to calculate likelihood
nlin = "softplus"

Here we load data from the dSTRF experiments. Each RF was convolved with 30 different songs, then the linear output was fed into the biophysical equations.

In [None]:
# data parameters - the number of trials for assimilation and for testing
pad_before = 0    # how much to pad stimulus before onset
pad_after = ntau * stim_dt # how much to pad after offset
n_trials = 10      # set the number of trials to use (needs to be the same for all stimuli)
p_test = 0.5      # proportion of trials to use for test

# song stimulus:
root = "/scratch/dmeliza/modeldata/SNR4posp-norm/"
cell = "b-tonic-24"
data = io.load_rothman(cell, root, spec_window, stim_dt, f_min=f_min, f_max=f_max, f_count=nfreq, 
                       compress=spec_compress, gammatone=False)

io.pad_stimuli(data, pad_before, pad_after, fill_value=0.0);
io.preprocess_spikes(data, model_dt, htau)

n_test = int(p_test * len(data))
# split into assimilation and test sets and merge stimuli
assim_data = io.merge_data(data[:-n_test])
test_data = io.merge_data(data[-n_test:])

In [None]:
stim = assim_data["stim"]
plt.subplot(311).imshow(assim_data["stim"], extent=(0, assim_data["duration"], f_min, f_max))
t_spike = np.linspace(0, assim_data["duration"], assim_data["spike_v"].shape[0])
plt.subplot(312).plot(t_spike, assim_data["spike_v"].sum(1))
ax = plt.subplot(313)
for i, spk in enumerate(assim_data["spike_t"]):
    ax.vlines(spk * model_dt, i, i + 0.5)
    ax.plot(t_spike, assim_data["spike_h"][:, :, i])
for ax in plt.gcf().axes:
    ax.set_xlim(0,2000)

## Estimate parameters

The reg_alpha and reg_lambda parameters set the L1 and L2 penalties for the initial ML estimation. Note that we supply the nonlinearity function to the constructor too, as this determines how the log-likelihood is calculated.

In [None]:
# initial guess of parameters using penalized ML. Note that we provide the cosine basis set to the constructor of
# mle.estimator, which causes the design matrix to be in the cosine basis set
mlest = mle.mat(assim_data["stim"], kcosbas, assim_data["spike_v"], assim_data["spike_h"],
                assim_data["stim_dt"], assim_data["spike_dt"], nlin=nlin)
# also construct an estimator for the test object so we can score results
mltest = mle.mat(test_data["stim"], kcosbas, test_data["spike_v"], test_data["spike_h"],
                       test_data["stim_dt"], test_data["spike_dt"], nlin=nlin)

In [None]:
# low-rank estimators
mlest = mle.matfact(assim_data["stim"], kcosbas, krank, assim_data["spike_v"], assim_data["spike_h"],
                assim_data["stim_dt"], assim_data["spike_dt"], nlin=nlin)
mltest = mle.matfact(test_data["stim"], kcosbas, krank, test_data["spike_v"], test_data["spike_h"],
                       test_data["stim_dt"], test_data["spike_dt"], nlin=nlin)

In [None]:
rf_alpha = 1e1
rf_lambda = 1e1
%time w0 = mlest.estimate(reg_lambda=rf_lambda, reg_alpha=rf_alpha)

In [None]:
print(w0[:3])
# the RF is in the cosine basis set, which we need to convert back in order to visualize
rf_sta = strf.as_matrix(mlest.sta(), kcosbas)
#rf_mle = strf.as_matrix(w0[3:], kcosbas)
rf_mle = strf.from_basis(strf.defactorize(w0[3:], nfreq, krank), kcosbas)

rf_true = np.fliplr(io.load_rothman_rf(cell, root))
from scipy.signal import resample
# pad RF out to ntau
# if rf_true.shape[1] < ntau * stim_dt:
#     rf_true = np.column_stack([np.zeros((rf_true.shape[0], int(stim_dt * ntau) - rf_true.shape[1])), rf_true])
rsmpl_rf = np.fliplr(resample(resample(rf_true,20),40,axis=1))
rsmpl_rf = np.vstack((np.zeros((3,40)), rsmpl_rf[:-3,:]))
rsmpl_rf = np.hstack((np.zeros((20,3)), rsmpl_rf[:,:-3]))

plt.subplot(421).imshow(rf_true)
plt.title("STA")
plt.subplot(422).imshow(rf_mle)
plt.title("ML estimate")
plt.subplot(424).imshow(rf_true)

In [None]:
import progressbar
from dstrf import crossvalidate

#reg_grid = np.logspace(-1, 5, 50)[::-1]
l1_ratios = [0.1, 0.5, 0.7, 0.9, 0.95]
reg_grid = np.logspace(-1, 5, 20)[::-1]
scores = []
results = []

bar = progressbar.ProgressBar(max_value=len(l1_ratios) * len(reg_grid),
                              widgets=[
                                ' [', progressbar.Timer(), '] ',
                                progressbar.Bar(),
                                ' (', progressbar.ETA(), ') ',
                            ])
for reg, s, w in bar(crossvalidate.elasticnet(mlest, 4, reg_grid, l1_ratios, avextol=1e-5, disp=False)):
    scores.append(s)
    results.append((reg, s, w))
    
best_idx = np.argmax(scores)
best = results[best_idx]
print("best solution at {}: {}".format(best[0], best[1]))
rf_alpha, rf_lambda = best[0]
w0 = best[2]

In [None]:
print("best solution at {}: {}".format(best[0], best[1]))
print(w0[:3])
rf_mle = strf.from_basis(strf.defactorize(w0[3:], nfreq, krank), kcosbas)
plt.subplot(422).imshow(rf_mle[:,-13:])
plt.title("ML estimate")
plt.subplot(424).imshow(rf_true)

In [None]:
mat.random_seed(1)
t_stim = np.linspace(0, test_data["duration"], test_data["stim"].shape[1])
plt.subplot(311).imshow(test_data["stim"], extent=(0, test_data["duration"], f_min, f_max))
V = mltest.V(w0)
plt.subplot(312).plot(t_stim, V)
ax = plt.subplot(313)
for j, spk in enumerate(test_data["spike_t"]):
    ax.vlines(spk * model_dt, j, j + 0.5, 'r')
pred = np.zeros_like(test_data["spike_v"])
for i in range(j):
    pred[:, i] = mltest.predict(w0, matparams[3:], V)
    spk_t = pred[:, i].nonzero()[0]
    ax.vlines(spk_t * model_dt, i + j + 1, i + j + 1.5)
pred_psth = spikes.psth(pred, upsample, 1)
test_psth = spikes.psth(test_data["spike_v"], upsample, 1)
ax.plot(t_stim, test_psth, t_stim, pred_psth)
for ax in plt.gcf().axes:
    ax.set_xlim(0, 4200)

#eo = performance.corrcoef(test_data["spike_v"][::2], test_data["spike_v"][1::2], upsample, 1)
print("loglike: {}".format(-mltest.loglike(w0)))
#print("CC: {}/{}".format(np.corrcoef(test_psth, pred_psth)[0, 1], eo))
print("spike count: data = {}, pred = {}".format(test_data["spike_v"].sum(), pred.sum()))

We'll use the ML estimate to seed the MCMC sampler. We're going to reduce the size of the parameter space by factorizing the RF (i.e., a bilinear approximation). Note that we try to use the mlest object as much as possible to do the calculations rather than reimplement things; however, there can be some significant performance enhancements from an optimized implementation.

In [None]:
# estimate parameters using emcee
from neurofit import priors, costs, utils, startpos

# the MAT parameters are just bounded between reasonable limits. These may need to be expanded when using real data.
mat_prior = priors.joint_independent(
                [ priors.uniform( 0,  20),
                  priors.uniform(-50,  200),
                  priors.uniform(-5,   10),
                ])

def lnpost(theta):
    """Posterior probability for dynamical parameters"""
    mparams = theta[:3]
    rfparams = theta[3:]
    #rf_prior = -np.sum(np.abs(rfparams)) * rf_alpha - np.dot(rfparams, rfparams) * rf_lambda
    ll = mat_prior(mparams) #+ rf_prior
    if not np.isfinite(ll):
        return -np.inf
    w = np.r_[mparams, rfparams]
    ll -= mlest.loglike(w, rf_lambda, rf_alpha)
    return -np.inf if not np.isfinite(ll) else ll

In [None]:
# get low-rank approx of the ML STRF
# k0f, k0t = strf.factorize(strf.as_matrix(w0[3:], ntbas))
# rf_mlb = strf.from_basis(np.dot(k0f, k0t), kcosbas)
# w0_bl = np.r_[w0[:3], k0f.squeeze(), k0t.squeeze()]
print("lnpost of ML estimate: {}".format(lnpost(w0)))
%timeit lnpost(w0)

This code starts the MCMC sampler. We initialize the walkers (chains) in a gaussian around the ML estimate, with standard deviation 2x the absolute value of the best guess. The model converges fairly quickly, but then we let it sample for a while.

In [None]:
import emcee
# assimilation parameters
if sys.platform == 'darwin':
    nthreads = 1
else:
    nthreads = 8
nwalkers = 500
nsteps = 500

# initialize walkers
pos = p0 = startpos.normal_independent(nwalkers, w0, np.abs(w0) * 2)
# initialize the sampler
sampler = emcee.EnsembleSampler(nwalkers, w0.size, lnpost, threads=nthreads)

In [None]:
# start the sampler
tracker = utils.convergence_tracker(nsteps, 100)
for pos, prob, like in tracker(sampler.sample(pos, iterations=nsteps, storechain=True)): 
    continue

In [None]:
print("lnpost of p median: {}".format(np.median(prob)))
print("average acceptance fraction: {}".format(sampler.acceptance_fraction.mean()))
try:
    print("autocorrelation time: {}".format(sampler.acor))
except:
    pass    
w1 = np.median(pos, 0)
rfparams = w1[3:]
rf_map = strf.from_basis(strf.defactorize(rfparams, nfreq, krank), kcosbas)
print(w0[:matparams_n])
print(w1[:matparams_n])
# print("mle error: {}; map error: {}".format(strf.subspace(k1, rf_mle), strf.subspace(k1, rf_map)))
#plt.subplot(221)
#sns.heatmap(strf.from_basis(k1c, kcosbas), cmap='jet')
plt.subplot(222)
sns.heatmap(rf_mle[:,-13:], cmap='jet')
plt.subplot(224)
sns.heatmap(rf_map[:,-13:], cmap='jet')

In [None]:
from corner import corner
sns.set_style("whitegrid")

mpos = pos[:,:matparams_n]
matlabs = ['w','a1','a2',]
c = corner(mpos,
       bins=50, smooth=2,smooth1d=0,
       labels=matlabs)

In [None]:
mat.random_seed(1)
n_draw = 10
t_stim = np.linspace(0, test_data["duration"], test_data["stim"].shape[1])
plt.subplot(311).imshow(test_data["stim"], extent=(0, test_data["duration"], f_min, f_max))
vax = plt.subplot(312)
ax = plt.subplot(313)
for j, spk in enumerate(test_data["spike_t"]):
    ax.vlines(spk * model_dt, j, j + 0.5, 'r')
    
samples = np.random.permutation(nwalkers)[:n_draw]
pred = np.zeros((test_data["spike_v"].shape[0], n_draw), dtype=test_data["spike_v"].dtype)
for i, idx in enumerate(samples):
    mparams = pos[idx]
    V_mc = mltest.V(mparams)
    vax.plot(t_stim, V_mc)
    pred[:, i] = mltest.predict(mparams, matparams[3:], V_mc)
    spk_t = pred[:, i].nonzero()[0]
    ax.vlines(spk_t * model_dt, i + j + 1, i + j + 1.5)
pred_psth = spikes.psth(pred, upsample, 1)
test_psth = spikes.psth(test_data["spike_v"], upsample, 1)
ax.plot(t_stim, test_psth, t_stim, pred_psth)
for ax in plt.gcf().axes:
    ax.set_xlim(0, 4200)

#eo = performance.corrcoef(test_data["spike_v"][::2], test_data["spike_v"][1::2], upsample, 1)
print("loglike: {}".format(-mltest.loglike(w1)))
#print("CC: {}/{}".format(np.corrcoef(test_psth, pred_psth)[0, 1], eo))
print("spike count: data = {}, pred = {}".format(test_data["spike_v"].sum() / n_trials, pred.sum() / n_draw))

In [None]:
%matplotlib inline
plt.figure(figsize=(4,6))
stim = test_data["stim"]
t_spike = np.linspace(0, test_data["duration"], test_data["spike_v"].shape[0])
#t_stim = np.linspace(0, test_data["duration"], test_data["I"].shape[0])

plt.subplot(511).imshow(test_data["stim"], extent=(0, test_data["duration"], f_min, f_max))

plt.subplot(512).vlines(test_data["spike_t"][5] * model_dt, -0.4, 0.4)
# for i, spk in enumerate(test_data["spike_t"]):
#     ax.vlines(spk * model_dt, i, i + 0.5)
plt.subplot(513).plot(t_stim, V_mc)
H = -test_data["spike_h"][:, :, 5]
plt.subplot(514).plot(t_spike, H)
lci = mltest.lci(w1)
plt.subplot(515).plot(t_spike, np.exp(lci[:, 5]))
for ax in plt.gcf().axes:
    ax.set_xlim(19760,21800)
plt.savefig("glm_fit_example_{}.pdf".format(cell))

In [None]:
# save data for testing in the Pillow code
import scipy.io as sio
V = mlest.V_interp(theta_true).squeeze()
H = assim_resp[0]["H"]
mu = V - H[:, 0] * matparams[1] - H[:, 1] * matparams[2] - matparams[0]
sio.savemat('glmat_2dbi_song_twin.mat', {"stim": (stim - stim.mean(1)[:, np.newaxis]) / stim.std(1)[:, np.newaxis], 
                                         "spikes": assim_resp[0]["spike_v"],
                                         "stim_dt": stim_dt,
                                         "spike_dt": model_dt,
                                         "rf": k1,
                                         "Istim": V,
                                         "Itot": mu})