# GLM Demo: multivariate song stimulus, CRCNS neurons

This notebook demonstrates fitting to a real neuron from the CRCNS dataset.

In [None]:
from __future__ import print_function, division
import os
import sys
import imp
import numpy as np

import mat_neuron._model as mat
from dstrf import io, strf, mle, simulate, data, filters, models, spikes, performance

# plotting packages
import ruamel.yaml as yaml
%reload_ext yamlmagic
%matplotlib inline
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package
sns.set_style("whitegrid")

cfg = {}

In [None]:
%%yaml cfg
model:
  dt: 0.5
  ataus: [10.0, 200.0]
  t_refract: 2.0
  filter:
    rank: 2
    len: 50
    ncos: 12
  prior:
    l1: 585.0558005127248
    l2: 30.792410553301334
data:
  source: "neurobank"
  cell: "st376_4_3_1"
  #source: "crcns"
  #root: "../../crcns"
  #cell: "yg0616_4_B"
  stimulus:
    stim_type: "conspecific"
    include:
      - A0
      - A0_motifs_000
      - A0_motifs_001
      - A0_motifs_002
      - A0_motifs_003
      - A1
      - A2
      - A2_motifs_000
      - A3
      - A4
      - A5
      - A6_motifs_000
      - A8
      - A8_motifs_000
      - A8_motifs_001
      - A8_motifs_002
      - A8_motifs_003
      - A9
      - B0
      - B0_motifs_000
      - B0_motifs_001
      - B0_motifs_002
      - B0_motifs_003
      - B1
      - B2
      - B2_motifs_000
      - B3
      - B4
      - B5
      - B6_motifs_000
      - B8
      - B8_motifs_000
      - B8_motifs_001
      - B8_motifs_002
      - B8_motifs_003
      - B9
      - C0
      - C0_motifs_000
      - C0_motifs_001
      - C0_motifs_002
      - C0_motifs_003
      - C1
      - C2
      - C2_motifs_000
      - C6_motifs_000
      - C8_motifs_000
      - C8_motifs_001
      - C8_motifs_002
      - C8_motifs_003
    spectrogram:
      window: 4.0
      compress: 1
      f_min: 0.3
      f_max: 10.0
      f_count: 24
      gammatone: True
  prepadding: 50.0
  dt: 1.0
  test_trials: 50
spike_detect:
  thresh: -20.0
  rise_dt: 1.0
emcee:
  nsteps: 10
  nthreads: 8
  nwalkers: 500
  startpos_scale: 2.0
  bounds:
  - [0, 20]
  - [-50, 200]
  - [-5, 10]  

In [None]:
from munch import munchify
cf = munchify(cfg)

In [None]:
imp.reload(data)
imp.reload(io)
stim_fun = getattr(data, cf.data.source)
raw_data = stim_fun(cf)
p_test = cf.data.get("test_proportion", None)
assim_data = io.merge_data(io.subselect_data(raw_data, p_test))

In [None]:
#raw_data[0]
[d['stim_name'] for d in raw_data]
#[d["spike_v"].shape[1] for d in raw_data]

In [None]:
psth_dt = 5
upsample = int(psth_dt / cf.model.dt)   
eo = performance.corrcoef(assim_data["spike_v"][::2], assim_data["spike_v"][1::2], upsample, 1)

print("duration:", assim_data["duration"])
print("stim bins:", assim_data["stim"].shape[1])
print("spike bins:", assim_data["spike_v"].shape[0])
print("total spikes:", np.sum(assim_data["spike_v"]))
print("avg spike rate:", 1000 * np.sum(assim_data["spike_v"]) / assim_data["duration"] / assim_data["spike_v"].shape[1])
print("EO cc: %3.3f" % eo)

In [None]:
t_stim = np.linspace(0, assim_data["duration"], assim_data["stim"].shape[1])
t_spike = np.linspace(0, assim_data["duration"], assim_data["spike_v"].shape[0])

fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(16, 5))
axes[0].imshow(assim_data["stim"], 
               extent=(0, assim_data["duration"], cf.data.stimulus.spectrogram.f_min, cf.data.stimulus.spectrogram.f_max),
               cmap='jet', origin='lower', aspect='auto')
for i, spk in enumerate(assim_data["spike_t"]):
    axes[1].vlines(spk * cf.model.dt, i, i + 0.5)
    
axes[0].set_xlim(0, 8000);

## Estimate parameters

In [None]:
# initial guess of parameters using ML
krank = cf.model.filter.rank
kcosbas = strf.cosbasis(cf.model.filter.len, cf.model.filter.ncos)
try:
    mlest = mle.matfact(assim_data["stim"], kcosbas, krank, assim_data["spike_v"], assim_data["spike_h"],
                        assim_data["stim_dt"], assim_data["spike_dt"])
except TypeError:
    mlest = mle.matfact(assim_data["stim"], kcosbas, krank, assim_data["spike_v"], assim_data["spike_h"],
                        assim_data["stim_dt"], assim_data["spike_dt"])

In [None]:
%%time
nparams = 1 + mlest.n_hparams + mlest.n_kparams
constraint = models.matconstraint(nparams, cf.model.ataus[0], cf.model.ataus[1], cf.model.t_refract)
w0 = mlest.estimate(reg_lambda=cf.model.prior.l2, reg_alpha=cf.model.prior.l1, 
                    method='trust-constr', constraints=[constraint],
                    gtol=1e-2)

In [None]:
%time w0 = mlest.estimate(reg_lambda=cf.model.prior.l2, reg_alpha=cf.model.prior.l1)

In [None]:
print("MLE rate and adaptation parameters:", w0[:3])
fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(6, 3))
rf_sta = strf.as_matrix(mlest.sta(), kcosbas)
rf_mle = strf.from_basis(strf.defactorize(w0[3:], cf.data.stimulus.spectrogram.f_count, krank), kcosbas)
axes[0].imshow(rf_sta, cmap='jet', aspect='auto')
axes[0].set_title("STA")
axes[1].imshow(rf_mle, cmap='jet', aspect='auto')
axes[1].set_title("MLE (rank-{})".format(krank));

In [None]:
matboundprior = models.matbounds(cf.model.ataus[0], cf.model.ataus[1], cf.model.t_refract)
if not matboundprior(w0):
    print("parameters out of bounds")

## Predict responses

In [None]:
# test_data = io.merge_data(raw_data[:-n_test])
test_data = io.merge_data(io.subselect_data(raw_data, p_test, first=False))
mltest = mle.matfact(test_data["stim"], kcosbas, krank, test_data["spike_v"], test_data["spike_h"],
                     test_data["stim_dt"], test_data["spike_dt"])

In [None]:
print("duration:", test_data["duration"])
print("stim bins:", test_data["stim"].shape[1])
print("spike bins:", test_data["spike_v"].shape[0])
print("total spikes:", np.sum(test_data["spike_v"]))
print("avg spike rate:", 1000 * np.sum(assim_data["spike_v"]) / assim_data["duration"] / assim_data["spike_v"].shape[1])

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=1, sharex=True, figsize=(18, 9))
axes[0].imshow(test_data["stim"], 
               extent=(0, test_data["duration"], cf.data.stimulus.spectrogram.f_min, cf.data.stimulus.spectrogram.f_max),
               cmap='jet', origin='lower', aspect='auto')

t_stim = np.linspace(0, test_data["duration"], test_data["stim"].shape[1])
t_spike = np.linspace(0, test_data["duration"], test_data["spike_v"].shape[0])

Vpred = mltest.V(w0)
axes[1].plot(t_stim, Vpred)

n_trials = cf.data.get('test_trials', test_data["ntrials"])
for i, spk in enumerate(test_data["spike_t"]):
    axes[2].vlines(spk * cf.model.dt, i - 0.4 + n_trials, i + 0.4 + n_trials)
pred = np.zeros((t_spike.size, n_trials), dtype=mltest.spikes.dtype)

for j in range(n_trials):
    pred[:, j] = models.predict_spikes_glm(Vpred, w0[:3], cf)
    spk_t = pred[:, j].nonzero()[0]
    axes[2].vlines(spk_t * cf.model.dt, j - 0.4, j + 0.4, color='r')

psth_dt = 10
upsample = int(psth_dt / cf.model.dt)   
pred_psth = spikes.psth(pred, upsample, 1) / n_trials
test_psth = spikes.psth(test_data["spike_v"], upsample, 1) / test_data["ntrials"]
t_psth = np.linspace(0, test_data["duration"], test_psth.size)
axes[3].plot(t_psth, test_psth, t_psth, pred_psth)
axes[3].set_xlim(0, 8000);

eo = performance.corrcoef(test_data["spike_v"][::2], test_data["spike_v"][1::2], upsample, 1)
cc = performance.corrcoef(test_data["spike_v"], pred, upsample, 1)
print("EO cc: %3.3f" % eo)
print("pred cc: %3.3f" % cc)
print("log-likelihood: %f" % mltest.loglike(w0))
print("spike count: data = {}, pred = {}".format(test_data["spike_v"].sum(0).mean(), pred.sum(0).mean()))