In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim
from ssidid.utility import get_subpop_stats, gen_data

from subtracking import Grouse, calc_subspace_proj_error

#np.random.seed(0)

# define problem size
p, n, T = 5, 2, 100
lag_range_full = np.arange(0, 4)
lag_range = lag_range_full.copy()
kl_ = np.max(lag_range)+1

nr = 0 # number of real eigenvalues
snr = (1., 1.)
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.9, 0.99, 0.9, 0.99

print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

# I/O matter
mmap, chunksize = True, np.min((p,2000))
data_path, save_file = '../fits/', 'test'
verbose=False

# create subpopulations
sub_pops = (np.arange(0,p),) 
obs_pops = np.array([0])
obs_time = np.array([T])

obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 50, np.inf, 20
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = 1 * a
    
# settings for fine-tuning of our model fits with batch gradients    
eps_conv = 0.999
a_batch = 0.01
max_iter_batch = 2000
    
# settings for GROUSE
a_grouse = 0.00002
max_iter_grouse = 2000



num_runs = 20
for run in range(num_runs):
    
    print('run ' + str(run+1) + '/' + str(num_runs))
    
    # draw system matrices 
    pars_true, x, y, Qs, idx_a, idx_b = gen_data(p,n,lag_range,T, nr,
                                                 eig_m_r, eig_M_r, 
                                                 eig_m_c, eig_M_c,
                                                 mmap, chunksize,
                                                 data_path,snr=snr)    

    # fit our model with single time-lag
    lag_range, pars_est_s = np.array([0]), 'default'
    _, pars_est_s, _ = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                          sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                          obs_pops=obs_pops,obs_time=obs_time,init=pars_est_s,
                                          alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,batch_size=batch_size,
                                          verbose=verbose, max_zip_size=max_zip_size)
    t = time.time()
    _, pars_est_s, traces_s = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                          sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                          obs_pops=obs_pops,obs_time=obs_time,init=pars_est_s,
                                          alpha_C=a_batch,max_iter=max_iter_batch,batch_size=None,
                                          verbose=verbose, max_zip_size=np.inf, eps_conv = eps_conv)

    print_slim(Qs,lag_range,pars_est_s,idx_a,idx_b,traces_s,mmap,data_path)
    print('fitting time was ', time.time() - t, 's')



    # fit our model with multiple time-lags
    lag_range, pars_est_m = lag_range_full.copy(), 'default'
    _, pars_est_m, _ = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                          sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                          obs_pops=obs_pops,obs_time=obs_time,init=pars_est_m,
                                          alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,batch_size=batch_size,
                                          verbose=verbose, max_zip_size=max_zip_size)
    t = time.time()
    _, pars_est_m, traces_m = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                          sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                          obs_pops=obs_pops,obs_time=obs_time,init=pars_est_m,
                                          alpha_C=a_batch,max_iter=max_iter_batch,batch_size=None,
                                          verbose=verbose, max_zip_size=np.inf, eps_conv = eps_conv)

    print_slim(Qs,lag_range,pars_est_m,idx_a,idx_b,traces_m,mmap,data_path)
    print('fitting time was ', time.time() - t, 's')



    # fit GROUSE
    nan_count = 0
    tracker = Grouse(p, n, a_grouse )
    error = np.zeros(max_iter_grouse)
    for i in range(max_iter_grouse):
        if verbose and np.mod(i,max_iter_grouse//10) == 0:
            print('finished % ' + str((100*i)//max_iter_grouse))
        idx = np.random.permutation(T)
        for t in range(T):
            sampling_vec = np.ones((p,1))
            sampling_vec[np.random.choice(n,nan_count,replace=False),0] = 0

            tracker.consume(y[idx[t],:].reshape(p,1), sampling_vec)

        error[i] = calc_subspace_proj_error(pars_true['C'], tracker.U)
    pars_est_g = {'C' : tracker.U}
    
    plt.figure(figsize=(20,10))
    plt.subplot(1,3,1)
    plt.loglog(traces_m[0])
    plt.xlabel('norm. SE')
    plt.title('final error multiple time-lags: ' + str(calc_subspace_proj_error(pars_true['C'], pars_est_m['C'])))    
    plt.subplot(1,3,2)
    plt.loglog(traces_s[0])
    plt.xlabel('norm. SE')
    plt.title('final error single time-lag: ' + str(calc_subspace_proj_error(pars_true['C'], pars_est_s['C'])))
    plt.subplot(1,3,3)
    plt.loglog(range(1,max_iter_grouse+1), error)
    plt.title('final error GROUSE: ' + str(error[-1]))
    plt.show()


(p,n,k+l,T) =  (5, 2, 4, 100) 

run 1/20
using batch gradients - switching to plain gradient descent