# checking information within additional time lags for reconstruction of $C$
- fitting small systems ($p=5, n=2$) to short noisy data ($T=50, SNR=0.5$) with several ($\tau=0,\ldots,4$) or a single time-lag ($\tau =0$)

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

os.chdir('../core')
import ssm_scripts, ssm_fit
from utility import get_subpop_stats, gen_data
from SSID_Hankel_loss import run_bad, plot_slim, print_slim
os.chdir('../dev')

#np.random.seed(0)

# define problem size
p, n, T = 5, 2, 50
lag_range = np.arange(0, 5)
kl_ = np.max(lag_range)+1

# I/O matter
mmap, chunksize = True, np.min((p,2000))
data_path, save_file = '../fits/', 'test'
verbose=True

data_path = '../fits/test_SSID_gain_from_lags/p5n2T50snr05/run1/'
try:
    os.mkdir(data_path)
except:
    pass

# draw system matrices 
print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')
nr = n//2 if n > 2 else 0 # number of real eigenvalues
snr = (0.5, 0.5)
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.8, 0.95, 0.8, 0.95
pars_true, x, y, Qs, idx_a, idx_b = gen_data(p,n,lag_range,T, nr,
                                             eig_m_r, eig_M_r, 
                                             eig_m_c, eig_M_c,
                                             mmap, chunksize,
                                             data_path,snr=snr)    

print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())


(p,n,k+l,T) =  (5, 2, 5, 50) 

computing empirical covariances
computing time-lagged covariance for lag  0
computing time-lagged covariance for lag  1
computing time-lagged covariance for lag  2
computing time-lagged covariance for lag  3
computing time-lagged covariance for lag  4


svmem(total=12277600256, available=9273073664, percent=24.5, used=6638002176, free=5639598080, active=3293990912, inactive=2797096960, buffers=235278336, cached=3398197248, shared=412778496)
sswap(total=0, used=0, free=0, percent=0, sin=0, sout=0)


In [4]:
try:
    os.mkdir(data_path)
except:
    pass

# peek at data & observation scheme

In [None]:
%matplotlib inline

# create subpopulations
sub_pops = (np.arange(0,p),) 

#assert(p==105)
#sub_pops = (np.arange(0,25), np.arange(20,45),np.arange(40,65),np.arange(60,85),np.arange(80,105))

obs_sweeps = 50

#assert(p==105)
#sub_pops = (np.arange(0,65),np.arange(40,p)) 

obs_pops = np.hstack([np.arange(len(sub_pops)) for i in range(obs_sweeps)])
obs_time = np.array([i*T//len(obs_pops) for i in range(1,len(obs_pops)+1)])


obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)


y_masked = np.memmap(data_path+'y', dtype=np.float, mode='c', shape=(T,p))
y_masked[:obs_time[0], np.setdiff1d(np.arange(p), sub_pops[obs_pops[0]])] = np.nan
for i in range(1,len(obs_time)):
    y_masked[np.ix_(np.arange(obs_time[i-1],obs_time[i]), np.setdiff1d(np.arange(p), sub_pops[obs_pops[i]]))] = np.nan
plt.figure(figsize=(20,10))
plt.imshow(y_masked.T, interpolation='None')
plt.show()

y_masked = np.memmap(data_path+'y', dtype=np.float, mode='c', shape=(T,p))
y_masked[:obs_time[0], np.setdiff1d(np.arange(p), sub_pops[obs_pops[0]])] = 0
for i in range(1,len(obs_time)):
    y_masked[np.ix_(np.arange(obs_time[i-1],obs_time[i]), np.setdiff1d(np.arange(p), sub_pops[obs_pops[i]]))] = 0

del y
y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))


# peek at latent covariances

In [None]:

plt.figure(figsize=(10,len(sub_pops)*5))

idx = np.array([],dtype=np.int)
for j in np.where(obs_pops==0)[0]:
    idxj = np.arange(obs_time[j-1],obs_time[j],dtype=np.int) if j > 0 else np.arange(obs_time[j],dtype=np.int)
    idx = np.hstack((idx, idxj)) 
cov_1 = np.cov(x[idx,:].T)

for i in range(len(sub_pops)):
    plt.subplot(len(sub_pops),2,2*i+1)
    idx = np.array([],dtype=np.int)
    for j in np.where(obs_pops==i)[0]:
        idxj = np.arange(obs_time[j-1],obs_time[j],dtype=np.int) if j > 0 else np.arange(obs_time[j],dtype=np.int)
        idx = np.hstack((idx, idxj)) 
    cov_i = np.cov(x[idx,:].T)
    plt.imshow(cov_i, interpolation='None')
    plt.subplot(len(sub_pops),2,2*i+2)
    plt.plot(cov_1.reshape(-1), cov_i.reshape(-1), '.')
    plt.title('corr = ' + str(np.corrcoef(cov_1.reshape(-1), cov_i.reshape(-1))[0,1]))
plt.show()

# fit with several lags

In [None]:

cheat = False
if cheat:
    pars_est = pars_true.copy()  
    pars_est['X'] = np.zeros((len(lag_range)*n, n))
    for i in range(len(lag_range)):
        pars_est['X'][i*n:(i+1)*n,:] = np.linalg.matrix_power(pars_est['A'],lag_range[i]).dot(pars_est['Pi'])
        
else: 
    pars_est = 'default'

In [None]:
%matplotlib inline

# settings for fitting algorithm
batch_size, max_zip_size, max_iter = None, np.inf, 20000
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = 1. * a

t = time.time()
_, pars_est, traces = run_bad(lag_range=lag_range,n=n,y=y_masked, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                      sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=obs_pops,obs_time=obs_time,init=pars_est,
                                      alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,batch_size=batch_size,
                                      max_zip_size=max_zip_size)

print('fitting time was ', time.time() - t, 's')
print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())

print_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)


In [None]:
plot_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)

In [None]:
_, _, _ = run_bad(lag_range=lag_range,n=n,y=y_masked, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                      sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=obs_pops,obs_time=obs_time,init=pars_est,
                                      alpha_C=0,alpha_R=0,b1_C=b1,b2_C=b2,e_C=e,max_iter=1,batch_size=1,
                                      max_zip_size=1)

pars=pars_est
tmp =  pars['C'][idx_a,:].dot(pars['X'][:n,:]).dot(pars['C'][idx_b,:].T) + np.diag(pars['R'])[np.ix_(idx_a,idx_b)]
np.sum((tmp - Qs[0])**2)


In [None]:
np.save(data_path + 'pars_est_' + str(len(lag_range)) + 'lags', pars_est)

# fit with single lag

In [None]:
lag_range =  np.array([0])
cheat = False
if cheat:
    pars_est = pars_true.copy()  
    pars_est['X'] = np.zeros((len(lag_range)*n, n))
    for i in range(len(lag_range)):
        pars_est['X'][i*n:(i+1)*n,:] = np.linalg.matrix_power(pars_est['A'],lag_range[i]).dot(pars_est['Pi'])
        
else: 
    pars_est = 'default'

In [None]:
%matplotlib inline

# settings for fitting algorithm
batch_size, max_zip_size, max_iter = None, np.inf, 20000
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = 1. * a

t = time.time()
_, pars_est, traces = run_bad(lag_range=lag_range,n=n,y=y_masked, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                                      sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                                      obs_pops=obs_pops,obs_time=obs_time,init=pars_est,
                                      alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,batch_size=batch_size,
                                      max_zip_size=max_zip_size)

print('fitting time was ', time.time() - t, 's')
print('\n')
print(psutil.virtual_memory())
print(psutil.swap_memory())

print_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)


In [None]:
plot_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)

In [None]:
np.save(data_path + 'pars_est_' + str(len(lag_range)) + 'lags', pars_est)

# compare to PCA

In [None]:
from sklearn.decomposition import PCA

pca = PCA()
pca.fit(y_masked)
explained_variance_ratio_ = np.hstack((pca.explained_variance_ratio_, np.zeros(np.max((p-T,0)))))
plt.plot(range(1,p+1), np.cumsum(explained_variance_ratio_)/np.sum(explained_variance_ratio_))
plt.hold(True)
plt.plot(np.linspace(0,p+1,np.min((20,p))), 
         np.cumsum(pca.explained_variance_ratio_[:np.min((20,p))])/np.sum(pca.explained_variance_ratio_), 
         'r--')
plt.legend(('cum. var. expl.', 'first 20, x-axis rescaled'))

pars_pca = {}
pars_pca['C'] = pca.components_[:n].T
pars_pca['Pi'] = np.diag(pca.explained_variance_[:n])    
pars_pca['X'] = np.vstack([pars_pca['Pi'] for m in range(len(lag_range))])
pars_pca['R'] = np.zeros(p)
plot_slim(Qs,lag_range,pars_pca,idx_a,idx_b,(0,0),mmap,data_path)

In [None]:
plt.imshow(pars_pca['C'].dot(pars_pca['X'][:n,:]).dot(pars_pca['C'].T), interpolation='None')
plt.show()

# compare to FA

In [None]:
%matplotlib inline
from sklearn.decomposition import FactorAnalysis

fa = FactorAnalysis()
fa.noise_variance_init=pars_true['R']
fa.fit(y_masked)
pars_fa = {}
pars_fa['C'] = fa.components_[:n].T
pars_fa['Pi'] = np.eye(n)
pars_fa['X'] = np.vstack([pars_fa['Pi'] for m in range(len(lag_range))])
pars_fa['R'] = fa.noise_variance_
plot_slim(Qs,lag_range,pars_fa,idx_a,idx_b,(0,0),mmap,data_path)


In [None]:
savedict = {'x' : x, 
            'y' : y,
            'snr' : snr, 
            'pars_true' : pars_true,
            'pars_pca' : pars_pca,
            'pars_fa' : pars_fa
           }

np.savez(data_path + 'data', savedict)


# compare to ground-truth parameters

In [None]:
pars_true['X'] = np.zeros((len(lag_range)*n, n))
kl_ = np.max(lag_range)+1
for i in range(len(lag_range)):
    m_ = lag_range[i]
    pars_true['X'][i*n:(i+1)*n,:] = np.cov(x[m_:m_-(kl_),:].T, x[:-(kl_),:].T)[:n,n:]     

plot_slim(Qs,lag_range,pars_true,idx_a,idx_b,(0,0),mmap,data_path)


In [None]:
_, _, traces = run_bad(lag_range=lag_range,n=n,y=y_masked, Qs=Qs,Om=Om,idx_a=idx_a, idx_b=idx_b,
                       sub_pops=sub_pops,idx_grp=idx_grp,co_obs=co_obs,obs_idx=obs_idx,
                       obs_pops=obs_pops,obs_time=obs_time,init=pars_true,
                       alpha_C=0,max_iter=1,batch_size=1,
                       max_zip_size=1)

print_slim(Qs,lag_range,pars_true,idx_a,idx_b,traces,mmap,data_path)

#plot_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)

pars=pars_true
tmp =  pars['C'][idx_a,:].dot(pars['X'][:n,:]).dot(pars['C'][idx_b,:].T) + np.diag(pars['R'])[np.ix_(idx_a,idx_b)]
np.sum((tmp - Qs[0])**2)

# load and compare fits

In [None]:
data_path = '../fits/test_SSID_gain_from_lags/p5n2T50snr05/run5/'  

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import glob, os, psutil, time

os.chdir('../core')
import ssm_scripts, ssm_fit
from utility import get_subpop_stats, gen_data
from SSID_Hankel_loss import run_bad, plot_slim, print_slim
os.chdir('../dev')

run = 0
data_path = '../fits/test_SSID_gain_from_lags/p5n2T50snr05/run' + str(run) + '/'  

tmp = np.load(data_path + 'data.npz')['arr_0'].tolist()
x,y,snr,pars_true,pars_pca = tmp['x'],tmp['y'],tmp['snr'],tmp['pars_true'],tmp['pars_pca']

p,n = pars_true['C'].shape
T = x.shape[0]
pars_est1 = np.load(data_path + 'pars_est_1lags.npy').tolist()
pars_est5 = np.load(data_path + 'pars_est_5lags.npy').tolist()

lag_range = np.arange(5)
idx_a, idx_b = np.arange(p), np.arange(p)
pa, pb = p, p
mmap = True
kl = len(lag_range)
Qs = []
for m_ in range(kl):
    Qs.append(np.memmap(data_path+'Qs_'+str(m_), dtype=np.float, 
                      mode='r', shape=(pa,pb)))




In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import glob, os, psutil, time
from scipy.io import savemat

os.chdir('../core')
import ssm_scripts, ssm_fit
from utility import get_subpop_stats, gen_data
from SSID_Hankel_loss import run_bad, plot_slim, print_slim
os.chdir('../dev')

for run in range(10):

    data_path = '../fits/test_SSID_gain_from_lags/p5n2T50snr05/run' + str(run) + '/'  

    tmp = np.load(data_path + 'data.npz')['arr_0'].tolist()
    x,y,snr,pars_true,pars_pca = tmp['x'],tmp['y'],tmp['snr'],tmp['pars_true'],tmp['pars_pca']

    p,n = pars_true['C'].shape
    T = x.shape[0]
    pars_est1 = np.load(data_path + 'pars_est_1lags.npy').tolist()
    pars_est5 = np.load(data_path + 'pars_est_5lags.npy').tolist()

    W = np.linalg.pinv(pars_est1['C']).dot( pars_true['C'] )
    print('1 lag :', np.mean( (pars_est1['C'].dot(W) - pars_true['C'])**2 ) )

    W = np.linalg.pinv(pars_est5['C']).dot( pars_true['C'] )
    print('5 lags:', np.mean( (pars_est5['C'].dot(W) - pars_true['C'])**2 ) )
    
    savemat(data_path + '_m', {'pars_true' : pars_true,
                               'pars_est1' : pars_est1, 
                               'pars_est5' : pars_est5, 
                               'pars_pca' : pars_pca, 
                               'pars_fa' : pars_fa
                              })



In [None]:
from scipy.io import savemat
savemat?

In [None]:
%matplotlib inline

print(pars_true['C'])
plt.imshow(pars_true['C'], interpolation='None')
plt.show()

W = np.linalg.pinv(pars_est1['C']).dot( pars_true['C'] )
print(pars_est1['C'].dot(W))
plt.imshow(pars_est1['C'].dot(W), interpolation='None')
plt.show()


W = np.linalg.pinv(pars_est5['C']).dot( pars_true['C'] )
print(pars_est5['C'].dot(W))
plt.imshow(pars_est5['C'].dot(W), interpolation='None')
plt.show()


# check details

In [None]:
#idxc_a = np.arange(p)
#idxc_b = np.arange(p)

#idxc_a = sub_pops[1].copy()
#idxc_b = sub_pops[1].copy()

# non-co-observed
i = 0
idxc_a = sub_pops[i]
idxc_b = np.setdiff1d(np.arange(p), sub_pops[i])

# only in first subpop
#idxc_a = np.setdiff1d(np.arange(p), sub_pops[1])
#idxc_b = np.setdiff1d(np.arange(p), sub_pops[1])

# only in second subpop
#idxc_a = np.setdiff1d(np.arange(p), sub_pops[0])
#idxc_b = np.setdiff1d(np.arange(p), sub_pops[0])

# get intersection
#idxc_a = np.intersect1d(sub_pops[0], sub_pops[1])
#idxc_b = idxc_a.copy()


#idx_i = 1
#if len(sub_pops) > 0:
#    idxc_b = np.setdiff1d(np.arange(p), sub_pops[idx_i]) # those only in subpop #2 
#else:
#    idxc_b = idx_a.copy()
#idxc_b = idx_a.copy()
    
pa, pb = len(idxc_a), len(idxc_b)


Qsc = []
m_range = range(len(lag_range))
corrs_est = np.zeros((len(m_range), 2))
MSE_est = np.zeros((len(m_range), 2))    
for m in m_range:
    m_ = lag_range[m]
    print('computing time-lagged covariance for lag ', str(m_))
    Qsc.append(None)
    if mmap:
        Q = np.memmap(data_path+'Qsc_'+str(m_), dtype=np.float, 
                      mode='w+', shape=(pa,pb))
    else:
        Q = np.empty((pa,pb))
    Q[:] = np.cov(y[m_:m_-(kl_),idxc_a].T, y[:-(kl_),idxc_b].T)[:pa,pa:]     
    if mmap:
        del Q
        Qsc[m] = np.memmap(data_path+'Qsc_'+str(m_), dtype=np.float, 
                          mode='r', shape=(pa,pb))
    else:
        Qsc[m] = Q
    
    pars = pars_est
    Q = (pars['C'][idxc_a,:].dot(pars['X'][m*n:(m+1)*n,:]).dot(pars['C'][idxc_b,:].T))
    if m_ == 0:
        Q += np.diag(pars['R'])[np.ix_(idxc_a, idxc_b)]
    corrs_est[m,0] = np.corrcoef(Qsc[m][:].reshape(-1), Q.reshape(-1))[0,1]
    MSE_est[m,0] = np.sum( (Qsc[m] - Q)**2 ) #/ np.mean(Qsc[m]**2)

    pars = pars_true
    Q = (pars['C'][idxc_a,:].dot(pars['X'][m*n:(m+1)*n,:]).dot(pars['C'][idxc_b,:].T))
    if m_ == 0:
        Q += np.diag(pars['R'])[np.ix_(idxc_a, idxc_b)]
    corrs_est[m,1] = np.corrcoef(Qsc[m][:].reshape(-1), Q.reshape(-1))[0,1]
    MSE_est[m,1] = np.sum( (Qsc[m] - Q)**2 ) #/ np.mean(Qsc[m]**2)
    

print('SE est. ' + str(np.mean(MSE_est[:,0])))
print('SE true ' + str(np.mean(MSE_est[:,1])))
    
plt.figure(figsize=(11,6))
ax = plt.subplot(121)
plt.bar(-.4+np.arange(len(m_range)), corrs_est[:,0], color='r')
plt.hold(True)
plt.bar(-.4+np.arange(len(m_range))+.1, corrs_est[:,1], color='b')
plt.xticks(np.arange(len(m_range)), [str(x) for x in lag_range], fontsize=14)
#plt.yticks(np.arange(0,1.01, 0.2), fontsize=14)
plt.xlabel('time-lag')
plt.ylabel('correlation')
plt.box('off')
ax.get_xaxis().tick_bottom()    
ax.get_yaxis().tick_left()    

ax = plt.subplot(122)
plt.bar(-.4+np.arange(len(m_range))+.1, MSE_est[:,1], color='b')
plt.hold(True)
plt.bar(-.4+np.arange(len(m_range)), MSE_est[:,0], color='r')
plt.xticks(np.arange(len(m_range)), [str(x) for x in lag_range], fontsize=14)
#plt.yticks(np.arange(0,0.0013,0.0004), fontsize=14)
plt.xlabel('time-lag')
plt.ylabel('norm. MSE [%]')
plt.box('off')
plt.legend(('ground-truth','stitched'))
ax.get_xaxis().tick_bottom()    
ax.get_yaxis().tick_left()    
plt.show()
print('corrs_est', corrs_est)

plt.figure()
plt.plot(Q.reshape(-1), Qsc[m].reshape(-1), '.')
plt.hold(True)
plt.plot(Qsc[m].reshape(-1), Qsc[m].reshape(-1), 'k.')
plt.show()

In [None]:
out = 0
pars = pars_est
for i in range(len(idx_grp)):
    Q = (pars['C'][idxc_a,:].dot(pars['X'][m*n:(m+1)*n,:]).dot(pars['C'][idxc_b,:].T))
    if m_ == 0:
        Q += np.diag(pars['R'])[np.ix_(idxc_a, idxc_b)]    
    tmp = np.sum( (Qsc[m][np.ix_(idx_grp[i],co_obs[i])] - Q[np.ix_(idx_grp[i],co_obs[i])])**2 )
    print(tmp)
    print(idx_grp[i],co_obs[i])
    out += tmp
print(out)

In [None]:
os.chdir('../core')
import ssm_scripts, ssm_fit
from SSID_Hankel_loss import f_l2_Hankel_nl, f_l2_inst, f_l2_block
os.chdir('../dev')


obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=(np.arange(p),), p=p, verbose=False)
pars = pars_est
print(f_l2_Hankel_nl(pars['C'],pars['X'],None,pars['R'],lag_range,Qs,idx_grp,co_obs,
                   idx_a=idx_a,idx_b=idx_b,W=None))

pars = pars_true
print(f_l2_Hankel_nl(pars['C'],pars['X'],None,pars['R'],lag_range,Qs,idx_grp,co_obs,
                   idx_a=idx_a,idx_b=idx_b,W=None))

In [None]:
pars = pars_est
C,Pi,R,Q,W = pars['C'],pars['X'][:n,:],pars['R'],Qsc[0],None
err = 0.
for i in range(len(idx_grp)):

    a = np.intersect1d(idx_grp[i],idx_a)
    b = np.intersect1d(co_obs[i], idx_b)
    a_Q = np.in1d(idx_a, idx_grp[i])
    b_Q = np.in1d(idx_b, co_obs[i])

    v = (C[a,:].dot(Pi).dot(C[b,:].T) - Q[np.ix_(a_Q,b_Q)])
    idx_R = np.where(np.in1d(b,a))[0]
    v[np.arange(len(idx_R)), idx_R] += R[a]
    v = v.reshape(-1,) if  W is None else W.reshape(-1,)*v.reshape(-1,)

    err += v.dot(v)
err

In [None]:
pars = pars_true
C,Pi,R,Q,W = pars['C'],pars['X'][:n,:],pars['R'],Qsc[0],None
err = 0.
for i in range(len(idx_grp)):

    a = np.intersect1d(idx_grp[i],idx_a)
    b = np.intersect1d(co_obs[i], idx_b)
    a_Q = np.in1d(idx_a, idx_grp[i])
    b_Q = np.in1d(idx_b, co_obs[i])

    v = (C[a,:].dot(Pi).dot(C[b,:].T) - Q[np.ix_(a_Q,b_Q)])
    idx_R = np.where(np.in1d(b,a))[0]
    v[np.arange(len(idx_R)), idx_R] += R[a]
    v = v.reshape(-1,) if  W is None else W.reshape(-1,)*v.reshape(-1,)

    err += v.dot(v)
err

In [None]:
pars=pars_est
mask = np.zeros((p,p))
idx = sub_pops[0]
Qest = pars['C'][idx,:].dot(pars['X'][:n,:]).dot(pars['C'][idx,:].T)+np.diag(pars['R'][idx])
tmp = np.sum((Qest - Qs[0][np.ix_(idx,idx)])**2)
mask[np.ix_(idx,idx)] += 1

idx = sub_pops[1]
Qest = pars['C'][idx,:].dot(pars['X'][:n,:]).dot(pars['C'][idx,:].T)+np.diag(pars['R'][idx])
tmp += np.sum((Qest - Qs[0][np.ix_(idx,idx)])**2)
mask[np.ix_(idx,idx)] += 1

idx = np.intersect1d(sub_pops[0],sub_pops[1])
Qest = pars['C'][idx,:].dot(pars['X'][:n,:]).dot(pars['C'][idx,:].T)+np.diag(pars['R'][idx])
tmp -= np.sum((Qest - Qs[0][np.ix_(idx,idx)])**2)
mask[np.ix_(idx,idx)] -= 1

plt.imshow(mask, interpolation='None')
plt.show()

tmp

In [None]:
pars = pars_est
print(f_l2_inst(pars['C'],pars['X'][:n,:],pars['R'],Qs[0],idx_grp,co_obs,
                   idx_a=idx_a,idx_b=idx_b,W=None))