# Experiment 'e4': stitching, scale, EM and us



In [None]:
%matplotlib inline
import os
os.chdir("/home/mackelab/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
import ssm_scripts
import ssm_fit

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.icml_scripts import run_default
from ssidid import ObservationScheme, progprint_xrange
from ssidid.utility import gen_pars

import time
import scipy as sp
from scipy import stats
from scipy import linalg as la
import numpy as np
import matplotlib.pyplot as plt
from ssidid import ObservationScheme, progprint_xrange
from ssidid.utility import draw_data
from ssm_scripts import setup_fit_lds

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


ps = np.array([1e2, 1e3, 1e4, 1e5],dtype=int)
data_path = '/home/mackelab/Desktop/Projects/Stitching/results/icml_e2/'
dtype=np.float32
mmap, verbose = False, True
whiten = False

###################################
# select simulation setup here !  #
i = 1                             #
p,n,T = ps[i],10, 30000           #
snr = (1.0, 1.0)                  #
###################################

max_iter_EM = 50

nr = 0 # number of real eigenvalues
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.95, 0.99, 0.95, 0.99
ev_r = np.linspace(eig_m_r, eig_M_r, nr)

data_path = data_path + 'p' + str(p) + '/'
for rnd_seed in range(23,30):

    np.random.seed(rnd_seed)
    nc, nc_u = n - nr, (n - nr)//2
    ev_c = np.exp(2 * 1j * np.pi * np.random.vonmises(mu=0, kappa=1000, size=nc_u))
    ev_c = np.linspace(eig_m_c, eig_M_c, (n - nr)//2) * ev_c
    pars_true = gen_pars(p,n, nr, ev_r, ev_c, snr, whiten, dtype=dtype)
    pars_true['d'], pars_true['mu0'], pars_true['V0'] = np.zeros(p), np.zeros(n), pars_true['Pi'].copy()
    pars_true['C'] = la.orth(pars_true['C']) * np.sqrt(p) / np.sqrt(n)
    pars_true['R'] = np.asarray(np.random.uniform(size=p, low=snr[0], high=snr[1]), dtype=dtype)
    x,y = draw_data(pars_true,T, dtype=dtype)
    y -= y.mean(axis=0)

    idx_a = np.sort(np.random.choice(p, 1000, replace=False)) if p > 1000 else np.arange(p)
    idx_b = idx_a.copy()

    plt.subplot(1,2,1)
    plt.imshow(pars_true['A'], interpolation='None')
    plt.colorbar()
    plt.subplot(1,2,2)
    plt.imshow(np.corrcoef(x.T), interpolation='None')
    del x
    plt.colorbar()
    plt.show()
    
    
    # start fitting    
    rnd_seed_fit = np.random.get_state()    
    pars_init={
        'C' : np.asarray(la.orth(np.random.normal(size=(p,n))),dtype=dtype) * np.sqrt(p) / np.sqrt(n),
        'A' : np.asarray(np.diag(np.linspace(0.89, 0.91, n)), dtype=dtype),
        'Q' : np.asarray(np.eye(n), dtype=dtype),
        'R' : 2*np.ones(p, dtype=dtype)
    }    

    # EM    
    likes = np.zeros(max_iter_EM)
    res = np.zeros((max_iter_EM, n+1))
    obs_scheme = {'sub_pops': [list(range(0,int(0.51*p))),list(range(int(0.49*p),p))],
             'obs_pops': [0,1],
             'obs_time': [T//2,T]}
    fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                            u=None, 
                            max_iter=max_iter_EM,
                            epsilon=np.log(1.001), 
                            eps_cov=1e-3,
                            plot_flag=False, 
                            trace_pars_flag=True, 
                            trace_stats_flag=False, 
                            diag_R_flag=True,
                            use_A_flag=True, 
                            use_B_flag=False)

    # fit the model to data          
    print('fitting model to data')
    pars_init[ 'd' ] = np.zeros(p, dtype=dtype)
    pars_init['mu0'] = np.zeros(n ,dtype=dtype)
    pars_init[ 'V0'] = np.eye(  n ,dtype=dtype)

    pars_hat = pars_init    
    t = time.time()
    try:
        pars_hat['B'] = np.empty((n,0), dtype=dtype)
        pars_hat,ll = fit_lds(x_dim=n,
                              pars=pars_hat, 
                              obs_scheme=obs_scheme,
                              save_file=None)
        for i_ in range(len(pars_hat['Cs'])-1):
            res[i_,1:] = principal_angle(pars_hat['Cs'][i_+1], pars_true['C'])
        pars_hat['Cs'] = None
        pars_hat['As'] = None
        pars_hat['Bs'] = None
        pars_hat['Qs'] = None
        pars_hat['mu0s'] = None
        pars_hat['V0s'] = None
        pars_hat['ds'] = None

        likes = np.array(ll[1:])
        elapsed_time = time.time() - t
    except:
        elapsed_time = np.nan 
        likes = np.zeros(0)
        print('\n ')
        print('EM BROKE')
        print('\n ')
    print('elapsed time for fitting is')
    print(elapsed_time)
    pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                       pars_hat['Q'])        
        
    t = time.time() - t    
    print('fitting time: ', t)
    
    plt.figure()
    plt.subplot(1,2,1)
    plt.plot(res[:,1:])
    plt.title('final princ. angles')
    plt.subplot(1,2,2)
    plt.plot(likes)
    plt.show()
    
    save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,
                 'obs_scheme' : obs_scheme, 'y' : None,
                 'pars_true' : pars_true, 'pars_est_EM' : pars_hat, 
                 'traces' : [likes, res], 'ts_EM': [t], 
                 'rnd_seed' : rnd_seed_fit
                }    
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + '_seed' + str(rnd_seed) + 'e2_EM'
    np.savez(data_path + file_name, save_dict, allow_pickle=False)    

    
    # ssidid
        
    lag_range = np.arange(2*n)
    sso = True
    obs_scheme = ObservationScheme(p=p, T=T, 
                                    sub_pops= (np.arange(0,int(0.51*p)),np.arange(int(0.49*p),p)), 
                                    obs_pops=(0,1), 
                                    obs_time=(T//2,T))
    obs_scheme.comp_subpop_stats()    

    W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
    for m in range(len(lag_range)):
        W[m][0,1] = 0
        W[m][1,0] = 0 

    print('computing time-lagged covariances')
    Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                          idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                          mmap=mmap,data_path=data_path,ts=None,ms=None)
    for m in range(len(lag_range)):
        Om[m][np.ix_(np.where(np.in1d(idx_a, obs_scheme.idx_grp[0]))[0], 
                     np.where(np.in1d(idx_b, obs_scheme.idx_grp[1]))[0])] = False
        Om[m][np.ix_(np.where(np.in1d(idx_a, obs_scheme.idx_grp[1]))[0], 
                     np.where(np.in1d(idx_b, obs_scheme.idx_grp[0]))[0])] = False

    rnd_seed_fit = np.random.get_state()    
    pars_init['B'] = np.eye(n) # somewhat unfortunate, used B to denote input matrix for ssid_fit, and sqrt(Pi) here
    pars_est, traces, ts= run_default(
                alphas    = (0.1, 0.0), 
                b1s       = (0.9, 0.9), 
                a_decays  = (0.95, 0.9999), 
                batch_sizes = (1, 1), 
                max_zip_sizes =  (T//100,10), 
                max_iters = (100, 10 ),
                parametrizations = ('nl', 'ln'),
                pars_est=pars_init, pars_true=pars_true, n=n, 
                y=y, sso=True, obs_scheme=obs_scheme, lag_range=lag_range, 
                idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                traces=[[], [], []], ts = [])          

    save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                 'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : None,
                 'pars_true' : pars_true, 'pars_est' : pars_est,
                 'traces' : traces, 'ts':ts, 
                 'rnd_seed' : rnd_seed_fit
                }
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + '_seed' + str(rnd_seed) + 'e2_final'
    np.savez(data_path + file_name, save_dict)              

# test, fully obs (high-noise)

In [None]:
%matplotlib inline
import os
os.chdir("/home/mackelab/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
import ssm_scripts
import ssm_fit

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.icml_scripts import run_default
from ssidid import ObservationScheme, progprint_xrange

import time
import scipy as sp
from scipy import linalg as la
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
from ssidid import ObservationScheme, progprint_xrange
from ssidid.utility import draw_data
from ssm_scripts import setup_fit_lds
from scipy import stats


def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


ps = np.array([1e2, 1e3, 1e4, 1e5],dtype=int)
data_path = '/home/mackelab/Desktop/Projects/Stitching/results/'

i = 2

p,n,T = ps[i],10,2000
snr = (9.0, 9.0)
dtype=np.float
mmap, verbose = False, True
whiten = False


for rnd_seed in range(20,21):
    

    nr = np.mod(n,2) # number of real eigenvalues
    eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.95, 0.99, 0.9, 0.99

    np.random.seed(rnd_seed)
    nc, nc_u = n - nr, (n - nr)//2
    ev_r = np.linspace(eig_m_r, eig_M_r, nr)
    ev_c = np.exp(2 * 1j * np.pi * np.random.vonmises(mu=0, kappa=1000, size=nc_u))
    ev_c = np.linspace(eig_m_c, eig_M_c, (n - nr)//2) * ev_c
    Q, D = np.zeros((n,n), dtype=complex), np.zeros(n, dtype=complex)
    D[:nr] = np.linspace(0.8, 0.99, nr) if ev_r is None else ev_r 
    Q[:,:nr] = np.random.normal(size=(n,nr))
    Q[:,:nr] /= np.sqrt((Q[:,:nr]**2).sum(axis=0)).reshape(1,nr)
    ev_c_r, ev_c_c = np.real(ev_c), np.imag(ev_c) 
    
    #V = np.random.normal(size=(n,n))
    V = 0.9 * np.eye(n) + 0.1 * np.random.normal(size=(n,n))
    #V = np.maximum(-0.2, V)
    #V = np.minimum( 0.2, V)
    for i in range(nc_u):
        Vi = V[:,i*2:(i+1)*2] / np.sqrt( np.sum(V[:,i*2:(i+1)*2]**2) )
        Q[:,nr+i], Q[:,nr+nc_u+i] = Vi[:,0]+1j*Vi[:,1], Vi[:,0]-1j*Vi[:,1] 
        D[nr+i], D[nr+i+nc_u] = ev_c_r[i]+1j*ev_c_c[i], ev_c_r[i]-1j*ev_c_c[i]
    A = Q.dot(np.diag(D)).dot(np.linalg.inv(Q))
    assert np.allclose(A, np.real(A))
    A = np.real(A)

    #A = np.random.normal(size=(n,n))
    #l, V = np.linalg.eig(A)
    #l  = 0.9 * (l / np.abs(l))
    #A = np.real(V.dot(np.diag(l)).dot(np.linalg.inv(V)))

    #A = np.random.normal(size=(n,n))
    #A = np.maximum(-.25, A)
    #A = np.minimum(0.25, A)

    # generate innovation noise covariance matrix Q
    Q = np.atleast_2d(stats.wishart(n, np.eye(n)).rvs()/(n)) # np.eye(n)
    Pi = np.atleast_2d(sp.linalg.solve_discrete_lyapunov(A, Q))
    #L = np.linalg.cholesky(Pi)
    #Linv = np.linalg.inv(L)
    #A, Q = Linv.dot(A).dot(L), Linv.dot(Q).dot(Linv.T)
    #Pi = np.atleast_2d(sp.linalg.solve_discrete_lyapunov(A, Q))

    pars_true = { 'A' : A, 'Q' : Q, 'Pi' : Pi, 'C' : np.random.normal(size=(p,n)) / np.sqrt(n) }
    pars_true['d'], pars_true['mu0'], pars_true['V0'] = np.zeros(p), np.zeros(n), pars_true['Pi'].copy()
    pars_true['C'] = la.orth(pars_true['C']) * np.sqrt(p) / np.sqrt(n)
    pars_true['R'] = np.asarray(np.random.uniform(size=p, low=snr[0], high=snr[1]), dtype=dtype)

    #np.random.seed(rnd_seed)
    x, _ = draw_data(pars_true, int(1.5*T))
    x = x[T//2:,:]

    x = (x - x.mean(axis=0))/np.std(x,axis=0) 

    eps = np.sqrt(pars_true['R']).reshape(1,p) * np.random.normal(size=(x.shape[0],p))
    y = x.dot(pars_true['C'].T) + eps
    y -= y.mean(axis=0)

    plt.subplot(1,2,1)
    plt.imshow(pars_true['A'], interpolation='None')
    plt.colorbar()
    plt.subplot(1,2,2)
    plt.imshow(np.corrcoef(x.T), interpolation='None')
    plt.colorbar()
    plt.show()

    print('eigvals cov(x) ', np.linalg.eigvals(np.cov(x.T)))
    
    # start fitting
    
    pars_init = {'A'  : np.diag(np.linspace(0.89, 0.91, n)),
         'Q' : np.eye(n,dtype=dtype),
         'd' : np.zeros(p, dtype=dtype),
         'mu0' : np.zeros(n),
         'V0' : np.eye(n),
         'C'  : np.random.normal(size=(p,n))/np.sqrt(n),
         'R'  : 1. * np.ones(p,dtype=dtype),
         'B'  : np.empty((n, 0))}

    max_iter = 10
    likes = np.zeros(max_iter)
    res = np.zeros((max_iter, n+1))

    t = time.time()

    rnd_seed_fit = np.random.get_state()
    from ssm_scripts import setup_fit_lds
    obs_scheme = {'sub_pops': [list(range(0,p)),list(range(0,p))],
             'obs_pops': [0,1],
             'obs_time': [T//2,T]}
    fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                            u=None, 
                            max_iter=1,
                            epsilon=np.log(1.001), 
                            eps_cov=0,
                            plot_flag=False, 
                            trace_pars_flag=False, 
                            trace_stats_flag=False, 
                            diag_R_flag=True,
                            use_A_flag=True, 
                            use_B_flag=False)

    # fit the model to data          
    print('fitting model to data')
    pars_hat = pars_init
    t = time.time()
    for i_ in range(max_iter):
        pars_hat['B'] = np.empty((n,0))
        pars_hat,ll = fit_lds(x_dim=n,
                              pars=pars_hat, 
                              obs_scheme=obs_scheme,
                              save_file=None)
        res[i_,1:] = principal_angle(pars_hat['C'], pars_true['C'])
        likes[i_] = ll[-1]
    elapsed_time = time.time() - t
    print('elapsed time for fitting is')
    print(elapsed_time)

    pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                       pars_hat['Q'])

    
    
        
    t = time.time() - t
    
    print('fitting time: ', t)
    
    plt.figure()
    plt.subplot(1,2,1)
    plt.plot(res[:,1:])
    plt.title('final princ. angles')
    plt.subplot(1,2,2)
    plt.plot(likes)
    plt.show()
        
    lag_range = np.arange(2*n)
    sso = True
    idx_a = np.sort(np.arange(np.minimum(200,p)))
    idx_b = idx_a.copy()
    obs_scheme = ObservationScheme(p=p, T=T, 
                                    sub_pops=(np.arange(p),), 
                                    obs_pops=(0,), 
                                    obs_time=(T,))
    obs_scheme.comp_subpop_stats()    

    W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
    print('computing time-lagged covariances')
    Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                          idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                          mmap=mmap,data_path=data_path,ts=None,ms=None)

    pars_est={
        'C' : la.orth(np.random.normal(size=(p,n))) * np.sqrt(p) / np.sqrt(n),
        'A' : np.diag(np.linspace(0.9, 0.9, n)),
        'B' : np.eye(n),
        'Pi': np.eye(n),
        'R' : np.ones(p)
    }
    pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])
    np.linalg.norm(pars_est['C'])
    
    pars_est, traces, ts= run_default(
                alphas    = (0.1, 0.0), 
                b1s       = (0.9, 0.9), 
                a_decays  = (0.95, 0.9999), 
                batch_sizes = (1, 1), 
                max_zip_sizes =  (T//100,10), 
                max_iters = (100, 10 ),
                parametrizations = ('nl', 'ln'),
                pars_est=pars_est, pars_true=pars_true, n=n, 
                y=y, sso=True, obs_scheme=obs_scheme, lag_range=lag_range, 
                idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                traces=[[], [], []], ts = [])          


In [None]:
%matplotlib inline
import os
os.chdir("/home/mackelab/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
import ssm_scripts
import ssm_fit

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.icml_scripts import run_default
from ssidid import ObservationScheme, progprint_xrange

import time
import scipy as sp
from scipy import linalg as la
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
from ssidid import ObservationScheme, progprint_xrange
from ssidid.utility import draw_data
from ssm_scripts import setup_fit_lds
from scipy import stats


def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


ps = np.array([1e2, 1e3, 1e4, 1e5],dtype=int)
data_path = '/home/mackelab/Desktop/Projects/Stitching/results/'

i = 2

p,n,T = ps[i],10,20000
snr = (9.0, 9.0)
dtype=np.float
mmap, verbose = False, True
whiten = False


for rnd_seed in range(20,21):
    

    nr = np.mod(n,2) # number of real eigenvalues
    eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.95, 0.99, 0.9, 0.99

    np.random.seed(rnd_seed)
    nc, nc_u = n - nr, (n - nr)//2
    ev_r = np.linspace(eig_m_r, eig_M_r, nr)
    ev_c = np.exp(2 * 1j * np.pi * np.random.vonmises(mu=0, kappa=1000, size=nc_u))
    ev_c = np.linspace(eig_m_c, eig_M_c, (n - nr)//2) * ev_c
    Q, D = np.zeros((n,n), dtype=complex), np.zeros(n, dtype=complex)
    D[:nr] = np.linspace(0.8, 0.99, nr) if ev_r is None else ev_r 
    Q[:,:nr] = np.random.normal(size=(n,nr))
    Q[:,:nr] /= np.sqrt((Q[:,:nr]**2).sum(axis=0)).reshape(1,nr)
    ev_c_r, ev_c_c = np.real(ev_c), np.imag(ev_c) 
    
    #V = np.random.normal(size=(n,n))
    V = 0.9 * np.eye(n) + 0.1 * np.random.normal(size=(n,n))
    #V = np.maximum(-0.2, V)
    #V = np.minimum( 0.2, V)
    for i in range(nc_u):
        Vi = V[:,i*2:(i+1)*2] / np.sqrt( np.sum(V[:,i*2:(i+1)*2]**2) )
        Q[:,nr+i], Q[:,nr+nc_u+i] = Vi[:,0]+1j*Vi[:,1], Vi[:,0]-1j*Vi[:,1] 
        D[nr+i], D[nr+i+nc_u] = ev_c_r[i]+1j*ev_c_c[i], ev_c_r[i]-1j*ev_c_c[i]
    A = Q.dot(np.diag(D)).dot(np.linalg.inv(Q))
    assert np.allclose(A, np.real(A))
    A = np.real(A)

    #A = np.random.normal(size=(n,n))
    #l, V = np.linalg.eig(A)
    #l  = 0.9 * (l / np.abs(l))
    #A = np.real(V.dot(np.diag(l)).dot(np.linalg.inv(V)))

    #A = np.random.normal(size=(n,n))
    #A = np.maximum(-.25, A)
    #A = np.minimum(0.25, A)

    # generate innovation noise covariance matrix Q
    Q = np.atleast_2d(stats.wishart(n, np.eye(n)).rvs()/(n)) # np.eye(n)
    Pi = np.atleast_2d(sp.linalg.solve_discrete_lyapunov(A, Q))
    #L = np.linalg.cholesky(Pi)
    #Linv = np.linalg.inv(L)
    #A, Q = Linv.dot(A).dot(L), Linv.dot(Q).dot(Linv.T)
    #Pi = np.atleast_2d(sp.linalg.solve_discrete_lyapunov(A, Q))

    pars_true = { 'A' : A, 'Q' : Q, 'Pi' : Pi, 'C' : np.random.normal(size=(p,n)) / np.sqrt(n) }
    pars_true['d'], pars_true['mu0'], pars_true['V0'] = np.zeros(p), np.zeros(n), pars_true['Pi'].copy()
    pars_true['C'] = la.orth(pars_true['C']) * np.sqrt(p) / np.sqrt(n)
    pars_true['R'] = np.asarray(np.random.uniform(size=p, low=snr[0], high=snr[1]), dtype=dtype)

    #np.random.seed(rnd_seed)
    x, _ = draw_data(pars_true, int(1.5*T))
    x = x[T//2:,:]

    x = (x - x.mean(axis=0))/np.std(x,axis=0) 

    eps = np.sqrt(pars_true['R']).reshape(1,p) * np.random.normal(size=(x.shape[0],p))
    y = x.dot(pars_true['C'].T) + eps
    y -= y.mean(axis=0)

    plt.subplot(1,2,1)
    plt.imshow(pars_true['A'], interpolation='None')
    plt.colorbar()
    plt.subplot(1,2,2)
    plt.imshow(np.corrcoef(x.T), interpolation='None')
    plt.colorbar()
    plt.show()

    print('eigvals cov(x) ', np.linalg.eigvals(np.cov(x.T)))
    
    # start fitting
    
    pars_init = {'A'  : np.diag(np.linspace(0.89, 0.91, n)),
         'Q' : np.eye(n,dtype=dtype),
         'd' : np.zeros(p, dtype=dtype),
         'mu0' : np.zeros(n),
         'V0' : np.eye(n),
         'C'  : np.random.normal(size=(p,n))/np.sqrt(n),
         'R'  : 1. * np.ones(p,dtype=dtype),
         'B'  : np.empty((n, 0))}

    max_iter = 10
    likes = np.zeros(max_iter)
    res = np.zeros((max_iter, n+1))

    t = time.time()

    rnd_seed_fit = np.random.get_state()
    from ssm_scripts import setup_fit_lds
    obs_scheme = {'sub_pops': [list(range(0,p)),list(range(0,p))],
             'obs_pops': [0,1],
             'obs_time': [T//2,T]}
    fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                            u=None, 
                            max_iter=1,
                            epsilon=np.log(1.001), 
                            eps_cov=0,
                            plot_flag=False, 
                            trace_pars_flag=False, 
                            trace_stats_flag=False, 
                            diag_R_flag=True,
                            use_A_flag=True, 
                            use_B_flag=False)

    # fit the model to data          
    print('fitting model to data')
    pars_hat = pars_init
    t = time.time()
    for i_ in range(max_iter):
        pars_hat['B'] = np.empty((n,0))
        pars_hat,ll = fit_lds(x_dim=n,
                              pars=pars_hat, 
                              obs_scheme=obs_scheme,
                              save_file=None)
        res[i_,1:] = principal_angle(pars_hat['C'], pars_true['C'])
        likes[i_] = ll[-1]
    elapsed_time = time.time() - t
    print('elapsed time for fitting is')
    print(elapsed_time)

    pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                       pars_hat['Q'])

    
    
        
    t = time.time() - t
    
    print('fitting time: ', t)
    
    plt.figure()
    plt.subplot(1,2,1)
    plt.plot(res[:,1:])
    plt.title('final princ. angles')
    plt.subplot(1,2,2)
    plt.plot(likes)
    plt.show()
        
    lag_range = np.arange(2*n)
    sso = True
    idx_a = np.sort(np.arange(np.minimum(200,p)))
    idx_b = idx_a.copy()
    obs_scheme = ObservationScheme(p=p, T=T, 
                                    sub_pops=(np.arange(p),), 
                                    obs_pops=(0,), 
                                    obs_time=(T,))
    obs_scheme.comp_subpop_stats()    

    W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
    print('computing time-lagged covariances')
    Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                          idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                          mmap=mmap,data_path=data_path,ts=None,ms=None)

    pars_est={
        'C' : la.orth(np.random.normal(size=(p,n))) * np.sqrt(p) / np.sqrt(n),
        'A' : np.diag(np.linspace(0.9, 0.9, n)),
        'B' : np.eye(n),
        'Pi': np.eye(n),
        'R' : np.ones(p)
    }
    pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])
    np.linalg.norm(pars_est['C'])
    
    pars_est, traces, ts= run_default(
                alphas    = (0.1, 0.0), 
                b1s       = (0.9, 0.9), 
                a_decays  = (0.95, 0.9999), 
                batch_sizes = (1, 1), 
                max_zip_sizes =  (T//100,10), 
                max_iters = (100, 10 ),
                parametrizations = ('nl', 'ln'),
                pars_est=pars_est, pars_true=pars_true, n=n, 
                y=y, sso=True, obs_scheme=obs_scheme, lag_range=lag_range, 
                idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                traces=[[], [], []], ts = [])          


In [None]:
%matplotlib inline
import os
os.chdir("/home/mackelab/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
import ssm_scripts
import ssm_fit

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.icml_scripts import run_default
from ssidid import ObservationScheme, progprint_xrange

import time
import scipy as sp
from scipy import linalg as la
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
from ssidid import ObservationScheme, progprint_xrange
from ssidid.utility import draw_data
from ssm_scripts import setup_fit_lds
from scipy import stats


def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


ps = np.array([1e2, 1e3, 1e4, 1e5],dtype=int)
data_path = '/home/mackelab/Desktop/Projects/Stitching/results/'

i = 0

p,n,T = ps[i],3,20000
snr = (9.0, 9.0)
dtype=np.float
mmap, verbose = False, True
whiten = False


for rnd_seed in range(20,30):
    

    nr = np.mod(n,2) # number of real eigenvalues
    eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.95, 0.99, 0.9, 0.99

    np.random.seed(rnd_seed)
    nc, nc_u = n - nr, (n - nr)//2
    ev_r = np.linspace(eig_m_r, eig_M_r, nr)
    ev_c = np.exp(2 * 1j * np.pi * np.random.vonmises(mu=0, kappa=1000, size=nc_u))
    ev_c = np.linspace(eig_m_c, eig_M_c, (n - nr)//2) * ev_c
    Q, D = np.zeros((n,n), dtype=complex), np.zeros(n, dtype=complex)
    D[:nr] = np.linspace(0.8, 0.99, nr) if ev_r is None else ev_r 
    Q[:,:nr] = np.random.normal(size=(n,nr))
    Q[:,:nr] /= np.sqrt((Q[:,:nr]**2).sum(axis=0)).reshape(1,nr)
    ev_c_r, ev_c_c = np.real(ev_c), np.imag(ev_c) 
    
    #V = np.random.normal(size=(n,n))
    V = 0.9 * np.eye(n) + 0.1 * np.random.normal(size=(n,n))
    #V = np.maximum(-0.2, V)
    #V = np.minimum( 0.2, V)
    for i in range(nc_u):
        Vi = V[:,i*2:(i+1)*2] / np.sqrt( np.sum(V[:,i*2:(i+1)*2]**2) )
        Q[:,nr+i], Q[:,nr+nc_u+i] = Vi[:,0]+1j*Vi[:,1], Vi[:,0]-1j*Vi[:,1] 
        D[nr+i], D[nr+i+nc_u] = ev_c_r[i]+1j*ev_c_c[i], ev_c_r[i]-1j*ev_c_c[i]
    A = Q.dot(np.diag(D)).dot(np.linalg.inv(Q))
    assert np.allclose(A, np.real(A))
    A = np.real(A)

    #A = np.random.normal(size=(n,n))
    #l, V = np.linalg.eig(A)
    #l  = 0.9 * (l / np.abs(l))
    #A = np.real(V.dot(np.diag(l)).dot(np.linalg.inv(V)))

    #A = np.random.normal(size=(n,n))
    #A = np.maximum(-.25, A)
    #A = np.minimum(0.25, A)

    # generate innovation noise covariance matrix Q
    Q = np.atleast_2d(stats.wishart(n, np.eye(n)).rvs()/(n)) # np.eye(n)
    Pi = np.atleast_2d(sp.linalg.solve_discrete_lyapunov(A, Q))
    #L = np.linalg.cholesky(Pi)
    #Linv = np.linalg.inv(L)
    #A, Q = Linv.dot(A).dot(L), Linv.dot(Q).dot(Linv.T)
    #Pi = np.atleast_2d(sp.linalg.solve_discrete_lyapunov(A, Q))

    pars_true = { 'A' : A, 'Q' : Q, 'Pi' : Pi, 'C' : np.random.normal(size=(p,n)) / np.sqrt(n) }
    pars_true['d'], pars_true['mu0'], pars_true['V0'] = np.zeros(p), np.zeros(n), pars_true['Pi'].copy()
    pars_true['C'] = la.orth(pars_true['C']) * np.sqrt(p) / np.sqrt(n)
    pars_true['R'] = np.asarray(np.random.uniform(size=p, low=snr[0], high=snr[1]), dtype=dtype)

    #np.random.seed(rnd_seed)
    x, _ = draw_data(pars_true, int(1.5*T))
    x = x[T//2:,:]

    x = (x - x.mean(axis=0))/np.std(x,axis=0) 

    eps = np.sqrt(pars_true['R']).reshape(1,p) * np.random.normal(size=(x.shape[0],p))
    y = x.dot(pars_true['C'].T) + eps
    y -= y.mean(axis=0)

    plt.subplot(1,2,1)
    plt.imshow(pars_true['A'], interpolation='None')
    plt.colorbar()
    plt.subplot(1,2,2)
    plt.imshow(np.corrcoef(x.T), interpolation='None')
    plt.colorbar()
    plt.show()

    print('eigvals cov(x) ', np.linalg.eigvals(np.cov(x.T)))
    
    # start fitting
    
    pars_init = {'A'  : np.diag(np.linspace(0.89, 0.91, n)),
         'Q' : np.eye(n,dtype=dtype),
         'd' : np.zeros(p, dtype=dtype),
         'mu0' : np.zeros(n),
         'V0' : np.eye(n),
         'C'  : np.random.normal(size=(p,n))/np.sqrt(n),
         'R'  : 1. * np.ones(p,dtype=dtype),
         'B'  : np.empty((n, 0))}

    max_iter = 10
    likes = np.zeros(max_iter)
    res = np.zeros((max_iter, n+1))

    t = time.time()

    rnd_seed_fit = np.random.get_state()
    from ssm_scripts import setup_fit_lds
    obs_scheme = {'sub_pops': [list(range(0,p)),list(range(0,p))],
             'obs_pops': [0,1],
             'obs_time': [T//2,T]}
    fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                            u=None, 
                            max_iter=1,
                            epsilon=np.log(1.001), 
                            eps_cov=0,
                            plot_flag=False, 
                            trace_pars_flag=False, 
                            trace_stats_flag=False, 
                            diag_R_flag=True,
                            use_A_flag=True, 
                            use_B_flag=False)

    # fit the model to data          
    print('fitting model to data')
    pars_hat = pars_init
    t = time.time()
    for i_ in range(max_iter):
        pars_hat['B'] = np.empty((n,0))
        pars_hat,ll = fit_lds(x_dim=n,
                              pars=pars_hat, 
                              obs_scheme=obs_scheme,
                              save_file=None)
        res[i_,1:] = principal_angle(pars_hat['C'], pars_true['C'])
        likes[i_] = ll[-1]
    elapsed_time = time.time() - t
    print('elapsed time for fitting is')
    print(elapsed_time)

    pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                       pars_hat['Q'])

    
    
        
    t = time.time() - t
    
    print('fitting time: ', t)
    
    plt.figure()
    plt.subplot(1,2,1)
    plt.plot(res[:,1:])
    plt.title('final princ. angles')
    plt.subplot(1,2,2)
    plt.plot(likes)
    plt.show()
        
    lag_range = np.arange(2*n)
    sso = True
    idx_a = np.sort(np.arange(np.minimum(200,p)))
    idx_b = idx_a.copy()
    obs_scheme = ObservationScheme(p=p, T=T, 
                                    sub_pops=(np.arange(p),), 
                                    obs_pops=(0,), 
                                    obs_time=(T,))
    obs_scheme.comp_subpop_stats()    

    W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
    print('computing time-lagged covariances')
    Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                          idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                          mmap=mmap,data_path=data_path,ts=None,ms=None)

    pars_est={
        'C' : la.orth(np.random.normal(size=(p,n))) * np.sqrt(p) / np.sqrt(n),
        'A' : np.diag(np.linspace(0.9, 0.9, n)),
        'B' : np.eye(n),
        'Pi': np.eye(n),
        'R' : np.ones(p)
    }
    pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])
    np.linalg.norm(pars_est['C'])
    
    pars_est, traces, ts= run_default(
                alphas    = (0.1, 0.0), 
                b1s       = (0.9, 0.9), 
                a_decays  = (0.95, 0.9999), 
                batch_sizes = (10, 1), 
                max_zip_sizes =  (200,10), 
                max_iters = (100, 10 ),
                parametrizations = ('nl', 'ln'),
                pars_est=pars_est, pars_true=pars_true, n=n, 
                y=y, sso=True, obs_scheme=obs_scheme, lag_range=lag_range, 
                idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                traces=[[], [], []], ts = [])          


In [None]:
np.mean(np.var(y,axis=0))

In [None]:
y.shape

In [None]:
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.icml_scripts import run_default
from ssidid import ObservationScheme, progprint_xrange

lag_range = np.arange(n)
sso = True
idx_a = np.sort(np.arange(np.minimum(200,p)))
idx_b = idx_a.copy()
obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=(np.arange(p),), 
                                obs_pops=(0,), 
                                obs_time=(T,))
obs_scheme.comp_subpop_stats()    

W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
print('computing time-lagged covariances')
Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                      idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                      mmap=mmap,data_path=data_path,ts=None,ms=None)

pars_true['X'] = np.vstack([np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                           Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        

print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)


In [None]:


pars_est={
    'C' : la.orth(np.random.normal(size=(p,n))) * np.sqrt(p) / np.sqrt(n),
    'A' : np.diag(np.linspace(0.9, 0.9, n)),
    'B' : np.eye(n),
    'Pi': np.eye(n),
    'R' : np.ones(p)
}
pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])
np.linalg.norm(pars_est['C'])

#pars_est = {
#    'C' : pars_true['C'].copy(),
#    'A' : pars_true['A'].copy(),
#    'B' : np.linalg.cholesky(pars_true['Pi']),
#    'Pi': pars_true['Pi'].copy(),
#    'R' : pars_true['R'].copy(),
#    'X' : pars_true['X'].copy()
#}

pars_est, traces, ts= run_default(
            alphas    = (0.05, 0.0), 
            b1s       = (0.9, 0.9), 
            a_decays  = (0.95, 0.9999), 
            batch_sizes = (10, 1), 
            max_zip_sizes =  (150,10), 
            max_iters = (100, 10 ),
            parametrizations = ('nl', 'ln'),
            pars_est=pars_est, pars_true=pars_true, n=n, 
            y=y, sso=True, obs_scheme=obs_scheme, lag_range=lag_range, 
            idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
            traces=[[], [], []], ts = [])    

In [None]:

pars_est={
    'C' : la.orth(np.random.normal(size=(p,n))) * np.sqrt(p) / np.sqrt(n),
    'A' : np.diag(np.linspace(0.9, 0.9, n)),
    'B' : np.eye(n),
    'Pi': np.eye(n),
    'R' : np.ones(p)
}
pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])
np.linalg.norm(pars_est['C'])

In [None]:
pars_true['Pi']

In [None]:
p

In [None]:
pars_true['X'] = np.vstack([np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                           Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        

print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)