# Set up environment

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
from __future__ import absolute_import
from __future__ import with_statement
from __future__ import division
from __future__ import nested_scopes
from __future__ import generators
from __future__ import unicode_literals
from __future__ import print_function

# Load scipy/numpy/matplotlib
from   scipy.linalg import expm
import matplotlib.pyplot as plt
from   pylab import *
import numpy as np
import scipy.io
from warnings import warn

# Configure figure resolution
plt.rcParams['figure.figsize'] = (12.0, 6.0)
plt.rcParams['savefig.dpi'   ] = 100

from izh       import * # Routines for sampling Izhikevich neurons
from plot      import * # Misc. plotting routines
from glm       import * # GLM fitting
from arppglm   import * # Sampling and integration
from utilities import * # Other utilities
from arguments import * # Argument verification

'''
import os
dtype='float64'
flags = 'mode=FAST_RUN,device=gpu,floatX=%s'%dtype
if dtype!='float64':
    flags += ',warn_float64=warn'
os.environ["THEANO_FLAGS"] = flags
import theano
import theano.tensor as T
'''

import os
dtype='float32'
os.environ['MKL_THREADING_LAYER']='GNU'
flags = 'mode=FAST_COMPILE,device=cuda0,'#,floatX=%s'%dtype
if dtype!='float64':
     flags += ',warn_float64=warn'
os.environ["THEANO_FLAGS"] = flags

import theano
import theano.tensor as T

from theano_arppglm import *

print('Workspace Initialized')

# Load saved GLM features

In [None]:
#filename = 'saved_training_model.mat'
filename = 'saved_training_model_badburster.mat'

saved_training_model = scipy.io.loadmat(filename)
K  = np.array(saved_training_model['K'],dtype=dtype)
B  = np.array(saved_training_model['B'],dtype=dtype)
By = np.array(saved_training_model['By'],dtype=dtype)
Bh = np.array(saved_training_model['Bh'],dtype=dtype)
A  = np.array(saved_training_model['A'],dtype=dtype)
C  = np.array(saved_training_model['C'],dtype=dtype)
Y  = np.array(saved_training_model['Y'],dtype=dtype)
dt = np.array(saved_training_model['dt'],dtype=dtype)

Bh_train = saved_training_model['Bh_train']
By_train = saved_training_model['By_train']
X_train  = concatenate([By_train,Bh_train],axis=1)
Y_train  = asvector(saved_training_model['Y_train'])

Bh_test  = saved_training_model['Bh_test']
By_test  = saved_training_model['By_test']
X_test   = concatenate([By_test,Bh_test],axis=1)
Y_test   = asvector(saved_training_model['Y_test'])
 
K  = int(scalar(K))
N  = prod(Y.shape)

N = len(X_train)
STARTPLOT=0
NPLOT=N

print('Saved GLM features loaded')
print(N)

#STARTSHOW = 14000
#STOPSHOW = 16000
STARTSHOW = 0
STOPSHOW = N

### GLM helpers

In [None]:
def lograte(Bh,By,p):
    '''
    Log-intensity of point process model on this dataset
    Predicted using the standard GLM way
    '''
    m       = array(p).ravel()[0]
    beta    = ascolumn(p[1:K+1])
    beta_st = ascolumn(p[1+K:])
    lograte = m + Bh.dot(beta_st) + By.dot(beta)
    return lograte

def logmean(Bh,M1,p):
    '''
    Projected history process
    Predicted using history-process means
    '''
    m       = array(p).ravel()[0]
    beta    = ascolumn(p[1:K+1])
    beta_st = ascolumn(p[1+K:])
    M1      = np.squeeze(M1)
    return (beta.T.dot(M1.T))[0] + (m + Bh.dot(beta_st))[:,0]

def get_stim(Bh,p):
    m        = array(p).ravel()[0]
    beta     = ascolumn(p[1:K+1])
    beta_st  = ascolumn(p[1+K:])
    stim     = (m + Bh.dot(beta_st))[:,0]
    return stim

def filter_GLM_np(Bh,p):
    m        = array(p).ravel()[0]
    beta     = ascolumn(p[1:K+1])
    beta_st  = ascolumn(p[1+K:])
    stim     = get_stim(Bh,p)
    allM1_np = np.zeros((N,K))
    M1       = np.zeros((K,1))
    for i in range(N):
        R   = scalar(sexp(p0[1:K+1].dot(M1)+m+stim[i]))
        M1 += A.dot(M1)*dt + C.dot(R)
        allM1_np[i] = M1[:,0]
    return allM1_np

def addspikes_(Y_=None):
    if Y_ is None or Y_ is True:
        Y_ = Y
    for t in find(Y_>0):
        axvline(t,color=OCHRE,lw=0.4)
    
def niceaxis(plotspikes=True):
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore",message='No labelled objects found')
        legend()
    simpleraxis()
    xlim(STARTSHOW,STOPSHOW)
    if plotspikes is True or not plotspikes is None:
        addspikes_(plotspikes)

print('GLM helpers done')

In [None]:
# Re-fit GLM
m,bhat  = fitGLM(X_train,asvector(Y_train))

# Re-pack model parameters
p0      = np.zeros((1+len(bhat)))
p0[0 ]  = m
p0[1:]  = bhat

allM1_np = filter_GLM_np(Bh_train,p0)
subplot(311)
plot(lograte(Bh_train,By_train,p0),lw=0.4,label='conditional intensity')
plot(logmean(Bh_train,allM1_np,p0),lw=0.4,label='mean-field',color=RUST)
niceaxis()
ylim(min(lograte(Bh_train,By_train,p0)),5)

# Filtering 

### Parameters at which to filter

In [None]:
oversample = 10   # Integration resolution
maxrate    = 10.0 # Largest allowed rate
maxvcorr   = 10.0 # Largest allowed variance correction
dt         = 1.0  # Data time resolution
reg_cov    = 1e-5
reg_rate   = 1e-5

p = p0.copy()
#p[1:K+1] *= 0.775
stim_np = get_stim(Bh_train,p)
beta_np = ascolumn(p[1:K+1])
print('Filtering using p=',v2str(p))

# Helper function to compute negative expected log-likelihood
def post_hoc_nll(LR,LV):
    R0 = sexp(LR)
    R1 = R0*(1+0.5*LV)
    ELL  = np.mean(Y*LR - R1)
    return -ELL

### Build Theano routines

For integrating moments (not conditioned on data), filtering (conditioned on data), and filtering using surrogate likelihoods (Gaussian approximations).

### Build Theano routines

For integrating moments (not conditioned on data), filtering (conditioned on data), and filtering using surrogate likelihoods (Gaussian approximations).

In [None]:
from theano_arppglm import *

GLM_log_intensity, GLMNLL_f, GLMNLL_g, GLMNLL_h = build_ML_GLM_likelihood_theano()

In [None]:
integrate_moments_theano, EMNLL_filt, EMNLL_grad = build_integrate_moments_theano(N,A,C,
    dt          = dt,
    oversample  = oversample,
    maxrate     = maxrate,
    maxvcorr    = maxvcorr,
    method      = "second_order",
    int_method  = "euler")

In [None]:
filter_moments_theano, NLL_filt, NLL_grad = build_filter_moments_theano(N,A,C,
    dt          = dt,
    oversample  = oversample,
    maxrate     = maxrate,
    maxvcorr    = maxvcorr,
    method      = "second_order",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = reg_cov,
    reg_rate    = reg_rate,
    return_surrogates = True)

In [None]:
filter_surrogate_theano, SNLL_filt, SNLL_grad = build_filter_moments_theano(N,A,C,
    dt          = dt,
    oversample  = oversample,
    maxrate     = maxrate,
    maxvcorr    = maxvcorr,
    method      = "second_order",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = reg_cov,
    reg_rate    = reg_rate,
    return_surrogates = False,
    use_surrogates    = True)

print('Theano functions bulit')

## Integrate to establish prior

We will use the integrated moments (without measurements) as a prior distribution for our K-step prediction with measurement updates.

In [None]:
m        = array(p).ravel()[0]
beta     = ascolumn(p[1:K+1])
beta_st  = ascolumn(p[1+K:])
stim     = (m + Bh_train.dot(beta_st))[:,0]
stim_np  = stim
beta_np  = ascolumn(p[1:K+1])

In [None]:
p = p0.copy()
#p[1:]*=3

m        = array(p).ravel()[0]
beta     = ascolumn(p[1:K+1])
beta_st  = ascolumn(p[1+K:])
stim     = (m + Bh_train.dot(beta_st))[:,0]
stim_np  = stim
beta_np  = ascolumn(p[1:K+1])

print('Filtering using p=',v2str(p))

tic()
allLRni,allLVni,allM1ni,allM2ni = integrate_moments(stim_np,A,beta_np,C,
    dt          = dt,
    oversample  = oversample,
    maxrate     = maxrate,
    maxvcorr    = maxvcorr,
    method      = "second_order",
    int_method  = "euler")
toc()
subplot(411)
stderrplot(allLRni,allLVni,color=BLACK,lw=0.5)
niceaxis()
xlim(STARTSHOW,STOPSHOW)
title('Integrating, numpy')

tic()
allLRti,allLVti,allM1ti,allM2ti = integrate_moments_theano(Bh_train,p)
toc()
subplot(412)
stderrplot(allLRti,allLVti,color=BLACK,lw=0.5)
niceaxis()
xlim(STARTSHOW,STOPSHOW)
title('Integrating, theano')

tight_layout()

# Deep filtering in Numpy

In [None]:
tic()
allLRn,allLVn,allM1n,allM2n,nlln,mrn,vrn = filter_moments(stim,Y_train,A,beta,C,p[0],
    dt          = dt,
    oversample  = oversample,
    maxrate     = maxrate,
    maxvcorr    = maxvcorr,
    method      = "second_order",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = reg_cov,
    reg_rate    = reg_rate,
    return_surrogates = True)
toc()

subplot(411)
stderrplot(allLRn,allLVn,color=BLACK,lw=0.5)
niceaxis()
xlim(STARTSHOW,STOPSHOW)
title('Filtering, numpy')
print('nll, numpy',nlln)

tic()
allLRt,allLVt,allM1t,allM2t,nllt,mrt,vrt = filter_moments_theano(Bh_train,Y_train,p)
toc()
subplot(412)
stderrplot(allLRt,allLVt,color=BLACK,lw=0.5)
niceaxis()
xlim(STARTSHOW,STOPSHOW)
title('Filtering, theano')
print('nll, theano',nllt)

subplot(413)
plot(allLRn,color=BLACK,label='log-λ numpy')
plot(allLRt,':',color=RUST,label='log-λ theano')
niceaxis()
xlim(STARTSHOW,STOPSHOW)

tight_layout()

## Set up prior

In [None]:
D = 5
ND = N-D

priorLR,priorLV,priorM1,priorM2 = allLRt,allLVt,allM1t,allM2t

# shift prior one time-step to exactly match case
# in which we filter linearly starting at time 0
#priorM1[1:]=priorM1[:-1]
#priorM2[1:]=priorM2[:-1]
#priorM1[0]=np.zeros((K,1))
#priorM2[0]=np.eye(K)*1e-6

defaultM1 = np.zeros((K,1))
defaultM2 = np.eye(K)*1e-6

iniM1 = np.zeros((N,K,1))
iniM2 = np.zeros((N,K,K))
iniM1[:D-1]=defaultM1
iniM2[:D-1]=defaultM2
iniM1[D-1:]=priorM1[:-D+1]
iniM2[D-1:]=priorM2[:-D+1]
allM1 = iniM1.copy()
allM2 = iniM2.copy()

## Start with naive implementation for reference

Demonstrate shallow depth-5 filtering. Even starting from a prior with no inforamation about the filtered state, these results can be relatively accurate. This could lead to parallel filtering routines to accelerate inference. 

In [None]:


# Precompute constants
maxlogr   = np.log(maxrate)
maxratemc = maxvcorr*maxrate
dtfine    = dt/oversample
Cb        = C.dot(beta.T)
CC        = C.dot(C.T)
Adt       = A*dtfine

themeasurement = 'moment'
int_method = 'euler'
method = 'second_order'

# Get measurement update function
measurement = get_measurement(themeasurement)
# Buid moment integrator functions
mean_update, cov_update = get_moment_integrator(int_method,Adt)
# Get update function (computes expected rate from moments)
update = get_update_function(method,Cb,Adt,maxvcorr)

allLR = np.zeros(N)
allLV = np.zeros(N)

for i in range(N):
    b = i+1
    a = i-D+1
    a = max(0,a)
    c = i-D+1
    ini = (priorM1[c],priorM2[c]) if c>=0 else (defaultM1,defaultM2)
    lr,lv,_,_,_,_,_ = filter_moments(stim[a:b],Y_train[a:b],A,beta,C,p[0],
        dt          = dt,
        oversample  = oversample,
        maxrate     = maxrate,
        maxvcorr    = maxvcorr,
        method      = "second_order",
        int_method  = "euler",
        measurement = "moment",
        reg_cov     = reg_cov,
        reg_rate    = reg_rate,
        return_surrogates = True,
        initial_conditions = ini)
    allLR[i] = lr[-1]
    allLV[i] = lv[-1]

assert(all(isfinite(allLR)))
assert(all(isfinite(allLV)))

allLRref,allLVref = allLR,allLV

subplot(311)
stderrplot(allLRn,allLVn,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()
ylim(max(ylim()[0],-100),5)

## Unpack naive parallel filtering

In [None]:
allLR = np.zeros(N)
allLV = np.zeros(N)

for i in range(N):
    b = i+1
    a = i-D+1
    # Initial condition for moments
    M1 = priorM1[a] if a>=0 else defaultM1
    M2 = priorM2[a] if a>=0 else defaultM2
    
    a = max(0,a)
    S = stim[a:b]
    Y = Y_train[a:b]
    
    for j,(s,y) in enumerate(zip(S,Y)):
        # Regularize
        if reg_cov>0:
            strength = reg_cov+max(0,-np.min(np.diag(M2)))
            M2 = 0.5*(M2+M2.T) + strength*np.eye(K) 
        # Integrate moments forward
        for k in range(oversample):
            logv  = beta.T.dot(M2).dot(beta)
            logx  = min(beta.T.dot(M1)+s,maxlogr)
            R0    = sexp(logx)
            R0    = min(maxrate,R0)
            R0   *= dtfine
            Rm,J  = update(logx,logv,R0,M1,M2)
            M2    = cov_update(M2,J) + CC*Rm
            M1    = mean_update(M1)  + C*Rm
        # Measurement update
        pM1,pM2 = M1,M2
        M1,M2,ll,mr,vr = measurement_update_projected_gaussian(\
                  M1,M2,y,beta,s,dt,m,reg_rate,measurement)
       
    allLR[i] = min(beta.T.dot(M1)+s,maxlogr)
    allLV[i] = beta.T.dot(M2).dot(beta)

    assert(all(isfinite(allLR)))
    assert(all(isfinite(allLV)))

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()
ylim(max(ylim()[0],-100),5)

## Re-order loops toward a parallel implementation

In [None]:
S = stim
Y = Y_train
allLR = np.zeros(N)
allLV = np.zeros(N)
allM1 = np.zeros((N,K,1))
allM2 = np.zeros((N,K,K))
allRC = np.zeros((N,K,K))
for i in range(N):
    allRC[i,...]=reg_cov*np.eye(K)

iniM1 = np.zeros((N,K,1))
iniM2 = np.zeros((N,K,K))
iniM1[:D-1]=defaultM1
iniM2[:D-1]=defaultM2
iniM1[D-1:]=priorM1[:-D+1]
iniM2[D-1:]=priorM2[:-D+1]
allM1[...] = iniM1
allM2[...] = iniM2

for di in arange(-D+1,1):
    for i in range(N):
        # Initial condition for moments
        M1 = allM1[i]
        M2 = allM2[i]
        offset = max(0,di+i)
        s = S[offset]
        y = Y[offset]
        # Regularize
        if reg_cov>0:
            strength = reg_cov+max(0,-np.min(np.diag(M2)))
            M2 = 0.5*(M2+M2.T) + strength*np.eye(K) 
        # Integrate moments forward
        for k in range(oversample):
            logv  = beta.T.dot(M2).dot(beta)
            logx  = min(beta.T.dot(M1)+s,maxlogr)
            R0    = sexp(logx)
            R0    = min(maxrate,R0)
            R0   *= dtfine
            Rm,J  = update(logx,logv,R0,M1,M2)
            M2    = cov_update(M2,J) + CC*Rm
            M1    = mean_update(M1)  + C*Rm
        # Measurement update
        M1,M2,ll,mr,vr = measurement_update_projected_gaussian(\
                  M1,M2,y,beta,s,dt,m,reg_rate,measurement)
        allM1[i]=M1
        allM2[i]=M2
        assert(all(isfinite(allM1)))
        assert(all(isfinite(allM2)))
        
for i in range(N):
    M1 = allM1[i,:,0]
    M2 = allM2[i]
    allLR[i] = min(beta.T.dot(M1)+S[i],maxlogr)
    allLV[i] = beta.T.dot(M2).dot(beta)
    assert(all(isfinite(allLR)))
    assert(all(isfinite(allLV)))

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()
ylim(max(ylim()[0],-100),5)

## One more loop re-arrangement

In [None]:
S = stim
Y = Y_train
allLR = np.zeros(N)
allLV = np.zeros(N)
allM1 = np.zeros((N,K,1))
allM2 = np.zeros((N,K,K))
allRC = np.zeros((N,K,K))
for i in range(N):
    allRC[i,...]=reg_cov*np.eye(K)
allM1[...] = iniM1
allM2[...] = iniM2
    
for di in arange(-D+1,1):
    # Reset values that really shouldn't be being integrated? 
    #allM1[:-di,...]=iniM1[:-di,...]
    #allM2[:-di,...]=iniM2[:-di,...]
    if reg_cov>0:
        for i in range(N):
            M2 = allM2[i]
            M2 = 0.5*(M2+M2.T) + reg_cov*np.eye(K) 
            allM2[i] = M2
    
    for k in range(oversample):
        for i in range(N):
            # Initial condition for moments
            M1 = allM1[i]
            M2 = allM2[i]
            offset = max(0,di+i)
            s = S[offset]
            logv  = beta.T.dot(M2).dot(beta)
            logx  = min(beta.T.dot(M1)+s,maxlogr)
            R0    = sexp(logx)
            R0    = min(maxrate,R0)
            R0   *= dtfine
            Rm,J  = update(logx,logv,R0,M1,M2)
            M2    = cov_update(M2,J) + CC*Rm
            M1    = mean_update(M1)  + C*Rm
            allM1[i]=M1
            allM2[i]=M2
            
    # Measurement update
    for i in range(N):
        M1    = allM1[i]
        M2    = allM2[i]
        offset = max(0,di+i)
        s     = S[offset]
        y     = Y[offset]
        M1,M2,ll,mr,vr = measurement_update_projected_gaussian(\
                  M1,M2,y,beta,s,dt,m,reg_rate,measurement,safe=True)
        allM1[i]=M1
        allM2[i]=M2
        assert(all(isfinite(allM1)))
        assert(all(isfinite(allM2)))
        
for i in range(N):
    M1 = allM1[i,:,0]
    M2 = allM2[i]
    allLR[i] = min(beta.T.dot(M1)+S[i],maxlogr)
    allLV[i] = beta.T.dot(M2).dot(beta)
    assert(all(isfinite(allLR)))
    assert(all(isfinite(allLV)))

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()
ylim(max(ylim()[0],-100),5)

## Some vectorization of inner loops

In [None]:
allLR = np.zeros(N)
allLV = np.zeros(N)
allM1 = np.zeros((N,K,1))
allM2 = np.zeros((N,K,K))
S = stim
Y = Y_train
allM1[...] = iniM1
allM2[...] = iniM2

for di in arange(-D+1,1):
    # Reset values that really shouldn't be being integrated? 
    allM1[:-di,...]=iniM1[:-di,...]
    allM2[:-di,...]=iniM2[:-di,...]
    # Regularize
    if reg_cov>0:
        allM2 = 0.5*(allM2 + allM2.transpose(0,2,1)) + allRC
    # Integrate moments forward
    for k in range(oversample):
        offsets = np.maximum(0,arange(N)+di)
        S_ = S[offsets]
        LOGV = allM2.dot(beta[:,0]).dot(beta[:,0])
        LOGX = np.minimum(maxlogr,allM1[:,:,0].dot(beta[:,0])+S_)
        R0_  = np.minimum(maxrate,sexp(LOGX))*dtfine
        RM = R0_ * np.minimum(1+0.5*LOGV,maxvcorr)
        J_   = Cb[None,:,:]*R0_[:,None,None]+Adt[None,:,:]
        allM1 += np.matmul(Adt,allM1[:,:,:])
        JM2_   = np.matmul(J_,allM2)
        allM2 += JM2_ + JM2_.transpose((0,2,1))
        allM1 +=  C[None,:,:]*RM[:,None,None]
        allM2 += CC[None,:,:]*RM[:,None,None]
        allM1 = np.clip(allM1,-100,100)
        allM2 = np.clip(allM2,-100,100)
    # Measurement update
    for i in range(N):
        M1    = allM1[i]
        M2    = allM2[i]
        offset = max(0,di+i)
        s     = S[offset]
        y     = Y[offset]
        M1,M2,ll,mr,vr = measurement_update_projected_gaussian(\
                  M1,M2,y,beta,s,dt,m,reg_rate,measurement)
        allM1[i]=M1
        allM2[i]=M2
    allM1 = np.clip(allM1,-100,100)
    allM2 = np.clip(allM2,-100,100)
        
for i in range(N):
    M1 = allM1[i,:,0]
    M2 = allM2[i]
    allLR[i] = min(beta.T.dot(M1)+S[i],maxlogr)
    allLV[i] = beta.T.dot(M2).dot(beta)

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()
ylim(max(ylim()[0],-100),5)

## The parallel measurement is a tricky one! Start by unpacking it

In [None]:
allLR = np.zeros(N)
allLV = np.zeros(N)
allM1 = np.zeros((N,K,1))
allM2 = np.zeros((N,K,K))
S = stim
Y = Y_train
allM1[...] = iniM1
allM2[...] = iniM2

for di in arange(-D+1,1):
    # Reset values that really shouldn't be being integrated? 
    allM1[:-di,...]=iniM1[:-di,...]
    allM2[:-di,...]=iniM2[:-di,...]
    # Regularize
    if reg_cov>0:
        allM2 = 0.5*(allM2 + allM2.transpose(0,2,1)) + allRC
    # Integrate moments forward
    for k in range(oversample):
        offsets = np.maximum(0,arange(N)+di)
        S_ = S[offsets]
        LOGV = allM2.dot(beta[:,0]).dot(beta[:,0])
        LOGX = np.minimum(maxlogr,allM1[:,:,0].dot(beta[:,0])+S_)
        R0_  = np.minimum(maxrate,sexp(LOGX))*dtfine
        RM = R0_ * np.minimum(1+0.5*LOGV,maxvcorr)
        J_   = Cb[None,:,:]*R0_[:,None,None]+Adt[None,:,:]
        allM1 += np.matmul(Adt,allM1[:,:,:])
        JM2_   = np.matmul(J_,allM2)
        allM2 += JM2_ + JM2_.transpose((0,2,1))
        allM1 +=  C[None,:,:]*RM[:,None,None]
        allM2 += CC[None,:,:]*RM[:,None,None]
        allM1 = np.clip(allM1,-100,100)
        allM2 = np.clip(allM2,-100,100)
    # Measurement update
    for i in range(N):
        M1    = allM1[i]
        M2    = allM2[i]
        offset = max(0,di+i)
        s     = S[offset]
        y     = Y[offset]
        #M1,M2,ll,mr,vr = measurement_update_projected_gaussian(\
        #          M1,M2,y,beta,s,dt,m,reg_rate,measurement)
        # Validate arguments
        m2b = M2.dot(beta)
        lv  = max(eps,(beta.T.dot(m2b))[0,0])
        lm  = (beta.T.dot(M1))[0,0]
        lt  = 1/lv
        tq  = lt + reg_rate
        mq = (lm*lt+m*reg_rate)/tq
        vq  = 1/tq
        mp,vp = measurement(mq,vq,y,s,dt)
        if vp<eps: vp=eps
        tp = 1/vp
        tr = max(eps,tp-lt)
        vr = scalar(1/tr)
        mr = (mp*tp-lm*lt)/tr
        Kg  = m2b/(vr+lv)
        M2 = M2 - Kg.dot(m2b.T)
        M1 = M1 + Kg*(mr-lm)
        #logr   = mp+s
        #logPyx = y*logr-sexp(logr)
        #ll     = logPyx + 0.5*slog(vp/v) - 0.5*(mp-m)**2/v 
        allM1[i]=M1
        allM2[i]=M2
        allM1 = np.clip(allM1,-100,100)
        allM2 = np.clip(allM2,-100,100)
        
for i in range(N):
    M1 = allM1[i,:,0]
    M2 = allM2[i]
    allLR[i] = min(beta.T.dot(M1)+S[i],maxlogr)
    allLV[i] = beta.T.dot(M2).dot(beta)

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()
ylim(max(ylim()[0],-100),5)

## Vectorize measurement

In [None]:
allLR = np.zeros(N)
allLV = np.zeros(N)
allM1 = np.zeros((N,K,1))
allM2 = np.zeros((N,K,K))
allM1[...] = iniM1
allM2[...] = iniM2

for di in arange(-D+1,1):
    # Reset values that really shouldn't be being integrated? 
    allM1[:-di,...]=iniM1[:-di,...]
    allM2[:-di,...]=iniM2[:-di,...]
    # Regularize
    if reg_cov>0:
        allM2 = 0.5*(allM2 + allM2.transpose(0,2,1)) + allRC
    # Integrate moments forward
    for k in range(oversample):
        offsets = np.maximum(0,arange(N)+di)
        S_ = S[offsets]
        LOGV = allM2.dot(beta[:,0]).dot(beta[:,0])
        LOGX = np.minimum(maxlogr,allM1[:,:,0].dot(beta[:,0])+S_)
        R0_  = np.minimum(maxrate,sexp(LOGX))*dtfine
        RM = R0_ * np.minimum(1+0.5*LOGV,maxvcorr)
        J_   = Cb[None,:,:]*R0_[:,None,None]+Adt[None,:,:]
        allM1 += np.matmul(Adt,allM1[:,:,:])
        JM2_   = np.matmul(J_,allM2)
        allM2 += JM2_ + JM2_.transpose((0,2,1))
        allM1 +=  C[None,:,:]*RM[:,None,None]
        allM2 += CC[None,:,:]*RM[:,None,None]
        allM1 = np.clip(allM1,-100,100)
        allM2 = np.clip(allM2,-100,100)
    # Parallel measurement update
    offsets = maximum(0,arange(N)+di)
    S_ = S[offsets]
    Y_ = Y[offsets]
    M2B_ = np.matmul(allM2,beta)
    LV = allM2.dot(beta[:,0]).dot(beta[:,0])
    LV = maximum(1e-12,LV)
    LM = allM1[:,:,0].dot(beta[:,0])
    LT = 1/LV
    TQ = LT + reg_rate
    VQ = 1/TQ
    MQ = (LM*LT+m*reg_rate)*VQ
    intr = linspace(-4,4,25)
    X_ = intr[None,:]*sqrt(VQ)[:,None]+MQ[:,None]
    R0_ = X_ + S_[:,None]+slog(dt)
    L = Y_[:,None]*R0_-sexp(R0_)
    L = L - np.max(L,axis=1)[:,None]
    L += -.5*(X_-MQ[:,None])**2/VQ[:,None]-.5*slog(VQ)[:,None]
    PR = sexp(L)
    PR = maximum(1e-7,PR)
    NR = np.sum(PR,axis=1)
    MP = np.sum(X_*PR,axis=1)/NR
    VP = np.sum((X_-MP[:,None])**2*PR,axis=1)/NR
    VP = maximum(1e-12,VP)
    TP = 1/VP
    TR = TP-LT
    TR = maximum(1e-12,TR)
    VR = 1/TR
    MR = (MP*TP-LM*LT)*VR
    KG = M2B_/(VR+LV)[:,None,None]
    allM2 -= np.matmul(KG,M2B_.transpose(0,2,1))
    allM1 += KG*(MR-LM)[:,None,None]
    # likelihood
    LOGR = MP+S_
    LOGPYX = Y_*LOGR-sexp(LOGR)
    LL = LOGPYX + 0.5*slog(VP/LV) - 0.5*(MP-LM)**2/LV
for i in range(N):
    M1 = allM1[i,:,0]
    M2 = allM2[i]
    allLR[i] = min(beta.T.dot(M1)+S[i],maxlogr)
    allLV[i] = beta.T.dot(M2).dot(beta)
    allM1 = np.clip(allM1,-100,100)
    allM2 = np.clip(allM2,-100,100)

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()
ylim(max(ylim()[0],-100),5)

## Clean things up a bit

In [None]:
allLR = np.zeros(N)
allLV = np.zeros(N)
allM1 = np.zeros((N,K,1))
allM2 = np.zeros((N,K,K))
allM1[...] = iniM1
allM2[...] = iniM2

for di in range(-D+1,1):
    # Reset values that really shouldn't be being integrated? 
    allM1[:-di,...]=iniM1[:-di,...]
    allM2[:-di,...]=iniM2[:-di,...]
    # Regularize
    if reg_cov>0:
        allM2 = 0.5*(allM2 + allM2.transpose(0,2,1)) + allRC
    offsets = np.maximum(0,np.arange(N)+di)
    S_ = S[offsets]
    Y_ = Y[offsets]
    for k in range(oversample):
        LOGV = allM2.dot(beta[:,0]).dot(beta[:,0])
        LOGX = np.minimum(maxlogr,allM1[:,:,0].dot(beta[:,0])+S_)
        R0_  = np.minimum(maxrate,sexp(LOGX))*dtfine
        RM = R0_ * np.minimum(1+0.5*LOGV,maxvcorr)
        J_   = Cb[None,:,:]*R0_[:,None,None]+Adt[None,:,:]
        allM1 += np.matmul(Adt,allM1[:,:,:])
        JM2_   = np.matmul(J_,allM2)
        allM2 += JM2_ + JM2_.transpose((0,2,1))
        allM1 +=  C[None,:,:]*RM[:,None,None]
        allM2 += CC[None,:,:]*RM[:,None,None]
    # Parallel measurement update
    M2B_ = np.matmul(allM2,beta)
    LV = allM2.dot(beta[:,0]).dot(beta[:,0])
    LV = np.maximum(1e-12,LV)
    LM = allM1[:,:,0].dot(beta[:,0])
    LT = 1/LV
    TQ = LT + reg_rate
    VQ = 1/TQ
    MQ = (LM*LT+m*reg_rate)*VQ
    intr = np.linspace(-4,4,25)
    X_ = intr[None,:]*np.sqrt(VQ)[:,None]+MQ[:,None]
    R0_ = X_ + S_[:,None]+slog(dt)
    L = Y_[:,None]*R0_-sexp(R0_)
    L = L - np.max(L,axis=1)[:,None]
    L += -.5*((intr**2)[None,:]+slog(VQ)[:,None])
    PR = sexp(L)
    PR = np.maximum(1e-7,PR)
    NR = 1/np.sum(PR,axis=1)
    MP = np.sum(X_*PR,axis=1)*NR
    VP = np.sum((X_-MP[:,None])**2*PR,axis=1)*NR
    VP = np.maximum(1e-12,VP)
    TP = 1/VP
    VR = 1/np.maximum(1e-12,TP-LT)
    MR = (MP*TP-LM*LT)*VR
    KG = M2B_/(VR+LV)[:,None,None]
    allM2 -= np.matmul(KG,M2B_.transpose(0,2,1))
    allM1 += KG*(MR-LM)[:,None,None]
    LOGR = MP+S_
    LOGPYX = Y_*LOGR-sexp(LOGR)
    LL = LOGPYX - 0.5*(slog(LV/VP) + (MP-LM)**2/LV)
for i in range(N):
    M1 = allM1[i,:,0]
    M2 = allM2[i]
    allLR[i] = min(beta.T.dot(M1)+S[i],maxlogr)
    allLV[i] = beta.T.dot(M2).dot(beta)

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()

# compare likelihoods
# parallel shallow likelihood as close to filtered likelihood
# as it is to the theano implementation
# meaning that a shallow filter is as accurate as a deep filter
# up to numerical precision errors
print(nllt,nlln,-mean(LL))
ylim(max(ylim()[0],-100),5)

## Convert to theano scan style syntax

Buliding a theano program with a for loop leads to computational grpahs that are too large, they need to be written as scans for efficiency

In [None]:
allLR = np.zeros(N)
allLV = np.zeros(N)
allM1 = np.zeros((N,K,1))
allM2 = np.zeros((N,K,K))
allM1[...] = iniM1
allM2[...] = iniM2

def project_moments_parallel(allM1,allM2,S_):
    LOGV   = allM2.dot(beta[:,0]).dot(beta[:,0])
    LOGM   = allM1[:,:,0].dot(beta[:,0])
    LOGX   = np.minimum(maxlogr,LOGM+S_)
    return LOGV,LOGM,LOGX

def integrate_moments_parallel(allM1,allM2,S_):
    #LOGV   = allM2.dot(beta[:,0]).dot(beta[:,0])
    #LOGX   = np.minimum(maxlogr,allM1[:,:,0].dot(beta[:,0])+S_)
    LOGV,LOGM,LOGX = project_moments_parallel(allM1,allM2,S_)
    R0_    = np.minimum(maxrate,sexp(LOGX))*dtfine
    RM     = R0_ * np.minimum(1+0.5*LOGV,maxvcorr)
    allM1 += np.matmul(Adt,allM1[:,:,:])
    allM1 += C[None,:,:]*RM[:,None,None]
    J_     = Cb[None,:,:]*R0_[:,None,None]+Adt[None,:,:]
    JM2_   = np.matmul(J_,allM2)
    allM2 += JM2_ + JM2_.transpose((0,2,1))
    allM2 += CC[None,:,:]*RM[:,None,None]
    return allM1,allM2

def measurement_update_parallel(allM1,allM2,S_,Y_):
    # Parallel measurement update
    M2B_ = np.matmul(allM2,beta)
    LV = allM2.dot(beta[:,0]).dot(beta[:,0])
    LV = np.maximum(1e-12,LV)
    LM = allM1[:,:,0].dot(beta[:,0])
    LT = 1/LV
    TQ = LT + reg_rate
    VQ = 1/TQ
    MQ = (LM*LT+m*reg_rate)*VQ
    intr = np.linspace(-4,4,25)
    X_ = intr[None,:]*np.sqrt(VQ)[:,None]+MQ[:,None]
    R0_ = X_ + S_[:,None]+slog(dt)
    L = Y_[:,None]*R0_-sexp(R0_)
    L = L - np.max(L,axis=1)[:,None]
    L += -.5*((intr**2)[None,:]+slog(VQ)[:,None])
    PR = np.maximum(1e-7,sexp(L))
    NR = 1/np.sum(PR,axis=1)
    MP = np.sum(X_*PR,axis=1)*NR
    VP = np.sum((X_-MP[:,None])**2*PR,axis=1)*NR
    TP = 1/np.maximum(1e-12,VP)
    VR = 1/np.maximum(1e-12,TP-LT)
    MR = (MP*TP-LM*LT)*VR
    KG = M2B_/(VR+LV)[:,None,None]
    allM2 -= np.matmul(KG,M2B_.transpose(0,2,1))
    allM1 += KG*(MR-LM)[:,None,None]
    LOGR = MP+S_
    LOGPYX = Y_*LOGR-sexp(LOGR)
    LL = LOGPYX - 0.5*(slog(LV/VP) + (MP-LM)**2/LV)
    return allM1,allM2,LL

def filter_moments_parallel(di,allM1,allM2):
    # Reset values that really shouldn't be being integrated? 
    allM1[:-di,...]=iniM1[:-di,...]
    allM2[:-di,...]=iniM2[:-di,...]
    # Regularize
    if reg_cov>0:
        allM2 = 0.5*(allM2 + allM2.transpose(0,2,1)) + allRC
    offsets = np.maximum(0,np.arange(N)+di)
    S_ = S[offsets]
    Y_ = Y[offsets]
    for k in range(oversample):
        allM1,allM2 = integrate_moments_parallel(allM1,allM2,S_)
    allM1,allM2,LL = measurement_update_parallel(allM1,allM2,S_,Y_)
    return allM1,allM2,LL

for di in range(-D+1,1):
    allM1,allM2,LL = filter_moments_parallel(di,allM1,allM2)

allLV = allM2.dot(beta[:,0]).dot(beta[:,0])
allLR = np.minimum(maxlogr,allM1[:,:,0].dot(beta[:,0])+S)

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()

# compare likelihoods
# parallel shallow likelihood as close to filtered likelihood
# as it is to the theano implementation
# meaning that a shallow filter is as accurate as a deep filter
# up to numerical precision errors
print(nllt,nlln,-mean(LL))
ylim(max(ylim()[0],-100),5)

## Compartamentalize in function

Note that parallel shallow filtering is still slower than running the full forward pass, as we must perform O(D*N) work as opposed to O(N). However, this will admits a depth-D algorithm in theano which may give us some improvement

In [None]:
tic()
allLRn,allLVn,allM1n,allM2n,nlln,mrn,vrn = filter_moments(stim,Y_train,A,beta,C,p[0],
    dt          = dt,
    oversample  = oversample,
    maxrate     = maxrate,
    maxvcorr    = maxvcorr,
    method      = "second_order",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = reg_cov,
    reg_rate    = reg_rate,
    return_surrogates = True)
toc()

from dstep import filter_moments_dstep

tic()
allLRnd,allLVnd,allM1nd,allM2nd,nllnd = filter_moments_dstep(D,stim,Y_train,A,beta,C,p[0],
    dt          = dt,
    oversample  = oversample,
    maxrate     = maxrate,
    maxvcorr    = maxvcorr,
    method      = "second_order",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = reg_cov,
    reg_rate    = reg_rate,
    prior       = (iniM1,iniM2))
toc()

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLRnd,allLVnd,color=RUST,lw=0.5,filled=0)
niceaxis()
ylim(max(ylim()[0],-100),5)

# compare likelihoods
# parallel shallow likelihood as close to filtered likelihood
# as it is to the theano implementation
# meaning that a shallow filter is as accurate as a deep filter
# up to numerical precision errors
print(nllt,nlln,nllnd)

## Add prior for initial conditions

This can get us a little closer to the "true" filtered states

In [None]:
'''_,_,priorM1,priorM2 = integrate_moments_theano(Bh_train,p)

tic()
allLRnd,allLVnd,allM1nd,allM2nd,nllnd = filter_moments_dstep(D,stim,Y_train,A,beta,C,p[0],
    dt          = dt,
    oversample  = oversample,
    maxrate     = maxrate,
    maxvcorr    = maxvcorr,
    method      = "second_order",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = reg_cov,
    reg_rate    = reg_rate,
    prior       = (priorM1,priorM2))
toc()

subplot(311)
stderrplot(allLRref,allLVref,color=AZURE,lw=0.5,filled=0)
stderrplot(allLRt,allLVt,color=BLACK,lw=0.5)
#stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLRnd,allLVnd,color=RUST,lw=0.5,filled=0)
niceaxis()
ylim(max(ylim()[0],-100),5)

# compare likelihoods
# parallel shallow likelihood as close to filtered likelihood
# as it is to the theano implementation
# meaning that a shallow filter is as accurate as a deep filter
# up to numerical precision errors
print(nllt,nlln,nllnd)'''

## Confirm that shallow filtering can handle cases where deep filtering fails

In [None]:
'''p = p0.copy()
p[1:]*=5

m        = array(p).ravel()[0]
beta     = ascolumn(p[1:K+1])
beta_st  = ascolumn(p[1+K:])
stim     = (m + Bh_train.dot(beta_st))[:,0]

D = 5

tic()
allLRn,allLVn,allM1n,allM2n,nlln,mrn,vrn = filter_moments(stim,Y_train,A,beta,C,p[0],
    dt          = dt,
    oversample  = oversample,
    maxrate     = maxrate,
    maxvcorr    = maxvcorr,
    method      = "second_order",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = reg_cov,
    reg_rate    = reg_rate,
    return_surrogates = True)
toc()


tic()
allLRt,allLVt,allM1t,allM2t,nllt,mrt,vrt = filter_moments_theano(Bh_train,Y_train,p)
toc()

from dstep import filter_moments_dstep


pLR,pLV,priorM1,priorM2 = integrate_moments_theano(Bh_train,p)

tic()
allLRnd,allLVnd,allM1nd,allM2nd,nllnd = filter_moments_dstep(D,stim,Y_train,A,beta,C,p[0],
    dt          = dt,
    oversample  = oversample,
    maxrate     = 1,
    maxvcorr    = 5,
    method      = "second_order",
    int_method  = "euler",
    measurement = "moment",
    reg_cov     = reg_cov,
    reg_rate    = reg_rate,
    prior       = (priorM1,priorM2))
toc()

subplot(311)
stderrplot(allLRn,allLVn,color=BLACK,lw=0.5)
stderrplot(allLRt,allLVt,color=AZURE,lw=0.5,filled=0)
stderrplot(pLR,pLV,color=OCHRE,lw=0.5,filled=0)
stderrplot(allLRnd,allLVnd,color=RUST,lw=0.5,filled=0)
niceaxis()

# compare likelihoods
# parallel shallow likelihood as close to filtered likelihood
# as it is to the theano implementation
# meaning that a shallow filter is as accurate as a deep filter
# up to numerical precision errors
print(nllt,nlln,nllnd)

ylim(-50,10)'''

## Now work toward a theano implementation

The hope is that on a good GPU, shallow depth-D filtering will be faster that deep filtering. Implementing this all at once is much too challenging, let's implement theano functio to replace pieces at a time.

## Implementation in Theano

In [None]:
from theano_arppglm import *
from theano_arppglm import Tmatmul

In [None]:
m        = array(p).ravel()[0]
beta     = ascolumn(p[1:K+1])
beta_st  = ascolumn(p[1+K:])
stim     = (m + Bh_train.dot(beta_st))[:,0]
stim_np  = stim
beta_np  = ascolumn(p[1:K+1])

In [None]:
S = stim
Y = Y_train

allRC = np.zeros((N,K,K))
for i in range(N):
    allRC[i,...]=reg_cov*np.eye(K)

TAdt  = Tcon(Adt)
Tbeta = Tcon(beta)
Tb    = Tcon(p.ravel()[1:K+1])
TC    = Tcon(C ).dimshuffle('x',0,1)
TCb   = Tcon(Cb).dimshuffle('x',0,1)
TCC   = Tcon(CC).dimshuffle('x',0,1)

mxl = Tcon(maxlogr)
mxr = Tcon(maxrate)
dtf = Tcon(dtfine)
xvc = Tcon(maxvcorr)
rr  = Tcon(reg_rate)
mm  = Tcon(m)

def project_moments_parallel_theano_source(M1,M2,S):
    LOGV = M2.dot(Tb).dot(Tb) # N
    LOGM = M1[:,:,0].dot(Tb) # N
    LOGX = Tmn(mxl,LOGM+S) # N
    R0   = Tmn(mxr,Tsexp(LOGX))*dtf # N 
    RM   = R0 * Tmn(1.0+0.5*LOGV,xvc) # N
    return LOGV,LOGM,LOGX,R0,RM
TallM1 = T.tensor3("TallM1",dtype=dtype)
TallM2 = T.tensor3("TallM2",dtype=dtype)
TallS_ = T.vector("TallS_",dtype=dtype)
project_moments_parallel_theano = Tfun(
    inp = [TallM1,TallM2,TallS_],
    out = project_moments_parallel_theano_source(TallM1,TallM2,TallS_))

def euler_update_moments_parallel_theano_source(M1,M2,R0,RM):
    M1  += TAdt.dot(M1).transpose(1,0,2)
    J    = TCb*R0[:,None,None]+TAdt[None,:,:]
    JM2  = T.batched_dot(J,M2)
    M2  += JM2 + JM2.transpose((0,2,1))
    M1  += TC *RM[:,None,None]
    M2  += TCC*RM[:,None,None]
    return M1,M2
TR0 = T.vector("TR0",dtype=dtype)
TRM = T.vector("TRM",dtype=dtype)
euler_update_moments_parallel_theano = Tfun(
    inp = [TallM1,TallM2,TR0,TRM],
    out = euler_update_moments_parallel_theano_source(TallM1,TallM2,TR0,TRM))

def integrate_moments_parallel_theano_source(M1,M2,S):
    LOGV,LOGM,LOGX,R0,RM = project_moments_parallel_theano_source(M1,M2,S)
    M1,M2 = euler_update_moments_parallel_theano_source(M1,M2,R0,RM)
    return M1,M2
integrate_moments_parallel_theano = Tfun(
    inp = [TallM1,TallM2,TallS_],
    out = integrate_moments_parallel_theano_source(TallM1,TallM2,TallS_))

def univariate_prior_parallel_theano_source(LM,LT):
    TQ = LT + rr
    VQ = Tsinv(TQ)
    MQ = (LM*LT+mm*rr)*VQ
    return MQ,VQ
TLM = T.vector("TLM",dtype=dtype)
TLT = T.vector("TLT",dtype=dtype)
univariate_prior_parallel_theano = Tfun(
    inp = [TLM,TLT],
    out = univariate_prior_parallel_theano_source(TLM,TLT))

def quadrature_moments_parallel(MQ,VQ,S_,Y_):
    intr = np.linspace(-4,4,25)
    X_ = intr[None,:]*np.sqrt(VQ)[:,None]+MQ[:,None]
    R0_ = X_ + S_[:,None]+slog(dt)
    L = Y_[:,None]*R0_-sexp(R0_)
    L = L - np.max(L,axis=1)[:,None]
    L += -.5*((intr**2)[None,:]+slog(VQ)[:,None])
    PR = np.maximum(1e-7,sexp(L))
    NR = 1/np.sum(PR,axis=1)
    MP = np.sum(X_*PR,axis=1)*NR
    VP = np.sum((X_-MP[:,None])**2*PR,axis=1)*NR
    return MP,VP

Tintr = Tcon(np.linspace(-4,4,25))
def quadrature_moments_parallel_theano_source(MQ,VQ,S,Y):
    X  = Tintr[None,:]*T.sqrt(VQ)[:,None]+MQ[:,None]
    R0 = X + S[:,None]+Tslog(Tcast(dt))
    L  = Y[:,None]*R0-Tsexp(R0)
    L  = L - T.max(L,axis=1)[:,None]
    L += -0.5*((Tintr**2.0)[None,:]+Tslog(VQ)[:,None])
    PR = Tmx(Tcon(1e-7),Tsexp(L))
    NR = Tsinv(T.sum(PR,axis=1))
    MP = T.sum(X*PR,axis=1)*NR
    VP = T.sum((X-MP[:,None])**2.0*PR,axis=1)*NR
    return MP,VP

TMQ = T.vector("TMQ",dtype=dtype)
TVQ = T.vector("TVQ",dtype=dtype)
TS_ = T.vector("TS_",dtype=dtype)
TY_ = T.vector("TY_",dtype=dtype)
quadrature_moments_parallel_theano = Tfun(
    inp = [TMQ,TVQ,TS_,TY_],
    out = quadrature_moments_parallel_theano_source(TMQ,TVQ,TS_,TY_))

def surrogate_likelihood_parallel_theano_source(LM,LT,MP,TP):
    VR = Tsinv(TP-LT)
    MR = (MP*TP-LM*LT)*VR
    return MR,VR
TLM = T.vector("TLM",dtype=dtype)
TLT = T.vector("TLT",dtype=dtype)
TMP = T.vector("TMP",dtype=dtype)
TTP = T.vector("TTP",dtype=dtype)
surrogate_likelihood_parallel_theano = Tfun(
    inp = [TLM,TLT,TMP,TTP],
    out = surrogate_likelihood_parallel_theano_source(TLM,TLT,TMP,TTP))

def conditional_gaussian_parallel_theano_source(M1,M2,MR,VR,LM,LV):
    M2B = M2.dot(Tbeta) # NxKx1
    KG  = M2B/(VR+LV)[:,None,None] #NxKx1
    M2 -= T.batched_dot(KG,M2B.transpose(0,2,1))
    M1 += KG*(MR-LM)[:,None,None]
    return M1,M2
TMR = T.vector("TLM",dtype=dtype)
TVR = T.vector("TLT",dtype=dtype)
TLV = T.vector("TLV",dtype=dtype)
conditional_gaussian_parallel_theano = Tfun(
    inp = [TallM1,TallM2,TMR,TVR,TLM,TLV],
    out = conditional_gaussian_parallel_theano_source(TallM1,TallM2,TMR,TVR,TLM,TLV))

def loglikelihood_parallel_theano_source(M1,M2,S,LM,LV,MP,VP):
    _,_,LR,_,_ = project_moments_parallel_theano_source(M1,M2,S)
    LOGPYX     = Y_*LR-Tsexp(LR)
    LL         = LOGPYX - 0.5*(Tslog(LV/VP) + (MP-LM)**2.0/LV)
    return LL
TVP = T.vector("TVP",dtype=dtype)
loglikelihood_parallel_theano = Tfun(
    inp = [TallM1,TallM2,TS_,TLM,TLV,TMP,TVP],
    out = loglikelihood_parallel_theano_source(TallM1,TallM2,TS_,TLM,TLV,TMP,TVP))

def measurement_update_parallel_theano_source(M1,M2,S,Y):
    LV,LM,_,_,_ = project_moments_parallel_theano_source(M1,M2,S)
    LT    = Tsinv(LV)
    MQ,VQ = univariate_prior_parallel_theano_source(LM,LT)
    MP,VP = quadrature_moments_parallel_theano_source(MQ,VQ,S,Y)
    TP    = Tsinv(VP)
    MR,VR = surrogate_likelihood_parallel_theano_source(LM,LT,MP,TP)
    M1,M2 = conditional_gaussian_parallel_theano_source(M1,M2,MR,VR,LM,LV)
    LL    = loglikelihood_parallel_theano_source(M1,M2,S,LM,LV,MP,VP)
    return M1,M2,LL
measurement_update_parallel_theano = Tfun(
    inp = [TallM1,TallM2,TS_,TY_],
    out = measurement_update_parallel_theano_source(TallM1,TallM2,TS_,TY_))

def integrate_dt_parallel_theano_source(M1,M2,S):
    for k in range(oversample):
        M1,M2 = integrate_moments_parallel_theano_source(M1,M2,S)
    return M1,M2
integrate_dt_parallel_theano = Tfun(
    inp = [TallM1,TallM2,TS_],
    out = integrate_dt_parallel_theano_source(TallM1,TallM2,TS_))

def filter_moments_parallel_theano_source(di,M1,M2,S,Y):
    if reg_cov>0.0:
        M2 = 0.5*(M2 + M2.transpose(0,2,1)) + Tcast(allRC)
    offsets = Tmx(0.0,T.arange(N)+di)
    offsets = T.cast(offsets,'int32')
    S = S[offsets]
    Y = Y[offsets]
    M1,M2    = integrate_dt_parallel_theano_source(M1,M2,S)
    M1,M2,LL = measurement_update_parallel_theano_source(M1,M2,S,Y)
    return M1,M2,LL
Tdi = T.scalar("Tdi",dtype=dtype)
filter_moments_parallel_theano = Tfun(
    inp = [Tdi,TallM1,TallM2,TS_,TY_],
    out = filter_moments_parallel_theano_source(Tdi,TallM1,TallM2,TS_,TY_))

# Depth D Loop
[_M1,_M2,_LL], up = theano.scan(filter_moments_parallel_theano_source,
                                sequences     = [Tcon(arange(1-D,1))],
                                outputs_info  = [Tcon(iniM1),Tcon(iniM2),None],
                                non_sequences = [TS_,TY_],
                                n_steps       = D,
                                name          = 'scan_moments_parallel_theano')
#
ALLM1,ALLM2 = _M1[-1],_M2[-1]
ALLLV,_,ALLLR,_,_ = project_moments_parallel_theano_source(ALLM1,ALLM2,TS_)
scan_moments_parallel_theano = Tfun(\
    inp = [TS_,TY_],
    out = [ALLLR,ALLLV,ALLM1,ALLM2,-T.mean(_LL[-1])],
    upd = up)
                                                 
##########################################

tic()
allLR,allLV,allM1,allM2,NLL = scan_moments_parallel_theano(S,Y)
toc()

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()

# compare likelihoods
# parallel shallow likelihood as close to filtered likelihood
# as it is to the theano implementation
# meaning that a shallow filter is as accurate as a deep filter
# up to numerical precision errors
print(nllt,nlln,NLL)
ylim(max(ylim()[0],-100),5) 

## Numpy integrator broken down into functions for easier conversion to theano

In [None]:
S = stim
Y = Y_train

allLR = np.zeros(N)
allLV = np.zeros(N)
allM1 = np.zeros((N,K,1))
allM2 = np.zeros((N,K,K))
allRC = np.zeros((N,K,K))
for i in range(N):
    allRC[i,...]=reg_cov*np.eye(K)
allM1[...] = iniM1
allM2[...] = iniM2

########################################################
# Numpy

def project_moments_parallel(allM1,allM2,S_):
    LOGV   = allM2.dot(beta[:,0]).dot(beta[:,0])
    LOGM   = allM1[:,:,0].dot(beta[:,0])
    LOGX   = np.minimum(maxlogr,LOGM+S_)
    R0_    = np.minimum(maxrate,sexp(LOGX))*dtfine # N 
    RM     = R0_ * np.minimum(1+0.5*LOGV,maxvcorr) # N
    return LOGV,LOGM,LOGX,R0_,RM

def euler_update_moments_parallel(allM1,allM2,R0,RM):
    allM1 += np.matmul(Adt,allM1)
    allM1 += C [None,:,:]*RM[:,None,None]
    J_     = Cb[None,:,:]*R0[:,None,None]+Adt[None,:,:]
    JM2_   = np.matmul(J_,allM2)
    allM2 += JM2_ + JM2_.transpose((0,2,1))
    allM2 += CC[None,:,:]*RM[:,None,None]
    return allM1,allM2

def integrate_moments_parallel(allM1,allM2,S_):
    LOGV,LOGM,LOGX,R0,RM = project_moments_parallel(allM1,allM2,S_)
    allM1,allM2 = euler_update_moments_parallel(allM1,allM2,R0,RM)
    return allM1,allM2

def integrate_dt_parallel(allM1,allM2,S_):
    for k in range(oversample):
        allM1,allM2 = integrate_moments_parallel(allM1,allM2,S_)
    return allM1,allM2

def quadrature_moments_parallel(MQ,VQ,S_,Y_):
    intr = np.linspace(-4,4,25)
    X_ = intr[None,:]*np.sqrt(VQ)[:,None]+MQ[:,None]
    R0_ = X_ + S_[:,None]+slog(dt)
    L = Y_[:,None]*R0_-sexp(R0_)
    L = L - np.max(L,axis=1)[:,None]
    L += -.5*((intr**2)[None,:]+slog(VQ)[:,None])
    PR = np.maximum(1e-7,sexp(L))
    NR = 1/np.sum(PR,axis=1)
    MP = np.sum(X_*PR,axis=1)*NR
    VP = np.sum((X_-MP[:,None])**2*PR,axis=1)*NR
    return MP,VP

def univariate_prior_parallel(LM,LT):
    TQ = LT + reg_rate
    VQ = 1/TQ
    MQ = (LM*LT+m*reg_rate)*VQ
    return MQ,VQ

def surrogate_likelihood_parallel(LM,LT,MP,TP):
    VR = 1/np.maximum(1e-12,TP-LT)
    MR = (MP*TP-LM*LT)*VR
    return MR,VR

def conditional_gaussian_parallel(allM1,allM2,MR,VR,LM,LV):
    M2B_ = np.matmul(allM2,beta)
    KG = M2B_/(VR+LV)[:,None,None]
    allM2 -= np.matmul(KG,M2B_.transpose(0,2,1))
    allM1 += KG*(MR-LM)[:,None,None]
    return allM1,allM2

def loglikelihood_parallel(allM1,allM2,S_,LM,LV,MP,VP):
    _,_,LR,_,_ = project_moments_parallel(allM1,allM2,S_)
    LOGPYX     = Y_*LR-sexp(LR)
    LL         = LOGPYX - 0.5*(slog(LV/VP) + (MP-LM)**2/LV)
    return LL

def measurement_update_parallel(allM1,allM2,S_,Y_):
    LV,LM,_,_,_ = project_moments_parallel_theano(allM1,allM2,S_)
    LT          = 1/np.maximum(1e-12,LV)
    MQ,VQ       = univariate_prior_parallel_theano(LM,LT)
    MP,VP       = quadrature_moments_parallel(MQ,VQ,S_,Y_)
    TP          = 1/np.maximum(1e-12,VP)
    MR,VR       = surrogate_likelihood_parallel_theano(LM,LT,MP,TP)
    allM1,allM2 = conditional_gaussian_parallel_theano(allM1,allM2,MR,VR,LM,LV)
    LL          = loglikelihood_parallel_theano(allM1,allM2,S_,LM,LV,MP,VP)
    return allM1,allM2,LL

def filter_moments_parallel(di,allM1,allM2,S,Y):
    # Reset values that really shouldn't be being integrated? 
    allM1[:-di,...]=iniM1[:-di,...]
    allM2[:-di,...]=iniM2[:-di,...]
    # Regularize
    if reg_cov>0:
        allM2 = 0.5*(allM2 + allM2.transpose(0,2,1)) + allRC
    offsets = np.maximum(0,np.arange(N)+di)
    S = S[offsets]
    Y = Y[offsets]
    allM1,allM2 = integrate_dt_parallel_theano(allM1,allM2,S)
    allM1,allM2,LL = measurement_update_parallel_theano(allM1,allM2,S,Y)
    return allM1,allM2,LL

def scan_moments_parallel():
    allM1 = iniM1.copy()
    allM2 = iniM2.copy()
    for di in range(-D+1,1):
        allM1,allM2,LL = filter_moments_parallel(di,allM1,allM2,S,Y)
    allLV = allM2.dot(beta[:,0]).dot(beta[:,0])
    allLR = np.minimum(maxlogr,allM1[:,:,0].dot(beta[:,0])+S)
    return allLR,allLV,allM1,allM2,-mean(LL)

tic()
allLR,allLV,allM1,allM2,NLL = scan_moments_parallel()
toc()

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()

# compare likelihoods
# parallel shallow likelihood as close to filtered likelihood
# as it is to the theano implementation
# meaning that a shallow filter is as accurate as a deep filter
# up to numerical precision errors
print(nllt,nlln,-mean(LL))
ylim(max(ylim()[0],-100),5) 

# Reorganize Theano implementation

Separate functions are great for debugging, but let's clean things up a bit!

In [None]:
S = stim
Y = Y_train

TAdt  = Tcon(Adt)
Tbeta = Tcon(beta)
Tb    = Tcon(p.ravel()[1:K+1])
TC    = Tcon(C ).dimshuffle('x',0,1)
TCb   = Tcon(Cb).dimshuffle('x',0,1)
TCC   = Tcon(CC).dimshuffle('x',0,1)

mxl = Tcon(maxlogr)
mxr = Tcon(maxrate)
dtf = Tcon(dtfine)
xvc = Tcon(maxvcorr)
rr  = Tcon(reg_rate)
mm  = Tcon(m)

TM1 = T.tensor3("TM1",dtype=dtype)
TM2 = T.tensor3("TM2",dtype=dtype)

def integrate_moments_parallel_theano_source(M1,M2,S):
    LOGV = M2.dot(Tb).dot(Tb) # N
    LOGM = M1[:,:,0].dot(Tb)  # N
    LOGX = Tmn(mxl,LOGM+S) # N
    R0   = Tmn(mxr,Tsexp(LOGX))*dtf # N 
    RM   = R0 * Tmn(1.0+0.5*LOGV,xvc) # N
    J    = TCb*R0[:,None,None]+TAdt[None,:,:]
    JM2  = T.batched_dot(J,M2)
    M2  += JM2 + JM2.transpose(0,2,1)    + TCC*RM[:,None,None]
    M1  += TAdt.dot(M1).transpose(1,0,2) + TC *RM[:,None,None]
    return M1,M2

Tintr = Tcon(np.linspace(-4,4,25))
def measurement_update_parallel_theano_source(M1,M2,S_,Y_):
    LV = M2.dot(Tb).dot(Tb) # N
    LM = M1[:,:,0].dot(Tb)  # N
    LT = Tsinv(LV)
    TQ = LT + rr
    VQ = Tsinv(TQ)
    MQ = (LM*LT+mm*rr)*VQ
    X_ = Tintr[None,:]*T.sqrt(VQ)[:,None]+MQ[:,None]
    R0 = X_ + S_[:,None]+Tslog(dt)
    L  = Y_[:,None]*R0-Tsexp(R0)
    L  = L - T.max(L,axis=1)[:,None]
    L += -0.5*((Tintr**2.0)[None,:]+Tslog(VQ)[:,None])
    PR = Tmx(1e-7,Tsexp(L))
    NR = Tsinv(T.sum(PR,axis=1))
    MP = T.sum(X_*PR,axis=1)*NR
    VP = T.sum((X_-MP[:,None])**2*PR,axis=1)*NR
    TP = Tsinv(VP)
    VR = Tsinv(TP-LT)
    MR = (MP*TP-LM*LT)*VR
    # Multivariate conditional update
    M2B_  = M2.dot(Tbeta) # NxKx1
    KG    = M2B_/(VR+LV)[:,None,None] #NxKx1
    M2   -= T.batched_dot(KG,M2B_.transpose(0,2,1))
    M1   += KG*(MR-LM)[:,None,None]
    LR    = Tmn(mxl,M1[:,:,0].dot(Tb)+S_) # N
    LOGPYX= Y_*LR-Tsexp(LR)
    LL    = LOGPYX - 0.5*(Tslog(LV/VP) + (MP-LM)**2.0/LV)
    return M1,M2,-T.mean(LL)

def filter_moments_parallel_theano_source(di,M1,M2,S_,Y_):
    if reg_cov>0:
        M2 = 0.5*(M2 + M2.transpose(0,2,1)) + Tcast(allRC)
    offsets = T.maximum(0,T.arange(N)+di)
    offsets = T.cast(offsets,'int32')
    S_ = S_[offsets]
    Y_ = Y_[offsets]
    for k in range(oversample):
        M1,M2 = integrate_moments_parallel_theano_source(M1,M2,S_)
    M1,M2,NLL = measurement_update_parallel_theano_source(M1,M2,S_,Y_)
    return M1,M2,NLL
Tdi = T.scalar("Tdi",dtype=dtype)
filter_moments_parallel_theano = Tfun(
    inp = [Tdi,TM1,TM2,TS_,TY_],
    out = filter_moments_parallel_theano_source(Tdi,TM1,TM2,TS_,TY_))

# Depth D Loop
[_M1,_M2,_NLL], up = theano.scan(filter_moments_parallel_theano_source,
                                sequences     = [Tcon(arange(1-D,1))],
                                outputs_info  = [Tcon(iniM1),Tcon(iniM2),None],
                                non_sequences = [TS_,TY_],
                                n_steps       = D,
                                name          = 'scan_moments_parallel_theano')
#
M1,M2 = _M1[-1],_M2[-1]
ALLLV = M2.dot(Tb).dot(Tb) # N
ALLLR = T.minimum(maxlogr,M1[:,:,0].dot(Tb)+TS_) # N

scan_moments_parallel_theano = Tfun(\
    inp = [TS_,TY_],
    out = [ALLLR,ALLLV,M1,M2,_NLL[-1]],
    upd = up)
                                                 
##########################################

tic()
allLR,allLV,M1,M2,NLL = scan_moments_parallel_theano(S,Y)
toc()

subplot(311)
stderrplot(allLRref,allLVref,color=BLACK,lw=0.5)
stderrplot(allLR,allLV,color=RUST,lw=0.5,filled=0)
niceaxis()

# compare likelihoods
# parallel shallow likelihood as close to filtered likelihood
# as it is to the theano implementation
# meaning that a shallow filter is as accurate as a deep filter
# up to numerical precision errors
print(nllt,nlln,NLL)
ylim(max(ylim()[0],-100),5) 