In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy.optimize import fmin_bfgs, check_grad
import glob, os

os.chdir('../core')
import stitching_ssid as ssid
os.chdir('../dev')

p,n = 10,3
k,l = 3,3

a, b1, b2, e = 0.0001, 0.9, 0.99, 1e-8
v_0 = np.zeros(48)
max_iter = 1000

def adam_zip(f_i,g_i,theta_0,a,b1,b2,e,converged,Om,idx_grp,co_obs,m_0=None,v_0=None):
    
    N = theta_0.size
    batch_size = 1
    
    # setting up the stitching setup
    is_, js_ = np.where(Om)
    
    # setting up Adam
    t_iter, t, t_zip = 0, 0, 0
    m = np.zeros(N) if m_0 is None else m_0.copy()
    v = np.zeros(N) if v_0 is None else v_0.copy()
    theta, theta_old = theta_0.copy(), np.inf * np.ones(N)

    # trace function values
    fun = np.empty(max_iter)    
    
    while not converged(theta_old, theta, e, t_iter):

        theta_old = theta.copy()

        t_iter += 1
        idx_use, idx_co = ssid.l2_sis_draw(batch_size, idx_grp, co_obs, is_,js_)

        for idx_zip in range(idx_use.size):
        #for idx_zip in range(len(idx_use)):
            t += 1

            # get data point(s) and corresponding gradients:        
            g = g_i(theta,(np.array((idx_use[idx_zip],)),),(np.array((idx_co[idx_zip],)),))
            #g = g_i(theta,idx_use,idx_co)
            m = (b1 * m + (1-b1)* g)     
            v = (b2 * v + (1-b2)*(g**2)) 
            if b1 != 1.:
                mh = m / (1-b1**t)
            else:
                mh = m
            if b2 != 1.:
                vh = v / (1-b2**t)
            else:
                vh = v

            theta = theta - a * mh/(np.sqrt(vh) + e)
        
        if t_iter <= max_iter:
            fun[t_iter-1] = f_i(theta)
            
        if np.mod(t_iter,max_iter//10) == 2:
            print('f = ', fun[t_iter-1])
            
    print('total iterations: ', t)
        
    return theta, fun



# create subpopulations
sub_pops = (np.arange(0,p//2+1), np.arange(p//2-1,p))
print('sub_pops', sub_pops)
obs_idx, idx_grp, co_obs, overlaps, overlap_grp, idx_overlap, Om, Ovw, Ovc = \
    ssid.get_subpop_stats(sub_pops, p, verbose=True)

for rep in range(10):
    #"""
    C_true      = np.random.normal(size=(p,n))
    
    V = np.random.normal(size=(n,n))
    V /= np.sqrt(np.sum(V**2,axis=0)).reshape(1,-1)
    A_true = V.dot(np.diag(np.linspace(0.7, 0.95, n))).dot(np.linalg.inv(V))
    
    B_true      = np.random.normal(size=(n,n))/np.sqrt(n)    
    Pi_true     = B_true.dot(B_true.T) #np.eye(n) 
    
    Qs = ssid.comp_model_covariances({'A': A_true, 'Pi': Pi_true, 'C': C_true}, k+l, Om)
    Qs_full = ssid.comp_model_covariances({'A': A_true, 'Pi': Pi_true, 'C': C_true}, k+l)

    
    A_0  = np.diag(np.random.uniform(low=0.7, high=0.8, size=n))
    B_0  = np.eye(n) #np.random.normal(size=(n,n))
    Pi_0 = B_0.dot(B_0.T)
    C_0  = np.random.normal(size=(p,n))
    pars_0 = np.hstack((A_0.reshape(n*n,),
                        B_0.reshape(n*n,),
                        C_0.reshape(p*n,)))
    H_0 = ssid.yy_Hankel_cov_mat( C_0,A_0,Pi_0,k,l,~Om)

    f_i, _ = ssid.l2_sis_setup(k,l,n,Qs,Om,idx_grp,obs_idx)
    def g_i(theta, idx_use, idx_co):
        return ssid.g_l2_Hankel_sis(theta,k,l,n,Qs,idx_use,idx_co)
    
    #"""

    def converged(theta_old, theta, e, t):
        if t > max_iter:
            return True
        return np.abs(f_i(theta_old) - f_i(theta)) < e
    
    pars_est_vec, fs = adam_zip(f_i,g_i,pars_0.copy(),a,b1,b2,e,converged,Om,idx_grp,co_obs,v_0=v_0)
    
    A_est = pars_est_vec[:n*n].reshape(n,n)
    B_est = pars_est_vec[n*n:2*n*n].reshape(n,n)
    Pi_est = B_est.dot(B_est.T)
    C_est = pars_est_vec[-p*n:].reshape(p,n)
    
    pars_init = {'A': A_0, 'C': C_0, 'Pi': Pi_0, 'B': B_0}
    pars_est  = {'A': A_est, 'C': C_est, 'Pi': Pi_est, 'B': B_est}
    pars_true = {'A': A_true, 'C': C_true, 'Pi': Pi_true, 'B': B_true}
    ssid.plot_outputs_l2_gradient_test(pars_true, pars_init, pars_est, k, l, Qs, 
                                       Qs_full, Om, Ovc, Ovw, f_i, g_i, if_flip = True)
    
    plt.figure(figsize=(20,8))
    plt.plot(fs[:max_iter])
    plt.show()
    


# Comparison: using batch-size p*n

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy.optimize import fmin_bfgs, check_grad
import glob, os

os.chdir('../core')
import stitching_ssid as ssid
os.chdir('../dev')

p,n = 10,3
k,l = 3,3

a, b1, b2, e = 0.00001, 0.0, 0.0, 0.0
v_0 = np.ones(48)
max_iter = 10000

def adam_zip(f_i,g_i,theta_0,a,b1,b2,e,converged,Om,idx_grp,co_obs,m_0=None,v_0=None):
    
    pn = theta_0.size
    batch_size = None
    
    # setting up the stitching setup
    is_, js_ = np.where(Om)
    
    # setting up Adam
    t_iter, t, t_zip = 0, 0, 0
    m = np.zeros(pn) if m_0 is None else m_0.copy()
    v = np.zeros(pn) if v_0 is None else v_0.copy()
    theta, theta_old = theta_0.copy(), np.inf * np.ones(pn)

    # trace function values
    fun = np.empty(max_iter)    
    
    while not converged(theta_old, theta, e, t_iter):

        theta_old = theta.copy()

        t_iter += 1
        idx_use, idx_co = ssid.l2_sis_draw(batch_size, idx_grp, co_obs, is_,js_)

        #for idx_zip in range(idx_use.size):
        for idx_zip in range(len(idx_use)):
            t += 1

            # get data point(s) and corresponding gradients:        
            #g = g_i(theta,(np.array((idx_use[idx_zip]),),),(np.array((idx_co[idx_zip],)),))
            g = g_i(theta,idx_use,idx_co)
            m = (b1 * m + (1-b1)* g)     
            v = (b2 * v + (1-b2)*(g**2)) 
            if b1 != 1.:
                mh = m / (1-b1**t)
            else:
                mh = m
            if b2 != 1.:
                vh = v / (1-b2**t)
            else:
                vh = v

            theta = theta - a * mh/(np.sqrt(vh) + e)
        
        if t_iter <= max_iter:
            fun[t_iter-1] = f_i(theta)
            
        if np.mod(t_iter,max_iter//10) == 2:
            print('f = ', fun[t_iter-1])
            
    print('total iterations: ', t)
        
    return theta, fun



# create subpopulations
sub_pops = (np.arange(0,p//2+1), np.arange(p//2-1,p))
print('sub_pops', sub_pops)
obs_idx, idx_grp, co_obs, overlaps, overlap_grp, idx_overlap, Om, Ovw, Ovc = \
    ssid.get_subpop_stats(sub_pops, p, verbose=True)

for rep in range(10):
    #"""
    C_true      = np.random.normal(size=(p,n))
    
    V = np.random.normal(size=(n,n))
    V /= np.sqrt(np.sum(V**2,axis=0)).reshape(1,-1)
    A_true = V.dot(np.diag(np.linspace(0.7, 0.95, n))).dot(np.linalg.inv(V))
    
    B_true      = np.random.normal(size=(n,n))/np.sqrt(n)    
    Pi_true     = B_true.dot(B_true.T) #np.eye(n) 
    
    Qs = ssid.comp_model_covariances({'A': A_true, 'Pi': Pi_true, 'C': C_true}, k+l, Om)
    Qs_full = ssid.comp_model_covariances({'A': A_true, 'Pi': Pi_true, 'C': C_true}, k+l)

    
    A_0  = np.diag(np.random.uniform(low=0.7, high=0.8, size=n))
    B_0  = np.eye(n) #np.random.normal(size=(n,n))
    Pi_0 = B_0.dot(B_0.T)
    C_0  = np.random.normal(size=(p,n))
    pars_0 = np.hstack((A_0.reshape(n*n,),
                        B_0.reshape(n*n,),
                        C_0.reshape(p*n,)))
    H_0 = ssid.yy_Hankel_cov_mat( C_0,A_0,Pi_0,k,l,~Om)

    f_i, _ = ssid.l2_sis_setup(k,l,n,Qs,Om,idx_grp,obs_idx)
    def g_i(theta, idx_use, idx_co):
        return ssid.g_l2_Hankel_sis(theta,k,l,n,Qs,idx_use,idx_co)
    
    #"""

    def converged(theta_old, theta, e, t):
        if t > max_iter:
            return True
        return np.abs(f_i(theta_old) - f_i(theta)) < e
    
    pars_est_vec, fs = adam_zip(f_i,g_i,pars_0.copy(),a,b1,b2,e,converged,Om,idx_grp,co_obs,v_0=v_0)
    
    A_est = pars_est_vec[:n*n].reshape(n,n)
    B_est = pars_est_vec[n*n:2*n*n].reshape(n,n)
    Pi_est = B_est.dot(B_est.T)
    C_est = pars_est_vec[-p*n:].reshape(p,n)
    
    pars_init = {'A': A_0, 'C': C_0, 'Pi': Pi_0, 'B': B_0}
    pars_est  = {'A': A_est, 'C': C_est, 'Pi': Pi_est, 'B': B_est}
    pars_true = {'A': A_true, 'C': C_true, 'Pi': Pi_true, 'B': B_true}
    ssid.plot_outputs_l2_gradient_test(pars_true, pars_init, pars_est, k, l, Qs, 
                                       Qs_full, Om, Ovc, Ovw, f_i, g_i, if_flip = True)
    
    plt.figure(figsize=(20,8))
    plt.plot(fs[:max_iter])
    plt.show()
    
    print('singular values of partial observability matrix for first subpop \n')
    _,s,_ = np.linalg.svd(ssid.observability_mat((A_true, C_true[sub_pops[0],:]), n))
    print(s)

    print('\n singular values of partial observability matrix for second subpop \n')
    _,s,_ = np.linalg.svd(ssid.observability_mat((A_true, C_true[sub_pops[1],:]), n))
    print(s)

    print('\n singular values of (noise) reachability matrix \n')
    _,s,_ = np.linalg.svd(ssid.observability_mat((A_true, B_true), n))
    print(s)
