# Experiment 'e1': recovering LDS parameters from varying amounts of data

- system size fixed, i.e. p = 1k, n = 20, and signal-to-noise ratio such that R gives 90% of total variance. 
- fitting first dynamics-agnostic, then switching to linearized parameterization with in particular $A$ exctracted from agnostic fit
- direct comparison with GROUSE on subspace identification task (principal angles). 

## notes:
- this is a master file. Individual runs for different data-set lengths $T \in [10^3, 10^4, 10^5]$ were run on instances of this file that might have been slightly altered (e.g. reducing max batch size and instead running more epochs for small T's). They should be backed up at the lab dropbox folder. 

# Load stored full data, extract observed data for this experiment

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy import linalg as la
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from subtracking import Grouse, calc_subspace_proj_error

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e1/'
mmap, verbose = True, True

T = 10000
lag_range = np.arange(20)

run = '_e1_init'
p,T_full,n,snr = 1000, 100000, 20, (9., 9.)
file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T_full) + 'snr' + str(np.int(np.mean(snr)//1)) +  run
load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
pars_true = load_file['pars_true']
idx_a, idx_b = load_file['idx_a'].copy(), load_file['idx_b'].copy()

y_full = np.memmap(data_path+'y_full', dtype=np.float, mode='r', shape=(T_full,p))
y = np.memmap(data_path+'y', dtype=np.float, mode='w+', shape=(T,p))
y[:] = y_full[:T, :].copy()
del y_full
del y
chunksize = np.minimum(p, 100)
if mmap: 
    print('ensuring zero-mean data for given observation scheme')
    for i in progprint_xrange(p//chunksize, perline=10):
        y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
        y[:, i*chunksize:(i+1)*chunksize] = y[:, i*chunksize:(i+1)*chunksize] - y[:, i*chunksize:(i+1)*chunksize].mean(axis=0)
        del y
    if (p//chunksize)*chunksize < p:
        y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
        y[:, (p//chunksize)*chunksize:] = y[:, (p//chunksize)*chunksize:] - y[:, (p//chunksize)*chunksize:].mean(axis=0)
        del y        
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
else:
    y -= y.mean(axis=0)

run = '_e1_final'    
    
print('re-computing observation scheme')    
obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=(np.arange(p),), 
                                obs_pops=(0,), 
                                obs_time=(T,))
obs_scheme.comp_subpop_stats()    

W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)

print('re-computing observed covariance matrices')    
Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                      idx_a=idx_a,idx_b=idx_b,W=W,sso=True,
                      mmap=mmap,data_path=data_path,ts=None,ms=None)


def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)


pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],
                               X=pars_true['X'],
                               R=pars_true['R'],
                               Qs=Qs,
                               Om=Om,
                               lag_range=lag_range,
                               ms=range(len(lag_range)),
                               idx_a=idx_a,
                               idx_b=idx_b))
print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)


pars_est = 'default'
pars_est_g = 'default'

rnd_seed = 1

traces, ts = [[], [], []], []



# Fit models

# dynamics-agnostic stage

In [None]:

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, 1000, 100
a, b1, b2, e = 0.01, 0.98, 0.99, 1e-8
a_decay = 0.98

proj_errors = np.zeros((max_iter,n+1))
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces_, _, _, _, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,a_decay=a_decay,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_epoch_size=max_zip_size,
                                      pars_track=pars_track)
traces[0].append(traces_[0])
traces[1].append(traces_[1])
traces[2].append(proj_errors.copy())
ts.append(t)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()


# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 10, 100, 150
a, b1, b2, e = 0.005, 0.95, 0.99, 1e-8
a_decay = 0.98


proj_errors = np.zeros((max_iter,n+1))
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces_, _, _, _, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,a_decay=a_decay,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_epoch_size=max_zip_size,
                                      pars_track=pars_track)
traces[0].append(traces_[0])
traces[1].append(traces_[1])
traces[2].append(proj_errors.copy())
ts.append(t)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.figure(figsize=(12,6))
plt.plot(proj_errors[:,1:])
plt.show()



# linear stage

In [None]:

pars_est['B'] = np.linalg.cholesky(pars_est['Pi'])
pars_est['A'] = np.linalg.lstsq(pars_est['X'][:(len(lag_range)-1)*n,:], pars_est['X'][n:len(lag_range)*n,:])[0]

plt.figure(figsize=(20,5))
plt.imshow( pars_est['X'].T )
plt.colorbar()
plt.show()

plt.figure(figsize=(20,5))
plt.imshow( np.vstack([ np.linalg.matrix_power(pars_est['A'],m).dot(pars_est['Pi']) for m in lag_range]).T )
plt.colorbar()
plt.show()

parametrization = 'ln'
# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 10, 100, 250
a, b1, b2, e = 0.001, 0.95, 0.99, 1e-8
a_decay = 0.99


proj_errors = np.zeros((max_iter,n+1))
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces_, _, _, _, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,a_decay=a_decay,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_epoch_size=max_zip_size,
                                      pars_track=pars_track)
traces[0].append(traces_[0])
traces[1].append(traces_[1])
traces[2].append(proj_errors.copy())
ts.append(t)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.figure(figsize=(12,6))
plt.plot(proj_errors[:,1:])
plt.show()


In [None]:

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 50, 20, 1250
a, b1, b2, e = 0.0001, 0.95, 0.99, 1e-8
a_decay = 0.998


proj_errors = np.zeros((max_iter,n+1))
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces_, _, _, _, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,a_decay=a_decay,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_epoch_size=max_zip_size,
                                      pars_track=pars_track)
traces[0].append(traces_[0])
traces[1].append(traces_[1])
traces[2].append(proj_errors.copy())
ts.append(t)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.figure(figsize=(12,6))
plt.plot(proj_errors[:,1:])
plt.show()


In [None]:
save_dict = {'p' : p,
             'n' : n,
             'T' : T,
             'snr' : snr,
             'obs_scheme' : obs_scheme,
             'lag_range' : lag_range,
             'mmap' : mmap,
             'y' : data_path if mmap else y,
             'pars_true' : pars_true,
             'pars_est' : pars_est,
             'pars_g' : pars_est_g,
             'idx_a' : idx_a,
             'idx_b' : idx_b,
             'traces' : traces,
             'W' : W,
             'Qs' : Qs,
             'Om' : Om,
             'rnd_seed' : 1
            }
file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+'snr'+str(np.int(np.mean(snr)//1))+'_run'+str(run)
np.savez(data_path + file_name, save_dict)

# Compare with GROUSE

In [None]:
pars_est_g = 'default'

In [None]:
from subtracking import Grouse, calc_subspace_proj_error
from scipy import linalg as la

# settings for GROUSE
a_grouse = 10
if pars_est_g == 'default':
    tracker = Grouse(p, n, a_grouse )
else:
    tracker.U = pars_est_g['C'].copy()
max_epoch_size = 1000
max_iter_grouse = 1000
get_obs = obs_scheme.gen_get_observed()

# fit GROUSE
print('\n - GROUSE')
tracker.step = a_grouse
ct = 1.
error = np.zeros((max_iter_grouse, n+1))
for i in range(max_iter_grouse):
    if verbose and np.mod(i,max_iter_grouse//10) == 0:
        print('finished % ' + str((100*i)//max_iter_grouse))
    idx = np.random.permutation(T-np.max(lag_range)-1)
    idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
    for j in range(len(idx)):
        tracker.consume(y[idx[j],:].reshape(p,1), get_obs(idx[j]).reshape(p,1))
        ct += 1     
        tracker.step = a_grouse / ct

    error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))
pars_est_g = {'C' : tracker.U.copy()}
traces_g = [error.copy()]

print('final proj. error (est.): ', str(error[-1][0]))

plt.plot(error[:,1:])
plt.title('subspace proj. error (GROUSE)')
plt.show()

In [None]:
from subtracking import Grouse, calc_subspace_proj_error
from scipy import linalg as la

# settings for GROUSE
a_grouse = 1e-6
if pars_est_g == 'default':
    tracker = Grouse(p, n, a_grouse )
else:
    tracker.U = pars_est_g['C'].copy()
max_epoch_size = 1000
max_iter_grouse = 1000
get_obs = obs_scheme.gen_get_observed()

# fit GROUSE
print('\n - GROUSE')
tracker.step = a_grouse
ct = 1.
error = np.zeros((max_iter_grouse, n+1))
for i in range(max_iter_grouse):
    if verbose and np.mod(i,max_iter_grouse//10) == 0:
        print('finished % ' + str((100*i)//max_iter_grouse))
    idx = np.random.permutation(T-np.max(lag_range)-1)
    idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
    for j in range(len(idx)):
        tracker.consume(y[idx[j],:].reshape(p,1), get_obs(idx[j]).reshape(p,1))
        ct += 1     
        tracker.step = a_grouse / ct

    error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))

print('final proj. error (est.): ', str(error[-1][0]))

plt.plot(error[:,1:])
plt.title('subspace proj. error (GROUSE)')
plt.show()

In [None]:
from scipy import linalg as la
plt.plot( principal_angle(pars_true['C'], pars_est_g['C']), 'bo' )
plt.plot( principal_angle(pars_true['C'], pars_est['C']), 'ro' )
plt.plot( principal_angle(pars_est['C'], pars_est_g['C']), 'gx' )
plt.show()

In [None]:
pars_est_g = {'C' : tracker.U.copy()}
traces_g = [error.copy()]

In [None]:
save_dict = {'p' : p,
             'n' : n,
             'T' : T,
             'snr' : snr,
             'obs_scheme' : obs_scheme,
             'lag_range' : lag_range,
             'mmap' : mmap,
             'y' : data_path if mmap else y,
             'pars_true' : pars_true,
             'pars_est' : pars_est,
             'pars_g' : pars_est_g,
             'idx_a' : idx_a,
             'idx_b' : idx_b,
             'traces' : traces,
             'traces_g' : traces_g,
             'W' : W,
             'Qs' : Qs,
             'Om' : Om,
             'rnd_seed' : 1
            }
file_name = 'p'+str(p)+'n'+str(n)+'T'+str(T)+'snr'+str(np.int(np.mean(snr)//1))+'_e1_final'
np.savez(data_path + file_name, save_dict)

# Data generation (to be run once!)

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error
from ssidid import progprint_xrange
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om

rnd_seed = 1 #np.int(np.mod(time.time(), 1) * 1e9)
np.random.seed(rnd_seed)

run = '_test_fullyobs'

# define problem size
lag_range = np.arange(0,20)
kl_ = np.max(lag_range)+1
T_full = 100000
p, n, T = 1000, 20, T_full

nr = 0 # number of real eigenvalues
snr = (9., 9.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.95, 0.99, 0.95, 0.99


# I/O matter
mmap, chunksize = True, np.min((p,100))
verbose=True

# create subpopulations

sso = True
sub_pops = (np.arange(p),)

reps = 1
obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=sub_pops, 
                                obs_pops=obs_pops, 
                                obs_time=obs_time)
obs_scheme.comp_subpop_stats()

missing_at_random, frac_obs = False, 0.5
if missing_at_random:
    n_obs = np.ceil(p * frac_obs)
    mask = np.zeros((T,p))
    for t in range(T):
        for i in range(len(obs_time)):
            if t < obs_time[i]:
                mask[t, np.random.choice(p, n_obs, replace=False)] = 1
                break                       
    obs_scheme.mask = mask
    del mask
else:
    if p*T < 1e8:
        obs_scheme.gen_mask_from_scheme()
        obs_scheme.use_mask = False

print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

if p*T < 1e8:
    plt.figure(figsize=(20,10))
    plt.imshow(obs_scheme.mask.T, interpolation='None', aspect='auto')
    plt.grid('off')
    plt.show()
    
pars_est = 'default'
pars_est_g = 'default'

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
    


In [None]:
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import progprint_xrange

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e1/'

np.random.seed(rnd_seed)
pars_true, x, y, _, _ = gen_data(p,n,lag_range,T, nr,
                                 eig_m_r, eig_M_r, 
                                 eig_m_c, eig_M_c,
                                 mmap, chunksize,
                                 data_path,
                                 snr=snr, whiten=whiten)    

if mmap: 
    print('ensuring zero-mean data for given observation scheme')
    for i in progprint_xrange(p//chunksize, perline=10):
        y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
        y[:, i*chunksize:(i+1)*chunksize] = y[:, i*chunksize:(i+1)*chunksize] - y[:, i*chunksize:(i+1)*chunksize].mean(axis=0)
        del y
    if (p//chunksize)*chunksize < p:
        y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
        y[:, (p//chunksize)*chunksize:] = y[:, (p//chunksize)*chunksize:] - y[:, (p//chunksize)*chunksize:].mean(axis=0)
        del y        
    y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
else:
    y -= y.mean(axis=0)

idx_a = np.sort(np.random.choice(p, 1000, replace=False)) if p > 1000 else np.arange(p)
idx_b = idx_a.copy()

W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)
print('computing time-lagged covariances')
Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                      idx_a=idx_a,idx_b=idx_b,W=W,sso=sso,
                      mmap=mmap,data_path=data_path,ts=None,ms=None)

#pars_true['X'] = np.vstack([ np.linalg.matrix_power(pars_true['A'],m).dot(pars_true['Pi']) for m in lag_range])
pars_true['X'] = np.vstack([ np.cov(x[m:T-kl_+m].T, x[:T-kl_].T)[:n,n:] for m in lag_range])
print('true param. loss: ', f_l2_Hankel_nl(C=pars_true['C'],
                               X=pars_true['X'],
                               R=pars_true['R'],
                               Qs=Qs,
                               Om=Om,
                               lag_range=lag_range,
                               ms=range(len(lag_range)),
                               idx_a=idx_a,
                               idx_b=idx_b))
print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,_,False,data_path)


In [None]:
save_dict = {'p' : p,
             'n' : n,
             'T' : T,
             'snr' : snr,
             'obs_scheme' : obs_scheme,
             'lag_range' : lag_range,
             'x' : x,
             'mmap' : mmap,
             'y' : data_path if mmap else y,
             'pars_true' : pars_true,
             'pars_est' : pars_est,
             'idx_a' : idx_a,
             'idx_b' : idx_b,
             'W' : W,
             'Qs' : Qs,
             'Om' : Om,
             'rnd_seed' : rnd_seed
            }
file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + 'e1_init'
np.savez(data_path + file_name, save_dict)