In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
from neurosim.models.ssr import StateSpaceRealization as SSR
from neurosim.models.varma import gen_var1_connectivity
from tqdm import tqdm
import pdb
from scipy.stats import ortho_group

In [3]:
import jax
import jax.numpy as jnp

In [4]:
import sys
sys.path.append('../..')
from utils import calc_loadings
from riccati import riccati_opt

### Testing Riccati Equation solver implementation

In [5]:
from scipy.stats import ortho_group

In [6]:
# Test on the minimal positive real realization problem
ssdim = 10
projdim = 2
noise_strength = 1

A = 1/5 * np.random.normal(size=(ssdim, ssdim))

if max(np.abs(np.linalg.eigvals(A))) >= 1:
    while max(np.abs(np.linalg.eigvals(A))) >= 1:
        A = 1/5 * np.random.normal(size=(ssdim, ssdim))

# Non-random B!
B = np.random.normal(size=(ssdim, ssdim))
P = scipy.linalg.solve_discrete_lyapunov(A, B @ B.T)

C = np.eye(ssdim)
D = noise_strength * np.eye(ssdim)

Cbar = (C @ P @ A.T + D @ B.T)

# Projection matrix
V = ortho_group.rvs(ssdim)[:, 0:projdim].T
Lambda0 = V @ (C @ P @ C.T + D @ D.T) @ V.T

In [7]:
# Does the ambient P satisfy the ricatti equation?

In [240]:
PP = scipy.linalg.solve_discrete_are(A.T, (V @ C).T, np.zeros(A.shape), -Lambda0, s=-(V @ Cbar).T)

In [None]:
# Verify correctness of solution

In [241]:
riccati(PP, A, V @ C, V @ Cbar, Lambda0)

array([[-9.54791801e-15,  5.77315973e-15, -8.88178420e-15,
        -7.10542736e-15,  2.66453526e-15, -2.66453526e-15,
        -6.77236045e-15,  1.33226763e-15, -4.77395901e-15,
        -2.99760217e-15],
       [ 5.55111512e-15, -1.33226763e-14,  4.88498131e-15,
         1.33226763e-15,  0.00000000e+00, -1.55431223e-15,
         3.10862447e-15, -9.10382880e-15,  1.11022302e-15,
         0.00000000e+00],
       [-8.65973959e-15,  5.32907052e-15,  2.22044605e-15,
        -1.06581410e-14, -1.11022302e-15,  1.77635684e-15,
        -6.10622664e-16,  2.66453526e-15, -8.88178420e-16,
         2.66453526e-15],
       [-6.88338275e-15,  2.22044605e-15, -1.11022302e-14,
        -1.73194792e-14,  0.00000000e+00, -1.99840144e-15,
        -1.55431223e-15, -8.88178420e-16,  4.44089210e-16,
         4.21884749e-15],
       [ 2.66453526e-15,  2.22044605e-16, -8.88178420e-16,
        -4.44089210e-16,  5.32907052e-15,  2.22044605e-15,
        -1.88737914e-15,  8.88178420e-16, -7.77156117e-16,
        -1.

In [210]:
def riccati(x, a, c, cbar, j):
    
    return a @ x @ a.T - x + (cbar.T - a @ x @ c.T) @ np.linalg.inv(j - c @ x @ c.T) @ (cbar.T - a @ x @ c.T).T

In [259]:
Delta = scipy.linalg.sqrtm(Lambda0 - V @ C @ PP @ C.T @ V.T)
BB = ((V @ Cbar).T - A @ PP @ (V @ C).T) @ np.linalg.inv(Delta)
DD = Delta

In [260]:
ssm_base = SSR(A=A, B=B, C=C, D=noise_strength*np.eye(ssdim))
ssm_proj1 = SSR(A=A, B=B, C=C, D=noise_strength*np.eye(ssdim))
ssm_proj2 = SSR(A=A, B=BB, C=V @ C, D=DD)

In [255]:
ccm0 = ssm_base.autocorrelation(10)

In [257]:
V @ ccm0[0, ...] @ V.T

array([[20.49850374,  5.51089172],
       [ 5.51089172, 29.34478341]])

In [256]:
ccm0

array([[[ 2.50811834e+01,  5.18248254e+00, -3.98015859e+00,
          9.37358710e+00, -4.37396773e+00,  8.66091227e-01,
         -3.64542462e-01, -7.95389526e+00, -1.00728609e+01,
         -3.72680933e+00],
        [ 5.18248254e+00,  3.22002133e+01, -8.43767213e+00,
          1.85943086e+00, -7.23394150e+00,  1.19538624e+00,
          3.03002188e+00,  1.46484256e+00, -6.04340576e+00,
          3.63480148e+00],
        [-3.98015859e+00, -8.43767213e+00,  1.44152268e+01,
          5.64734706e+00,  8.12403010e-01,  3.18404801e-01,
          1.22828381e+00, -4.75735243e+00,  1.41968795e+00,
          4.30581518e+00],
        [ 9.37358710e+00,  1.85943086e+00,  5.64734706e+00,
          2.77553833e+01, -2.82273589e+00,  1.58361114e+00,
          2.28744088e+00, -7.34162690e+00,  4.76142583e+00,
          6.82534955e+00],
        [-4.37396773e+00, -7.23394150e+00,  8.12403010e-01,
         -2.82273589e+00,  2.32636127e+01, -4.39101078e+00,
          4.72710930e+00, -4.26815699e+00,  8.875959

In [261]:
ccm1 = ssm_proj1.autocorrelation(10, proj=V)
ccm2 = ssm_proj2.autocorrelation(10)

In [None]:
# VERY IMPORTANT: Verficiation of minimum phase property

In [266]:
np.abs(np.linalg.eigvals(A - BB @ np.linalg.inv(DD) @ V))

array([0.73777688, 0.57993333, 0.57993333, 0.40550238, 0.40550238,
       0.26291958, 0.26291958, 0.38505546, 0.23669226, 0.0476823 ])

In [262]:
ccm2

array([[[20.49850374,  5.51089172],
        [ 5.51089172, 29.34478341]],

       [[-0.15218347, -2.24580294],
        [-4.73528694, -8.63995405]],

       [[-0.49370838, -0.13385936],
        [ 0.70223348,  1.95747754]],

       [[-0.9917214 , -1.85316767],
        [ 1.37665605, -1.02461924]],

       [[-0.93370732,  0.12956689],
        [-0.36188606,  0.71043229]],

       [[-0.37456093, -0.73041463],
        [-0.84304363, -1.44923945]],

       [[ 0.19934368,  0.21764652],
        [-0.55198461, -0.27218266]],

       [[ 0.26242531,  0.09202791],
        [-0.41545162, -0.47392901]],

       [[ 0.218259  ,  0.18392069],
        [-0.14041895, -0.12721794]],

       [[ 0.11686732,  0.05895613],
        [ 0.03269685, -0.08536792]]])

In [263]:
ccm1

array([[[20.49850374,  5.51089172],
        [ 5.51089172, 29.34478341]],

       [[-0.15218347, -2.24580294],
        [-4.73528694, -8.63995405]],

       [[-0.49370838, -0.13385936],
        [ 0.70223348,  1.95747754]],

       [[-0.9917214 , -1.85316767],
        [ 1.37665605, -1.02461924]],

       [[-0.93370732,  0.12956689],
        [-0.36188606,  0.71043229]],

       [[-0.37456093, -0.73041463],
        [-0.84304363, -1.44923945]],

       [[ 0.19934368,  0.21764652],
        [-0.55198461, -0.27218266]],

       [[ 0.26242531,  0.09202791],
        [-0.41545162, -0.47392901]],

       [[ 0.218259  ,  0.18392069],
        [-0.14041895, -0.12721794]],

       [[ 0.11686732,  0.05895613],
        [ 0.03269685, -0.08536792]]])

In [9]:
def solve_riccati(F, G, H, J):
    
    # Form matrix pencil
    A = np.block([[F - G @ np.linalg.inv(J + J.T) @ H, np.zeros((H.shape[1], H.shape[1]))],
                  [-H.T @ np.linalg.inv(J + J.T) @ H, np.eye(H.shape[1])]])

    B = np.block([[np.eye(G.shape[0]), -G @ np.linalg.inv(J + J.T) @ G.T], \
                  [np.zeros((G.shape[0], F.shape[0])), F.T - H.T @ np.linalg.inv(J + J.T) @ G.T]])
    
    # Need to compute the stable deflating subspace of the above matrix pencil

    # For now, use scipy's built in sorting function
    AA, BB, _, _, Q, Z = scipy.linalg.ordqz(A, B, output='complex', sort='iuc')
#     # Need to get the block diagonals to obtain the eigenvalues
#     AAblocks = [AA[2*i:2*i + 2,2*i:2*i + 2] for i in range(AA.shape[0]//2)]
#     BBblocks = [BB[2*i:2*i + 2,2*i:2*i + 2] for i in range(BB.shape[0]//2)]
        
    # Identify generalized eigenvalues with absolute value < 1
#     lam = np.divide(np.linalg.eigvals(AA), np.linalg.eigvals(BB))
#     stable_eigs = np.squeeze(np.argwhere(np.abs(lam) < 1))
    n = Z.shape[0]    
    Zstable = Z[:, 0:n//2]    
    
    return Zstable[n//2:, :] @ np.linalg.inv(Zstable[0:n//2, :])    

### Comparison of Balanced Bases to Optimized Subspaces - PI is lower by balancing!

In [20]:
from dca.dca import DynamicalComponentsAnalysis as DCA
from dca.cov_util import calc_cov_from_cross_cov_mats, calc_pi_from_cov
from neurosim.utils.riccati import riccati_solve

In [22]:
# Scale matrix non-normality
ssdim = 20

# B = np.random.normal(size=(10, 10))
B = np.eye(ssdim)
C = np.eye(ssdim)
#D = 1e-3 * np.eye(ssdim)

reps = 10
t = 15

# loadings_balanced = np.zeros((reps, 10))
# loadings_optimized = np.zeros((reps, 10))

pi_cc_top2 = np.zeros(reps)
pi_balanced_proj = np.zeros(reps)
pi_opt = np.zeros(reps)

for rep in tqdm(range(reps)):

    A = 1/5 * np.random.normal(size=(ssdim, ssdim))

    if max(np.abs(np.linalg.eigvals(A))) >= 1:
        while max(np.abs(np.linalg.eigvals(A))) >= 1:
            A = 1/5 * np.random.normal(size=(ssdim, ssdim))

    ssr = SSR(A=A, B=B, C=C)
    ssr.solve_min_phase()
    ssr.solve_max_phase()
    Pmin = riccati_solve(ssr.A, ssr.C, ssr.Cbar, ssr.cov, Pinit=ssr.P, tol=1e-8)
    Qmin = riccati_solve(ssr.A.T, ssr.Cbar, ssr.C, ssr.cov, Pinit=ssr.P, tol=1e-8)

    # Cholesky decomposition
    R = np.linalg.cholesky(Pmin)
    U, s, Vh = np.linalg.svd(R.T @ Qmin @ R)

    T = scipy.linalg.sqrtm(np.diag(np.sqrt(s))) @ U.T @ np.linalg.inv(R)

    # Quick check that we are getting the right PI
    assert(np.isclose(-1/2 * np.sum(np.log(1 - s)), ssr.pi(10), atol=1e-5))

    pi_cc_top2[rep] = -1 * np.sum(np.log(1-s[0:2]))

    # Optimize using the exact covariance sequence
    ssm = SSR(A=A, B=B, C=C)

    ccm1 = ssm.autocorrelation(t)
    ccm2 = ssm.autocorrelation(2 * t)

    dcamodel = DCA(d=2, T=t)
    dcamodel.cross_covs = ccm2
    dcamodel.fit_projection()
    
    pi_opt[rep] = max(dcamodel.scores)


    # Check against the projection suggested by balanced truncation
    Cbalanced = C @ np.linalg.inv(T)
    Cbalanced = C[0:2, :]

    ssm_balanced = SSR(A = A, B=B, C=Cbalanced)
    pi_balanced_proj[rep] = ssm_balanced.pi(10)
    

100%|██████████| 10/10 [00:09<00:00,  1.07it/s]


In [23]:
pi_opt

array([2.52782579, 1.98992782, 1.55890989, 2.49144498, 3.05700805,
       1.58311454, 2.86251124, 4.48079095, 2.58158331, 2.29356703])

In [24]:
pi_balanced_proj

array([0.8207508 , 0.39164305, 0.40566156, 1.41757604, 0.93335865,
       0.34764095, 1.19588973, 2.72997262, 1.37273505, 0.72110266])