# Autoregressive Point-Processes as Latent State-Space Models

## Configure notebook

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from __future__ import division
from __future__ import print_function

# Load scipy/numpy/matplotlib
from   scipy.linalg import expm
import matplotlib.pyplot as plt
from   pylab import *

# Configure figure resolution
plt.rcParams['figure.figsize'] = (12.0, 6.0)
plt.rcParams['savefig.dpi'   ] = 100

from izh       import * # Routines for sampling Izhikevich neurons
from plot      import * # Misc. plotting routines
from glm       import * # GLM fitting
from arppglm   import * # Sampling and integration
from utilities import * # Other utilities
from arguments import * # Argument verification

figure_prefix = "RuleSanguinetti2018_figure_"

import numpy as np

In [None]:
from neurotools.nlab import *

# Case study: phasic bursting Izhikevich neuron

In [None]:
# Izhikevich parameters
izh = (0.02,0.25,-55,0.05) # a, b, c, d
dt  = 1.0

nplot = 1000   # time points for plotting

# Generate constant drive with synaptic noise
I = 0.6
stimulus = randn(nplot)*sqrt(I)+I

# Plot current input
subplot(311); plot(stimulus)
xlim(0,nplot); nox(); noaxis(); ylabel('pA')
title('Current injected')

# Solve Izh model
state = sim_izh(*izh,stimulus)

# Plot voltage and spikes
subplot(312); plot(state[:,1],color=OCHRE);
xlim(0,nplot); noaxis(); addspikes(state[:,-1]); ylabel('mV');
title('Simulated voltage and spikes');

## Train model on pulses

GLMs can emulate neural firing, but have limited ability to generalize outside of the dynamical regime in which they are trained (Weber & Pillow 2017). For this reason, we train with stimuli that elicit phasic bursting responses (tonic bursting seems to be possible at higher stimulation currents, but interferes with the GLMs ability to model the phasic bursting regime).

In [None]:
# Define pulse training stimuli
'''
offset     = -1     # Baseline current (picoamps)
min_amp    = 0.3    # Smallest current step (picoamps)
max_amp    = 0.7    # Largest  current step (picoamps)
min_pulse  = 10     # Shortest pulse duration (ms)
max_pulse  = 500    # Longest  pulse duration (ms)
'''

# More challenging
offset     = -0.5     # Baseline current (picoamps)
min_amp    = 0.05    # Smallest current step (picoamps)
max_amp    = 3.0    # Largest  current step (picoamps)
min_pulse  = 10     # Shortest pulse duration (ms)
max_pulse  = 500    # Longest  pulse duration (ms)

amplitudes = exp(linspace(log(min_amp),log(max_amp),10))
durationms = int32(exp(linspace(log(min_pulse),log(max_pulse),10)))
stimulus   = pulse_sequence(amplitudes,durationms,offset)

# Define Ornstein–Uhlenbeck (OU) process training noise 
'''
ssvar     = 0.005          # Noise steady-state viariance (ln(pA)^2)
'''

ssvar     = 0.05          # Noise steady-state viariance (ln(pA)^2)

tau       = 200            # Noise correlation time constant (ms)
noisevar  = 2*ssvar/tau    # Noise fluctuation variance
sigma     = sqrt(noisevar) # Noise flucutation standard deviation
stimulus  += sample_ou_process(0,sigma,tau,dt,len(stimulus),ntrial=1).ravel()

ntrain     = len(stimulus)

# Plot training stimulus
subplot(311); plot(stimulus)
nox(); noaxis(); xlim(0,ntrain); ylabel('pA')
title('Training stimulus');

# Solve Izh model and get voltage and spikes
state = sim_izh(*izh,stimulus,dt=dt)
v,Y   = state[:,1],state[:,2]

# Plot voltage and spikes
subplot(312); plot(v,color=OCHRE); addspikes(Y,lw=0.05);
noaxis(); xlim(0,ntrain); ylabel('mV');
title('Simulated voltage and spikes');

# Fit GLM to Izhikevich model

### Define history basis functions

In [None]:
# Define history basis functions
N = 150   # Duration of history filter
K = 8     # number of basis elements
D = 5     # Duration of shortest basis element
B = make_cosine_basis(K,N,D,normalize=False)

# Plot history basis functions
subplot(421)
plot(B.T,color=BLACK,clip_on=False);
xlim(0,N); ylim(0,0.5); simpleaxis()
xlabel('Time lag (ms)')
title('History basis functions')

### Generate stimulus and spiking history training features

In [None]:
# Build stimulus filter (history trace of I)
# These are needed to model subthreshold dynamics
Bh = array([convolve(b,stimulus) for b in B]).T[:ntrain]
Bp = concatenate([zeros((K,1)),B],axis=1)
By = array([convolve(b,Y) for b in Bp]).T[:ntrain]

# Plot stimulus history features
subplot(311); plot(Bh); noxyaxes();
title('Stimulus history features');

# Plot spike history features
subplot(312); plot(By);
for t in find(Y>0): axvline(t,lw=0.1,color=BLACK)
noaxis(); noy(); xlabel('ms');
title('Spike history features');

## Train model

In [None]:
# Compose feature vector and fit GLM
X = concatenate([By,Bh],axis=1)
m,bhat = fitGLM(X,Y)

bhat_spikehist = bhat[:K]
bhat_stimulus  = bhat[K:]
beta = bhat[:K].reshape(K,1)

### Pulse stimulus for demonstration

In [None]:
# define demo pulse
duration = 150  # Pulse duration (ms)
padding  = 50   # Pulse padding (ms)
burnin   = 200  # Time for Izhikevich model to settle (ms)
current  = 0.3  # Pulse current (pA)
ndemo    = duration + 2*padding # total length of demo stimulus (ms)

# Build demo stimulus
demo_stimulus = zeros(ndemo+burnin) + offset
demo_stimulus[burnin+padding:burnin+padding+duration] = current

figure(figsize=(6,6))

# Plot demo stimulus
subplot(411); plot(demo_stimulus[burnin:])
nox(); noaxis(); xlim(0,ndemo); ylabel('pA')
title('Training stimulus');

# Solve Izh model
demo_state = sim_izh(*izh,demo_stimulus,dt=dt)
demo_v = demo_state[burnin:,1]
demo_Y = demo_state[burnin:,2]

# Plot demo model spiking 
subplot(412); plot(demo_v,color=OCHRE);
addspikes(demo_Y)
nox(); noaxis(); xlim(0,ndemo); ylabel('mV');
title('Simulated voltage and spikes');

# Bulid GLM filter responses
demo_Bh = array([convolve(b,demo_stimulus) for b in B ]).T[burnin-1:][:ndemo,:]
demo_By = array([convolve(b,demo_Y       ) for b in Bp]).T[:ndemo,:]
demo_X  = concatenate([demo_By,demo_Bh],axis=1)

# Plot demo stimulus history features 
subplot(413); plot(demo_Bh);
xlim(0,ndemo); nox(); noaxis(); ylabel('a.u.')
title('Stimulus history features')

# Plot demo spiking history features 
subplot(414); plot(demo_By);
xlim(0,ndemo); noaxis(); xlabel('ms'); ylabel('a.u.')
title('Spiking history features')
subplots_adjust(hspace=0.5)
plt.draw()

In [None]:
figure(figsize=(6,8))
# Use dB for log-units
dB    = log10(e)*10

def labeltime():
    text(xlim()[1],0+pixels_to_yunits(5),'%d ms'%N,
        horizontalalignment='right',
        verticalalignment='bottom',
        fontsize=9)

# Plot stimulus history filter
stimyscale = 2
a1=subplot2grid((5,2),(0,0),colspan=1)
plot(bhat_stimulus.dot(B)*dB,color='k',lw=1,clip_on=False)
axhline(0,color='k',lw=1)
xlim(0,N); ylim(-stimyscale,stimyscale); nox(); nicey(); simpleraxis();
ylabel('Gain (dB)',fontsize=9); fudgey(10); labeltime()
title('Stimulus filter')
subfigurelabel('A')

# Plot spike history filter
histyscale = 20
a1=subplot2grid((5,2),(0,1),colspan=1)
plot(bhat_spikehist.dot(B)*dB,color='k',lw=1)
axhline(0,color='k',lw=1)
xlim(0,N); ylim(-histyscale,histyscale); nox(); nicey(); simpleraxis(); labeltime()
title('Post-spike filter')

# Illustrate neuron stimulus and response
a2=subplot2grid((5,2),(1,0),colspan=2)
plot(demo_v,'k',lw=0.7)
draw()
yl = ylim()
for t in find(demo_Y)-1:
    plot([t,t],yl,color='k',lw=0.3)
height = abs(diff(yl)*0.25)
lower  = yl[0]-height*1.5
ii     = demo_stimulus[burnin:][:ndemo]
ii     = (ii-min(ii))/(max(ii)-min(ii))
plot(ii*height+lower,color='k',lw=0.5,clip_on=False)
yscalebar(yl[1]-35,50,'50 mV'); 
yscalebar(height*0.5+lower,height,'%0.1f pA'%current)
xlim(0,ndemo); ylim(lower,yl[1]); noxyaxes()
title('Izhikevich neuron response')
subfigurelabel('B')

# Spike history contribution
a3=subplot2grid((5,2),(2,0),colspan=2)
plot(bhat_spikehist.dot(demo_By.T),color='k',lw=0.7)
xlim(0,ndemo); yscalebar(mean(ylim()),10,'10 dB'); noxyaxes()
title('Post-spike contribution to log-intensity')
subfigurelabel('C')

# Stimulus contribution
a4=subplot2grid((5,2),(3,0),colspan=2)
plot(bhat_stimulus.dot(demo_Bh.T),color='k',lw=0.7)
xlim(0,ndemo); yscalebar(mean(ylim()),10,'10 dB'); noxyaxes()
title('Stimulus contribution to log-intensity')
subfigurelabel('D')

# Sample the spiking response of the GLM
a5=subplot2grid((5,2),(4,0),colspan=2)
nsample = 20
stim = m + bhat_stimulus.dot(demo_Bh.T)
ysamp,logratesamp = ensemble_sample(stim,B,beta,nsample)
pcolormesh(1-ysamp.T,cmap="gray")
noaxis(); xticks(arange(0,251,50)); yticks([0,nsample],['0','%s'%nsample])
xlabel('Time (ms)'); ylabel('Sample #',fontsize=9); fudgey(20)
title('Sampled autoregressive point-process model')
subfigurelabel('E',dy=10)

# Make final adjustments
plt.draw()
subplots_adjust(hspace=0.5)
nudge_axis_y(-10,a2); adjust_axis_height_pixels(20,a3); 
nudge_axis_y(-10,a3); adjust_axis_height_pixels(20,a4)
suptitle('Phasic bursting autoregressive PP-GLM model')

savefig(figure_prefix+'1.pdf',transparent=True,bbox_inches='tight')

# Construct low-dimensional system for history process

If the history basis is chosen suitably, the resuling linear system closely approximates the history basis. One can also use a linear system for the history filter form the outset, e.g. a collection of decaying exponential basis functions, enabeling an exact model. Since histor bases are commonly used, and the filtering approach is discussed elsewhere, we demonstrate the low-dimensional delay-line projection here. 

In [None]:
# Create discrete differentiation operator
Dtau = -eye(N) + eye(N,k=-1)
# Create delta operator (to inject signal into delay line)
S = zeros((N,1))
S[0,0] = 1
# Perform a change of basis from function space into the basis projection B
A = B.dot(Dtau).dot(pinv(B))
C = B.dot(S)

figure(figsize=(8,4))
subplot(221); imshow(Dtau)
title('$\partial_\\tau$')
subplot(222); imshow(S.T,aspect=N/2)
title('$\delta_{\\tau=0}$')
subplot(223); imshow(A)
title('$B \partial_\\tau B^{+}$')
subplot(224)
imshow(C.T,aspect=K/2)
title('$B \delta_{\\tau=0}$')
subplots_adjust(hspace=0.5,wspace=-0.3)

# Save model for later use

In [None]:
saved_training_model = {}
saved_training_model['K'] = K
saved_training_model['B'] = B
saved_training_model['By'] = By
saved_training_model['Bh'] = Bh
saved_training_model['A'] = A
saved_training_model['C'] = C
saved_training_model['Y'] = Y 
saved_training_model['dt'] = dt
saved_training_model = scipy.io.savemat('saved_training_model.mat',saved_training_model)

# Illustrate basis projection

In [None]:
figure(figsize=(9,1.5))
styles = ['--','-',':']

impulse = zeros(N)
impulse[0]=1

subplot(121)
for i,b in enumerate(array([convolve(b,impulse) for b in B])):
    plot(b[:N],lw=1,linestyle=styles[i%len(styles)],color='k',clip_on=False)
simpleaxis(); xlim(0,N); ylim(0,.5); nicexy()
title('Original basis')
xticks([0,150],['0','150 ms'])

filtered = linfilter(A,C,impulse)
subplot(122)
for i,b in enumerate(filtered.T):
    plot(b[:N],lw=1,linestyle=styles[i%len(styles)],color='k',clip_on=False)
simpleaxis(); xlim(0,N); ylim(0,.5); nicexy()
title('Approximated (filtered) basis')
xticks([0,150],['0','150 ms'])

savefig(figure_prefix+'2.pdf',transparent=True,bbox_inches='tight')

# Estimate single-time marginal log-intensity and varience using several different procedures



In [None]:
# "True" sample from point process model
demo_logxpp,demo_logvpp,_,_ = ensemble_sample_moments(stim,B,beta,M=10000)
demo_lxpp = box_filter(demo_logxpp,5)
demo_lvpp = box_filter(demo_logvpp,5)

# Sample from langevin approximation of point process
demo_logxlv,demo_logvlv,_,_  = langevin_sample_moments(stim,A,beta,C,M=10000)
demo_lxlv = box_filter(demo_logxlv,5)
demo_lvlv = box_filter(demo_logvlv,5)

In [None]:
# Estimate moments from mean-field and linear noise approximation
demo_logxmf,demo_logvmf,_,_ = integrate_moments(stim,A,beta,C,
                                                method     = "LNA",
                                                int_method = "euler")

# Only first two moments of rate are used for filtering
demo_logxso,demo_logvso,_,_ = integrate_moments(stim,A,beta,C,
                                                method     = "second_order",
                                                int_method = "euler")
# Estimate using moment closure
demo_logxmc,demo_logvmc,_,_ = integrate_moments(stim,A,beta,C,
                                                method     = "moment_closure",
                                                int_method = "euler")

In [None]:
figure(figsize=(12,5))

# Plot sampled point process against sampled Langevin
subplot(221)
stderrplot(demo_lxpp,demo_lvpp,BLACK,filled=1)
stderrplot(demo_lxlv,demo_lvlv,OCHRE,filled=0)
xlim(0,ndemo); noxyaxes()
title('Langevin')

# Plot LNA against sampled Langevin
subplot(322)
stderrplot(demo_lxpp,demo_lvpp,BLACK,filled=1)
stderrplot(demo_logxmf,demo_logvmf,TURQUOISE,filled=0)
xlim(0,ndemo); noxyaxes()
title('Mean-field, linear noise approximation')

# Plot moment-closure against sampled Langevin
subplot(223)
stderrplot(demo_lxpp,demo_lvpp,BLACK,filled=1)
stderrplot(demo_logxmc,demo_logvmc,RUST,filled=0)
xlim(0,ndemo); noy(); noxyaxes()
title('Gaussian moment-closure')

# Plot second-order against sampled Langevin
# This amounts to moment closure, where the GLM nonlinearity
# is locally approximated as a quadratic function. The log-
# Gaussian distribution on the rates is, in genral, too right
# skewed, and over-estimates the probability of high rates. 
# The quadratic approximation removes the effects of higher
# order moments, attenuating this heavy tail and leading to
# a moment closure that is less stiff to integrate and more
# accurately captures the second moment.
subplot(224)
stderrplot(demo_lxpp,demo_lvpp,BLACK,filled=1)
stderrplot(demo_logxso,demo_logvso, AZURE,filled=0)
xlim(0,ndemo); noy(); noxyaxes();
title('Second-order approximation')


subplots_adjust(wspace=0.1,hspace=0.2)

# Demonstration pulse sequence

Construct a more "naturalistic" intput stimulus

In [None]:
NSTIM    = 3000

# Random pulse sequence
offset     = -0.5
pulse_stimulus = pulse_sequence(linspace(0.5,1.0,5),int32(linspace(50,500,7)),offset)

# OU process defining additional Gaussian noise
ntrain = len(pulse_stimulus)
ssvar    = 1
tau      = 100
noisevar = 2*ssvar/tau
sigma    = sqrt(noisevar)
noise2   = sample_ou_process(0,sigma,tau,dt,ntrain,ntrial=1).ravel()

# Combine pulses with noise, apply synaptic filter
stimulus = pulse_stimulus + noise2
stimulus = stimulus[:NSTIM*2]

# Plot
subplot(411); plot(stimulus[NSTIM:])
xlim(0,NSTIM); ylabel('pA'); noaxis(); nox()
title('Stimulus')

# Solve Izh model
state = sim_izh(*izh,stimulus,dt=dt)
V     = state[:,1][NSTIM:]
Y     = state[:,2][NSTIM:]
subplot(312); plot(V,color=OCHRE);
addspikes(Y)
xlim(0,NSTIM); noaxis(); nox()
title('Simulated voltage and spikes');
ylabel('mV');

stimulus = stimulus[NSTIM:]

# Build stimulus filter (history trace of I)
# These are needed to model subthreshold dynamics
demo_Bh = array([convolve(b,stimulus) for b in B ]).T[:NSTIM]
demo_By = array([convolve(b,Y)        for b in Bp]).T[:NSTIM,:]

# Plot stimulus history features
subplot(313); plot(demo_Bh); 
xlim(0,NSTIM); noaxis(); xlabel('Time (ms)')
title('Stimulus history features')

# Demonstrate moment-closure approximation of the AR-PP-GLM

In [None]:
# Filtered stimulus with offset
stim = m + bhat_stimulus.dot(demo_Bh.T)

# "True" sample from point process model
logxpp,logvpp,ratepp,ratevpp = ensemble_sample_moments(stim,B,beta,M=1000)
lxpp = box_filter(logxpp,5)
lvpp = box_filter(logvpp,5)

# Sample from langevin approximation of point process
logxlv,logvlv,expmlv,expvlv  = langevin_sample_moments(stim,A,beta,C,M=5000)
lxlv = box_filter(logxlv,5)
lvlv = box_filter(logvlv,5)

# Estimate moments from expansion to second order 
# Only first two moments of rate are used for filtering
logxso,logvso,_,_ = integrate_moments(stim,A,beta,C,
                                            method     = "second_order",
                                            int_method = "exponential",
                                            oversample = 3)

logxmc,logvmc,_,_ = integrate_moments(stim,A,beta,C,
                                            method     = "moment_closure",
                                            int_method = "euler",
                                            oversample = 5)

logxmf,logvmf,_,_ = integrate_moments(stim,A,beta,C,
                                            method     = "LNA",
                                            int_method = "euler",
                                            oversample = 5)

In [None]:
figure(figsize=(8,13))
NROWS = 10
NPLOT = 3000
NSHOW = 1000

def stimmarks():
    axvline(padding,color=(0.5,)*3,lw=0.5)
    axvline(ndemo-padding,color=(0.5,)*3,lw=0.5)

# Plot second-order approximation
subplot2grid((NROWS,2),(1,1),facecolor=(1,1,1,0))
stderrplot(demo_lxpp  ,demo_lvpp  ,BLACK    ,filled=1)
stderrplot(demo_logxso,demo_logvso,AZURE,filled=0)
stimmarks(); xlim(0,ndemo); noy(); noaxis()
yl = ylim()
xlabel('Time (ms)')
title('Second order')
subfigurelabel('D')
    
# Plot true sampled GLM against Langevin sampled GLM
subplot2grid((NROWS,2),(0,0),facecolor=(1,1,1,0))
stderrplot(demo_lxpp,demo_lvpp,BLACK,filled=1)
stderrplot(demo_lxlv,demo_lvlv,OCHRE,filled=0)
stimmarks(); ylim(*yl); xlim(0,ndemo); noxyaxes()
yscalebar(mean(ylim())/dB,20/dB,'20 dB'); 
title('Langevin approximation')
subfigurelabel('A')

# Plot mean-field LNA solution for moments
subplot2grid((NROWS,2),(0,1),facecolor=(1,1,1,0))
stderrplot(demo_lxpp  ,demo_lvpp  ,BLACK,filled=1)
stderrplot(demo_logxmf,demo_logvmf,TURQUOISE,filled=0)
stimmarks(); ylim(*yl); xlim(0,ndemo); noxyaxes(); simpleraxis()
title('Mean-field, LNA')
subfigurelabel('B')

# Plot moment closure moments
subplot2grid((NROWS,2),(1,0),facecolor=(1,1,1,0))
stderrplot(demo_lxpp,demo_lvpp,BLACK,filled=1)
stderrplot(demo_logxmc,demo_logvmc,RUST,filled=0)
stimmarks(); ylim(*yl); xlim(0,ndemo); noy(); noaxis()
xlabel('Time (ms)')
title('Gaussian moment-closure')
subfigurelabel('C')

# Plot Izh neuron
sc = 20
ax4=subplot2grid((NROWS,2),(2,0),colspan=2,facecolor=(1,1,1,0))
plot(V,color=RUST,lw=1.25)
xlim(0,NSHOW);
yscalebar(min(V)+25,50,'50 mV')
offset = -min(stimulus*sc) + ylim()[1]+30
ss = stimulus*sc + offset
plot(ss,color=BLACK,lw=1)
yscalebar(mean(ss),sc*5,'5 pA'); 
noxyaxes()
title('Stimulus example')
subfigurelabel('E')

# Plot stimulus
ax5=subplot2grid((NROWS,2),(3,0),colspan=2,facecolor=(1,1,1,0))
stderrplot(lxpp,lvpp,BLACK,filled=1)
# Langevin approximation
stderrplot(logxlv,logvlv,OCHRE,filled=0)
xlim(0,NSHOW); ylim(*yl); yscalebar(mean(ylim()),20,'20 dB'); noxyaxes(); 
title('Langevin approximation')
subfigurelabel('F')

# Plot sampled from Langevin (preserves some autocorrelation)
ax6=subplot2grid((NROWS,2),(4,0),colspan=2,facecolor=(1,1,1,0))
nsamp = 20
p = np.random.poisson(exp(langevin_sample(stim,A,beta,C,M=nsamp)))
pcolormesh(-int32(p.T>0),cmap='gray')
noaxis(); nox(); xlim(0,NSHOW);
ylabel('Sample',fontsize=9);

# Plot stimulus
ax7=subplot2grid((NROWS,2),(5,0),colspan=2,facecolor=(1,1,1,0))
stderrplot(lxpp,lvpp,BLACK,filled=1)
stderrplot(logxso,logvso,AZURE,filled=0)
xlim(0,NSHOW); ylim(*yl); yscalebar(mean(ylim()),20,'20 dB'); noxyaxes(); 
title('Second-order state-space model')
subfigurelabel('G')

# Plot sampled from single-time marginals (no time correlation)
ax8=subplot2grid((NROWS,2),(6,0),colspan=2,facecolor=(1,1,1,0))
rate = exp(logxso)*(1+0.5*logvso)
p = np.random.poisson(rate[:,None],(len(rate),nsamp))
pcolormesh(-int32(p.T>0),cmap='gray')
noaxis(); xticks(arange(0,1001,100)); xlim(0,NSHOW);
xlabel('Time (ms)'); ylabel('Sample',fontsize=9); 

# Adjust axes
subplots_adjust(wspace=0.1,hspace=0.3)
nudge_axis_y(-75,ax4)
adjust_axis_height_pixels(20,ax5)
nudge_axis_y(-85,ax5)
adjust_axis_height_pixels(30,ax6)
nudge_axis_y(-35,ax6)
adjust_axis_height_pixels(20,ax7)
nudge_axis_y(-40,ax7)
adjust_axis_height_pixels(30,ax8)
nudge_axis_y(10,ax8)

savefig(figure_prefix+'4.pdf',transparent=True,bbox_inches='tight',format='pdf')

# Measurement, filtering, and inference

In [None]:
NFILT = NPLOT#200
#assert(NFILT<=NPLOT)

from measurements import *
from arppglm import *

## Test different updates (variational, Laplace, moment-matching)

In [None]:
logxso,logvso,M1,M2 = integrate_moments(stim,A,beta,C,
                                        method     = "second_order",
                                        int_method = "exponential",
                                        oversample = 3)

t = find(Y>0)[0]
m1 = M1[t]
m2 = M2[t]
s  = stim[t]
y  = Y[i]

# Plot Prior
v  = beta.T.dot(m2).dot(beta)[0,0]
mu = beta.T.dot(m1)
ss = sqrt(v)
x = np.linspace(mu-4*ss,mu+6*ss,150)
plot(x,npdf(mu,v**0.5,x).ravel(),color=GREEN,label='prior')

# posterior, exact
l = y*(x+s)-sexp(x+s)
l-= np.max(l)
l+=-.5*(x-mu)**2/v-.5*slog(v) 
p = sexp(l)
eps = 1e-12
p[p<eps] = eps   
p = (p/np.sum(p)).ravel()
plot(x,p/diff(x)[0],color=OCHRE,label='posterior')
pmode = argmax(p)
scatter([x[pmode]],[p[pmode]/diff(x)[0]],s=20,color=BLACK)
axvline(x[pmode],color=BLACK)

# Posterior moments
mp,vp = univariate_lgp_update_moment(mu,v,y,s,1.0)
l = -.5*(x-mp)**2/vp-.5*slog(vp) 
p = sexp(l-np.max(l))
eps = 1e-12
p[p<eps] = eps   
p /= np.sum(p)
plot(x,p/diff(x)[0],color=RUST,label='posterior moments')

# Posterior moments
mp,vp = univariate_lgp_update_variational(mu,v,y,s,1.0)
l = -.5*(x-mp)**2/vp-.5*slog(vp) 
p = sexp(l-np.max(l))
eps = 1e-12
p[p<eps] = eps   
p /= np.sum(p)
plot(x,p/diff(x)[0],color=AZURE,label='posterior variational')

# Posterior moments
mp,vp = univariate_lgp_update_laplace(mu,v,y,s,1.0)
l = -.5*(x-mp)**2/vp-.5*slog(vp) 
p = sexp(l-np.max(l))
eps = 1e-12
p[p<eps] = eps   
p /= np.sum(p)
plot(x,p/diff(x)[0],color=MAUVE,label='posterior Laplace')

legend()

# Benchmark different measurement updates

In [None]:
m = -5.1

In [None]:
figure(figsize=(10,8))

for i,measurement in enumerate("variational laplace moment".split()):
    tic()
    fallLR,fallLV,fallM1,fallM2,nll = filter_moments(stim[:NFILT],Y[:NFILT],A,beta,C,m,
                                         method      = "moment_closure",
                                         int_method  = "euler",
                                         measurement = measurement,
                                         oversample  = 25,
                                         reg_cov     = 0.0001,
                                         reg_rate    = 0.0000)
    toc()
    print(nll)
    subplot(4,1,i+1)
    logr = m + bhat_stimulus.dot(demo_Bh.T) + bhat_spikehist.dot(demo_By.T)
    stderrplot(fallLR,fallLV,BLACK,filled=1,lw=0.5)
    xlim(0,NFILT);
    plot(logr,color=RUST,lw=0.5)
    simpleraxis()
    title('method = %s'%measurement)
    ylim(-40,20)

# Compare to moment integration without filtering
subplot(4,1,i+1)
tic()
fallLR,fallLV,fallM1,fallM2 = integrate_moments(stim[:NFILT],A,beta,C,
                                     method      = "second_order",
                                     int_method  = "euler",
                                     oversample  = 4)
toc()
subplot(4,1,4)
logr = m + bhat_stimulus.dot(demo_Bh.T) + bhat_spikehist.dot(demo_By.T)
stderrplot(fallLR,fallLV,BLACK,filled=1,lw=0.5)
xlim(0,NFILT);
plot(logr,color=RUST,lw=0.5)
simpleraxis()
title('No Measurements')

subplots_adjust(hspace=0.4)

## Optimize model likelihood using filtering (new model)

In [None]:
measurement = "moment"
method      = "moment_closure"
int_method  = "euler"
oversample  = 25
showplot    = False

baseline_m = m

#@memoize
def objective(parameters):
    # parameters encode beta,m
    parameters = parameters.ravel()
    m          = parameters[0]
    beta       = parameters[1:].reshape((K,1))
    stim2      = stim[:NFILT]*exp(m-baseline_m)
    
    # We need unconstrained system to also be stable!
    '''
    try:
        LR,LV,M1,M2 = integrate_moments(stim2,A,beta,C,
                                        method     = method,
                                        int_method = int_method,
                                        oversample = oversample)
        # Ensure moments do not diverge
        rate = exp(LR)*(1+0.5*LV)
        if not all([np.all(np.isfinite(x)) for x in [LR,LV,M1,M2,rate]]): 
            return inf
    except (KeyboardInterrupt, SystemExit): raise
    except: return inf
    '''
    
    # Get likelihood by filtering
    try:
        LR,LV,M1,M2,nll = filter_moments(stim2,Y[:NFILT],A,beta,C,m,
                                         method      = method,
                                         int_method  = int_method,
                                         measurement = measurement,
                                         oversample  = oversample,
                                         reg_cov     = 0.005,
                                         reg_rate    = 0.005)
    except (KeyboardInterrupt, SystemExit): raise
    except: 
        return inf
    
    print(nll,'['+','.join(['%0.6f'%x for x in parameters])+']')
    if showplot:
        subplot(311)
        logr = m + bhat_stimulus.dot(demo_Bh.T) + bhat_spikehist.dot(demo_By.T)
        stderrplot(fallLR,fallLV,BLACK,filled=1,lw=0.5)
        xlim(0,NFILT);
        plot(logr,color=RUST,lw=0.5)
        simpleraxis()
        show()
    
    if not np.isfinite(nll):
        return inf
    return nll

p0 = np.zeros((1+K))
p0[0 ] = m
p0[1:] = beta.ravel()

# Some previous optimization results; might want to start here 
p0 = [-5.11891304,7.73303371,-17.66706182,14.49299672,-8.02784155,4.92439071,-3.49133731,1.31580891,-0.97790629]
parameters = minimize_retry(objective,p0)
nll = objective(parameters)
print(nll,'['+','.join(['%0.8f'%x for x in parameters])+']')

# Inspect result

In [None]:
# Get new parameters from optimization
parameters = array(parameters)
m2         = parameters[0]-0.5
beta2      = parameters[1:].reshape((K,1))

logr2      = m2 + bhat_stimulus.dot(demo_Bh.T) + (beta2.T.dot(demo_By.T))[0]

stim2 = stim*exp(m2-m)
# "True" sample from new point process model
logxpp2,logvpp2,ratepp2,ratevpp2 = ensemble_sample_moments(stim2,B,beta2,M=1000)
lxpp2 = box_filter(logxpp2,5)
lvpp2 = box_filter(logvpp2,5)

In [None]:
NSHOW = 1000

subplot(411)
fallLR,fallLV,fallM1,fallM2,nll = filter_moments(stim[:NFILT],Y[:NFILT],A,beta,C,m,
                                     method      = "second_order",
                                     int_method  = "euler",
                                     oversample  = 5,
                                     reg_cov     = 0.001)
logr = m + bhat_stimulus.dot(demo_Bh.T) + bhat_spikehist.dot(demo_By.T)
stderrplot(lxpp,lvpp,BLACK,filled=1)
stderrplot(fallLR,fallLV,AZURE,filled=0)
xlim(0,NSHOW);
plot(logr,color=RUST)
simpleraxis()
title('unconstrained')

subplot(412)
fallLR,fallLV,fallM1,fallM2,nll = filter_moments(stim2[:NFILT],Y[:NFILT],A,beta2,C,m2,
                                     method      = "second_order",
                                     int_method  = "euler",
                                     oversample  = 5,
                                     reg_cov     = 0.001)
stderrplot(lxpp,lvpp,BLACK,filled=1)
stderrplot(fallLR,fallLV,AZURE,filled=0)
xlim(0,NSHOW);
plot(logr2,color=RUST)
simpleraxis()

# Unconstrained filtering
subplot(413)
fallLR,fallLV,fallM1,fallM2 = integrate_moments(stim[:NFILT],A,beta,C,
                                     method      = "second_order",
                                     int_method  = "euler",
                                     oversample  = 5)
stderrplot(lxpp,lvpp,BLACK,filled=1)
stderrplot(fallLR,fallLV,AZURE,filled=0)
xlim(0,NSHOW);
plot(logr,color=RUST)
simpleraxis()

# Unconstrained filtering
subplot(414)
logxso2,logvso2,fallM1,fallM2 = integrate_moments(stim2[:NFILT],A,beta2,C,
                                     method      = "second_order",
                                     int_method  = "euler",
                                     oversample  = 5)
stderrplot(lxpp2,lvpp2,BLACK,filled=1)
stderrplot(logxso2,logvso2,AZURE,filled=0)
xlim(0,NSHOW);
plot(logr2,color=RUST)
simpleraxis()

In [None]:
figure(figsize=(8,13))

sc = 20
ax4=subplot2grid((NROWS,2),(2,0),colspan=2,facecolor=(1,1,1,0))
plot(V,color=RUST,lw=1.25)
xlim(0,NSHOW);
yscalebar(min(V)+25,50,'50 mV')
offset = -min(stimulus*sc) + ylim()[1]+30
ss = stimulus*sc + offset
plot(ss,color=BLACK,lw=1)
yscalebar(mean(ss),sc*5,'5 pA'); 
noxyaxes()
title('Stimulus example')
subfigurelabel('A')

yl = (-40,20)

# Plot stimulus
ax5=subplot2grid((NROWS,2),(3,0),colspan=2,facecolor=(1,1,1,0))
stderrplot(lxpp,lvpp,BLACK,filled=1)
stderrplot(logxso,logvso,AZURE,filled=0)
xlim(0,NSHOW); ylim(*yl); yscalebar(mean(ylim()),20,'20 dB'); noxyaxes(); 
title('Second-order state-space model')
subfigurelabel('B')

# Plot sampled from single-time marginals (no time correlation)
ax6=subplot2grid((NROWS,2),(4,0),colspan=2,facecolor=(1,1,1,0))
rate = exp(logxso)*(1+0.5*logvso)
p = np.random.poisson(rate[:,None],(len(rate),nsamp))
pcolormesh(-int32(p.T>0),cmap='gray')
noxyaxes();  xlim(0,NSHOW);
ylabel('Sample',fontsize=9); 

# Plot stimulus
ax7=subplot2grid((NROWS,2),(5,0),colspan=2,facecolor=(1,1,1,0))
stderrplot(lxpp2,lvpp2,BLACK,filled=1)
stderrplot(logxso2,logvso2,AZURE,filled=0)
xlim(0,NSHOW); ylim(*yl); yscalebar(mean(ylim()),20,'20 dB'); noxyaxes(); 
title('Second-order state-space model')
subfigurelabel('G')

# Plot sampled from single-time marginals (no time correlation)
ax8=subplot2grid((NROWS,2),(6,0),colspan=2,facecolor=(1,1,1,0))
rate = exp(logxso2)*(1+0.5*logvso2)
p = np.random.poisson(rate[:,None],(len(rate),nsamp))
pcolormesh(-int32(p.T>0),cmap='gray')
noaxis(); xticks(arange(0,1001,100)); xlim(0,NSHOW);
xlabel('Time (ms)'); ylabel('Sample',fontsize=9); 


# Adjust axes
subplots_adjust(wspace=0.1,hspace=0.3)
nudge_axis_y(-75,ax4)
adjust_axis_height_pixels(20,ax5)
nudge_axis_y(-85,ax5)
adjust_axis_height_pixels(30,ax6)
nudge_axis_y(-35,ax6)
adjust_axis_height_pixels(20,ax7)
nudge_axis_y(-40,ax7)
adjust_axis_height_pixels(30,ax8)
nudge_axis_y(10,ax8)

# Problems!

Optimization is WAY too slow. There's also something wrong (excessive biase toward high rates). But to figure out what's wrong, we need a faster update!

# The surrogate method

In [None]:
from arppglm import *
from measurements import *
from utilities import *
from arguments import *

In [None]:
def filter_moments_surrogate(stim,Y,A,beta,C,m,
    dt          = 1.0,
    oversample  = 10,
    maxrate     = 500,
    maxvcorr    = 2000,
    method      = "moment_closure",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = 0.01,
    reg_rate    = 0.001,
    return_surrogates = True,
    surrogates        = None):
    # check arguments
    if oversample<1:
        raise ValueError('oversample must be non-negative integer')
    if return_surrogates and not surrogates is None:
        raise ValueError('Asked to compute surrogate likelihoods, but surrogates provided?')
    if not return_surrogates and surrogates is None:
        raise ValueError('No surrogate likelihoods provided?')
    # Precompute constants
    maxlogr   = np.log(maxrate)
    maxratemc = maxvcorr*maxrate
    dtfine    = dt/oversample
    T         = len(stim)
    K         = beta.size
    I         = np.eye(K)
    Cb        = C.dot(beta.T)
    CC        = C.dot(C.T)
    BB        = beta.dot(beta.T)
    Adt       = A*dtfine
    # Get measurement update function
    measurement            = get_measurement(measurement)
    mean_update,cov_update = get_moment_integrator(int_method,Adt)
    update                 = get_update_function(method,Cb,Adt,maxvcorr)
    # accumulate negative log-likelihood up to a constant
    nll = 0
    llrescale = 1.0/len(stim)
    # Store surrogate likelihoods
    if return_surrogates:
        surrogates = np.zeros((T,2))
    # Initial condition for moments
    M1 = pinv(beta,m).reshape((K,1))
    M2 = np.eye(K)*1e-2
    for i,s in enumerate(stim):
        # Integrate moments forward
        for j in range(oversample):
            logv  = beta.T.dot(M2).dot(beta)
            logx  = min(beta.T.dot(M1)+s,maxlogr)
            R0    = sexp(logx)*dtfine
            Rm,J  = update(logx,logv,R0,M1,M2)
            M2    = cov_update(M2,J) + CC*Rm
            M1    = mean_update(M1)  + C*Rm
        # Measurement update
        pM1,pM2 = M1,M2
        # If computing surrogates, requenst and store them
        if return_surrogates:
            M1,M2,ll,(mr,tr) = measurement_update_projected_gaussian_surrogate(\
                      M1,M2,Y[i],beta,s,dt,m,reg_rate,measurement,
                      return_surrogate=True)
            surrogates[i] = mr,tr
        # Otherwise, pass precomputed surrogate likelihood for update
        else:
            mr,tr = surrogates[i]
            M1,M2,ll = measurement_update_projected_gaussian_surrogate(\
                      M1,M2,Y[i],beta,s,dt,m,reg_rate,measurement,
                      surrogate=(mr, tr))
        nll -= ll*llrescale
        
        # Regularize
        strength = reg_cov+max(0,-np.min(np.diag(M2)))
        M2 = 0.5*(M2+M2.T) + strength*np.eye(K) 
        # Store moments
        allM1[i] = M1[:,0].copy()
        allM2[i] = M2.copy()
        allLR[i] = beta.T.dot(M1)+s
        allLV[i] = beta.T.dot(M2).dot(beta)
        
        # Heuristic: detect numerical failure and exit early
        failed = False
        failed|= np.any(M1)<-1e5
        failed|= logx>100*maxlogr
        failed|= nll<-1e10
        if failed:
            nll = inf
            break
    if return_surrogates:
        return allLR,allLV,allM1,allM2,nll,surrogates
    else:
        return allLR,allLV,allM1,allM2,nll

In [None]:
def filter_likelihood_surrogate(stim,Y,A,beta,C,m,
    dt          = 1.0,
    oversample  = 10,
    maxrate     = 500,
    maxvcorr    = 2000,
    method      = "moment_closure",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = 0.01,
    reg_rate    = 0.001,
    surrogates  = None):
    assert surrogates is not none
    # Precompute constants
    maxlogr   = np.log(maxrate)
    maxratemc = maxvcorr*maxrate
    dtfine    = dt/oversample
    T         = len(stim)
    K         = beta.size
    I         = np.eye(K)
    Cb        = C.dot(beta.T)
    CC        = C.dot(C.T)
    BB        = beta.dot(beta.T)
    Adt       = A*dtfine
    # Get measurement update function
    measurement = get_measurement(measurement)
    # Buid moment integrator functions
    mean_update, cov_update = get_moment_integrator(int_method,Adt)
    # Get update function (computes expected rate from moments)
    update = get_update_function(method,Cb,Adt,maxvcorr)
    # accumulate negative log-likelihood up to a constant
    nll = 0
    llrescale = 1.0/len(stim)
    # Store moments
    allM1 = np.zeros((T,K))
    allM2 = np.zeros((T,K,K))
    allLR = np.zeros((T))
    allLV = np.zeros((T))
    # Store surrogate likelihoods
    if return_surrogates:
        surrogates = np.zeros((T,2))
    # Initial condition for moments
    M1 = pinv(beta,m).reshape((K,1))
    M2 = np.eye(K)*1e-2
    for i,s in enumerate(stim):
        # Regularize
        M2 = 0.5*(M2+M2.T)+(reg_cov+max(0,-np.min(np.diag(M2))))*I
        # Integrate moments forward
        for j in range(oversample):
            logv  = beta.T.dot(M2).dot(beta)
            logx  = min(beta.T.dot(M1)+s,maxlogr)
            R0    = sexp(logx)*dtfine
            Rm,J  = update(logx,logv,R0,M1,M2)
            M2    = cov_update(M2,J) + CC*Rm
            M1    = mean_update(M1)  + C*Rm
        # Measurement update
        mr,tr    = surrogates[i]
        M1,M2,ll = measurement_update_projected_gaussian_surrogate(\
                  M1,M2,Y[i],beta,s,dt,m,reg_rate,measurement,
                  surrogate=(mr, tr))
        nll -= ll*llrescale
        # Heuristic: detect numerical failure and exit early
        failed = False
        failed|= np.any(M1)<-1e5
        failed|= logx>100*maxlogr
        failed|= nll<-1e10
        if failed:
            nll = inf
            break
    return nll

### Test it!

In [None]:
measurement = "moment"
method      = "second_order"
int_method  = "euler"
oversample  = 4
showplot    = False

p0 = np.zeros((1+K))
p0[0 ] = m
p0[1:] = beta.ravel()

tic()
LR,LV,M1,M2,nll,surrogates = filter_moments_surrogate(stim[:NFILT],Y[:NFILT],A,beta,C,m,
                                 method      = method,
                                 int_method  = int_method,
                                 measurement = measurement,
                                 oversample  = oversample,
                                 reg_cov     = 0.0001,
                                 reg_rate    = 0.0001,
                                 return_surrogates = True)
toc()
print(nll)
subplot(311)
logr = m + bhat_stimulus.dot(demo_Bh.T) + bhat_spikehist.dot(demo_By.T)
stderrplot(LR,LV,BLACK,filled=1,lw=0.5)
xlim(0,NFILT);
plot(logr,color=RUST,lw=0.5)
simpleraxis()
title('With measurements')

tic()
LR,LV,M1,M2,nll = filter_moments_surrogate(stim[:NFILT],Y[:NFILT],A,beta,C,m,
                                 method      = method,
                                 int_method  = int_method,
                                 measurement = measurement,
                                 oversample  = oversample,
                                 reg_cov     = 0.0001,
                                 reg_rate    = 0.0001,
                                 return_surrogates = False,
                                 surrogates  = surrogates)
toc()
print(nll)
subplot(312)
logr = m + bhat_stimulus.dot(demo_Bh.T) + bhat_spikehist.dot(demo_By.T)
stderrplot(LR,LV,BLACK,filled=1,lw=0.5)
xlim(0,NFILT);
plot(logr,color=RUST,lw=0.5)
simpleraxis()
title('Surrogate measurements')


tic()
fallLR,fallLV,fallM1,fallM2 = integrate_moments(stim[:NFILT],A,beta,C,
                                     method      = "second_order",
                                     int_method  = "euler",
                                     oversample  = 4)
toc()
print(nll)
subplot(313)
logr = m + bhat_stimulus.dot(demo_Bh.T) + bhat_spikehist.dot(demo_By.T)
stderrplot(fallLR,fallLV,BLACK,filled=1,lw=0.5)
xlim(0,NFILT);
plot(logr,color=RUST,lw=0.5)
simpleraxis()
title('No measurements')


In [None]:
baseline_m = m

def objective(parameters):
    # parameters encode beta,m
    parameters = parameters.ravel()
    
    m    = parameters[0]
    beta = parameters[1:].reshape((K,1))
    
    
    
    # Get likelihood by filtering
    try:
        LR,LV,M1,M2,nll = filter_moments_surrogate(\
                                                   stim[:NFILT]*exp(m-baseline_m),
                                                   Y[:NFILT],
                                                   A,beta,C,m,
                                 method      = method,
                                 int_method  = int_method,
                                 measurement = measurement,
                                 oversample  = oversample,
                                 reg_cov     = 0.0001,
                                 reg_rate    = 0.0001,
                                 return_surrogates = False,
                                 surrogates  = surrogates)
    except (KeyboardInterrupt, SystemExit): 
        raise
    except:
        traceback.print_exc()
        nll = inf
    
    print(nll,'['+','.join(['%0.6f'%x for x in parameters])+']')
    return nll

parameters = minimize_retry(objective,p0)
print('['+','.join(['%0.4f'%x for x in parameters])+']')

# Inspect dynamics of new model

In [None]:
#parameters = [-6.8034,-14.5821,-12.4526,19.5295,-11.7096,9.0415,-4.9698,-1.2497,-0.5224]

# Get new parameters from optimization
parameters = array(parameters)
m2         = parameters[0]-0.5
beta2      = parameters[1:].reshape((K,1))
logr2      = m2 + bhat_stimulus.dot(demo_Bh.T) + (beta2.T.dot(demo_By.T))[0]

stim2 = stim*exp(m2-m)
# "True" sample from new point process model
logxpp2,logvpp2,ratepp2,ratevpp2 = ensemble_sample_moments(stim2,B,beta2,M=1000)
lxpp2 = box_filter(logxpp2,5)
lvpp2 = box_filter(logvpp2,5)

In [None]:
NSHOW = 1000

subplot(411)
fallLR,fallLV,fallM1,fallM2,nll = filter_moments(stim[:NFILT],Y[:NFILT],A,beta,C,m,
                                     method      = "second_order",
                                     int_method  = "euler",
                                     oversample  = 5,
                                     reg_cov     = 0.001)
logr = m + bhat_stimulus.dot(demo_Bh.T) + bhat_spikehist.dot(demo_By.T)
stderrplot(lxpp,lvpp,BLACK,filled=1)
stderrplot(fallLR,fallLV,AZURE,filled=0)
xlim(0,NSHOW);
plot(logr,color=RUST)
simpleraxis()
title('unconstrained')

subplot(412)
fallLR,fallLV,fallM1,fallM2,nll = filter_moments(
    stim2[:NFILT],Y[:NFILT],A,beta2,C,m2,
    method      = "second_order",
    int_method  = "euler",
    oversample  = 5,
    reg_cov     = 0.001)
stderrplot(lxpp,lvpp,BLACK,filled=1)
stderrplot(fallLR,fallLV,AZURE,filled=0)
xlim(0,NSHOW);
plot(logr2,color=RUST)
simpleraxis()

# Unconstrained filtering
subplot(413)
fallLR,fallLV,fallM1,fallM2 = integrate_moments(
    stim[:NFILT],A,beta,C,
    method      = "second_order",
    int_method  = "euler",
    oversample  = 5)
stderrplot(lxpp,lvpp,BLACK,filled=1)
stderrplot(fallLR,fallLV,AZURE,filled=0)
xlim(0,NSHOW);
plot(logr,color=RUST)
simpleraxis()

# Unconstrained filtering
subplot(414)
logxso2,logvso2,fallM1,fallM2 = integrate_moments(
    stim2[:NFILT],A,beta2,C,
    method      = "second_order",
    int_method  = "euler",
    oversample  = 5)
stderrplot(lxpp2,lvpp2,BLACK,filled=1)
stderrplot(logxso2,logvso2,AZURE,filled=0)
xlim(0,NSHOW);
plot(logr,color=RUST)
simpleraxis()

In [None]:
figure(figsize=(8,13))

sc = 20
ax4=subplot2grid((NROWS,2),(2,0),colspan=2,facecolor=(1,1,1,0))
plot(V,color=RUST,lw=1.25)
xlim(0,NSHOW);
yscalebar(min(V)+25,50,'50 mV')
offset = -min(stimulus*sc) + ylim()[1]+30
ss = stimulus*sc + offset
plot(ss,color=BLACK,lw=1)
yscalebar(mean(ss),sc*5,'5 pA'); 
noxyaxes()
title('Stimulus example')
subfigurelabel('A')

yl = (-40,20)

# Plot stimulus
ax5=subplot2grid((NROWS,2),(3,0),colspan=2,facecolor=(1,1,1,0))
stderrplot(lxpp,lvpp,BLACK,filled=1)
stderrplot(logxso,logvso,AZURE,filled=0)
xlim(0,NSHOW); ylim(*yl); yscalebar(mean(ylim()),20,'20 dB'); noxyaxes(); 
title('Second-order state-space model')
subfigurelabel('B')

# Plot sampled from single-time marginals (no time correlation)
ax6=subplot2grid((NROWS,2),(4,0),colspan=2,facecolor=(1,1,1,0))
rate = exp(logxso)*(1+0.5*logvso)
p = np.random.poisson(rate[:,None],(len(rate),nsamp))
pcolormesh(-int32(p.T>0),cmap='gray')
noxyaxes();  xlim(0,NSHOW);
ylabel('Sample',fontsize=9); 

# Plot stimulus
ax7=subplot2grid((NROWS,2),(5,0),colspan=2,facecolor=(1,1,1,0))
stderrplot(lxpp2,lvpp2,BLACK,filled=1)
stderrplot(logxso2,logvso2,AZURE,filled=0)
xlim(0,NSHOW); ylim(*yl); yscalebar(mean(ylim()),20,'20 dB'); noxyaxes(); 
title('Second-order state-space model')
subfigurelabel('G')

# Plot sampled from single-time marginals (no time correlation)
ax8=subplot2grid((NROWS,2),(6,0),colspan=2,facecolor=(1,1,1,0))
rate = exp(logxso2)*(1+0.5*logvso2)
p = np.random.poisson(rate[:,None],(len(rate),nsamp))
pcolormesh(-int32(p.T>0),cmap='gray')
noaxis(); xticks(arange(0,1001,100)); xlim(0,NSHOW);
xlabel('Time (ms)'); ylabel('Sample',fontsize=9); 


# Adjust axes
subplots_adjust(wspace=0.1,hspace=0.3)
nudge_axis_y(-75,ax4)
adjust_axis_height_pixels(20,ax5)
nudge_axis_y(-85,ax5)
adjust_axis_height_pixels(30,ax6)
nudge_axis_y(-35,ax6)
adjust_axis_height_pixels(20,ax7)
nudge_axis_y(-40,ax7)
adjust_axis_height_pixels(30,ax8)
nudge_axis_y(10,ax8)

# Derivation of variational update

Match a multivariate Gaussian to the true measurement posterior using variational Bayes: minimize the KL divergence from the approximating to the true distribution.

$$
\Pr(x) = \left\lvert 2\pi\Sigma \right\rvert ^{-\frac 1 2} \exp\left({-\frac 1 2 (x-\mu)^\top \Sigma ^{-1} (x-\mu)}\right)
$$

The observation model is Poisson 

$$
\Pr(y|x) = \frac 1 {y!} \lambda^y e^{-\lambda}
$$

$$
\ln\lambda = \beta^\top x + I(t)
$$

For a variational update, we minimize

$$
D_{KL}( Q \| \Pr(x|y) ) = 
\int_x Q(x) \ln \frac {Q(x)} {\Pr(x|y)} = \left< \ln \frac {Q(x)} {\Pr(x|y)} \right>_Q
$$

Where $Q(x)$ is the approximating variational posterior, which we will take to be multivariate normal $Q(x) \sim \mathcal{N}(\tilde\mu,\tilde\Sigma)$.


Since $\Pr(x|y) \propto \Pr(y|x) P(x)$, we can minimize $D_{KL}$ by minimizing

$$
\left< \ln \frac {Q(x) } {\Pr(y|x)  {\Pr(x)}  }
\right>_Q
$$

The prior and the likelihood can be separated, and we can focus on minimizing

$$
D_{KL}\left( Q \| P \right)
-
\left< \ln \Pr(y|x) \right>_Q
$$

The $D_{KL}\left( Q \| P \right)$ between two Gaussian has the form

$$
D_{KL}\left( Q \| P \right) = 
\frac 1 2 \left[tr( \Sigma^{-1} \tilde\Sigma )
+ (\mu - \tilde\mu)^\top \Sigma^{-1} (\mu - \tilde\mu)
+ \ln \frac {|\Sigma|}{|\tilde\Sigma|}
\right] + \text{constant}
$$

Consider then the second term, $\left< \ln \Pr(y|x) \right>_Q$

$$
\ln\Pr(y|x) = -\ln y! + y \ln \lambda - \lambda
$$

For the purposes of optimizing the posterior, $-\ln y!$ is constant and can be ignored. Any terms here that do not depend on $x$ can be dropped as they are consant with respect to optimizing the posterior. Taking expectation over $Q$, we get

$$
\left<\ln\Pr(y|x)\right>_Q = 
y \beta^\top \tilde\mu - \left< \lambda \right>_Q + \text{constant}
$$

The expectation $\left< \lambda \right>_Q$ looks familiar! It is the same expectation that we use for estimating $\left<\lambda\right>$ in the moment closure, and we can use the same approximation for evaluating it, this time at the variational posterior $Q$. 

$$
\left< \lambda \right>_Q = \exp\left( 
\beta^\top \tilde\mu + I(t) + \tfrac 1 2 \beta^\top \tilde \Sigma \beta
\right)
$$ 

Overall, we need to minimize

$$
\frac 1 2 \left[tr( \Sigma^{-1} \tilde\Sigma )
+ (\tilde\mu - \mu)^\top \Sigma^{-1} (\tilde\mu - \mu)
+ \ln \frac {|\Sigma|}{|\tilde\Sigma|}
\right]
-\left[
y \beta^\top \tilde\mu - \left< \lambda \right>_Q 
\right]
$$


The gradient in $\tilde\mu$ 

$$
\nabla_{\tilde\mu}
\left[\dots
\right] = 
\Sigma^{-1} (\tilde\mu - \mu)
+\left(\left< \lambda \right>_Q - y\right)\beta^\top
$$

The hessian in $\tilde\mu$

$$
\nabla^2_{\tilde\mu} = \Sigma^{-1} +  \text{diag}\left(\left<\lambda\right>_Q\beta^2\right)
$$

For fixed $\tilde\mu$, the optimal $\tilde\Sigma$ has the closed form (Zhao and Park 2016)



$$
\tilde\Sigma^{-1} = \Sigma^{-1} + \text{diag}\left( 
\left<\lambda\right>_Q \beta^2
\right)
$$

This provides a coordinate descent approach to obtaining the variational posterior

How to backpropagate across a non-conjugate update. 
For a non-conjugate update, we estimate

$$
Q(x) \approx P(x|y) = P(y|x) \frac {P(x)} {P(y)}
$$

Usually $Q(x)\sim\mathcal{N}(\hat\mu,\hat\Sigma)$ is multivariate Gaussian. We estimate the likelihood $P(y)$ as

$$
P(y) \approx \frac{P(y|x{=}\hat\mu)} {Q(x{=}\hat\mu)} P(x{=}\hat\mu)
$$

Or the log-likelihood

$$
\ln P(y) \approx \ln P(y|x{=}\hat\mu) - \ln Q(x{=}\hat\mu) + \ln P(x{=}\hat\mu)
$$

The likelihood and state-update are usually differentiable in parameters $\theta$. The derivative of the posterior, which is the solution to a minimization problem, is less clear. 

$$
\nabla_\theta \ln Q(x{=}\hat\mu)
$$

If $Q$ is Gaussian evaluated at its mean, then 

$$
\ln Q(x{=}\hat\mu) = -\frac 1 2 \left\lvert 2 \pi \hat \Sigma \right\rvert
$$

To backpropagate / chain rule, we need to know

$$
\nabla_\theta \left\lvert \hat \Sigma \right\rvert
$$


In [None]:
dtype = 'float32'
min_log = np.finfo(dtype).tiny
max_exp = np.log(np.sqrt(np.finfo(dtype).max))