In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error

run = '_real_stitch'

# define problem size
lag_range = np.arange(10)
kl_ = np.max(lag_range)+1
p, n, T = 10000, 20, 10000-kl_ + kl_

nr = 0 # number of real eigenvalues
snr = (9., 9.)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.95, 0.90, 0.95


# I/O matter
mmap, chunksize = False, np.min((p,2000))
verbose=True

# create subpopulations

sub_pops = (np.arange(p),)

reps = 1
obs_pops = np.concatenate([ np.arange(len(sub_pops)) for r in range(reps) ])
obs_time = np.linspace(0,T, len(obs_pops)+1)[1:].astype(int)
obs_scheme = ObservationScheme(p=p, T=T, 
                                sub_pops=sub_pops, 
                                obs_pops=obs_pops, 
                                obs_time=obs_time)
obs_scheme.comp_subpop_stats()

missing_at_random, frac_obs = False, 0.5
if missing_at_random:
    n_obs = np.ceil(p * frac_obs)
    mask = np.zeros((T,p))
    for t in range(T):
        for i in range(len(obs_time)):
            if t < obs_time[i]:
                mask[t, np.random.choice(p, n_obs, replace=False)] = 1
                break                       
    obs_scheme.mask = mask
    del mask
else:
    if p*T < 1e8:
        obs_scheme.gen_mask_from_scheme()
        obs_scheme.use_mask = False

print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

if p*T < 1e8:
    plt.figure(figsize=(20,10))
    plt.imshow(obs_scheme.mask.T, interpolation='None', aspect='auto')
    plt.grid('off')
    plt.show()
    
pars_est = 'default'



In [None]:
obs_scheme.gen_mask_from_scheme()
obs_scheme.use_mask = False
plt.figure(figsize=(20,10))
plt.imshow(obs_scheme.mask[:,1:-1:100].T, interpolation='None', aspect='auto')
plt.grid('off')
plt.show()

In [None]:
from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl, f_l2_Hankel_comp_Q_Om

data_path = '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/fits/compare_vs_grouse/'

pars_true, x, y, _, _ = gen_data(p,n,lag_range,T, nr,
                                 eig_m_r, eig_M_r, 
                                 eig_m_c, eig_M_c,
                                 mmap, chunksize,
                                 data_path,
                                 snr=snr, whiten=whiten)    

if len(obs_scheme.sub_pops) > 1:
    disp('ensuring zero-mean data for given observation scheme')
    if mmap: 
        for i in progprint_xrange(p//chunksize, perline=10):
            y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, i*chunksize:(i+1)*chunksize] = y[:, i*chunksize:(i+1)*chunksize] - y[:, i*chunksize:(i+1)*chunksize].mean(axis=0)
            del y
        if (p//chunksize)*chunksize < p:
            y = np.memmap(data_path+'y', dtype=np.float, mode='r+', shape=(T,p))
            y[:, (p//chunksize)*chunksize:] = y[:, (p//chunksize)*chunksize:] - y[:, (p//chunksize)*chunksize:].mean(axis=0)
            del y        
        y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p))
    else:
        y -= y.mean(axis=0)

idx_a = np.sort(np.random.choice(p, 1000, replace=False)) if p > 1000 else np.arange(p)
idx_b = idx_a.copy()

W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b)
Qs, Om = f_l2_Hankel_comp_Q_Om(n=n,y=y,lag_range=lag_range,obs_scheme=obs_scheme,
                      idx_a=idx_a,idx_b=idx_b,W=W,
                      mmap=mmap,data_path=data_path,ts=None,ms=None)


In [None]:
save_dict = {'p' : p,
             'n' : n,
             'T' : T,
             'snr' : snr,
             'obs_scheme' : obs_scheme,
             'lag_range' : lag_range,
             'x' : x,
             'mmap' : mmap,
             'y' : data_path if mmap else y,
             'pars_true' : pars_true,
             'pars_est' : pars_est,
             'idx_a' : idx_a,
             'idx_b' : idx_b,
             'W' : W,
             'Qs' : Qs,
             'Om' : Om
            }
file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_run' + str(run)
np.savez(data_path + file_name, save_dict)

# Load stored data

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error

# create subpopulations
#data_path = '../fits/lsfm/grid_quick/'
#idx_str = 'small'
#idx_fish = np.load(data_path + 'idx_' + idx_str + '.npy')
data_path = '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/fits/compare_vs_grouse/'

p = 10000
T = 10000

n = 20

lag_range = np.arange(0,10)
kl_ = np.max(lag_range)+1

snr = (9., 9.)
run = '_test'
verbose=True

file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_run' + str(run)

load_file = np.load(data_path + file_name + '.npz')['arr_0'].tolist()
p,n,T,lag_range = load_file['p'], load_file['n'], load_file['T'], load_file['lag_range']
y, x, snr, idx_a, idx_b = load_file['y'], load_file['x'], load_file['snr'], load_file['idx_a'], load_file['idx_b'] 
pars_true, pars_est, obs_scheme = load_file['pars_true'], load_file['pars_est'],load_file['obs_scheme']
W, Om= load_file['W'], load_file['Om']
Qs = [np.load(data_path+'Qs_'+str(lag_range[m])+'.npy') for m in range(len(lag_range)) ]
W = obs_scheme.comp_coocurrence_weights(lag_range, sso=True, idx_a=idx_a, idx_b=idx_b) if W is None else W


mmap = load_file['mmap']
y = np.memmap(data_path+'y', dtype=np.float, mode='r', shape=(T,p)) if type(load_file['y'])==str else load_file['y']

print('(T,p,n)', (T,p,n))
print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,_,False,data_path)
print('loss: ', f_l2_Hankel_nl(C=pars_est['C'],
                               X=pars_est['X'],
                               R=pars_est['R'],
                               Qs=Qs,
                               Om=Om,
                               lag_range=lag_range,
                               ms=range(len(lag_range)),
                               idx_a=idx_a,
                               idx_b=idx_b))

In [None]:
print('(T,p,n)', (T,p,n))
print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,_,False,data_path)
print('loss: ', f_l2_Hankel_nl(C=pars_est['C'],
                               X=pars_est['X'],
                               R=pars_est['R'],
                               Qs=Qs,
                               Om=Om,
                               lag_range=lag_range,
                               ms=range(len(lag_range)),
                               idx_a=idx_a,
                               idx_b=idx_b))

# Fit models

In [None]:
from scipy import linalg as la

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 1, 100, 20
a, b1, b2, e = 0.05, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
    
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_epoch_size=max_zip_size,
                                      pars_track=pars_track)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()

In [None]:
from scipy import linalg as la

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 10, 50, 50
a, b1, b2, e = 0.005, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
    
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_epoch_size=max_zip_size,
                                      pars_track=pars_track)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()

In [None]:
from scipy import linalg as la

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 10, 100, 50
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
    
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_epoch_size=max_zip_size,
                                      pars_track=pars_track)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()

In [None]:
from scipy import linalg as la

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 100, 100, 50
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
    
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_epoch_size=max_zip_size,
                                      pars_track=pars_track)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()

In [None]:
principal_angle(pars_est['C'], pars_true['C'])

In [None]:
from scipy import linalg as la

parametrization='nl'
sso = True

# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = 100, 100, 20
a, b1, b2, e = 0.0001, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros((max_iter,n+1))
def principal_angle(A, B):
    "A and B must be column-orthogonal."    
    A = np.atleast_2d(A).T if (A.ndim<2) else A
    B = np.atleast_2d(B).T if (B.ndim<2) else B
    A = la.orth(A)
    B = la.orth(B)
    svd = la.svd(A.T.dot(B))
    return np.arccos(np.minimum(svd[1], 1.0)) / (np.pi/2)
    
def pars_track(pars,t): 
    C = pars[0]
    proj_errors[t] = np.hstack((0, principal_angle(pars_true['C'], C)))
            
_, pars_est, traces, Qs, Om, W, t = run_bad(lag_range=lag_range,n=n,y=y, idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      parametrization=parametrization, sso=sso,
                                      Qs=Qs, Om=Om, W=W,
                                      alpha=a,b1=b1,b2=b2,e=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_epoch_size=max_zip_size,
                                      pars_track=pars_track)

print_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])

plt.plot(proj_errors[:,1:])
plt.show()

In [None]:
save_dict = {'p' : p,
             'n' : n,
             'T' : T,
             'snr' : snr,
             'obs_scheme' : obs_scheme,
             'lag_range' : lag_range,
             'x' : x,
             'mmap' : mmap,
             'y' : data_path if mmap else y,
             'pars_true' : pars_true,
             'pars_est' : pars_est,
             'idx_a' : idx_a,
             'idx_b' : idx_b,
             'W' : W,
             'Qs' : Qs,
             'Om' : Om
            }
file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_run' + str(run)
np.savez(data_path + file_name, save_dict)

In [None]:
# settings for GROUSE
a_grouse = 0.000005
#tracker = Grouse(p, n, a_grouse )
max_iter_grouse = 20

# fit GROUSE
t = time.time()
print('\n - GROUSE')
tracker.step = a_grouse
ct = 1.
error = np.zeros((max_iter_grouse, n+1))
for i in range(max_iter_grouse):
    if verbose and np.mod(i,max_iter_grouse//10) == 0:
        print('finished % ' + str((100*i)//max_iter_grouse))
    idx = np.random.permutation(T-np.max(lag_range)-1)
    for j in range(len(idx)):
        tracker.consume(y[idx[j],:].reshape(p,1), obs_scheme.mask[idx[j],:].reshape(p,1))
    tracker.step = a_grouse / ct
    ct += 0.1

    error[i] = np.hstack((calc_subspace_proj_error(pars_true['C'], tracker.U), principal_angle(pars_true['C'], tracker.U)))
pars_est_g = {'C' : tracker.U}

print('final proj. error (est.): ', str(error[-1][0]))

plt.plot(error[:,1:])
plt.title('subspace proj. error (GROUSE)')
plt.show()

In [None]:
save_dict = {'p' : p,
             'n' : n,
             'T' : T,
             'obs_scheme' : obs_scheme,
             'lag_range' : lag_range,
             'pars_true' : pars_true,
             'pars_g' : pars_est_g
            }
file_name = 'p' + str(p) + 'n' + str(n) + 'T' + str(T) + 'snr' + str(np.int(np.mean(snr)//1)) + '_run' + str(run) + '_GROUSE'
np.savez(data_path + file_name, save_dict)

In [None]:
plot_slim(Qs,Om,lag_range,pars_est,idx_a,idx_b,traces,False,data_path)

# check gradients

In [None]:
from ssidid.SSID_Hankel_loss import f_l2_Hankel_nl, g_l2_Hankel_sgd_nl, g_l2_Hankel_sgd_ln 
from ssidid.SSID_Hankel_loss import g_l2_Hankel_sgd_nl_sso, g_l2_Hankel_sgd_ln_sso 
from ssidid.SSID_Hankel_loss import f_l2_Hankel_comp_Q_Om

##################
# Grab test data #
##################

ts1 = np.random.choice(T-kl_, 10, replace=False)
msu = np.random.choice(kl_, kl_, replace=False)

ms = np.hstack([msu[i] * np.ones_like(ts1) for i in range(len(msu))])
ts = np.hstack([ ts1 for i in range(len(msu))])

# gen system pars
nr = 0
nc, nc_u = n - nr, (n - nr)//2
Q, D = np.zeros((n,n), dtype=complex), np.zeros(n, dtype=complex)
# draw real eigenvalues and eigenvectors
D[:nr] = np.linspace(0.8, 0.99, nr)
Q[:,:nr] = np.random.normal(size=(n,nr))
Q[:,:nr] /= np.sqrt((Q[:,:nr]**2).sum(axis=0)).reshape(1,nr)
# draw complex eigenvalues and eigenvectors
circs = np.exp(2 * 1j * np.pi * np.random.vonmises(mu=0, kappa=1000, size=nc_u))
scales = np.random.uniform(size=nc_u)
ev_c_r, ev_c_c = scales * np.real(circs), scales * np.imag(circs)
V = np.random.normal(size=(n,n))
for i in range(nc_u):
    Vi = V[:,i*2:(i+1)*2] / np.sqrt( np.sum(V[:,i*2:(i+1)*2]**2) )
    Q[:,nr+i], Q[:,nr+nc_u+i] = Vi[:,0]+1j*Vi[:,1], Vi[:,0]-1j*Vi[:,1] 
    D[nr+i], D[nr+i+nc_u] = ev_c_r[i]+1j*ev_c_c[i], ev_c_r[i]-1j*ev_c_c[i]

A = Q.dot(np.diag(D)).dot(np.linalg.inv(Q))
assert np.allclose(A, np.real(A))
A = np.real(A)
B = np.random.normal(size=(n,n))
X, R = np.random.normal(size=(len(lag_range)*n,n)), np.random.normal(size=(p))**2
X_ = X.copy() # solves namespace issue in decoration fof loss functions for d/dC and d/dR
C = np.random.normal(size=(p,n))

get_observed = obs_scheme.gen_get_observed()

##################
# Loss functions #
##################

# loss function that the non-weighted stochastic gradient descent actually minimizes
def f_sg(C, X, R, y, ts, ms, get_observed, W):
    C = C.reshape(p,n)
    X = X.reshape(len(lag_range)*n,n)
    R = R.reshape(p)
    f = 0
    for i in range(len(ts)):
        t,m = ts[i],ms[i]
        m_ = lag_range[m]
        a, b = get_observed(t+m_), get_observed(t)
        Xm = X[m*n:(m+1)*n,:].copy()
        f += ((C[a,:].dot(Xm).dot(C[b,:].T) + \
               (m_==0)* np.diag(R)[np.ix_(a,b)] - \
               np.outer(y[t+m_,a], y[t,b]))**2).sum()
    return 0.5 * f / len(ts)

# loss function for re-weighted stochastic gradient descent
def f_rw(C,X,R, y, ts, ms, get_observed, W):
    C = C.reshape(p,n)
    X = X.reshape(len(lag_range)*n,n)
    R = R.reshape(p)
    f = 0
    
    for i in range(len(ts)):
        t,m = ts[i],ms[i]
        m_ = lag_range[m]
        a, b = get_observed(t+m_), get_observed(t)
        Xm = X[m*n:(m+1)*n,:].copy()
        f += (((C[a,:].dot(Xm).dot(C[b,:].T) + \
                (m_==0)* np.diag(R)[np.ix_(a,b)] - \
                np.outer(y[t+m,a], y[t,b]))**2)*W[m][np.ix_(a,b)]).sum()
    return 0.5*f

def get_WSQsOm(sso):
    if sso:

        idx_grp,get_idx_grp = obs_scheme.idx_grp, obs_scheme.gen_get_idx_grp()
        obs_time = obs_scheme._obs_time
        ng = len(idx_grp)
        W = [np.zeros((ng,ng), dtype=int) for m in lag_range]
        for m in msu:            
            m_ = lag_range[m]
            for t in ts1:
                is_, js_ = get_idx_grp(t+m_), get_idx_grp(t)
                W[m][np.ix_(is_,js_)] += 1
            W[m] = 1./(np.maximum(W[m], 1))

        Om = [np.zeros((p,p), dtype=bool) for m in range(len(lag_range))]
        S  = [np.zeros((p,p)) for m in range(len(lag_range))]
        Qs  = [np.zeros((p,p)) for m in range(len(lag_range))]

        for m in msu:
            m_ = lag_range[m]
            for t in ts1:
                a, b = get_observed(t+m_), get_observed(t)
                Om[m][np.ix_(a,b)] = True
                S[m][np.ix_(a,b)] += np.outer(y[t+m_,a], y[t,b])
            for i in range(len(idx_grp)):
                for j in range(len(idx_grp)):
                    Qs[m][np.ix_(idx_grp[i], idx_grp[j])] = S[m][np.ix_(idx_grp[i], idx_grp[j])] * W[m][i,j]            
    else:

        W  = [np.zeros((p,p), dtype=int) for m in lag_range]
        S  = [np.zeros((p,p)) for m in range(len(lag_range))]
        Om = [np.zeros((p,p), dtype=bool) for m in range(len(lag_range))]
        for m in msu:            
            m_ = lag_range[m]
            for t in ts1:
                a, b = get_observed(t+m_), get_observed(t)
                W[m][np.ix_(a,b)] += 1
                S[m][np.ix_(a,b)] += np.outer(y[t+m_,a], y[t,b])
                Om[m][np.ix_(a,b)] = True
            W[m] = 1./(np.maximum(W[m], 1))
        Qs = [W[m]*S[m] for m in range(len(W))]    
    return W,S,Qs,Om

def decorate_fs_gs(W,S,Qs,Om,parametrization,sso):
    
    g_l2_Hankel_ln = g_l2_Hankel_sgd_ln_sso if sso else g_l2_Hankel_sgd_ln
    g_l2_Hankel_nl = g_l2_Hankel_sgd_nl_sso if sso else g_l2_Hankel_sgd_nl

    def f_ha(C,X,R, y, ts, ms, get_observed, W):
        C = C.reshape(p,n)
        X = X.reshape(len(lag_range)*n,n)
        R = R.reshape(p)
        return 0.5*np.sum([np.sum( (C.dot(X[m*n:(m+1)*n,:]).dot(C.T) + (lag_range[m]==0)*np.diag(R) - S[m]*W[m])[Om[m]]**2) for m in range(len(lag_range))]), S, Om

    # decorating
    def fC(C):
        X = X_.copy() if parametrization=='nl' else np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_sg(C,X,R,y, ts, ms, get_observed, W)
    def fC_ha(C):
        X = X_.copy() if parametrization=='nl' else np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_ha(C,X,R,y, ts, ms, get_observed, W)[0]
    def fC_rw(C):
        X = X_.copy() if parametrization=='nl' else np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_rw(C,X,R,y, ts, ms, get_observed, W)
    def fC_impl(C): 
        C = C.reshape(p,n)    
        X = X_.copy() if parametrization=='nl' else np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_l2_Hankel_nl(C,X,R,Qs=Qs,Om=Om,lag_range=lag_range,ms=msu,
                              idx_a=np.arange(p), idx_b=np.arange(p), anb=np.arange(p), idx_Ra=np.arange(p), idx_Rb=np.arange(p))
    def gC(C):
        C = C.reshape(p,n)    
        if parametrization == 'nl':         
            grad_C,_,_ = g_l2_Hankel_nl(C,X,R,y,lag_range,ts=ts1,ms=msu,
                                            obs_scheme=obs_scheme,W=W)
        elif  parametrization == 'ln':
            grad_C,_,_,_ = g_l2_Hankel_ln(C,A,B,R,y,lag_range,ts=ts1,ms=msu,
                                            obs_scheme=obs_scheme,W=W)
        return grad_C.reshape(-1)

    def fX(X):
        return f_sg(C,X,R, y, ts, ms, get_observed, W)
    def fX_ha(X):
        return f_ha(C,X,R, y, ts, ms, get_observed, W)[0]
    def fX_rw(X):
        return f_rw(C,X,R, y, ts, ms, get_observed, W)
    def fX_impl(X): 
        X = X.reshape(len(lag_range)*n,n)
        return f_l2_Hankel_nl(C,X,R,Qs=Qs,Om=Om,lag_range=lag_range,ms=msu,
                              idx_a=np.arange(p), idx_b=np.arange(p), anb=np.arange(p), idx_Ra=np.arange(p), idx_Rb=np.arange(p))
    def gX(X):
        X = X.reshape(len(lag_range)*n,n)
        _,grad_X,_ = g_l2_Hankel_nl(C,X,R,y,lag_range,ts1,msu,
                                        obs_scheme=obs_scheme,W=W)
        return grad_X.reshape(-1)

    def fR(R):
        X = X_.copy() if parametrization=='nl' else np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_sg(C,X,R, y, ts, ms, get_observed, W)
    def fR_ha(R):
        X = X_.copy() if parametrization=='nl' else np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_ha(C,X,R, y, ts, ms, get_observed, W)[0]
    def fR_rw(R):
        X = X_.copy() if parametrization=='nl' else np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_rw(C,X,R, y, ts, ms, get_observed, W)
    def fR_impl(R): 
        X = X_.copy() if parametrization=='nl' else np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        R = R.reshape(p)
        return f_l2_Hankel_nl(C,X,R,Qs=Qs,Om=Om,lag_range=lag_range,ms=msu,
                              idx_a=np.arange(p), idx_b=np.arange(p), anb=np.arange(p), idx_Ra=np.arange(p), idx_Rb=np.arange(p))
    def gR(R):
        R = R.reshape(p)
        if parametrization == 'nl':         
            _,_,grad_R = g_l2_Hankel_nl(C,X,R,y,lag_range,ts=ts1,ms=msu,
                                            obs_scheme=obs_scheme,W=W)
        elif  parametrization == 'ln':
            _,_,_,grad_R = g_l2_Hankel_ln(C,A,B,R,y,lag_range,ts=ts1,ms=msu,
                                            obs_scheme=obs_scheme,W=W)
        return grad_R.reshape(-1)

    def fA(A):
        A = A.reshape(n,n)
        X = np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_sg(C,X,R,y, ts, ms, get_observed, W)
    def fA_ha(A):
        A = A.reshape(n,n)
        X = np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_ha(C,X,R,y, ts, ms, get_observed, W)[0]
    def fA_rw(A):
        A = A.reshape(n,n)
        X = np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_rw(C,X,R, y, ts, ms, get_observed, W)
    def fA_impl(A): 
        A = A.reshape(n,n)
        X = np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_l2_Hankel_nl(C,X,R,Qs=Qs,Om=Om,lag_range=lag_range,ms=msu,
                              idx_a=np.arange(p), idx_b=np.arange(p), anb=np.arange(p), idx_Ra=np.arange(p), idx_Rb=np.arange(p))
    def gA(A):
        A = A.reshape(n,n)
        return g_l2_Hankel_ln(C,A,B,R,y,lag_range,ts1,msu,
                                  obs_scheme=obs_scheme,W=W)[1].reshape(-1)

    def fB(B):
        B = B.reshape(n,n)
        X = np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_sg(C,X,R,y, ts, ms, get_observed, W)
    def fB_ha(B):
        B = B.reshape(n,n)
        X = np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_ha(C,X,R,y, ts, ms, get_observed, W)[0]
    def fB_rw(B):
        B = B.reshape(n,n)
        X = np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_rw(C,X,R, y, ts, ms, get_observed, W)
    def fB_impl(B):
        B = B.reshape(n,n)
        X = np.vstack([np.linalg.matrix_power(A,m).dot(B.dot(B.T)) for m in range(kl_)])
        return f_l2_Hankel_nl(C,X,R,Qs=Qs,Om=Om,lag_range=lag_range,ms=msu,
                              idx_a=np.arange(p), idx_b=np.arange(p), anb=np.arange(p), idx_Ra=np.arange(p), idx_Rb=np.arange(p))
    def gB(B):
        B = B.reshape(n,n)
        return g_l2_Hankel_ln(C,A,B,R,y,lag_range,ts1,msu,
                                  obs_scheme=obs_scheme,W=W)[2].reshape(-1)
    
    return (fC, fC_ha, fC_rw, fC_impl, 
            fX, fX_ha, fX_rw, fX_impl, 
            fR, fR_ha, fR_rw, fR_impl, 
            fA, fA_ha, fA_rw, fA_impl, 
            fB, fB_ha, fB_rw, fB_impl,
            gC, gX, gR, gA, gB)

def check_grads(parametrization, sso):
    
    W,S,Qs,Om =  get_WSQsOm(sso)
    (fC, fC_ha, fC_rw, fC_impl, fX, fX_ha, fX_rw, fX_impl, fR, fR_ha, fR_rw, fR_impl, fA, fA_ha, fA_rw, fA_impl, fB, fB_ha, fB_rw, fB_impl, gC, gX, gR, gA, gB) = decorate_fs_gs(W,S,Qs,Om,parametrization,sso)

    print('m:', np.sort(msu))
    print('\n')

    print('parametrization (affects choice of gradient implementation for d/dC and d/dR) = ', parametrization)
    print('\n')

    print('explicit representation as serial subset observations?', sso)
    print('\n')


    if not sso:
        print('grad C (Pr. Er)', sp.optimize.check_grad(fC,    gC, C.reshape(-1)))
        print('grad C (Hankel)', sp.optimize.check_grad(fC_ha, gC, C.reshape(-1)))
        print('grad C (corr. )', sp.optimize.check_grad(fC_rw, gC, C.reshape(-1)))
    print('grad C (impl. )', sp.optimize.check_grad(fC_impl, gC, C.reshape(-1)))
    print('||d/dC|| ', np.sum(gC(C.reshape(-1))**2))

    print('\n')

    if not sso:
        print('grad X (Pr. Er)', sp.optimize.check_grad(fX,    gX, X.reshape(-1)))
        print('grad X (Hankel)', sp.optimize.check_grad(fX_ha, gX, X.reshape(-1)))
        print('grad X (corr. )', sp.optimize.check_grad(fX_rw, gX, X.reshape(-1)))
    print('grad X (impl. )', sp.optimize.check_grad(fX_impl, gX, X.reshape(-1)))
    print('||d/dX|| ', np.sum(gX(X.reshape(-1))**2))

    print('\n')

    if not sso:
        print('grad R (Pr. Er)', sp.optimize.check_grad(fR,    gR, R.reshape(-1)))
        print('grad R (Hankel)', sp.optimize.check_grad(fR_ha, gR, R.reshape(-1)))
        print('grad R (corr. )', sp.optimize.check_grad(fR_rw, gR, R.reshape(-1)))
    print('grad R (impl. )', sp.optimize.check_grad(fR_impl, gR, R.reshape(-1)))
    print('||d/dR|| ', np.sum(gR(R.reshape(-1))**2))

    print('\n')

    if not sso:
        print('grad A (Pr. Er)', sp.optimize.check_grad(fA,    gA, A.reshape(-1)))
        print('grad A (Hankel)', sp.optimize.check_grad(fA_ha, gA, A.reshape(-1)))
        print('grad A (corr. )', sp.optimize.check_grad(fA_rw, gA, A.reshape(-1)))
    print('grad A (impl. )', sp.optimize.check_grad(fA_impl, gA, A.reshape(-1)))
    print('||d/dA|| ', np.sum(gA(A.reshape(-1))**2))

    print('\n')

    if not sso:
        print('grad B (Pr. Er)', sp.optimize.check_grad(fB,    gB, B.reshape(-1)))
        print('grad B (Hankel)', sp.optimize.check_grad(fB_ha, gB, B.reshape(-1)))
        print('grad B (corr. )', sp.optimize.check_grad(fB_rw, gB, B.reshape(-1)))
    print('grad B (impl. )', sp.optimize.check_grad(fB_impl, gB, B.reshape(-1)))  
    print('||d/dB|| ', np.sum(gB(B.reshape(-1))**2))
    

In [None]:
check_grads(parametrization='nl', sso=True)

In [None]:

(g_l2_Hankel_sgd_nl_sso(C,X,R,y,lag_range=lag_range,ts=np.random.choice(T-kl_, 10, replace=False),ms=range(len(lag_range)),obs_scheme=obs_scheme,W=W)[0]**2).sum()

In [None]:
check_grads(parametrization='nl', sso=True)
check_grads(parametrization='nl', sso=False)

# formula check

- numerically comparing implemented gradients with non-vectorised analytic formula 

In [None]:
t, m = np.random.choice(p, 1), 0
m_ = m

C, Xm, R = pars_true['C'], pars_true['X'][m*n:(m+1)*n, :], pars_true['R']
p,n = C.shape
grad_C = np.zeros((p,n))
grad_X = np.zeros(pars_true['X'].shape)
grad_R = np.zeros(p)
idx_ct = np.zeros((p,2),dtype=np.int32)

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

a,b = obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]
a,b = np.where(a)[1], np.where(b)[1]

anb = np.intersect1d(a,b)
a_ = np.setdiff1d(a,b)
b_ = np.setdiff1d(b,a)

yf = y[t+m_,a]
yp = y[t,b]

Om = np.outer(obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]).astype(bool)
L = np.outer(y[t+m,:], y[t,:])
grad_C = np.zeros_like(C)
for k in range(p):
    for i in range(p):
        for j in range(p):        

            if Om[i,j]:
                #print(i,j)
                Ci, Cj = C[i,:], C[j,:]
                if k==i and k!=j:
                    #print('1')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                if k==j and k!=i:
                    #print('2')
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                if k==i and k==j:
                    #print('3')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                    if m ==0:
                        grad_C[k,:] += R[i] * (Cj.dot(Xm.T)+Ci.dot(Xm))
                    
print('non-vectorised', grad_C)
grad_C_blunt = grad_C.copy()
#g_C_l2_Hankel_vector_pair(grad_C, m_, C, Xm, R, a, b, ab, CC_a, CC_b, yp, yf)    

grad_C = np.zeros((p,n))

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

grad_C[a,:] += C[a,:].dot( C_tr[b,:].T.dot(C_tr[b,:]) ) - np.outer(yf,yp.dot(C_tr[b,:]))
grad_C[b,:] += C[b,:].dot( C___[a,:].T.dot(C___[a,:]) ) - np.outer(yp,yf.dot(C___[a,:]))

# correction for variables not observed both at t+m_ and t  
#if a_.size > 0:
#    grad_C[a_,:] -= (np.sum(C[a_,:]*C_tr[a_,:],axis=1) - y[t+m,a_]*y[t,a_]).reshape(-1,1) * C_tr[a_,:]
#if b_.size > 0:
#    grad_C[b_,:] -= (np.sum(C[b_,:]*C___[b_,:],axis=1) - y[t+m,b_]*y[t,b_]).reshape(-1,1) * C___[b_,:]

if m_==0: 
    grad_C[anb,:] += R[anb].reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])
print('vectorised', grad_C)

print('overlap', anb)

assert np.allclose(grad_C_blunt, grad_C)

plt.imshow(Om, interpolation='None')
plt.title('observation scheme (\Omega)')
plt.show()