# Experiment 'e2': recovering LDS parameters for varying system sizes

- system size p = 100, 1k, 10k, 100, n = 10, and signal-to-noise ratio such that R gives 90% of total variance. 
- fitting first dynamics-agnostic, then switching to linearized parameterization with in particular $A$ exctracted from agnostic fit
- direct comparison with GROUSE on subspace identification task (principal angles). 

## notes:
- this is a master file. Individual runs for different data-set lengths $T \in [10^3, 10^4, 10^5]$ were run on instances of this file that might have been slightly altered (e.g. reducing max batch size and instead running more epochs for small T's). They should be backed up at the lab dropbox folder. 

# Load stored full data, extract observed data for this experiment

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
#from subtracking import Grouse, calc_subspace_proj_error
from ssidid.utility import gen_data, gen_pars, draw_data 
from ssidid.icml_scripts import run_default


i = 1
run = '_e2_init'

rnd_seeds = range(20,30)
ps = np.array([1e2, 1e3, 1e4, 1e5],dtype=int)
ns = [ 10,  10,  10,  10]
Ts = int(1e5) * np.ones(len(ps),dtype=int)
lag_range = np.arange(20)

p,n,T,snr = ps[i], ns[i], Ts[i], (1, 1)

data_path = '/groups/turaga/home/nonnenmacherm/results/icml/icml_e2/'
#data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e2/'
dat_path = data_path + 'p' + str(p) + '/'


for rnd_seed in rnd_seeds:

    nr = 0 # number of real eigenvalues
    whiten = True
    eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.95, 0.99, 0.95, 0.99
    nc, nc_u = n - nr, (n - nr)//2
    ev_r = np.linspace(eig_m_r, eig_M_r, nr)
    ev_c = np.exp(2 * 1j * np.pi * np.random.vonmises(mu=0, kappa=1000, size=nc_u))
    ev_c = np.linspace(eig_m_c, eig_M_c, (n - nr)//2) * ev_c
    mmap, verbose = False, True

    np.random.seed(rnd_seed)
    pars_true = gen_pars(p,n, nr, ev_r, ev_c, snr, whiten)
    pars_true['d'], pars_true['mu0'], pars_true['V0'] = np.zeros(p), np.zeros(n), pars_true['Pi'].copy()
    pars_true['C'] = la.orth(pars_true['C'])
    _,y = draw_data(pars_true,T)
    y -= y.mean(axis=0)

    idx_a = np.sort(np.random.choice(p, 1000, replace=False)) if p > 1000 else np.arange(p)
    idx_b = idx_a.copy()

    sso = True
    obs_scheme = ObservationScheme(p=p, T=T, 
                                    sub_pops=(np.arange(p),), 
                                    obs_pops=(0,), 
                                    obs_time=(T,))
    obs_scheme.comp_subpop_stats()    

    W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
    print('computing time-lagged covariances')
    Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                          idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                          mmap=mmap,data_path=data_path,ts=None,ms=None)

    def principal_angle(A, B):
        "A and B must be column-orthogonal."    
        A = np.atleast_2d(A).T if (A.ndim<2) else A
        B = np.atleast_2d(B).T if (B.ndim<2) else B
        A = la.orth(A)
        B = la.orth(B)
        svd = la.svd(A.T.dot(B))
        return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


    pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
    print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],
                                   X=pars_true['X'],
                                   R=pars_true['R'],
                                   Qs=Qs,
                                   Om=Om,
                                   lag_range=lag_range,
                                   ms=range(len(lag_range)),
                                   idx_a=idx_a,
                                   idx_b=idx_b))
    print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)


    pars_est = 'default'
    pars_est_g = 'default'

    save_dict = {'p' : p,
                 'n' : n,
                 'T' : T,
                 'snr' : snr,
                 'obs_scheme' : obs_scheme,
                 'lag_range' : lag_range,
                 'mmap' : mmap,
                 'pars_true' : pars_true,
                 'pars_est' : pars_est,
                 'idx_a' : idx_a,
                 'idx_b' : idx_b,
                 'W' : W,
                 'Qs' : Qs,
                 'Om' : Om,
                 'rnd_seed' : rnd_seed
                }
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + '_seed' + str(rnd_seed) + 'e2_init'
    np.savez(data_path + file_name, save_dict)
    
    
    rnd_seed_fit = np.random.get_state()
    #np.random.seed(rnd_seed)
    pars_est, traces, ts= run_default(
                alphas    = (0.01, 0.001), 
                b1s       = (0.98, 0.95), 
                a_decays  = (0.98, 0.98), 
                batch_sizes = (1, 10), 
                max_zip_sizes =  (1000,250), 
                max_iters = (100, 200 ),
                parametrizations = ('nl', 'ln'),
                pars_est='default', pars_true=pars_true, n=n, 
                y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
                idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                traces=[[], [], []], ts = [])    

    """
    # settings for GROUSE
    a_grouse = 1.
    tracker = Grouse(p, n, a_grouse )
    max_epoch_size = 1000
    max_iter_grouse = 200
    get_obs = obs_scheme.gen_get_observed()

    # fit GROUSE
    print('\n - GROUSE')
    tracker.step = a_grouse
    ct = 1.
    error = np.zeros((max_iter_grouse, n+1))
    t = time.time()
    get_obs = obs_scheme.gen_get_observed()

    for i in range(max_iter_grouse):
        if np.mod(i,max_iter_grouse//10) == 0:
            print('finished % ' + str((100*i)//max_iter_grouse))
        idx = np.random.permutation(T-np.max(lag_range)-1)
        idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
        for j in range(len(idx)):
            obs_idx =  np.zeros((p,1), dtype=bool)
            obs_idx[get_obs(idx[j])] = True
            tracker.consume(y[idx[j],:].reshape(-1,1), obs_idx)
            ct += 1     
            tracker.step = a_grouse / ct

        error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))
    t = time.time() - t
    pars_est_g = {'C' : tracker.U.copy()}

    print('final proj. error (est.): ', str(error[-1][0]))

    plt.subplot(1,2,1)
    plt.plot(error[:,1:])
    plt.title('subspace proj. error (GROUSE)')
    plt.subplot(1,2,2)
    plt.loglog(error[:,1:])
    plt.title('subspace proj. error (GROUSE)')
    plt.show()

    traces_g = [error.copy()]
    ts_g = [t]            

    print('filtering data') 
    obs_scheme.gen_mask_from_scheme()
    tracker = Grouse(p, n, 0. )
    tracker.U = pars_est_g['C'].copy()
    x_g = np.zeros((T,n))
    for t in range (T):
        x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
    obs_scheme.mask = None    

    lag_range_g = np.arange(20)
    kl_ = np.max(lag_range_g) + 1
    print('extracting dynamics parameters') 
    pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range_g])
    pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range_g)-1)*n,:], pars_est_g['X'][n:len(lag_range_g)*n,:])[0]
    pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
    ev_est = np.linalg.eigvals(pars_est_g['A'])
    del x_g                
    """
    pars_est_g, traces_g, ts_g = None, None, np.inf

    save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                 'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : None,
                 'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_est_g' : pars_est_g,
                 'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                 'traces' : traces, 'traces_g' : traces_g, 'ts':ts, 'ts_g':ts_g,
                 'rnd_seed' : rnd_seed_fit
                }
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + '_seed' + str(rnd_seed) + 'e2_final'
    np.savez(data_path + file_name, save_dict)          