In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import cm

import numpy as np
from numpy import ma

import scipy as sp
from scipy import linalg as la
import glob, os, psutil, time

from pykalman import KalmanFilter
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from ssidid.utility import get_subpop_stats, gen_data, draw_data
from ssidid.icml_scripts import run_default


lag_range = np.arange(11)
kl_ = np.max(lag_range) + 1
snr = (0.5, 0.5)
p,n,T = 10, 2, 1000 + kl_
mmap, verbose = False, False

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, 1000, 200
a, b1, b2, e = 0.005, 0.99, 0.99, 1e-8
a_decay = 0.98

idx_a, idx_b = np.arange(p), np.arange(p)
obs_scheme = ObservationScheme(p=p, T=T)
W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e0/'
th = np.pi/20

numExps = 10

for iExp in range(0,numExps):

    pars_true = {
        'C' : np.random.normal(size=(p,n)) / np.sqrt(n), 
        #'A' : np.array([[0.99 * np.cos(th),  0.99 * np.sin(th), 0. ], 
        #                [-0.99 * np.sin(th), 0.99 * np.cos(th), 0. ], 
        #                [0.,                  0.,               0.]]), 
        #'Q' : np.diag((2e-3, 2e-3, 1)),
        #'mu0': np.array([0,0,0]),
        'mu0' : np.zeros(2),
        'A' : np.diag([0.95, 0.]),
        'Q' : np.diag([1e-2, 1.]),
        'R' : np.mean(snr) * np.ones(p)
    }
    pars_true['C'] = la.orth(pars_true['C'])
    pars_true['Pi'] = np.atleast_2d(sp.linalg.solve_discrete_lyapunov(pars_true['A'], pars_true['Q']))
    pars_true['V0'] = pars_true['Pi'].copy()

    rnd_seed = np.mod(int(time.time() * 10000), 10000)
    np.random.seed(rnd_seed)
    x, _ = draw_data(pars_true, T)
    
    x = (x - x.mean(axis=0))/np.std(x,axis=0)
    x[:,0] /= np.sqrt(10)
    
    y = x.dot(pars_true['C'].T) + np.sqrt(pars_true['R']).reshape(1,p) * np.random.normal(size=(x.shape[0],p))
    y -= y.mean(axis=0)

    plt.figure(figsize=(12,6))
    plt.subplot(1,2,1)
    plt.imshow(pars_true['C'], interpolation='None')
    plt.colorbar()
    plt.subplot(2,2,2)
    plt.plot(x[:,0], x[:,1])
    plot_range = np.arange(T)
    plt.subplot(2,2,4)
    plt.plot(x[plot_range,:])
    plt.show()


    # setting up observation scheme
    Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                          idx_a=idx_a,idx_b=idx_b,W=W,sso=True,
                          mmap=mmap,data_path=data_path,ts=None,ms=None)   
    print('var[x]', np.var(x, axis=0))

    plt.figure(figsize=(20,3))
    plt.imshow((np.hstack([Qs[m] for m in range(len(lag_range))])), interpolation='None')
    plt.show()


    pars_true['X'] = np.vstack([np.linalg.matrix_power(pars_true['A'],m) for m in lag_range])
    print_slim(Qs,Om,lag_range,pars_true,idx_a,idx_b,None,False,data_path)

    # settings for GROUSE
    pars_est_g = 'default'
    a_grouse = 100
    tracker = Grouse(p, n, a_grouse )
    max_epoch_size = 1000
    max_iter_grouse = 100
    get_obs = obs_scheme.gen_get_observed()

    # fit GROUSE
    print('\n - GROUSE')
    tracker.step = a_grouse
    ct = 1.
    traces_g = np.zeros((max_iter_grouse, n+1))
    t_g = time.time()
    get_obs = obs_scheme.gen_get_observed()

    for i in range(max_iter_grouse):
        if np.mod(i,max_iter_grouse//10) == 0:
            print('finished % ' + str((100*i)//max_iter_grouse))
        idx = np.random.permutation(T-np.max(lag_range)-1)
        idx = idx[:max_epoch_size] if len(idx) > max_epoch_size else idx
        for j in range(len(idx)):
            obs_idx =  np.zeros((p,1), dtype=bool)
            obs_idx[get_obs(idx[j])] = True
            tracker.consume(y[idx[j],:].reshape(-1,1), obs_idx)
            ct += 1     
            tracker.step = a_grouse / ct

        traces_g[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))
    t_g = time.time() - t_g
    pars_est_g = {'C' : tracker.U.copy()}


    plt.subplot(1,2,1)
    plt.plot(traces_g[:,1:])
    plt.title('subspace proj. error (GROUSE)')
    plt.subplot(1,2,2)
    plt.loglog(traces_g[:,1:])
    plt.title('subspace proj. error (GROUSE)')
    plt.show()
    
    
    proj_errors = np.zeros((max_iter,n+1))
    def pars_track(pars,t): 
        C = pars[0]
        proj_errors[t] = np.hstack((calc_subspace_proj_error(pars_true['C'], C), 
                                    principal_angle(pars_true['C'], C)))

    _, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,pars_init='default',
                                          parametrization='nl', sso=True,
                                          Qs=Qs, Om=Om, W=W,
                                          alpha=a,b1=b1,b2=b2,e=e,a_decay=a_decay,max_iter=max_iter,
                                          batch_size=batch_size,verbose=True, max_epoch_size=max_zip_size,
                                          pars_track=pars_track)

    traces = list(traces)
    traces.append(proj_errors.copy())
    print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
    print('fitting time was ', t, 's')
    print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])
    plt.plot(proj_errors[:,1:])
    plt.show()

    save_dict = {'p' : p, 'n' : n, 'T' : T, 'snr' : snr,
                 'obs_scheme' : obs_scheme,
                 'lag_range' : lag_range,
                 'pars_true' : pars_true,
                 'pars_est' : pars_est,
                 'traces' : traces, 
                 'pars_est_g' : pars_est_g,
                 'traces_g' : traces_g, 
                 'idx_a' : idx_a,
                 'idx_b' : idx_a,
                 't' : t,
                 't_g' : t_g,
                 'W' : W,
                 'Qs' : Qs,
                 'Om' : Om,
                 'rnd_seed' : rnd_seed}

    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + '_runExp' + str(iExp)
    np.savez(data_path + file_name, save_dict)    
    
    

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import cm

import numpy as np
from numpy import ma

import scipy as sp
from scipy import linalg as la
import glob, os, psutil, time

from pykalman import KalmanFilter
from subtracking import Grouse, calc_subspace_proj_error
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid import ObservationScheme, progprint_xrange
from ssidid.utility import get_subpop_stats, gen_data, draw_data
from ssidid.icml_scripts import run_default


color_us =            (244/255, 152/255, 25/255)
color_us_bitflipped = 'k'  #(0,0,0)
color_GROUSE =        'm'  #0.7 * np.array((244/255, 152/255, 25/255))

lag_range = np.arange(11)
kl_ = np.max(lag_range) + 1
snr = (0.5, 0.5)
p,n,T = 10, 2, 1000 + kl_
mmap, verbose = False, False

def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)

data_path =  '/home/mackelab/Desktop/Projects/Stitching/results/icml_e0/'
fig_path =  '/home/mackelab/Desktop/Projects/Stitching/figures/'
numExps = 10


plt.figure(figsize=(20,5))

prin_angl_final = np.zeros((numExps,2,n))
for iExp in range(numExps):


    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + '_runExp' + str(iExp)
    load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()    
    
    pars_true = load_file['pars_true']
    pars_est = load_file['pars_est']
    pars_est_g = load_file['pars_est_g']
    
    prin_angl_final[iExp,0,:] = principal_angle(pars_true['C'], 
                                              pars_est['C'])
    prin_angl_final[iExp,1,:] = principal_angle(pars_true['C'], 
                                              pars_est_g['C'])
    
    plt.subplot(1,4,1)
    plt.plot(np.arange(0, 2*len(load_file['traces_g'][:, 0]), 2),
             load_file['traces_g'][:, 0], color=color_GROUSE, linewidth=1.5)
    plt.plot(load_file['traces'][2][:,0], color=color_us, linewidth=1.5)
    plt.title('subspace projection errors')
    plt.xlabel('epoch')
    plt.ylabel('subsp. proj. err.')
    plt.box('off')
    plt.legend(['GROUSE', 'ssidid'])
    #plt.plot(principal_angle[load_file['traces_g'][:, 1:]], 'm')

    plt.subplot(1,4,2)
    plt.plot(load_file['traces_g'][-1, 0], load_file['traces'][2][-1, 0], 'ko', linewidth=3)
    plt.xlabel('subsp. proj. err. GROUSE')
    plt.ylabel('subsp. proj. err. ssidid')
    plt.plot([0.1,1], [0.1,1], color='k')
    plt.box('off')
    plt.title('final subsp. proj. err, GROUSE vs. us')
    
    
    plt.subplot(1,4,3)
    plt.plot(load_file['traces_g'][-1, 1:], 'o', color=color_GROUSE, linewidth=3)
    plt.plot(np.arange(n)+0.05, load_file['traces'][2][-1, 1:], 'o', color=color_us, linewidth=3)
    plt.axis([-0.1, 1.1, 0, 0.9])
    plt.title('final prinicpal angles betw. est. and true C')
    plt.xlabel('latent mode')
    plt.ylabel('principal angle')
    plt.xticks([0, 1], ['#1', '#2'])
    plt.box('off')
    plt.legend(['GROUSE', 'ssidid'], loc=2)
    
    plt.subplot(1,4,4)
    plt.plot(load_file['traces_g'][-1, 1], load_file['traces'][2][-1, 1], 'ko', linewidth=3)
    plt.plot(load_file['traces_g'][-1, 2], load_file['traces'][2][-1, 2], 'kx', linewidth=3)
    plt.xlabel('principal angle GROUSE')
    plt.ylabel('principal angle ssidid')
    plt.legend(['mode #1', 'mode #2'], loc =2)
    plt.plot([0,1], [0,1], color='k')
    plt.box('off')
    plt.title('final principal angles, GROUSE vs. us')
    
    plt.savefig(fig_path + 'illustration_dynamics_dominated.pdf')
        
plt.show()