# implementation of non-linear (locally (bi-)linear) models

Zhao et al. (2016), "Interpretable Nonlinear Dynamic Modeling
of Neural Trajectories"

\begin{align}
y_{t+1} &= (1-\exp(-\tau^2)) y_t + C x_t + B_t u_t + \epsilon_t \\
\mbox{vec}\left(B_t\right) &= W x_t \nonumber \\
x_{i,t} &= \Phi_i(y_t) = \frac{1}{Z} \exp\left(- \frac{||y_t - z_i ||}{2\sigma_i^2}\right) \nonumber \\
\epsilon_t &\sim \mathcal{N}(0, R) \nonumber
\end{align}
Parameters: 
- $C \in \mathbb{R}^{p \times n}$, 
- $\mbox{diag}(R) \in \mathbb{R}^p$, 
- $W \in \mathbb{R}^{p\cdot{}m \times n}$, 
- $\tau \in \mathbb{R}$, $\alpha = 1- \exp(-\tau^2))$ is effective parameter for AR process
- $\forall i = 1, \ldots,n: z_i \in \mathbb{R}^p, \sigma_i^2 \in \mathbb{R}$

Remarks:
- $B_t u_t = \left( u_t^\top \otimes \mathcal{1}_{p}  \right) W x_t$, i.e. the input-dependent terms are bilinear in $x_t, u_t$. 
- $x_t = \Phi(y_t)$ is identical to 'responsibilities' in a Gaussian mixture model with n spherical mixture distributions, with $\mu_i = z_i$, $\Sigma_i = \sigma_i^2 \mathcal{1}_p$.
- with $C,W = 0$, the model is a simple autoregressive process. To fix a certain output variance for each $y_i$, choose
$R_{ii} = (1-\alpha)^2 \mbox{Var}[y_i]$

Cookbook:
- settting $z_i = -C_{:,i}$ potentially leads to interesting (albeit extreme) dynamics
- centering the $C_{:,i}$ ensures dynamics are centered on the origin even for small $n$
- with $\tau =0$ (no AR part), the $\sigma_i^2$ become the single-most important parameters to get right
- the behavior of the model can change drastically with $\bar{\sigma}^2 = < \sigma_i^2 >$ 

In [None]:
%matplotlib inline

import numpy as np
import scipy as sp
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import glob, os, psutil, time

p,n,m,T = 10000,5,2,1000
k,l =  3,3
pars = {}

# I/O matter
mmap, chunksize = True, np.min((2*T,10000))
#data_path, save_file = '/media/marcel/636f7b46-1fd1-4600-b69e-86d2ed82002c/stitching/hankel/', 'test'
data_path, save_file = '../fits/', 'test'
pa, pb = np.min((p, 1000)), np.min((p, 1000))
idx_a, idx_b = np.sort(np.random.choice(p,pa,replace=False)), np.sort(np.random.choice(p,pb,replace=False))
verbose=True

# auto-regression on observed variables
pars['tau'] = 0 * np.sqrt(-np.log( 0.5 ))
pars['alpha'] = 1 - np.exp(-pars['tau']**2)
pars['R'] = (1-pars['alpha']**2) * np.ones(p)

# classic LDS pars
pars['C'] = 1 * np.random.normal(size=(p,n))
pars['C'] -= np.mean(pars['C'], axis=1).reshape((-1,1))

# fixed non-linear mapping from observed to latents
pars['Z'] = - pars['C'].T.copy() 
#pars['Z'] = np.random.normal(size=(n,p))
pars['sig2'] = 1/2 * p/n * np.ones(n)

# input
u_const = np.random.normal(size=m)
u = np.zeros((2*T,m))
for t in range(2*T):
    u[t] = u_const.copy()

# bilinear dependence on inputs & latents
pars['W'] = 0 * np.random.normal(size=(p*m,n))/np.sqrt(n)
pars['B'] = lambda x: np.reshape((pars['W'].dot(x)), (m,p)).T
pars['Ceff'] = pars['C'] + np.kron(u_const.T, np.eye(p)).dot(pars['W']) 
def predict(y, u, x,pars):
    return pars['alpha']*y + pars['C'].dot(x) + pars['B'](x).dot(u)

# technical convenience parameters
pars['e'] = 10e-30
pars['sqR'] = np.sqrt(pars['R'])
print('alpha = ', pars['alpha'])

def condition_on(y):
    phi = np.exp( - np.sum((y-pars['Z'])**2,1) / (2*pars['sig2']) )
    phi /= (pars['e'] + phi.sum())
    return phi

In [None]:
# draw data
x = np.zeros((2*T,n))
eps = np.random.normal(size=(2*T,p))
if mmap:
    y = np.memmap(data_path+'y', dtype=np.float, mode='w+', shape=(2*T,p))
else:
    y = np.empty(shape=(2*T,p))
y[0] =  np.mean(pars['Z'], axis=0) # np.random.normal(size=p)   
for t in range(1,2*T):
    x[t-1] = condition_on(y[t-1])    
    y[ t ] = predict(y[t-1], u[t-1], x[t-1], pars) + pars['sqR'] * eps[t-1]
    if mmap and np.mod(t,chunksize)==0:
        del y # releases RAM, forces flush to disk
        y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(2*T,p))        
x[2*T-1] = condition_on(y[ 2*T-1 ])
if mmap:
    del y # releases RAM, forces flush to disk
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(2*T,p))

Qs = []
for m in range(k+l):
    Qs.append(None)
    print('computing time-lagged covariance for lag ', str(m))
    if mmap:
        Q = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, 
                      mode='w+', shape=(pa,pb))
    else:
        Q = np.empty((pa,pb))
    Q[:] = np.cov(y[m:m-(k+l),idx_a].T, y[:-(k+l),idx_b].T)[:pa,pa:]     
    if mmap:
        del Q
        Qs[m] = np.memmap(data_path+'Qs_'+str(m), dtype=np.float, 
                          mode='r', shape=(pa,pb))
    else:
        Qs[m] = Q
    
plt.figure(figsize=(20,10))
plt.subplot(3,1,1)
plt.plot(y[T:, idx_a])
plt.xlabel('t')
plt.ylabel('y')
plt.subplot(3,1,2)
plt.plot(x[T:])
plt.xlabel('t')
plt.ylabel('x')
plt.subplot(3,1,3)
plt.plot(3*T//2+np.arange(100,dtype=int), x[3*T//2+np.arange(100,dtype=int)])
plt.xlabel('t')
plt.ylabel('x')
plt.show()
if mmap:
    del y # releases RAM, forces flush to disk
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(2*T,p))

print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())    

In [None]:
plt.figure(figsize=(20,30))
plt.subplot(3,2,1)
plt.plot(y[T:,0], y[T:,1], 'ko-')
plt.hold(True)
for n_ in range(n):
    i = [0,1]
    plt.plot(pars['Z'][n_][i[0]],pars['Z'][n_][i[1]], 'o', markersize=20)    
    plt.xlabel('y_'+str(i[0]+1))
    plt.ylabel('y_'+str(i[1]+1))
    plt.title('projection onto 2D, traj. + n=' + str(n) + ' RBF centres')
    
for j in range(3):
    plt.subplot(3,2,2+j)
    i = np.random.choice(p,2,replace=False)
    plt.plot(y[T:,i[0]], y[T:,i[1]], 'ko-')
    plt.hold(True)
    for n_ in range(n):
        plt.plot(pars['Z'][n_][i[0]],pars['Z'][n_][i[1]], 'o', markersize=20)    
    plt.xlabel('y_'+str(i[0]+1))
    plt.ylabel('y_'+str(i[1]+1))
    
plt.subplot(3,2,5)
if not mmap:
    pca = PCA()
    pca.fit(y[T:]-np.mean(y[T:],0))
    if T >= p:
        plt.plot(range(1,p+1), np.cumsum(pca.explained_variance_ratio_)/np.sum(pca.explained_variance_ratio_))
    plt.hold(True)
    plt.plot(np.linspace(0,p+1,np.min((20,p))), 
             np.cumsum(pca.explained_variance_ratio_[:np.min((20,p))])/np.sum(pca.explained_variance_ratio_), 
             'r--')
    plt.legend(('cum. var. expl.', 'first 20, x-axis rescaled'))
plt.xlabel('#eigvalue')
plt.title('% explained variance of PCA (if p not too large)')
plt.show()

print('noise variance / total variance: \n ', np.mean(pars['R']/np.var(y,axis=0)))

if p < 1000:
    L0 = np.cov(y.T)
    L0[np.diag_indices(p)] *= 0
    Lr = (pars['Ceff']).dot(np.cov(x.T)).dot((pars['Ceff']).T) + np.cov(np.atleast_2d((pars['sqR'])*eps).T) #np.diag(pars['R'])
    Lr[np.diag_indices(p)] *= 0

    plt.figure(figsize=(15,8))
    plt.subplot(1,3,1)
    plt.imshow(L0, interpolation='None')
    
    plt.title('emp. cov matrix')
    plt.subplot(1,3,2)
    plt.imshow(Lr, interpolation='None')
    plt.title('est. cov matrix (C*cov(x)*C.T + cov(eps))')
    plt.subplot(1,3,3)
    plt.plot(L0[:], Lr[:], 'b.')
    plt.xlabel('emp.')
    plt.xlabel('est.')
    plt.show()

    print(np.var(y), np.var(pars['sqR']*eps), 
          np.mean(np.abs(x.dot(pars['C'].T))), np.var(x.dot(pars['C'].T)))
#print(np.var(y, axis=0), np.var(pars['sqR']*eps, axis=0), np.mean(np.abs(x.dot(pars['C'].T)), axis=0))

#nonlinearity of latents
xp = x[0:T-1]
xf = x[1:T]
A_ls = np.linalg.lstsq(xp, xf)[0].T
xf_l = xp.dot(A_ls)
plt.figure(figsize=(np.ceil(n/2)*10,20))
for i in range(n):
    plt.subplot(np.ceil(n/2),2,i+1)
    plt.plot(xf_l[:,i], xf[:,i], 'k.')
    plt.hold(True)
    plt.plot(xf_l[:,i], xf_l[:,i], 'r.')
    plt.xlabel('x_pred_lin')
    plt.ylabel('x_emp')
    plt.axis('equal')
    mlim = np.max((np.max(np.abs(xf_l[:,i])),np.max(np.abs(xf[:,i]))))
    #plt.axis(1.1*np.array([-mlim,mlim,-mlim,mlim]))
    plt.title('non-linearity of x_'+str(i+1))
plt.show()
    

In [None]:
A_ls

In [None]:
print('eigenvalues of A (least squares on x_t vs. x_{t-1})')
np.sort(np.linalg.eig(A_ls)[0])


In [None]:
np.mean(x,0)

In [None]:
np.linalg.eigvals(L0)[:20]

In [None]:
np.cov(x.T)

In [None]:

plt.figure(figsize=(20, 20))
for m in range(k+l):
    plt.subplot(k+l,1,m)
    Q = np.empty((2*n,2*n))
    Q[:] = np.corrcoef(x[m:m-(k+l),:].T, x[:-(k+l),:].T)
    plt.imshow(Q, interpolation='None')    
    plt.colorbar()
    plt.title('m' + str(m))
plt.show()



In [None]:
plt.figure(figsize=(10*(k+l), 10))
for m in range(k+l):
    print('computing time-lagged covariance for lag ', str(m))
    
    plt.figure(figsize=(20,8))
    plt.subplot(1,3,1)
    Q = np.empty((pa,pb))
    Q[:] = (pars['Ceff'].dot(np.cov(x[m:m-(k+l),:].T, x[:-(k+l),:].T)[:n,n:]).dot(pars['Ceff'].T))[np.ix_(idx_a,idx_b)]
    Q[np.diag_indices(pa)] *= 0
    plt.imshow(Q, interpolation='None')    

    plt.subplot(1,3,2)    
    Qf = np.empty((pa,pb))
    Qf[:] = np.cov(y[m:m-(k+l),idx_a].T, y[:-(k+l),idx_b].T)[:pa,pa:]    
    if mmap:
        del y # releases RAM, forces flush to disk
        y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(2*T,p))    
    Qf[np.diag_indices(pa)] *= 0
    plt.imshow(Qf, interpolation='None')    

    plt.subplot(1,3,3)    
    plt.plot(Qf[:], Q[:], 'b.')
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

os.chdir('../core')
import ssm_scripts, ssm_fit
from utility import get_subpop_stats, gen_data
from SSID_Hankel_loss import run_bad, plot_slim, plot_outputs_l2_gradient_test
os.chdir('../dev')

#np.random.seed(0)

#y -= np.mean(y,axis=0)

# settings for fitting algorithm
# settings for fitting algorithm
batch_size, max_zip_size, max_iter = 1, 100, 100
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = 0.001
linearity, stable, sym_psd = 'False', False, False

# create subpopulations
sub_pops = (np.arange(0,p), np.arange(0,p))

obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)

pars_init='default'        
t = time.time()
_, pars_est, traces = run_bad(k=k,l=l,n=n,y=y, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                      sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      linearity=linearity,stable=stable,init=pars_init,
                                      alpha=a,alpha_R=a_R,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)

print('fitting time was ', time.time() - t, 's')
print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())

plot_slim(Qs,k,l,pars_est,idx_a,idx_b,traces,mmap,data_path)
  

In [None]:
plt.plot(Qs[1][:], Qs[5][:], 'b.')
plt.show()

In [None]:
# settings for fitting algorithm
batch_size, max_zip_size, max_iter = 1, 100, 100
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = 0.001
linearity, stable, sym_psd = 'False', False, False


_, pars_est, traces = run_bad(k=k,l=l,n=n,y=y, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                      sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      linearity=linearity,stable=stable,init=pars_est,
                                      alpha=a,alpha_R=a_R,b1=b1,b2=b2,e=e,max_iter=max_iter,batch_size=batch_size,
                                      verbose=verbose, sym_psd=sym_psd, max_zip_size=max_zip_size)

print('fitting time was ', time.time() - t, 's')
print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())

plot_slim(Qs,k,l,pars_est,idx_a,idx_b,traces,mmap,data_path)
  

In [None]:
plt.plot(Qs[0], pars_est['C'].dot(pars_est['X'][:n,:]).dot(pars_est['C'].T) + np.diag(pars_est['R']), 'b.')
plt.hold(True)
plt.plot(np.diag(Qs[0]), np.diag(pars_est['C'].dot(pars_est['X'][:n,:]).dot(pars_est['C'].T)) + pars_est['R'], 'r.')
plt.show()

In [None]:
Pi = np.cov(x.T)
L0 = np.cov(y.T)
L0[np.diag_indices(p)] *= 0
Lr = (pars['Ceff']).dot(Pi).dot((pars['Ceff']).T) + np.diag(pars['R'])
Lr[np.diag_indices(p)] *= 0

plt.figure(figsize=(15,8))
plt.subplot(1,2,1)
plt.imshow(L0, interpolation='None')
plt.subplot(1,2,2)
plt.imshow(Lr, interpolation='None')
plt.show()
