# checking gradients

# gen toy data

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error

# define problem size
p, n, T = 5, 2, 200
lag_range = np.arange(0,10)
kl_ = np.max(lag_range)+1

nr = 0 # number of real eigenvalues
snr = (0.0, 0.0)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.95, 0.90, 0.95

print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

# I/O matter
mmap, chunksize = False, np.min((p,2000))
verbose=True

# create subpopulations
sub_pops = (np.arange(0,p), np.arange(0,p))
obs_pops = np.array([0,1])
obs_time = np.array([T//2,T])
obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)
obs_scheme = ObservationScheme(p=p, T=T, 
                               sub_pops=sub_pops, obs_pops=obs_pops, 
                               obs_time=obs_time, obs_idx=obs_time, 
                               idx_grp=idx_grp)

n_obs = np.ceil(p * 0.5)
mask = np.zeros((T,p))
for t in range(T):
    for i in range(len(obs_time)):
        if t < obs_time[i]:
            #mask[t, np.random.choice(p, n_obs, replace=False)] = 1
            mask[t,sub_pops[obs_pops[i]]] = 1
            break            
            
obs_scheme.mask = mask
plt.figure(figsize=(20,10))
plt.imshow(mask.T)
plt.show()

# draw system matrices / data
data_path = '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/fits/compare_vs_grouse/'
pars_true, x, y, Qs, idx_a, idx_b = gen_data(p,n,lag_range,T, nr,
                                             eig_m_r, eig_M_r, 
                                             eig_m_c, eig_M_c,
                                             mmap, chunksize,
                                             data_path,
                                             snr=snr, whiten=whiten)    
pars_true['X'] = np.vstack([np.linalg.matrix_power(pars_true['A'],k).dot(pars_true['Pi']) for k in lag_range])

if not obs_scheme.mask is None:
    def get_observed(t):
        return np.where(obs_scheme.mask[t,:])[0]
elif obs_time is None or obs_pops is None or sub_pops is None:
    def get_observed(t):
        return range(p) 
else:
    def get_observed(t):
        i = obs_pops[np.digitize(t, obs_time)]
        return sub_pops[i]

W = [np.zeros((p,p), dtype=int) for m in lag_range]
for m in range(len(lag_range)):
    m_ = lag_range[m]
    for t in range(T-kl_):
        a, b = get_observed(t+m_), get_observed(t)
        W[m][np.ix_(a,b)] += 1
    W[m] = 1./np.maximum(W[m], 1)
    
pars_est = 'default'

In [None]:
W

In [None]:
# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = None, np.inf, 1000
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros(max_iter)
def pars_track(C,X,R,t): 
    proj_errors[t] = calc_subspace_proj_error(pars_true['C'], C)
_, pars_est, traces = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size,
                                      pars_track=pars_track, W=W)

t = time.time() - t
print_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])
print('ground-truth error: ', f_l2_Hankel_nl(C=pars_true['C'],
               X=np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]),
               Pi=np.cov(x.T),
               R=pars_true['R'],lag_range=lag_range,Qs=Qs,
               idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b))

print('final error (est.): ', traces[0][-1])
print('final proj. error (est.): ', str(calc_subspace_proj_error(pars_true['C'], pars_est['C'])))


In [None]:
# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = None, np.inf, 1000
a, b1, b2, e = 0.001, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros(max_iter)
def pars_track(C,X,R,t): 
    proj_errors[t] = calc_subspace_proj_error(pars_true['C'], C)
_, pars_est, traces = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size,
                                      pars_track=pars_track, W=W)

t = time.time() - t
print_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])
print('ground-truth error: ', f_l2_Hankel_nl(C=pars_true['C'],
               X=np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]),
               Pi=np.cov(x.T),
               R=pars_true['R'],lag_range=lag_range,Qs=Qs,
               idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b))

print('final error (est.): ', traces[0][-1])
print('final proj. error (est.): ', str(calc_subspace_proj_error(pars_true['C'], pars_est['C'])))


In [None]:
# settings for quick initial SGD fitting phase for our model
batch_size, max_zip_size, max_iter = None, np.inf, 1000
a, b1, b2, e = 0.0001, 0.9, 0.99, 1e-8
a_R = 1 * a

proj_errors = np.zeros(max_iter)
def pars_track(C,X,R,t): 
    proj_errors[t] = calc_subspace_proj_error(pars_true['C'], C)
_, pars_est, traces = run_bad(lag_range=lag_range,n=n,y=y, Qs=Qs,idx_a=idx_a, idx_b=idx_b,
                                      obs_scheme=obs_scheme,init=pars_est,
                                      alpha_C=a,alpha_R=a_R,b1_C=b1,b2_C=b2,e_C=e,max_iter=max_iter,
                                      batch_size=batch_size,verbose=verbose, max_zip_size=max_zip_size,
                                      pars_track=pars_track, W=W)

t = time.time() - t
print_slim(Qs,lag_range,pars_est,idx_a,idx_b,traces,mmap,data_path)
print('fitting time was ', t, 's')
print('rank of final C_est: ', sp.linalg.orth(pars_est['C']).shape[1])
print('ground-truth error: ', f_l2_Hankel_nl(C=pars_true['C'],
               X=np.vstack([np.cov(x[k_:-(kl_+1)+k_, :].T, x[:-(kl_+1), :].T)[:n,n:] for k_ in lag_range]),
               Pi=np.cov(x.T),
               R=pars_true['R'],lag_range=lag_range,Qs=Qs,
               idx_grp=idx_grp,co_obs=co_obs,idx_a=idx_a,idx_b=idx_b))

print('final error (est.): ', traces[0][-1])
print('final proj. error (est.): ', str(calc_subspace_proj_error(pars_true['C'], pars_est['C'])))


# corrected gradients

In [None]:
##################
# Loss functions #
##################


# loss function that the non-weighted stochastic gradient descent actually minimizes
def f_sg(C, X, R, ts, ms, get_observed, t_ij):
    C = C.reshape(p,n)
    X = X.reshape(len(lag_range)*n,n)
    R = R.reshape(p)
    f = 0
    for i in range(len(ts)):
        a, b = get_observed(i)
        t, m = ts[i], ms[i] 
        Xm = X[m*n:(m+1)*n,:].copy()
        f += ((C[a,:].dot(Xm).dot(C[b,:].T) + \
               (ms[i]==0)* np.diag(R)[np.ix_(a,b)] - \
               np.outer(y[t+m,a], y[t,b]))**2).sum()
    return 0.5*f / len(ts)

# loss function based on L2 error for Hankel cov matrix  
def f_ha(C,X,R, ts, ms, get_observed, t_ij):
    C = C.reshape(p,n)
    X = X.reshape(len(lag_range)*n,n)
    R = R.reshape(p)
    S  = [np.zeros((p,p)) for m in range(len(lag_range))]
    Om = [np.zeros((p,p), dtype=bool) for m in range(len(lag_range))]
    for i in range(len(ts)):        
        a, b = get_observed(i)
        t, m = ts[i], ms[i] 
        S[ m][np.ix_(a, b)] += np.outer(y[t+m,a], y[t,b])
        Om[m][np.ix_(a, b)] = True
    return 0.5*np.sum([np.sum( (C.dot(X[m*n:(m+1)*n,:]).dot(C.T) + (m==0)*np.diag(R) - S[m]/t_ij[m])[Om[m]]**2) for m in range(len(lag_range))])

# loss function for re-weighted stochastic gradient descent
def f_rw(C,X,R, ts, ms, get_observed, t_ij):
    C = C.reshape(p,n)
    X = X.reshape(len(lag_range)*n,n)
    R = R.reshape(p)
    f = 0
    
    for i in range(len(ts)):
        a, b = get_observed(i)
        t, m = ts[i], ms[i] 
        Xm = X[m*n:(m+1)*n,:].copy()
        f += (((C[a,:].dot(Xm).dot(C[b,:].T) + \
                (ms[i]==0)* np.diag(R)[np.ix_(a,b)] - \
                np.outer(y[t+m,a], y[t,b]))**2)/t_ij[m][np.ix_(a,b)]).sum()
    return 0.5*f


#############
# Gradients #
#############

# unweighted stochastic gradients w.r.t. C 
def g_C(C, X, R, ts, ms, get_observed, t_ij): 
    C = C.reshape(p,n)
    grad_C = np.zeros((p,n))
    
    for i in range(len(ts)):
        
        t,m = ts[i], ms[i]
        a,b = get_observed(i)
        
        Xm = X[m*n:(m+1)*n,:].copy()
        C___ = C.dot(Xm)   # mad-
        C_tr = C.dot(Xm.T) # ness        
        grad_C[a,:] += C[a,:].dot( C_tr[b,:].T.dot(C_tr[b,:]) ) - np.outer(y[t+m,a],y[t,b].dot(C_tr[b,:]))
        grad_C[b,:] += C[b,:].dot( C___[a,:].T.dot(C___[a,:]) ) - np.outer(y[t,b],y[t+m,a].dot(C___[a,:]))
        if m ==0:
            anb = np.intersect1d(a,b)        
            grad_C[anb,:] += R[anb].reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])  
            
    return grad_C.reshape(-1) / len(ts)

# re-weighted stochastic gradients w.r.t. C 
def g_rw_C(C, X, R, ts, ms, get_observed, t_ij): 
    C = C.reshape(p,n)
    grad_C = np.zeros((p,n))
    
    for i in range(len(ts)):
        
        a, b = get_observed(i)
        t, m = ts[i], ms[i] 
        
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        C___ = C.dot(Xm)   # mad-
        C_tr = C.dot(Xm.T) # ness        
        
        for k in a:
            
            WC = C_tr[b,:] / np.maximum(t_ij[m][k,b],1).reshape(-1,1)
            grad_C[k,:] += C[k,:].dot( C_tr[b,:].T.dot(WC) ) 
            grad_C[k,:] -= y[t+m,k] * y[t,b].dot(WC)
            
        for k in b:
            
            WC = C___[a,:] / np.maximum(t_ij[m][a,k],1).reshape(-1,1)
            grad_C[k,:] += C[k,:].dot( C___[a,:].T.dot(WC) ) 
            grad_C[k,:] -= y[t,k] * y[t+m,a].dot(WC)
            
        if m == 0:
            anb = np.intersect1d(a,b)            
            grad_C[anb,:] += (R[anb]/t_ij[m][anb,anb]).reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])
                            
    return grad_C.reshape(-1) 

# re-weighted stochastic gradients w.r.t. X
def g_rw_X(C, X, R, ts, ms, get_observed, t_ij): 
    X = X.reshape(-1, n)
    grad_X = np.zeros_like(X)
            
    for i in range(len(ts)):
        a, b = get_observed(i)
        t, m = ts[i], ms[i]   
        
        Xm = X[m*n:(m+1)*n,:].copy()
        for k in a:
            
            S_k = C[b,:].T.dot(C[b,:] / t_ij[m][k,b].reshape(-1,1))
            grad_X[m*n:(m+1)*n,:] += np.outer(C[k,:], C[k,:]).dot(Xm).dot(S_k)
            
            S_k = y[t,b].dot(C[b,:] / t_ij[m][k,b].reshape(-1,1))
            grad_X[m*n:(m+1)*n,:] -= np.outer(y[t+m,k] * C[k,:], S_k)
            
        if m == 0:
            anb  = np.intersect1d(a,b)
            grad_X[:n,:] += C[anb,:].T.dot( (R[anb]/t_ij[m][anb,anb]).reshape(-1,1)*C[anb,:]) 
                
    return grad_X.reshape(-1)

# re-weighted stochastic gradients w.r.t. R
def g_rw_R(C, X, R, ts, ms, get_observed, t_ij):
    grad_R = np.zeros(p)
    
    for i in range(len(ts)):
        t, m = ts[i], ms[i]
        _, b = get_observed(i) 
        
        if m==0:
            grad_R[b] += (R[b] + np.sum(C[b,:] * C[b,:].dot(X[:n,:].T),axis=1) - y[t,b]**2) / t_ij[m][b,b]

    return grad_R


In [None]:
n_t = 10
m_max = np.max(lag_range)
ts,ms = np.random.choice(T-m_max, n_t, replace=False), np.random.choice(m_max, 1)[0] * np.ones(n_t,dtype=int)
#ts,ms = np.random.choice(T-m_max, n_t, replace=False), np.random.choice(m_max, n_t)

fully_obs = False
if fully_obs:
    print('assuming fully observed data \n')
    ais = [np.where(obs_scheme.mask[t+m,:])[0] for (t,m) in zip(ts, ms)]
    bis = [np.where(obs_scheme.mask[t,:])[0] for (t,m) in zip(ts, ms)]
else:
    print('assuming partially observed data \n')
    ais = [np.random.choice(p,p//2,replace=False) for (t,m) in zip(ts, ms)]
    bis = [ais[i].copy() if ms[i] == 0 else np.random.choice(p,p//2,replace=False) for i in range(len(ts))]
anbis = [np.intersect1d(ais[i],bis[i]) for i in range(len(ts)) ]

X, R = np.random.normal(size=(len(lag_range)*n,n)), np.random.normal(size=(p))**2
C = np.random.normal(size=(p,n))

def get_observed_i(i):
    return get_observed(ts[i] + ms[i]), get_observed(ts[i])

t_ij = [np.zeros((p,p), dtype=int) for m in range(len(lag_range))]
for i in range(len(ts)):  
    a, b = get_observed_i(i)
    t_ij[ms[i]][np.ix_(a, b)] += 1
t_ij = [np.maximum(t_ij[m], 1) for m in range(len(lag_range))]

W = [1./t_ij[m] for m in range(len(lag_range))]

# decorating
def fC(C):
    return f_sg(C,X,R, ts, ms, get_observed_i, t_ij)
def fC_ha(C):
    return f_ha(C,X,R, ts, ms, get_observed_i, t_ij)
def fC_rw(C):
    return f_rw(C,X,R, ts, ms, get_observed_i, t_ij)
def gC_rw(C):
    return g_rw_C(C, X, R,ts,ms,get_observed_i,t_ij)
def gC(C):
    C = C.reshape(p,n)
    grad_C,_,_ = g_l2_Hankel_sgd(C,X,R,y,lag_range,ts,(ms[0],),get_observed,linear=False, W=W)
    return grad_C.reshape(-1)

def fX(X):
    return f_sg(C,X,R, ts, ms, get_observed_i, t_ij)
def fX_ha(X):
    return f_ha(C,X,R, ts, ms, get_observed_i, t_ij)
def fX_rw(X):
    return f_rw(C,X,R, ts, ms, get_observed_i, t_ij)
def gX_rw(X):
    return g_rw_X(C, X, R, ts, ms, get_observed_i, t_ij)
def gX(X):
    X = X.reshape(len(lag_range)*n,n)
    _,grad_X,_ = g_l2_Hankel_sgd(C,X,R,y,lag_range,ts,(ms[0],),get_observed,linear=False, W=W)
    return grad_X.reshape(-1)

def fR(R):
    return f_sg(C,X,R, ts, ms, get_observed_i, t_ij)
def fR_ha(R):
    return f_ha(C,X,R, ts, ms, get_observed_i, t_ij)
def fR_rw(R):
    return f_rw(C,X,R, ts, ms, get_observed_i, t_ij)
def gR_rw(R):
    return g_rw_R(C, X, R, ts, ms, get_observed_i, t_ij)
def gR(R):
    R = R.reshape(p)
    _,_,grad_R = g_l2_Hankel_sgd(C,X,R,y,lag_range,ts,(ms[0],),get_observed,linear=False, W=W)
    return grad_R.reshape(-1)

print('m:', ms)
print('\n')


print('grad C (actual)', sp.optimize.check_grad(fC,    gC, C.reshape(-1)))
print('grad C (Hankel)', sp.optimize.check_grad(fC_ha, gC, C.reshape(-1)))
print('grad C (corr. )', sp.optimize.check_grad(fC_rw, gC, C.reshape(-1)))

print('\n')

print('grad X (actual)', sp.optimize.check_grad(fX,    gX, X.reshape(-1)))
print('grad X (Hankel)', sp.optimize.check_grad(fX_ha, gX, X.reshape(-1)))
print('grad X (corr. )', sp.optimize.check_grad(fX_rw, gX, X.reshape(-1)))

print('\n')

print('grad R (actual)', sp.optimize.check_grad(fR,    gR, R.reshape(-1)))
print('grad R (Hankel)', sp.optimize.check_grad(fR_ha, gR, R.reshape(-1)))
print('grad R (corr. )', sp.optimize.check_grad(fR_rw, gR, R.reshape(-1)))


# formula check

- numerically comparing implemented gradients with non-vectorised analytic formula 

In [None]:
t, m = np.random.choice(p, 1), 0
m_ = m

C, Xm, R = pars_true['C'], pars_true['X'][m*n:(m+1)*n, :], pars_true['R']
p,n = C.shape
grad_C = np.zeros((p,n))
grad_X = np.zeros(pars_true['X'].shape)
grad_R = np.zeros(p)
idx_ct = np.zeros((p,2),dtype=np.int32)

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

a,b = obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]
a,b = np.where(a)[1], np.where(b)[1]

anb = np.intersect1d(a,b)
a_ = np.setdiff1d(a,b)
b_ = np.setdiff1d(b,a)

yf = y[t+m_,a]
yp = y[t,b]

Om = np.outer(obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]).astype(bool)
L = np.outer(y[t+m,:], y[t,:])
grad_C = np.zeros_like(C)
for k in range(p):
    for i in range(p):
        for j in range(p):        

            if Om[i,j]:
                #print(i,j)
                Ci, Cj = C[i,:], C[j,:]
                if k==i and k!=j:
                    #print('1')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                if k==j and k!=i:
                    #print('2')
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                if k==i and k==j:
                    #print('3')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                    if m ==0:
                        grad_C[k,:] += R[i] * (Cj.dot(Xm.T)+Ci.dot(Xm))
                    
print('non-vectorised', grad_C)
grad_C_blunt = grad_C.copy()
#g_C_l2_Hankel_vector_pair(grad_C, m_, C, Xm, R, a, b, ab, CC_a, CC_b, yp, yf)    

grad_C = np.zeros((p,n))

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

grad_C[a,:] += C[a,:].dot( C_tr[b,:].T.dot(C_tr[b,:]) ) - np.outer(yf,yp.dot(C_tr[b,:]))
grad_C[b,:] += C[b,:].dot( C___[a,:].T.dot(C___[a,:]) ) - np.outer(yp,yf.dot(C___[a,:]))

# correction for variables not observed both at t+m_ and t  
#if a_.size > 0:
#    grad_C[a_,:] -= (np.sum(C[a_,:]*C_tr[a_,:],axis=1) - y[t+m,a_]*y[t,a_]).reshape(-1,1) * C_tr[a_,:]
#if b_.size > 0:
#    grad_C[b_,:] -= (np.sum(C[b_,:]*C___[b_,:],axis=1) - y[t+m,b_]*y[t,b_]).reshape(-1,1) * C___[b_,:]

if m_==0: 
    grad_C[anb,:] += R[anb].reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])
print('vectorised', grad_C)

print('overlap', anb)

assert np.allclose(grad_C_blunt, grad_C)

plt.imshow(Om, interpolation='None')
plt.title('observation scheme (\Omega)')
plt.show()