# Serial subset obserations

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3'
# define problem size
lag_range = np.arange(10)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100000 + kl_
T = T_full

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True

rnd_seeds = range(40,50)
overlaps = (10,)


for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    idx_a, idx_b = np.arange(p), np.arange(p)
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
    pars_true = load_file['pars_true']
    print('angles ev(A): ', np.sort(np.angle(np.linalg.eigvals(pars_true['A'])))/np.pi)
    
    #np.random.seed(rnd_seed)
    #pars_true, x, y, _, _ = gen_data(p,n,lag_range,T_full, nr,
    #                                 eig_m_r, eig_M_r, 
    #                                 eig_m_c, eig_M_c,
    #                                 mmap, chunksize,
    #                                 data_path,
    #                                 snr=snr, whiten=whiten)        
    #pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
    
    idx_a, idx_b = np.arange(p), np.arange(p)

    sso = True

    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        reps = 1
        ns = 10
        sub_pops = [np.arange(0,  p//ns + overlap//2)]
        sub_pops = sub_pops + [np.arange(i*p//ns-overlap//2, (i+1)*p//ns+overlap//2) for i in range(1,ns-1)]
        sub_pops.append(np.arange((ns-1)*p//ns-overlap//2,p))
        obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
        obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
        obs_scheme = ObservationScheme(p=p, T=T, 
                                        sub_pops=sub_pops, 
                                        obs_pops=obs_pops, 
                                        obs_time=obs_time)

        W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
        if overlap < p:
            for m in range(1, len(lag_range)):
                W[m] = W[0].copy()
                #W[m][0,1] = 0
                #W[m][1,0] = 0

        print('computing time-lagged covariances')
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)    
        if overlap < p:
            for m in range(len(lag_range)):    
                Om[m] = Om[0].copy()
                #Om[m][np.ix_(obs_scheme.idx_grp[0], obs_scheme.idx_grp[1])] = False
                #Om[m][np.ix_(obs_scheme.idx_grp[1], obs_scheme.idx_grp[0])] = False

        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        

        
        sub_pops = obs_scheme.sub_pops
        rnd_seed = int(time.time()) #np.random.get_state()
        np.random.seed(rnd_seed)
        pars_est, traces, ts= run_default(
                    alphas    = (0.1, 0.001), 
                    b1s       = (0.9 , 0.9), 
                    a_decays  = (0.95, 0.98), 
                    batch_sizes = (1, 10), 
                    max_zip_sizes =  (1000,100), 
                    max_iters = (100, 100 ),
                    parametrizations = ('nl', 'ln'),
                    pars_est='default', pars_true=pars_true, n=n, 
                    y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
                    idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                    traces=[[], [], []], ts = [])    


        print('per-subpops principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'][sub_pops[0],:], C[sub_pops[0],:]))
        print(principal_angle(pars_true['C'][sub_pops[1],:], C[sub_pops[1],:]))

        print('final principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'], C))

        print('final principal angles (sign-flipped)')
        C = pars_est['C'].copy()
        C[sub_pops[0],:] *= -1
        print(principal_angle(pars_true['C'], C))

        del C
        
        """
        
        # settings for GROUSE
        a_grouse = 1.
        tracker = Grouse(p, n, a_grouse )
        max_epoch_size = 100
        max_iter_grouse = 1000
        get_obs = obs_scheme.gen_get_observed()

        # fit GROUSE
        print('\n - GROUSE')
        tracker.step = a_grouse
        ct = 1.
        error = np.zeros((max_iter_grouse, n+1))
        t = time.time()
        get_obs = obs_scheme.gen_get_observed()
        
        for i in range(max_iter_grouse):
            if np.mod(i,max_iter_grouse//10) == 0:
                print('finished % ' + str((100*i)//max_iter_grouse))
            idx = np.random.permutation(T-np.max(lag_range)-1)
            idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
            for j in range(len(idx)):
                obs_idx =  np.zeros((p,1), dtype=bool)
                obs_idx[get_obs(idx[j])] = True
                tracker.consume(y[idx[j],:].reshape(-1,1), obs_idx)
                ct += 1     
                tracker.step = a_grouse / ct
        
            error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))
        t = time.time() - t
        pars_est_g = {'C' : tracker.U.copy()}

        print('final proj. error (est.): ', str(error[-1][0]))

        plt.subplot(1,2,1)
        plt.plot(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.subplot(1,2,2)
        plt.loglog(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.show()

        print('per-subpops principal angles')
        C = pars_est_g['C'].copy()
        print(principal_angle(pars_true['C'][sub_pops[0],:], C[sub_pops[0],:]))
        print(principal_angle(pars_true['C'][sub_pops[1],:], C[sub_pops[1],:]))

        print('final principal angles')
        C = pars_est_g['C'].copy()
        print(principal_angle(pars_true['C'], C))

        print('final principal angles (sign-flipped)')
        C = pars_est_g['C'].copy()
        C[sub_pops[0],:] *= -1
        print(principal_angle(pars_true['C'], C))

        del C    
        traces_g = [error.copy()]
        ts_g = [t]            

        print('filtering data') 
        obs_scheme.gen_mask_from_scheme()
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est_g['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
        obs_scheme.mask = None    

        lag_range_g = np.arange(20)
        kl_ = np.max(lag_range_g) + 1
        print('extracting dynamics parameters') 
        pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range_g])
        pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range_g)-1)*n,:], pars_est_g['X'][n:len(lag_range_g)*n,:])[0]
        pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
        ev_est = np.linalg.eigvals(pars_est_g['A'])
        del x_g                
        """
        pars_est_g, traces_g, ts_g = None, None, None

        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : data_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_est_g' : pars_est_g,
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_g' : traces_g, 'ts':ts, 'ts_g':ts_g,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap) 
        np.savez(data_path + file_name, save_dict)    


# FA fits

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from sklearn.decomposition import FactorAnalysis

run = '_e3'
# define problem size
lag_range = np.arange(10)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T = 100000 + kl_
T_full = 100030

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True
dtype=np.float
rnd_seeds = range(30,50)
#overlaps = (0,5,10,15,20,25,50,100,300,1000)
overlaps = (1000, 300, 100, 50, 25, 20, 15, 10, 5, 0)

overlaps = (10,)
for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    y = np.memmap(data_path+'y', dtype=np.float, mode='c', shape=(T,p))
    idx_a, idx_b = np.arange(p), np.arange(p)
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
    pars_true = load_file['pars_true']        

    sso = True

    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        reps = 1
        ns = 10
        sub_pops = [np.arange(0,  p//ns + overlap//2)]
        sub_pops = sub_pops + [np.arange(i*p//ns-overlap//2, (i+1)*p//ns+overlap//2) for i in range(1,ns-1)]
        sub_pops.append(np.arange((ns-1)*p//ns-overlap//2,p))
        obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
        obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
        obs_scheme = ObservationScheme(p=p, T=T, 
                                        sub_pops=sub_pops, 
                                        obs_pops=obs_pops, 
                                        obs_time=obs_time)
        obs_scheme.comp_subpop_stats()        
        obs_scheme.gen_mask_from_scheme()
        y[np.invert(obs_scheme.mask)] = 0
        obs_scheme.mask = None
        obs_scheme.use_mask = False

        # start fitting    
        pars_ests = []
        traces = []
        ts = []
        for j in range(len(sub_pops)):
            print('fitting subpop #' + str(j))        
            data = y[:obs_time[0], sub_pops[0]] if j==0 else y[obs_time[j-1]:obs_time[j], sub_pops[j]]
            print('data shape ', data.shape)
            rnd_seed_fit = np.random.get_state()    
            fa = FactorAnalysis(n_components=n, 
                                tol=0.01, 
                                copy=True, 
                                max_iter=1000, 
                                noise_variance_init=None, 
                                svd_method='randomized', 
                                iterated_power=3, 
                                random_state=0)
            t = time.time()
            fa.fit(data)
            t = time.time() - t
            print('fitting time: ', t)
            print('principal angles: ', principal_angle(pars_true['C'][sub_pops[j],:], fa.components_.T))

            pars_ests.append({
                'C' : fa.components_.T.copy(),
                'Pi' : np.eye(n).copy(),
                'R' : fa.noise_variance_.copy()
            })
            traces.append(fa.loglike_.copy())
            ts.append(t)        
            del fa


        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : None,
                     'pars_true' : pars_true, 
                     'pars_ests' : pars_ests, 'traces' : traces, 'ts': ts, 
                     'rnd_seed' : rnd_seed_fit
                    }

        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap) + '_FAsp'
        np.savez(data_path + file_name, save_dict)            

In [None]:
from scipy import linalg as la 
run = '_e3'

overlaps = (10,)
rnd_seeds = range(30,50)


subsp_errors_FAst    = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors_FAst_f  = np.zeros((len(overlaps), len(rnd_seeds)))
for rndsidx in range(len(rnd_seeds)):

    rnd_seed = rnd_seeds[rndsidx]
    #print('\n rndseed ' + str(rnd_seed))
    
    for i in range(len(overlaps)):
        
        overlap = overlaps[i]
        #print('\n overlap ' + str(overlap))

        data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_' + str(rnd_seed) + '/'        
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap) + '_FAsp'
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
        pars_ests = load_file['pars_ests']
        pars_true = load_file['pars_true']
        obs_scheme = load_file['obs_scheme']
        p,n = pars_true['C'].shape

        #pars_true['C'] = pars_true['C'][np.union1d(np.union1d(sub_pops[0],sub_pops[1]), sub_pops[2]),:]
        
        sub_pops = obs_scheme.sub_pops
        C12 = np.zeros_like(pars_true['C'])    
        C12[sub_pops[0],:] = pars_ests[0]['C']
        for j in range(1,len(sub_pops)):
            idx_overlap = np.intersect1d(sub_pops[j-1], sub_pops[j])
            C2 = np.nan * np.zeros((p,n))            
            C2[sub_pops[ j ], :] = pars_ests[ j ]['C']

            if overlap > 0:
                W, sclale = la.orthogonal_procrustes(C2[idx_overlap,:], C12[idx_overlap,:])
            else:
                W = np.eye(n)
            C12[sub_pops[j],:] = C2[sub_pops[j],:].dot(W)

        subsp_errors_FAst[i,rndsidx]    = calc_subspace_proj_error(pars_true['C'], C12)
        subsp_errors_FAst_f[i, rndsidx] = calc_subspace_proj_error(pars_true['C'], C12)
    
        C = C12.copy()
        C[sub_pops[0],:] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], C12):
            C12[sub_pops[0],:] *= -1    
            subsp_errors_FAst_f[i, rndsidx] = calc_subspace_proj_error(pars_true['C'], C12)
                                               
        
        
plt.semilogx(np.array(overlaps)+0.01, subsp_errors_FAst, 'k')
plt.show()

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
save_dict = {
    'run' : run, 
    'p' : p,
    'n' : n, 
    'overlaps' : overlaps ,
    'rnd_seeds' : rnd_seeds,
    'subsp_errors_FAst' : subsp_errors_FAst,   
}
#np.save(data_path + 'fig3_B_FAst_final_data', save_dict)       


# add dynamics estimates to FA fits

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import Grouse, calc_subspace_proj_error


mmap, verbose = True, True

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

run = '_e3'

p,T_full,n,snr = 1000, 100010, 10, (1., 1.)

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
load_file = np.load(data_path + 'p1000n10T100010_e3_FA_all.npz')['arr_0'].tolist()
overlaps = load_file['overlaps']
rnd_seeds = load_file['rnd_seeds']

lag_range = np.arange(10)
kl_ = np.max(lag_range) + 1

for rndsidx in range(len(rnd_seeds)):
    
    rnd_seed = rnd_seeds[rndsidx]
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_'+str(rnd_seed)+'/'

    for i in range(len(overlaps)):

        T = T_full
        overlap = overlaps[i]

        print('T = ', T)
        print('overlap =', overlap)
        y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
        
        sub_pops = [np.arange((p+overlap)//2),np.arange((p-overlap)//2,p)]

        reps = 1
        obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
        obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
        obs_scheme = ObservationScheme(p=p, T=T, 
                                        sub_pops=sub_pops, 
                                        obs_pops=obs_pops, 
                                        obs_time=obs_time)
        obs_scheme.comp_subpop_stats()        
        obs_scheme.gen_mask_from_scheme()        
        
        pars_est = load_file['pars_est_all'][rndsidx][i]


        print('filtering data') 
        obs_scheme.gen_mask_from_scheme()
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
        obs_scheme.mask = None

        print('extracting dynamics parameters') 
        pars_est['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range])
        pars_est['A'] = np.linalg.lstsq(pars_est['X'][:(len(lag_range)-1)*n,:], pars_est['X'][n:len(lag_range)*n,:])[0]
        pars_est['Pi'] = (pars_est['X'][:n,:] + pars_est['X'][:n,:].T)/2 


save_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'   
file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run) + '_FA_all_addedDyns'
np.savez(save_path + file_name, load_file)  

In [None]:
load_file['rnd_seeds']

# add EM fits

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om

os.chdir("/home/mackelab/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
from ssm_scripts import setup_fit_lds

run = '_e3'
# define problem size
lag_range = np.arange(10)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100000 + kl_
T = T_full

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True
dtype=np.float
rnd_seeds = range(40,50)
#overlaps = (0,10,15,20,25,50,100,300,1000)
overlaps = (1000, 300, 100, 50, 25, 20, 15, 10, 0)


for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    idx_a, idx_b = np.arange(p), np.arange(p)
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
    pars_true = load_file['pars_true']
    print('angles ev(A): ', np.sort(np.angle(np.linalg.eigvals(pars_true['A'])))/np.pi)
        

    sso = True

    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        sub_pops = [np.arange((p+overlap)//2),np.arange((p-overlap)//2,p)]

        reps = 1
        obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
        obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
        obs_scheme = {'sub_pops': sub_pops,
                      'obs_pops': obs_pops,
                      'obs_time': obs_time}

        np.random.seed(int(time.time()))
        rnd_seed_fit = np.random.get_state()    
        pars_init={
            'C' : np.asarray(la.orth(np.random.normal(size=(p,n))),dtype=dtype) * np.sqrt(p) / np.sqrt(n),
            'A' : np.asarray(np.diag(np.linspace(0.89, 0.91, n)), dtype=dtype),
            'Q' : np.asarray(np.eye(n), dtype=dtype),
            'R' : 2*np.ones(p, dtype=dtype),
            'd' : np.zeros(p, dtype=dtype),
            'mu0' :  np.zeros(n ,dtype=dtype),
            'V0' : np.eye(  n ,dtype=dtype)            
        }    

        # EM    
        max_iter_EM = 200
        eps_cov = 1e-5
        epsilon = 1e-6
        likes = np.zeros(max_iter_EM)
        res = np.zeros((max_iter_EM, n+1))
        fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                                u=None, 
                                max_iter=max_iter_EM,
                                epsilon=epsilon, 
                                eps_cov=eps_cov,
                                plot_flag=False, 
                                trace_pars_flag=True, 
                                trace_stats_flag=False, 
                                diag_R_flag=True,
                                use_A_flag=True, 
                                use_B_flag=False)        
        
        
        pars_hat = pars_init    
        t = time.time()
        #try:
        pars_hat['B'] = np.empty((n,0), dtype=dtype)
        pars_hat,ll = fit_lds(x_dim=n,
                              pars=pars_hat, 
                              obs_scheme=obs_scheme,
                              save_file=None)
        for i_ in range(len(pars_hat['Cs'])-1):
            res[i_,1:] = principal_angle(pars_hat['Cs'][i_+1], pars_true['C'])

        likes = np.array(ll[1:])
        elapsed_time = time.time() - t
        #except:
        #    elapsed_time = np.nan 
        #    likes = np.zeros(0)
        #    print('\n ')
        #    print('EM BROKE')
        #    print('\n ')
        print('elapsed time for fitting is')
        print(elapsed_time)
        pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                           pars_hat['Q'])        

        t = time.time() - t    
        print('fitting time: ', t)

        plt.figure()
        plt.subplot(1,2,1)
        plt.plot(res[:,1:])
        plt.title('final princ. angles')
        plt.subplot(1,2,2)
        plt.plot(likes)
        plt.show()

        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'pars_true' : pars_true, 'pars_est_EM' : pars_hat, 
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces_EM' : [likes, res], 'ts_EM':[t], 
                     'rnd_seed' : rnd_seed, 'eps_cov' : eps_cov, 'epsilon' : epsilon
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap) + '_EM_r2'
        np.savez(data_path + file_name, save_dict)    


# add more EM iterations to previous fits

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om

os.chdir("/home/mackelab/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
from ssm_scripts import setup_fit_lds

run = '_e3'
# define problem size
lag_range = np.arange(10)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100000 + kl_
T = T_full

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True
dtype=np.float
rnd_seeds = range(43,50)
overlaps = (1000, 300, 100, 50, 25, 20, 15, 10, 0)
#overlaps = (0, )


for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    idx_a, idx_b = np.arange(p), np.arange(p)
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
    pars_true = load_file['pars_true']
    print('angles ev(A): ', np.sort(np.angle(np.linalg.eigvals(pars_true['A'])))/np.pi)
        

    sso = True

    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        sub_pops = [np.arange((p+overlap)//2),np.arange((p-overlap)//2,p)]

        reps = 1
        obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
        obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
        obs_scheme = {'sub_pops': sub_pops,
                      'obs_pops': obs_pops,
                      'obs_time': obs_time}


        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap) + '_EM'
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
        pars_hat = load_file['pars_est_EM']
        traces_EM = load_file['traces_EM']
        ts_EM = load_file['ts_EM']
        eps_cov = load_file['eps_cov']

        # EM    
        max_iter_EM = 100
        epsilon = 1e-6
        likes = np.zeros(max_iter_EM)
        res = np.zeros((max_iter_EM, n+1))
        fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                                u=None, 
                                max_iter=max_iter_EM,
                                epsilon=epsilon, 
                                eps_cov=eps_cov,
                                plot_flag=False, 
                                trace_pars_flag=True, 
                                trace_stats_flag=False, 
                                diag_R_flag=True,
                                use_A_flag=True, 
                                use_B_flag=False)        
        
        
        pars_hat['As'].pop()
        pars_hat['Qs'].pop()
        pars_hat['Bs'].pop()
        pars_hat['mu0s'].pop()
        pars_hat['V0s'].pop()
        pars_hat['Cs'].pop()
        pars_hat['Rs'].pop()
        pars_hat['ds'].pop()
        pars_hat['B'] = np.empty((n,0), dtype=dtype)
        t = time.time()
        pars_hat,ll = fit_lds(x_dim=n,
                              pars=pars_hat, 
                              obs_scheme=obs_scheme,
                              save_file=None)
        for i_ in range(max_iter_EM):
            res[i_,1:] = principal_angle(pars_hat['Cs'][len(pars_hat['Cs'])-max_iter_EM+i_], pars_true['C'])

        likes = np.array(ll[1:])
        elapsed_time = time.time() - t
        #except:
        #    elapsed_time = np.nan 
        #    likes = np.zeros(0)
        #    print('\n ')
        #    print('EM BROKE')
        #    print('\n ')
        print('elapsed time for fitting is')
        print(elapsed_time)
        pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                           pars_hat['Q'])        

        t = time.time() - t    
        print('fitting time: ', t)

        plt.figure()
        plt.subplot(1,2,1)
        plt.plot(res[:,1:])
        plt.title('final princ. angles')
        plt.subplot(1,2,2)
        plt.plot(likes)
        plt.show()
        
        ts_EM.append(t)
        traces_EM[0] = np.hstack((traces_EM[0], likes))
        traces_EM[1] = np.vstack((traces_EM[1], res))
        snr, lag_range = load_file['snr'], load_file['lag_range']
        idx_a, idx_b = load_file['idx_a'], load_file['idx_b']

        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'pars_true' : pars_true, 'pars_est_EM' : pars_hat, 
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces_EM' : traces_EM, 'ts_EM':ts_EM, 
                     'rnd_seed' : rnd_seed, 'eps_cov' : eps_cov, 'epsilon' : epsilon
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap) + '_EM'
        np.savez(data_path + file_name, save_dict)    


# add SSID fits with harsh learning rate

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3'
# define problem size
lag_range = np.arange(10)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100000 + kl_
T = T_full

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True

rnd_seeds = range(30,50)
overlaps = (10,)


for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    
    print('first data entries: ', y[:10,:10])
        
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full + 20)+'_e3_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
    pars_true = load_file['pars_true']
    sso = True
    idx_a, idx_b = np.arange(p), np.arange(p)

    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))

        reps = 1
        obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
        obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
        obs_scheme = ObservationScheme(p=p, T=T, 
                                        sub_pops=sub_pops, 
                                        obs_pops=obs_pops, 
                                        obs_time=obs_time)

        W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
        if overlap < p:
            for m in range(len(lag_range)):
                W[m][0,1] = 0
                W[m][1,0] = 0

        print('computing time-lagged covariances')
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)    
        if overlap < p:
            for m in range(len(lag_range)):    
                Om[m][np.ix_(obs_scheme.idx_grp[0], obs_scheme.idx_grp[1])] = False
                Om[m][np.ix_(obs_scheme.idx_grp[1], obs_scheme.idx_grp[0])] = False

        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        

        
        sub_pops = obs_scheme.sub_pops
        rnd_seed = np.random.get_state()
        #np.random.seed(rnd_seed)
        pars_init={
            'C' : np.asarray(la.orth(np.random.normal(size=(p,n))),dtype=dtype) * np.sqrt(p) / np.sqrt(n),
            'A' : np.asarray(np.diag(np.linspace(0.89, 0.91, n)), dtype=dtype), 
            'Pi' : np.asarray(np.eye(n), dtype=dtype),
            'R' : np.ones(p, dtype=dtype)    
        }            
        pars_init['X'] = np.vstack([ np.linalg.matrix_power(pars_init['A'],m).dot(pars_init['Pi']) for m in lag_range])

        pars_est, traces, ts= run_default(
                    alphas    = (0.05, 0.001), 
                    b1s       = (0.9, 0.95), 
                    a_decays  = (0.95, 0.98), 
                    batch_sizes = (1, 10), 
                    max_zip_sizes =  (1000, 100), 
                    max_iters = (100, 100 ),
                    parametrizations = ('nl', 'ln'),
                    pars_est=pars_init, pars_true=pars_true, n=n, 
                    y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
                    idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                    traces=[[], [], []], ts = [])    
        
        print('final subsp. proj. error: ', calc_subspace_proj_error(pars_est['C'], pars_true['C']))


        print('per-subpops principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'][sub_pops[0],:], C[sub_pops[0],:]))
        print(principal_angle(pars_true['C'][sub_pops[1],:], C[sub_pops[1],:]))

        print('final principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'], C))

        print('final principal angles (sign-flipped)')
        C = pars_est['C'].copy()
        C[sub_pops[0],:] *= -1
        print(principal_angle(pars_true['C'], C))

        del C

        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : data_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_est_g' : None,
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_g' : None, 'ts':ts, 'ts_g':None,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap) + '_r2'
        np.savez(data_path + file_name, save_dict)    


In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3'
# define problem size
lag_range = np.arange(10)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100000 + kl_
T = T_full

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True

rnd_seeds = range(30,50)
overlaps = (10,)


subsp_errs = np.zeros((len(overlaps), len(rnd_seeds)))

rndsidx = 0

for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    
    print('first data entries: ', y[:10,:10])
        
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full + 20)+'_e3_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
    pars_true = load_file['pars_true']
    sso = True
    idx_a, idx_b = np.arange(p), np.arange(p)

    i = 0
    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))
        
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap) + '_r2'
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
        pars_est  = load_file['pars_est']
        
        print('final subsp. proj. error: ', calc_subspace_proj_error(pars_est['C'], pars_true['C']))
        
        subsp_errs[i, rndsidx] = calc_subspace_proj_error(pars_est['C'], pars_true['C'])
        
        print('per-subpops principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'][sub_pops[0],:], C[sub_pops[0],:]))
        print(principal_angle(pars_true['C'][sub_pops[1],:], C[sub_pops[1],:]))

        print('final principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'], C))

        print('final principal angles (sign-flipped)')
        C = pars_est['C'].copy()
        C[sub_pops[0],:] *= -1
        print(principal_angle(pars_true['C'], C))
        
        i += 1
    rndsidx += 1
        


In [None]:
subsp_errs[0,10] = np.nan

In [None]:
np.nanmean(subsp_errs), np.nanstd(subsp_errs)

# add GROUSE fits

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3_slim'
# define problem size
lag_range = np.arange(30)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100000 + kl_
T = T_full

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True

rnd_seeds = range(30,40)
overlaps = (5000, 50000)


for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/explore/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
        
    idx_a, idx_b = np.arange(p), np.arange(p)

    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+'_e3_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
    pars_true = load_file['pars_true']
    sso = True

    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')
        
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap)
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()            
        pars_est = load_file['pars_est']
        traces = load_file['traces']
        ts = load_file['ts']
        W,Qs,Om =  load_file['W'],load_file['Qs'],load_file['Om']
        obs_scheme = load_file['obs_scheme']
        obs_scheme.use_mask = False # sso scheme! just to make sure...
        
        W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
        if overlap < p:
            for m in range(len(lag_range)):
                W[m][0,1] = 0
                W[m][1,0] = 0

        print('computing time-lagged covariances')
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)    
        
        
        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        

                                                   
        print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces[-1],False,data_path)
        print('est. param. loss: ', f_l2_Hankel_nl(C=pars_est['C'],X=pars_est['X'],R=pars_est['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        
                
        sub_pops = obs_scheme.sub_pops
        print('\n')
        if len(sub_pops) > 1:
            print('overlap: ', str(len(np.intersect1d(sub_pops[0], sub_pops[1]))))            

        ts_g = load_file['ts_g']
        traces_g = load_file['traces_g']
        pars_est_g = load_file['pars_est_g'].copy()
        rnd_seed = load_file['rnd_seed']
        
        """        
        # settings for GROUSE
        a_grouse = 1.
        tracker = Grouse(p, n, a_grouse )
        max_epoch_size = 1000
        max_iter_grouse = 1500
        get_obs = obs_scheme.gen_get_observed()

        # fit GROUSE
        print('\n - GROUSE')
        tracker.step = a_grouse
        ct = 1.
        error = np.zeros((max_iter_grouse, n+1))
        t = time.time()
        get_obs = obs_scheme.gen_get_observed()
        
        for i in range(max_iter_grouse):
            if np.mod(i,max_iter_grouse//10) == 0:
                print('finished % ' + str((100*i)//max_iter_grouse))
            idx = np.random.permutation(T-np.max(lag_range)-1)
            idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
            for j in range(len(idx)):
                obs_idx =  np.zeros((p,1), dtype=bool)
                obs_idx[get_obs(idx[j])] = True
                tracker.consume(y[idx[j],:].reshape(-1,1), obs_idx)
                ct += 1     
                tracker.step = a_grouse / ct
        
            error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))
        t = time.time() - t
        pars_est_g = {'C' : tracker.U.copy()}

        print('final proj. error (est.): ', str(error[-1][0]))

        plt.subplot(1,2,1)
        plt.plot(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.subplot(1,2,2)
        plt.loglog(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.show()

        print('per-subpops principal angles')
        C = pars_est_g['C'].copy()
        print(principal_angle(pars_true['C'][sub_pops[0],:], C[sub_pops[0],:]))
        print(principal_angle(pars_true['C'][sub_pops[1],:], C[sub_pops[1],:]))

        print('final principal angles')
        C = pars_est_g['C'].copy()
        print(principal_angle(pars_true['C'], C))

        print('final principal angles (sign-flipped)')
        C = pars_est_g['C'].copy()
        C[sub_pops[0],:] *= -1
        print(principal_angle(pars_true['C'], C))

        del C    
        traces_g = [error.copy()]
        ts_g = [t]         
        
        """

        print('filtering data') 
        obs_scheme.gen_mask_from_scheme()
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est_g['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
        obs_scheme.mask = None

        print('extracting dynamics parameters') 
        pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range])
        pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range)-1)*n,:], pars_est_g['X'][n:len(lag_range)*n,:])[0]
        pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
        ev_est = np.linalg.eigvals(pars_est_g['A'])
        
        
        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : data_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_est_g' : pars_est_g,
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_g' : traces_g, 'ts':ts, 'ts_g':ts_g,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap)
        np.savez(data_path + file_name, save_dict)    


# Get estimates of linear latent dynamics

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import Grouse, calc_subspace_proj_error

rnd_seed = 32

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/seed_' + str(rnd_seed) + '/'
mmap, verbose = True, True

#run = '_e3_init'
#p,T_full,n,snr,lag_range = 1000, 200020, 20, (1., 1.), np.arange(20)
#file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) +  run
#load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
#pars_true = load_file['pars_true'].copy()
#del file_name
#kl_ = np.max(lag_range)+1
#ev_true = np.linalg.eigvals(pars_true['A'])

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

rnd_seeds = range(30,40)
overlaps = (10,15,20,25,50,100,300)
run = '_e3_slim'
    
p,T_full,n,snr = 1000, 100030, 10, (1., 1.)
lag_range = np.arange(20)
kl_ = np.max(lag_range) + 1

for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_'+str(rnd_seed)+'/'

    for i in range(len(overlaps)):

        T = T_full
        overlap = overlaps[i]

        print('T = ', T)
        print('overlap =', overlap)
        y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))

        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_'+str(overlap)
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
        
        obs_scheme, W,Qs,Om = load_file['obs_scheme'], load_file['W'],load_file['Qs'],load_file['Om']
        pars_est_g = load_file['pars_est_g']
        traces_g = load_file['traces_g']
        idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()

        print('filtering data') 
        obs_scheme.gen_mask_from_scheme()
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est_g['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
        obs_scheme.mask = None

        print('extracting dynamics parameters') 
        pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range])
        pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range)-1)*n,:], pars_est_g['X'][n:len(lag_range)*n,:])[0]
        pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
        ev_est = np.linalg.eigvals(pars_est_g['A'])

        print('storing')
        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : data_path if mmap else y,
                     'pars_true' : load_file['pars_true'], 'pars_est' : load_file['pars_est'], 
                     'idx_a' : load_file['idx_a'],'idx_b' : load_file['idx_b'], 'W' : W,'Qs' : Qs,'Om' : Om,
                     'traces' : load_file['traces'], 'ts': load_file['ts'], 
                     'rnd_seed' : load_file['rnd_seed'], 
                     'pars_est_g' : pars_est_g, 'traces_g' :  load_file['traces_g'], 'ts_g': load_file['ts_g']
                    }    

        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap)
        np.savez(data_path + file_name, save_dict)        


# missing at random

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme, progprint_xrange
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3rnd'

# define problem size
lag_range = np.arange(10)
lag_range_g = np.arange(20)

kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100030
T = 10000 + kl_

nr = 0 # number of real eigenvalues
snr = (9., 9.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True


sso = True

obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=(np.arange(p),), 
                                obs_pops=(0,), 
                                obs_time=(T,))
idx_a, idx_b = np.arange(p), np.arange(p)


rnd_seeds = range(11,20)
fracs_obs = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)

for rnd_seed in rnd_seeds:
    
    #data_path = '/media/marcel/636f7b46-1fd1-4600-b69e-86d2ed82002c/stitching/hankel/icml_e3/rnd/seed_' + str(rnd_seed) + '/'
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e1/seed_' + str(int(rnd_seed)) + '/'
    save_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/rnd/seed_' + str(int(rnd_seed)) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    
    """
    np.random.seed(rnd_seed)
    pars_true, x, y, _, _ = gen_data(p,n,lag_range,T, nr,
                                     eig_m_r, eig_M_r, 
                                     eig_m_c, eig_M_c,
                                     mmap, chunksize,
                                     data_path,
                                     snr=snr, whiten=whiten)        
    y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
    y -= y.mean(axis=0)
    #y[np.invert(obs_scheme.mask)] = np.nan
    del y
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    idx_a, idx_b = np.arange(p), np.arange(p)
    """
    
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) + 'snr' + str(np.int(np.mean(snr)//1)) + 'e1_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    pars_true = load_file['pars_true']    
    
    y_full = np.memmap(data_path+'y_full', dtype=np.float, mode='r', shape=(T_full,p))
    y = np.memmap(save_path+'y', dtype=np.float, mode='w+', shape=(T,p))
    y[:] = y_full[:T, :].copy()
    del y_full
    del y
    chunksize = np.minimum(p, 100)
    if mmap: 
        print('ensuring zero-mean data for given observation scheme')
        for i in progprint_xrange(p//chunksize, perline=10):
            y = np.memmap(save_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, i*chunksize:(i+1)*chunksize] = y[:, i*chunksize:(i+1)*chunksize] - y[:, i*chunksize:(i+1)*chunksize].mean(axis=0)
            del y
        if (p//chunksize)*chunksize < p:
            y = np.memmap(save_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, (p//chunksize)*chunksize:] = y[:, (p//chunksize)*chunksize:] - y[:, (p//chunksize)*chunksize:].mean(axis=0)
            del y        
        y = np.memmap(save_path+'y', dtype=np.float, mode='r', shape=(T,p))
    else:
        y -= y.mean(axis=0)
    
    for frac_obs in fracs_obs:
        
        print('\n')
        print('fraction observed:', str(frac_obs))
        print('\n')
        
        n_obs = np.ceil(p * frac_obs)
        mask = np.zeros((T,p),dtype=bool)
        for t in range(T):
            mask[t, np.random.choice(p, n_obs, replace=False)] = 1
        obs_scheme.mask = mask

        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        if p*T < 1e8:
            plt.figure(figsize=(20,10))
            plt.imshow(obs_scheme.mask.T, interpolation='None', aspect='auto')
            plt.grid('off')
            plt.show()

        print('computing time-lagged covariances')    
        obs_scheme.use_mask = False
        W_ = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W_,sso=True,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)

        obs_scheme.use_mask = True
        W = [ 1 / (frac_obs**2 * T * np.ones((1,1))) for m in range(len(lag_range))]
        #Om = [np.ones((p,p), dtype=bool) for m in lag_range]
        #Qs = [np.zeros((p,p)) for m in lag_range]
        
        
        pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        
        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,_,False,data_path)

        rnd_seed = np.random.get_state()
        #np.random.seed(rnd_seed)
        pars_est, traces, ts= run_default(
                    alphas    = (0.01, 0.001), 
                    b1s       = (0.98, 0.95), 
                    a_decays  = (0.98, 0.98), 
                    batch_sizes = (1, 10), 
                    max_zip_sizes =  (1000,250), 
                    max_iters = (250, 100),
                    parametrizations = ('nl', 'ln'),
                    pars_est='default', pars_true=pars_true, n=n, 
                    y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
                    idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                    traces=[[], [], []], ts = [])    
        
        # settings for GROUSE
        a_grouse = 1.
        tracker = Grouse(p, n, a_grouse )
        max_epoch_size = 1000
        max_iter_grouse = 1000
        get_obs = obs_scheme.gen_get_observed()

        # fit GROUSE
        print('\n - GROUSE')
        tracker.step = a_grouse
        ct = 1.
        error = np.zeros((max_iter_grouse, n+1))
        t = time.time()

        for i in range(max_iter_grouse):
            if verbose and np.mod(i,max_iter_grouse//10) == 0:
                print('finished % ' + str((100*i)//max_iter_grouse))
            idx = np.random.permutation(T-np.max(lag_range)-1)
            idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
            for j in range(len(idx)):
                tracker.consume(y[idx[j],:].reshape(-1,1), obs_scheme.mask[idx[j],:].reshape(-1,1))
                ct += 1     
                tracker.step = a_grouse / ct

            error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))

        t = time.time() - t

        pars_est_g = {'C' : tracker.U.copy()}
        traces_g = [error.copy()]
        ts_g = [t]
        
        # extracting dynamics for GROUSE
        print('filtering data') 
        obs_scheme.gen_mask_from_scheme()
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est_g['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
        
        kl_ = np.max(lag_range_g) + 1
        print('extracting dynamics parameters') 
        pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range_g])
        pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range_g)-1)*n,:], pars_est_g['X'][n:len(lag_range_g)*n,:])[0]
        pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
        ev_est = np.linalg.eigvals(pars_est_g['A'])
        del x_g

        plt.plot(np.real(np.linalg.eigvals( pars_est['A'])), 'go-')
        plt.plot(np.real(np.linalg.eigvals(pars_est_g['A'])), 'bo-')
        plt.plot(np.real(np.linalg.eigvals(pars_true['A'])), 'k')
        plt.show()
        plt.plot(np.imag(np.linalg.eigvals( pars_est['A'])), 'go-')
        plt.plot(np.imag(np.linalg.eigvals(pars_est_g['A'])), 'bo-')
        plt.plot(np.imag(np.linalg.eigvals(pars_true['A'])), 'k')
        plt.show()
        
        

        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : save_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_est_g' : pars_est_g,
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_g' : traces_g, 'ts':ts, 'ts_g':ts_g,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(int(100*frac_obs))
        np.savez(save_path + file_name, save_dict)    

        print('final proj. error (est.): ', str(error[-1][0]))

        plt.subplot(1,2,1)
        plt.plot(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')

        plt.subplot(1,2,2)
        plt.loglog(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.show()
        
        plt.plot(principal_angle(pars_true['C'], tracker.U), 'ro')
        plt.plot(principal_angle(pars_true['C'], pars_est['C']), 'bo')
        plt.legend(('GROUSE', 'SSIDID'))
        plt.ylabel('principal angles')
        plt.show()


In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme, progprint_xrange
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3rnd'

# define problem size
lag_range = np.arange(10)
lag_range_g = np.arange(20)

kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100030
T = 10000 + kl_

nr = 0 # number of real eigenvalues
snr = (9., 9.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True


sso = True

obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=(np.arange(p),), 
                                obs_pops=(0,), 
                                obs_time=(T,))
idx_a, idx_b = np.arange(p), np.arange(p)


rnd_seeds = range(19,20)
fracs_obs = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)

for rnd_seed in rnd_seeds:
    
    #data_path = '/media/marcel/636f7b46-1fd1-4600-b69e-86d2ed82002c/stitching/hankel/icml_e3/rnd/seed_' + str(rnd_seed) + '/'
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e1/seed_' + str(int(rnd_seed)) + '/'
    save_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/rnd/seed_' + str(int(rnd_seed)) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    
    """
    np.random.seed(rnd_seed)
    pars_true, x, y, _, _ = gen_data(p,n,lag_range,T, nr,
                                     eig_m_r, eig_M_r, 
                                     eig_m_c, eig_M_c,
                                     mmap, chunksize,
                                     data_path,
                                     snr=snr, whiten=whiten)        
    y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
    y -= y.mean(axis=0)
    #y[np.invert(obs_scheme.mask)] = np.nan
    del y
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    idx_a, idx_b = np.arange(p), np.arange(p)
    """
    
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) + 'snr' + str(np.int(np.mean(snr)//1)) + 'e1_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    pars_true = load_file['pars_true']    
    
    y_full = np.memmap(data_path+'y_full', dtype=np.float, mode='r', shape=(T_full,p))
    y = np.memmap(save_path+'y', dtype=np.float, mode='w+', shape=(T,p))
    y[:] = y_full[:T, :].copy()
    del y_full
    del y
    chunksize = np.minimum(p, 100)
    if mmap: 
        print('ensuring zero-mean data for given observation scheme')
        for i in progprint_xrange(p//chunksize, perline=10):
            y = np.memmap(save_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, i*chunksize:(i+1)*chunksize] = y[:, i*chunksize:(i+1)*chunksize] - y[:, i*chunksize:(i+1)*chunksize].mean(axis=0)
            del y
        if (p//chunksize)*chunksize < p:
            y = np.memmap(save_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, (p//chunksize)*chunksize:] = y[:, (p//chunksize)*chunksize:] - y[:, (p//chunksize)*chunksize:].mean(axis=0)
            del y        
        y = np.memmap(save_path+'y', dtype=np.float, mode='r', shape=(T,p))
    else:
        y -= y.mean(axis=0)
    
    for frac_obs in fracs_obs:
        
        print('\n')
        print('fraction observed:', str(frac_obs))
        print('\n')
        
        n_obs = np.ceil(p * frac_obs)
        mask = np.zeros((T,p),dtype=bool)
        for t in range(T):
            mask[t, np.random.choice(p, n_obs, replace=False)] = 1
        obs_scheme.mask = mask

        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        if p*T < 1e8:
            plt.figure(figsize=(20,10))
            plt.imshow(obs_scheme.mask.T, interpolation='None', aspect='auto')
            plt.grid('off')
            plt.show()

        print('computing time-lagged covariances')    
        obs_scheme.use_mask = False
        W_ = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W_,sso=True,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)

        obs_scheme.use_mask = True
        W = [ 1 / (frac_obs**2 * T * np.ones((1,1))) for m in range(len(lag_range))]
        #Om = [np.ones((p,p), dtype=bool) for m in lag_range]
        #Qs = [np.zeros((p,p)) for m in lag_range]
        
        
        pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        
        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,_,False,data_path)

        rnd_seed = np.random.get_state()
        #np.random.seed(rnd_seed)
        pars_est, traces, ts= run_default(
                    alphas    = (0.01, 0.001), 
                    b1s       = (0.98, 0.95), 
                    a_decays  = (0.98, 0.98), 
                    batch_sizes = (1, 10), 
                    max_zip_sizes =  (1000,250), 
                    max_iters = (250, 100),
                    parametrizations = ('nl', 'ln'),
                    pars_est='default', pars_true=pars_true, n=n, 
                    y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
                    idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                    traces=[[], [], []], ts = [])    
        
        # settings for GROUSE
        a_grouse = 1.
        tracker = Grouse(p, n, a_grouse )
        max_epoch_size = 1000
        max_iter_grouse = 1000
        get_obs = obs_scheme.gen_get_observed()

        # fit GROUSE
        print('\n - GROUSE')
        tracker.step = a_grouse
        ct = 1.
        error = np.zeros((max_iter_grouse, n+1))
        t = time.time()

        for i in range(max_iter_grouse):
            if verbose and np.mod(i,max_iter_grouse//10) == 0:
                print('finished % ' + str((100*i)//max_iter_grouse))
            idx = np.random.permutation(T-np.max(lag_range)-1)
            idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
            for j in range(len(idx)):
                tracker.consume(y[idx[j],:].reshape(-1,1), obs_scheme.mask[idx[j],:].reshape(-1,1))
                ct += 1     
                tracker.step = a_grouse / ct

            error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))

        t = time.time() - t

        pars_est_g = {'C' : tracker.U.copy()}
        traces_g = [error.copy()]
        ts_g = [t]
        
        # extracting dynamics for GROUSE
        print('filtering data') 
        obs_scheme.gen_mask_from_scheme()
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est_g['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
        
        kl_ = np.max(lag_range_g) + 1
        print('extracting dynamics parameters') 
        pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range_g])
        pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range_g)-1)*n,:], pars_est_g['X'][n:len(lag_range_g)*n,:])[0]
        pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
        ev_est = np.linalg.eigvals(pars_est_g['A'])
        del x_g

        plt.plot(np.real(np.linalg.eigvals( pars_est['A'])), 'go-')
        plt.plot(np.real(np.linalg.eigvals(pars_est_g['A'])), 'bo-')
        plt.plot(np.real(np.linalg.eigvals(pars_true['A'])), 'k')
        plt.show()
        plt.plot(np.imag(np.linalg.eigvals( pars_est['A'])), 'go-')
        plt.plot(np.imag(np.linalg.eigvals(pars_est_g['A'])), 'bo-')
        plt.plot(np.imag(np.linalg.eigvals(pars_true['A'])), 'k')
        plt.show()
        
        

        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : save_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_est_g' : pars_est_g,
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_g' : traces_g, 'ts':ts, 'ts_g':ts_g,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(int(100*frac_obs))
        np.savez(save_path + file_name, save_dict)    

        print('final proj. error (est.): ', str(error[-1][0]))

        plt.subplot(1,2,1)
        plt.plot(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')

        plt.subplot(1,2,2)
        plt.loglog(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.show()
        
        plt.plot(principal_angle(pars_true['C'], tracker.U), 'ro')
        plt.plot(principal_angle(pars_true['C'], pars_est['C']), 'bo')
        plt.legend(('GROUSE', 'SSIDID'))
        plt.ylabel('principal angles')
        plt.show()


In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme, progprint_xrange
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3rnd'

# define problem size
lag_range = np.arange(10)
lag_range_g = np.arange(20)

kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100030
T = 10000 + kl_

nr = 0 # number of real eigenvalues
snr = (9., 9.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True


sso = True

obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=(np.arange(p),), 
                                obs_pops=(0,), 
                                obs_time=(T,))
idx_a, idx_b = np.arange(p), np.arange(p)


rnd_seeds = range(10,20)
fracs_obs = (0.1, )

for rnd_seed in rnd_seeds:
    
    #data_path = '/media/marcel/636f7b46-1fd1-4600-b69e-86d2ed82002c/stitching/hankel/icml_e3/rnd/seed_' + str(rnd_seed) + '/'
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e1/seed_' + str(int(rnd_seed)) + '/'
    save_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/rnd/seed_' + str(int(rnd_seed)) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    
    """
    np.random.seed(rnd_seed)
    pars_true, x, y, _, _ = gen_data(p,n,lag_range,T, nr,
                                     eig_m_r, eig_M_r, 
                                     eig_m_c, eig_M_c,
                                     mmap, chunksize,
                                     data_path,
                                     snr=snr, whiten=whiten)        
    y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
    y -= y.mean(axis=0)
    #y[np.invert(obs_scheme.mask)] = np.nan
    del y
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    idx_a, idx_b = np.arange(p), np.arange(p)
    """
    
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) + 'snr' + str(np.int(np.mean(snr)//1)) + 'e1_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    pars_true = load_file['pars_true']    
    
    y_full = np.memmap(data_path+'y_full', dtype=np.float, mode='r', shape=(T_full,p))
    y = np.memmap(save_path+'y', dtype=np.float, mode='w+', shape=(T,p))
    y[:] = y_full[:T, :].copy()
    del y_full
    del y
    chunksize = np.minimum(p, 100)
    if mmap: 
        print('ensuring zero-mean data for given observation scheme')
        for i in progprint_xrange(p//chunksize, perline=10):
            y = np.memmap(save_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, i*chunksize:(i+1)*chunksize] = y[:, i*chunksize:(i+1)*chunksize] - y[:, i*chunksize:(i+1)*chunksize].mean(axis=0)
            del y
        if (p//chunksize)*chunksize < p:
            y = np.memmap(save_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, (p//chunksize)*chunksize:] = y[:, (p//chunksize)*chunksize:] - y[:, (p//chunksize)*chunksize:].mean(axis=0)
            del y        
        y = np.memmap(save_path+'y', dtype=np.float, mode='r', shape=(T,p))
    else:
        y -= y.mean(axis=0)
    
    for frac_obs in fracs_obs:
        
        print('\n')
        print('fraction observed:', str(frac_obs))
        print('\n')
        
        n_obs = np.ceil(p * frac_obs)
        mask = np.zeros((T,p),dtype=bool)
        for t in range(T):
            mask[t, np.random.choice(p, n_obs, replace=False)] = 1
        obs_scheme.mask = mask

        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        if p*T < 1e8:
            plt.figure(figsize=(20,10))
            plt.imshow(obs_scheme.mask.T, interpolation='None', aspect='auto')
            plt.grid('off')
            plt.show()

        print('computing time-lagged covariances')    
        obs_scheme.use_mask = False
        W_ = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W_,sso=True,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)

        obs_scheme.use_mask = True
        W = [ 1 / (frac_obs**2 * T * np.ones((1,1))) for m in range(len(lag_range))]
        #Om = [np.ones((p,p), dtype=bool) for m in lag_range]
        #Qs = [np.zeros((p,p)) for m in lag_range]
        
        
        pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        
        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,_,False,data_path)

        rnd_seed = np.random.get_state()
        #np.random.seed(rnd_seed)
        pars_est, traces, ts= run_default(
                    alphas    = (0.01, 0.001), 
                    b1s       = (0.98, 0.95), 
                    a_decays  = (0.98, 0.98), 
                    batch_sizes = (1, 10), 
                    max_zip_sizes =  (1000,500), 
                    max_iters = (200, 200),
                    parametrizations = ('nl', 'ln'),
                    pars_est='default', pars_true=pars_true, n=n, 
                    y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
                    idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                    traces=[[], [], []], ts = [])    
        
        # settings for GROUSE
        a_grouse = 10.
        tracker = Grouse(p, n, a_grouse )
        max_epoch_size = 1000
        max_iter_grouse = 1000
        get_obs = obs_scheme.gen_get_observed()

        # fit GROUSE
        print('\n - GROUSE')
        tracker.step = a_grouse
        ct = 1.
        error = np.zeros((max_iter_grouse, n+1))
        t = time.time()

        for i in range(max_iter_grouse):
            if verbose and np.mod(i,max_iter_grouse//10) == 0:
                print('finished % ' + str((100*i)//max_iter_grouse))
            idx = np.random.permutation(T-np.max(lag_range)-1)
            idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
            for j in range(len(idx)):
                tracker.consume(y[idx[j],:].reshape(-1,1), obs_scheme.mask[idx[j],:].reshape(-1,1))
                ct += 1     
                tracker.step = a_grouse / ct

            error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))

        t = time.time() - t

        pars_est_g = {'C' : tracker.U.copy()}
        traces_g = [error.copy()]
        ts_g = [t]
        
        # extracting dynamics for GROUSE
        print('filtering data') 
        obs_scheme.gen_mask_from_scheme()
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est_g['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
        
        kl_ = np.max(lag_range_g) + 1
        print('extracting dynamics parameters') 
        pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range_g])
        pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range_g)-1)*n,:], pars_est_g['X'][n:len(lag_range_g)*n,:])[0]
        pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
        ev_est = np.linalg.eigvals(pars_est_g['A'])
        del x_g

        plt.plot(np.real(np.linalg.eigvals( pars_est['A'])), 'go-')
        plt.plot(np.real(np.linalg.eigvals(pars_est_g['A'])), 'bo-')
        plt.plot(np.real(np.linalg.eigvals(pars_true['A'])), 'k')
        plt.show()
        plt.plot(np.imag(np.linalg.eigvals( pars_est['A'])), 'go-')
        plt.plot(np.imag(np.linalg.eigvals(pars_est_g['A'])), 'bo-')
        plt.plot(np.imag(np.linalg.eigvals(pars_true['A'])), 'k')
        plt.show()
        
        

        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : save_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_est_g' : pars_est_g,
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_g' : traces_g, 'ts':ts, 'ts_g':ts_g,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(int(100*frac_obs))
        np.savez(save_path + file_name, save_dict)    

        print('final proj. error (est.): ', str(error[-1][0]))

        plt.subplot(1,2,1)
        plt.plot(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')

        plt.subplot(1,2,2)
        plt.loglog(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.show()
        
        plt.plot(principal_angle(pars_true['C'], tracker.U), 'ro')
        plt.plot(principal_angle(pars_true['C'], pars_est['C']), 'bo')
        plt.legend(('GROUSE', 'SSIDID'))
        plt.ylabel('principal angles')
        plt.show()


# add GROUSE fits

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme, progprint_xrange
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3rnd'

# define problem size
lag_range = np.arange(10)
lag_range_g = np.arange(20)

kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100030
T = 10000 + kl_

nr = 0 # number of real eigenvalues
snr = (9., 9.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True


sso = True

obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=(np.arange(p),), 
                                obs_pops=(0,), 
                                obs_time=(T,))
idx_a, idx_b = np.arange(p), np.arange(p)


rnd_seeds = range(10,20)
fracs_obs = (0.1, )

for rnd_seed in rnd_seeds:
    
    #data_path = '/media/marcel/636f7b46-1fd1-4600-b69e-86d2ed82002c/stitching/hankel/icml_e3/rnd/seed_' + str(rnd_seed) + '/'
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e1/seed_' + str(int(rnd_seed)) + '/'
    save_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/rnd/seed_' + str(int(rnd_seed)) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    
    """
    np.random.seed(rnd_seed)
    pars_true, x, y, _, _ = gen_data(p,n,lag_range,T, nr,
                                     eig_m_r, eig_M_r, 
                                     eig_m_c, eig_M_c,
                                     mmap, chunksize,
                                     data_path,
                                     snr=snr, whiten=whiten)        
    y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
    y -= y.mean(axis=0)
    #y[np.invert(obs_scheme.mask)] = np.nan
    del y
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    idx_a, idx_b = np.arange(p), np.arange(p)
    """
    
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) + 'snr' + str(np.int(np.mean(snr)//1)) + 'e1_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    pars_true = load_file['pars_true']    
    
    y_full = np.memmap(data_path+'y_full', dtype=np.float, mode='r', shape=(T_full,p))
    y = np.memmap(save_path+'y', dtype=np.float, mode='w+', shape=(T,p))
    y[:] = y_full[:T, :].copy()
    del y_full
    del y
    chunksize = np.minimum(p, 100)
    if mmap: 
        print('ensuring zero-mean data for given observation scheme')
        for i in progprint_xrange(p//chunksize, perline=10):
            y = np.memmap(save_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, i*chunksize:(i+1)*chunksize] = y[:, i*chunksize:(i+1)*chunksize] - y[:, i*chunksize:(i+1)*chunksize].mean(axis=0)
            del y
        if (p//chunksize)*chunksize < p:
            y = np.memmap(save_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, (p//chunksize)*chunksize:] = y[:, (p//chunksize)*chunksize:] - y[:, (p//chunksize)*chunksize:].mean(axis=0)
            del y        
        y = np.memmap(save_path+'y', dtype=np.float, mode='r', shape=(T,p))
    else:
        y -= y.mean(axis=0)
    
    for frac_obs in fracs_obs:
        
        print('\n')
        print('fraction observed:', str(frac_obs))
        print('\n')
        
        n_obs = np.ceil(p * frac_obs)
        mask = np.zeros((T,p),dtype=bool)
        for t in range(T):
            mask[t, np.random.choice(p, n_obs, replace=False)] = 1
        obs_scheme.mask = mask

        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        if p*T < 1e8:
            plt.figure(figsize=(20,10))
            plt.imshow(obs_scheme.mask.T, interpolation='None', aspect='auto')
            plt.grid('off')
            plt.show()

        print('computing time-lagged covariances')    
        obs_scheme.use_mask = False
        W_ = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W_,sso=True,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)

        obs_scheme.use_mask = True
        W = [ 1 / (frac_obs**2 * T * np.ones((1,1))) for m in range(len(lag_range))]
        #Om = [np.ones((p,p), dtype=bool) for m in lag_range]
        #Qs = [np.zeros((p,p)) for m in lag_range]
        
        
        pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        
        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,_,False,data_path)

        rnd_seed = np.random.get_state()
        #np.random.seed(rnd_seed)
        pars_est, traces, ts= run_default(
                    alphas    = (0.01, 0.001), 
                    b1s       = (0.98, 0.95), 
                    a_decays  = (0.98, 0.98), 
                    batch_sizes = (1, 10), 
                    max_zip_sizes =  (1000,500), 
                    max_iters = (200, 200),
                    parametrizations = ('nl', 'ln'),
                    pars_est='default', pars_true=pars_true, n=n, 
                    y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
                    idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                    traces=[[], [], []], ts = [])    
        
        # settings for GROUSE
        a_grouse = 1.
        tracker = Grouse(p, n, a_grouse )
        max_epoch_size = 1000
        max_iter_grouse = 2000
        get_obs = obs_scheme.gen_get_observed()

        # fit GROUSE
        print('\n - GROUSE')
        tracker.step = a_grouse
        ct = 1.
        error = np.zeros((max_iter_grouse, n+1))
        t = time.time()

        for i in range(max_iter_grouse):
            if verbose and np.mod(i,max_iter_grouse//10) == 0:
                print('finished % ' + str((100*i)//max_iter_grouse))
            idx = np.random.permutation(T-np.max(lag_range)-1)
            idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
            for j in range(len(idx)):
                tracker.consume(y[idx[j],:].reshape(-1,1), obs_scheme.mask[idx[j],:].reshape(-1,1))
                ct += 1     
                tracker.step = a_grouse / ct

            error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))

        t = time.time() - t

        pars_est_g = {'C' : tracker.U.copy()}
        traces_g = [error.copy()]
        ts_g = [t]
        
        # extracting dynamics for GROUSE
        print('filtering data') 
        obs_scheme.gen_mask_from_scheme()
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est_g['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
        
        kl_ = np.max(lag_range_g) + 1
        print('extracting dynamics parameters') 
        pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range_g])
        pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range_g)-1)*n,:], pars_est_g['X'][n:len(lag_range_g)*n,:])[0]
        pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
        ev_est = np.linalg.eigvals(pars_est_g['A'])
        del x_g

        plt.plot(np.real(np.linalg.eigvals( pars_est['A'])), 'go-')
        plt.plot(np.real(np.linalg.eigvals(pars_est_g['A'])), 'bo-')
        plt.plot(np.real(np.linalg.eigvals(pars_true['A'])), 'k')
        plt.show()
        plt.plot(np.imag(np.linalg.eigvals( pars_est['A'])), 'go-')
        plt.plot(np.imag(np.linalg.eigvals(pars_est_g['A'])), 'bo-')
        plt.plot(np.imag(np.linalg.eigvals(pars_true['A'])), 'k')
        plt.show()
        
        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : save_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_est_g' : pars_est_g,
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_g' : traces_g, 'ts':ts, 'ts_g':ts_g,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(int(100*frac_obs))
        np.savez(save_path + file_name, save_dict)    

        print('final proj. error (est.): ', str(error[-1][0]))

        plt.subplot(1,2,1)
        plt.plot(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')

        plt.subplot(1,2,2)
        plt.loglog(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.show()
        
        plt.plot(principal_angle(pars_true['C'], tracker.U), 'ro')
        plt.plot(principal_angle(pars_true['C'], pars_est['C']), 'bo')
        plt.legend(('GROUSE', 'SSIDID'))
        plt.ylabel('principal angles')
        plt.show()


In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3_slim'
# define problem size
lag_range = np.arange(30)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 100000 + kl_
T = T_full

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True

rnd_seeds = range(30,40)
overlaps = (5000, 50000)


for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/explore/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
        
    idx_a, idx_b = np.arange(p), np.arange(p)

    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+'_e3_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
    pars_true = load_file['pars_true']
    sso = True

    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')
        
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap)
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()            
        pars_est = load_file['pars_est']
        traces = load_file['traces']
        ts = load_file['ts']
        W,Qs,Om =  load_file['W'],load_file['Qs'],load_file['Om']
        obs_scheme = load_file['obs_scheme']
        obs_scheme.use_mask = False # sso scheme! just to make sure...
        
        W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
        if overlap < p:
            for m in range(len(lag_range)):
                W[m][0,1] = 0
                W[m][1,0] = 0

        print('computing time-lagged covariances')
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)    
        
        
        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        

                                                   
        print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces[-1],False,data_path)
        print('est. param. loss: ', f_l2_Hankel_nl(C=pars_est['C'],X=pars_est['X'],R=pars_est['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        
                
        sub_pops = obs_scheme.sub_pops
        print('\n')
        if len(sub_pops) > 1:
            print('overlap: ', str(len(np.intersect1d(sub_pops[0], sub_pops[1]))))            

        ts_g = load_file['ts_g']
        traces_g = load_file['traces_g']
        pars_est_g = load_file['pars_est_g'].copy()
        rnd_seed = load_file['rnd_seed']
        
        """        
        # settings for GROUSE
        a_grouse = 1.
        tracker = Grouse(p, n, a_grouse )
        max_epoch_size = 1000
        max_iter_grouse = 1500
        get_obs = obs_scheme.gen_get_observed()

        # fit GROUSE
        print('\n - GROUSE')
        tracker.step = a_grouse
        ct = 1.
        error = np.zeros((max_iter_grouse, n+1))
        t = time.time()
        get_obs = obs_scheme.gen_get_observed()
        
        for i in range(max_iter_grouse):
            if np.mod(i,max_iter_grouse//10) == 0:
                print('finished % ' + str((100*i)//max_iter_grouse))
            idx = np.random.permutation(T-np.max(lag_range)-1)
            idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
            for j in range(len(idx)):
                obs_idx =  np.zeros((p,1), dtype=bool)
                obs_idx[get_obs(idx[j])] = True
                tracker.consume(y[idx[j],:].reshape(-1,1), obs_idx)
                ct += 1     
                tracker.step = a_grouse / ct
        
            error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))
        t = time.time() - t
        pars_est_g = {'C' : tracker.U.copy()}

        print('final proj. error (est.): ', str(error[-1][0]))

        plt.subplot(1,2,1)
        plt.plot(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.subplot(1,2,2)
        plt.loglog(error[:,1:])
        plt.title('subspace proj. error (GROUSE)')
        plt.show()

        print('per-subpops principal angles')
        C = pars_est_g['C'].copy()
        print(principal_angle(pars_true['C'][sub_pops[0],:], C[sub_pops[0],:]))
        print(principal_angle(pars_true['C'][sub_pops[1],:], C[sub_pops[1],:]))

        print('final principal angles')
        C = pars_est_g['C'].copy()
        print(principal_angle(pars_true['C'], C))

        print('final principal angles (sign-flipped)')
        C = pars_est_g['C'].copy()
        C[sub_pops[0],:] *= -1
        print(principal_angle(pars_true['C'], C))

        del C    
        traces_g = [error.copy()]
        ts_g = [t]         
        
        """

        print('filtering data') 
        obs_scheme.gen_mask_from_scheme()
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est_g['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)
        obs_scheme.mask = None

        print('extracting dynamics parameters') 
        pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range])
        pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range)-1)*n,:], pars_est_g['X'][n:len(lag_range)*n,:])[0]
        pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
        ev_est = np.linalg.eigvals(pars_est_g['A'])
        
        
        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : data_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_est_g' : pars_est_g,
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_g' : traces_g, 'ts':ts, 'ts_g':ts_g,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap)
        np.savez(data_path + file_name, save_dict)    


# Get estimates for linear dynamics for GROUSE

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import Grouse, calc_subspace_proj_error

mmap, verbose = True, True

p,T,n,snr = 1000, 20020, 10, (1., 1.)
rnd_seeds = range(30,40)
fracs_obs = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)

lag_range = np.arange(20)
run = '_e3rnd'

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

for rnd_seed in rnd_seeds:
    
    data_path = '/media/marcel/636f7b46-1fd1-4600-b69e-86d2ed82002c/stitching/hankel/icml_e3/rnd/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')    
    
    #file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) +  run
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(int(100*frac_obs))
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    pars_true = load_file['pars_true'].copy()
    del load_file
    kl_ = np.max(lag_range)+1
    ev_true = np.linalg.eigvals(pars_true['A'])

    for i in range(len(fracs_obs)):

        frac_obs = fracs_obs[i]

        print('T = ', T)
        print('frac_obs =', frac_obs)
        y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))


        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(int(100*frac_obs))
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
        obs_scheme, W,Qs,Om = load_file['obs_scheme'], load_file['W'],load_file['Qs'],load_file['Om']
        pars_est_g = load_file['pars_est_g']
        traces_g = load_file['traces_g']
        idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()

        plt.figure(figsize=(20,10))
        plt.imshow(obs_scheme.mask.T, interpolation='None')
        plt.show()
        
        print('filtering data') 
        tracker = Grouse(p, n, 0. )
        tracker.U = pars_est_g['C'].copy()
        x_g = np.zeros((T,n))
        for t in range (T):
            x_g[t,:] = tracker._project(y[t,:].reshape(p,1), obs_scheme.mask[t,:].reshape(p,1)).reshape(-1)

        print('extracting dynamics parameters') 
        pars_est_g['X'] = np.vstack([np.cov(x_g[m:-(kl_+1)+m, :].T, x_g[:-(kl_+1), :].T)[:n,n:] for m in lag_range])
        pars_est_g['A'] = np.linalg.lstsq(pars_est_g['X'][:(len(lag_range)-1)*n,:], pars_est_g['X'][n:len(lag_range)*n,:])[0]
        pars_est_g['Pi'] = (pars_est_g['X'][:n,:] + pars_est_g['X'][:n,:].T)/2 
        ev_est = np.linalg.eigvals(pars_est_g['A'])

        print('storing')
        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : data_path if mmap else y,
                     'pars_true' : load_file['pars_true'], 'pars_est' : load_file['pars_est'], 
                     'idx_a' : load_file['idx_a'],'idx_b' : load_file['idx_b'], 'W' : W,'Qs' : Qs,'Om' : Om,
                     'traces' : load_file['traces'], 'ts': load_file['ts'], 
                     'rnd_seed' : load_file['rnd_seed'], 
                     'pars_est_g' : pars_est_g, 'traces_g' :  load_file['traces_g'], 'ts_g': load_file['ts_g']
                    }
        
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(int(100*frac_obs))
        np.savez(data_path + file_name, save_dict)        