## GLMAT: 2D kernel (low-rank), ML+MC estimation — song stimuli

In [None]:
from __future__ import print_function, division
import os
import sys
import imp
import numpy as np

import mat_neuron._model as mat
from dstrf import strf, mle, io

# plotting packages
%matplotlib inline
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package
sns.set_style("whitegrid")

In [None]:
# model parameters: (ω, α1, α2, τ1, τ2, tref)
matparams = np.asarray([7, 100, 2, 10, 200, 2], dtype='d')
model_dt = 0.5

matparams_i = [0,1,2]
matparams_n = len(matparams_i)
matparams_f = matparams[matparams_i]

In [None]:
# STRF: keep this very simple for proof of principle
stim_dt = 3.0
upsample = int(stim_dt / model_dt)
kscale = 8
f_min = 0.25
f_max = 8.0
nfreq = 20
ntau  = 40
ntbas = 8

# raised-cosine basis functions
kcosbas = strf.cosbasis(ntau, ntbas)
ntbas = kcosbas.shape[1]

from scipy.signal import resample
filts = np.load('../../filters.npz')
print(filts.keys())
k1 = resample(filts['bbs'] * kscale, nfreq, axis=0)[:,ntau-1::-1]

sns.heatmap(k1)

In [None]:
# data parameters
n_test = 5
n_trials = 3

# song stimulus:
root = os.path.join(os.environ["HOME"], "data", "crcns")
cell = "yg0616_4_B"
stim_type = "conspecific"
data = io.load_crcns(cell, stim_type, root, 4.0, stim_dt, f_min=f_min, f_max=f_max, f_count=nfreq, compress=1, gammatone=True)
# split into assimilation and test sets and merge stimuli
assim_data = io.merge_data(data[:-n_test], pad_before=1000, pad_after=1000, dt=stim_dt)
test_data = io.merge_data(data[n_test:], pad_before=1000, pad_after=1000, dt=stim_dt)
plt.subplot(211).imshow(assim_data['stim'], cmap='jet', aspect='auto', extent=(0, assim_data["duration"], f_min, f_max))
ax = plt.subplot(212)
for i, d in enumerate(assim_data["spikes"]):
    ax.vlines(d, i, i + 0.5)
ax.set_xlim(0, assim_data["duration"])

In [None]:
mat.random_seed(1)

def predict_spikes(V, params, dt, upsample):
    omega, a1, a2, t1, t2, tref = params
    return mat.predict_poisson(V - omega, (a1, a2), (t1, t2), tref, 
                               dt, upsample)

stim = assim_data["stim"]
V = strf.convolve(stim, k1)
assim_resp = [] 
for i in range(n_trials):
    spikes = predict_spikes(V, matparams, model_dt, upsample)
    H = mat.adaptation(spikes, matparams[3:5], model_dt)
    z = np.nonzero(spikes)[0]
    d = dict(H=H,
             spike_t=z,
             spike_v=spikes)
    assim_resp.append(d)


In [None]:
plt.subplot(311).imshow(stim, cmap='jet', aspect='auto', extent=(0, assim_data["duration"], f_min, f_max))
plt.subplot(312).plot(np.linspace(0, assim_data["duration"], V.size), V)
ax = plt.subplot(313)
for i, d in enumerate(assim_resp):
    ax.vlines(d["spike_t"] * model_dt, i, i + 0.5)
for ax in plt.gcf().axes:
    ax.set_xlim(0, 4000)

## Estimate parameters

In [None]:
# initial guess of parameters using ML
spikes = np.stack([d["spike_v"] for d in assim_resp], axis=1)
mlest = mle.estimator(stim, spikes, kcosbas, matparams[3:5], stim_dt, model_dt)
%time w0 = mlest.estimate(maxiter=500)

In [None]:
print(w0[:3])
k1c = strf.to_basis(k1, kcosbas)
rf_sta = strf.as_matrix(mlest.sta(), kcosbas)
rf_mle = strf.as_matrix(w0[3:], kcosbas)
plt.subplot(221).imshow(k1, cmap='jet', aspect='auto')
plt.subplot(222).imshow(strf.from_basis(k1c, kcosbas), cmap='jet', aspect='auto')
plt.subplot(223).imshow(rf_sta, cmap='jet', aspect='auto')
plt.subplot(224).imshow(rf_mle, cmap='jet', aspect='auto')

In [None]:
# estimate parameters using emcee
from neurofit import priors, costs, utils, startpos
import emcee

# assimilation parameters
if sys.platform == 'darwin':
    nthreads = 1
else:
    nthreads = 8
nwalkers = 1000
nsteps = 200

mat_prior = priors.joint_independent(
                [ priors.uniform( 0,  20),
                  priors.uniform(-50,  200),
                  priors.uniform(-5,   10),
                ])

# lasso prior on RF parameters
rf_lambda = 1.0
X_stim = mlest._X_stim.get_value()

def lnpost_dyn(theta):
    """Posterior probability for dynamical parameters"""
    mparams = theta[:3]
    rfparams = theta[3:]
    rf_prior = -np.sum(np.abs(rfparams)) * rf_lambda
    ll = mat_prior(mparams) + rf_prior
    if not np.isfinite(ll):
        return -np.inf
    lp = 0
    # reassemble strf
    kf = rfparams[:nfreq]
    kt = rfparams[nfreq:]
    k = np.outer(kf, kt).flatten()    
    V = np.dot(X_stim, k)
    # the log_likelihood method in mat-neuron will abort if the likelihood blows up, so it's a bit faster at converging.
    for d in assim_resp:
        lp += mat.log_likelihood_poisson(V - mparams[0], d["H"], d["spike_v"], mparams[1:3], model_dt, upsample)
    return ll + lp

In [None]:
# get low-rank approx of the ML STRF
k0f, k0t = strf.factorize(strf.as_matrix(w0[3:], ntbas))
rf_mlb = strf.from_basis(np.dot(k0f, k0t), kcosbas)
w0_bl = np.r_[w0[:3], k0f.squeeze(), k0t.squeeze()]
print("lnpost of ML estimate: {}".format(lnpost_dyn(w0_bl)))

# and this is our initial population of walkers
pos = p0 = startpos.normal_independent(nwalkers, w0_bl, np.abs(w0_bl) * 0.2)
theta_0 = np.median(p0, 0)
print("lnpost of p0 median: {}".format(lnpost_dyn(theta_0)))
%timeit lnpost_dyn(theta_0)

In [None]:
sampler = emcee.EnsembleSampler(nwalkers, theta_0.size, lnpost_dyn, threads=nthreads)
tracker = utils.convergence_tracker(nsteps, 1)

for pos, prob, _ in tracker(sampler.sample(pos, iterations=nsteps, storechain=True)): 
    continue

In [None]:
print("lnpost of p median: {}".format(np.median(prob)))
print("average acceptance fraction: {}".format(sampler.acceptance_fraction.mean()))
w1 = np.median(pos, 0)
rfparams = w1[3:]
kf = rfparams[:nfreq]
kt = rfparams[nfreq:]
rf_map = strf.from_basis(np.outer(kf, kt), kcosbas)
print(w0[:matparams_n])
print(w1[:matparams_n])
print("mle error: {}; map error: {}".format(strf.subspace(k1, rf_mle), strf.subspace(k1, rf_map)))
plt.subplot(221).imshow(k1, cmap='jet', aspect='auto')
plt.subplot(222).imshow(rf_mle, cmap='jet', aspect='auto')
plt.subplot(223).imshow(rf_mlb, cmap='jet', aspect='auto')
plt.subplot(224).imshow(rf_map, cmap='jet', aspect='auto')

In [None]:
from corner import corner
sns.set_style("whitegrid")

mpos = pos[:,:matparams_n]
matlabs = ['w','a1','a2',]
c = corner(mpos,
       bins=50, smooth=2,smooth1d=0,
       labels=matlabs,
       truths=theta_true)

In [None]:
# see how well predictions line up
Vref = mlest.V(theta_true)
V_ml = mlest.V(w0)
V_map = strf.convolve(stim, rf_map)

plt.plot(Vref)
plt.plot(V_ml)
plt.plot(V_map)
plt.xlim(0, 1000)

In [None]:
# posterior predictive distribution
stim = test_data["stim"]
V = strf.convolve(stim, k1)
assim_resp = [] 
for j in range(n_trials):
    spikes = predict_spikes(V, matparams, model_dt, upsample)
    z = np.nonzero(spikes)[0]
    plt.vlines(z, j, j + 0.5, 'r')

ndraw = 10
mparams = matparams.copy()
samples = np.random.permutation(nwalkers)[:ndraw]
mparams[matparams_i] = w1[:matparams_n]
V = strf.convolve(stim, rf_map)
for i, idx in enumerate(samples):
    mparams[matparams_i] = pos[idx, :matparams_n]
    S = predict_spikes(V, mparams, model_dt, upsample)
    spk_t = S.nonzero()[0]
    plt.vlines(spk_t, i + j + 1, i + j + 1.5)

plt.xlim(0, 6000)