# implementation of non-linear (locally (bi-)linear) models

Zhao et al. (2016), "Interpretable Nonlinear Dynamic Modeling
of Neural Trajectories"

\begin{align}
y_t &= (1-\exp(-\tau^2)) y_t + C x_t + B_t u_t + \epsilon_t \\
\mbox{vec}\left(B_t\right) &= W x_t \nonumber \\
x_{i,t} &= \Phi_i(y_t) = \frac{1}{Z} \exp\left(- \frac{||x_t - z_i ||}{2\sigma_i^2}\right) \nonumber \\
\epsilon_t &\sim \mathcal{N}(0, R) \nonumber
\end{align}
Parameters: 
- $C \in \mathbb{R}^{p \times n}$, 
- $\mbox{diag}(R) \in \mathbb{R}^p$, 
- $W \in \mathbb{R}^{p\cdot{}m \times n}$, 
- $\tau \in \mathbb{R}$, 
- $\forall i = 1, \ldots,n: z_i \in \mathbb{R}^p, \sigma_i^2 \in \mathbb{R}$


In [None]:
import numpy as np
import scipy as sp


p,n,m,T = 2,2,2,50
pars = {}

# classic LDS pars
#pars['C'] = np.random.normal(size=(p,n))
pars['C'] = - np.array([[1,0],[0,1]])
pars['d'] = np.ones(p)
pars['R'] = 0 * np.ones(p) 

# auto-regression on observed variables
pars['tau'] = np.sqrt(-np.log( 0.01 ))

# fixed non-linear mapping from observed to latents
#pars['Z'] = np.zeros((n,p))
pars['Z'] = np.array([[1,0],[0,1]])
pars['sig'] = np.ones(n)

# bilinear dependence on inputs & latents
pars['W'] = 0*np.random.normal(size=(p*m,n))
pars['B'] = lambda x: np.reshape(pars['W'].dot(x), (p,m))

# technical convenience parameters
pars['e'] = 10e-7
pars['sqR'] = np.sqrt(pars['R'])
pars['alpha'] = 1 - np.exp(-pars['tau']**2)

def condition_on(y):
    phi = np.exp( - np.sum((y-pars['Z'])**2,1) / (2*pars['sig']) )
    return phi / (pars['e'] + phi.sum())

def predict(y, u, x,pars, eps):
    return pars['alpha']*y + pars['C'].dot(x) + pars['B'](x).dot(u) + pars['d'] + pars['sqR'] * eps

u = np.random.normal(size=(T,m))
y,x = np.zeros((T,p)), np.zeros((T,n))
for t in range(1,T):
    x[t-1] = condition_on(y[t-1])
    y[ t ] = predict(y[t-1], u[t-1], x[t-1], pars, np.random.normal(size=p))

%matplotlib inline
import matplotlib.pyplot as plt

plt.figure(figsize=(20,10))
plt.subplot(3,1,1)
plt.plot(y)
plt.xlabel('t')
plt.ylabel('y')
plt.subplot(3,1,2)
plt.plot(x)
plt.xlabel('t')
plt.ylabel('x')
plt.subplot(3,1,3)
tmp = np.exp( - np.sum(y**2,1)/ 2 )
plt.plot(tmp / (pars['e']+3*tmp))
plt.show()
    

In [None]:
plt.plot(x.sum(axis=1))