In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error

data_path = '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/fits/compare_vs_grouse/'

#np.random.seed(0)

# define problem size
p, n, T = 10, 2, 200
lag_range_full = np.arange(0,10)
lag_range = lag_range_full.copy()
kl_ = np.max(lag_range)+1

nr = 0 # number of real eigenvalues
snr = (0.0, 0.0)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.95, 0.90, 0.95

print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

# I/O matter
mmap, chunksize = False, np.min((p,2000))
verbose=True

# create subpopulations
sub_pops = (np.arange(0,p), np.arange(0,p))
obs_pops = np.array([0,1])
obs_time = np.array([T//2,T])

obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)

obs_scheme = ObservationScheme(p=p, T=T, 
                               sub_pops=sub_pops, obs_pops=obs_pops, 
                               obs_time=obs_time, obs_idx=obs_time, 
                               idx_grp=idx_grp)

    
# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, np.inf, 100
a, b1, b2, e = 0.01, 0.99, 0.99, 1e-8
a_R = 0 * a

# settings for quick initial SGD fitting phase for our model
batch_size_late, max_zip_size_late, max_iter_late = 100, np.inf, 900
a_late, b1_late, b2_late, e_late = 0.005, 0.9, 0.99, 1e-8
a_R_late = 0 * a_late
    
# settings for GROUSE
a_grouse = 0.0005
max_iter_grouse = 200

n_obs = np.ceil(p * 0.5)
mask = np.zeros((T,p))
for t in range(T):
    for i in range(len(obs_time)):
        if t < obs_time[i]:
            #mask[t, np.random.choice(p, n_obs, replace=False)] = 1
            mask[t,sub_pops[obs_pops[i]]] = 1
            break            
            
obs_scheme.mask = mask
plt.figure(figsize=(20,10))
plt.imshow(mask.T)
plt.show()

num_runs = 1
res = np.zeros((num_runs, 3))
rgt = np.zeros((num_runs, 2))
for run in range(num_runs):
    
    print('run ' + str(run+1) + '/' + str(num_runs))
    
    # draw system matrices 
    lag_range = lag_range_full.copy()
    pars_true, x, y, Qs, idx_a, idx_b = gen_data(p,n,lag_range,T, nr,
                                                 eig_m_r, eig_M_r, 
                                                 eig_m_c, eig_M_c,
                                                 mmap, chunksize,
                                                 data_path,
                                                 snr=snr, whiten=whiten)    
    y -= y.mean(axis=0)
    
    #y[mask==0] = np.nan
    plt.imshow(mask.T)
    plt.show()
    
    pars_true['X'] = np.vstack([np.linalg.matrix_power(pars_true['A'],k).dot(pars_true['Pi']) for k in lag_range_full])

    # fit our model with multiple time-lags
    print('\n - multiple lags')
    pars_est_m = 'default'
    
    t = time.time()
    
    proj_errors = np.zeros((max_iter, 2))
    def pars_track(C,X,R,t): 
        proj_errors[t,0] = calc_subspace_proj_error(pars_true['C'], C)
    _, pars_est_m, traces_m = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_m,
                                          alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                          batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size,
                                          pars_track=pars_track)
    pars_est_m['R'] *= 0

    proj_errors_late = np.zeros((max_iter_late, 2))    
    def pars_track(C,X,R,t): 
        proj_errors_late[t,0] = calc_subspace_proj_error(pars_true['C'], C)
    _, pars_est_m, traces_m2 = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_m,
                                          alpha_C=a_late,alpha_R=a_R_late,b1_C=b1_late,b2_C=b2_late,e_C=e_late,
                                          max_iter=max_iter_late,batch_size=batch_size_late,
                                          verbose=verbose, max_zip_size=max_zip_size_late,pars_track=pars_track)
    traces_m = (np.hstack((traces_m[0], traces_m2[0])), np.hstack((traces_m[1], traces_m2[1])))
    t_m = time.time() - t
    print_slim(Qs,lag_range,pars_est_m,idx_a,idx_b,traces_m,mmap,data_path)
    print('fitting time was ', t_m, 's')
    print('rank of final C_est: ', sp.linalg.orth(pars_est_m['C']).shape[1])
    print('final proj. error: ', str(calc_subspace_proj_error(pars_true['C'], pars_est_m['C'])))

    rgt[run, 0] = f_l2_Hankel_nl(C=pars_true['C'],
                                 X=np.vstack([np.cov(x[k_:-(kl_)+k_, :].T, x[:-(kl_), :].T)[:n,n:] for k_ in lag_range]),
                                 Pi=np.cov(x.T),
                                 R=pars_true['R'],lag_range=lag_range,Qs=Qs,
                                 idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b)
    print('final error: ' + str(traces_m[0][-1]))    
    print('ground-truth reference error: ' + str(rgt[run,0]))


    # fit our model with single time-lag    
    print('\n - single lag')
    lag_range, pars_est_s = np.array([0]), 'default'
    t = time.time()
    def pars_track(C,X,R,t): 
        proj_errors[t,1] = calc_subspace_proj_error(pars_true['C'], C)
    
    _, pars_est_s, traces_s = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_s,
                                          alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,batch_size=batch_size,
                                          verbose=verbose, max_zip_size=max_zip_size,pars_track=pars_track)
    
    pars_est_s['R'] *= 0
    
    def pars_track(C,X,R,t): 
        proj_errors_late[t,1] = calc_subspace_proj_error(pars_true['C'], C)
    _, pars_est_s, traces_s2 = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                          obs_scheme=obs_scheme,init=pars_est_s,
                                          alpha_C=a_late,alpha_R=a_R_late,b1_C=b1_late,b2_C=b2_late,e_C=e_late,
                                          max_iter=max_iter_late,batch_size=batch_size_late,
                                          verbose=verbose, max_zip_size=max_zip_size_late,pars_track=pars_track)
    traces_s = (np.hstack((traces_s[0], traces_s2[0])), np.hstack((traces_s[1], traces_s2[1])))
    proj_errors = np.vstack((proj_errors, proj_errors_late))
    t_s = time.time() - t
    print_slim(Qs,lag_range,pars_est_s,idx_a,idx_b,traces_s,mmap,data_path)
    print('fitting time was ', t_s, 's')
    print('rank of final C_est: ', sp.linalg.orth(pars_est_s['C']).shape[1])
    print('final proj. error: ', str(calc_subspace_proj_error(pars_true['C'], pars_est_s['C'])))

    rgt[run, 1] = f_l2_Hankel_nl(C=pars_true['C'],
                                 X=np.vstack([np.cov(x[k_:-(kl_)+k_, :].T, x[:-(kl_), :].T)[:n,n:] for k_ in lag_range]),
                                 Pi=np.cov(x.T),
                                 R=pars_true['R'],lag_range=lag_range,Qs=Qs,
                                 idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b)    
    
    print('final error: ' + str(traces_s[0][-1]))        
    print('ground-truth reference error: ' + str(rgt[run,1]))

    # fit GROUSE
    t = time.time()
    print('\n - GROUSE')
    tracker = Grouse(p, n, a_grouse )
    error = np.zeros(max_iter_grouse)
    for i in range(max_iter_grouse):
        if verbose and np.mod(i,max_iter_grouse//10) == 0:
            print('finished % ' + str((100*i)//max_iter_grouse))
        idx = np.random.permutation(T)
        for j in range(T):
            tracker.consume(y[idx[j],:].reshape(p,1), mask[idx[j],:].reshape(p,1))

        error[i] = calc_subspace_proj_error(pars_true['C'], tracker.U)
    t_g = time.time() - t
    pars_est_g = {'C' : tracker.U}
    
    
    res[run,:] = np.array([ calc_subspace_proj_error(pars_true['C'], pars_est_m['C']),
                            calc_subspace_proj_error(pars_true['C'], pars_est_s['C']),
                            error[-1]])
    plt.figure(figsize=(20,10))
    plt.subplot(1,3,1)
    plt.loglog(traces_m[0])
    plt.xlabel('norm. SE')
    plt.title('final error multiple time-lags: ' + str(calc_subspace_proj_error(pars_true['C'], pars_est_m['C'])))    
    plt.subplot(1,3,2)
    plt.loglog(traces_s[0])
    plt.xlabel('norm. SE')
    plt.title('final error single time-lag: ' + str(calc_subspace_proj_error(pars_true['C'], pars_est_s['C'])))
    plt.subplot(1,3,3)
    plt.loglog(range(1,max_iter_grouse+1), error)
    plt.title('final error GROUSE: ' + str(error[-1]))
    plt.show()
    
    plt.figure(figsize=(20,10))
    plt.plot(proj_errors[:,0])
    plt.hold(True)
    plt.plot(proj_errors[:,1])
    plt.show()

    save_dict = {'p' : p,
                 'n' : n,
                 'T' : T,
                 'snr' : snr,
                 'obs_scheme' : obs_scheme,
                 'lag_range' : lag_range_full,
                 'x' : x,
                 'y' : y,
                 'pars_true' : pars_true,
                 'pars_est_s' : pars_est_s,
                 'pars_est_m' : pars_est_m,
                 'pars_est_g' : pars_est_g,
                 'res' : res,
                 'rgt' : rgt,
                 't_s' : t_s,
                 't_m' : t_m,
                 't_g' : t_g,
                 'traces_m' : traces_m,
                 'traces_s' : traces_s,
                 'traces_g' : error
                }
    file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_run' + str(run) + '_partial_dat'
    #np.savez(data_path + file_name, save_dict)

    

In [None]:
lag_range = lag_range_full.copy()
# fit our model with multiple time-lags
print('\n - multiple lags')

# settings for quick initial SGD fitting phase for our model
batch_size_late, max_zip_size_late, max_iter_late = None, np.inf, 1000
a_late, b1_late, b2_late, e_late = 0.01, 0.9, 0.99, 1e-8
a_R_late = 1. * a_late


#pars_est_m = {'C' : pars_true['C'].copy(), 
#              'X' : np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]), 
#              'R' : pars_true['R'].copy()
#             }
t = time.time()

proj_errors_late = np.zeros((max_iter_late, 2))    
def pars_track(C,X,R,t): 
    proj_errors_late[t,0] = calc_subspace_proj_error(pars_true['C'], C)
_, pars_est_m, traces_m = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est_m,
                                      alpha_C=a_late,alpha_R=a_R_late,b1_C=b1_late,b2_C=b2_late,e_C=e_late,
                                      max_iter=max_iter_late,batch_size=batch_size_late,
                                      verbose=verbose, max_zip_size=max_zip_size_late,pars_track=pars_track)
t_m = time.time() - t
print_slim(Qs,lag_range,pars_est_m,idx_a,idx_b,traces_m,mmap,data_path)
print('fitting time was ', t_m, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est_m['C']).shape[1])
print('final proj. error: ', str(calc_subspace_proj_error(pars_true['C'], pars_est_m['C'])))
rgt[run, 0] = f_l2_Hankel_nl(C=pars_true['C'],
                             X=np.vstack([np.cov(x[k_:(-kl_+k_), :].T, x[:-kl_, :].T)[:n,n:] for k_ in lag_range]),
                             Pi=np.cov(x.T),
                             R=pars_true['R'],lag_range=lag_range,Qs=Qs,
                             idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b)
print('final error: ' + str(traces_m[0][-1]))    
print('ground-truth reference error: ' + str(rgt[run,0]))

In [None]:
lag_range = np.array([0])
# fit our model with multiple time-lags
print('\n - multiple lags')

# settings for quick initial SGD fitting phase for our model
batch_size_late, max_zip_size_late, max_iter_late = None, np.inf, 1000
a_late, b1_late, b2_late, e_late = 0.002, 0.9, 0.99, 1e-8
a_R_late = 1 * a_late

t = time.time()

proj_errors_late = np.zeros((max_iter_late, 2))    
def pars_track(C,X,R,t): 
    proj_errors_late[t,0] = calc_subspace_proj_error(pars_true['C'], C)
_, pars_est_s, traces_s = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est_s,
                                      alpha_C=a_late,alpha_R=a_R_late,b1_C=b1_late,b2_C=b2_late,e_C=e_late,
                                      max_iter=max_iter_late,batch_size=batch_size_late,
                                      verbose=verbose, max_zip_size=max_zip_size_late,pars_track=pars_track)
t_s = time.time() - t
print_slim(Qs,lag_range,pars_est_s,idx_a,idx_b,traces_s,mmap,data_path)
print('fitting time was ', t_s, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est_s['C']).shape[1])
print('final proj. error: ', str(calc_subspace_proj_error(pars_true['C'], pars_est_s['C'])))
rgt[run, 1] = f_l2_Hankel_nl(C=pars_true['C'],
                             X=np.vstack([np.cov(x[k_:-(kl_)+k_, :].T, x[:-(kl_), :].T)[:n,n:] for k_ in lag_range]),
                             Pi=np.cov(x.T),
                             R=pars_true['R'],lag_range=lag_range,Qs=Qs,
                             idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b)
print('final error: ' + str(traces_s[0][-1]))    
print('ground-truth reference error: ' + str(rgt[run,1]))

In [None]:
# fit GROUSE
t = time.time()
print('\n - GROUSE')
tracker.step = 0.0001
error = np.zeros(max_iter_grouse)
for i in range(max_iter_grouse):
    if verbose and np.mod(i,max_iter_grouse//10) == 0:
        print('finished % ' + str((100*i)//max_iter_grouse))
    idx = np.random.permutation(T)
    for j in range(T):
        tracker.consume(y[idx[j],:].reshape(p,1), mask[idx[j],:].reshape(p,1))

    error[i] = calc_subspace_proj_error(pars_true['C'], tracker.U)
t_g = time.time() - t
pars_est_g = {'C' : tracker.U}
print(error[-1])
plt.plot(error)
plt.show()

In [None]:
t, m = np.random.choice(p, 1), 0
m_ = m

C, Xm, R = pars_true['C'], pars_true['X'][m*n:(m+1)*n, :], pars_true['R']
p,n = C.shape
grad_C = np.zeros((p,n))
grad_X = np.zeros(pars_true['X'].shape)
grad_R = np.zeros(p)
idx_ct = np.zeros((p,2),dtype=np.int32)

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

a,b = obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]
a,b = np.where(a)[1], np.where(b)[1]

anb = np.intersect1d(a,b)
a_ = np.setdiff1d(a,b)
b_ = np.setdiff1d(b,a)

yf = y[t+m_,a]
yp = y[t,b]

Om = np.outer(obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]).astype(bool)
L = np.outer(y[t+m,:], y[t,:])
grad_C = np.zeros_like(C)
for k in range(p):
    for i in range(p):
        for j in range(p):        

            if Om[i,j]:
                #print(i,j)
                Ci, Cj = C[i,:], C[j,:]
                if k==i and k!=j:
                    #print('1')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                if k==j and k!=i:
                    #print('2')
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                if k==i and k==j:
                    #print('3')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                    if m ==0:
                        grad_C[k,:] += R[i] * (Cj.dot(Xm.T)+Ci.dot(Xm))
                    
print(grad_C)
grad_C_blunt = grad_C.copy()
#g_C_l2_Hankel_vector_pair(grad_C, m_, C, Xm, R, a, b, ab, CC_a, CC_b, yp, yf)    

grad_C = np.zeros((p,n))

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

grad_C[a,:] += C[a,:].dot( C_tr[b,:].T.dot(C_tr[b,:]) ) - np.outer(yf,yp.dot(C_tr[b,:]))
grad_C[b,:] += C[b,:].dot( C___[a,:].T.dot(C___[a,:]) ) - np.outer(yp,yf.dot(C___[a,:]))

# correction for variables not observed both at t+m_ and t  
#if a_.size > 0:
#    grad_C[a_,:] -= (np.sum(C[a_,:]*C_tr[a_,:],axis=1) - y[t+m,a_]*y[t,a_]).reshape(-1,1) * C_tr[a_,:]
#if b_.size > 0:
#    grad_C[b_,:] -= (np.sum(C[b_,:]*C___[b_,:],axis=1) - y[t+m,b_]*y[t,b_]).reshape(-1,1) * C___[b_,:]

if m_==0: 
    grad_C[anb,:] += R[anb].reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])
print(grad_C)

print(anb)

assert np.allclose(grad_C_blunt, grad_C)

plt.imshow(Om, interpolation='None')
plt.show()

In [None]:
def f_l2_block(C,AmPi,Q,idx_grp,co_obs,idx_a,idx_b,W=None):
    "Hankel reconstruction error on an individual Hankel block"

    err = 0.
    for i in range(len(idx_grp)):
        err_ab = 0.
        a = np.intersect1d(idx_grp[i],idx_a)
        b = np.intersect1d(co_obs[i], idx_b)
        a_Q = np.in1d(idx_a, idx_grp[i])
        b_Q = np.in1d(idx_b, co_obs[i])

        v = (C[a,:].dot(AmPi).dot(C[b,:].T) - Q[np.ix_(a_Q,b_Q)])
        v = v.reshape(-1,) if  W is None else W.reshape(-1,) * v.reshape(-1,)

        err += v.dot(v)

    return err

def f_l2_inst(C,Pi,R,Q,idx_grp,co_obs,idx_a,idx_b,W=None):
    "reconstruction error on the instantaneous covariance"

    err = 0.
    if not Q is None:
        for i in range(len(idx_grp)):

            a = np.intersect1d(idx_grp[i],idx_a)
            b = np.intersect1d(co_obs[i], idx_b)
            a_Q = np.in1d(idx_a, idx_grp[i])
            b_Q = np.in1d(idx_b, co_obs[i])

            v = (C[a,:].dot(Pi).dot(C[b,:].T) - Q[np.ix_(a_Q,b_Q)])
            idx_R = np.where(np.in1d(b,a))[0]
            v[np.arange(len(idx_R)), idx_R] += R[a]
            v = v.reshape(-1,) if  W is None else W.reshape(-1,)*v.reshape(-1,)

            err += v.dot(v)

    return err

def f_l2_Hankel_nl(C,X,Pi,R,lag_range,Qs,idx_grp,co_obs,
                   idx_a=None,idx_b=None,W=None):
    "returns overall l2 Hankel reconstruction error"

    kl = len(lag_range)
    p,n = C.shape
    idx_a = np.arange(p) if idx_a is None else idx_a
    idx_b = idx_a if idx_b is None else idx_b
    assert (len(idx_a), len(idx_b)) == Qs[0].shape

    err = np.zeros(kl)
    err[0] = f_l2_inst(C,X[:n, :],R,Qs[0],idx_grp,co_obs,idx_a,idx_b)
    for m in range(1,kl):
        err[m]= f_l2_block(C,X[m*n:(m+1)*n, :],Qs[m],idx_grp,co_obs,idx_a,idx_b,W)
            
    return err
pars_est = pars_est_m.copy()
print(f_l2_Hankel_nl(pars_est['C'], pars_est['X'], pars_est['X'][:n,:], pars_est['R'], lag_range, Qs, idx_grp, co_obs))

pars_est = pars_true.copy()
print(f_l2_Hankel_nl(pars_est['C'], np.vstack([np.cov(x[k_:-(kl_)+k_, :].T, x[:-(kl_), :].T)[:n,n:] for k_ in lag_range]), pars_est['X'][:n,:], pars_est['R'], lag_range, Qs, idx_grp, co_obs))


In [None]:
m_max = 10
t,m = np.random.choice(T-m_max, 1)[0], np.random.choice(m_max, 1)[0]

a = np.where(obs_scheme.mask[t+m,:])[0]
b = np.where(obs_scheme.mask[t,:])[0]

a = np.random.choice(p,p//2,replace=False)
b = a.copy() if m == 0 else np.random.choice(p,p//2,replace=False)


anb = np.intersect1d(a,b)
pars_est = pars_true.copy()

#Xm, R = pars_est['X'][m*n:(m+1)*n,:], pars_est['R']
Xm, R = np.random.normal(size=(n,n)), np.random.normal(size=(p))**2
C = np.random.normal(size=(p,n))


def fC(C):
    C = C.reshape(p,n)
    return 0.5*np.sum( (C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)* np.diag(R)[np.ix_(a,b)] - np.outer(y[t+m,a], y[t,b]) )**2)
def fX(Xm):
    Xm = Xm.reshape(n,n)
    return 0.5*np.sum( (C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)* np.diag(R)[np.ix_(a,b)] - np.outer(y[t+m,a], y[t,b]) )**2)
def fR(R):
    return 0.5*np.sum( (C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)* np.diag(R)[np.ix_(a,b)] - np.outer(y[t+m,a], y[t,b]) )**2)



def gC(C): 
    C = C.reshape(p,n)
    C___ = C.dot(Xm)   # mad-
    C_tr = C.dot(Xm.T) # ness
    grad_C = np.zeros((p,n))
    grad_C[a,:] += C[a,:].dot( C_tr[b,:].T.dot(C_tr[b,:]) ) - np.outer(y[t+m,a],y[t,b].dot(C_tr[b,:]))
    grad_C[b,:] += C[b,:].dot( C___[a,:].T.dot(C___[a,:]) ) - np.outer(y[t,b],y[t+m,a].dot(C___[a,:]))
    if m ==0:
        grad_C[anb,:] += R[anb].reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])  
    return grad_C.reshape(-1)

def gX(Xm): 
    Xm = Xm.reshape(n,n)
    CC_a = C[a,:].T.dot(C[a,:])
    CC_b = C[b,:].T.dot(C[b,:]) if not a is b else CC_a    
    grad_X = CC_a.dot(Xm).dot(CC_b) - np.outer(y[t+m,a].dot(C[a,:]), y[t,b].dot(C[b,:]))
    if m == 0:
        grad_X += C[anb,:].T.dot(R[anb].reshape(-1,1) * C[anb,:])
    return grad_X.reshape(-1)

def gR(R): 
    grad_R = np.zeros(p)
    if m==0:
        grad_R[b] = R[b] + np.sum(C[b,:] * C[b,:].dot(Xm.T),axis=1) - y[t,b]**2
    return grad_R


print(fC(C), fX(Xm), fR(R))
#gC(pars_est['C'])

#def fC(C):
#    return 0.5*np.sum(C**2)
#def gC(C):
#    return C

print(a,b,anb)
print('m', m)

print('grad C', sp.optimize.check_grad(fC, gC, C.reshape(-1)))
print('grad X', sp.optimize.check_grad(fX, gX, Xm.reshape(-1)))
print('grad R', sp.optimize.check_grad(fR, gR, R.reshape(-1)))


In [None]:
m_max = 4
n_t = 10
ts,ms = np.random.choice(T-m_max, n_t), np.random.choice(m_max, 1)[0] * np.ones(n_t,dtype=int)

a = [np.where(obs_scheme.mask[t+m,:])[0] for (t,m) in zip(ts, ms)]
b = [np.where(obs_scheme.mask[t,:])[0] for (t,m) in zip(ts, ms)]

a = [np.random.choice(p,p//2,replace=False) for (t,m) in zip(ts, ms)]
b = [a[i].copy() if ms[i] == 0 else np.random.choice(p,p//2,replace=False) for i in range(len(ts))]


anb = [np.intersect1d(a[i],b[i]) for i in range(len(ts)) ]
pars_est = pars_true.copy()

#Xm, R = pars_est['X'][m*n:(m+1)*n,:], pars_est['R']
X, R = np.random.normal(size=(m_max*n,n)), np.random.normal(size=(p))**2
C = np.random.normal(size=(p,n))

def fC(C):
    C = C.reshape(p,n)
    f = 0
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        f += ((C[a[i],:].dot(Xm).dot(C[b[i],:].T) + \
               (ms[i]==0)* np.diag(R)[np.ix_(a[i],b[i])] - \
               np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]]))**2).sum()
    
    return 0.5*f / len(ts)

def fC_rw(C):
    C = C.reshape(p,n)
    f = 0
    
    nC = [np.zeros((p,p), dtype=int) for m in range(m_max)]
    for i in range(len(ts)):        
        nC[ms[i]][np.ix_(a[i], b[i])] += 1
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        f += (((C[a[i],:].dot(Xm).dot(C[b[i],:].T) + \
                (ms[i]==0)* np.diag(R)[np.ix_(a[i],b[i])] - \
                np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]]))**2)/nC[ms[i]][np.ix_(a[i],b[i])]).sum()
    
    return 0.5*f

def fC_(C):
    C = C.reshape(p,n)
    S  = [np.zeros((p,p)) for m in range(m_max)]
    cS = [np.zeros((p,p), dtype=int) for m in range(m_max)]
    for i in range(len(ts)):        
        S[ ms[i]][np.ix_(a[i], b[i])] += np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]])
        cS[ms[i]][np.ix_(a[i], b[i])] += 1
    Om = [cS[m] > 0 for m in range(m_max)]
    cS = [np.maximum(cS[m], 1) for m in range(m_max)]
    #print(cS)
    return 0.5*np.sum([np.sum( (C.dot(X[m*n:(m+1)*n,:]).dot(C.T) + (m==0)*np.diag(R) - S[m]/cS[m])[Om[m]]**2) for m in range(m_max)])


def fX(X):
    X = X.reshape(-1,n)
    f = 0
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        
        f += ((C[a[i],:].dot(Xm).dot(C[b[i],:].T) - np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]]) + (ms[i]==0)* np.diag(R)[np.ix_(a[i],b[i])])**2).sum()
    
    return 0.5*f / len(ts)

def fR(R):
    S = np.zeros((p,p))
    f = 0
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        
        f += ((C[a[i],:].dot(Xm).dot(C[b[i],:].T) - np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]]) + (ms[i]==0)* np.diag(R)[np.ix_(a[i],b[i])])**2).sum()
    
    return 0.5*f / len(ts)


def gC(C): 
    C = C.reshape(p,n)
    grad_C = np.zeros((p,n))

    nC = [np.zeros((p,p), dtype=int) for m in range(m_max)]
    for i in range(len(ts)):        
        nC[ms[i]][np.ix_(a[i], b[i])] += 1    
    
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        C___ = C.dot(Xm)   # mad-
        C_tr = C.dot(Xm.T) # ness        
        grad_C[a[i],:] += C[a[i],:].dot( C_tr[b[i],:].T.dot(C_tr[b[i],:]) ) - np.outer(y[ts[i]+ms[i],a[i]],y[ts[i],b[i]].dot(C_tr[b[i],:]))
        grad_C[b[i],:] += C[b[i],:].dot( C___[a[i],:].T.dot(C___[a[i],:]) ) - np.outer(y[ts[i],b[i]],y[ts[i]+ms[i],a[i]].dot(C___[a[i],:]))
        if ms[i] ==0:
            grad_C[anb[i],:] += R[anb[i]].reshape(-1,1)*(C___[anb[i],:] + C_tr[anb[i],:])  
    return grad_C.reshape(-1) / len(ts)

def gX(X):
    X = X.reshape(-1, n)
    grad_X = np.zeros_like(X)
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        CC_a = C[a[i],:].T.dot(C[a[i],:])
        CC_b = C[b[i],:].T.dot(C[b[i],:])
        grad_X[ms[i]*n:(ms[i]+1)*n,:] += CC_a.dot(Xm).dot(CC_b) - np.outer(y[ts[i]+ms[i],a[i]].dot(C[a[i],:]), y[ts[i],b[i]].dot(C[b[i],:]))
        if ms[i] == 0:
            grad_X[:n,:] += C[anb[i],:].T.dot(R[anb[i]].reshape(-1,1) * C[anb[i],:])
    return grad_X.reshape(-1) / len(ts)

def gR(R): 
    grad_R = np.zeros(p)
    for i in range(len(ts)):
        if ms[i]==0:
            grad_R[b[i]] += R[b[i]] + np.sum(C[b[i],:] * C[b[i],:].dot(X[:n,:].T),axis=1) - y[ts[i],b[i]]**2
    return grad_R / len(ts)

print(fC(C), fX(X), fR(R))
#gC(pars_est['C'])

#def fC(C):
#    return 0.5*np.sum(C**2)
#def gC(C):
#    return C

print(a,b,anb)
print('ms', ms)

print('grad C (actual)', sp.optimize.check_grad(fC, gC, C.reshape(-1)))
print('grad C (Hankel)', sp.optimize.check_grad(fC_, gC, C.reshape(-1)))
print('grad C (corr. )', sp.optimize.check_grad(fC_rw, gC, C.reshape(-1)))



print('grad X', sp.optimize.check_grad(fX, gX, X.reshape(-1)))
print('grad R', sp.optimize.check_grad(fR, gR, R.reshape(-1)))

fC(C), fC_(C), fC__(C)