In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
from scipy.optimize import fmin_bfgs, check_grad
import glob, os

os.chdir('../core')
import stitching_ssid as ssid
os.chdir('../dev')

import pprint

# Covariance matrix completion 
- $\mbox{cov}(y) = C C^\top$
- if cov($y$) has missing off-diagonal blocks, parts of $C$ are underdetermined (change of latent basis)
- Bishop et al. (2014) introduced a basic algorithm for rotating latent bases based on overlap
- Srini had the idea to just learn all chunks of C at the same time using gradient descent

In [None]:

p,n = 20,4
k,l = 1,1
overlap_size = 4

# create subpopulations
sub_pops = (np.arange(0,12), np.arange(8,p))
print('sub_pops', sub_pops)

obs_idx, idx_grp = ssid.get_obs_index_groups(obs_scheme={'sub_pops': sub_pops,'obs_pops': (0,1)},p=p)
overlaps, overlap_grp, idx_overlap = ssid.get_obs_index_overlaps(idx_grp, sub_pops)
print('idx_grp:', idx_grp)
print('obs_idx:', obs_idx)
        
Om, Ovw, Ovc = ssid.comp_subpop_index_mats(sub_pops,idx_grp,overlap_grp,idx_overlap)    
    
plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.imshow(Om,interpolation='none')
plt.title('Observation pattern')
plt.subplot(1,3,2)
plt.imshow(Ovw,interpolation='none')
plt.title('Overlap pattern')
plt.subplot(1,3,3)
plt.imshow(Ovc,interpolation='none')
plt.title('Cross-overlap pattern')
plt.show()

for rep in range(10):
    
    C_true = np.random.normal(size=(p,n))
    Q_true = C_true.dot(C_true.T)
    Q_obs = Q_true * np.asarray( Om, dtype=float)
    Q_sti = Q_true * np.asarray(~Om, dtype=float)
    A, Pi = np.eye(n), np.eye(n)
    
    C_0 = np.random.normal(size=(p,n))
        
    def f_i(C):        
        return ssid.f_l2_Hankel(C,A,Pi,k,l,[Q_obs], Om)*np.sum(Om)
    def g_i(C):
        return ssid.g_C_l2_Hankel(C,A,Pi,k,l,[Q_obs],idx_grp, obs_idx)
    
    
    
    print('difference in gradient to finite-differencing value:', check_grad(f_i, g_i, C_0.reshape(p*n,)))
    
    C_est = fmin_bfgs(f_i, C_0.reshape(p*n,), fprime=g_i, gtol=1e-20).reshape(p,n)
    

    print('final squared error on observed parts:', 
          ssid.f_l2_block(C_est, A, Q_obs, Om))
    print('final squared error on overlapping parts:', 
          ssid.f_l2_block(C_est, A, Q_obs, Ovw))
    print('final squared error on cross-overlapping parts:',
          ssid.f_l2_block(C_est, A, Q_obs, Ovc))
    print('final squared error on stitched parts:',
          ssid.f_l2_block(C_est, A, Q_sti,~Om))

    plt.figure(figsize=(18,12))
    plt.subplot(2,2,1)
    plt.imshow(C_0.dot(C_0.T),interpolation='none')
    plt.title('Initial matrix (C_0 C_0^T)')
    plt.subplot(2,2,3)
    plt.imshow(Q_true,interpolation='none')
    plt.title('True  matrix')
    plt.subplot(2,2,4)
    plt.imshow(C_est.dot(C_est.T),interpolation='none')
    plt.title('Estimated matrix')
    plt.subplot(2,2,2)
    plt.imshow(Q_true[Ovw].reshape(-1,overlap_size), interpolation='none')
    plt.title('overlaps (extracted)')
    plt.show()

    #_,s,_ = np.linalg.svd(Q_true[Ovw].reshape(-1,2)[:2,:])
    #print('singular values first overlap:', s)
    #_,s,_ = np.linalg.svd(Q_true[Ovw].reshape(-1,2)[2:,:])
    #print('singular values second overlap:', s)
    
    print('\n')
    print('\n')


# Hankel covariance matrix completion 
- $ H_{k,l} = (I_k \otimes C) H^{xx}_{k,l} (I_l \otimes C)^\top $
- $ H^{xx}_{k,l} = \left[\begin{array}{llll} A \Pi & A^2 \Pi & \ldots & A^l \Pi\\ A^2 \Pi & A^3 \Pi & \ldots & A^{l+1} \Pi\\ \vdots & \vdots & \ddots & \vdots \\ A^{k} \Pi & A^{k+1} \Pi & \ldots & A^{k+l-1} \Pi \end{array} \right] $
- if cov($x$) has missing off-diagonal blocks, parts of $C$ are underdetermined (change of latent basis)
- each block of the Hankel cov matrix $H_{k,l}$ exhibits the same structure of missing entries as does cov($y$.
- We can combine the overlaps of the $k \times l$ many blocks of $H_{k,l}$ when collecting constraints on the latent basis.
- We here first assume $A,\Pi$ to be known, and apply Srini's idea of joint gradient descent on the whole $C$ to $H_{k,l}$. Then we learn $A, \Pi$ with the other parameters fixed

# Derivative w.r.t. $C$
- needs code base from commit a00b35e to run, as I later changed the function signatures

In [None]:
p,n = 11,2
k,l = 3,3

# create subpopulations
sub_pops = (np.arange(0,6), np.arange(5,p))
print('sub_pops', sub_pops)

obs_idx, idx_grp = ssid.get_obs_index_groups(obs_scheme={'sub_pops': sub_pops,'obs_pops': (0,1)},p=p)
overlaps, overlap_grp, idx_overlap = ssid.get_obs_index_overlaps(idx_grp, sub_pops)
print('idx_grp:', idx_grp)
print('obs_idx:', obs_idx)
        
Om, Ovw, Ovc = ssid.comp_subpop_index_mats(sub_pops,idx_grp,overlap_grp,idx_overlap)    
    
plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.imshow(Om,interpolation='none')
plt.title('Observation pattern')
plt.subplot(1,3,2)
plt.imshow(Ovw,interpolation='none')
plt.title('Overlap pattern')
plt.subplot(1,3,3)
plt.imshow(Ovc,interpolation='none')
plt.title('Cross-overlap pattern')
plt.show()

for rep in range(10):
    
    C_true = np.random.normal(size=(p,n))
    Pi     = np.random.normal(size=(n,n))/np.sqrt(n)
    Pi     = Pi.dot(Pi.T) #np.eye(n) 
    A      = np.random.normal(size=(n,n))    
    Qs = []
    for kl_ in range(1,k+l):
        Akl = np.linalg.matrix_power(A, kl_)
        Qs.append(C_true.dot(Akl.dot(Pi)).dot(C_true.T) *np.asarray( Om,dtype=int) )
    Qs_full = []
    for kl_ in range(1,k+l):
        Akl = np.linalg.matrix_power(A, kl_)
        Qs_full.append(C_true.dot(Akl.dot(Pi)).dot(C_true.T) )       
    H_true = ssid.yy_Hankel_cov_mat(C_true,A,Pi,k,l)
    H_obs = ssid.yy_Hankel_cov_mat(C_true,A,Pi,k,l, Om)
    H_obs[np.where(H_obs==0)] = np.nan
    H_sti = ssid.yy_Hankel_cov_mat(C_true,A,Pi,k,l,~Om)
    
    C_0 = np.random.normal(size=(p,n))
        
    def f_i(C):        
        return ssid.f_l2_Hankel(C,A,Pi,k,l,Qs, Om)*np.sum(Om)*(k*l)
    def g_i(C):
        return ssid.g_C_l2_Hankel(C,A,Pi,k,l,Qs,idx_grp, obs_idx)
    
    
    print('difference in gradient to finite-differencing value:', check_grad(f_i, g_i, C_0.reshape(p*n,)))
    
    C_est = fmin_bfgs(f_i, C_0.reshape(p*n,), fprime=g_i, gtol=1e-20).reshape(p,n)    

    print('final squared error on observed parts:', 
          ssid.f_l2_Hankel(C_est,A,Pi,k,l,Qs, Om))
    print('final squared error on overlapping parts:', 
          ssid.f_l2_Hankel(C_est,A,Pi,k,l,Qs,Ovw))
    print('final squared error on cross-overlapping parts:',
          ssid.f_l2_Hankel(C_est,A,Pi,k,l,Qs,Ovc))
    print('final squared error on stitched parts:',
          ssid.f_l2_Hankel(C_est,A,Pi,k,l,Qs_full,~Om))
    
    plt.figure(figsize=(16,12))
    plt.subplot(2,2,1)
    tmp = ssid.yy_Hankel_cov_mat(C_true,A,Pi,k,l,Om)
    tmp[np.where(tmp==0)] = np.nan
    plt.imshow(tmp,interpolation='none')
    plt.title('Given data matrix (C_true, masked)')
    plt.subplot(2,2,2)
    plt.imshow(ssid.yy_Hankel_cov_mat(C_true,A,Pi,k,l),interpolation='none')
    plt.title('True  matrix (C_true)')    
    plt.subplot(2,2,3)
    plt.imshow(ssid.yy_Hankel_cov_mat(C_0,A,Pi,k,l),interpolation='none')
    plt.title('Initial matrix (C_0)')
    plt.subplot(2,2,4)
    plt.imshow(ssid.yy_Hankel_cov_mat(C_est,A,Pi,k,l),interpolation='none')
    plt.title('Estimated matrix (C_est)')
    plt.show()
    
    # closely compare stitched blocks:
    """
    plt.figure(figsize=(20,6))
    plt.subplot(1,3,1)
    tmpt = ssid.yy_Hankel_cov_mat(C_true,A,Pi,k,l,~Om)
    tmpt = tmpt[np.where(tmpt != 0)].reshape(k*6,l*3)
    plt.imshow(tmpt,interpolation='none')
    plt.title('est. stitched Hankel subblocks (assembled)')

    plt.subplot(1,3,2)
    tmpe = ssid.yy_Hankel_cov_mat(C_est, A,Pi,k,l,~Om)
    tmpe = tmpe[np.where(tmpe != 0)].reshape(k*6,l*3)
    plt.imshow(tmpe,interpolation='none')
    plt.title('true stitched Hankel subblocks (assembled)')
    
    plt.subplot(1,3,3)
    plt.plot(tmpt, tmpe, 'r.')
    plt.xlabel('true')
    plt.ylabel('est')

    plt.show()    
    """

    #if np.sum(Ovw) > 0:
    #    _,s,_ = np.linalg.svd(Qs[0][Ovw].reshape(-1,1)[:2,:])
    #    print('singular values first overlap:', s)
    #    #_,s,_ = np.linalg.svd(Q_obs[Ovw].reshape(-1,2)[2:,:])
    #    #print('singular values second overlap:', s)    
    

# Derivative w.r.t. $\Pi$
- needs code base from commit a00b35e to run, as I later changed the function signatures

In [None]:
p,n = 10,4
k,l = 2,2

# create subpopulations
sub_pops = (np.arange(0,p//2), np.arange(p//2,p))
print('sub_pops', sub_pops)

obs_idx, idx_grp = ssid.get_obs_index_groups(obs_scheme={'sub_pops': sub_pops,'obs_pops': (0,1)},p=p)
overlaps, overlap_grp, idx_overlap = ssid.get_obs_index_overlaps(idx_grp, sub_pops)
print('idx_grp:', idx_grp)
print('obs_idx:', obs_idx)
        
Om, Ovw, Ovc = ssid.comp_subpop_index_mats(sub_pops,idx_grp,overlap_grp,idx_overlap)    
    
plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.imshow(Om,interpolation='none')
plt.title('Observation pattern')
plt.subplot(1,3,2)
plt.imshow(Ovw,interpolation='none')
plt.title('Overlap pattern')
plt.subplot(1,3,3)
plt.imshow(Ovc,interpolation='none')
plt.title('Cross-overlap pattern')
plt.show()

for rep in range(5):
    
    C      = np.random.normal(size=(p,n))
    B_true = np.random.normal(size=(n,n))/np.sqrt(n)
    Pi_true = B_true.dot(B_true.T) #np.eye(n) 
    A      = np.random.normal(size=(n,n)) # np.diag(np.linspace(0.5, 0.9, n)) #
    Qs = []
    for kl_ in range(1,k+l):
        Akl = np.linalg.matrix_power(A, kl_)
        Qs.append(C.dot(Akl.dot(Pi_true)).dot(C.T) *np.asarray( Om,dtype=int) )
    Qs_full = []
    for kl_ in range(1,k+l):
        Akl = np.linalg.matrix_power(A, kl_)
        Qs_full.append(C.dot(Akl.dot(Pi_true)).dot(C.T) )       
    H_true = ssid.yy_Hankel_cov_mat(C,A,Pi_true,k,l)
    H_obs = ssid.yy_Hankel_cov_mat(C,A,Pi_true,k,l, Om)
    H_obs[np.where(H_obs==0)] = np.nan
    H_sti = ssid.yy_Hankel_cov_mat(C,A,Pi_true,k,l,~Om)
    
    B_0  = np.random.normal(size=(n,n))
    Pi_0 = B_0.dot(B_0.T)
    
    def f_i(B):        
        if len(B.shape)<2:
            B = B.reshape(A.shape[0], A.shape[0])
        return ssid.f_l2_Hankel(C,A,B.dot(B.T),k,l,Qs, Om)*np.sum(Om)*(k*l)
    def g_i(B):
        return ssid.g_B_l2_Hankel(C,A,B,k,l,Qs,Om)
    
    
    print('difference in gradient to finite-differencing value:', check_grad(f_i, g_i, B_0.reshape(n*n,)))
    
    B_est = fmin_bfgs(f_i, B_0.reshape(n*n,), fprime=g_i, gtol=1e-20).reshape(n,n)    
    Pi_est = B_est.dot(B_est.T)
    
    print('final squared error on observed parts:', 
          ssid.f_l2_Hankel(C,A,Pi_est,k,l,Qs, Om))
    print('final squared error on overlapping parts:', 
          ssid.f_l2_Hankel(C,A,Pi_est,k,l,Qs,Ovw))
    print('final squared error on cross-overlapping parts:',
          ssid.f_l2_Hankel(C,A,Pi_est,k,l,Qs,Ovc))
    print('final squared error on stitched parts:',
          ssid.f_l2_Hankel(C,A,Pi_est,k,l,Qs_full,~Om))
    
    plt.figure(figsize=(16,12))
    plt.subplot(2,2,1)
    tmp = ssid.yy_Hankel_cov_mat(C,A,Pi_true,k,l,Om)
    tmp[np.where(tmp==0)] = np.nan
    plt.imshow(tmp,interpolation='none')
    plt.title('Given data matrix (Pi_true, masked)')
    plt.subplot(2,2,2)
    plt.imshow(ssid.yy_Hankel_cov_mat(C,A,Pi_true,k,l),interpolation='none')
    plt.title('True  matrix (Pi_true)')    
    plt.subplot(2,2,3)
    plt.imshow(ssid.yy_Hankel_cov_mat(C,A,Pi_0,k,l),interpolation='none')
    plt.title('Initial matrix (Pi_0)')
    plt.subplot(2,2,4)
    plt.imshow(ssid.yy_Hankel_cov_mat(C,A,Pi_est,k,l),interpolation='none')
    plt.title('Estimated matrix (Pi_est)')
    plt.show()
    
    plt.figure(figsize=(16,12))
    plt.subplot(1,3,1)
    plt.imshow(Pi_0,interpolation='none')
    plt.title('Pi init')
    plt.subplot(1,3,2)
    plt.imshow(Pi_est,interpolation='none')
    plt.title('Pi est')
    plt.subplot(1,3,3)
    plt.imshow(Pi_true,interpolation='none')
    plt.title('Pi true')
    plt.show()


# Derivative w.r.t. $A$

- This one is nasty: we need the derivative of $A^m$ w.r.t. $A$, i.e. $\frac{\delta{}A^m}{\delta{}A} \in \mathbb{R}^{n \times n \times n \times n}$
- needs code base from commit a00b35e to run, as I later changed the function signatures

In [None]:
p,n = 5,3
k,l = 2,2

# create subpopulations
sub_pops = (np.arange(0,p//2), np.arange(p//2,p))
print('sub_pops', sub_pops)

obs_idx, idx_grp = ssid.get_obs_index_groups(obs_scheme={'sub_pops': sub_pops,'obs_pops': (0,1)},p=p)
overlaps, overlap_grp, idx_overlap = ssid.get_obs_index_overlaps(idx_grp, sub_pops)
print('idx_grp:', idx_grp)
print('obs_idx:', obs_idx)
        
Om, Ovw, Ovc = ssid.comp_subpop_index_mats(sub_pops,idx_grp,overlap_grp,idx_overlap)    
    
plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.imshow(Om,interpolation='none')
plt.title('Observation pattern')
plt.subplot(1,3,2)
plt.imshow(Ovw,interpolation='none')
plt.title('Overlap pattern')
plt.subplot(1,3,3)
plt.imshow(Ovc,interpolation='none')
plt.title('Cross-overlap pattern')
plt.show()

for rep in range(3):
    
    C      = np.random.normal(size=(p,n))
    B      = np.random.normal(size=(n,n))/np.sqrt(n)
    Pi     = B.dot(B.T) #np.eye(n) 
    A_true = np.random.normal(size=(n,n)) # np.diag(np.linspace(0.5, 0.9, n)) #
    Qs = []
    for kl_ in range(1,k+l):
        Akl = np.linalg.matrix_power(A_true, kl_)
        Qs.append(C.dot(Akl.dot(Pi)).dot(C.T) *np.asarray( Om,dtype=int) )
    Qs_full = []
    for kl_ in range(1,k+l):
        Akl = np.linalg.matrix_power(A_true, kl_)
        Qs_full.append(C.dot(Akl.dot(Pi)).dot(C.T) )       
    H_true = ssid.yy_Hankel_cov_mat(C,A_true,Pi,k,l)
    H_obs = ssid.yy_Hankel_cov_mat( C,A_true,Pi,k,l, Om)
    H_obs[np.where(H_obs==0)] = np.nan
    H_sti = ssid.yy_Hankel_cov_mat( C,A_true,Pi,k,l,~Om)
    
    A_0  = np.random.normal(size=(n,n))
    
    def f_i(A):        
        if len(A.shape)<2:
            A = A.reshape(Pi.shape[0], Pi.shape[0])
        return ssid.f_l2_Hankel(C,A,Pi,k,l,Qs, Om)*np.sum(Om)*(k*l)
    def g_i(A):
        return ssid.g_A_l2_Hankel(C,A,Pi,k,l,Qs,Om)
    
    
    print('difference in gradient to finite-differencing value:', check_grad(f_i, g_i, A_0.reshape(n*n,)))
    
    A_est = fmin_bfgs(f_i, A_0.reshape(n*n,), fprime=g_i, gtol=1e-20).reshape(n,n)    
    
    print('final squared error on observed parts:', 
          ssid.f_l2_Hankel(C,A_est,Pi,k,l,Qs, Om))
    print('final squared error on overlapping parts:', 
          ssid.f_l2_Hankel(C,A_est,Pi,k,l,Qs,Ovw))
    print('final squared error on cross-overlapping parts:',
          ssid.f_l2_Hankel(C,A_est,Pi,k,l,Qs,Ovc))
    print('final squared error on stitched parts:',
          ssid.f_l2_Hankel(C,A_est,Pi,k,l,Qs_full,~Om))
    
    plt.figure(figsize=(16,12))
    plt.subplot(2,2,1)
    tmp = ssid.yy_Hankel_cov_mat(C,A_true,Pi,k,l,Om)
    tmp[np.where(tmp==0)] = np.nan
    plt.imshow(tmp,interpolation='none')
    plt.title('Given data matrix (A_true, masked)')
    plt.subplot(2,2,2)
    plt.imshow(ssid.yy_Hankel_cov_mat(C,A_true,Pi,k,l),interpolation='none')
    plt.title('True  matrix (A_true)')    
    plt.subplot(2,2,3)
    plt.imshow(ssid.yy_Hankel_cov_mat(C,A_0,Pi,k,l),interpolation='none')
    plt.title('Initial matrix (A_0)')
    plt.subplot(2,2,4)
    plt.imshow(ssid.yy_Hankel_cov_mat(C,A_est,Pi,k,l),interpolation='none')
    plt.title('Estimated matrix (A_est)')
    plt.show()
    
    plt.figure(figsize=(16,12))
    plt.subplot(1,3,1)
    plt.imshow(A_0,interpolation='none')
    plt.title('A init')
    plt.subplot(1,3,2)
    plt.imshow(A_est,interpolation='none')
    plt.title('A est')
    plt.subplot(1,3,3)
    plt.imshow(A_true,interpolation='none')
    plt.title('A true')
    plt.show()
    

# Derivative w.r.t. $A, \Pi$ and $C$
- needs code base from commit a00b35e to run, as I later changed the function signatures

In [None]:
p,n = 15,3
k,l = 2,2

# create subpopulations
sub_pops = (np.arange(0,8), np.arange(7,p))
print('sub_pops', sub_pops)

obs_idx, idx_grp = ssid.get_obs_index_groups(obs_scheme={'sub_pops': sub_pops,'obs_pops': (0,1)},p=p)
overlaps, overlap_grp, idx_overlap = ssid.get_obs_index_overlaps(idx_grp, sub_pops)
print('idx_grp:', idx_grp)
print('obs_idx:', obs_idx)
        
Om, Ovw, Ovc = ssid.comp_subpop_index_mats(sub_pops,idx_grp,overlap_grp,idx_overlap)    
    
plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.imshow(Om,interpolation='none')
plt.title('Observation pattern')
plt.subplot(1,3,2)
plt.imshow(Ovw,interpolation='none')
plt.title('Overlap pattern')
plt.subplot(1,3,3)
plt.imshow(Ovc,interpolation='none')
plt.title('Cross-overlap pattern')
plt.show()

for rep in range(3):
    
    C_true      = np.random.normal(size=(p,n))
    B_true      = np.random.normal(size=(n,n))/np.sqrt(n)
    Pi_true     = B_true.dot(B_true.T) #np.eye(n) 
    A_true      = np.random.normal(size=(n,n)) # np.diag(np.linspace(0.5, 0.9, n)) #
    Qs = []
    for kl_ in range(1,k+l):
        Akl = np.linalg.matrix_power(A_true, kl_)
        Qs.append(C_true.dot(Akl.dot(Pi_true)).dot(C_true.T) *np.asarray( Om,dtype=int) )
    Qs_full = []
    for kl_ in range(1,k+l):
        Akl = np.linalg.matrix_power(A_true, kl_)
        Qs_full.append(C_true.dot(Akl.dot(Pi_true)).dot(C_true.T) )       
    H_true = ssid.yy_Hankel_cov_mat(C_true,A_true,Pi_true,k,l)
    H_obs = ssid.yy_Hankel_cov_mat( C_true,A_true,Pi_true,k,l, Om)
    H_obs[np.where(H_obs==0)] = np.nan
    H_sti = ssid.yy_Hankel_cov_mat( C_true,A_true,Pi_true,k,l,~Om)
    
    A_0  = np.random.normal(size=(n,n))
    B_0  = np.random.normal(size=(n,n))
    Pi_0 = B_0.dot(B_0.T)
    C_0  = np.random.normal(size=(p,n))
    pars_0 = np.hstack((A_0.reshape(n*n,),
                        B_0.reshape(n*n,),
                        C_0.reshape(p*n,)))
    
    
    def f_i(pars):        
                
        A,B,C = pars[:n*n].reshape(n,n), pars[n*n:2*n*n].reshape(n,n), pars[-p*n:].reshape(p,n)
        Pi = B.dot(B.T)
        return ssid.f_l2_Hankel(C,A,Pi,k,l,Qs, Om)*np.sum(Om)*(k*l)
    
    def g_i(pars):        

        A,B,C = pars[:n*n], pars[(n*n):(2*n*n)], pars[-p*n:]
        return ssid.g_l2_Hankel(C,A,B,k,l,Qs,Om,idx_grp, obs_idx)
    
    
    print('difference in gradient to finite-differencing value:', check_grad(f_i, g_i, pars_0))
    
    pars_est = fmin_bfgs(f_i, pars_0, fprime=g_i, gtol=1e-20)    
    A_est = pars_est[:n*n].reshape(n,n)
    B_est = pars_est[n*n:2*n*n].reshape(n,n)
    Pi_est = B_est.dot(B_est.T)
    C_est = pars_est[-p*n:].reshape(p,n)
    
    print('final squared error on observed parts:', 
          ssid.f_l2_Hankel(C_est,A_est,Pi_est,k,l,Qs, Om))
    print('final squared error on overlapping parts:', 
          ssid.f_l2_Hankel(C_est,A_est,Pi_est,k,l,Qs,Ovw))
    print('final squared error on cross-overlapping parts:',
          ssid.f_l2_Hankel(C_est,A_est,Pi_est,k,l,Qs,Ovc))
    print('final squared error on stitched parts:',
          ssid.f_l2_Hankel(C_est,A_est,Pi_est,k,l,Qs_full,~Om))
    
    plt.figure(figsize=(16,12))
    plt.subplot(2,2,1)
    tmp = ssid.yy_Hankel_cov_mat(C_true,A_true,Pi_true,k,l,Om)
    tmp[np.where(tmp==0)] = np.nan
    plt.imshow(tmp,interpolation='none')
    plt.title('Given data matrix (A_true, masked)')
    plt.subplot(2,2,2)
    plt.imshow(ssid.yy_Hankel_cov_mat(C_true,A_true,Pi_true,k,l),interpolation='none')
    plt.title('True  matrix (A_true)')    
    plt.subplot(2,2,3)
    plt.imshow(ssid.yy_Hankel_cov_mat(C_0,A_0,Pi_0,k,l),interpolation='none')
    plt.title('Initial matrix (A_0)')
    plt.subplot(2,2,4)
    plt.imshow(ssid.yy_Hankel_cov_mat(C_est,A_est,Pi_est,k,l),interpolation='none')
    plt.title('Estimated matrix (A_est)')
    plt.show()
    
    plt.figure(figsize=(16,12))
    plt.subplot(1,3,1)
    plt.imshow(A_0,interpolation='none')
    plt.title('A init')
    plt.subplot(1,3,2)
    plt.imshow(A_est,interpolation='none')
    plt.title('A est')
    plt.subplot(1,3,3)
    plt.imshow(A_true,interpolation='none')
    plt.title('A true')
    plt.show()
    

# Solving for $A^m * \Pi$ analytically
- gradient descent only for $C \in \mathbb{R}^{p \times n}$
- for each current estimate of $C$, we want to immediately find the corresponding $A^m \Pi$ such that $C A^m \Pi C^\top = \Lambda(m)$ 
- solving for $A,\Pi = BB^\top$ 'analytically': truly analytic solution is only possible in fully observed case
- for partially observed case, a simple fixed-point iteration can generate the $A^m \Pi$ that match the current $C, \Lambda(m)$ 
- differs from gradient descent w.r.t. to $A,B$ and $C$ in several ways:


(+) the accompanying latent bases as in $M A M^{-1}, M B, C M^{-1}$ are guaranteed to match at each step of the descent, starting at the very first step

(+) ... thus we can cancel the algorithm early and still get a constistent set of parameters to initialise EM with

(+) reduces dimensionality of the gradient step by $2 * n^2$ (though that indeed is the smallest chunk in practice...)

(+) numerically stable: Avoids calculating $A^m$ expliticly

(-) Bad estimates for $C$ may cause bad estimates for $A, \Pi$

(-) Might not be as robust to non-low-rank time-lagged covariance matrices $\Lambda(m)$ as when doing descent on all of $A,B,C$

(+/-) Works for nonlinear latent dynamics. The estimates for $X_m = \mbox{cov}[x_{t+m}, x_t]$ need not be consistent along differnt $m$ in the sense that $A^m \Pi = A * (A^{m-1} \Pi)$. This may in practice also give robustness to model miss-match and finite data 

**Showcasing that the fixed-point iteration for $A, \Pi$ works**

In [None]:
# pick problem size and time lag m
p,n,m = 50, 30, 1

# pick number of fixed-point iterations
num_fixed_point_iters = 10000

# create subpopulations
sub_pops = (np.arange(0,p//2+2), np.arange(p//2-1,p))
obs_idx, idx_grp = ssid.get_obs_index_groups(obs_scheme={'sub_pops': sub_pops,'obs_pops': (0,1)},p=p)
overlaps, overlap_grp, idx_overlap = ssid.get_obs_index_overlaps(idx_grp, sub_pops)
Om, Ovw, Ovc = ssid.comp_subpop_index_mats(sub_pops,idx_grp,overlap_grp,idx_overlap)    
not_Om = np.invert(Om)

# draw system parameters
nr = 2
ev_r = np.linspace(0.1, 0.2, nr)
ev_c = np.exp(2 * 1j * np.pi * np.random.uniform(size= (n - nr)//2))
ev_c = np.linspace(0.8, 0.9, (n - nr)//2) * ev_c
pars = ssid.gen_pars(p=p,n=n, nr=nr, ev_r = ev_r, ev_c = ev_c)
A, C, Pi = pars['A'], pars['C'], pars['Pi']

lam,V_est = np.linalg.eig(A)
AmPi = np.linalg.matrix_power(A,m).dot(Pi)

Cd = np.linalg.pinv(C)
CCd = C.dot(Cd)
CCd_I = CCd -np.eye(p)

# construct true, initial and estimate time-lagged covariance matrix Lambda(m)
Q_true = C.dot(AmPi).dot(C.T)

X_m = np.eye(n) # initialise X_m
Q_0 = Q_true.copy()
Q_0[not_Om] = C.dot(X_m).dot(C.T)[not_Om]

Q_est = Q_true.copy()
for i in range(num_fixed_point_iters):
    X_m = ssid.iter_X_m(Q_est, C, not_Om, Cd, X_m)
Q_est[not_Om] = C.dot(X_m).dot(C.T)[not_Om]

print('Visualising results for problem size (p,n) = ', (p,n))
print('Time lag is m = ', m)
print('Number of times we iterated the fixed-point equation: ', num_fixed_point_iters)

plt.figure(figsize=(20,10))
plt.subplot(2,3,1)
plt.imshow(Om, interpolation='none')
plt.title('observation pattern')
plt.subplot(2,3,4)
plt.imshow(Q_true, interpolation='none')
plt.title('true Lambda(m) = C A^m Pi C^T')
plt.subplot(2,3,2)
plt.imshow(Q_0, interpolation='none')
plt.title('initial guess for Lambda(m)')
plt.subplot(2,3,5)
plt.imshow(Q_est, interpolation='none')
plt.title('final guess for Lambda(m)')
plt.subplot(2,3,3)
plt.imshow(AmPi, interpolation='none')
plt.title('true A^m Pi')
plt.subplot(2,3,6)
plt.imshow(X_m, interpolation='none')
plt.title('X_m (estimated A^m Pi)')
plt.show()


**Showcasing full gradient descent along $C$ with $A^m \Pi$ computed at each gradient step**

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
from scipy.optimize import fmin_bfgs, check_grad
import glob, os

os.chdir('../core')
import stitching_ssid as ssid
os.chdir('../dev')

import pprint

p,n = 20,5
k,l = 5,5
max_iter_X_m = 50 # maximum number of iterations for fixed-point iteration

# create subpopulations
sub_pops = (np.arange(0,p//2+n//2), np.arange(p//2-n//2,p))
obs_idx, idx_grp, overlaps, overlap_grp, idx_overlap, Om, Ovw, Ovc = \
    ssid.get_subpop_stats(sub_pops, p, verbose=True)
not_Om = np.invert(Om)

for rep in range(50):
    
    # generate model and data
    C_true, Pi, A = np.random.normal(size=(p,n)), np.random.normal(size=(n,n))/np.sqrt(n), np.random.normal(size=(n,n))
    A, Pi = A.dot(np.diag(np.linspace(0.5, 0.95, n))).dot(np.linalg.inv(A)), Pi.dot(Pi.T) #np.eye(n) 
    Qs, Qs_full = [], []
    for kl_ in range(0,k+l):
        Akl = np.linalg.matrix_power(A, kl_)
        Qs.append(C_true.dot(Akl.dot(Pi)).dot(C_true.T) *np.asarray( Om,dtype=int) )
        Qs_full.append(C_true.dot(Akl.dot(Pi)).dot(C_true.T) )       
        
    # fit the model
    C_0 = np.random.normal(size=(p,n))
    def f_i(C):        
        return ssid.f_l2_Hankel_coord_asc(C, k, l, n, Qs, Om, not_Om, max_iter=max_iter_X_m)
    def g_i(C):
        return ssid.g_l2_coord_asc(C,k,l,n,Qs,not_Om,max_iter=max_iter_X_m)  
    C_est = fmin_bfgs(f_i, C_0.reshape(p*n,), fprime=g_i, gtol=1e-20).reshape(p,n)    


    # evaluate results
    print('initial squared error on observed parts:', 
          ssid.f_l2_Hankel_coord_asc(C_0,k,l,n,Qs, Om))    
    """
    print('initial squared error on stitched parts:', 
          f_l2_Hankel(C_0,k,l,n,Qs,not_Om))    
    print('\n final squared error on observed parts:', 
          f_l2_Hankel(C_est,k,l,n,Qs, Om))
    print('final squared error on overlapping parts:', 
          f_l2_Hankel(C_est,k,l,n,Qs,Ovw))
    print('final squared error on cross-overlapping parts:',
          f_l2_Hankel(C_est,k,l,n,Qs,Ovc))
    print('final squared error on stitched parts:',
          f_l2_Hankel(C_est,k,l,n,Qs_full,~Om))
    """
    plt.figure(figsize=(16,12))
    plt.subplot(2,2,1)
    H_obs = ssid.yy_Hankel_cov_mat(C_true,A,Pi,k,l, Om)
    H_obs[np.where(H_obs==0)] = np.nan
    plt.imshow(H_obs,interpolation='none')
    plt.title('Given data matrix (C_true, masked)')
    plt.subplot(2,2,2)
    H_true = ssid.yy_Hankel_cov_mat(C_true,A,Pi,k,l)
    plt.imshow(H_true,interpolation='none')
    plt.title('True  matrix (C_true)')    
    plt.subplot(2,2,3)
    plt.imshow(ssid.yy_Hankel_cov_mat_coord_asc(C_0,k,l,Qs),interpolation='none')
    plt.title('Initial matrix (C_0)')
    plt.subplot(2,2,4)
    Cd_est = np.linalg.pinv(C_est)
    plt.imshow(ssid.yy_Hankel_cov_mat_coord_asc(C_est,k,l,Qs),interpolation='none')
    plt.title('Estimated matrix (C_est)')
    plt.show()

# Debug and unfinished stuff

In [None]:
def collect_constraints(H, Ov, k, l, ovl_size):

    p = Om.shape[0]    
    assert Om.shape[1] == p and H.shape == (k*p, l*p)

    n_cnstr = np.sum(Ov)
    cnstr = 42 * np.ones(k*l*n_cnstr)


    for k_ in range(k):
        offset_k = k_*k*n_cnstr
        for l_ in range(l):
            cnstr[offset_k+l_*n_cnstr:offset_k+(l_+1)*n_cnstr] = H[k_*p:(k_+1)*p, l_*p:(l_+1)*p][Ov]

    return cnstr.reshape(-1,ovl_size)

ovl_size = 1
H = ssid.yy_Hankel_cov_mat(C_true,A,Pi,k,l)
x = collect_constraints(H, Ovw, k, l, ovl_size)
plt.imshow(x.reshape(k,l), interpolation='none')   
m,M = np.min(x), np.max(x)

plt.figure(figsize=(20,10))
for k_ in range(k):
    for l_ in range(l):
        plt.subplot(k,l,(k_*k)+l_)
        plt.imshow( x[ovl_size*(k_*k+l_):ovl_size*(k_*k+l_+1),:], interpolation='none', vmin=m, vmax=M)
        plt.title(str(x[ovl_size*(k_*k+l_):ovl_size*(k_*k+l_+1),:][0,0]))
plt.show()

tmp = x.reshape(k,l)
_,s,_ = np.linalg.svd(tmp)
print(s)

_,s,_ = np.linalg.svd(ssid.observability_mat({'A': A, 'C': C_true[3,:]},k))
print(s)

ssid.observability_mat({'A': A, 'C': C_true[3,:]},k)
