In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
import py4sid

rnd_seed = 0

# define problem size
lag_range = np.arange(8)
kl_ = np.max(lag_range)+1
p, n = 200, 4
T = 1000

dtype=np.float32

nr = 0 # number of real eigenvalues
snr = (0.1, 0.1)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.98, 0.99, 0.98, 0.99
mmap, chunksize, sso = False, np.min((p,10000)), True

data_path, save_file = '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/fits/', 'test'
verbose=True

#np.random.seed(rnd_seed)
pars_true, x, y, _, _ = gen_data(p,n,lag_range,T, nr,
                                 eig_m_r, eig_M_r, 
                                 eig_m_c, eig_M_c,
                                 mmap, chunksize,
                                 data_path,
                                 snr=snr, whiten=whiten,
                                 dtype=dtype)        
if mmap:
    y = np.memmap(data_path+'y', dtype=dtype, mode='r+', shape=(T,p))
y -= y.mean(axis=0)
pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])

plt.figure(figsize=(20,5))
plt.plot(x[:100,:])
plt.show()

np.linalg.eigvals(pars_true['A'])

y_ = y.copy()


In [None]:
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

In [None]:
import numpy as np

class Super_res():
    """ very minimalistic wrapper class for numpy array. 
    Increases resolution along first axis by factor r.
    - written only to meet basic requirements of ssidid code
    - does not work with outer indexing such as np.ix_()
    - does not support coppying, value assignment, ...    
    """
    
    def __init__(self, y, r):
        self._y = y
        self._r = r
        
        assert len(y.shape) == 2
        self.shape = (y.shape[0] * self._r, y.shape[1])
        self.dtype  = y.dtype
        
    def __getitem__(self, t):
        if isinstance(t, tuple):
            return np.atleast_2d(self._y[t[0]//self._r])[:,t[1]]
        else:
            return self._y[t//self._r]
        
        
        
""" overwriting the code for computing the time-lagged covariances """        
# only part in the whole module that requires outer indexing via np.ix_ # 
def f_l2_Hankel_comp_Q_Om(n,y,lag_range,obs_scheme,idx_a,idx_b,W,sso=False,
                          mmap=False,data_path=None,ts=None,ms=None):

    T,p = y.shape
    kl_ = np.max(lag_range)+1
    pa, pb = len(idx_a), len(idx_b)
    idx_grp = obs_scheme.idx_grp

    ts = range(T-kl_) if ts is None else ts
    ms = range(len(lag_range)) if ms is None else ms

    Qs = [np.zeros((pa,pb), dtype=y.dtype) for m in range(len(lag_range))]
    Om = [np.zeros((pa,pb), dtype=bool) for m in range(len(lag_range))]

    if sso: 
        get_obs_idx = obs_scheme.gen_get_idx_grp()
        get_coobs_intervals = obs_scheme.gen_get_coobs_intervals(lag_range)
        idx_grp = obs_scheme.idx_grp
        for j in range(len(idx_grp)):
            b = np.intersect1d(idx_grp[j], idx_b)
            b_Q = np.in1d(idx_b, b)
            for i in range(len(idx_grp)):
                a = np.intersect1d(idx_grp[i], idx_a)
                a_Q = np.in1d(idx_a, a)
                for m in ms:
                    idx_coobs_ijm = get_coobs_intervals(j,i,m) # note ordering of j,i
                    if len(idx_coobs_ijm) > 0:                    
                        
                        """ following line is the ONLY change needed ... """
                        Qs[m][np.ix_(a_Q,b_Q)] = y[idx_coobs_ijm,a].T.dot(y[idx_coobs_ijm-m,b])
                        Om[m][np.ix_(a_Q,b_Q)] = True
    else:
        get_observed = obs_scheme.gen_get_observed()
        for m in ms:
            m_ = lag_range[m]
            for t in ts:
                a = np.intersect1d(get_observed(t+m_), idx_a)
                b = np.intersect1d(get_observed(t),    idx_b)
                a_Q = np.in1d(idx_a, a)
                b_Q = np.in1d(idx_b, b)

                Qs[m][np.ix_(a_Q, b_Q)] += np.outer(y[t+m_,a], y[t,b])
                Om[m][np.ix_(a_Q, b_Q)] = True

    if np.all(W[0].shape == (len(idx_grp), len(idx_grp))):
        for m in ms:
            for i in range(len(idx_grp)):
                for j in range(len(idx_grp)):

                    a = np.in1d(idx_a, np.intersect1d(idx_grp[i], idx_a))
                    b = np.in1d(idx_b, np.intersect1d(idx_grp[j], idx_b))

                    Qs[m][np.ix_(a, b)] *= W[m][i,j]

    elif np.all(W[0].shape == (p, p)):
        for m in ms:
            Qs[m] = Qs[m] * W[m][np.ix_(idx_a,idx_b)]

    else:
        raise Exception('shape misfit for weights W[m] at time-lag m=0')

    if mmap: # probably computing the Qs is costly
        for m in range(len(lag_range)):
            np.save(data_path+'Qs_'+str(lag_range[m]), Qs[m])

    return Qs, Om        


In [None]:

# create subpopulations
nsp = 1
lag_range = np.arange(2*nsp)
#y = Super_res(y_, 1)
T, p = y.shape

#nsp -= 1
sub_pops = [np.arange(i*p//nsp, (i+1)*p//nsp) for i in range(nsp-1)]
sub_pops.append(np.arange((nsp-1)*p//nsp, p))
#sub_pops.append(np.array([]))
#nsp += 1


# pass-and-flyback pattern
reps = T//(nsp)
obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)


# up-down pattern
#reps = T//(2*nsp)
#obs_pops = np.hstack([np.hstack((np.arange(1,len(sub_pops)),np.arange(len(sub_pops)-1)[::-1])) for i in range(reps)])
#obs_pops = np.hstack([np.arange(1), obs_pops, np.arange(1,len(sub_pops))])
#obs_time = np.array([i*T//len(obs_pops) for i in range(1,len(obs_pops)+1)])

obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=sub_pops, 
                                obs_pops=obs_pops, 
                                obs_time=obs_time)

idx_a, idx_b = np.arange(p), np.arange(p)
W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
obs_scheme.gen_mask_from_scheme()
plt.figure(figsize=(20,5))
plt.imshow(obs_scheme.mask[:100,:].T, interpolation='None', aspect='auto')
plt.show()


print('computing time-lagged covariances')
Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                      idx_a=idx_a,idx_b=idx_b,W=W,sso=True,
                      mmap=mmap,data_path=data_path,ts=None,ms=None)    



In [None]:
pars_est = 'default'
#lag_range = np.arange(2)

from ssidid.icml_scripts import run_default

rnd_seed = np.random.get_state()
pars_est, traces, ts= run_default(
            alphas    = (0.01, 0.001), 
            b1s       = (0.98, 0.95), 
            a_decays  = (0.98, 0.98), 
            batch_sizes = (1, 10), 
            max_zip_sizes =  (1000,250), 
            max_iters = (50, 50 ),
            parametrizations = ('nl', 'ln'),
            pars_est=pars_est, pars_true=pars_true, n=n, 
            y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
            idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
            traces=[[], [], []], ts = [], dtype=dtype)    
print('time =', sum(ts))

In [None]:
idx_a, idx_b = np.arange(p), np.arange(p)
obs_scheme_full = ObservationScheme(p=p, T=T)
W_full = obs_scheme_full.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
Qs, _ = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme_full,
                      idx_a=idx_a,idx_b=idx_b,W=W_full,sso=True,
                      mmap=mmap,data_path=data_path,ts=None,ms=None)   

In [None]:

ev_est = np.linalg.eigvals(np.linalg.matrix_power(pars_est['A'], 5))
ev_true =  np.linalg.eigvals(np.linalg.matrix_power(pars_true['A'], 5))
plt.plot(  np.real(ev_est), np.imag(ev_est), 'ro')
plt.plot(  np.real(ev_true), np.imag(ev_true), 'kx')
plt.show()



In [None]:
pars = pars_est
C, X, R = pars['C'].copy(), pars['X'], pars['R']
A = pars['A']
Pi = pars['X'][:n, :]

for i in range(len(sub_pops)):
    if len(sub_pops[i]) > 0:
        C[sub_pops[i],:] = C[sub_pops[i],:].dot(np.linalg.matrix_power(A, np.mod(i, len(sub_pops))))

for m in range(len(lag_range)):
    #idx_a, idx_b = sub_pops[i], sub_pops[j]
    idx_a, idx_b = np.arange(p), np.arange(p)
    
    #Am = np.linalg.matrix_power(A, -np.mod(m, len(sub_pops)))
    #Am = np.eye(n)    
    #Qest = C[idx_a,:].dot(Am).dot(X[m*n : (m+1)*n, : ]).dot(C[idx_b, :].T) + (m==0) * np.diag(R)[np.ix_(idx_a, idx_b)]
    
    Qest = C[idx_a,:].dot(  np.linalg.matrix_power(A,0).dot(Pi)     ).dot(C[idx_b,:].T)
    
    Qys = np.cov(y_[m:y_.shape[0]-kl_+m,:].T, y_[:y_.shape[0]-kl_, :].T)[:p, p:]

    plt.figure(figsize=(12,7))
    plt.subplot(2,4,1)
    plt.imshow(Qs[m][np.ix_(idx_a, idx_b)], interpolation='None')
    plt.title('emp. (fast time-scale)')
    plt.colorbar()
    #plt.ylabel('m = ' _ str*kag)
    plt.subplot(2,4,5)
    #plt.plot(Qs[m][np.ix_(idx_a, idx_b)].reshape(-1), Qest.reshape(-1), '.')
    plt.plot(Qs[m][Om[m]], Qest[Om[m]], '.')
    plt.xlabel('emp. (fast time-scale)')
    plt.ylabel('est.')
    plt.subplot(2,4,2)
    plt.imshow(Qest, interpolation='None')
    plt.title('est.')
    plt.colorbar()
    plt.subplot(2,4,6)
    plt.plot(Qs[m][np.ix_(idx_a, idx_b)].reshape(-1), Qys[np.ix_(idx_a, idx_b)].reshape(-1), '.')
    plt.xlabel('emp. (fast time-scale)')
    plt.ylabel('emp. (orig. time-scale)')
    plt.subplot(2,4,3)
    plt.imshow(Qys[np.ix_(idx_a, idx_b)], interpolation='None')
    plt.colorbar()
    plt.title('emp. (orig. time-scale)')
    plt.subplot(2,4,7)
    #plt.plot(Qys[np.ix_(idx_a, idx_b)].reshape(-1), Qest.reshape(-1), '.')
    plt.plot(Qs[m][np.invert(Om[m])], Qest[np.invert(Om[m])], '.')

    plt.xlabel('emp. (orig. time-scale)')
    plt.ylabel('est.')
    
    plt.subplot(2,4,4)
    plt.imshow(Om[m], interpolation='None')
    plt.title('observation scheme at lag s =' +str(lag_range[m]))
    plt.show()


In [None]:
C = pars_est['C'].copy() 
A = pars_est['A']

for j in range(len(sub_pops)):
    C = pars_est['C'].copy() 
    for i in range(len(sub_pops)):
        if len(sub_pops[i]) > 0:
            C[sub_pops[i],:] = C[sub_pops[i],:].dot(np.linalg.matrix_power(A, np.mod(i+j, len(sub_pops))))
    print(principal_angle(C, pars_true['C']))

In [None]:
principal_angle(C, pars_true['C'])

In [None]:
principal_angle(pars_est['C'], pars_true['C'])

In [None]:
Pi = X[:n, :]
plt.subplot(1,2,1)
plt.imshow(C.dot(np.linalg.inv(pars_est['A'])).dot(Pi).dot(C.T), interpolation='None')
plt.subplot(1,2,2)
plt.imshow(Qs[5], interpolation='None')
plt.show()

In [None]:
plt.imshow(pars_est['Pi'], interpolation='None')
plt.show()

In [None]:

plt.figure(figsize=(20, 3))
plt.subplot(1,2,1)
plt.imshow(pars_true['X'].T, interpolation='None', aspect='auto', clim=[-1,1])
plt.colorbar()
plt.title('true X')
plt.subplot(1,2,2)
plt.imshow(pars_est['X'].T, interpolation='None', aspect='auto', clim=[-1,1])
plt.colorbar()
plt.title('est. X')
plt.show()

m = 5

ev_true = np.linalg.eigvals(np.linalg.matrix_power(pars_true['A'], m))
ev_est  = np.linalg.eigvals(np.linalg.matrix_power(pars_est['A'], m))

plt.plot(np.real(ev_true), np.imag(ev_true), 'kx', markersize=10)
plt.plot(np.real(ev_est ), np.imag(ev_est), 'r*', markersize=10)


In [None]:
from ssidid.icml_scripts import run_default

rnd_seed = np.random.get_state()
pars_est, traces, ts= run_default(
            alphas    = (0.01, 0.001), 
            b1s       = (0.98, 0.95), 
            a_decays  = (0.98, 0.98), 
            batch_sizes = (1, 10), 
            max_zip_sizes =  (1000,250), 
            max_iters = (100, 100 ),
            parametrizations = ('ln', 'ln'),
            pars_est='default', pars_true=pars_true, n=n, 
            y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
            idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
            traces=[[], [], []], ts = [])    

In [None]:

plt.figure(figsize=(20, 3))
plt.subplot(1,2,1)
plt.imshow(pars_true['X'].T, interpolation='None', aspect='auto', clim=[-1,1])
plt.colorbar()
plt.title('true X')
plt.subplot(1,2,2)
plt.imshow(pars_est['X'].T, interpolation='None', aspect='auto', clim=[-1,1])
plt.colorbar()
plt.title('est. X')
plt.show()

m = 1

ev_true = np.linalg.eigvals(np.linalg.matrix_power(pars_true['A'], m))
ev_est  = np.linalg.eigvals(np.linalg.matrix_power(pars_est['A'], m))

plt.plot(np.real(ev_true), np.imag(ev_true), 'kx', markersize=10)
plt.plot(np.real(ev_est ), np.imag(ev_est), 'r*', markersize=10)


In [None]:
pars_est['R']

In [None]:

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

print('per-subpops principal angles')
C = pars_est['C'].copy()
print(principal_angle(pars_true['C'][sub_pops[0],:], C[sub_pops[0],:]))
print(principal_angle(pars_true['C'][sub_pops[1],:], C[sub_pops[1],:]))

print('final principal angles')
C = pars_est['C'].copy()
print(principal_angle(pars_true['C'], C))

print('final principal angles (sign-flipped)')
C = pars_est['C'].copy()
C[sub_pops[0],:] *= -1
print(principal_angle(pars_true['C'], C))


In [None]:
pars = pars_true
C, Pi, R = pars['C'], pars['Pi'], pars['R']
Qgt = C.dot(Pi).dot(C.T) + np.diag(R)
plt.imshow(Qgt * np.invert(Om[0]), interpolation='None')
plt.show()

pars = pars_est
C, Pi, R = pars['C'], pars['Pi'], pars['R']
#C[sub_pops[2],:] *= -1
Q = C.dot(Pi).dot(C.T) + np.diag(R)
plt.imshow(Q * np.invert(Om[0]), interpolation='None')
plt.show()

res = np.zeros((nsp, nsp))
for i in range(nsp):
    for j in range(nsp):
        Qij = Q[np.ix_(sub_pops[i], sub_pops[j])]
        Qgtij = Qgt[np.ix_(sub_pops[i], sub_pops[j])]
        res[i,j] = np.corrcoef(Qij.reshape(-1), Qgtij.reshape(-1))[0,1]
 
plt.imshow(res, interpolation='None')
plt.colorbar()
plt.show()
#np.corrcoef(np.cov(y.T)[np.invert(Om[0])], Q[np.invert(Om[0])])


#np.corrcoef(np.cov(y.T)[np.invert(Om[0])], np.cov(y[1:T,:].T, y[:T-1,:].T)[:p,p:][np.invert(Om[0])])

In [None]:
C = pars_est['C'].copy()

C[np.ix_(sub_pops[0], [0,2,3])] *= -1 

print(principal_angle(pars_true['C'], C))


In [None]:
plt.imshow(Qs[1][np.ix_(sub_pops[1],sub_pops[0])])

In [None]:
i = 0
m_ = 1
Qest = pars_est['C'][sub_pops[i],:].dot( pars_est['X'][n:2*n,:]).dot(pars_est['C'][sub_pops[i-m_],:].T)
Qgd =  pars_true['C'][sub_pops[i],:].dot(pars_true['A'].dot(pars_true['Pi'])).dot(pars_true['C'][sub_pops[i-m_],:].T)

np.corrcoef(Qest.reshape(-1), Qgd.reshape(-1))


In [None]:
i = 0
principal_angle(pars_true['C'][np.hstack((sub_pops[i], sub_pops[i-1])),:],
                pars_est['C'][np.hstack((sub_pops[i], sub_pops[i-1])),:])                

In [None]:
C = pars_est['C'].copy()
#C[sub_pops[1],:] *= -1
for i in range(len(sub_pops)):
    print( principal_angle(pars_true['C'][sub_pops[i],:], C[sub_pops[i],:]) )

In [None]:
kl_ = np.max(lag_range)+1
A, C = py4sid.estimate_parameters_moments(y,kl_,n)

In [None]:
kl_ = np.max(lag_range)+1
A, C = py4sid.estimate_parameters_moments(y,2*n,n)

In [None]:

plt.plot(np.arange(1, n+1), principal_angle(pars_true['C'], C), 'go') 
plt.plot(np.arange(1, n+1), principal_angle(pars_true['C'], pars_est['C']), 'bo')
plt.legend(('classical Hankel SSID', 'SSIDID'), loc=2)
plt.xlabel('# latent dim')
plt.ylabel('principal angle')
plt.axis([0.5, n+0.5, 0, 1.05])
plt.show()

In [None]:
from scipy import linalg as la

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, np.inf, 20
a, b1, b2, e = 0.01, 0.98, 0.99, 1e-8
a_decay = 0.98
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
    
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W, 
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=pars_track,save_every=np.inf, data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()


In [None]:
batch_size, max_zip_size, max_iter = 10, np.inf, 20
a, b1, b2, e = 0.005, 0.9, 0.99, 1e-8
a_decay = 0.98
a_R = 1 * a

def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W, 
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=pars_track,save_every=np.inf, data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()


In [None]:

principal_angle(pars_true['C'], C),  principal_angle(pars_true['C'], pars_est['C'])

In [None]:
pars_est = 'default'
lag_range = np.arange(1,11)

In [None]:
from scipy import linalg as la

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, np.inf, 20
a, b1, b2, e = 0.01, 0.98, 0.99, 1e-8
a_decay = 0.98
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
    
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W, 
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=pars_track,save_every=np.inf, data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()


In [None]:
batch_size, max_zip_size, max_iter = 10, np.inf, 20
a, b1, b2, e = 0.005, 0.9, 0.99, 1e-8
a_decay = 0.98
a_R = 1 * a

def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W, 
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,a_decay=a_decay,
                                      batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                      pars_track=pars_track,save_every=np.inf, data_path=data_path)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()


In [None]:

plt.plot(np.arange(1, n+1), principal_angle(pars_true['C'], C), 'go') 
plt.plot(np.arange(1, n+1), principal_angle(pars_true['C'], pars_est['C']), 'bo')
plt.legend(('classical Hankel SSID', 'SSIDID'), loc=2)
plt.xlabel('# latent dim')
plt.ylabel('principal angle')
plt.axis([0.5, n+0.5, 0, 1.05])
plt.show()