In [None]:
from __future__ import print_function, division
import numpy as np
import scipy as sp
from numba import jit

from dstrf import mat, strf, mle

# plotting packages
%matplotlib inline
import matplotlib.pyplot as plt # plotting functions
import seaborn as sns           # data visualization package
sns.set_style("whitegrid")

In [None]:
# model parameters: (α1, α2, ω, τ1, τ2, tref)
matparams = np.asarray([100, 2, 7, 10, 200, 2], dtype='d')
model_dt = 0.5

# data parameters
duration = 500000
n_samples = int(duration / model_dt)
n_assim = 1
n_test = 0

In [None]:
# convolution - simple alpha kernel
from scipy.signal import resample
stim_dt = 10.0
upsample = int(stim_dt / model_dt)

# alpha filter
tau_h = 50
tt = np.arange(0, 600, stim_dt)
ka = np.flipud(tt / tau_h * np.exp(-tt / tau_h))

# difference of gammas:
from scipy.special import gamma
tau_h1 = 600. / 32
tau_h2 = 600. / 16
kg1 = 1/(gamma(6)*tau_h1)*(tt/tau_h1)**5 * np.exp(-tt/tau_h1)
kg2 = 1/(gamma(6)*tau_h2)*(tt/tau_h2)**5 * np.exp(-tt/tau_h2)
kg = np.flipud(kg1 - kg2 / 1.5)
kg /= np.linalg.norm(kg)

# select a filter
k1 = (kg * 5)[::-1]
plt.plot(k1)

def filter_stimulus(S, k1, upsample=None):
    X = np.convolve(S, k1, mode="full")[:S.size]
    if upsample is not None:
        t = np.arange(X.size)
        tu = np.arange(0, X.size, 1 / upsample)
        return np.interp(tu, t, X)
    else:
        return X
    

In [None]:
@jit
def predict_spikes(V, params, dt, upsample):
    N = V.size * upsample
    S = np.zeros(N, dtype='i')
    Vi = np.zeros(N, dtype='d')
    R = np.random.uniform(size=N)
    omega = params[2]
    A1 = np.exp(-dt / params[3])
    A2 = np.exp(-dt / params[4])
    i_refrac = int(params[5] / dt)
    H1 = H2 = 0
    iref = 0
    for i in range(N):
        Vi[i] = V[i // upsample]
        H1 *= A1
        H2 *= A2
        p = np.exp(Vi[i] - H1 - H2 - omega) * dt
        if i > iref and p > R[i]:
            H1 += params[0]
            H2 += params[1]
            iref = i + i_refrac
            S[i] = 1
    return Vi, S


def lci_poisson(V, H, spikes, params, dt):
    mu = V - H[:, 0] * params[0] - H[:, 1] * params[1] - params[2]
    lp = spikes * mu - dt * np.exp(mu)
    return lp.sum()


In [None]:
# generate data to fit
np.random.seed(1)
data = []
stim = np.random.randn(int(n_samples / (stim_dt / model_dt)))
stim[:100] = 0
        
V = filter_stimulus(stim, k1, upsample=1)
for i in range(n_assim + n_test):
    Vinterp, spikes = predict_spikes(V, matparams, model_dt, upsample)
    H = np.column_stack([mat.adaptation(spikes, tau, model_dt)
                         for tau in matparams[3:5]])
    z = np.nonzero(spikes)[0]
    d = {"V": Vinterp,
         "H": H,
         "duration": duration,
         "spike_t": z, 
         "spike_v": spikes,
         "lci": lci_poisson(Vinterp, H, spikes, matparams, model_dt)
         }
    data.append(d)

# split into assimilation and test sets
assim_data = data[:n_assim]
test_data = data[n_assim:]

In [None]:
ax1 = plt.subplot(211)
ax2 = plt.subplot(212)
for i, d in enumerate(data):
    ax2.vlines(d["spike_t"], i, i + 0.5)
ax1.plot(d["V"])
ax1.set_xlim(0, 8000)
ax2.set_xlim(0, 8000)
print(sum(d["lci"] for d in data))
print("spikes: {}; rate: {} / dt".format(np.mean([d["spike_t"].size for d in data]), 
                                         np.mean([d["spike_t"].size / d["duration"] for d in data])))

In [None]:
# generate design matrix for stimulus
X_stim = strf.lagged_matrix(stim.reshape(1, stim.size), k1.size)
# initial guess of strf
sta = strf.correlate(X_stim, d["spike_v"])
plt.plot(sta[::-1])
plt.plot(k1)

In [None]:
psth = np.sum([d["spike_v"] for d in data], axis=0)
H = np.column_stack([mat.predict_adaptation(psth, tau, model_dt)
                     for tau in matparams[3:5]])
plt.plot(psth[:8000])
plt.plot(d["spike_v"][:8000])

In [None]:
lfuns = mle.make_likelihood(X_stim, d["H"], d["spike_t"], stim_dt, model_dt)
w = np.r_[matparams[2], matparams[:2], k1[::-1]]
lci = lfuns['lci'](w)
plt.plot(lci[:8000])
lfuns['loglike'](w)

In [None]:
import scipy.optimize as op
w0 = np.r_[0, 0, 0, sta]
w1 = op.fmin_ncg(lfuns['loglike'], w, lfuns['gradient'], 
                 fhess=lfuns['hessian'], maxiter=100)

In [None]:
print(w1[:3])
rf_est = w1[:3:-1]
plt.plot(rf_est)
plt.plot(w0[:3:-1])
plt.plot(k1)