# implementation of non-linear (locally (bi-)linear) models

Zhao et al. (2016), "Interpretable Nonlinear Dynamic Modeling
of Neural Trajectories"

\begin{align}
y_{t+1} &= (1-\exp(-\tau^2)) y_t + C x_t + B_t u_t + \epsilon_t \\
\mbox{vec}\left(B_t\right) &= W x_t \nonumber \\
x_{i,t} &= \Phi_i(y_t) = \frac{1}{Z} \exp\left(- \frac{||y_t - z_i ||}{2\sigma_i^2}\right) \nonumber \\
\epsilon_t &\sim \mathcal{N}(0, R) \nonumber
\end{align}
Parameters: 
- $C \in \mathbb{R}^{p \times n}$, 
- $\mbox{diag}(R) \in \mathbb{R}^p$, 
- $W \in \mathbb{R}^{p\cdot{}m \times n}$, 
- $\tau \in \mathbb{R}$, 
- $\forall i = 1, \ldots,n: z_i \in \mathbb{R}^p, \sigma_i^2 \in \mathbb{R}$

Notice that $B_t u_t = \left( u_t^\top \otimes \mathbb{1}_{p}  \right) W x_t$, i.e. the input-dependent terms are bilinear in $x_t, u_t$. 


In [None]:
import numpy as np
import scipy as sp


p,n,m,T = 20,5,3,100000
pars = {}

# classic LDS pars
pars['C'] = np.random.normal(size=(p,n))/np.sqrt(n)
pars['R'] = np.ones(p)

# auto-regression on observed variables
pars['tau'] = 0#np.sqrt(-np.log( 0.99 ))

# fixed non-linear mapping from observed to latents
pars['Z'] = np.random.normal(size=(n,p))
pars['sig2'] =  np.ones(n) * p 

# bilinear dependence on inputs & latents
pars['W'] = np.random.normal(size=(p*m,n))
pars['B'] = lambda x: np.reshape((pars['W'].dot(x)), (m,p)).T

# input
u_const = np.random.normal(size=m)
u = np.zeros((T,m))
for t in range(T):
    u[t] = u_const.copy()

# technical convenience parameters
pars['e'] = 10e-10
pars['sqR'] = np.sqrt(pars['R'])
pars['alpha'] = 1 - np.exp(-pars['tau']**2)
pars['Ceff'] = pars['C'] + np.kron(u_const.T, np.eye(p)).dot(pars['W']) 
print('alpha = ', pars['alpha'])


def condition_on(y):
    phi = np.exp( - np.sum((y-pars['Z'])**2,1) / (2*pars['sig2']) )
    return phi / (pars['e'] + phi.sum())

def predict(y, u, x,pars, eps):
    return pars['alpha']*y + pars['C'].dot(x) + pars['B'](x).dot(u) + pars['sqR'] * eps


#u = np.random.normal(size=(T,m))
y,x = np.zeros((T,p)), np.zeros((T,n))
y2 =  np.zeros((T,p))
#y[0] =  np.array([0,1])
for t in range(1,T):
    x[t-1] = condition_on(y[t-1])
    eps = np.random.normal(size=p)
    y[ t ] = predict(y[t-1], u[t-1], x[t-1], pars, eps)
x[T-1] = condition_on(y[T-1])

%matplotlib inline
import matplotlib.pyplot as plt

plt.figure(figsize=(20,10))
plt.subplot(3,1,1)
plt.plot(y)
plt.xlabel('t')
plt.ylabel('y')
plt.subplot(3,1,2)
plt.plot(x)
plt.xlabel('t')
plt.ylabel('x')
plt.subplot(3,1,3)
plt.show()
    

In [None]:
Pi = np.cov(x.T)
L0 = np.cov(y.T)
L0[np.diag_indices(p)] *= 0
Lr = (pars['Ceff']).dot(Pi).dot((pars['Ceff']).T) + np.diag(pars['R'])
Lr[np.diag_indices(p)] *= 0

plt.figure(figsize=(15,8))
plt.subplot(1,2,1)
plt.imshow(L0, interpolation='None')
plt.subplot(1,2,2)
plt.imshow(Lr, interpolation='None')
plt.show()


In [None]:
plt.plot(y[101:111,0],pars['Ceff'][0,:].dot(x[100:110,:].T).T)
plt.show()

In [None]:
import time
for t in range(T):
    print('t =', t)
    plt.subplot(1,2,1)
    plt.plot(y[t][0],y[t][1], color='k', marker='o')
    plt.hold(True)
    plt.plot(pars['Z'][0,0], pars['Z'][0,1], marker='o')
    plt.plot(pars['Z'][1,0], pars['Z'][1,1], marker='o')
    plt.plot(pars['Z'][2,0], pars['Z'][2,1], marker='o')
    plt.subplot(1,2,2)
    plt.bar(0, x[t][0], color='b')
    plt.hold(True)
    plt.bar(1, x[t][1], color='g')
    plt.bar(2, x[t][2], color='r')
    plt.show()


In [None]:
plt.bar?

In [None]:
plt.plot(y[37], color='k')
plt.hold(True)
plt.plot(pars['Z'].T)
plt.show

In [None]:
np.exp( - np.sum((y[0]-pars['Z'])**2,1) / (2*pars['sig2']) )