## GLMAT: 2D kernel, ML estimation

In [None]:
from __future__ import print_function, division
import os
import sys
import numpy as np

import mat_neuron._model as mat
from dstrf import strf, mle

# plotting packages
%matplotlib inline
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package
sns.set_style("whitegrid")

In [None]:
# model parameters: (ω, α1, α2, τ1, τ2, tref)
matparams = np.asarray([7, 100, 2, 10, 200, 2], dtype='d')
model_dt = 0.5

# data parameters
duration = 400000
n_samples = int(duration / model_dt)
n_assim = 1
n_test = 0

matparams_i = [0,1,2]
matparams_n = len(matparams_i)
matparams_f = matparams[matparams_i]

In [None]:
# STRF: keep this very simple for proof of principle
stim_dt = 10.0
kscale = 5
nfreq = 28
ntau  = 30
ntbas = 8

# raised-cosine basis functions
kcosbas = strf.cosbasis(ntau, ntbas)

from scipy.signal import resample
filts = np.load('../../filters.npz')
print(filts.keys())
k1 = resample(filts['bbs'] * kscale, nfreq, axis=0)[:,ntau-1::-1]

plt.imshow(k1, cmap='jet', aspect='auto')

In [None]:
# generate some random data to fit
np.random.seed(1)
mat.random_seed(1)
stim_dt = 10.0
upsample = int(stim_dt / model_dt)
stim = np.random.randn(nfreq, int(n_samples / (stim_dt / model_dt)))
stim[:,:100] = 0
plt.imshow(stim, aspect='auto')

In [None]:
# song stimulus:
# cell = "yg0616_4_B"
# stim_type = "conspecific"

# stims, durations, spk_data, spky_data, names = utils.load_crcns(cell, stim_type, nfreq, t_dsample=1, compress=1, names=True)
# plt.imshow(stims[0], aspect='auto')

In [None]:
def filter_stimulus(S, kernel):
    """Convolve spectrogram S with spectrotemporal kernel. Kernel should not be flipped."""
    nf, nt = S.shape
    X = np.zeros(nt)
    for i in range(nf):
        X += np.convolve(S[i], kernel[i], mode="full")[:nt]
    return X

def predict_spikes(V, params, dt, upsample):
    omega, a1, a2, t1, t2, tref = params
    return mat.predict_poisson(V - omega, (a1, a2), (t1, t2), tref, 
                               dt, upsample)

In [None]:
data = []
V = filter_stimulus(stim, np.fliplr(k1))
for i in range(n_assim + n_test):
    spikes = predict_spikes(V, matparams, model_dt, upsample)
    H = mat.adaptation(spikes, matparams[3:5], model_dt)
    z = np.nonzero(spikes)[0]
    d = {"H": H,
         "duration": duration,
         "spike_t": z, 
         "spike_v": spikes,
        }
    data.append(d)

# split into assimilation and test sets
assim_data = data[:n_assim]
test_data = data[n_assim:]

In [None]:
ax1 = plt.subplot(211)
ax2 = plt.subplot(212)
ax1.plot(V)
for i, d in enumerate(data):
    ax2.vlines(d["spike_t"], i, i + 0.5)
ax1.set_xlim(0, 8000 // upsample)
ax2.set_xlim(0, 8000)
len(data[0]["spike_t"])

## Estimate parameters

In [None]:
# initial guess of parameters using ML
spikes = np.stack([d["spike_v"] for d in data], axis=1)
mlest = mle.estimator(stim, spikes, kcosbas, matparams[3:5], stim_dt, model_dt)
%time w0 = mlest.estimate()

In [None]:
print(w0[:3])
k1c = strf.to_basis(k1, kcosbas)
rf_sta = strf.as_matrix(mlest.sta(), ntbas)
rf_mle = strf.as_matrix(w0[3:], ntbas)
plt.subplot(221).imshow(k1, cmap='jet', aspect='auto')
plt.subplot(222).imshow(strf.from_basis(k1c, kcosbas), cmap='jet', aspect='auto')
plt.subplot(223).imshow(strf.from_basis(rf_sta, kcosbas), cmap='jet', aspect='auto')
plt.subplot(224).imshow(strf.from_basis(rf_mle, kcosbas), cmap='jet', aspect='auto')

In [None]:
# estimate parameters using emcee
from neurofit import priors, costs, utils, startpos
import emcee

# assimilation parameters
nthreads = 8
nwalkers = 2000
nsteps = 500

mat_prior = priors.joint_independent(
                [ priors.uniform( 0,  20),
                  priors.uniform(-50,  200),
                  priors.uniform(-5,   10),
                ])

# lasso prior on RF parameters
rf_lambda = 1.0

def lnpost_dyn(theta):
    """Posterior probability for dynamical parameters"""
    mparams = theta[:3]
    rfparams = theta[3:]
    rf_prior = -np.sum(np.abs(rfparams)) * rf_lambda
    ll = mat_prior(mparams) + rf_prior
    if not np.isfinite(ll):
        return -np.inf
    lp = 0
    # Use these lines on OS X: otherwise python will explode due to a bug in accelerate framework
    k = np.fliplr(strf.from_basis(strf.as_matrix(rfparams, ntbas), kcosbas))
    V = filter_stimulus(stim, k)
    # use this line on linux
    # V = mlest.V(theta)
    # the log_likelihood method in mat-neuron will abort if the likelihood blows up, so it's a bit faster at converging.
    for d in assim_data:
        lp += mat.log_likelihood_poisson(V - mparams[0], d["H"], d["spike_v"], matparams[1:3], model_dt, upsample)
    return ll + lp

In [None]:
# theoretically this is as good as it can get
theta_true = np.concatenate([matparams[:3], strf.as_vector(k1c)])
print("lnpost of p_true: {}".format(lnpost_dyn(theta_true)))
# and this is our initial population of walkers
pos = p0 = startpos.normal_independent(nwalkers, w0, np.abs(w0) * 0.2)
theta_0 = np.median(p0, 0)
print("lnpost of p0 median: {}".format(lnpost_dyn(theta_0)))
%timeit lnpost_dyn(theta_true)

In [None]:
sampler = emcee.EnsembleSampler(nwalkers, theta_true.size, lnpost_dyn, threads=nthreads)
tracker = utils.convergence_tracker(nsteps, 25)

for pos, prob, _ in tracker(sampler.sample(pos, iterations=nsteps, storechain=False)): 
    continue

In [None]:
print("lnpost of p median: {}".format(np.median(prob)))
print("average acceptance fraction: {}".format(sampler.acceptance_fraction.mean()))
w1 = np.median(pos, 0)
mp_map = w1[:matparams_n]
rf_map = np.fliplr(strf.as_matrix(w1[3:], ntau))
print(mp_est)
print("mle error: {}; map error: {}".format(strf.subspace(k1, rf_mle), strf.subspace(k1, rf_map)))
plt.subplot(121).imshow(k1, cmap='jet', aspect='auto')
plt.subplot(122).imshow(rf_map, cmap='jet', aspect='auto')

In [None]:
from corner import corner
sns.set_style("whitegrid")

mpos = pos[:,:matparams_n]
matlabs = ['w','a1','a2',]
c = corner(mpos,
       bins=50, smooth=2,smooth1d=0,
       labels=matlabs,
       truths=theta_true)

In [None]:
# see how well predictions line up
Vref = mlest.V(theta_true)
V_ml = mlest.V(w0)
V_map = mlest.V(w1)

#Y, S = matmodel.predict(matstate, mparams, I, model_dt, upsample=upsample, stochastic=True)
plt.plot(Vref[:200])
plt.plot(V_ml[:200])
plt.plot(V_map[:200])

In [None]:
for i, d in enumerate(data):
    plt.vlines(d["spike_t"], i, i + 0.5, 'r')

mparams = matparams.copy()
mparams[matparams_i] = mp_est
for i in range(len(data), len(data) + 10):
    S = predict_spikes(V_map, mparams, model_dt, upsample)
    spk_t = S.nonzero()[0]
    plt.vlines(spk_t, i, i + 0.5)

plt.xlim(0, 10000)