In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default
from sklearn.decomposition import FactorAnalysis

run = '_e3'
# define problem size
lag_range = np.arange(10)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 10000 + kl_
T = T_full
dtype=np.float

nr = 0 # number of real eigenvalues
snr = (4., 4.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.99, 0.999, 0.99, 0.999

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = False, np.min((p,1000))
verbose=True

rnd_seeds = range(43,44)
overlaps = (20,)


for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    
    np.random.seed(rnd_seed)
    pars_true, x, y, _, _ = gen_data(p,n,lag_range,T_full, nr,
                                     eig_m_r, eig_M_r, 
                                     eig_m_c, eig_M_c,
                                     mmap, chunksize,
                                     data_path,
                                     snr=snr, whiten=whiten)        
    pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
    
    idx_a = np.sort(np.random.choice(p, np.minimum(p,1000), replace=False))
    idx_b = idx_a.copy()

    sso = True

    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        reps = 1
        ns = 10
        sub_pops = [np.arange(0,  p//ns + overlap//2)]
        sub_pops = sub_pops + [np.arange(i*p//ns-overlap//2, (i+1)*p//ns+overlap//2) for i in range(1,ns-1)]
        sub_pops.append(np.arange((ns-1)*p//ns-overlap//2,p))
        obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
        obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
        obs_scheme = ObservationScheme(p=p, T=T, 
                                        sub_pops=sub_pops, 
                                        obs_pops=obs_pops, 
                                        obs_time=obs_time)

        W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
        if overlap < p:
            for m in range(1, len(lag_range)):
                W[m] = W[0].copy()
                #W[m][0,1] = 0
                #W[m][1,0] = 0

        print('computing time-lagged covariances')
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)    
        if overlap < p:
            for m in range(len(lag_range)):    
                Om[m] = Om[0].copy()
                #Om[m][np.ix_(obs_scheme.idx_grp[0], obs_scheme.idx_grp[1])] = False
                #Om[m][np.ix_(obs_scheme.idx_grp[1], obs_scheme.idx_grp[0])] = False

        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        

        
        sub_pops = obs_scheme.sub_pops
        rnd_seed_fit = np.random.get_state()    
        
        # start fitting    
        pars_ests = []
        tracess = []
        tss = []
        for j in range(len(sub_pops)):
            print('fitting subpop #' + str(j))        
            data = y[:obs_time[0], sub_pops[0]] if j==0 else y[obs_time[j-1]:obs_time[j], sub_pops[j]]
            print('data shape ', data.shape)
            fa = FactorAnalysis(n_components=n, 
                                tol=0.01, 
                                copy=True, 
                                max_iter=1000, 
                                noise_variance_init=None, 
                                svd_method='randomized', 
                                iterated_power=3, 
                                random_state=0)
            t = time.time()
            fa.fit(data)
            t = time.time() - t
            print('fitting time: ', t)
            print('principal angles: ', principal_angle(pars_true['C'][sub_pops[j],:], fa.components_.T))

            
            pars_ests.append({
                'C' : fa.components_.T.copy(),
                'Pi' : np.eye(n).copy(),
                'R' : fa.noise_variance_.copy()
            })
            tracess.append(fa.loglike_.copy())
            tss.append(t)        
            del fa

        
        C12 = np.zeros_like(pars_true['C'])    
        C12[sub_pops[0],:] = pars_ests[0]['C']
        for j in range(1,len(sub_pops)):
            idx_overlap = np.intersect1d(sub_pops[j-1], sub_pops[j])
            C2 = np.nan * np.zeros((p,n))            
            C2[sub_pops[ j ], :] = pars_ests[ j ]['C']

            if overlap > 0:
                Q, sclale = la.orthogonal_procrustes(C2[idx_overlap,:], C12[idx_overlap,:])
            else:
                Q = np.eye(n)
            C12[sub_pops[j],:] = C2[sub_pops[j],:].dot(Q)

        print(' total subsp. error') 
        print(principal_angle(pars_true['C'], C12))
        
        
        
        pars_init={
            'C' : np.asarray(la.orth(np.random.normal(size=(p,n))),dtype=dtype) * np.sqrt(p) / np.sqrt(n),
            'A' : np.asarray(np.diag(np.linspace(0.89, 0.91, n)), dtype=dtype),
            'Pi' : np.asarray(np.eye(n), dtype=dtype),
            'B' :  np.asarray(np.eye(n), dtype=dtype),
            'R' :  np.ones(p, dtype=dtype)
        }            
        pars_init['X'] = np.vstack([ np.linalg.matrix_power(pars_init['A'],m).dot(pars_init['Pi']) for m in lag_range])

        pars_est, traces, ts= run_default(
                    alphas    = (0.01, 0.0005), 
                    b1s       = (0.95 , 0.9), 
                    a_decays  = (0.98, 0.98), 
                    batch_sizes = (1, 10), 
                    max_zip_sizes =  (1000,100), 
                    max_iters = (200, 10 ),
                    parametrizations = ('nl', 'nl'),
                    pars_est=pars_init, pars_true=pars_true, n=n, 
                    y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
                    idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                    traces=[[], [], []], ts = [])    


        print('per-subpops principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'][sub_pops[0],:], C[sub_pops[0],:]))
        print(principal_angle(pars_true['C'][sub_pops[1],:], C[sub_pops[1],:]))

        print('final principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'], C))
            
            
        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : data_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_ests_FA' : pars_ests,
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_FA' : tracess, 'ts':ts, 'ts_FA':tss,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap)  + 'breakFA'
        np.savez(data_path + file_name, save_dict)    


In [None]:
pars_est['R']

In [None]:
os.chdir("/home/mackelab/Desktop/Projects/Stitching/code/pyRRHDLDS/core")
from ssm_scripts import setup_fit_lds


# EM    
max_iter_EM = 200
eps_cov = 1e-20
epsilon = 0
likes = np.zeros(max_iter_EM)
res = np.zeros((max_iter_EM, n+1))
fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                        u=None, 
                        max_iter=max_iter_EM,
                        epsilon=epsilon, 
                        eps_cov=eps_cov,
                        plot_flag=False, 
                        trace_pars_flag=True, 
                        trace_stats_flag=False, 
                        diag_R_flag=True,
                        use_A_flag=True, 
                        use_B_flag=False)        

if pars_est['A'] is None:
    pars_est['A'] = np.linalg.lstsq(pars_est['X'][:(len(lag_range)-1)*n,:], pars_est['X'][n:len(lag_range)*n,:])[0]
pars_est['Q'] = np.eye(n)
pars_est['d'] = np.zeros(p, dtype=dtype)
pars_est['mu0'] = np.zeros(n ,dtype=dtype)
pars_est['V0'] = np.eye(  n ,dtype=dtype)            

pars_hat = pars_est
t = time.time()

pars_hat['B'] = np.empty((n,0), dtype=dtype)
pars_hat,ll = fit_lds(x_dim=n,
                      pars=pars_hat, 
                      obs_scheme={'sub_pops' : obs_scheme.sub_pops,
                                  'obs_pops' : obs_scheme.obs_pops,
                                  'obs_time' : obs_scheme.obs_time},
                      save_file=None)
for i_ in range(len(pars_hat['Cs'])-1):
    res[i_,1:] = principal_angle(pars_hat['Cs'][i_+1], pars_true['C'])

likes = np.array(ll[1:])
elapsed_time = time.time() - t
#except:
#    elapsed_time = np.nan 
#    likes = np.zeros(0)
#    print('\n ')
#    print('EM BROKE')
#    print('\n ')
print('elapsed time for fitting is')
print(elapsed_time)
pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                   pars_hat['Q'])        

t = time.time() - t    
print('fitting time: ', t)

print(principal_angle(pars_true['C'], pars_hat['C']))

plt.figure()
plt.subplot(1,2,1)
plt.plot(res[:,1:])
plt.title('final princ. angles')
plt.subplot(1,2,2)
plt.plot(likes)
plt.show()


In [None]:
np.abs(np.linalg.eigvals(pars_hat['A']))

In [None]:
obs_scheme.gen_mask_from_scheme()
plt.figure(figsize=(10,10))
plt.imshow(obs_scheme.mask[::100,:].T, aspect='auto')
plt.show()

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time
from scipy import linalg as la
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.icml_scripts import run_default
from sklearn.decomposition import FactorAnalysis

run = '_e3'
# define problem size
lag_range = np.arange(10)
kl_ = np.max(lag_range)+1
p, n = 1000, 10
T_full = 10000 + kl_
T = T_full
dtype=np.float

nr = 0 # number of real eigenvalues
snr = (4., 4.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.99, 0.999, 0.99, 0.999

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


# I/O matter
mmap, chunksize = False, np.min((p,1000))
verbose=True

rnd_seeds = range(43,44)
overlaps = (20,)


for rnd_seed in rnd_seeds:
    
    data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e3/sso/seed_' + str(rnd_seed) + '/'

    print('\n')
    print('\n')
    print('seed:', str(rnd_seed))
    print('\n')
    print('\n')
    
    
    np.random.seed(rnd_seed)
    pars_true, x, y, _, _ = gen_data(p,n,lag_range,T_full, nr,
                                     eig_m_r, eig_M_r, 
                                     eig_m_c, eig_M_c,
                                     mmap, chunksize,
                                     data_path,
                                     snr=snr, whiten=whiten)        
    pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
    
    idx_a = np.sort(np.random.choice(p, np.minimum(p,1000), replace=False))
    idx_b = idx_a.copy()

    sso = True

    for overlap in overlaps:

        # compute length of recordings to keep total observation count stable    
        print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

        reps = 1
        ns = 10
        sub_pops = [np.arange(0,  p//ns + overlap//2)]
        sub_pops = sub_pops + [np.arange(i*p//ns-overlap//2, (i+1)*p//ns+overlap//2) for i in range(1,ns-1)]
        sub_pops.append(np.arange((ns-1)*p//ns-overlap//2,p))
        obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
        obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
        obs_scheme = ObservationScheme(p=p, T=T, 
                                        sub_pops=sub_pops, 
                                        obs_pops=obs_pops, 
                                        obs_time=obs_time)

        W = obs_scheme.comp_coocurrence_weights(lag_range, sso=sso, idx_a=idx_a, idx_b=idx_b)
        if overlap < p:
            for m in range(1, len(lag_range)):
                W[m] = W[0].copy()
                #W[m][0,1] = 0
                #W[m][1,0] = 0

        print('computing time-lagged covariances')
        Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                              idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                              mmap=mmap,data_path=data_path,ts=None,ms=None)    
        if overlap < p:
            for m in range(len(lag_range)):    
                Om[m] = Om[0].copy()
                #Om[m][np.ix_(obs_scheme.idx_grp[0], obs_scheme.idx_grp[1])] = False
                #Om[m][np.ix_(obs_scheme.idx_grp[1], obs_scheme.idx_grp[0])] = False

        print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)
        print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],R=pars_true['R'],
                                       Qs=Qs,Om=Om,lag_range=lag_range,ms=range(len(lag_range)),idx_a=idx_a,idx_b=idx_b))        

        
        sub_pops = obs_scheme.sub_pops
        rnd_seed_fit = np.random.get_state()    
        
        # start fitting    
        pars_ests = []
        tracess = []
        tss = []
        for j in range(len(sub_pops)):
            print('fitting subpop #' + str(j))        
            data = y[:obs_time[0], sub_pops[0]] if j==0 else y[obs_time[j-1]:obs_time[j], sub_pops[j]]
            print('data shape ', data.shape)
            fa = FactorAnalysis(n_components=n, 
                                tol=0.01, 
                                copy=True, 
                                max_iter=1000, 
                                noise_variance_init=None, 
                                svd_method='randomized', 
                                iterated_power=3, 
                                random_state=0)
            t = time.time()
            fa.fit(data)
            t = time.time() - t
            print('fitting time: ', t)
            print('principal angles: ', principal_angle(pars_true['C'][sub_pops[j],:], fa.components_.T))

            
            pars_ests.append({
                'C' : fa.components_.T.copy(),
                'Pi' : np.eye(n).copy(),
                'R' : fa.noise_variance_.copy()
            })
            tracess.append(fa.loglike_.copy())
            tss.append(t)        
            del fa

        
        C12 = np.zeros_like(pars_true['C'])    
        C12[sub_pops[0],:] = pars_ests[0]['C']
        for j in range(1,len(sub_pops)):
            idx_overlap = np.intersect1d(sub_pops[j-1], sub_pops[j])
            C2 = np.nan * np.zeros((p,n))            
            C2[sub_pops[ j ], :] = pars_ests[ j ]['C']

            if overlap > 0:
                Q, sclale = la.orthogonal_procrustes(C2[idx_overlap,:], C12[idx_overlap,:])
            else:
                Q = np.eye(n)
            C12[sub_pops[j],:] = C2[sub_pops[j],:].dot(Q)

        print(' total subsp. error') 
        print(principal_angle(pars_true['C'], C12))
        
        
        
        pars_init={
            'C' : np.asarray(la.orth(np.random.normal(size=(p,n))),dtype=dtype) * np.sqrt(p) / np.sqrt(n),
            'A' : np.asarray(np.diag(np.linspace(0.89, 0.91, n)), dtype=dtype),
            'Pi' : np.asarray(np.eye(n), dtype=dtype),
            'B' :  np.asarray(np.eye(n), dtype=dtype),
            'R' :  np.ones(p, dtype=dtype)
        }            
        pars_init['X'] = np.vstack([ np.linalg.matrix_power(pars_init['A'],m).dot(pars_init['Pi']) for m in lag_range])

        pars_est, traces, ts= run_default(
                    alphas    = (0.01, 0.0005), 
                    b1s       = (0.95 , 0.9), 
                    a_decays  = (0.98, 0.98), 
                    batch_sizes = (1, 10), 
                    max_zip_sizes =  (1000,100), 
                    max_iters = (200, 10 ),
                    parametrizations = ('nl', 'nl'),
                    pars_est=pars_init, pars_true=pars_true, n=n, 
                    y=y, sso=sso, obs_scheme=obs_scheme, lag_range=lag_range, 
                    idx_a=idx_a, idx_b=idx_b,Qs=Qs,Om=Om, W=W,
                    traces=[[], [], []], ts = [])    


        print('per-subpops principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'][sub_pops[0],:], C[sub_pops[0],:]))
        print(principal_angle(pars_true['C'][sub_pops[1],:], C[sub_pops[1],:]))

        print('final principal angles')
        C = pars_est['C'].copy()
        print(principal_angle(pars_true['C'], C))
    
        # EM    
        max_iter_EM = 200
        eps_cov = 1e-20
        epsilon = 0
        likes = np.zeros(max_iter_EM)
        res = np.zeros((max_iter_EM, n+1))
        fit_lds = setup_fit_lds(y=y.T.reshape(p,T,1), 
                                u=None, 
                                max_iter=max_iter_EM,
                                epsilon=epsilon, 
                                eps_cov=eps_cov,
                                plot_flag=False, 
                                trace_pars_flag=True, 
                                trace_stats_flag=False, 
                                diag_R_flag=True,
                                use_A_flag=True, 
                                use_B_flag=False)        

        if pars_est['A'] is None:
            pars_est['A'] = np.linalg.lstsq(pars_est['X'][:(len(lag_range)-1)*n,:], pars_est['X'][n:len(lag_range)*n,:])[0]
        pars_est['Q'] = np.eye(n)
        pars_est['d'] = np.zeros(p, dtype=dtype)
        pars_est['mu0'] = np.zeros(n ,dtype=dtype)
        pars_est['V0'] = np.eye(  n ,dtype=dtype)            

        pars_hat = pars_est
        t = time.time()

        pars_hat['B'] = np.empty((n,0), dtype=dtype)
        pars_hat,ll = fit_lds(x_dim=n,
                              pars=pars_hat, 
                              obs_scheme={'sub_pops' : obs_scheme.sub_pops,
                                          'obs_pops' : obs_scheme.obs_pops,
                                          'obs_time' : obs_scheme.obs_time},
                              save_file=None)
        for i_ in range(len(pars_hat['Cs'])-1):
            res[i_,1:] = principal_angle(pars_hat['Cs'][i_+1], pars_true['C'])

        likes = np.array(ll[1:])
        elapsed_time = time.time() - t
        #except:
        #    elapsed_time = np.nan 
        #    likes = np.zeros(0)
        #    print('\n ')
        #    print('EM BROKE')
        #    print('\n ')
        print('elapsed time for fitting is')
        print(elapsed_time)
        pars_hat['Pi'] = sp.linalg.solve_discrete_lyapunov(pars_hat['A'], 
                                                           pars_hat['Q'])        

        t = time.time() - t    
        print('fitting time: ', t)

        print(principal_angle(pars_true['C'], pars_hat['C']))

        plt.figure()
        plt.subplot(1,2,1)
        plt.plot(res[:,1:])
        plt.title('final princ. angles')
        plt.subplot(1,2,2)
        plt.plot(likes)
        plt.show()

            
            
        save_dict = {'p' : p,'n' : n,'T' : T,'snr' : snr,'lag_range' : lag_range,
                     'obs_scheme' : obs_scheme, 'mmap' : mmap,'y' : data_path if mmap else y,
                     'pars_true' : pars_true, 'pars_est' : pars_est, 'pars_ests_FA' : pars_ests,
                     'pars_est_EM' : pars_hat, 'traces_EM' : [likes],
                     'idx_a' : idx_a,'idx_b' : idx_b, 'W' : None,'Qs' : None,'Om' : None,
                     'traces' : traces, 'traces_FA' : tracess, 'ts':ts, 'ts_FA':tss,
                     'rnd_seed' : rnd_seed
                    }
        file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+str(run)+'_'+str(overlap)  + 'breakFA'
        np.savez(data_path + file_name, save_dict)    
