In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error

data_path = '/home/marcel/Desktop/Projects/Stitching/code/le_stitch/python/fits/compare_vs_grouse/'

#np.random.seed(0)

# define problem size
p, n, T = 10, 2, 200
lag_range_full = np.arange(0, 4)
lag_range = lag_range_full.copy()
kl_ = np.max(lag_range)+1

nr = 0 # number of real eigenvalues
snr = (.25, .25)
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.9, 0.99, 0.9, 0.99

print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

# I/O matter
mmap, chunksize = False, np.min((p,2000))
verbose=True

# create subpopulations
sub_pops = (np.arange(0,6), np.arange(4,p))
obs_pops = np.array([0,1])
obs_time = np.array([T//2,T])

obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)

obs_scheme = ObservationScheme(p=p, T=T, 
                               sub_pops=sub_pops, obs_pops=obs_pops, 
                               obs_time=obs_time, obs_idx=obs_time, 
                               idx_grp=idx_grp)
    
# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 10, np.inf, 500
a, b1, b2, e = 0.01, 0.9, 0.99, 1e-8
a_R = 1 * a

# settings for quick initial SGD fitting phase for our model
batch_size_late, max_zip_size_late, max_iter_late = 50, np.inf, 2500
a_late, b1_late, b2_late, e_late = 0.001, 0.9, 0.99, 1e-8
a_R_late = 1 * a
    
# settings for GROUSE
a_grouse = 0.0001
max_iter_grouse = 3000


mask = np.zeros((T,p))
for t in range(T):
    for i in range(len(obs_time)):
        if t < obs_time[i]:
            mask[t,sub_pops[obs_pops[i]]] = 1
            break
        
obs_scheme.mask = mask
plt.imshow(mask.T)
plt.show()

num_runs = 500
res = np.zeros((num_runs, 3))
rgt = np.zeros((num_runs, 2))
for run in range(num_runs):
    
    print('run ' + str(run+1) + '/' + str(num_runs))
    
    # draw system matrices 
    lag_range = lag_range_full.copy()
    pars_true, x, y, Qs, idx_a, idx_b = gen_data(p,n,lag_range,T, nr,
                                                 eig_m_r, eig_M_r, 
                                                 eig_m_c, eig_M_c,
                                                 mmap, chunksize,
                                                 data_path,snr=snr)    
        
    pars_true['X'] = np.vstack([np.linalg.matrix_power(pars_true['A'],k).dot(pars_true['Pi']) for k in lag_range_full])

    # fit our model with multiple time-lags
    print('\n - multiple lags')
    pars_est_m = 'default'
    t = time.time()
    _, pars_est_m, traces_m = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_m,
                                          alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                          batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size)
    _, pars_est_m, traces_m2 = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_m,
                                          alpha_C=a_late,alpha_R=a_R_late,b1_C=b1_late,b2_C=b2_late,e_C=e_late,
                                          max_iter=max_iter_late,batch_size=batch_size_late,
                                          verbose=verbose, max_zip_size=max_zip_size_late)
    traces_m = (np.hstack((traces_m[0], traces_m2[0])), np.hstack((traces_m[1], traces_m2[1])))
    t_m = time.time() - t
    print_slim(Qs,lag_range,pars_est_m,idx_a,idx_b,traces_m,mmap,data_path)
    print('fitting time was ', t_m, 's')

    rgt[run, 0] = f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],Pi=pars_true['Pi'],
                                                          R=pars_true['R'],lag_range=lag_range,Qs=Qs,
                                                          idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b)
    print('final error: ' + str(traces_m[0][-1]))    
    print('ground-truth reference error: ' + str(rgt[run,0]))


    # fit our model with single time-lag
    print('\n - single lag')
    lag_range, pars_est_s = np.array([0]), 'default'
    t = time.time()
    _, pars_est_s, traces_s = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_s,
                                          alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,batch_size=batch_size,
                                          verbose=verbose, max_zip_size=max_zip_size)
    _, pars_est_s, traces_s2 = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_s,
                                          alpha_C=a_late,alpha_R=a_R_late,b1_C=b1_late,b2_C=b2_late,e_C=e_late,
                                          max_iter=max_iter_late,batch_size=batch_size_late,
                                          verbose=verbose, max_zip_size=max_zip_size_late)
    traces_s = (np.hstack((traces_s[0], traces_s2[0])), np.hstack((traces_s[1], traces_s2[1])))
    t_s = time.time() - t
    print_slim(Qs,lag_range,pars_est_s,idx_a,idx_b,traces_s,mmap,data_path)
    print('fitting time was ', t_s, 's')

    rgt[run, 1] = f_l2_Hankel_nl(C=pars_true['C'],X=pars_true['X'],Pi=pars_true['Pi'],
                                                          R=pars_true['R'],lag_range=lag_range,Qs=Qs,
                                                          idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b)    
    
    print('final error: ' + str(traces_s[0][-1]))        
    print('ground-truth reference error: ' + str(rgt[run,1]))

    # fit GROUSE
    t = time.time()
    print('\n - GROUSE')
    tracker = Grouse(p, n, a_grouse )
    error = np.zeros(max_iter_grouse)
    for i in range(max_iter_grouse):
        if verbose and np.mod(i,max_iter_grouse//10) == 0:
            print('finished % ' + str((100*i)//max_iter_grouse))
        idx = np.random.permutation(T)
        for j in range(T):
            tracker.consume(y[idx[j],:].reshape(p,1), mask[idx[j],:].reshape(p,1))

        error[i] = calc_subspace_proj_error(pars_true['C'], tracker.U)
    t_g = time.time() - t
    pars_est_g = {'C' : tracker.U}
    
    
    res[run,:] = np.array([ calc_subspace_proj_error(pars_true['C'], pars_est_m['C']),
                            calc_subspace_proj_error(pars_true['C'], pars_est_s['C']),
                            error[-1]])
    plt.figure(figsize=(20,10))
    plt.subplot(1,3,1)
    plt.loglog(traces_m[0])
    plt.xlabel('norm. SE')
    plt.title('final error multiple time-lags: ' + str(calc_subspace_proj_error(pars_true['C'], pars_est_m['C'])))    
    plt.subplot(1,3,2)
    plt.loglog(traces_s[0])
    plt.xlabel('norm. SE')
    plt.title('final error single time-lag: ' + str(calc_subspace_proj_error(pars_true['C'], pars_est_s['C'])))
    plt.subplot(1,3,3)
    plt.loglog(range(1,max_iter_grouse+1), error)
    plt.title('final error GROUSE: ' + str(error[-1]))
    plt.show()
    

    save_dict = {'p' : p,
                 'n' : n,
                 'T' : T,
                 'snr' : snr,
                 'obs_scheme' : obs_scheme,
                 'lag_range' : lag_range_full,
                 'x' : x,
                 'y' : y,
                 'pars_true' : pars_true,
                 'pars_est_s' : pars_est_s,
                 'pars_est_m' : pars_est_m,
                 'pars_est_g' : pars_est_g,
                 'res' : res,
                 'rgt' : rgt,
                 't_s' : t_s,
                 't_m' : t_m,
                 't_g' : t_g,
                 'traces_m' : traces_m,
                 'traces_s' : traces_s,
                 'traces_g' : error
                }
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_run' + str(run) + '_partial_dat'
    np.savez(data_path + file_name, save_dict)

    

In [None]:
res = res[:run,:]
rgt = rgt[:run,:]
algorithms = ['SSID 4 lags', 'SSID 1 lag', 'GROUSE']

plt.figure(figsize=(6,6))
plt.plot([0,1], [0,1], 'k')
plt.hold(True)
plt.plot(res[:,2], res[:,0], '.')
plt.xlabel(algorithms[2])
plt.ylabel(algorithms[0])
plt.plot()


plt.figure(figsize=(6,6))
plt.plot([0,1], [0,1], 'k')
plt.hold(True)
plt.plot(res[:,2], res[:,1], '.')
plt.xlabel(algorithms[2])
plt.ylabel(algorithms[1])
plt.plot()


plt.figure(figsize=(6,6))
plt.plot([0,1], [0,1], 'k')
plt.hold(True)
plt.plot(res[:,1], res[:,0], '.')
plt.xlabel(algorithms[1])
plt.ylabel(algorithms[0])
plt.plot()

plt.figure(figsize=(14,4))
plt.subplot(1,3,1)
plt.hist(res[:,0], bins=np.linspace(0,1,21))
plt.title('projection errors ' + algorithms[0] )
plt.subplot(1,3,2)
plt.hist(res[:,1], bins=np.linspace(0,1,21))
plt.title('projection errors ' + algorithms[1] )
plt.subplot(1,3,3)
plt.hist(res[:,2], bins=np.linspace(0,1,21))
plt.title('projection errors ' + algorithms[2] )
plt.show()
