# Fig 2 - two subpop stitching as function of overlap

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
from scipy import spatial
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import calc_subspace_proj_error

from matplotlib import cm

fig_path =  '/home/mackelab/Desktop/Projects/Stitching/figures/'

compare_grouse_dyns = False

mmap, verbose = True, True

p,T_full,n,snr = 1000, 100030, 10, (1., 1.)
T = T_full

lag_range = np.arange(20)
overlaps = (0,10,15,20,25,50,100,300,1000)
rnd_seeds = range(30, 50)

def comp_slim(Qs,Om,lag_range,pars,idx_a,idx_b,traces=None,mmap=False,data_path=None):

    kl = len(lag_range)
    
    out = np.zeros(len(lag_range))
    
    p,n = pars['C'].shape
    pa, pb = idx_a.size, idx_b.size
    idx_ab = np.intersect1d(idx_a, idx_b)
    idx_a_ab = np.where(np.in1d(idx_a, idx_ab))[0]
    idx_b_ab = np.where(np.in1d(idx_b, idx_ab))[0]
    for m in range(kl): 
        m_ = lag_range[m] 
        Qrec = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T) 
        if m_ == 0:
            Qrec[np.ix_(idx_a_ab, idx_b_ab)] += np.diag(pars['R'][idx_ab])
        if mmap:
            Q = np.memmap(data_path+'Qs_'+str(m_), dtype=np.float, mode='r', shape=(pa,pb))
        else:
            Q = Qs[m]
        out[m] = np.corrcoef( Qrec[Om[m]].reshape(-1), (Qs[m][Om[m]]).reshape(-1) )[0,1]
        if mmap:
            del Q
    return out

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

clrs = cm.jet(np.linspace(50, 200, len(overlaps), dtype=int))
lgnd = []

plt.figure(figsize=(16,16))
idx_subplot = 2

subsp_errors_avg = np.zeros(len(overlaps))
subsp_errors_avg_g = np.zeros(len(overlaps))

redid   = np.zeros((len(overlaps), len(rnd_seeds)), dtype=bool)
bitflip = np.zeros((len(overlaps), len(rnd_seeds)), dtype=bool)
ls  = np.zeros((len(overlaps), len(rnd_seeds)))
lsf = np.zeros((len(overlaps), len(rnd_seeds)))

subsp_errors   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errorsf   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors1   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors2   = np.zeros((len(overlaps), len(rnd_seeds)))

tmp_corrs  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrsf = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrs_st  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrsf_st = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))

dyn_errors_abs = np.zeros((len(overlaps), len(rnd_seeds)))
dyn_errors_agl = np.zeros((len(overlaps), len(rnd_seeds)))

subsp_errors_g   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errorsf_g  = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors1_g   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors2_g   = np.zeros((len(overlaps), len(rnd_seeds)))
dyn_errors_abs_g = np.zeros((len(overlaps), len(rnd_seeds)))
dyn_errors_agl_g = np.zeros((len(overlaps), len(rnd_seeds)))

tmp_corrs_g  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrs_st_g  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))

tmp_corrs_true  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))


for rndsidx in range(len(rnd_seeds)):
    
    rnd_seed = rnd_seeds[rndsidx]

    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_'+str(rnd_seed)+'/'
    
    run = '_e3_init'
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) +  run
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    pars_true = load_file['pars_true'].copy()
    #lag_range = load_file['lag_range']
    pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
    
    Qs = [np.load(data_path + 'Qs_' + str(m) + '.npy') for m in range(len(lag_range))]

    ev_true = np.linalg.eigvals(pars_true['A'])
    #ev_idx = np.argsort(np.angle(ev_true))
    #ev_true = ev_true[n//2:]
    std_angl = np.std(np.angle(ev_true))
    std_tmsc = np.std(np.log(1-np.abs(ev_true)))  
    
    run = '_e3_slim'
    for i in range(len(overlaps)):
        
        overlap = overlaps[i]

        print('overlap = ', str(overlap))

        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_'+str(overlap)
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
        idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()
        obs_scheme = load_file['obs_scheme']
        pars_est = load_file['pars_est']
        pars_est_g = load_file['pars_est_g']
        traces = load_file['traces']
        traces_g = load_file['traces_g']

        print(pars_est_g.keys())
        
        sub_pops = obs_scheme.sub_pops
        Om = np.zeros((len(idx_a),len(idx_b)), dtype=bool)
        for i_ in range(len(sub_pops)):
            Om[np.ix_(sub_pops[i_], sub_pops[i_])] = True
        Om = [Om for m in range(len(lag_range))]
        
        
        ls[   i,rndsidx] = traces[0][-1][-1]
        if rnd_seed in range(30,40):
            ls[i,rndsidx] /= len(lag_range) # used an old loss function version that did not normalise by #time-lags
        
        tmp_corrs_true[i, rndsidx] = comp_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,None)

        
        C = pars_est_g['C'].copy()
        subsp_errors_g[i,rndsidx] = calc_subspace_proj_error(pars_true['C'],C)      
        subsp_errors1_g[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[0],:],C[sub_pops[0],:])      
        subsp_errors2_g[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[1],:],C[sub_pops[1],:])      

        pars_est_g['R'] = np.zeros_like(pars_true['R'])
        pars_est_g['X'] = np.vstack([ np.linalg.matrix_power(pars_est_g['A'],m).dot(pars_est_g['Pi']) for m in lag_range])        
        tmp_corrs_g[i, rndsidx,:] = comp_slim(Qs,Om,lag_range,pars_est_g,idx_a,idx_b,None,False,None)
        tmp_corrs_st_g[i, rndsidx,:] = comp_slim(Qs,[np.invert(Om[m]) for m in range(len(lag_range))],
                                               lag_range,pars_est_g,idx_a,idx_b,None,False,None)

        C[sub_pops[0],:] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est_g['C']):
            pars_est_g['C'][sub_pops[0],:] *= -1
        subsp_errorsf_g[i, rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est_g['C'])
                
        C = pars_est['C'].copy()
        subsp_errors[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], C)
        subsp_errors1[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[0],:],C[sub_pops[0],:])      
        subsp_errors2[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[1],:],C[sub_pops[1],:])      
        
        pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])        
        tmp_corrs[i, rndsidx,:] = comp_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,None)
        tmp_corrs_st[i, rndsidx,:] = comp_slim(Qs,[np.invert(Om[m]) for m in range(len(lag_range))],
                                               lag_range,pars_est,idx_a,idx_b,None,False,None)
        C[sub_pops[0],:] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est['C']):
            bitflip[i,rndsidx] = True
            print('flipping bit')
            pars_est['C'][sub_pops[0],:] *= -1    
        subsp_errorsf[i, rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est['C'])
        tmp_corrsf[i, rndsidx,:] = comp_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,None)
        tmp_corrsf_st[i, rndsidx,:] = comp_slim(Qs,[np.invert(Om[m]) for m in range(len(lag_range))],
                                               lag_range,pars_est,idx_a,idx_b,None,False,None)
        
        
                                        
        if False: #ls[i, rndsidx] > 100.:
            
                redid[i,rndsidx]   = True           

                file_name_rep = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+'_e3final_'+str(overlap)
                load_file = np.load(data_path + file_name_rep + '.npz')['arr_0'].tolist()
                idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()
                obs_scheme = load_file['obs_scheme']
                pars_est   = load_file['pars_est']
                traces     = load_file['traces']            
                C = pars_est['C'].copy()
                subsp_errors[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], C)     
                subsp_errors1[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[0],:],C[sub_pops[0],:])      
                subsp_errors2[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[1],:],C[sub_pops[1],:])      
                C[sub_pops[0],:] *= -1
                if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est['C']):
                    pars_est['C'][sub_pops[0],:] *= -1    
                    bitflip[i,rndsidx] = True
                subsp_errorsf[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est['C'])  
                lsf[i, rndsidx] =  traces[0][-1][-1]
            
    plt.subplot(7,3,idx_subplot)
    
    plt.loglog(1e-10, 1e-10, 'o-', color='b', linewidth=1.5)
    plt.loglog(1e-10, 1e-10, 'o-', color='g', linewidth=2)
    plt.loglog(1e-10, 1e-10, 'r*', markersize=10)    
    plt.loglog(1e-10, 1e-10, 'mx', markersize=8, markeredgewidth=3)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors_g[:,rndsidx], 'o-', color='b', linewidth=1.5)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors[:,rndsidx], 'o-', color='g', linewidth=2)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors1_g[:,rndsidx], 'o--', color='b', linewidth=1.5)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors2_g[:,rndsidx], 'o--', color='b', linewidth=1.5)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors1[:,rndsidx], 'o--', color='g', linewidth=2)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors2[:,rndsidx], 'o--', color='g', linewidth=2)
    plt.loglog(np.array(overlaps)[redid[:,rndsidx]]+0.1, 1.25 * subsp_errors[:,rndsidx][redid[:,rndsidx]], 'r*', markersize=10)
    plt.loglog(np.array(overlaps)[bitflip[:,rndsidx]]+0.1, 1.5 * subsp_errors[:,rndsidx][bitflip[:,rndsidx]], 'mx', markersize=8, markeredgewidth=3)
    plt.box('off')
    if rndsidx == 9:
        lgnd = ['GROUSE', 'ssidid', 'rep. fit', 'bitflip']
        plt.legend(lgnd, loc=1,frameon=False)
    plt.ylabel('subsp. proj. error')
    if rndsidx > 6:
        plt.xlabel('overlap')
    plt.title('random seed ' + str(rnd_seed))
    plt.axis([0.09, 1100, 0.01, 1.6])
    plt.xticks([0.1, 1, 10, 100, 1000], ['0', '1', '10', '100', '1000'])
        
    idx_subplot +=1
        
plt.subplot(4,3,1)
file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_'+str(100)
load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
obs_scheme = load_file['obs_scheme']
obs_scheme.gen_mask_from_scheme()        
plt.imshow(obs_scheme.mask[0:-1:1000, :].T, aspect='auto', interpolation='None')
plt.xticks([0,50,100], [0, 50000, 100000])
plt.ylabel('variable index i')
plt.xlabel('time t')
plt.title('example data observation scheme (10% overlap)')
plt.set_cmap('gray')

#plt.savefig(fig_path + 'sso_10seeds_raw.pdf')    
    
plt.show()


data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
save_dict = {
    'run' : run, 
    'p' : p,
    'T' : T, 
    'T_full' : T_full,
    'n' : n, 
    'snr' : snr, 
    'rnd_seeds' : rnd_seeds,
    'overlaps' : overlaps , 
    
    'tmp_corrs' : tmp_corrs,
    'tmp_corrsf' : tmp_corrsf,
    'tmp_corrs_st' : tmp_corrs_st,
    'tmp_corrsf_st' : tmp_corrsf_st,

    'tmp_corrs_g' : tmp_corrs_g,
    'tmp_corrs_st_g' : tmp_corrs_st_g,
    
    
    'subsp_errors' : subsp_errors,   
    'subsp_errorsf' : subsp_errorsf,   
    'subsp_errors1' : subsp_errors1,   
    'subsp_errors2' : subsp_errors2,   
    'dyn_errors_abs' : dyn_errors_abs,
    'dyn_errors_agl' : dyn_errors_agl, 
    'subsp_errors_g' : subsp_errors_g,
    'subsp_errorsf_g' : subsp_errorsf_g,   
    'subsp_errors1_g' : subsp_errors1_g,   
    'subsp_errors2_g' : subsp_errors2_g,   
    'dyn_errors_abs_g' : dyn_errors_abs_g,  
    'dyn_errors_agl_g' : dyn_errors_agl_g 
}
np.save(data_path + 'fig3_B_data', save_dict)    

    

# selection of re-runs

In [None]:
data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_'+str(30)+'/'
nrmlzrs = np.zeros(len(overlaps))
for i in range(len(overlaps)):
    overlap = overlaps[i]
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+'_e3_'+str(overlap)
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    Om = load_file['Om']
    nrmlzrs[i] = Om[0].sum()


In [None]:


for i in range(3):
    plt.loglog(0.00001, 1, 'o', color='k', linewidth=2.5)
    plt.loglog(0.00001, 1, 'o', color='r', linewidth=2.5)
#    plt.loglog(0.00001, 1, 'o', color='b', linewidth=2.5)

    
plt.legend(('good fits', 'fits rerun'), loc=1)

for i in range(len(overlaps)):
    overlap = overlaps[i]
    sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))
    tmp = np.zeros((p,p), dtype=bool)
    
    tmp[np.ix_(sub_pops[0], sub_pops[0])] = True
    
    nrmlzr = int(nrmlzrs[i]) #= (len(sub_pops[0])**2 + len(sub_pops[1])**2 + len(sub_pops[0])*len(sub_pops[1]))
    
    print(overlap, nrmlzr)
    plt.loglog(np.ones(np.sum(np.invert(redid[i,:])))*overlaps[i]+0.1, ls[i,np.invert(redid[i,:])]/nrmlzr , 'o', color='k', linewidth=2.5)
    plt.loglog(np.ones(np.sum(redid[i,:]))*overlaps[i]+0.1, ls[i,redid[i,:]]/nrmlzr , 'o', color='r', linewidth=2.5)

plt.axis([0.09, 1100, 1e-6, 1e-2])

plt.xticks([0.1, 1, 10, 100, 1000], ['0', '0.1', '1', '10', '100'])
plt.ylabel('final loss (norm. RMSE)')
plt.xlabel('overlap [%]')
plt.title('selection criterion for re-running ssidid algorithm for stitching')

#plt.savefig(fig_path + 'sso_selection_of_fits_to_redo.pdf')

plt.show()

# sso - naive FA

In [None]:
run = '_e3'
p,n,T = 1000, 10, 100010

def comp_slim(Qs,Om,lag_range,pars,idx_a,idx_b,traces=None,mmap=False,data_path=None):

    kl = len(lag_range)
    
    out = np.zeros(len(lag_range))
    
    p,n = pars['C'].shape
    pa, pb = idx_a.size, idx_b.size
    idx_ab = np.intersect1d(idx_a, idx_b)
    idx_a_ab = np.where(np.in1d(idx_a, idx_ab))[0]
    idx_b_ab = np.where(np.in1d(idx_b, idx_ab))[0]
    for m in range(kl): 
        m_ = lag_range[m] 
        Qrec = pars['C'][idx_a,:].dot(pars['X'][m*n:(m+1)*n, :]).dot(pars['C'][idx_b,:].T) 
        if m_ == 0:
            Qrec[np.ix_(idx_a_ab, idx_b_ab)] += np.diag(pars['R'][idx_ab])
        if mmap:
            Q = np.memmap(data_path+'Qs_'+str(m_), dtype=np.float, mode='r', shape=(pa,pb))
        else:
            Q = Qs[m]
        out[m] = np.corrcoef( Qrec[Om[m]].reshape(-1), (Qs[m][Om[m]]).reshape(-1) )[0,1]
        if mmap:
            del Q
    return out

lag_range = np.arange(20)

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
file_name = 'p1000n10T100010_e3_FA_all_addedDyns'
load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
subsp_errors  = np.zeros((len(overlaps), len(rnd_seeds))) 
tmp_corrs  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrs_st  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))


overlaps = load_file['overlaps']
rnd_seeds = load_file['rnd_seeds']
idx_a, idx_b = np.arange(p), np.arange(p)

for rndsidx in range(len(rnd_seeds)):
    
    pars_true = load_file['pars_true_all'][rndsidx]

    rnd_seed = rnd_seeds[rndsidx]
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_'+str(rnd_seed)+'/'    
    Qs = [np.load(data_path + 'Qs_' + str(m) + '.npy') for m in range(len(lag_range))]
    
    
    for i in range(len(overlaps)):
        
        overlap = overlaps[i]
    
        overlap =overlaps[i]
        sub_pops = [np.arange((p+overlap)//2),np.arange((p-overlap)//2,p)]
        Om = np.zeros((len(idx_a),len(idx_b)), dtype=bool)
        for i_ in range(len(sub_pops)):
            Om[np.ix_(sub_pops[i_], sub_pops[i_])] = True
        Om = [Om for m in range(len(lag_range))]
        
        print(Om[0].sum())

        pars_est = load_file['pars_est_all'][rndsidx][i] 
        pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])        
        
        subsp_errors[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est['C'])
        
        tmp_corrs[i, rndsidx,:] = comp_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,None)
        
        tmp_corrs_st[i, rndsidx,:] = comp_slim(Qs,[np.invert(Om[m]) for m in range(len(lag_range))],
                                               lag_range,pars_est,idx_a,idx_b,None,False,None)        
                                                           
                                                           
data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
save_dict = {
    'run' : run, 
    'p' : p,
    'T' : T, 
    'n' : n, 
    'rnd_seeds' : rnd_seeds,
    'overlaps' : overlaps ,     
    'subsp_errors' : subsp_errors,
    'tmp_corrs' : tmp_corrs,
    'tmp_corrs_st' : tmp_corrs_st
}
np.save(data_path + 'fig3_B_FA_data', save_dict)        


# sso - EM 

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
from scipy import spatial
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import calc_subspace_proj_error

from matplotlib import cm

run = '_e3'
p, n = 1000, 10
T = 100000 + 10

lag_range = np.arange(20)
idx_a, idx_b = np.arange(p), np.arange(p)

clrs = np.flipud(cm.hot(np.linspace(0, 0.9, len(overlaps))))

save_path = '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/e3_EM/'

rnd_seeds = range(30,50)

tmp_corrs  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrsf = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrs_st  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrsf_st = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))

subsp_errors = np.zeros((len(overlaps),len(rnd_seeds)))


plt.figure(figsize=(20,10))

rnd_seeds = range(30,40)
file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run) + '_EM_' + str(min(rnd_seeds))+ '_' +str(max(rnd_seeds)) + '_best' 
load_file = np.load(save_path + file_name + '.npz')['arr_0'].tolist() 
overlaps = load_file['overlaps']
for rndsidx in range(len(rnd_seeds)):    
    
    rnd_seed = rnd_seeds[rndsidx]
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_'+str(rnd_seed)+'/'    
    Qs = [np.load(data_path + 'Qs_' + str(m) + '.npy') for m in range(len(lag_range))]

    for i in range(len(overlaps)):
        
        overlap =overlaps[i]
        sub_pops = [np.arange((p+overlap)//2),np.arange((p-overlap)//2,p)]
        Om = np.zeros((len(idx_a),len(idx_b)), dtype=bool)
        for i_ in range(len(sub_pops)):
            Om[np.ix_(sub_pops[i_], sub_pops[i_])] = True
        Om = [Om for m in range(len(lag_range))]
        
        print(Om[0].sum())
                
        l = load_file['traces_all'][rndsidx][i][0][-1]
        pars_true = load_file['pars_true_all'][rndsidx]
        pars_est = load_file['pars_est_all' ][rndsidx][i]
        subsp_errors[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], 
                                           pars_est['C'], ortho=False)     
        pars_est['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_est['A'], pars_est['Q'])             
        pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])        
        tmp_corrs[i, rndsidx,:] = comp_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,None)
        
        tmp_corrs_st[i, rndsidx,:] = comp_slim(Qs,[np.invert(Om[m]) for m in range(len(lag_range))],
                                               lag_range,pars_est,idx_a,idx_b,None,False,None)
        plt.plot(subsp_errors[i,rndsidx], l, 's', color=clrs[i])

        C = pars_est['C'].copy()
        C[sub_pops[0],:] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est['C']):
            print('flipping bit')
            pars_est['C'][sub_pops[0],:] *= -1    
        subsp_errorsf[i, rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est['C'])
        tmp_corrsf[i, rndsidx,:] = comp_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,None)
        tmp_corrsf_st[i, rndsidx,:] = comp_slim(Qs,[np.invert(Om[m]) for m in range(len(lag_range))],
                                               lag_range,pars_est,idx_a,idx_b,None,False,None)
        
        print( tmp_corrsf_st[i, rndsidx,:])
        
        
        
rndsidx_offset = len(rnd_seeds)
rnd_seeds = range(40,50)
file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run) + '_EM_' + str(min(rnd_seeds))+ '_' +str(max(rnd_seeds)) + '_best' 
load_file = np.load(save_path + file_name + '.npz')['arr_0'].tolist() 
overlaps = load_file['overlaps']

for rndsidx in range(len(rnd_seeds)):    
    
    rnd_seed = rnd_seeds[rndsidx]
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_'+str(rnd_seed)+'/'    
    Qs = [np.load(data_path + 'Qs_' + str(m) + '.npy') for m in range(len(lag_range))]

    for i in range(len(overlaps)):
                                  
        overlap =overlaps[i]
        sub_pops = [np.arange((p+overlap)//2),np.arange((p-overlap)//2,p)]
        Om = np.zeros((len(idx_a),len(idx_b)), dtype=bool)
        for i_ in range(len(sub_pops)):
            Om[np.ix_(sub_pops[i_], sub_pops[i_])] = True
        Om = [Om for m in range(len(lag_range))]
        
        print(Om[0].sum())        
                                             
        pars_est = load_file['pars_est_all' ][rndsidx][i]
        pars_true = load_file['pars_true_all'][rndsidx]
        l = load_file['traces_all'][rndsidx][i][0][-1]
        subsp_errors[i,rndsidx+rndsidx_offset] = calc_subspace_proj_error(pars_true['C'], 
                                                 pars_est['C'], ortho=False)                
        
        pars_est['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_est['A'], pars_est['Q'])     
        pars_est['X'] = np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range])        
        tmp_corrs[i, rndsidx+rndsidx_offset,:] = comp_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,None)
        tmp_corrs_st[i, rndsidx+rndsidx_offset,:] = comp_slim(Qs,[np.invert(Om[m]) for m in range(len(lag_range))],
                                               lag_range,pars_est,idx_a,idx_b,None,False,None)
        
        C = pars_est['C'].copy()
        C[sub_pops[0],:] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est['C']):
            print('flipping bit')
            pars_est['C'][sub_pops[0],:] *= -1    
        subsp_errorsf[i, rndsidx+rndsidx_offset] = calc_subspace_proj_error(pars_true['C'], pars_est['C'])
        tmp_corrsf[i, rndsidx+rndsidx_offset,:] = comp_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,None,False,None)
        tmp_corrsf_st[i, rndsidx+rndsidx_offset,:] = comp_slim(Qs,[np.invert(Om[m]) for m in range(len(lag_range))],
                                               lag_range,pars_est,idx_a,idx_b,None,False,None)        
        
        print( tmp_corrsf_st[i, rndsidx+rndsidx_offset,:])
        
        
        plt.plot(subsp_errors[i,rndsidx+rndsidx_offset], l, 's', color=clrs[i])
plt.show()

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
save_dict = {
    'run' : run, 
    'p' : p,
    'T' : T, 
    'T_full' : T,
    'n' : n, 
    'overlaps' : overlaps ,     
    'subsp_errors' : subsp_errors,   
    'tmp_corrs' : tmp_corrs,
    'tmp_corrs_st' : tmp_corrs_st,
    'tmp_corrsf' : tmp_corrsf,
    'tmp_corrsf_st' : tmp_corrsf_st
    
}
np.save(data_path + 'fig3_B_EM_final_data', save_dict)       

# add factor analysis (properly stitched)

In [None]:
from scipy import linalg as la 
data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/e3_EM/'
load_file = np.load(data_path + 'p1000n10T100010_e3_FAsp_30_49.npz')['arr_0'].tolist()
overlaps = load_file['overlaps']
rnd_seeds = load_file['rnd_seeds']

save_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'   
file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run) + '_FA_' + str(min(rnd_seeds))+ '_' +str(max(rnd_seeds)) + '_addedDyns'
np.savez(save_path + file_name, load_file)  

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

run = '_e3'

subsp_errors_FAst    = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors_FAst_f  = np.zeros((len(overlaps), len(rnd_seeds)))

tmp_corrs  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrsf = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrs_st  = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))
tmp_corrsf_st = np.zeros((len(overlaps), len(rnd_seeds), len(lag_range)))

for i in range(len(overlaps)):
    overlap = overlaps[i]
    rnd_seed = rnd_seeds[rndsidx]
    sub_pops = [np.arange((p+overlap)//2),np.arange((p-overlap)//2,p)]
    idx_overlap = np.intersect1d(sub_pops[0], sub_pops[1])
    idx1 = np.intersect1d(sub_pops[0], idx_overlap)
    idx2 = np.arange(np.intersect1d(sub_pops[1], idx_overlap).size)
    print(len(idx1))
    print(len(idx2))
    for rndsidx in range(len(rnd_seeds)):

        pars_est_sp1 = load_file['pars_est_sp1_all'][rndsidx][i]
        pars_est_sp2 = load_file['pars_est_sp2_all'][rndsidx][i]
        pars_true = load_file['pars_true_all'][rndsidx]
        p,n = pars_true['C'].shape


        if overlap > 0:
            W, sclale = la.orthogonal_procrustes(pars_est_sp1['C'][idx1,:],
                                     pars_est_sp2['C'][idx2,:])
        else:
            W = np.eye(n)
        C12 = np.zeros((p,n))    
        C12[sub_pops[1],:] = pars_est_sp2['C']
        C12[sub_pops[0],:] = pars_est_sp1['C'].dot(W)

        subsp_errors_FAst[i,rndsidx]    = calc_subspace_proj_error(pars_true['C'], C12)
        subsp_errors_FAst_f[i, rndsidx] = calc_subspace_proj_error(pars_true['C'], C12)
    
        C = C12.copy()
        C[sub_pops[0],:] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], C12):
            C12[sub_pops[0],:] *= -1    
            subsp_errors_FAst_f[i, rndsidx] = calc_subspace_proj_error(pars_true['C'], C12)
                                               
        
        
plt.semilogx(np.array(overlaps)+0.01, subsp_errors_FAst, 'k')
plt.semilogx(np.array(overlaps)+0.01, subsp_errors_FAst_f, 'g')
plt.show()

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
save_dict = {
    'run' : run, 
    'p' : p,
    'n' : n, 
    'overlaps' : overlaps ,
    'rnd_seeds' : rnd_seeds,
    'subsp_errors_FAst' : subsp_errors_FAst,   
}
#np.save(data_path + 'fig3_B_FAst_final_data', save_dict)       


# figure 3:  actual figure

In [None]:

% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
from scipy import spatial
import glob, os, psutil, time
import itertools

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import calc_subspace_proj_error

from matplotlib import cm
algo_name = 'S3ID'
algo_name_EM = 'sEM'
color_us =            (244/255, 152/255, 25/255)
color_us_bitflipped = 'r'  #(0,0,0)
color_GROUSE =        'b'  #0.7 * np.array((244/255, 152/255, 25/255))
color_FA  =            0.4 * np.ones(3)
color_FAst =           'g'
#colors_EM = np.linspace(0.5, 1, len(numsIter)).reshape(len(numsIter),1) * np.array([1,0,0])
colors_EM = 'm' #np.array([1,0,0])

use_SEM = True
XTimes = 1.
fig_path =  '/home/mackelab/Desktop/Projects/Stitching/figures/'

compare_grouse_dyns = False

mmap, verbose = True, True

p,T,n,snr = 1000, 100030, 10, (1., 1.)

plt.figure(figsize=(16,5))


# sso results

plt.subplot(1,2,1)

# add legend to this subplot! 
lgnd = ['FA (naive)', 'GROUSE', algo_name_EM, algo_name]
plt.plot(1e-20, -10, 'o-', color=color_FA, linewidth=2.5)               # add legend for naive FA results
plt.plot(1e-20, -10, 'o-', color=color_GROUSE, linewidth=2.5)           # add legend for GROUSE results
#plt.plot(1e-20, -10, 'o-', color=color_FAst, linewidth=2.5)             # add legend for stitching FA results
plt.plot(1e-20, -10, 'o-', color=colors_EM, linewidth=2.5)               # add legend for EM algorithm
plt.plot(1e-20, -10, 'o-', color=color_us, linewidth=2.5)               # add legend for our results
#for j in range(len(numsIter)):
#    plt.plot(1e-20, -10, 'o-', color=colors_EM[j], linewidth=2.5)       # add legend for EM algorithm
#    lgnd.append('EM (' + str(numsIter[j]) + ' iterations)')
plt.plot(1e-20, -10, 'o', color=color_us_bitflipped, linewidth=2.5)     # add legend for us, bit-flipped
lgnd.append(algo_name + '\n (bit flipped)')


# add FA results
data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
load_file = np.load(data_path + 'fig3_B_FA_data.npy').tolist()
 
overlaps = load_file['overlaps'][::-1]              # need to sort these ascending 
subsp_errors    =  load_file['subsp_errors'][::-1]  # for fill_between below 

numRuns = subsp_errors.shape[1]

idx_pos = np.arange(1, len(overlaps), dtype=int)
overlaps_pos = np.array(overlaps)[idx_pos]
m = np.mean(subsp_errors[idx_pos,:], axis=1)
plt.semilogx(overlaps_pos, m, 'o-', color=color_FA, linewidth=2.5)    
s = XTimes * np.std( subsp_errors[idx_pos,:], axis=1)# / np.sqrt(subsp_errors[idx_pos,:].shape[1])
s /= np.sqrt(numRuns) if use_SEM else s
plt.fill_between(overlaps_pos, m-s, m+s, where=m+s>=m-s, 
                 facecolor=color_FA, alpha=0.5)
# add zero overlap
m = np.mean(subsp_errors[0,:])
s = XTimes * np.std( subsp_errors[0,:]) #/ np.sqrt(subsp_errors[0,:].size)
s /= np.sqrt(numRuns) if use_SEM else s

plt.semilogx(5, m, 'o-', color=color_FA, linewidth=2.5)
plt.fill_between([4.9, 5.1], (m-s) * np.ones(2),(m+s) * np.ones(2), where=None, 
                 facecolor=color_FA, alpha=0.5)


# add GROUSE results
load_file = np.load(data_path + 'fig3_B_data.npy').tolist()

overlaps = load_file['overlaps']
subsp_errors_g   =  load_file['subsp_errors_g']
subsp_errorsf_g  =  load_file['subsp_errorsf_g']
subsp_errors1_g  =  load_file['subsp_errors1_g']
subsp_errors2_g  =  load_file['subsp_errors2_g']
dyn_errors_abs_g =  load_file['dyn_errors_abs_g']
dyn_errors_agl_g =  load_file['dyn_errors_agl_g']

numRuns = subsp_errors_g.shape[1]

idx_pos = np.arange(1, len(overlaps), dtype=int)
overlaps_pos = np.array(overlaps)[idx_pos]
plt.semilogx(overlaps_pos, np.mean(subsp_errors_g[idx_pos,:], axis=1), 'o-', color=color_GROUSE, linewidth=2.5)

# add std / sem shaded areas

m = np.mean(subsp_errors_g[idx_pos,:], axis=1)
s = XTimes * np.std( subsp_errors_g[idx_pos,:], axis=1) #/ np.sqrt(subsp_errors_g[idx_pos,:].shape[1])
s /= np.sqrt(numRuns) if use_SEM else s
plt.fill_between(overlaps_pos, m-s, m+s, where=m+s>=m-s, 
                 facecolor=color_GROUSE, alpha=0.5)    

# add raw results for zero overlap:
m = np.mean(subsp_errors_g[0,:])
s = XTimes * np.std( subsp_errors_g[0,:]) #/ np.sqrt(subsp_errors_g[0,:].size)
s /= np.sqrt(numRuns) if use_SEM else s
plt.semilogx(5, m, 'o-', color=color_GROUSE, linewidth=2.5)
plt.fill_between([4.9, 5.1], (m-s) * np.ones(2),(m+s) * np.ones(2), where=None, 
                 facecolor=color_GROUSE, alpha=0.5)
    

# add EM results

load_file = np.load(data_path + 'fig3_B_EM_final_data.npy').tolist()
overlaps        = load_file['overlaps'][::-1]      # need to sort these ascending 
subsp_errors    =  load_file['subsp_errors'][::-1] # for fill_between below 

numRuns = subsp_errors.shape[1]

idx_pos = np.arange(1, len(overlaps), dtype=int)
overlaps_pos = np.array(overlaps)[idx_pos]

m = np.mean(subsp_errors[idx_pos,:], axis=1)
plt.semilogx(overlaps_pos, m, 'o-', color=colors_EM, linewidth=2.5)    
s = XTimes * np.std( subsp_errors[idx_pos,:], axis=1) #/ np.sqrt(subsp_errors[idx_pos,:].shape[1])
s /= np.sqrt(numRuns) if use_SEM else s
plt.fill_between(overlaps_pos, m-s, m+s, where=m+s>=m-s, 
                 facecolor=colors_EM, alpha=0.5)

m = np.mean(subsp_errors[0,:])
plt.semilogx(5, m, 'o-', color=colors_EM, linewidth=2.5)    
s = XTimes * np.std( subsp_errors[0,:]) #/ np.sqrt(subsp_errors[0,:].size)
s /= np.sqrt(numRuns) if use_SEM else s
plt.fill_between([4.9, 5.1], (m-s) * np.ones(2),(m+s) * np.ones(2), where=None, 
                 facecolor=colors_EM, alpha=0.5)

# add SSID results
load_file = np.load(data_path + 'fig3_B_data.npy').tolist()
overlaps = load_file['overlaps']
subsp_errors    =  load_file['subsp_errors']
subsp_errorsf   =  load_file['subsp_errorsf']
subsp_errors1   =  load_file['subsp_errors1']
subsp_errors2   =  load_file['subsp_errors2']
dyn_errors_abs   =  load_file['dyn_errors_abs']
dyn_errors_agl   =  load_file['dyn_errors_agl']
        
numRuns = subsp_errors.shape[1]
    
m = np.mean(subsp_errorsf[0,:])
plt.semilogx(5, m, 'o', color=color_us_bitflipped, linewidth=2.5)

idx_pos = np.arange(1, len(overlaps), dtype=int)
overlaps_pos = np.array(overlaps)[idx_pos]
plt.semilogx(overlaps_pos, np.mean(subsp_errors[  idx_pos,:], axis=1), 'o-', color=color_us, linewidth=2.5)    

# add std / sem shaded areas
m = np.mean(subsp_errors[idx_pos,:], axis=1)
s = XTimes * np.std( subsp_errors[idx_pos,:], axis=1) #/ np.sqrt(subsp_errors[idx_pos,:].shape[1])
s /= np.sqrt(numRuns) if use_SEM else s
plt.fill_between(overlaps_pos, m-s, m+s, where=m+s>=m-s, 
                 facecolor=color_us, alpha=0.5)

# add bitflipped results for zero overlap:
m = np.mean(subsp_errorsf[0,:])
s = XTimes * np.std( subsp_errorsf[0,:]) 
s /= np.sqrt(numRuns) if use_SEM else s
plt.semilogx(5, m, 'o-', color=color_us_bitflipped, linewidth=2.5)
plt.fill_between([4.9, 5.1], (m-s) * np.ones(2),(m+s) * np.ones(2), where=None, 
                 facecolor=color_us_bitflipped, alpha=0.5)

# add raw results for zero overlap:
m = np.mean(subsp_errors[0,:])
s = XTimes * np.std( subsp_errors[0,:]) #/ np.sqrt(subsp_errors[0,:].size)
s /= np.sqrt(numRuns) if use_SEM else s
plt.semilogx(5, m, 'o-', color=color_us, linewidth=2.5)
plt.fill_between([4.9, 5.1], (m-s) * np.ones(2),(m+s) * np.ones(2), where=None, 
                 facecolor=color_us, alpha=0.5)
    
"""    
data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
load_file = np.load(data_path + 'fig3_B_FAst_final_data.npy').tolist()          
overlaps  = np.array(load_file['overlaps'])[::-1]
rnd_seeds = np.array(load_file['rnd_seeds'])[::-1]
subsp_errors_FAst = np.flipud(load_file['subsp_errors_FAst'])

numRuns = subsp_errors_FAst.shape[1]

idx_pos = np.arange(1, len(overlaps), dtype=int)
overlaps_pos = np.array(overlaps)[idx_pos]
m = np.mean(subsp_errors_FAst[idx_pos,:], axis=1)
plt.semilogx(overlaps_pos, m, 'o-', color=color_FAst, linewidth=2.5)    
s = XTimes * np.std( subsp_errors_FAst[idx_pos,:], axis=1)# / np.sqrt(subsp_errors[idx_pos,:].shape[1])
s /= np.sqrt(numRuns) if use_SEM else s
plt.fill_between(overlaps_pos, m-s, m+s, where=m+s>=m-s, 
                 facecolor=color_FAst, alpha=0.5)
# add zero overlap
m = np.mean(subsp_errors_FAst[0,:])
s = XTimes * np.std( subsp_errors_FAst[0,:]) #/ np.sqrt(subsp_errors[0,:].size)
s /= np.sqrt(numRuns) if use_SEM else s
plt.semilogx(1, m, 'o-', color=color_FAst, linewidth=2.5)
plt.fill_between([0.98, 1.02], (m-s) * np.ones(2),(m+s) * np.ones(2), where=None, 
                 facecolor=color_FAst, alpha=0.5)
"""


plt.xticks([1, 10, 100, 1000], ['0', '1', '10', '100'])
plt.legend(lgnd, loc=1,frameon=False)
plt.axis([4.5, 1100, 0.0, 0.9])
plt.yticks([0., 0.2, 0.4, 0.6, 0.8])
plt.ylabel('subsp. proj. error')
plt.xlabel('overlap o')

plt.fill_between(5 * np.array([0.925, 1/0.925]), 
                 [-.05, -.05], 
                 [0.85, 0.85],
                 edgecolor='k', 
                 where=None, facecolor='k', alpha=0.15)
plt.box('off')




lag_range = np.arange(20)
idx_overlaps = [0, 1, 7]

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
alphas = np.ones(len(overlaps))#np.linspace(0.5, 1.0, len(overlaps))
i_ = 0
for i in idx_overlaps: #range(0,len(overlaps)):
    plt.subplot(3,4,11-i_*4)

    """
    # add naive FA results
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/'
    load_file = np.load(data_path + 'fig3_B_FA_data.npy').tolist()
    overlaps = load_file['overlaps']
    tmp_corrs = load_file['tmp_corrs']
    tmp_corrs_st = load_file['tmp_corrs_st']
    idx_pos = np.arange(0, len(overlaps), dtype=int)
    overlaps_pos = np.array(overlaps)[idx_pos]
    numRuns = tmp_corrs_st.shape[1]
    m  = np.flipud(np.mean(tmp_corrs_st[idx_pos,:,:],axis=1))
    s  = np.flipud(XTimes * np.std(tmp_corrs_st[idx_pos,:,:], axis=1))
    s /= np.sqrt(numRuns) if use_SEM else s
    plt.plot(lag_range, m[i,:].T, 'o-', color=color_FA, linewidth=2, alpha = alphas[i])
    plt.fill_between(lag_range, m[i,:]-s[i,:], m[i,:]+s[i,:], where=None, 
                     facecolor=color_FA, alpha=0.5)

    """
    
    # add GROUSE results
    load_file = np.load(data_path + 'fig3_B_data.npy').tolist()
    overlaps = load_file['overlaps']
    tmp_corrsf_g = load_file['tmp_corrs_g']
    tmp_corrsf_st_g = load_file['tmp_corrs_st_g']
    idx_pos = np.arange(0, len(overlaps), dtype=int)
    overlaps_pos = np.array(overlaps)[idx_pos]
    numRuns = tmp_corrsf_g.shape[1]
    m  = np.mean(tmp_corrsf_st_g[idx_pos,:,:],axis=1)
    s  = XTimes * np.std(tmp_corrsf_st_g[idx_pos,:,:], axis=1)
    s /= np.sqrt(numRuns) if use_SEM else s
    plt.plot(lag_range,m[i,:].T, 'o-', color=color_GROUSE, linewidth=2, alpha = alphas[i])
    plt.fill_between(lag_range, m[i,:]-s[i,:], m[i,:]+s[i,:], where=None, 
                     facecolor=color_GROUSE, alpha=0.5)
    
    # add EM results

    load_file = np.load(data_path + 'fig3_B_EM_final_data.npy').tolist()
    overlaps = load_file['overlaps']
    tmp_corrs = load_file['tmp_corrs']
    tmp_corrs_st = load_file['tmp_corrs_st']
    idx_pos = np.arange(0, len(overlaps), dtype=int)
    overlaps_pos = np.array(overlaps)[idx_pos]
    numRuns = tmp_corrs_st.shape[1]
    m  = np.flipud(np.mean(tmp_corrs_st[idx_pos,:,:],axis=1))
    s  = np.flipud(XTimes * np.std(tmp_corrs_st[idx_pos,:,:], axis=1))
    s /= np.sqrt(numRuns) if use_SEM else s
    plt.plot(lag_range, m[i,:].T, 'o-', color=colors_EM, linewidth=2, alpha = alphas[i])
    plt.fill_between(lag_range, m[i,:]-s[i,:], m[i,:]+s[i,:], where=None, 
                     facecolor=colors_EM, alpha=0.5)


    # add SSID results
    load_file = np.load(data_path + 'fig3_B_data.npy').tolist()
    overlaps = load_file['overlaps']
    tmp_corrs = load_file['tmp_corrs']
    tmp_corrs_st = load_file['tmp_corrs_st']        
    idx_pos = np.arange(0, len(overlaps), dtype=int)
    overlaps_pos = np.array(overlaps)[idx_pos]
    numRuns = tmp_corrs_st.shape[1]
    m  = np.mean(tmp_corrs_st[idx_pos,:,:],axis=1)
    s  = XTimes * np.std(tmp_corrs_st[idx_pos,:,:], axis=1)
    s /= np.sqrt(numRuns) if use_SEM else s
    plt.plot(lag_range, m[i,:].T, 'o-', color=color_us, linewidth=2, alpha = alphas[i])
    plt.fill_between(lag_range, m[i,:]-s[i,:], m[i,:]+s[i,:], where=None, 
                     facecolor=color_us, alpha=0.5)
    
    
    if i_ == 2:
        #plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
        plt.axis([0, np.max(lag_range)+.5, 0.89, 1.005])
        plt.xticks([])
        plt.yticks([0.9, 0.95, 1.0])
        plt.title('30 % overlap')
    if i_ == 1:
        #plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
        plt.axis([0, np.max(lag_range)+.5, 0.6, 1.02])
        plt.yticks([0.6, 0.8, 1.0])
        plt.xticks([])
        plt.title('1 % overlap')
    if i_ == 0:
        
        tmp_corrsf = load_file['tmp_corrsf']
        tmp_corrsf_st = load_file['tmp_corrsf_st']    
        numRuns = tmp_corrsf_st.shape[1]
        m  = np.mean(tmp_corrsf_st[idx_pos,:,:],axis=1)
        s  = XTimes * np.std(tmp_corrsf_st[idx_pos,:,:], axis=1)
        s /= np.sqrt(numRuns) if use_SEM else s
        plt.plot(lag_range, m[i,:].T, 'o-', color=color_us_bitflipped, linewidth=2.5, alpha = alphas[i])
        plt.fill_between(lag_range, m[i,:]-s[i,:], m[i,:]+s[i,:], where=None, 
                         facecolor=color_us_bitflipped, alpha=0.5)
        
        #plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
        plt.axis([0, np.max(lag_range)+.5, -0.05, 1.02])
        plt.yticks([0.0, 0.5, 1.0])
        plt.xlabel('time-lag s')
        plt.title('0 % overlap')
    plt.ylabel('corr. of cov.')
    plt.box('off')
    plt.tick_params(axis="both", which="both", top="off", right="off", labelleft="on", tickdir='out')   

    i_ += 1
    
    
plt.subplot(1,4,4)
data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/e3_EM/'
load_file = np.load(data_path + 'p1000n10T100010_e3_EM_perf_30_49_best.npy').tolist()
perf_traces = load_file['perf_traces']
overlaps = load_file['overlaps']
rnd_seeds = load_file['rnd_seeds']
overlaps_plot = [0, 1, 2, 4, 7, 8] #np.arange(len(overlaps))
#clrs = cm.hsv(np.linspace(30, 120, len(overlaps_plot), dtype=int))

clrs = cm.Purples(np.linspace(100, 255, len(overlaps_plot), dtype=int)[::-1])
for i_ in range(len(overlaps_plot)): 
    i = overlaps_plot[i_]    
    tmp = perf_traces[i,:,:].copy()    
    for rndsidx in range(perf_traces.shape[1]):
        idx = tmp[rndsidx,:] == 0
        if idx.sum()>0:
            tmp[rndsidx,idx] = tmp[rndsidx, np.where(idx)[0][0]-1]
    plt.plot(np.mean(tmp, axis=0), color=clrs[i_], linewidth=4)

for i_ in range(len(overlaps_plot)): 
    i = overlaps_plot[i_]
    tmp = perf_traces[i,:,:].copy()    
    tmp[tmp==0] = np.nan
    plt.plot(tmp.T, '-', color=clrs[i_], linewidth=1, alpha=0.4)

plt.box('off')
plt.tick_params(axis="both", which="both", top="off", right="off", labelleft="on", tickdir='out')      
plt.legend([ 'o = '+str(overlaps[i]/10) + '%' for i in overlaps_plot], loc=1)
plt.xlabel('EM iterations')
plt.ylabel('subsp. proj. error')
plt.yticks([0., 0.2, 0.4, 0.6, 0.8])
plt.axis([0, 200, 0, 0.9])
    


plt.savefig(fig_path + 'fig2_SEMs.pdf')

plt.show()

# spiking-data figure