In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error

data_path = '/home/marcel/Desktop/Projects/Stitching/code/le_stitch/python/fits/compare_vs_grouse/'

#np.random.seed(0)

# define problem size
p, n, T = 500, 20, 100
lag_range_full = np.arange(5)
lag_range = lag_range_full.copy()
kl_ = np.max(lag_range)+1

nr = 0 # number of real eigenvalues
snr = (1., 1.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.95, 0.90, 0.95

print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

# I/O matter
mmap, chunksize = False, np.min((p,2000))
verbose=True

# create subpopulations
sub_pops = (np.arange(0,p), np.arange(0,p))
obs_pops = np.array([0,1])
obs_time = np.array([T//2,T])

obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)

obs_scheme = ObservationScheme(p=p, T=T, 
                               sub_pops=sub_pops, obs_pops=obs_pops, 
                               obs_time=obs_time, obs_idx=obs_time, 
                               idx_grp=idx_grp)

    
# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, np.inf, 100
a, b1, b2, e = 0.01, 0.99, 0.99, 1e-8
a_R = 1 * a

# settings for quick initial SGD fitting phase for our model
batch_size_late, max_zip_size_late, max_iter_late = 25, np.inf, 400
a_late, b1_late, b2_late, e_late = 0.01, 0.9, 0.99, 1e-8
a_R_late = 1 * a_late
    
# settings for GROUSE
a_grouse = 0.002
max_iter_grouse = 500


mask = np.zeros((T,p))
for t in range(T):
    for i in range(len(obs_time)):
        if t < obs_time[i]:
            mask[t,sub_pops[obs_pops[i]]] = 1
            break
        
obs_scheme.mask = mask
plt.imshow(mask.T)
plt.show()

num_runs = 50
res = np.zeros((num_runs, 3))
rgt = np.zeros((num_runs, 2))
for run in range(num_runs):
    
    print('run ' + str(run+1) + '/' + str(num_runs))
    
    # draw system matrices 
    lag_range = lag_range_full.copy()
    pars_true, x, y, Qs, idx_a, idx_b = gen_data(p,n,lag_range,T, nr,
                                                 eig_m_r, eig_M_r, 
                                                 eig_m_c, eig_M_c,
                                                 mmap, chunksize,
                                                 data_path,
                                                 snr=snr, whiten=whiten)    
    
    mask = np.ones((T,p))
    #for t in range(T):
    #    mask[t,np.random.choice(p,num_unobs,replace=False)] = 0
    obs_scheme.mask = mask
    plt.imshow(mask.T)
    plt.show()
    
    pars_true['X'] = np.vstack([np.linalg.matrix_power(pars_true['A'],k).dot(pars_true['Pi']) for k in lag_range_full])

    # fit our model with multiple time-lags
    print('\n - multiple lags')
    pars_est_m = 'default'
    t = time.time()
    
    proj_errors = np.zeros((max_iter, 2))
    def pars_track(C,X,R,t): 
        proj_errors[t,0] = calc_subspace_proj_error(pars_true['C'], C)
    _, pars_est_m, traces_m = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_m,
                                          alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                          batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size,
                                          pars_track=pars_track)
    proj_errors_late = np.zeros((max_iter_late, 2))    
    def pars_track(C,X,R,t): 
        proj_errors_late[t,0] = calc_subspace_proj_error(pars_true['C'], C)
    _, pars_est_m, traces_m2 = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_m,
                                          alpha_C=a_late,alpha_R=a_R_late,b1_C=b1_late,b2_C=b2_late,e_C=e_late,
                                          max_iter=max_iter_late,batch_size=batch_size_late,
                                          verbose=verbose, max_zip_size=max_zip_size_late,pars_track=pars_track)
    traces_m = (np.hstack((traces_m[0], traces_m2[0])), np.hstack((traces_m[1], traces_m2[1])))
    t_m = time.time() - t
    print_slim(Qs,lag_range,pars_est_m,idx_a,idx_b,traces_m,mmap,data_path)
    print('fitting time was ', t_m, 's')
    print('rank of final C_est: ', sp.linalg.orth(pars_est_m['C']).shape[1])
    print('final proj. error: ', str(calc_subspace_proj_error(pars_true['C'], pars_est_m['C'])))

    rgt[run, 0] = f_l2_Hankel_nl(C=pars_true['C'],
                                 X=np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]),
                                 Pi=np.cov(x.T),
                                 R=pars_true['R'],lag_range=lag_range,Qs=Qs,
                                 idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b)
    print('final error: ' + str(traces_m[0][-1]))    
    print('ground-truth reference error: ' + str(rgt[run,0]))


    # fit our model with single time-lag    
    print('\n - single lag')
    lag_range, pars_est_s = np.array([0]), 'default'
    t = time.time()
    def pars_track(C,X,R,t): 
        proj_errors[t,1] = calc_subspace_proj_error(pars_true['C'], C)
    
    _, pars_est_s, traces_s = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_s,
                                          alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,batch_size=batch_size,
                                          verbose=verbose, max_zip_size=max_zip_size,pars_track=pars_track)
    def pars_track(C,X,R,t): 
        proj_errors_late[t,1] = calc_subspace_proj_error(pars_true['C'], C)
    _, pars_est_s, traces_s2 = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_s,
                                          alpha_C=a_late,alpha_R=a_R_late,b1_C=b1_late,b2_C=b2_late,e_C=e_late,
                                          max_iter=max_iter_late,batch_size=batch_size_late,
                                          verbose=verbose, max_zip_size=max_zip_size_late,pars_track=pars_track)
    traces_s = (np.hstack((traces_s[0], traces_s2[0])), np.hstack((traces_s[1], traces_s2[1])))
    proj_errors = np.vstack((proj_errors, proj_errors_late))
    t_s = time.time() - t
    print_slim(Qs,lag_range,pars_est_s,idx_a,idx_b,traces_s,mmap,data_path)
    print('fitting time was ', t_s, 's')
    print('rank of final C_est: ', sp.linalg.orth(pars_est_s['C']).shape[1])
    print('final proj. error: ', str(calc_subspace_proj_error(pars_true['C'], pars_est_s['C'])))

    rgt[run, 1] = f_l2_Hankel_nl(C=pars_true['C'],
                                 X=np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]),
                                 Pi=np.cov(x.T),
                                 R=pars_true['R'],lag_range=lag_range,Qs=Qs,
                                 idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b)    
    
    print('final error: ' + str(traces_s[0][-1]))        
    print('ground-truth reference error: ' + str(rgt[run,1]))

    # fit GROUSE
    t = time.time()
    print('\n - GROUSE')
    tracker = Grouse(p, n, a_grouse )
    error = np.zeros(max_iter_grouse)
    for i in range(max_iter_grouse):
        if verbose and np.mod(i,max_iter_grouse//10) == 0:
            print('finished % ' + str((100*i)//max_iter_grouse))
        idx = np.random.permutation(T)
        for j in range(T):
            tracker.consume(y[idx[j],:].reshape(p,1), mask[idx[j],:].reshape(p,1))

        error[i] = calc_subspace_proj_error(pars_true['C'], tracker.U)
    t_g = time.time() - t
    pars_est_g = {'C' : tracker.U}
    
    
    res[run,:] = np.array([ calc_subspace_proj_error(pars_true['C'], pars_est_m['C']),
                            calc_subspace_proj_error(pars_true['C'], pars_est_s['C']),
                            error[-1]])
    plt.figure(figsize=(20,10))
    plt.subplot(1,3,1)
    plt.loglog(traces_m[0])
    plt.xlabel('norm. SE')
    plt.title('final error multiple time-lags: ' + str(calc_subspace_proj_error(pars_true['C'], pars_est_m['C'])))    
    plt.subplot(1,3,2)
    plt.loglog(traces_s[0])
    plt.xlabel('norm. SE')
    plt.title('final error single time-lag: ' + str(calc_subspace_proj_error(pars_true['C'], pars_est_s['C'])))
    plt.subplot(1,3,3)
    plt.loglog(range(1,max_iter_grouse+1), error)
    plt.title('final error GROUSE: ' + str(error[-1]))
    plt.show()
    
    plt.figure(figsize=(20,10))
    plt.plot(proj_errors[:,0])
    plt.hold(True)
    plt.plot(proj_errors[:,1])
    plt.show()

    save_dict = {'p' : p,
                 'n' : n,
                 'T' : T,
                 'snr' : snr,
                 'obs_scheme' : obs_scheme,
                 'lag_range' : lag_range_full,
                 'x' : x,
                 'y' : y,
                 'pars_true' : pars_true,
                 'pars_est_s' : pars_est_s,
                 'pars_est_m' : pars_est_m,
                 'pars_est_g' : pars_est_g,
                 'res' : res,
                 'rgt' : rgt,
                 't_s' : t_s,
                 't_m' : t_m,
                 't_g' : t_g,
                 'traces_m' : traces_m,
                 'traces_s' : traces_s,
                 'traces_g' : error
                }
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_run' + str(run) + '_partial_dat'
    #np.savez(data_path + file_name, save_dict)

    

In [None]:
plt.figure(figsize=(10,10))
kl_ = np.max(lag_range_full)
tmp = np.sqrt(np.diag(np.cov(x[:-(kl_+1)].T)))
print('empirical std: ' + str(tmp))
tmp = np.outer(tmp,tmp)
tmp_est = np.sqrt(np.diag(pars_est_m['X'][:n, :]))
tmp_est = np.outer(tmp_est, tmp_est)
for k in range(len(lag_range_full)):
    k_ = lag_range_full[k]
    plt.subplot(len(lag_range_full),2,2*(k)+1)
    plt.imshow(np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:]/tmp, interpolation='None')
    plt.colorbar()
    plt.title('k = ' + str(k))
    
    plt.subplot(len(lag_range_full),2,2*(k+1))
    plt.imshow(pars_est_m['X'][(k)*n:(k+1)*n,:] / tmp_est, interpolation='None')
    plt.colorbar()
    plt.title('k = ' + str(k))
    
plt.show()
plt.figure(figsize=(7,3))
plt.subplot(1,2,1)
plt.plot(x[:,0], x[:,1], '.')
plt.subplot(1,2,2)
plt.plot(x)
plt.show()
    

In [None]:
plot_slim(Qs,lag_range_full,pars_true,idx_a,idx_b,traces_s,mmap,data_path)

In [None]:
np.sort(np.abs(np.linalg.eigvals(pars_true['A'])))

In [None]:
pars_true['Q'], pars_true['Pi']

In [None]:
res = res[:run,:]
rgt = rgt[:run,:]
algorithms = ['SSID 4 lags', 'SSID 1 lag', 'GROUSE']

plt.figure(figsize=(6,6))
plt.plot([0,1], [0,1], 'k')
plt.hold(True)
plt.plot(res[:,2], res[:,0], '.')
plt.xlabel(algorithms[2])
plt.ylabel(algorithms[0])
plt.plot()


plt.figure(figsize=(6,6))
plt.plot([0,1], [0,1], 'k')
plt.hold(True)
plt.plot(res[:,2], res[:,1], '.')
plt.xlabel(algorithms[2])
plt.ylabel(algorithms[1])
plt.plot()


plt.figure(figsize=(6,6))
plt.plot([0,1], [0,1], 'k')
plt.hold(True)
plt.plot(res[:,1], res[:,0], '.')
plt.xlabel(algorithms[1])
plt.ylabel(algorithms[0])
plt.plot()

plt.figure(figsize=(14,4))
plt.subplot(1,3,1)
plt.hist(res[:,0], bins=np.linspace(0,1,21))
plt.title('projection errors ' + algorithms[0] )
plt.subplot(1,3,2)
plt.hist(res[:,1], bins=np.linspace(0,1,21))
plt.title('projection errors ' + algorithms[1] )
plt.subplot(1,3,3)
plt.hist(res[:,2], bins=np.linspace(0,1,21))
plt.title('projection errors ' + algorithms[2] )
plt.show()


In [None]:
    print('ground-truth reference error: ' + str(f_l2_Hankel_nl(C=pars_true['C'],X=np.cov(x.T),Pi=np.cov(x.T),
                                                          R=pars_true['R'],lag_range=lag_range,Qs=[np.cov(y.T)],
                                                          idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b)))

In [None]:
print_slim(Qs,lag_range_full,pars_true,idx_a,idx_b,traces_m,mmap,data_path)

In [None]:
%matplotlib inline

plt.figure(figsize=(20,8))
plt.subplot(1,3,1)
plt.imshow(np.cov(y.T), interpolation='None')
plt.colorbar()

plt.subplot(1,3,2)
pars_est = pars_est_s
plt.imshow(pars_est['C'].dot(pars_est['X'][:n,:]).dot(pars_est['C'].T) + np.diag(pars_est['R']),interpolation='None')
plt.colorbar()

plt.subplot(1,3,3)
pars_est = pars_est_m
plt.imshow(pars_est['C'].dot(pars_est['X'][:n,:]).dot(pars_est['C'].T) + np.diag(pars_est['R']),interpolation='None')
plt.colorbar()

plt.show()

plt.imshow(pars_true['C'].dot(pars_true['Pi']).dot(pars_true['C'].T)  + np.diag(pars_true['R']),interpolation='None')
plt.colorbar()
plt.show()

plt.imshow(Qs[0], interpolation='None')
plt.colorbar()
plt.show()


In [None]:
%matplotlib notebook
from mpl_toolkits.mplot3d import Axes3D
pars_est = pars_est_s
fig = plt.figure(figsize=(12,12))
ax = fig.add_subplot(111, projection='3d')
X = np.random.normal(size=(n, 100))
Y = pars_est['C'].dot(X)
Yt = pars_true['C'].dot(X)
ax.scatter(xs=Y[0,:], ys=Y[1,:], zs=Y[2,:], c='b')
ax.scatter(xs=Yt[0,:], ys=Yt[1,:], zs=Yt[2,:], c='r')
ax.scatter(xs=y.T[0,:], ys=y.T[1,:], zs=y.T[2,:], c='k')


In [None]:
from mpl_toolkits.mplot3d import Axes3D
pars_est = pars_est_m
fig = plt.figure(figsize=(12,12))
ax = fig.add_subplot(111, projection='3d')
X = np.random.normal(size=(n, 100))
Y = pars_est['C'].dot(X)
Yt = pars_true['C'].dot(X)
Axes3D.scatter(ax, xs=Y[0,:], ys=Y[1,:], zs=Y[2,:])
#plt.hold(True)
Axes3D.scatter(ax, xs=Yt[0,:], ys=Yt[1,:], zs=Yt[2,:], color='r')
Axes3D.scatter(ax, xs=y.T[0,:], ys=y.T[1,:], zs=y.T[2,:], color='k')
plt.show()

In [None]:
from mpl_toolkits.mplot3d import Axes3D
pars_est = pars_est_g
fig = plt.figure(figsize=(12,12))
ax = fig.add_subplot(111, projection='3d')
X = np.random.normal(size=(n, 100))
Y = pars_est['C'].dot(X)
Yt = pars_true['C'].dot(X)
Axes3D.scatter(ax, xs=Y[0,:], ys=Y[1,:], zs=Y[2,:])
#plt.hold(True)
Axes3D.scatter(ax, xs=Yt[0,:], ys=Yt[1,:], zs=Yt[2,:], color='r')
Axes3D.scatter(ax, xs=y.T[0,:], ys=y.T[1,:], zs=y.T[2,:], color='k')
plt.show()

In [None]:
pars_est_s['C'].dot(pars_est_s['X'][:n,:]).dot(pars_est_s['C'].T), \
pars_est_m['C'].dot(pars_est_m['X'][:n,:]).dot(pars_est_m['C'].T), \
np.cov(y.T) - np.diag(pars_true['R'])