# checking gradients

# gen toy data

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error

# define problem size
lag_range = np.arange(0,4)
kl_ = np.max(lag_range)+1
p, n, T = 10, 2, 10000 + kl_

nr = 0 # number of real eigenvalues
snr = (10.0, 10.0)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.9, 0.95, 0.90, 0.95

print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

# I/O matter
mmap, chunksize = False, np.min((p,2000))
verbose=True

# create subpopulations
num_pops = 5
reps = 10
sub_pops = [np.arange(i*p//num_pops, (i+1)*p//num_pops) for i in range(num_pops)]
obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)

sub_pops = (np.arange(p), )
obs_pops = (0,)
obs_time = (T,)

obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)
obs_scheme = ObservationScheme(p=p, T=T, 
                               sub_pops=sub_pops, obs_pops=obs_pops, 
                               obs_time=obs_time, obs_idx=obs_time, 
                               idx_grp=idx_grp)

n_obs = np.ceil(p * 0.5)
mask = np.zeros((T,p))
for t in range(T):
    for i in range(len(obs_time)):
        if t < obs_time[i]:
            mask[t, np.random.choice(p, n_obs, replace=False)] = 1
            #mask[t,sub_pops[obs_pops[i]]] = 1
            break            
            
obs_scheme.mask = mask
plt.figure(figsize=(20,10))
plt.imshow(mask.T)
plt.show()

# draw system matrices / data
data_path = '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/fits/compare_vs_grouse/'
pars_true, x, y, Qs, idx_a, idx_b = gen_data(p,n,lag_range,T, nr,
                                             eig_m_r, eig_M_r, 
                                             eig_m_c, eig_M_c,
                                             mmap, chunksize,
                                             data_path,
                                             snr=snr, whiten=whiten)    
pars_true['X'] = np.vstack([np.cov(x[m_:T-kl_+m_].T,x[:T-kl_].T)[:n,n:] for m_ in lag_range])

y[mask==0] = np.nan
y -= np.nanmean(y, axis=0)

if not obs_scheme.mask is None:
    def get_observed(t):
        return np.where(obs_scheme.mask[t,:])[0]
elif obs_time is None or obs_pops is None or sub_pops is None:
    def get_observed(t):
        return range(p) 
else:
    def get_observed(t):
        i = obs_pops[np.digitize(t, obs_time)]
        return sub_pops[i]

plt.figure(figsize=(20,6))
plt.plot(x[:np.min((T,200)),:])
plt.show()
    
pars_est = 'default'
pars_true['X'] = np.vstack([np.cov(x[m_:T-kl_+m_].T,x[:T-kl_].T)[:n,n:] for m_ in lag_range])

W = [np.zeros((p,p), dtype=int) for m in lag_range]
S  = [np.zeros((p,p)) for m in range(len(lag_range))]
for m in range(len(lag_range)):
    m_ = lag_range[m]
    for t in range(T-kl_):
        a, b = get_observed(t+m_), get_observed(t)
        W[m][np.ix_(a,b)] += 1
        S[m][np.ix_(a,b)] += np.outer(y[t+m_,a], y[t,b])
    W[m] = 1./np.maximum(W[m]-1, 1)
    


for m in range(len(lag_range)):
    
    Q_pred = pars_true['C'].dot(pars_true['X'][m*n:(m+1)*n,:]).dot(pars_true['C'].T)
    if lag_range[m] == 0:
        Q_pred += np.diag( pars_true['R'] )
    Q_emp  = S[m] * W[m]
    
    plt.figure(figsize=(20,5))
    plt.subplot(1,3,1)
    plt.imshow(np.hstack((Q_pred, Q_emp)),
               interpolation='None')

    plt.subplot(1,3,2)
    plt.imshow( Q_pred / Q_emp,
               interpolation='None')
    plt.colorbar()

    plt.subplot(1,3,3)
    plt.plot(Q_pred.reshape(-1),Q_emp.reshape(-1),'.')

    plt.show()

    Q_pred.reshape(-1)-Q_emp.reshape(-1)

In [None]:
yma = np.ma.masked_array(y, mask=np.invert(obs_scheme.mask))

plt.imshow(np.ma.cov(yma[:T-kl_].T), interpolation='None')
plt.show()

plt.imshow(S[0] * W[0], interpolation='None')
plt.show()

plt.plot(np.ma.cov(yma[:T-kl_].T).reshape(-1), (S[0] * W[0]).reshape(-1))
plt.show()

plt.imshow(np.ma.cov(yma[:T-kl_].T) - (S[0] * W[0]), interpolation='None')
plt.colorbar()
plt.show()

In [None]:
data_path = '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/fits/'

save_dict = {'p' : p,
             'n' : n,
             'T' : T,
             'snr' : snr,
             'obs_scheme' : obs_scheme,
             'lag_range' : lag_range,
             'x' : x,
             'y' : y,
             'pars_true' : pars_true,
             'pars_est_m' : pars_est,
             'traces_m' : traces
            }
file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_partial_dat'
np.savez(data_path + file_name, save_dict)


In [None]:
W = [np.ones((p,p), dtype=int)/(T-kl_-1) for m in lag_range]
pars_est = 'default'

In [None]:
W = [np.zeros((p,p), dtype=int) for m in lag_range]
for m in range(len(lag_range)):
    m_ = lag_range[m]
    for t in range(T-kl_):
        a, b = get_observed(t+m_), get_observed(t)
        W[m][np.ix_(a,b)] += 1
    W[m] = 1./np.maximum(W[m]-1, 1)

In [None]:
from scipy import linalg as la

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, np.inf, 10
a, b1, b2, e = 0.01, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
def pars_track(C,X,R,t): 
    proj_errors[t] = np.hstack((calc_subspace_proj_error(pars_true['C'], C), principal_angle(pars_true['C'], C)))
        
    
t = time.time()
_, pars_est, traces = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size,
                                      pars_track=pars_track, W=W)

t = time.time() - t
print_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])
print('ground-truth error: ', f_l2_Hankel_nl(C=pars_true['C'],
               X=np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]),
               R=pars_true['R'],lag_range=lag_range,y=y,ts=np.arange(T-kl_), ms = range(len(lag_range)),
               idx_a=idx_a,idx_b=idx_b,W=W,get_observed=get_observed))

print('final error (est.): ', traces[0][-1])
print('final proj. error (est.): ', str(calc_subspace_proj_error(pars_true['C'], pars_est['C'])))


In [None]:
plt.plot(proj_errors[:,1:])
plt.show()

In [None]:
from scipy import linalg as la

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 500, np.inf, 10
a, b1, b2, e = 0.001, 0.95, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
def pars_track(C,X,R,t): 
    proj_errors[t] = np.hstack((calc_subspace_proj_error(pars_true['C'], C), principal_angle(pars_true['C'], C)))
        
    
t = time.time()
_, pars_est, traces = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size,
                                      pars_track=pars_track, W=W)

t = time.time() - t
print_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])
print('ground-truth error: ', f_l2_Hankel_nl(C=pars_true['C'],
               X=np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]),
               R=pars_true['R'],lag_range=lag_range,y=y,ts=np.arange(T-kl_), ms = range(len(lag_range)),
               idx_a=idx_a,idx_b=idx_b,W=W,get_observed=get_observed))

print('final error (est.): ', traces[0][-1])
print('final proj. error (est.): ', str(calc_subspace_proj_error(pars_true['C'], pars_est['C'])))


In [None]:
plt.plot(proj_errors[:,1:])
plt.show()

In [None]:
from scipy import linalg as la

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 200, np.inf, 100
a, b1, b2, e = 0.0005, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
def pars_track(C,X,R,t): 
    proj_errors[t] = np.hstack((calc_subspace_proj_error(pars_true['C'], C), principal_angle(pars_true['C'], C)))
        
    
t = time.time()
_, pars_est, traces = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size,
                                      pars_track=pars_track, W=W)

t = time.time() - t
print_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])
print('ground-truth error: ', f_l2_Hankel_nl(C=pars_true['C'],
               X=np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]),
               R=pars_true['R'],lag_range=lag_range,y=y,ts=np.arange(T-kl_), ms = range(len(lag_range)),
               idx_a=idx_a,idx_b=idx_b,W=W,get_observed=get_observed))

print('final error (est.): ', traces[0][-1])
print('final proj. error (est.): ', str(calc_subspace_proj_error(pars_true['C'], pars_est['C'])))


In [None]:
plt.plot(proj_errors[:,1:])
plt.show()

In [None]:
from scipy import linalg as la

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = None, np.inf, 500
a, b1, b2, e = 0.025, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
def pars_track(C,X,R,t): 
    proj_errors[t] = np.hstack((calc_subspace_proj_error(pars_true['C'], C), principal_angle(pars_true['C'], C)))
        
    
t = time.time()
_, pars_est, traces = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size,
                                      pars_track=pars_track, W=W)

t = time.time() - t
print_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])
print('ground-truth error: ', f_l2_Hankel_nl(C=pars_true['C'],
               X=np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]),
               R=pars_true['R'],lag_range=lag_range,y=y,ts=np.arange(T-kl_), ms = range(len(lag_range)),
               idx_a=idx_a,idx_b=idx_b,W=W,get_observed=get_observed))

print('final error (est.): ', traces[0][-1])
print('final proj. error (est.): ', str(calc_subspace_proj_error(pars_true['C'], pars_est['C'])))


In [None]:
plt.plot(proj_errors[:,1:])
plt.show()

In [None]:
plt.figure(figsize=(20,8))
plt.imshow(np.hstack([W[m] for m in range(kl_)]), interpolation='None')
plt.colorbar()
plt.show()

In [None]:
plt.figure(figsize=(20,8))
plt.imshow(np.hstack([W[m] for m in range(1)]), interpolation='None')
plt.colorbar()
plt.show()

In [None]:
W = [np.zeros((p,p), dtype=int) for m in lag_range]
S  = [np.zeros((p,p)) for m in range(len(lag_range))]
for m in range(len(lag_range)):
    m_ = lag_range[m]
    for t in range(T-kl_):
        a, b = get_observed(t+m_), get_observed(t)
        W[m][np.ix_(a,b)] += 1
        S[m][np.ix_(a,b)] += np.outer(y[t+m_,a], y[t,b])
    W[m] = 1./np.maximum(W[m], 1)
    
f, ft = 0., 0.
for m in range(len(lag_range)):
    plt.figure(figsize=(18,4))

    pars = pars_est    
    C, Xm, R = pars['C'],pars['X'][m*n:(m+1)*n,:],pars['R']
    plt.subplot(1,4,1)
    a,b = sub_pops[0], sub_pops[0]
    plt.plot(C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)*np.diag(R)[np.ix_(a,b)], S[m][np.ix_(a,b)]*W[m][np.ix_(a,b)], '.')
    f += 0.5*np.sum((C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)*np.diag(R)[np.ix_(a,b)] - S[m][np.ix_(a,b)]*W[m][np.ix_(a,b)])**2)
    plt.subplot(1,4,2)
    #a,b = sub_pops[1], sub_pops[1]
    #plt.plot(C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)*np.diag(R)[np.ix_(a,b)], S[m][np.ix_(a,b)]*W[m][np.ix_(a,b)], '.')
    #f += 0.5*np.sum((C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)*np.diag(R)[np.ix_(a,b)] - S[m][np.ix_(a,b)]*W[m][np.ix_(a,b)])**2)

    pars = pars_true  
    C, Xm, R = pars['C'],pars['X'][m*n:(m+1)*n,:],pars['R']
    plt.subplot(1,4,3)
    a,b = sub_pops[0], sub_pops[0]
    plt.plot(C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)*np.diag(R)[np.ix_(a,b)], S[m][np.ix_(a,b)]*W[m][np.ix_(a,b)], '.')
    ft += 0.5*np.sum((C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)*np.diag(R)[np.ix_(a,b)] - S[m][np.ix_(a,b)]*W[m][np.ix_(a,b)])**2)
    plt.subplot(1,4,4)
    #a,b = sub_pops[1], sub_pops[1]
    #plt.plot(C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)*np.diag(R)[np.ix_(a,b)], S[m][np.ix_(a,b)]*W[m][np.ix_(a,b)], '.')
    #ft += 0.5*np.sum((C[a,:].dot(Xm).dot(C[b,:].T) + (m==0)*np.diag(R)[np.ix_(a,b)] - S[m][np.ix_(a,b)]*W[m][np.ix_(a,b)])**2)

    plt.show()
print(f, ft)

In [None]:
plot_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)


In [None]:
# settings for GROUSE
a_grouse = 0.00001
#tracker = Grouse(p, n, a_grouse )
max_iter_grouse = 10

# fit GROUSE
t = time.time()
print('\n - GROUSE')
tracker.step = a_grouse
ct = 1.
error = np.zeros((max_iter_grouse, n+1))
for i in range(max_iter_grouse):
    if verbose and np.mod(i,max_iter_grouse//10) == 0:
        print('finished % ' + str((100*i)//max_iter_grouse))
    idx = np.random.permutation(T-np.max(lag_range)-1)
    for j in range(len(idx)):
        tracker.consume(y[idx[j],:].reshape(p,1), mask[idx[j],:].reshape(p,1))
        tracker.step = a_grouse / ct
        ct += 0.1

    error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))
pars_est_g = {'C' : tracker.U}

print('final proj. error (est.): ', str(error[-1][0]))

plt.plot(error[:,1:])
plt.title('subspace proj. error (GROUSE)')
plt.show()

# corrected gradients

In [None]:
##################
# Loss functions #
##################


# loss function that the non-weighted stochastic gradient descent actually minimizes
def f_sg(C, X, R, ts, ms, get_observed, W):
    C = C.reshape(p,n)
    X = X.reshape(len(lag_range)*n,n)
    R = R.reshape(p)
    f = 0
    for i in range(len(ts)):
        t,m = ts[i],ms[i]
        m_ = lag_range[m]
        a, b = get_observed(t+m_), get_observed(t)
        Xm = X[m*n:(m+1)*n,:].copy()
        f += ((C[a,:].dot(Xm).dot(C[b,:].T) + \
               (m_==0)* np.diag(R)[np.ix_(a,b)] - \
               np.outer(y[t+m_,a], y[t,b]))**2).sum()
    return 0.5 * f / len(ts)

# loss function based on L2 error for Hankel cov matrix  
def f_ha(C,X,R, ts, ms, get_observed, W):
    C = C.reshape(p,n)
    X = X.reshape(len(lag_range)*n,n)
    R = R.reshape(p)
    S  = [np.zeros((p,p)) for m in range(len(lag_range))]
    Om = [np.zeros((p,p), dtype=bool) for m in range(len(lag_range))]
    for i in range(len(ts)):        
        t,m = ts[i],ms[i]
        m_ = lag_range[m]
        a, b = get_observed(t+m_), get_observed(t)
        S[ m][np.ix_(a, b)] += np.outer(y[t+m_,a], y[t,b])
        Om[m][np.ix_(a, b)] = True
    return 0.5*np.sum([np.sum( (C.dot(X[m*n:(m+1)*n,:]).dot(C.T) + (lag_range[m]==0)*np.diag(R) - S[m]/(t_ij[m]))[Om[m]]**2) for m in range(len(lag_range))])

# loss function for re-weighted stochastic gradient descent
def f_rw(C,X,R, ts, ms, get_observed, W):
    C = C.reshape(p,n)
    X = X.reshape(len(lag_range)*n,n)
    R = R.reshape(p)
    f = 0
    
    for i in range(len(ts)):
        t,m = ts[i],ms[i]
        m_ = lag_range[m]
        a, b = get_observed(t+m_), get_observed(t)
        Xm = X[m*n:(m+1)*n,:].copy()
        f += (((C[a,:].dot(Xm).dot(C[b,:].T) + \
                (m_==0)* np.diag(R)[np.ix_(a,b)] - \
                np.outer(y[t+m,a], y[t,b]))**2)/(t_ij[m])[np.ix_(a,b)]).sum()
    return 0.5*f


#############
# Gradients #
#############

# unweighted stochastic gradients w.r.t. C 
def g_C(C, X, R, ts, ms, get_observed, W): 
    C = C.reshape(p,n)
    grad_C = np.zeros((p,n))
    
    for i in range(len(ts)):
        
        t,m = ts[i],ms[i]
        m_ = lag_range[m]
        a, b = get_observed(t+m_), get_observed(t)
        
        Xm = X[m*n:(m+1)*n,:].copy()
        C___ = C.dot(Xm)   # mad-
        C_tr = C.dot(Xm.T) # ness        
        grad_C[a,:] += C[a,:].dot( C_tr[b,:].T.dot(C_tr[b,:]) ) - np.outer(y[t+m_,a],y[t,b].dot(C_tr[b,:]))
        grad_C[b,:] += C[b,:].dot( C___[a,:].T.dot(C___[a,:]) ) - np.outer(y[t,b],y[t+m_,a].dot(C___[a,:]))
        if m_ ==0:
            anb = np.intersect1d(a,b)        
            grad_C[anb,:] += R[anb].reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])  
            
    return grad_C.reshape(-1) / len(ts)

# re-weighted stochastic gradients w.r.t. C 
def g_rw_C(C, X, R, ts, ms, get_observed, W): 
    C = C.reshape(p,n)
    grad_C = np.zeros((p,n))
    
    for i in range(len(ts)):
        
        t,m = ts[i],ms[i]
        m_ = lag_range[m]
        a, b = get_observed(t+m_), get_observed(t)
        
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        C___ = C.dot(Xm)   # mad-
        C_tr = C.dot(Xm.T) # ness        
        
        for k in a:
            
            WC = C_tr[b,:] / np.maximum(t_ij[m][k,b],1).reshape(-1,1)
            grad_C[k,:] += C[k,:].dot( C_tr[b,:].T.dot(WC) ) 
            grad_C[k,:] -= y[t+m_,k] * y[t,b].dot(WC)
            
        for k in b:
            
            WC = C___[a,:] / np.maximum(t_ij[m][a,k],1).reshape(-1,1)
            grad_C[k,:] += C[k,:].dot( C___[a,:].T.dot(WC) ) 
            grad_C[k,:] -= y[t,k] * y[t+m_,a].dot(WC)
            
        if m_ == 0:
            anb = np.intersect1d(a,b)            
            grad_C[anb,:] += (R[anb]/t_ij[m][anb,anb]).reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])
                            
    return grad_C.reshape(-1) 

# re-weighted stochastic gradients w.r.t. X
def g_rw_X(C, X, R, ts, ms, get_observed, W): 
    X = X.reshape(-1, n)
    grad_X = np.zeros_like(X)
            
    for i in range(len(ts)):
        t,m = ts[i],ms[i]
        m_ = lag_range[m]
        a, b = get_observed(t+m_), get_observed(t)
        
        Xm = X[m*n:(m+1)*n,:].copy()
        for k in a:
            
            S_k = C[b,:].T.dot(C[b,:] / t_ij[m][k,b].reshape(-1,1))
            grad_X[m*n:(m+1)*n,:] += np.outer(C[k,:], C[k,:]).dot(Xm).dot(S_k)
            
            S_k = y[t,b].dot(C[b,:] / t_ij[m][k,b].reshape(-1,1))
            grad_X[m*n:(m+1)*n,:] -= np.outer(y[t+m_,k] * C[k,:], S_k)
            
        if m_ == 0:
            anb  = np.intersect1d(a,b)
            grad_X[:n,:] += C[anb,:].T.dot( (R[anb]/t_ij[m][anb,anb]).reshape(-1,1)*C[anb,:]) 
                
    return grad_X.reshape(-1)

# re-weighted stochastic gradients w.r.t. R
def g_rw_R(C, X, R, ts, ms, get_observed, W):
    grad_R = np.zeros(p)
    
    for i in range(len(ts)):
        t, m = ts[i], ms[i]
        m_ = lag_range[m]
        _, b = get_observed(i) 
        
        if m_==0:
            grad_R[b] += (R[b] + np.sum(C[b,:] * C[b,:].dot(X[:n,:].T),axis=1) - y[t,b]**2) * W[m][b,b]

    return grad_R

In [None]:
from ssidid.SSID_Hankel_loss import f_l2_Hankel_nl, g_l2_Hankel_sgd

ts1 = np.random.choice(T-kl_, T-kl_, replace=False)
msu = np.random.choice(kl_, kl_, replace=False)

ms = np.hstack([msu[i] * np.ones_like(ts1) for i in range(len(msu))])
ts = np.hstack([ ts1 for i in range(len(msu))])

X, R = np.random.normal(size=(len(lag_range)*n,n)), np.random.normal(size=(p))**2
C = np.random.normal(size=(p,n))

t_ij = [np.zeros((p,p), dtype=int) for m in range(len(lag_range))]
for i in range(len(ts)):  
    t,m = ts[i],ms[i]
    m_ = lag_range[m]
    a, b = get_observed(t+m_), get_observed(t)
    t_ij[ms[i]][np.ix_(a, b)] += 1
t_ij = [np.maximum(t_ij[m], 1) for m in range(len(lag_range))]

#t_ij = [len(ts1)*np.ones((p,p)) for m in range(len(lag_range))]

W = [1./t_ij[m] for m in range(len(lag_range))]

# decorating
def fC(C):
    return f_sg(C,X,R, ts, ms, get_observed, W)
def fC_ha(C):
    return f_ha(C,X,R, ts, ms, get_observed, W)
def fC_rw(C):
    return f_rw(C,X,R, ts, ms, get_observed, W)
def gC_rw(C):
    return g_rw_C(C, X, R,ts,ms,get_observed,W)
def gC(C):
    C = C.reshape(p,n)
    grad_C,_,_ = g_l2_Hankel_sgd(C,X,R,y,lag_range,ts1,ms=msu,get_observed=get_observed,linear=False, W=W)
    return grad_C.reshape(-1)

def fC_impl(C): 
    C = C.reshape(p,n)    
    return f_l2_Hankel_nl(C,X,R, y, lag_range=lag_range, ms=msu, get_observed=get_observed, 
                          idx_a=np.arange(p), idx_b=np.arange(p), W=W, ts=ts1)
def fX(X):
    return f_sg(C,X,R, ts, ms, get_observed, W)
def fX_ha(X):
    return f_ha(C,X,R, ts, ms, get_observed, W)
def fX_rw(X):
    return f_rw(C,X,R, ts, ms, get_observed, W)
def gX_rw(X):
    return g_rw_X(C, X, R, ts, ms, get_observed, W)
def gX(X):
    X = X.reshape(len(lag_range)*n,n)
    _,grad_X,_ = g_l2_Hankel_sgd(C,X,R,y,lag_range,ts1,msu,get_observed,linear=False, W=W)
    return grad_X.reshape(-1)
def fX_impl(X): 
    X = X.reshape(len(lag_range)*n,n)
    return f_l2_Hankel_nl(C,X,R, y, lag_range=lag_range, ms=msu, get_observed=get_observed, 
                          idx_a=np.arange(p), idx_b=np.arange(p), W=W, ts=ts1)
def fR(R):
    return f_sg(C,X,R, ts, ms, get_observed, W)
def fR_ha(R):
    return f_ha(C,X,R, ts, ms, get_observed, W)
def fR_rw(R):
    return f_rw(C,X,R, ts, ms, get_observed, W)
def gR_rw(R):
    return g_rw_R(C, X, R, ts, ms, get_observed, W)
def gR(R):
    R = R.reshape(p)
    _,_,grad_R = g_l2_Hankel_sgd(C,X,R,y,lag_range,ts1,np.unique(ms),get_observed,linear=False, W=W)
    return grad_R.reshape(-1)
def fR_impl(R): 
    R = R.reshape(p)
    return f_l2_Hankel_nl(C,X,R, y, lag_range=lag_range, ms=np.unique(ms), get_observed=get_observed, 
                          idx_a=np.arange(p), idx_b=np.arange(p), W=W, ts=ts1)

print('m:', np.sort(msu))
print('\n')

print('grad C (Pr. Er)', sp.optimize.check_grad(fC,    gC, C.reshape(-1)))
print('grad C (Hankel)', sp.optimize.check_grad(fC_ha, gC, C.reshape(-1)))
print('grad C (corr. )', sp.optimize.check_grad(fC_rw, gC, C.reshape(-1)))
print('grad C (impl. )', sp.optimize.check_grad(fC_impl, gC, C.reshape(-1)))

print('\n')

print('grad X (Pr. Er)', sp.optimize.check_grad(fX,    gX, X.reshape(-1)))
print('grad X (Hankel)', sp.optimize.check_grad(fX_ha, gX, X.reshape(-1)))
print('grad X (corr. )', sp.optimize.check_grad(fX_rw, gX, X.reshape(-1)))
print('grad X (impl. )', sp.optimize.check_grad(fX_impl, gX, X.reshape(-1)))

print('\n')

print('grad R (Pr. Er)', sp.optimize.check_grad(fR,    gR, R.reshape(-1)))
print('grad R (Hankel)', sp.optimize.check_grad(fR_ha, gR, R.reshape(-1)))
print('grad R (corr. )', sp.optimize.check_grad(fR_rw, gR, R.reshape(-1)))
print('grad R (impl. )', sp.optimize.check_grad(fR_impl, gR, R.reshape(-1)))

np.sort(ts1[ts1<=250]), np.sort(ts1[ts1>250]), np.min(np.diff(np.sort(ts1)))

In [None]:
g1 = g_l2_Hankel_sgd(C,X,R,y,lag_range,ts,(1,),get_observed,W=W)[0]
g2 = g_l2_Hankel_sgd(C,X,R,y,lag_range,ts,(3,),get_observed,W=W)[0]
g12 = g_l2_Hankel_sgd(C,X,R,y,lag_range,ts,(1,3),get_observed,W=W)[0]
g1+g2, g12

f1 = f_l2_Hankel_nl(C,X,R, y, lag_range=lag_range, ms=(1,), get_observed=get_observed, 
                          idx_a=np.arange(p), idx_b=np.arange(p), W=W, ts=ts1)
f2 = f_l2_Hankel_nl(C,X,R, y, lag_range=lag_range, ms=(3,), get_observed=get_observed, 
                          idx_a=np.arange(p), idx_b=np.arange(p), W=W, ts=ts1)
f12 = f_l2_Hankel_nl(C,X,R, y, lag_range=lag_range, ms=(1,3), get_observed=get_observed, 
                          idx_a=np.arange(p), idx_b=np.arange(p), W=W, ts=ts1)
f1+f2, f12, g1+g2, g12

# d/dA

In [None]:
from ssidid.SSID_Hankel_loss import f_l2_Hankel_nl, g_l2_Hankel_sgd

ts1 = np.random.choice(T-kl_, 1, replace=False)
msu = np.random.choice(kl_, 1, replace=False)

ms = np.hstack([msu[i] * np.ones_like(ts1) for i in range(len(msu))])
ts = np.hstack([ ts1 for i in range(len(msu))])

A, X0, R = np.random.normal(size=(n,n)), np.random.normal(size=(n,n)), np.random.normal(size=(p))**2
C = np.random.normal(size=(p,n))

t_ij = [np.zeros((p,p), dtype=int) for m in range(len(lag_range))]
for i in range(len(ts)):  
    t,m = ts[i],ms[i]
    m_ = lag_range[m]
    a, b = get_observed(t+m_), get_observed(t)
    t_ij[ms[i]][np.ix_(a, b)] += 1
t_ij = [np.maximum(t_ij[m], 1) for m in range(len(lag_range))]

t_ij = [len(ts1)*np.ones((p,p)) for m in range(len(lag_range))]

W = [1./t_ij[m] for m in range(len(lag_range))]

S  = [np.zeros((p,p)) for m in range(len(lag_range))]
for i in range(len(ts)):        
    t,m = ts[i],ms[i]
    m_ = lag_range[m]
    a, b = get_observed(t+m_), get_observed(t)
    S[m][np.ix_(a, b)] += np.outer(y[t+m_,a], y[t,b])

# decorating
def fA(A):
    A = A.reshape(n,n)
    X = np.vstack([np.linalg.matrix_power(A,m).dot(X0) for m in range(kl_)])
    return f_ha(C,X,R, ts, ms, get_observed, t_ij)


# unweighted stochastic gradients w.r.t. C 

def gA(A):
    return g_A(C, A, X0, R, ts, ms, get_observed, t_ij)
    
def g_A(C, A, X0, R, ts, ms, get_observed, t_ij): 
    A = A.reshape(n,n)
    return g_l2_Hankel(C, A, X0, kl_, 0, n,Qs,idx_grp,co_obs)


def g_l2_Hankel(C, A, Pi,k,l,n,Qs,idx_grp,co_obs):
    "returns overall l2 Hankel reconstruction gradient w.r.t. A, B, C"

    p,n = C.shape

    Aexpm = np.zeros((n,n,k+l))
    Aexpm[:,:,0]= np.eye(n)
    for m in range(1,k+l):
        Aexpm[:,:,m] = A.dot(Aexpm[:,:,m-1])

    grad_A = np.zeros((n,n))
    for m in range(1,k+l-1):

        AmPi = Aexpm[:,:,m].dot(Pi)

        # the expensive part: handling p x p ovserved-space matrices 
        CAPiC_L, CTC = C.dot(AmPi).dot(C.T) - S[m]*W[m], np.zeros((n,n))
        for i in range(len(idx_grp)):
            a,b = idx_grp[i],co_obs[i]
            Ci = CAPiC_L[np.ix_(a,b)].dot(C[b,:])
            CiT =  CAPiC_L[np.ix_(b,a)].T.dot(C[b,:])
            CTC += C[a,:].T.dot(Ci)
        grad_A += g_A_l2_block(CTC,Aexpm,m,Pi)
    return grad_A.reshape(-1)
def g_A_l2_block(CTC,Aexpm,m,Pi):
    "returns l2 Hankel reconstr. gradient w.r.t. A for a single Hankel block"
    
    CTCPi = CTC.dot(Pi)
    grad = np.zeros(Aexpm.shape[:2])
    for q in range(m):
        grad += Aexpm[:,:,q].T.dot(CTCPi.dot(Aexpm[:,:,m-1-q].T))
    return grad


print('m:', np.sort(msu))
print('\n')

print('grad A ', sp.optimize.check_grad(fA, gA, A.reshape(-1)))

In [None]:
import numpy as np
import scipy as sp
from scipy import optimize

m = 1
lag_range = np.arange(m+1)
kl_ = len(lag_range) + 1
T = 10 + kl_

p,n = 5, 2
A = np.random.normal(size=(n,n))
C = np.random.normal(size=(p,n))

def get_observed(t):
    return np.arange(p)


t_ij = [(T-kl_)*np.ones((p,p)) for m in range(len(lag_range))]
W = [1/(t_ij[m]-1) for m in range(len(lag_range))]

x = np.zeros((T, 2))
x[0] = np.random.normal(size=n)
for t in range(1,T):
    x[t] = A.dot(x[t-1])
y = x.dot(C.T)
y -= y[:T-kl_].mean(axis=0)


X0 = np.cov(x[:T-kl_].T)
X = np.vstack([np.linalg.matrix_power(A,m).dot(X0) for m in range(kl_)])
R = np.zeros(p)

ts = np.arange(T-kl_)
ms = m * np.ones_like(ts)

S  = [np.zeros((p,p)) for m in range(len(lag_range))]
Om = [np.zeros((p,p), dtype=bool) for m in range(len(lag_range))]
ym = [ y[m:T-kl_+m].mean(axis=0) for m in range(len(lag_range))]
for i in range(len(ts)):        
    t,m = ts[i],ms[i]
    m_ = lag_range[m]
    a, b = get_observed(t+m_), get_observed(t)
    S[ m][np.ix_(a, b)] += np.outer(y[t+m_,a] - ym[m], y[t,b])
    Om[m][np.ix_(a, b)] = True
    
def f_ha(C,X,R, ts, ms, get_observed, t_ij):
    C = C.reshape(p,n)
    R = R.reshape(p)
    return 0.5*np.sum([np.sum( (C.dot(X[m*n:(m+1)*n,:]).dot(C.T) + (lag_range[m]==0)*np.diag(R) - S[m]/t_ij[m])[Om[m]]**2) for m in range(len(lag_range))])

# decorating
def fA(A):
    A = A.reshape(n,n)
    X = np.vstack([np.linalg.matrix_power(A,m).dot(X0) for m in range(kl_)])
    return f_ha(C,X,R, ts, ms, get_observed, t_ij)


# unweighted stochastic gradients w.r.t. C 
def gA(A):    
    A = A.reshape(n,n)
    grad_A = np.zeros((n,n))    
    CC = C.T.dot(C)
    for q in range(m):
        Am   = np.linalg.matrix_power(A, m)
        Aq   = np.linalg.matrix_power(A, q)
        Am1q = np.linalg.matrix_power(A, m-1-q)
        grad_A += Aq.T.dot( CC.dot(Am.dot(X0)).dot(CC) - C.T.dot(S[m]*W[m]).dot(C)).dot(X0).dot(Am1q)
    return grad_A.reshape(-1)

print('f, g', (fA(A), gA(A)))

print('grad A ', sp.optimize.check_grad(fA, gA, A.reshape(-1)))

# formula check

- numerically comparing implemented gradients with non-vectorised analytic formula 

In [None]:
t, m = np.random.choice(p, 1), 0
m_ = m

C, Xm, R = pars_true['C'], pars_true['X'][m*n:(m+1)*n, :], pars_true['R']
p,n = C.shape
grad_C = np.zeros((p,n))
grad_X = np.zeros(pars_true['X'].shape)
grad_R = np.zeros(p)
idx_ct = np.zeros((p,2),dtype=np.int32)

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

a,b = obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]
a,b = np.where(a)[1], np.where(b)[1]

anb = np.intersect1d(a,b)
a_ = np.setdiff1d(a,b)
b_ = np.setdiff1d(b,a)

yf = y[t+m_,a]
yp = y[t,b]

Om = np.outer(obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]).astype(bool)
L = np.outer(y[t+m,:], y[t,:])
grad_C = np.zeros_like(C)
for k in range(p):
    for i in range(p):
        for j in range(p):        

            if Om[i,j]:
                #print(i,j)
                Ci, Cj = C[i,:], C[j,:]
                if k==i and k!=j:
                    #print('1')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                if k==j and k!=i:
                    #print('2')
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                if k==i and k==j:
                    #print('3')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                    if m ==0:
                        grad_C[k,:] += R[i] * (Cj.dot(Xm.T)+Ci.dot(Xm))
                    
print('non-vectorised', grad_C)
grad_C_blunt = grad_C.copy()
#g_C_l2_Hankel_vector_pair(grad_C, m_, C, Xm, R, a, b, ab, CC_a, CC_b, yp, yf)    

grad_C = np.zeros((p,n))

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

grad_C[a,:] += C[a,:].dot( C_tr[b,:].T.dot(C_tr[b,:]) ) - np.outer(yf,yp.dot(C_tr[b,:]))
grad_C[b,:] += C[b,:].dot( C___[a,:].T.dot(C___[a,:]) ) - np.outer(yp,yf.dot(C___[a,:]))

# correction for variables not observed both at t+m_ and t  
#if a_.size > 0:
#    grad_C[a_,:] -= (np.sum(C[a_,:]*C_tr[a_,:],axis=1) - y[t+m,a_]*y[t,a_]).reshape(-1,1) * C_tr[a_,:]
#if b_.size > 0:
#    grad_C[b_,:] -= (np.sum(C[b_,:]*C___[b_,:],axis=1) - y[t+m,b_]*y[t,b_]).reshape(-1,1) * C___[b_,:]

if m_==0: 
    grad_C[anb,:] += R[anb].reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])
print('vectorised', grad_C)

print('overlap', anb)

assert np.allclose(grad_C_blunt, grad_C)

plt.imshow(Om, interpolation='None')
plt.title('observation scheme (\Omega)')
plt.show()