In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
import py4sid

#np.random.seed(0)

# define problem size
p, n, T = 1000, 10, 100000
lag_range = np.arange(10)

# I/O matter
mmap, chunksize = True, np.min((p,10000))
data_path, save_file = '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/fits/', 'test'
verbose=True

# create subpopulations
obs_scheme = ObservationScheme(p=p,T=T,
                               sub_pops=(np.arange(p),), 
                               obs_pops=(0,),
                               obs_time=(T,))

# draw system matrices 
print('(p,n,T) = ', (p,n,T), '\n')
nr = 2 # number of real eigenvalues
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.8, 0.99, 0.8, 0.99
pars_true, x, y, _, _ = gen_data(p,n,lag_range,T, nr,
                                             eig_m_r, eig_M_r, 
                                             eig_m_c, eig_M_c,
                                             mmap, chunksize,
                                             data_path)
pars_est='default'

idx_a = np.sort(np.random.choice(p, np.minimum(p,1000), replace=False))
idx_b = idx_a.copy()

W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)
print('computing time-lagged covariances')
Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                      idx_a=idx_a,idx_b=idx_b,W=W,sso=True,
                      mmap=mmap,data_path=data_path,ts=None,ms=None)

In [None]:
kl_ = np.max(lag_range)+1
A, C = py4sid.estimate_parameters_moments(y,kl_,n)

In [None]:
from scipy import linalg as la

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, np.inf, 20
a, b1, b2, e = 0.01, 0.98, 0.99, 1e-8
a_decay = 0.98
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
    
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W, return_aux = True,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=pars_track,save_every=np.inf, data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()


In [None]:

principal_angle(pars_true['C'], C),  principal_angle(pars_true['C'], pars_est['C'])

In [None]:
mp,vp = traces[2]

In [None]:
[mp[i].shape for i in range(len(mp))]

In [None]:
[mp[i].shape for i in range(len(vp))]