## GLMAT: 2D kernel, noise stimulus, ML estimation

In [None]:
from __future__ import print_function, division
import os
import sys
import numpy as np
import scipy as sp

import mat_neuron._model as mat
from dstrf import strf, mle

# plotting packages
%matplotlib inline
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package
sns.set_style("whitegrid")

In [None]:
# model parameters: (ω, α1, α2, τ1, τ2, tref)
matparams = np.asarray([7, 100, 2, 10, 200, 2], dtype='d')
model_dt = 0.5

# data parameters
duration = 400000
n_samples = int(duration / model_dt)
n_assim = 1
n_test = 0

In [None]:
# STRF: keep this very simple for proof of principle
stim_dt = 10.0
nfreq = 28
ntau  = 30
ntbas = 8
kscale = 5
# raised-cosine basis functions
kcosbas = strf.cosbasis(ntau, ntbas)

from scipy.signal import resample
filts = np.load('../../filters.npz')
print(filts.keys())
k1 = resample(filts['bbs'] * kscale, nfreq, axis=0)[:,ntau-1::-1]
k1c = strf.to_basis(k1, kcosbas)

plt.imshow(k1, cmap='jet', aspect='auto')

In [None]:
# generate some random data to fit
np.random.seed(1)
stim_dt = 10.0
upsample = int(stim_dt / model_dt)
stim = np.random.randn(nfreq, int(n_samples / (stim_dt / model_dt)))
stim[:,:100] = 0
plt.imshow(stim[:,:400], aspect='auto')

In [None]:
mat.random_seed(1)

def predict_spikes(V, params, dt, upsample):
    omega, a1, a2, t1, t2, tref = params
    return mat.predict_poisson(V - omega, (a1, a2), (t1, t2), tref, 
                               dt, upsample)

data = []
V = strf.convolve(stim, k1)
for i in range(n_assim + n_test):
    spikes = predict_spikes(V, matparams, model_dt, upsample)
    H = mat.adaptation(spikes, matparams[3:5], model_dt)
    z = np.nonzero(spikes)[0]
    d = {"H": H,
         "duration": duration,
         "spike_t": z, 
         "spike_v": spikes,
        }
    data.append(d)

# split into assimilation and test sets
assim_data = data[:n_assim]
test_data = data[n_assim:]

In [None]:
ax1 = plt.subplot(211)
ax2 = plt.subplot(212)
ax1.plot(V)
for i, d in enumerate(data):
    ax2.vlines(d["spike_t"], i, i + 0.5)
ax1.set_xlim(0, 8000 // upsample)
ax2.set_xlim(0, 8000)
len(data[0]["spike_t"])

## Estimate parameters

In [None]:
from theano import config
import scipy.optimize as op
ftype = config.floatX

# combine the trials
spikes = np.stack([d["spike_v"] for d in data], axis=1).astype(ftype)
# spikes in the exponential basis set
X_spikes = np.stack([d["H"] for d in data], axis=2).astype(ftype)
# generate design matrix for stimulus
X_stim = strf.lagged_matrix(stim, kcosbas).astype(ftype)

# initial guess of strf
sta = strf.correlate(X_stim, spikes)
rf_sta = strf.as_matrix(sta, ntbas)
plt.subplot(221).imshow(k1c, cmap='jet', aspect='auto')
plt.subplot(222).imshow(rf_sta, cmap='jet', aspect='auto')
plt.subplot(223).imshow(strf.from_basis(k1c, kcosbas), cmap='jet', aspect='auto')
plt.subplot(224).imshow(strf.from_basis(rf_sta, kcosbas), cmap='jet', aspect='auto')

In [None]:
from theano import function, config, shared, sparse, gradient
import theano.tensor as T
import scipy.sparse as sps

if X_spikes.ndim == 2:
    spike_design = np.expand_dims(X_spikes, 2)
if spikes.ndim == 1:
    spikes = np.expand_dims(spikes, 1)

nframes, nk = X_stim.shape
nbins, nalpha, ntrials = X_spikes.shape
upsample = int(stim_dt / model_dt)
# make an interpolation matrix
interp = sps.kron(sps.eye(nframes),
                  np.ones((upsample, 1), dtype=config.floatX),
                  format='csc')

# load the data into theano.shared structures
M = shared(interp)
dt = shared(model_dt)
Xstim = shared(X_stim)
Xspke = shared(np.rollaxis(X_spikes, 2))
spkx, spky = map(shared, spikes.nonzero())

# split out the parameter vector
w = T.vector('w')
dc = w[0]
h = w[1:(nalpha+1)]
k = w[(nalpha+1):]
Vx = T.dot(Xstim, k)
# Vx has to be promoted to a matrix for structured_dot to work
Vi = sparse.structured_dot(M, T.shape_padright(Vx))
H = T.dot(Xspke, h).T
mu = Vi - H - dc
ll = T.exp(mu).sum() * dt - mu[spkx, spky].sum()
dL = T.grad(ll, w)
# arbitrary vector for hessian-vector product
v = T.vector('v')
ddLv = T.grad(T.sum(dL * v), w)

fV = function([w], Vx)
fH = function([w], H)
fL = function([w], ll)
fgrad = function([w], dL)
fhess = function([w, v], ddLv)

In [None]:
# initial likelihood
w0 = np.r_[0, 0, 0, sta]
fL(w0)

In [None]:
%%time
w0 = np.r_[0, 0, 0, sta]
w1 =  op.fmin_ncg(fL, w0, fgrad, fhess_p=fhess, maxiter=100)

In [None]:
print(w1[:3])
rf_mle = strf.as_matrix(w1[3:], ntbas)
plt.subplot(131).imshow(strf.from_basis(k1c, kcosbas), cmap='jet', aspect='auto')
plt.subplot(132).imshow(strf.from_basis(rf_sta, kcosbas), cmap='jet', aspect='auto')
plt.subplot(133).imshow(strf.from_basis(rf_mle, kcosbas), cmap='jet', aspect='auto')

In [None]:
# fit low-rank approximation
krank = 1
k0f, k0t = strf.factorize(rf_mle, krank)
plt.imshow(strf.from_basis(np.dot(k0f, k0t), kcosbas), cmap='jet', aspect='auto')

In [None]:
# bilinear version of the model
nkf = nfreq * krank
nkt = ntbas * krank
# bilinear convolution
Mkt = sps.kron(sps.eye(nkf), k0t.reshape(nkt, krank), format='csc')
dSSdx = X_stim * Mkt # sparse dot product
Vpb = np.dot(dSSdx, k0f)
Vpp = np.dot(X_stim, w1[3:])
plt.plot(Vpb[:400])
plt.plot(Vpp[:400])

In [None]:
from theano.tensor import slinalg
# split out the parameter vector

_nkt = shared(nkt)
_nkf = shared(nkf)
w = T.vector('w')
dc = w[0]
h = w[1:(nalpha+1)]
kt = w[(nalpha+1):(nalpha+_nkt+1)]
kf = w[(nalpha+_nkt+1):(nalpha+_nkt+_nkf+1)]
k = T.dot(kf.reshape((_nkf, krank)), kt.reshape((krank, _nkt))).ravel()

# convolution: first with block-diagonal matrix containing temporal vectors then with freq vectors
# it might be nice for these to be sparse
# Mkt = T.slinalg.kron(T.eye(nkf), kt.reshape((nkt, krank)))
# Vxt = T.dot(Xstim, Mkt)
# Vx = T.dot(Vxt, kf)
Vx = T.dot(Xstim, k)
# Vx has to be promoted to a matrix for structured_dot to work
Vi = sparse.structured_dot(M, T.shape_padright(Vx))
H = T.dot(Xspke, h).T
mu = Vi - H - dc
ll = T.exp(mu).sum() * dt - mu[spkx, spky].sum()
dL = T.grad(ll, w)
# arbitrary vector for hessian-vector product
v = T.vector('v')
ddLv = T.grad(T.sum(dL * v), w)

fVb = function([w], Vx)
fHb = function([w], H)
fLb = function([w], ll)
fgradb = function([w], dL)
fhessb = function([w, v], ddLv)

In [None]:
w2 = np.r_[0, 0, 0, k0t.squeeze(), k0f.squeeze()]
%time w3 =  op.fmin_ncg(fLb, w2, fgradb, fhess_p=fhessb, maxiter=100)

In [None]:
print(w3[:3])
w3_kt = w3[3:(3+nkt)]
w3_kf = w3[(3+nkt):]
rf_bl = np.outer(w3_kf, w3_kt)
plt.subplot(211).imshow(strf.from_basis(rf_bl, kcosbas), cmap='jet', aspect='auto')
plt.subplot(212).imshow(np.fliplr(k1), cmap='jet', aspect='auto')