# save-file trimming

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
from scipy import spatial
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import calc_subspace_proj_error

from matplotlib import cm


mmap, verbose = True, True
p,T,n,snr = 1000, 100030, 10, (1., 1.)
rnd_seeds = range(30,40)

exp_id = 'sso'

if exp_id == 'sso':

    exp_vars = (1000,) # overlaps
    factor = 1.
    run = '_e3'
else:
    exp_vars  = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0) # fractions observed
    factor = 100  # switch to %
    run = '_e3rnd'


for rnd_seed in rnd_seeds:

    
    print('rnd_seed  =', str(rnd_seed))

    
    #data_path =  '/media/marcel/636f7b46-1fd1-4600-b69e-86d2ed82002c/stitching/hankel/icml_e3/rnd/seed_' + str(rnd_seed) + '/'
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_'+str(rnd_seed)+'/'

    try:
        for i in range(len(exp_vars)):
            exp_var = exp_vars[i]

            print('exp_var = ', str(exp_var))

            file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(int(factor*exp_var))
            load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
            idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()
            obs_scheme = load_file['obs_scheme']
            pars_est = load_file['pars_est']
            pars_est_g = load_file['pars_est_g']
            traces = load_file['traces']
            traces_g = load_file['traces_g']
            lag_range = load_file['lag_range']
            ts = load_file['ts']
            ts_g =  load_file['ts_g']
            rnd_seed =  load_file['rnd_seed']
            W = load_file['W']

            save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                         'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : data_path if mmap else y,
                         'pars_true' : 'see _init.npz file', 'pars_est' : pars_est, 'pars_est_g' : pars_est_g,
                         'idx_a' : idx_a,'idx_b' : idx_b, 'W' : W,'Qs' : None, 'Om' : None,
                         'traces' : traces, 'traces_g' : traces_g, 'ts':ts, 'ts_g':ts_g,
                         'rnd_seed' : rnd_seed
                        }
            file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_slim_'+str(int(factor*exp_var))
            np.savez(data_path + file_name, save_dict)    
            #print('found final fit. Trimming file size')
    except:
        pass

# illustration figures - observation schemes

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default

run = '_e3rnd'


##############################
# Data missing at random     #
##############################


# define problem size
lag_range = np.arange(20)
kl_ = np.max(lag_range)+1
p, n = 100, 10
T_full = 200 + kl_
T = T_full

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True


sso = True

obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=(np.arange(p),), 
                                obs_pops=(0,), 
                                obs_time=(T,))

frac_obs = 0.3
n_obs = np.ceil(p * frac_obs)
mask = np.zeros((T,p),dtype=bool)
for t in range(T):
    mask[t, np.random.choice(p, n_obs, replace=False)] = 1
obs_scheme.mask = mask
plt.imshow(mask.T, interpolation='None', aspect='auto')
plt.xlabel('T')
plt.ylabel('i')
plt.xticks([100, 200])
plt.yticks([1, 50, 100])
plt.savefig(fig_path + 'random.pdf')
plt.show()


##############################
# Serial subset observations #
##############################

# define problem size
lag_range = np.arange(30)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 1000 + kl_
T = T_full

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.99, 0.90, 0.99

# I/O matter
mmap, chunksize = True, np.min((p,1000))
verbose=True


idx_a, idx_b = np.arange(p), np.arange(p)

sso = True

overlap = 1000

# compute length of recordings to keep total observation count stable    
print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))

reps = 1
obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=sub_pops, 
                                obs_pops=obs_pops, 
                                obs_time=obs_time)


obs_scheme.gen_mask_from_scheme()

plt.imshow(obs_scheme.mask.T, interpolation='None', aspect='auto')
plt.xlabel('T')
plt.ylabel('i')
plt.xticks([500, 1000], ['50.000', '100.000'])
plt.yticks([1, 500, 1000])
plt.savefig(fig_path + 'full_overlap.pdf')
plt.clim([0,1])
plt.show()


p,T = 41, 1200
sub_pops = (np.arange(21),np.arange(20,p))

reps = 1
obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])

obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=sub_pops, 
                                obs_pops=obs_pops, 
                                obs_time=obs_time)


plt.figure(figsize=(10,6))
obs_scheme.gen_mask_from_scheme()        
mask = obs_scheme.mask
plt.imshow(mask.T, aspect='auto', interpolation='None')
plt.yticks([0,20,40], ['1', '21', '41'])
plt.xticks([0,599,1199], ['1', '600', '1200'])
plt.ylabel('imaging plane z')
plt.xlabel('time t')
plt.title('observation scheme')
plt.set_cmap('gray')
#ax.get_xaxis().tick_bottom()    
#ax.get_yaxis().tick_left()    
plt.tick_params(axis="both", which="both", top="off", right="off", labelleft="on", tickdir='out')    
plt.savefig(fig_path + 'zebrafish_observation_scheme_2sp.pdf')
plt.show()


# Figure 2 a

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
from scipy import spatial
import glob, os, psutil, time
import itertools

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import calc_subspace_proj_error

from matplotlib import cm

fig_path =  '/home/mackelab/Desktop/Projects/Stitching/figures/'

compare_grouse_dyns = False

mmap, verbose = True, True

plot_subsp_vs_dyns = 'subsp'


lag_range = np.arange(10)
kl_ = np.max(lag_range) + 1
p,T_full,n,snr = 1000, 100030, 10, (9., 9.)
Ts = np.array([1000, 3000, 5000, 10000, 30000, 50000, 100000]) + kl_
#Ts = np.array([1000, 3000, 10000, 30000, 100000]) + kl_
Ts_g = Ts + 30 - kl_
rnd_seeds = range(10, 20)


perms = np.array(list(itertools.permutations(range(n//2))))

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

def find_nearest_pairs(x,y): 
    n,d = x.shape
    assert np.all((n,d) == y.shape)

    out = np.zeros((n,2),dtype=int)
    rx, ry = np.arange(n,dtype=int), np.arange(n,dtype=int)
    pwd = spatial.distance.squareform(spatial.distance.pdist(np.vstack((x,y))))[:n,n:]
    for i in range(n):
        out[i,:] = np.unravel_index(np.argmin(pwd[np.ix_(rx,ry)]), dims=(len(rx),len(ry)))
        out[i,:] = [rx[out[i,0]], ry[out[i,1]]]
        rx = np.setdiff1d(rx, out[i,0])        
        ry = np.setdiff1d(ry, out[i,1])     
        
    return out

if plot_subsp_vs_dyns =='subsp':
    plt.figure(figsize=(16,16))
idx_subplot = 2
clrs = cm.jet(np.linspace(50, 200, n//2, dtype=int))
lgnd = []

redid = np.zeros((len(Ts), len(rnd_seeds)), dtype=bool)
ls = np.zeros((len(Ts), len(rnd_seeds)))

subsp_errors   = np.zeros((len(Ts), len(rnd_seeds)))
dyn_errors_abs = np.zeros((len(Ts), len(rnd_seeds)))
dyn_errors_agl = np.zeros((len(Ts), len(rnd_seeds)))

subsp_errors_g   = np.zeros((len(Ts), len(rnd_seeds)))
dyn_errors_abs_g = np.zeros((len(Ts), len(rnd_seeds)))
dyn_errors_agl_g = np.zeros((len(Ts), len(rnd_seeds)))

for rndsidx in range(len(rnd_seeds)):
    
    rnd_seed = rnd_seeds[rndsidx]

    print('rnd seed: ', rnd_seed)
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e1/seed_' + str(int(rnd_seed)) + '/'    
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) + 'snr' + str(np.int(np.mean(snr)//1)) + 'e1_init'
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    pars_true = load_file['pars_true'].copy()

    ev_true = np.linalg.eigvals(pars_true['A'])
    ev_idx = np.argsort(np.angle(ev_true))
    ev_true = ev_true[ev_idx][n//2:]
    std_angl = np.std(np.angle(ev_true))
    std_tmsc = np.std(np.log(1-np.abs(ev_true)))  
    
    run = '_e1'
    for i in range(len(Ts)):
        
        T = Ts[i]
        print('T = ', str(T))

        #file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_'+str(overlap)
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+'snr'+str(np.int(np.mean(snr)//1))+'_run'+str(run)                
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
        idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()
        obs_scheme = load_file['obs_scheme']
        pars_est = load_file['pars_est']
        traces = load_file['traces']
        W = load_file['W']
        Qs = load_file['Qs']
        Om = load_file['Om']
        pars_est_g = load_file['pars_est_g']
        traces_g = load_file['traces_g']        
        
        #ls[   i,rndsidx] = traces[0][-1][-1]
        sub_pops = obs_scheme.sub_pops

        C = pars_est_g['C'].copy()
        subsp_errors_g[i,rndsidx] = calc_subspace_proj_error(pars_true['C'],C)      

        C = pars_est['C'].copy()
        subsp_errors[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], C)
                                        
        # ssidid
        ev_est = np.linalg.eigvals(pars_est['A'])
        ev_idx = np.argsort(np.angle(ev_est))
        ev_est = ev_est[ev_idx][n//2:]        

        x_ = np.vstack([np.angle(ev_true)/std_angl, np.log(1-np.abs(ev_true))/std_tmsc]).T
        y_ = np.vstack([np.angle(ev_est )/std_angl, np.log(1-np.abs(ev_est ))/std_tmsc]).T

        #out = find_nearest_pairs(x=x_,y=y_)                
        
        idxopt = np.argmin( np.mean( (x_[:, 0] - y_[perms,0])**2 + (x_[:, 1] - y_[perms,1])**2, axis=1 ))
        out = np.vstack((np.arange(x_.shape[0]), perms[idxopt])).T        
        
        x_ = x_[out[:,0],:]
        y_ = y_[out[:,1],:]
        
        ev_true = ev_true[out[:,0]]
        ev_est = ev_est[out[:,1]]            
        dyn_errors_abs[i,rndsidx] = np.sqrt(np.mean( (np.log10(1-np.abs(ev_est))-np.log10(1-np.abs(ev_true)))**2 ))
        dyn_errors_agl[i,rndsidx] = np.sqrt(np.mean( (np.angle(ev_est) - np.angle(ev_true))**2))/np.pi * 90    

        #kde = sp.stats.gaussian_kde(x_.T)(y_.T)
        #dyn_errors_abs[i,rndsidx] = kde.mean() 
        #dyn_errors_agl[i,rndsidx] = np.sum(np.log(kde))    

        if plot_subsp_vs_dyns=='dyns' :
            if T < 300000:
                plt.subplot(1,2,1)
                for c in range(len(clrs)):
                    #plt.plot(np.log10(1-np.abs(ev_true[c] )), np.angle(ev_true[c]), 'x', 
                    #         color=clrs[c,:],
                    #         markersize=10,
                    #         markeredgewidth=3)
                    #plt.hold(True)
                    #plt.plot(np.log10(1-np.abs(ev_est[c] )),  np.angle(ev_est[c]), 'o',
                    #         color=clrs[c,:],
                    #         markersize=10,
                    #         markeredgewidth=3)            
                    plt.plot(x_[c,1], x_[c,0], 'x', 
                             color=clrs[c,:],
                             markersize=10,
                             markeredgewidth=3)
                    plt.hold(True)
                    plt.plot(y_[c,1],  y_[c,0], 'o',
                             color=clrs[c,:],
                             markersize=10,
                             markeredgewidth=3)            
        
        
        # GROUSE
        ev_est_g = np.linalg.eigvals(pars_est_g['A'])
        ev_idx = np.argsort(np.angle(ev_est_g))
        ev_est_g = ev_est_g[ev_idx][n//2:]        

        x_ = np.vstack([np.angle(ev_true )/std_angl, np.log(1-np.abs(ev_true ))/std_tmsc]).T
        y_ = np.vstack([np.angle(ev_est_g)/std_angl, np.log(1-np.abs(ev_est_g))/std_tmsc]).T

        #out = find_nearest_pairs(x=x_,y=y_)            
        
        idxopt = np.argmin( np.mean( (x_[:, 0] - y_[perms,0])**2 + (x_[:, 1] - y_[perms,1])**2, axis=1 ))
        out = np.vstack((np.arange(x_.shape[0]), perms[idxopt])).T        
        
        x_ = x_[out[:,0],:]
        y_ = y_[out[:,1],:]
        
        ev_true = ev_true[out[:,0]]
        ev_est_g = ev_est_g[out[:,1]]            

        dyn_errors_abs_g[i,rndsidx] = np.sqrt(np.mean( (np.log10(1-np.abs(ev_est_g))-np.log10(1-np.abs(ev_true)))**2 ))
        dyn_errors_agl_g[i,rndsidx] = np.sqrt(np.mean( (np.angle(ev_est_g) - np.angle(ev_true))**2))/np.pi * 90

        #kde = sp.stats.gaussian_kde(x_.T)(y_.T)
        #dyn_errors_abs_g[i,rndsidx] = kde.mean() 
        #dyn_errors_agl_g[i,rndsidx] = np.sum(np.log(kde))    
        
        
        if plot_subsp_vs_dyns=='dyns' :
            if T < 300000:
                plt.subplot(1,2,2)
                for c in range(len(clrs)):
                    #plt.plot(np.log10(1-np.abs(ev_true[c] )), np.angle(ev_true[c]), 'x', 
                    #         color=clrs[c,:],
                    #         markersize=10,
                    #         markeredgewidth=3)
                    #plt.hold(True)
                    #plt.plot(np.log10(1-np.abs(ev_est_g[c] )),  np.angle(ev_est_g[c]), 'o',
                    #         color=clrs[c,:],
                    #         markersize=10,
                    #         markeredgewidth=3)            
                    plt.plot(x_[c,1], x_[c,0], 'x', 
                             color=clrs[c,:],
                             markersize=10,
                             markeredgewidth=3)
                    plt.hold(True)
                    plt.plot(y_[c,1],  y_[c,0], 'o',
                             color=clrs[c,:],
                             markersize=10,
                             markeredgewidth=3)            
                    
            plt.show()
        
        
        #if dyn_errors_abs[i, rndsidx] > dyn_errors_abs_g[i, rndsidx]:
        #    redid[i,rndsidx] = True
        #if dyn_errors_agl[i, rndsidx] > dyn_errors_agl_g[i, rndsidx]:
        #    redid[i,rndsidx] = True
        #if subsp_errors[i,rndsidx] > 1.05 * subsp_errors_g[i,rndsidx]:
        #    redid[i,rndsidx] = True
        
        
       
    if plot_subsp_vs_dyns =='subsp':
        plt.subplot(4,3,idx_subplot)
        plt.loglog(np.array(Ts)+0.1, subsp_errors_g[:,rndsidx], 'o-', color='b', linewidth=1.5)
        plt.loglog(np.array(Ts)+0.1, subsp_errors[:,rndsidx], 'o-', color='g', linewidth=2)
        plt.box('off')
        lgnd = ['GROUSE', 'ssidid']
        if rndsidx == 9:            
            plt.legend(lgnd, loc=1,frameon=False)
        plt.ylabel('subsp. proj. error')
        if rndsidx > 7:
            plt.xlabel('T')
        plt.title('random seed ' + str(rnd_seed))
        #plt.axis([900, 110000, 0.01, 1.])
        plt.xticks([1000,10000,100000], ['1000', '10000', '100000'])    
    
    idx_subplot +=1

if plot_subsp_vs_dyns =='subsp':
    plt.subplot(4,3,1)
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+'snr'+str(np.int(np.mean(snr)//1))+'_run'+str(run)                
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    obs_scheme = load_file['obs_scheme']
    obs_scheme.gen_mask_from_scheme()        
    plt.imshow(obs_scheme.mask[0:-1:1000, :].T, aspect='auto', interpolation='None')
    plt.xticks([0,50,100], [0, 50000, 100000])
    plt.ylabel('variable index i')
    plt.xlabel('time t')
    plt.title('example data observation scheme (fully observed)')
    plt.set_cmap('gray')
    plt.clim([0.,1.])
    plt.savefig(fig_path + 'varyT_10seeds.pdf')
plt.show()

redid[ls > 1.1 * np.mean(ls, axis=1).reshape(ls.shape[0], 1)]
    

In [None]:
plt.figure(figsize=(16,8))
for i in range(len(rnd_seeds)):
    plt.subplot(4,3,(i+1))
    plt.plot(dyn_errors_abs[:,i], 'go-')
    plt.plot(dyn_errors_abs_g[:,i].T, 'bo-')
    plt.plot(dyn_errors_agl[:,i], 'gx--')
    plt.plot(dyn_errors_agl_g[:,i].T, 'bx--')
    plt.plot(np.where(redid[:,i])[0], dyn_errors_agl[:,i][redid[:,i]], 'r*', markersize=10)    
    plt.box('off')
plt.show()

plt.subplot(1,2,1)
m = np.mean(dyn_errors_abs, axis=1)
s = np.std(dyn_errors_abs, axis=1) / np.sqrt(dyn_errors_abs.shape[1])
plt.plot(m, 'go-')
plt.plot(m-s, 'g--')
plt.plot(m+s, 'g--')

m = np.mean(dyn_errors_abs_g, axis=1)
s = np.std(dyn_errors_abs_g, axis=1) / np.sqrt(dyn_errors_abs_g.shape[1])
plt.plot(m, 'bo-')
plt.plot(m-s, 'b--')
plt.plot(m+s, 'b--')

plt.subplot(1,2,2)
m = np.mean(dyn_errors_agl, axis=1)
s = np.std(dyn_errors_agl, axis=1) / np.sqrt(dyn_errors_agl.shape[1])
plt.plot(m, 'go-')
plt.plot(m-s, 'g--')
plt.plot(m+s, 'g--')

m = np.mean(dyn_errors_agl_g, axis=1)
s = np.std(dyn_errors_agl_g, axis=1) / np.sqrt(dyn_errors_agl_g.shape[1])
plt.plot(m, 'bo-')
plt.plot(m-s, 'b--')
plt.plot(m+s, 'b--')

plt.show()
    

In [None]:

plt.figure(figsize=(14,6))

plt.subplot(1,2,1)
m = np.mean(subsp_errors_g, axis=1)
plt.loglog(np.array(Ts)+0.1, m, 'o-', color='b', linewidth=2.5)
m = np.mean(subsp_errors, axis=1)
s = np.std( subsp_errors, axis=1) / np.sqrt(subsp_errors.shape[1])
plt.loglog(np.array(Ts)+0.1, m, 'o-', color='g', linewidth=2.5)

m = np.mean(subsp_errors_g, axis=1)
s = np.std( subsp_errors_g, axis=1) / np.sqrt(subsp_errors_g.shape[1])
#plt.loglog(np.array(Ts)+0.1, m - s, '--', color='b', linewidth=2.5)
#plt.loglog(np.array(Ts)+0.1, m + s, '--', color='b', linewidth=2.5)

m = np.mean(subsp_errors, axis=1)
s = np.std( subsp_errors, axis=1) / np.sqrt(subsp_errors.shape[1])
#plt.loglog(np.array(Ts)+0.1, m - s, '--', color='g', linewidth=2.5)
#plt.loglog(np.array(Ts)+0.1, m + s, '--', color='g', linewidth=2.5)


plt.box('off')
lgnd = ['GROUSE', 'ssidid']
plt.legend(lgnd, loc=1,frameon=False)
plt.ylabel('subsp. proj. error')
plt.xlabel('recording length T')
plt.title('subspace identification')
plt.axis([900, 110000, 0.01, 1.])
plt.xticks([1000, 10000, 100000], ['1000', '10000', '100000'])


plt.subplot(1,2,2)



# complex angles

m = np.mean(dyn_errors_agl_g, axis=1)
s = np.std(dyn_errors_agl_g, axis=1) / np.sqrt(dyn_errors_agl_g.shape[1])
plt.semilogx(Ts, m, 'bx--', linewidth=2)

m = np.mean(dyn_errors_agl, axis=1)
s = np.std(dyn_errors_agl, axis=1) / np.sqrt(dyn_errors_agl.shape[1])
plt.semilogx(Ts, m, 'gx--', linewidth=2)

m = np.mean(dyn_errors_agl, axis=1)
s = np.std( dyn_errors_agl, axis=1) / np.sqrt(dyn_errors_agl.shape[1])
plt.fill_between(Ts, m-s, m+s, where=m+s>=m-s, 
                 facecolor='green', alpha=0.3)

m = np.mean(dyn_errors_agl_g, axis=1)
s = np.std( dyn_errors_agl_g, axis=1) / np.sqrt(dyn_errors_agl_g.shape[1])
plt.fill_between(Ts, m-s, m+s, where=m+s>=m-s, 
                 facecolor='blue', alpha=0.3)    


# absolute values

m = np.mean(dyn_errors_abs_g, axis=1)
s = np.std(dyn_errors_abs_g, axis=1) / np.sqrt(dyn_errors_abs_g.shape[1])
plt.semilogx(Ts, m, 'bo-', linewidth=2)

m = np.mean(dyn_errors_abs, axis=1)
s = np.std(dyn_errors_abs, axis=1) / np.sqrt(dyn_errors_abs.shape[1])
plt.semilogx(Ts, m, 'go-', linewidth=2)


m = np.mean(dyn_errors_abs, axis=1)
s = np.std( dyn_errors_abs, axis=1) / np.sqrt(dyn_errors_abs.shape[1])
plt.fill_between(Ts, m-s, m+s, where=m+s>=m-s, 
                 facecolor='green', alpha=0.5)

m = np.mean(dyn_errors_abs_g, axis=1)
s = np.std( dyn_errors_abs_g, axis=1) / np.sqrt(dyn_errors_abs_g.shape[1])
plt.fill_between(Ts, m-s, m+s, where=m+s>=m-s, 
                 facecolor='blue', alpha=0.5)    


#plt.semilogx(Ts, dyn_errors_abs[:,i], 'go-')
#plt.semilogx(Ts, dyn_errors_abs_g[:,i].T, 'bo-')
#plt.semilogx(Ts, dyn_errors_agl[:,i], 'gx--')
#plt.semilogx(Ts, dyn_errors_agl_g[:,i], 'bx--')

# if reporting errors on time-scales and complex angles:
lgnd = ['time-scale error (GROUSE)', 'time-scale error (ssidid)', 'angle error (GROUSE)', 'angle error (ssidid)']
plt.ylabel('root mean squared error')


plt.title('dynamics identification')
plt.axis([900, 110000, 0.01, 0.7])
plt.xticks([1000, 10000, 100000], ['1000', '10000', '100000'])
plt.box('off')
plt.legend(lgnd, loc=1,frameon=False)
plt.xlabel('recording length T')


#plt.loglog(np.array(overlaps)[redid]+0.1, 1.25 * subsp_errors[redid], 'r*', markersize=10)


plt.savefig(fig_path + 'sim_fully_obs_avg.pdf')
#plt.savefig(fig_path + 'sim_fully_obs_avg_kde.pdf')
plt.show()




In [None]:

plt.figure(figsize=(14,6))

plt.subplot(1,2,1)
m = np.mean(subsp_errors_g, axis=1)
plt.loglog(np.array(Ts)+0.1, m, 'o-', color='b', linewidth=2.5)
m = np.mean(subsp_errors, axis=1)
s = np.std( subsp_errors, axis=1) / np.sqrt(subsp_errors.shape[1])
plt.loglog(np.array(Ts)+0.1, m, 'o-', color='g', linewidth=2.5)

m = np.mean(subsp_errors_g, axis=1)
s = np.std( subsp_errors_g, axis=1) / np.sqrt(subsp_errors_g.shape[1])
#plt.loglog(np.array(Ts)+0.1, m - s, '--', color='b', linewidth=2.5)
#plt.loglog(np.array(Ts)+0.1, m + s, '--', color='b', linewidth=2.5)

m = np.mean(subsp_errors, axis=1)
s = np.std( subsp_errors, axis=1) / np.sqrt(subsp_errors.shape[1])
#plt.loglog(np.array(Ts)+0.1, m - s, '--', color='g', linewidth=2.5)
#plt.loglog(np.array(Ts)+0.1, m + s, '--', color='g', linewidth=2.5)


plt.box('off')
lgnd = ['GROUSE', 'ssidid']
plt.legend(lgnd, loc=1,frameon=False)
plt.ylabel('subsp. proj. error')
plt.xlabel('recording length T')
plt.title('subspace identification')
plt.axis([900, 110000, 0.01, 1.])
plt.xticks([1000, 10000, 100000], ['1000', '10000', '100000'])


plt.subplot(1,2,2)

# absolute values

m = np.mean(dyn_errors_agl_g, axis=1)
s = np.std(dyn_errors_agl_g, axis=1) / np.sqrt(dyn_errors_agl_g.shape[1])
plt.semilogx(Ts, m, 'bo-', linewidth=2)

m = np.mean(dyn_errors_agl, axis=1)
s = np.std(dyn_errors_agl, axis=1) / np.sqrt(dyn_errors_agl.shape[1])
plt.semilogx(Ts, m, 'go-', linewidth=2)

m = np.mean(dyn_errors_agl, axis=1)
s = np.std( dyn_errors_agl, axis=1) / np.sqrt(dyn_errors_agl.shape[1])
plt.fill_between(Ts, m-s, m+s, where=m+s>=m-s, 
                 facecolor='green', alpha=0.5)

m = np.mean(dyn_errors_agl_g, axis=1)
s = np.std( dyn_errors_agl_g, axis=1) / np.sqrt(dyn_errors_agl_g.shape[1])
plt.fill_between(Ts, m-s, m+s, where=m+s>=m-s, 
                 facecolor='blue', alpha=0.5)    

# if reporting errors on time-scales and complex angles:
#lgnd = ['time-scale error (GROUSE)', 'time-scale error (ssidid)', 'angle error (GROUSE)', 'angle error (ssidid)']
#plt.ylabel('root mean squared error')
lgnd = ['GROUSE', 'ssidid']
plt.ylabel('Kernel density estimate')

plt.title('dynamics identification')
#plt.axis([900, 110000, -14.5, -12.5])
plt.xticks([1000, 10000, 100000], ['1000', '10000', '100000'])
plt.box('off')
plt.legend(lgnd, loc=2,frameon=False)
plt.xlabel('recording length T')


#plt.loglog(np.array(overlaps)[redid]+0.1, 1.25 * subsp_errors[redid], 'r*', markersize=10)


#plt.savefig(fig_path + 'sim_fully_obs_avg.pdf')
plt.savefig(fig_path + 'sim_fully_obs_avg_kde.pdf')
plt.show()




# Fig 3 a - data missing at random

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
from scipy import spatial
import glob, os, psutil, time
import itertools

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import calc_subspace_proj_error

from matplotlib import cm


p,T,n,snr = 1000, 10010, 10, (9., 9.)
T_full = 100030
rnd_seeds = range(10,20)
fracs_obs = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)

#fracs_obs = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0)

run = '_e3rnd'

plot_subsp_vs_dyns='subsp'
perms = np.array(list(itertools.permutations(range(n//2))))

mmap, verbose = True, True

fig_path =  '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/figs/'

if plot_subsp_vs_dyns=='subsp':
    plt.figure(figsize=(10,4))

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

clrs = cm.jet(np.linspace(50, 200, n//2, dtype=int))
lgnd = []

subsp_errors   = np.zeros((len(fracs_obs),len(rnd_seeds)))
dyn_errors_abs = np.zeros((len(fracs_obs),len(rnd_seeds)))
dyn_errors_agl = np.zeros((len(fracs_obs),len(rnd_seeds)))

subsp_errors_g   = np.zeros((len(fracs_obs),len(rnd_seeds)))
dyn_errors_abs_g = np.zeros((len(fracs_obs),len(rnd_seeds)))
dyn_errors_agl_g = np.zeros((len(fracs_obs),len(rnd_seeds)))

for rndsidx in range(len(rnd_seeds)):

    rnd_seed = rnd_seeds[rndsidx]
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/rnd/seed_' + str(rnd_seed) + '/'

    init_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e1/seed_' + str(int(rnd_seed)) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
        
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) + 'snr' + str(np.int(np.mean(snr)//1)) + 'e1_init'
    load_file = np.load(init_path + file_name + '.npz')['arr_0'].tolist()
    pars_true = load_file['pars_true']    

    ev_true = np.linalg.eigvals(pars_true['A'])
    
    ev_idx = np.argsort(np.angle(ev_true))
    ev_true = ev_true[ev_idx][n//2:]
    std_angl = np.std(np.angle(ev_true))
    std_tmsc = np.std(np.log(1-np.abs(ev_true)))    

    

    for i in range(len(fracs_obs)):
        frac_obs = fracs_obs[i]

        print('frac_obs = ', str(frac_obs))

        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(int(100*frac_obs))
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
        idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()
        obs_scheme = load_file['obs_scheme']
        pars_est = load_file['pars_est']
        pars_est_g = load_file['pars_est_g']
        traces = load_file['traces']
        traces_g = load_file['traces_g']
        W = load_file['W']
        Qs = load_file['Qs']
        Om = load_file['Om']

        C = pars_est['C'].copy()
        C[obs_scheme.sub_pops[0]] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est['C']):
            pars_est['C'][obs_scheme.sub_pops[0]] *= -1
        del C    
        subsp_errors[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est['C'])


        C = pars_est_g['C'].copy()
        C[obs_scheme.sub_pops[0]] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est_g['C']):
            pars_est_g['C'][obs_scheme.sub_pops[0]] *= -1
        del C        
        subsp_errors_g[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est_g['C'])


        # ssidid
        ev_est = np.linalg.eigvals(pars_est['A'])
        ev_idx = np.argsort(np.angle(ev_est))
        ev_est = ev_est[ev_idx][n//2:]        

        x_ = np.vstack([np.angle(ev_true)/std_angl, np.log(1-np.abs(ev_true))/std_tmsc]).T
        y_ = np.vstack([np.angle(ev_est )/std_angl, np.log(1-np.abs(ev_est ))/std_tmsc]).T

        #out = find_nearest_pairs(x=x_,y=y_)                
        
        idxopt = np.argmin( np.mean( (x_[:, 0] - y_[perms,0])**2 + (x_[:, 1] - y_[perms,1])**2, axis=1 ))
        out = np.vstack((np.arange(x_.shape[0]), perms[idxopt])).T        
        
        x_ = x_[out[:,0],:]
        y_ = y_[out[:,1],:]
        
        ev_true = ev_true[out[:,0]]
        ev_est = ev_est[out[:,1]]            
        dyn_errors_abs[i,rndsidx] = np.sqrt(np.mean( (np.log10(1-np.abs(ev_est))-np.log10(1-np.abs(ev_true)))**2 ))
        dyn_errors_agl[i,rndsidx] = np.sqrt(np.mean( (np.angle(ev_est) - np.angle(ev_true))**2))/np.pi * 90    

        #kde = sp.stats.gaussian_kde(x_.T)(y_.T)
        #dyn_errors_abs[i,rndsidx] = kde.mean() 
        #dyn_errors_agl[i,rndsidx] = np.sum(np.log(kde))    

        if plot_subsp_vs_dyns=='dyns' :
            if T < 300000:
                plt.subplot(1,2,1)
                for c in range(len(clrs)):
                    #plt.plot(np.log10(1-np.abs(ev_true[c] )), np.angle(ev_true[c]), 'x', 
                    #         color=clrs[c,:],
                    #         markersize=10,
                    #         markeredgewidth=3)
                    #plt.hold(True)
                    #plt.plot(np.log10(1-np.abs(ev_est[c] )),  np.angle(ev_est[c]), 'o',
                    #         color=clrs[c,:],
                    #         markersize=10,
                    #         markeredgewidth=3)            
                    plt.plot(x_[c,1], x_[c,0], 'x', 
                             color=clrs[c,:],
                             markersize=10,
                             markeredgewidth=3)
                    plt.hold(True)
                    plt.plot(y_[c,1],  y_[c,0], 'o',
                             color=clrs[c,:],
                             markersize=10,
                             markeredgewidth=3)            
                
        
        
        # GROUSE
        ev_est_g = np.linalg.eigvals(pars_est_g['A'])
        ev_idx = np.argsort(np.angle(ev_est_g))
        ev_est_g = ev_est_g[ev_idx][n//2:]        

        x_ = np.vstack([np.angle(ev_true )/std_angl, np.log(1-np.abs(ev_true ))/std_tmsc]).T
        y_ = np.vstack([np.angle(ev_est_g)/std_angl, np.log(1-np.abs(ev_est_g))/std_tmsc]).T

        #out = find_nearest_pairs(x=x_,y=y_)            
        
        idxopt = np.argmin( np.mean( (x_[:, 0] - y_[perms,0])**2 + (x_[:, 1] - y_[perms,1])**2, axis=1 ))
        out = np.vstack((np.arange(x_.shape[0]), perms[idxopt])).T        
        
        x_ = x_[out[:,0],:]
        y_ = y_[out[:,1],:]
        
        ev_true = ev_true[out[:,0]]
        ev_est_g = ev_est_g[out[:,1]]            

        dyn_errors_abs_g[i,rndsidx] = np.sqrt(np.mean( (np.log10(1-np.abs(ev_est_g))-np.log10(1-np.abs(ev_true)))**2 ))
        dyn_errors_agl_g[i,rndsidx] = np.sqrt(np.mean( (np.angle(ev_est_g) - np.angle(ev_true))**2))/np.pi * 90

        #kde = sp.stats.gaussian_kde(x_.T)(y_.T)
        #dyn_errors_abs_g[i,rndsidx] = kde.mean() 
        #dyn_errors_agl_g[i,rndsidx] = np.sum(np.log(kde))    
        
        
        if plot_subsp_vs_dyns=='dyns' :
            if T < 300000:
                plt.subplot(1,2,2)
                for c in range(len(clrs)):
                    #plt.plot(np.log10(1-np.abs(ev_true[c] )), np.angle(ev_true[c]), 'x', 
                    #         color=clrs[c,:],
                    #         markersize=10,
                    #         markeredgewidth=3)
                    #plt.hold(True)
                    #plt.plot(np.log10(1-np.abs(ev_est_g[c] )),  np.angle(ev_est_g[c]), 'o',
                    #         color=clrs[c,:],
                    #         markersize=10,
                    #         markeredgewidth=3)            
                    plt.plot(x_[c,1], x_[c,0], 'x', 
                             color=clrs[c,:],
                             markersize=10,
                             markeredgewidth=3)
                    plt.hold(True)
                    plt.plot(y_[c,1],  y_[c,0], 'o',
                             color=clrs[c,:],
                             markersize=10,
                             markeredgewidth=3)            
                    
        """
        if plot_subsp_vs_dyns=='dyns' :
            plt.subplot(1,3,1)
            plt.imshow(pars_est['A'], interpolation='None')
            plt.subplot(1,3,2)
            plt.plot(np.real(np.linalg.eigvals(pars_est['A'])), 'go-')
            plt.plot(np.real(np.linalg.eigvals(pars_est_g['A'])), 'bo-')
            plt.plot(np.imag(np.linalg.eigvals(pars_est['A'])), 'gx--')
            plt.plot(np.imag(np.linalg.eigvals(pars_est_g['A'])), 'bx--')
            plt.subplot(1,3,3)
            plt.imshow(pars_est_g['A'], interpolation='None')
            plt.show()
        """ 
    plt.show()
        
if plot_subsp_vs_dyns=='subsp' :
    plt.subplot(1,2,1)
    plt.semilogy(fracs_obs, np.mean(subsp_errors_g, axis=1), 'o-', color='b', linewidth=1.5)
    plt.semilogy(fracs_obs, np.mean(subsp_errors, axis=1), 'o-', color='g', linewidth=2)
    plt.box('off')
    lgnd = ['GROUSE', 'ssidid']
    plt.legend(lgnd, loc=1,frameon=False)
    plt.ylabel('subsp. proj. error')
    plt.xlabel('% observed')
    plt.title('subspace estimation')
    #plt.xticks(10**np.arange(3,6))
    #plt.yticks(10**np.arange(-2,1))

    plt.subplot(1,2,2)

    plt.plot(fracs_obs, np.mean(dyn_errors_abs_g, axis=1), 'bo-', linewidth=1.5)
    plt.plot(fracs_obs, np.mean(dyn_errors_abs, axis=1),   'go-', linewidth=2)
    plt.plot(fracs_obs, np.mean(dyn_errors_agl_g, axis=1), 'bx--', linewidth=1.5)
    plt.plot(fracs_obs, np.mean(dyn_errors_agl, axis=1),   'gx--', linewidth=2)
    #plt.loglog(fracs_obs, dyn_errors_abs_g, 'b.-')
    #plt.loglog(fracs_obs, dyn_errors_agl_g, 'bx--')
    lgnd = ['GROUSE', 'ssidid']
    plt.legend(lgnd, loc=1,frameon=False)
    plt.box('off')
    plt.ylabel('eigenvalues of A')
    plt.xlabel('% observed')
    plt.title('spectrum of dynamics')
    #plt.xticks(10**np.arange(3,6))
    #plt.yticks(10**np.arange(-3,1))
    #plt.savefig(fig_path + 'fig3_A.pdf')
    plt.show()
    
    

In [None]:

plt.figure(figsize=(14,4))

plt.subplot(1,2,1)
m = np.mean(subsp_errors_g, axis=1)
plt.plot(np.array(fracs_obs), m, 'o-', color='b', linewidth=2.5)
m = np.mean(subsp_errors, axis=1)
s = np.std( subsp_errors, axis=1) / np.sqrt(subsp_errors.shape[1])
plt.plot(np.array(fracs_obs), m, 'o-', color='g', linewidth=2.5)


# add std / sem shaded areas
m = np.mean(subsp_errors, axis=1)
s = np.std( subsp_errors, axis=1) / np.sqrt(subsp_errors.shape[1])
plt.fill_between(np.array(fracs_obs), m-s, m+s, where=m+s>=m-s, 
                 facecolor='green', alpha=0.5)

m = np.mean(subsp_errors_g, axis=1)
s = np.std( subsp_errors_g, axis=1) / np.sqrt(subsp_errors_g.shape[1])
plt.fill_between(np.array(fracs_obs), m-s, m+s, where=m+s>=m-s, 
                 facecolor='blue', alpha=0.5)    

plt.box('off')
lgnd = ['GROUSE', 'ssidid']
plt.legend(lgnd, loc=1,frameon=False)
plt.ylabel('subsp. proj. error')
plt.xlabel('fraction observed [%]')
plt.title('subspace identification')
plt.axis([0.05, 1.05, 0.09, 0.6])
plt.xticks([.20, .40, .60, .80, 1.00], ['20', '40', '60', '80', '100'])
plt.yticks([0.1, 0.2, 0.3], ['0.1', '0.2', '0.3'])

plt.subplot(1,2,2)

# absolute values
m = np.mean(dyn_errors_abs_g, axis=1)
s = np.std(dyn_errors_abs_g, axis=1) / np.sqrt(dyn_errors_abs_g.shape[1])
plt.plot(fracs_obs, m, 'bo-', linewidth=2)

m = np.mean(dyn_errors_abs, axis=1)
s = np.std(dyn_errors_abs, axis=1) / np.sqrt(dyn_errors_abs.shape[1])
plt.plot(fracs_obs, m, 'go-', linewidth=2)

m = np.mean(dyn_errors_abs, axis=1)
s = np.std( dyn_errors_abs, axis=1) / np.sqrt(dyn_errors_abs.shape[1])
plt.fill_between(np.array(fracs_obs), m-s, m+s, where=m+s>=m-s, 
                 facecolor='green', alpha=0.5)

m = np.mean(dyn_errors_abs_g, axis=1)
s = np.std( dyn_errors_abs_g, axis=1) / np.sqrt(dyn_errors_abs_g.shape[1])
plt.fill_between(np.array(fracs_obs), m-s, m+s, where=m+s>=m-s, 
                 facecolor='blue', alpha=0.5)    



# complex angles
m = np.mean(dyn_errors_agl_g, axis=1)
s = np.std(dyn_errors_agl_g, axis=1) / np.sqrt(dyn_errors_agl_g.shape[1])
plt.plot(fracs_obs, m, 'bx--', linewidth=2)

m = np.mean(dyn_errors_agl, axis=1)
s = np.std(dyn_errors_agl, axis=1) / np.sqrt(dyn_errors_agl.shape[1])
plt.plot(fracs_obs, m, 'gx--', linewidth=2)

m = np.mean(dyn_errors_agl, axis=1)
s = np.std( dyn_errors_agl, axis=1) / np.sqrt(dyn_errors_agl.shape[1])
plt.fill_between(np.array(fracs_obs), m-s, m+s, where=m+s>=m-s, 
                 facecolor='green', alpha=0.3)

m = np.mean(dyn_errors_agl_g, axis=1)
s = np.std( dyn_errors_agl_g, axis=1) / np.sqrt(dyn_errors_agl_g.shape[1])
plt.fill_between(np.array(fracs_obs), m-s, m+s, where=m+s>=m-s, 
                 facecolor='blue', alpha=0.3)    


# if reporting errors on time-scales and complex angles:
lgnd = ['time-scale error (GROUSE)', 'time-scale error (ssidid)', 'angle error (GROUSE)', 'angle error (ssidid)']
plt.ylabel('root mean squared error')


plt.title('dynamics identification')
plt.axis([0.05, 1.05, 0.01, 0.45])
plt.xticks([.20, .40, .60, .80, 1.00], ['20', '40', '60', '80', '100'])
plt.yticks([0, .10, .20, .30])
plt.box('off')
plt.legend(lgnd, loc=1,frameon=False)
plt.xlabel('fraction observed [%]')


#plt.loglog(np.array(overlaps)[redid]+0.1, 1.25 * subsp_errors[redid], 'r*', markersize=10)


#plt.savefig(fig_path + 'sim_fully_obs_avg.pdf')
#plt.savefig(fig_path + 'sim_fully_obs_avg_kde.pdf')
plt.savefig(fig_path + 'sim_rnd_obs_avg.pdf')
plt.show()




# Fig 3 b - two subpop stitching as function of overlap

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
from scipy import spatial
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import calc_subspace_proj_error

from matplotlib import cm

fig_path =  '/home/mackelab/Desktop/Projects/Stitching/figures/'

compare_grouse_dyns = False

mmap, verbose = True, True

p,T_full,n,snr = 1000, 100030, 10, (1., 1.)

overlaps = (0,10,15,20,25,50,100,300,1000)
rnd_seeds = range(30, 40)


def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

clrs = cm.jet(np.linspace(50, 200, len(overlaps), dtype=int))
lgnd = []

plt.figure(figsize=(16,16))
idx_subplot = 2

subsp_errors_avg = np.zeros(len(overlaps))
subsp_errors_avg_g = np.zeros(len(overlaps))

redid   = np.zeros((len(overlaps), len(rnd_seeds)), dtype=bool)
bitflip = np.zeros((len(overlaps), len(rnd_seeds)), dtype=bool)
ls = np.zeros((len(overlaps), len(rnd_seeds)))

subsp_errors   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errorsf   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors1   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors2   = np.zeros((len(overlaps), len(rnd_seeds)))

dyn_errors_abs = np.zeros((len(overlaps), len(rnd_seeds)))
dyn_errors_agl = np.zeros((len(overlaps), len(rnd_seeds)))

subsp_errors_g   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errorsf_g  = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors1_g   = np.zeros((len(overlaps), len(rnd_seeds)))
subsp_errors2_g   = np.zeros((len(overlaps), len(rnd_seeds)))
dyn_errors_abs_g = np.zeros((len(overlaps), len(rnd_seeds)))
dyn_errors_agl_g = np.zeros((len(overlaps), len(rnd_seeds)))

for rndsidx in range(len(rnd_seeds)):
    
    rnd_seed = rnd_seeds[rndsidx]

    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_'+str(rnd_seed)+'/'
    data_path_rep =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/explore/seed_'+str(rnd_seed)+'/'
    
    run = '_e3_init'
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) +  run
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    pars_true = load_file['pars_true'].copy()

    ev_true = np.linalg.eigvals(pars_true['A'])
    #ev_idx = np.argsort(np.angle(ev_true))
    #ev_true = ev_true[n//2:]
    std_angl = np.std(np.angle(ev_true))
    std_tmsc = np.std(np.log(1-np.abs(ev_true)))  
    
    run = '_e3_slim'
    for i in range(len(overlaps)):
        
        overlap = overlaps[i]

        print('overlap = ', str(overlap))

        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_'+str(overlap)
        load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
        idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()
        obs_scheme = load_file['obs_scheme']
        pars_est = load_file['pars_est']
        pars_est_g = load_file['pars_est_g']
        traces = load_file['traces']
        traces_g = load_file['traces_g']
        W = load_file['W']
        Qs = load_file['Qs']
        Om = load_file['Om']

        
        ls[   i,rndsidx] = traces[0][-1][-1]
        
        sub_pops = obs_scheme.sub_pops

        C = pars_est_g['C'].copy()
        subsp_errors_g[i,rndsidx] = calc_subspace_proj_error(pars_true['C'],C)      
        subsp_errors1_g[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[0],:],C[sub_pops[0],:])      
        subsp_errors2_g[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[1],:],C[sub_pops[1],:])      
        C[sub_pops[0],:] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est_g['C']):
            pars_est_g['C'][sub_pops[0],:] *= -1
        subsp_errorsf_g[i, rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est_g['C'])

        C = pars_est['C'].copy()
        subsp_errors[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], C)
        subsp_errors1[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[0],:],C[sub_pops[0],:])      
        subsp_errors2[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[1],:],C[sub_pops[1],:])      
        C[sub_pops[0],:] *= -1
        if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est['C']):
            bitflip[i,rndsidx] = True
            print('flipping bit')
            pars_est['C'][sub_pops[0],:] *= -1    
        subsp_errorsf[i, rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est['C'])
                                        
        if ls[i, rndsidx] > 1000.:
            
            redid[i,rndsidx]   = True           
            
            file_name_rep = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+'_e3final_'+str(overlap)
            load_file = np.load(data_path + file_name_rep + '.npz')['arr_0'].tolist()
            idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()
            obs_scheme = load_file['obs_scheme']
            pars_est   = load_file['pars_est']
            traces     = load_file['traces']            
            C = pars_est['C'].copy()
            subsp_errors[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], C)     
            subsp_errors1[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[0],:],C[sub_pops[0],:])      
            subsp_errors2[i,rndsidx] = calc_subspace_proj_error(pars_true['C'][sub_pops[1],:],C[sub_pops[1],:])      
            C[sub_pops[0],:] *= -1
            if calc_subspace_proj_error(pars_true['C'], C) < calc_subspace_proj_error(pars_true['C'], pars_est['C']):
                pars_est['C'][sub_pops[0],:] *= -1    
                bitflip[i,rndsidx] = True
            subsp_errorsf[i,rndsidx] = calc_subspace_proj_error(pars_true['C'], pars_est['C'])     
            
            
    plt.subplot(4,3,idx_subplot)
    
    plt.loglog(1e-10, 1e-10, 'o-', color='b', linewidth=1.5)
    plt.loglog(1e-10, 1e-10, 'o-', color='g', linewidth=2)
    plt.loglog(1e-10, 1e-10, 'r*', markersize=10)    
    plt.loglog(1e-10, 1e-10, 'mx', markersize=8, markeredgewidth=3)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors_g[:,rndsidx], 'o-', color='b', linewidth=1.5)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors[:,rndsidx], 'o-', color='g', linewidth=2)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors1_g[:,rndsidx], 'o--', color='b', linewidth=1.5)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors2_g[:,rndsidx], 'o--', color='b', linewidth=1.5)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors1[:,rndsidx], 'o--', color='g', linewidth=2)
    plt.loglog(np.array(overlaps)+0.1, subsp_errors2[:,rndsidx], 'o--', color='g', linewidth=2)
    plt.loglog(np.array(overlaps)[redid[:,rndsidx]]+0.1, 1.25 * subsp_errors[:,rndsidx][redid[:,rndsidx]], 'r*', markersize=10)
    plt.loglog(np.array(overlaps)[bitflip[:,rndsidx]]+0.1, 1.5 * subsp_errors[:,rndsidx][bitflip[:,rndsidx]], 'mx', markersize=8, markeredgewidth=3)
    plt.box('off')
    if rndsidx == 9:
        lgnd = ['GROUSE', 'ssidid', 'rep. fit', 'bitflip']
        plt.legend(lgnd, loc=1,frameon=False)
    plt.ylabel('subsp. proj. error')
    if rndsidx > 6:
        plt.xlabel('overlap')
    plt.title('random seed ' + str(rnd_seed))
    plt.axis([0.09, 1100, 0.01, 1.6])
    plt.xticks([0.1, 1, 10, 100, 1000], ['0', '1', '10', '100', '1000'])
        
    idx_subplot +=1
    
plt.subplot(4,3,1)
file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_'+str(100)
load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
obs_scheme = load_file['obs_scheme']
obs_scheme.gen_mask_from_scheme()        
plt.imshow(obs_scheme.mask[0:-1:1000, :].T, aspect='auto', interpolation='None')
plt.xticks([0,50,100], [0, 50000, 100000])
plt.ylabel('variable index i')
plt.xlabel('time t')
plt.title('example data observation scheme (10% overlap)')
plt.set_cmap('gray')

#plt.savefig(fig_path + 'sso_10seeds_raw.pdf')
    
plt.show()
    

# selection of re-runs

In [None]:
nrmlzrs = np.zeros(len(overlaps))
for i in range(len(overlaps)):
    overlap = overlaps[i]
    file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+'_e3_'+str(overlap)
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
    Om = load_file['Om']
    nrmlzrs[i] = Om[0].sum()


In [None]:


for i in range(3):
    plt.loglog(0.00001, 1, 'o', color='k', linewidth=2.5)
    plt.loglog(0.00001, 1, 'o', color='r', linewidth=2.5)
#    plt.loglog(0.00001, 1, 'o', color='b', linewidth=2.5)

    
plt.legend(('good fits', 'fits rerun'), loc=1)

for i in range(len(overlaps)):
    overlap = overlaps[i]
    sub_pops = (np.arange((p+overlap)//2),np.arange((p-overlap)//2,p))
    tmp = np.zeros((p,p), dtype=bool)
    
    tmp[np.ix_(sub_pops[0], sub_pops[0])] = True
    
    nrmlzr = int(nrmlzrs[i]) #= (len(sub_pops[0])**2 + len(sub_pops[1])**2 + len(sub_pops[0])*len(sub_pops[1]))
    
    print(overlap, nrmlzr)
    plt.loglog(np.ones(np.sum(np.invert(redid[i,:])))*overlaps[i]+0.1, ls[i,np.invert(redid[i,:])]/nrmlzr , 'o', color='k', linewidth=2.5)
    plt.loglog(np.ones(np.sum(redid[i,:]))*overlaps[i]+0.1, ls[i,redid[i,:]]/nrmlzr , 'o', color='r', linewidth=2.5)

plt.axis([0.09, 1100, 0.0001, 1.0])

plt.xticks([0.1, 1, 10, 100, 1000], ['0', '0.1', '1', '10', '100'])
plt.ylabel('final loss (norm. MSE)')
plt.xlabel('overlap [%]')
plt.title('selection criterion for re-running ssidid algorithm for stitching')

plt.savefig(fig_path + 'sso_selection_of_fits_to_redo.pdf')

plt.show()

# final summary figure

In [None]:

plt.figure(figsize=(14,4))

plt.subplot(1,2,1)
file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T_full)+str(run)+'_'+str(100)
load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
obs_scheme = load_file['obs_scheme']
obs_scheme.gen_mask_from_scheme()        
mask = obs_scheme.mask[0:-1:1000, :]
plt.imshow(mask.T, aspect='auto', interpolation='None')
plt.xticks([0,50,100], [0, 50000, 100000])
plt.yticks([0,500,1000], ['1', '500', '1000'])
plt.ylabel('variable index i')
plt.xlabel('time t')
plt.title('example data observation scheme (10% overlap)')
plt.set_cmap('gray')

plt.subplot(1,2,2)

if np.min(overlaps) == 0:    
    
    idx_pos = np.arange(1, len(overlaps), dtype=int)
    overlaps_pos = np.array(overlaps)[idx_pos]
    plt.semilogx(overlaps_pos, np.mean(subsp_errors_g[idx_pos,:], axis=1), 'o-', color='b', linewidth=2.5)
    plt.semilogx(overlaps_pos, np.mean(subsp_errors[  idx_pos,:], axis=1), 'o-', color='g', linewidth=2.5)

    m = np.mean(subsp_errorsf[0,:])
    plt.semilogx(5, m, 'o-', color='c', linewidth=2.5)
    
    
    # add std / sem shaded areas
    m = np.mean(subsp_errors[idx_pos,:], axis=1)
    s = np.std( subsp_errors[idx_pos,:], axis=1) / np.sqrt(subsp_errors[idx_pos,:].shape[1])
    plt.fill_between(overlaps_pos, m-s, m+s, where=m+s>=m-s, 
                     facecolor='green', alpha=0.5)

    m = np.mean(subsp_errors_g[idx_pos,:], axis=1)
    s = np.std( subsp_errors_g[idx_pos,:], axis=1) / np.sqrt(subsp_errors_g[idx_pos,:].shape[1])
    plt.fill_between(overlaps_pos, m-s, m+s, where=m+s>=m-s, 
                     facecolor='blue', alpha=0.5)    
    
    
    # add bitflipped results for zero overlap:
    m = np.mean(subsp_errorsf[0,:])
    s = np.std( subsp_errorsf[0,:]) / np.sqrt(subsp_errors[0,:].size)
    plt.semilogx(5, m, 'o-', color='c', linewidth=2.5)
    plt.fill_between([4.9, 5.1], (m-s) * np.ones(2),(m+s) * np.ones(2), where=None, 
                     facecolor='c', alpha=0.5)

    # add raw results for zero overlap:
    m = np.mean(subsp_errors[0,:])
    s = np.std( subsp_errors[0,:]) / np.sqrt(subsp_errors[0,:].size)
    plt.semilogx(5, m, 'o-', color='g', linewidth=2.5)
    plt.fill_between([4.9, 5.1], (m-s) * np.ones(2),(m+s) * np.ones(2), where=None, 
                     facecolor='g', alpha=0.5)
    
    m = np.mean(subsp_errors_g[0,:])
    s = np.std( subsp_errors_g[0,:]) / np.sqrt(subsp_errors_g[0,:].size)
    plt.semilogx(5.1, m, 'o-', color='b', linewidth=2.5)
    plt.fill_between([5.0, 5.2], (m-s) * np.ones(2),(m+s) * np.ones(2), where=None, 
                     facecolor='b', alpha=0.5)
    
    plt.axis([4.0, 1100, 0.0, 0.8])
    plt.xticks([5, 10, 100, 1000], ['0', '1', '10', '100'])
    plt.yticks([0., 0.2, 0.4, 0.6, 0.8])
    
    lgnd = ['GROUSE', 'ssidid', 'ssidid \n (bit flipped)']    
    plt.legend(lgnd, loc=1,frameon=False)
    
    
elif np.min(overlaps) ==  10:
    
    plt.semilogx(np.array(overlaps), np.mean(subsp_errors_g, axis=1), 'o-', color='b', linewidth=2.5)
    plt.semilogx(np.array(overlaps), np.mean(subsp_errors,   axis=1), 'o-', color='g', linewidth=2.5)

    # add std / sem shaded areas
    m = np.mean(subsp_errors, axis=1)
    s = np.std( subsp_errors, axis=1) / np.sqrt(subsp_errors.shape[1])
    plt.fill_between(np.array(overlaps), m-s, m+s, where=m+s>=m-s, 
                     facecolor='green', alpha=0.5)

    m = np.mean(subsp_errors_g, axis=1)
    s = np.std( subsp_errors_g, axis=1) / np.sqrt(subsp_errors_g.shape[1])
    plt.fill_between(np.array(overlaps), m-s, m+s, where=m+s>=m-s, 
                     facecolor='blue', alpha=0.5)    
    
    plt.axis([9, 1100, 0.0, 0.5])
    plt.xticks([10, 100, 1000], ['1', '10', '100'])
    plt.yticks([0., 0.2, 0.4, 0.6, 0.8])
    
    lgnd = ['GROUSE', 'ssidid']    
    plt.legend(lgnd, loc=1,frameon=False)

plt.box('off')
plt.ylabel('subsp. proj. error')
plt.xlabel('overlap [%]')
plt.title('subspace identification')
    
plt.savefig(fig_path + 'sim_stitched_overlap_avg.pdf')
plt.show()


# add lines for individual subpopulations
#m = np.mean(np.hstack((subsp_errors1_g,subsp_errors2_g)),axis=1)
#s =np.std(np.hstack((subsp_errors1_g,subsp_errors2_g)),axis=1) / np.sqrt(np.hstack((subsp_errors1_g,subsp_errors2_g)).shape[1])
#plt.loglog(np.array(overlaps)+0.1, m - s, '--', color='b', linewidth=2.5)
#plt.loglog(np.array(overlaps)+0.1, m + s, '--', color='b', linewidth=2.5)
#m = np.mean(np.hstack((subsp_errors1,  subsp_errors2  )),axis=1)
#s =np.std(np.hstack((subsp_errors1,subsp_errors2)),axis=1) / np.sqrt(np.hstack((subsp_errors1,subsp_errors2)).shape[1])
#plt.loglog(np.array(overlaps)+0.1, m - s, '--', color='g', linewidth=2.5)
#plt.loglog(np.array(overlaps)+0.1, m + s, '--', color='g', linewidth=2.5)





# Fig 4a

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from ssidid import ObservationScheme

fig_path =  '/home/mackelab/Desktop/Projects/Stitching/figures/'


p = 41
T = 100


# create subpopulations

# splitting imaging planes, 2 subpops
ns = 10

sub_pops = [np.arange(i*(p//ns), (i+1)*(p//ns)) for i in range(ns-1)] + [np.arange((ns-1)*(p//ns),p)]

obs_sweeps = 10

obs_pops = np.hstack([np.arange(len(sub_pops)) for i in range(obs_sweeps)])
obs_time = np.array([i*T//len(obs_pops) for i in range(1,len(obs_pops)+1)])


obs_scheme = ObservationScheme(p=p,T=T,
                               sub_pops=sub_pops, 
                               obs_pops=obs_pops,
                               obs_time=obs_time)
obs_scheme.gen_mask_from_scheme()

plt.figure(figsize=(16,10))
plt.subplot(3,2,1)
x = np.arange(11)
for i in range(ns+1):
    plt.plot(10*i*np.array([1,1]), np.array([0,p]), 'k')
#    plt.plot(10*i+x, 100-10*x, '--', color='b')
plt.xticks(range(10,101,10), range(1,11))
plt.yticks(np.arange(0,41,10)+.5, np.arange(1, 42, 10)[::-1])
plt.axis([0,100,0,p])
#plt.xlabel('time [s]')
plt.ylabel('observed pixel/z-plane')
plt.subplot(3,2,1).get_yaxis().set_tick_params(direction='out')
plt.subplot(3,2,1).get_xaxis().set_tick_params(direction='out')

plt.subplot(3,2,3)
x = np.arange(11)
for i in range(ns+1):
    plt.plot(10*i*np.array([1,1]), np.array([0,p]), 'k')
    plt.plot(10*i+x, p-p*x/10, '--', color='b')
plt.xticks(range(10,101,10), range(1,11))
plt.yticks(np.arange(0,41,10)+.5, np.arange(1, 42, 10)[::-1])
plt.axis([0,100,0,p])
#plt.xlabel('time [s]')
plt.ylabel('observed pixel/z-plane')
plt.subplot(3,2,3).get_yaxis().set_tick_params(direction='out')
plt.subplot(3,2,3).get_xaxis().set_tick_params(direction='out')

plt.subplot(3,2,5)
plt.imshow(obs_scheme.mask.T, interpolation='None', aspect='auto')
plt.gray()
plt.xticks(np.arange(10,101,10)-0.5, range(1,11,1))
plt.yticks(np.arange(0,41,10), np.arange(1, 42, 10))
plt.xlabel('time [s]')
plt.ylabel('observed pixel/z-plane')
plt.subplot(3,2,5).get_yaxis().set_tick_params(direction='out')
plt.subplot(3,2,5).get_xaxis().set_tick_params(direction='out')

lag_range=np.arange(10)
W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True)

for i in range(5):
    plt.subplot(6,4,3+i*4)
    plt.imshow(1/W[2*i], interpolation='None')
    plt.xticks([])
    plt.yticks([])
    plt.ylabel('m = ' + str(2*i))
    plt.gray()

    plt.subplot(6,4,4+i*4)
    plt.imshow(1/W[2*i+1], interpolation='None')
    plt.xticks([])
    plt.yticks([])
    plt.ylabel('m = ' + str(2*i+1))
    plt.gray()
    #if i == 4:
    #    plt.colorbar()
plt.subplot(6, 6, 35)    
plt.imshow(np.ones((10,10)))
plt.clim([0,1])
plt.xticks([])
plt.yticks([])
plt.ylabel('U m')


plt.savefig(fig_path + 'fig4_A.pdf')

plt.show()

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from ssidid import ObservationScheme

fig_path =  '/home/mackelab/Desktop/Projects/Stitching/figures/'




p = 41
T = 1200
lag_range = np.arange(10)

# create subpopulations

# splitting imaging planes, 2 subpops
ns = 10

sub_pops = [np.arange(i*(p//ns), (i+1)*(p//ns)) for i in range(ns-1)] + [np.arange((ns-1)*(p//ns),p)]

obs_sweeps = 60

obs_pops = np.hstack([np.hstack((np.arange(0,len(sub_pops)),np.arange(len(sub_pops))[::-1])) for i in range(obs_sweeps)])
#obs_pops = np.hstack([np.arange(1), obs_pops, np.arange(1,len(sub_pops))])
obs_time = np.arange(1,T+1)
#obs_time = np.array([i*T//len(obs_pops) for i in range(1,len(obs_pops)+1)])

obs_scheme = ObservationScheme(p=p,T=T,
                               sub_pops=sub_pops, 
                               obs_pops=obs_pops,
                               obs_time=obs_time)
obs_scheme.gen_mask_from_scheme()

plt.figure(figsize=(16,10))
plt.subplot(3,2,1)
x = np.arange(11)
for i in range(ns+1):
    plt.plot(10*i*np.array([1,1]), np.array([0,p]), 'k')
#    plt.plot(10*i+x, 100-10*x, '--', color='b')
plt.xticks(range(10,101,10), range(1,11))
plt.yticks(np.arange(0,41,10)+.5, np.arange(1, 42, 10)[::-1])
plt.axis([0,100,0,p])
#plt.xlabel('time [s]')
plt.ylabel('observed pixel/z-plane')
plt.subplot(3,2,1).get_yaxis().set_tick_params(direction='out')
plt.subplot(3,2,1).get_xaxis().set_tick_params(direction='out')

plt.subplot(3,2,3)
x = np.arange(11)
for i in range(ns+1):
    plt.plot(10*i*np.array([1,1]), np.array([0,p]), 'k')
    #plt.plot(10*i+x, p-p*x/10, '--', color='b')
    if np.mod(i,2)==0:
        plt.plot(10*i+x, p-p*x/10, '--', color='b')
    else:
        plt.plot(10*i+x, p*x/10, '--', color='b')
    
    
plt.xticks(range(10,101,10), range(1,11))
plt.yticks(np.arange(0,41,10)+.5, np.arange(1, 42, 10)[::-1])
plt.axis([0,100,0,p])
#plt.xlabel('time [s]')
plt.ylabel('observed pixel/z-plane')
plt.subplot(3,2,3).get_yaxis().set_tick_params(direction='out')
plt.subplot(3,2,3).get_xaxis().set_tick_params(direction='out')

plt.subplot(3,2,5)
plt.imshow(obs_scheme.mask[:100,:].T, interpolation='None', aspect='auto')
plt.gray()
plt.xticks(np.arange(10,101,10)-0.5, range(1,11,1))
plt.yticks(np.arange(0,41,10), np.arange(1, 42, 10))
plt.xlabel('time [s]')
plt.ylabel('observed pixel/z-plane')
plt.subplot(3,2,5).get_yaxis().set_tick_params(direction='out')
plt.subplot(3,2,5).get_xaxis().set_tick_params(direction='out')

lag_range=np.arange(10)
W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True)

for i in range(5):
    plt.subplot(6,4,3+i*4)
    plt.imshow(1/W[2*i], interpolation='None')
    plt.xticks([])
    plt.yticks([])
    plt.ylabel('m = ' + str(2*i))
    plt.gray()
    plt.clim(0,185)

    plt.subplot(6,4,4+i*4)
    plt.imshow(1/W[2*i+1], interpolation='None')
    plt.xticks([])
    plt.yticks([])
    plt.ylabel('m = ' + str(2*i+1))
    plt.gray()
    plt.clim(0,185)
    #if i == 4:
    #    plt.colorbar()
plt.subplot(6, 6, 35)    
tmp = np.zeros_like(W[0])
for i in range(len(W)):
    tmp += 1/W[i]
plt.imshow(tmp, interpolation='None')  
plt.clim(0,185)
#plt.clim([0,1])
plt.xticks([])
plt.yticks([])
plt.ylabel('U m')


plt.savefig(fig_path + 'fig4_A_alt10.pdf')

plt.show()

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from ssidid import ObservationScheme

fig_path =  '/home/mackelab/Desktop/Projects/Stitching/figures/'




p = 41
T = 10000
lag_range = np.arange(11)

# create subpopulations

# splitting imaging planes, 2 subpops
ns = 11

sub_pops = [np.arange(i*(4), (i+1)*(4)) for i in range(ns-1)] + [np.arange((ns-1)*4,p)]

obs_sweeps = 500

obs_pops = np.hstack([np.hstack((np.arange(1,len(sub_pops)),np.arange(len(sub_pops)-1)[::-1])) for i in range(obs_sweeps)])
#obs_pops = np.hstack([np.arange(1), obs_pops, np.arange(1,len(sub_pops))])
obs_time = np.arange(1,T+1)
#obs_time = np.array([i*T//len(obs_pops) for i in range(1,len(obs_pops)+1)])

obs_scheme = ObservationScheme(p=p,T=T,
                               sub_pops=sub_pops, 
                               obs_pops=obs_pops,
                               obs_time=obs_time)
obs_scheme.gen_mask_from_scheme()

plt.figure(figsize=(16,10))
plt.subplot(3,2,1)
x = np.arange(11)
for i in range(ns+1):
    plt.plot(10*i*np.array([1,1]), np.array([0,p]), 'k')
#    plt.plot(10*i+x, 100-10*x, '--', color='b')
plt.xticks(range(10,101,10), range(1,11))
plt.yticks(np.arange(0,41,10)+.5, np.arange(1, 42, 10)[::-1])
plt.axis([0,100,0,p])
#plt.xlabel('time [s]')
plt.ylabel('observed pixel/z-plane')
plt.subplot(3,2,1).get_yaxis().set_tick_params(direction='out')
plt.subplot(3,2,1).get_xaxis().set_tick_params(direction='out')

plt.subplot(3,2,3)
x = np.arange(11)
for i in range(ns+1):
    plt.plot(10*i*np.array([1,1]), np.array([0,p]), 'k')
    #plt.plot(10*i+x, p-p*x/10, '--', color='b')
    if np.mod(i,2)==0:
        plt.plot(10*i+x, p-p*x/10, '--', color='b')
    else:
        plt.plot(10*i+x, p*x/10, '--', color='b')
    
    
plt.xticks(range(10,101,10), range(1,11))
plt.yticks(np.arange(0,41,10)+.5, np.arange(1, 42, 10)[::-1])
plt.axis([0,100,0,p])
#plt.xlabel('time [s]')
plt.ylabel('observed pixel/z-plane')
plt.subplot(3,2,3).get_yaxis().set_tick_params(direction='out')
plt.subplot(3,2,3).get_xaxis().set_tick_params(direction='out')

plt.subplot(3,2,5)
plt.imshow(obs_scheme.mask[:100,:].T, interpolation='None', aspect='auto')
plt.gray()
plt.xticks(np.arange(10,101,10)-0.5, range(1,11,1))
plt.yticks(np.arange(0,41,10), np.arange(1, 42, 10))
plt.xlabel('time [s]')
plt.ylabel('observed pixel/z-plane')
plt.subplot(3,2,5).get_yaxis().set_tick_params(direction='out')
plt.subplot(3,2,5).get_xaxis().set_tick_params(direction='out')

W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True)

for i in range(5):
    plt.subplot(6,4,3+i*4)
    plt.imshow(1/W[2*i], interpolation='None')
    plt.xticks([])
    plt.yticks([])
    plt.ylabel('m = ' + str(2*i))
    plt.gray()

    plt.subplot(6,4,4+i*4)
    plt.imshow(1/W[2*i+1], interpolation='None')
    plt.xticks([])
    plt.yticks([])
    plt.ylabel('m = ' + str(2*i+1))
    plt.gray()
    #if i == 4:
    #    plt.colorbar()
plt.subplot(6,4,23)
plt.imshow(1/W[10], interpolation='None')
plt.xticks([])
plt.yticks([])
plt.ylabel('m = ' + str(11))
plt.gray()
    
plt.subplot(6,4,24)    
tmp = np.zeros_like(W[0])
for i in range(len(W)):
    tmp += 1/W[i]
plt.imshow(tmp, interpolation='None')
#plt.colorbar()
plt.clim([0,tmp.max()])
plt.xticks([])
plt.yticks([])
plt.ylabel('U m')


plt.savefig(fig_path + 'fig4_A_alt11.pdf')

plt.show()