In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import scipy
import sys
from glob import glob

In [3]:
from pyuoi.linear_model.var import VAR
from sklearn.model_selection import KFold

In [4]:
sys.path.append('../..')

In [5]:
from em import ARMAStateSpaceML
from loaders import load_sabes
from dca.cov_util import calc_cross_cov_mats_from_data
from neurosim.models.var import VAR as VARss
from neurosim.models.var import form_companion
from neurosim.models.ssr import StateSpaceRealization as SSR

In [6]:
# On *Sabes* Data, demonstrate that fitting ARMAStateSpace initialized by VAR improves cross-validated CCM prediction performance --> this is actually too large to do locally. Instead go back to synthetic model

In [7]:
state_dim = 30
obs_dim = 20
A = np.random.normal(scale=1/(1.7 * np.sqrt(state_dim)), size=(state_dim, state_dim))
while max(np.abs(np.linalg.eigvals(A))) > 0.99:
    A = np.random.normal(scale=1/(1.7 * np.sqrt(state_dim)), size=(state_dim, state_dim))

C = scipy.stats.ortho_group.rvs(state_dim)[:, 0:obs_dim].T
ssr = SSR(A=A, B=np.eye(A.shape[0]), C=C)
ccm = ssr.autocorrelation(5)

In [8]:
y = ssr.trajectory(int(5e3))

In [9]:
cv_split = list(KFold(5).split(y))

In [10]:
train_idxs, test_idxs = cv_split[0]

In [11]:
# Fit OLS VAR(3)
varmodel = VAR(order=2, estimator='ols')

In [12]:
varmodel.fit(y[train_idxs])

Rows of VAR matrix processed: 20it [00:00, 631.22it/s]


In [13]:
varss = VARss(varmodel.coef_)

In [14]:
ccm_test = calc_cross_cov_mats_from_data(y[test_idxs], 5)
ccm_var = varss.autocorrelation(5)

In [15]:
ccm_var_error = np.array([np.linalg.norm(ccm_var[i] - ccm_test[i]) for i in range(5)])

In [16]:
ccm_var_error

array([1.3704781 , 1.30507517, 1.12619335, 1.07052301, 0.99779408])

In [17]:
varmaxmodel = ARMAStateSpaceML(init_strategy='manual')

In [18]:
# To initiate parameters, solve min phase
varss.solve_min_phase()

In [19]:
varmaxmodel.fit(y[train_idxs], state_dim = 30, Ainit=varss.A, Cinit=varss.C, Kinit=varss.Bmin @ np.linalg.inv(varss.Dmin), Rinit= varss.Dmin @ varss.Dmin.T, x0 = np.zeros(varss.A.shape[0]), Sigma0 = varss.P, Qinit=varss.Bmin @ varss.Bmin.T)

  Psqrt_pred[i] = Psqrt_pred_
  Ppred[i] = Psqrt_pred_ @ Psqrt_pred_.T
  Psqrt_filt[i] = Psqrt_filt_
  Psmooth[i] = Psqrt_smooth @ Psqrt_smooth.T
L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A.transpose(-2, -1).conj()).transpose(-2, -1).conj() (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448224956/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:1284.)
  self.Psqrt0 = torch.cholesky(torch.tensor(P0)).float()


E step: 6.877015


  return _VF.chain_matmul(matrices)  # type: ignore[attr-defined]
100%|██████████| 100/100 [02:14<00:00,  1.35s/it]


M step: 138.856035
E step: 6.669816


100%|██████████| 100/100 [02:15<00:00,  1.35s/it]


M step: 139.531156
E step: 6.598305


100%|██████████| 100/100 [02:11<00:00,  1.32s/it]


M step: 136.213006
E step: 6.648780


100%|██████████| 100/100 [02:12<00:00,  1.32s/it]


M step: 136.562677
E step: 6.616940


100%|██████████| 100/100 [02:13<00:00,  1.33s/it]


M step: 137.399914
E step: 6.652199


100%|██████████| 100/100 [02:16<00:00,  1.36s/it]


M step: 140.594361
E step: 6.533943


100%|██████████| 100/100 [02:08<00:00,  1.28s/it]


M step: 132.650943
E step: 6.506279


100%|██████████| 100/100 [02:06<00:00,  1.27s/it]


M step: 131.104180
E step: 6.164299


100%|██████████| 100/100 [02:04<00:00,  1.25s/it]


M step: 129.221010
E step: 6.256009


100%|██████████| 100/100 [02:07<00:00,  1.27s/it]


M step: 131.359431
E step: 6.409265


100%|██████████| 100/100 [02:07<00:00,  1.28s/it]


M step: 132.012982
E step: 6.155097


100%|██████████| 100/100 [02:08<00:00,  1.29s/it]


M step: 133.027756
E step: 6.312376


100%|██████████| 100/100 [02:08<00:00,  1.29s/it]


M step: 132.934118
E step: 6.380029


100%|██████████| 100/100 [02:10<00:00,  1.31s/it]


M step: 134.980922
E step: 6.318720


100%|██████████| 100/100 [02:08<00:00,  1.29s/it]


M step: 133.182337
E step: 6.302460


100%|██████████| 100/100 [02:09<00:00,  1.30s/it]


M step: 134.146145
E step: 6.696658


100%|██████████| 100/100 [02:07<00:00,  1.27s/it]


M step: 131.255192
E step: 6.208906


100%|██████████| 100/100 [02:16<00:00,  1.36s/it]


M step: 140.751959
E step: 6.976629


100%|██████████| 100/100 [02:21<00:00,  1.42s/it]


M step: 145.948832
E step: 6.852685


100%|██████████| 100/100 [02:19<00:00,  1.40s/it]


M step: 143.791918
E step: 6.584129


100%|██████████| 100/100 [02:11<00:00,  1.32s/it]


M step: 135.865894
E step: 6.428014


100%|██████████| 100/100 [02:07<00:00,  1.28s/it]


M step: 131.822931
E step: 6.253425


100%|██████████| 100/100 [02:06<00:00,  1.26s/it]


M step: 130.205164
E step: 6.321633


100%|██████████| 100/100 [02:08<00:00,  1.29s/it]


M step: 132.813526
E step: 6.303153


100%|██████████| 100/100 [02:06<00:00,  1.27s/it]


M step: 131.138513
E step: 6.449687


100%|██████████| 100/100 [02:07<00:00,  1.28s/it]


M step: 131.799231
E step: 6.452705


100%|██████████| 100/100 [02:04<00:00,  1.24s/it]


M step: 128.447930
E step: 6.377103


100%|██████████| 100/100 [02:05<00:00,  1.25s/it]


M step: 129.357952
E step: 6.137690


100%|██████████| 100/100 [02:07<00:00,  1.28s/it]


M step: 131.911952
E step: 6.171821


100%|██████████| 100/100 [02:06<00:00,  1.26s/it]


M step: 130.697622
E step: 6.452004


100%|██████████| 100/100 [02:03<00:00,  1.23s/it]


M step: 127.534045
E step: 6.117009


100%|██████████| 100/100 [02:04<00:00,  1.24s/it]


M step: 128.551723
E step: 6.129442


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 127.931013
E step: 6.267442


100%|██████████| 100/100 [02:04<00:00,  1.24s/it]


M step: 128.507782
E step: 6.281939


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 127.822958
E step: 6.455819


100%|██████████| 100/100 [02:04<00:00,  1.24s/it]


M step: 128.203382
E step: 6.231394


100%|██████████| 100/100 [02:03<00:00,  1.23s/it]


M step: 127.525336
E step: 6.435318


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 127.685053
E step: 6.201349


100%|██████████| 100/100 [02:02<00:00,  1.23s/it]


M step: 126.926360
E step: 6.180972


100%|██████████| 100/100 [02:04<00:00,  1.24s/it]


M step: 128.680332
E step: 6.218460


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 127.739226
E step: 6.386546


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 128.082615
E step: 6.322929


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 127.758098
E step: 6.380089


100%|██████████| 100/100 [02:03<00:00,  1.23s/it]


M step: 127.473770
E step: 6.158637


100%|██████████| 100/100 [02:02<00:00,  1.22s/it]


M step: 126.497786
E step: 6.165068


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 127.681344
E step: 6.204326


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 128.127752
E step: 6.209128


100%|██████████| 100/100 [02:02<00:00,  1.22s/it]


M step: 126.193038
E step: 6.357074


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 127.953151
E step: 6.145125


100%|██████████| 100/100 [02:13<00:00,  1.34s/it]


M step: 138.632405
E step: 6.846924


100%|██████████| 100/100 [02:20<00:00,  1.40s/it]


M step: 145.068144
E step: 6.626424


100%|██████████| 100/100 [02:17<00:00,  1.37s/it]


M step: 142.049842
E step: 6.637808


100%|██████████| 100/100 [02:18<00:00,  1.39s/it]


M step: 143.805781
E step: 7.242793


100%|██████████| 100/100 [02:12<00:00,  1.33s/it]


M step: 136.678593
E step: 6.180429


100%|██████████| 100/100 [01:59<00:00,  1.20s/it]


M step: 123.918494
E step: 6.093565


100%|██████████| 100/100 [01:59<00:00,  1.20s/it]


M step: 123.779736
E step: 5.985243


100%|██████████| 100/100 [01:59<00:00,  1.20s/it]


M step: 124.048221
E step: 6.140686


100%|██████████| 100/100 [02:01<00:00,  1.22s/it]


M step: 126.099638
E step: 6.048418


100%|██████████| 100/100 [02:00<00:00,  1.20s/it]


M step: 124.236115
E step: 6.169329


100%|██████████| 100/100 [02:01<00:00,  1.21s/it]


M step: 125.561396
E step: 6.058623


100%|██████████| 100/100 [02:01<00:00,  1.22s/it]


M step: 125.811127
E step: 6.297299


100%|██████████| 100/100 [02:02<00:00,  1.22s/it]


M step: 126.578651
E step: 6.233667


100%|██████████| 100/100 [02:01<00:00,  1.22s/it]


M step: 125.913505
E step: 6.402116


100%|██████████| 100/100 [02:01<00:00,  1.22s/it]


M step: 125.983393
E step: 6.466828


100%|██████████| 100/100 [02:01<00:00,  1.22s/it]


M step: 125.741106
E step: 6.448556


100%|██████████| 100/100 [02:01<00:00,  1.22s/it]


M step: 126.120841
E step: 6.164275


100%|██████████| 100/100 [02:02<00:00,  1.23s/it]


M step: 126.716243
E step: 6.209215


100%|██████████| 100/100 [02:03<00:00,  1.23s/it]


M step: 127.192215
E step: 6.043835


100%|██████████| 100/100 [02:03<00:00,  1.23s/it]


M step: 127.607267
E step: 6.128891


100%|██████████| 100/100 [02:01<00:00,  1.21s/it]


M step: 125.150401
E step: 6.206873


100%|██████████| 100/100 [02:01<00:00,  1.22s/it]


M step: 126.058673
E step: 6.085663


100%|██████████| 100/100 [02:02<00:00,  1.23s/it]


M step: 127.082204
E step: 6.105018


100%|██████████| 100/100 [02:03<00:00,  1.24s/it]


M step: 128.134309
E step: 6.129392


100%|██████████| 100/100 [02:02<00:00,  1.22s/it]


M step: 126.691266
E step: 6.124869


100%|██████████| 100/100 [02:04<00:00,  1.24s/it]


M step: 128.391147
E step: 6.089148


100%|██████████| 100/100 [02:04<00:00,  1.25s/it]


M step: 128.878625
E step: 6.696704


100%|██████████| 100/100 [02:04<00:00,  1.24s/it]


M step: 128.639442
E step: 6.339595


100%|██████████| 100/100 [02:02<00:00,  1.22s/it]


M step: 126.582052
E step: 6.182097


100%|██████████| 100/100 [02:06<00:00,  1.27s/it]


M step: 130.987176
E step: 6.411013


100%|██████████| 100/100 [02:05<00:00,  1.25s/it]


M step: 129.595745
E step: 6.085804


100%|██████████| 100/100 [02:07<00:00,  1.27s/it]


M step: 131.456381
E step: 6.776192


100%|██████████| 100/100 [02:14<00:00,  1.34s/it]


M step: 139.196619
E step: 7.327208


100%|██████████| 100/100 [02:15<00:00,  1.35s/it]


M step: 139.354543
E step: 6.437987


100%|██████████| 100/100 [02:11<00:00,  1.31s/it]


M step: 135.556750
E step: 6.571547


100%|██████████| 100/100 [02:10<00:00,  1.30s/it]


M step: 134.327272
E step: 6.410771


100%|██████████| 100/100 [02:10<00:00,  1.30s/it]


M step: 134.294792
E step: 6.352110


100%|██████████| 100/100 [02:11<00:00,  1.31s/it]


M step: 135.532847
E step: 6.464947


100%|██████████| 100/100 [02:10<00:00,  1.30s/it]


M step: 134.429373
E step: 6.409553


100%|██████████| 100/100 [02:10<00:00,  1.31s/it]


M step: 134.936028
E step: 6.496848


100%|██████████| 100/100 [02:10<00:00,  1.31s/it]


M step: 135.270874
E step: 6.560016


100%|██████████| 100/100 [02:11<00:00,  1.32s/it]


M step: 136.072162
E step: 6.265138


100%|██████████| 100/100 [02:09<00:00,  1.30s/it]


M step: 133.950256
E step: 6.212267


100%|██████████| 100/100 [02:07<00:00,  1.27s/it]


M step: 131.629893
E step: 6.589193


100%|██████████| 100/100 [02:09<00:00,  1.29s/it]


M step: 133.763909
E step: 6.355378


100%|██████████| 100/100 [02:07<00:00,  1.27s/it]


M step: 131.683274
E step: 6.293971


100%|██████████| 100/100 [02:07<00:00,  1.28s/it]


M step: 132.356908
E step: 6.554104


100%|██████████| 100/100 [02:09<00:00,  1.30s/it]


M step: 134.062575
E step: 6.472778


100%|██████████| 100/100 [02:10<00:00,  1.30s/it]


M step: 134.722640
E step: 6.258533


100%|██████████| 100/100 [02:04<00:00,  1.25s/it]


M step: 129.210367
E step: 6.366287


100%|██████████| 100/100 [02:07<00:00,  1.28s/it]


M step: 132.367291


In [23]:
np.abs(np.linalg.eigvals(A))

array([0.57277005, 0.57277005, 0.52274874, 0.5059852 , 0.5059852 ,
       0.53413055, 0.53413055, 0.47570937, 0.47570937, 0.50046757,
       0.50046757, 0.39653482, 0.39653482, 0.39740329, 0.39740329,
       0.34455306, 0.34455306, 0.28284213, 0.38452126, 0.38452126,
       0.46954518, 0.43630209, 0.30691375, 0.26642725, 0.26642725,
       0.20581297, 0.20581297, 0.09976788, 0.09976788, 0.06514524])

In [22]:
np.abs(np.linalg.eigvals(varmaxmodel.A))

array([0.5727169 , 0.5727169 , 0.54088318, 0.54088318, 0.5200568 ,
       0.5200568 , 0.47115094, 0.48527573, 0.48527573, 0.46942597,
       0.46942597, 0.40369581, 0.40369581, 0.3418304 , 0.3418304 ,
       0.36707684, 0.36707684, 0.25180371, 0.25180371, 0.2888591 ,
       0.2888591 , 0.20053664, 0.38009606, 0.3239354 , 0.3239354 ,
       0.20064043, 0.20064043, 0.2961741 , 0.28816883, 0.28816883,
       0.23834564, 0.23834564, 0.18020215, 0.18020215, 0.19134532,
       0.19134532, 0.09786041, 0.09786041, 0.04497333, 0.04497333])

In [24]:
varmaxmodel.Q

array([[ 2.59725650e-01,  1.35558770e-01,  9.03894822e-02, ...,
        -1.24578375e-02,  2.38846074e-01,  1.36475029e-01],
       [ 1.35558770e-01,  2.89426768e-01,  6.04160904e-02, ...,
         6.46501187e-02,  1.92544170e-01,  2.73604728e-02],
       [ 9.03894822e-02,  6.04160904e-02,  1.86769343e-01, ...,
        -1.99227919e-02,  1.30371678e-01, -8.80720099e-02],
       ...,
       [-1.24578375e-02,  6.46501187e-02, -1.99227919e-02, ...,
         1.02549292e+00, -5.56532157e-03,  7.46308756e-03],
       [ 2.38846074e-01,  1.92544170e-01,  1.30371678e-01, ...,
        -5.56532157e-03,  1.02162045e+00, -6.97240376e-04],
       [ 1.36475029e-01,  2.73604728e-02, -8.80720099e-02, ...,
         7.46308756e-03, -6.97240376e-04,  1.03280567e+00]])

In [25]:
def autocorrelation(A, C, K, R, T):

    autocorr = np.zeros((T, C.shape[0], C.shape[0]))

    P = scipy.linalg.solve_discrete_lyapunov(A, K @ R @ K.T)

    autocorr[0, ...] = C @ P @ C.T + R

    # Construct Cbar from A, P, C, and S
    Cbar = C @ P @ A.T + (K @ R).T

    for i in range(1, T):
        autocorr[i, ...] = C @ np.linalg.matrix_power(A, i - 1) @ Cbar.T
    return autocorr  

In [26]:
ccm_varmax = autocorrelation(varmaxmodel.A, varmaxmodel.C, varmaxmodel.K, varmaxmodel.R, 5)

In [27]:
ccm_varmax_error = np.array([np.linalg.norm(ccm_varmax[i] - ccm_test[i]) for i in range(5)])

In [28]:
ccm_varmax_error

array([4.48239783, 2.3785857 , 1.72001859, 1.1860499 , 1.06047833])

In [29]:
ccm_var_error

array([1.3704781 , 1.30507517, 1.12619335, 1.07052301, 0.99779408])