# checking gradients

# gen toy data

In [None]:
% matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import glob, os, psutil, time

from ssidid.SSID_Hankel_loss import run_bad, plot_slim, print_slim, f_l2_Hankel_nl
from ssidid.utility import get_subpop_stats, gen_data
from ssidid import ObservationScheme
from subtracking import Grouse, calc_subspace_proj_error

# define problem size
p, n, T = 10, 2, 200
lag_range = np.arange(0,10)
kl_ = np.max(lag_range)+1

nr = 0 # number of real eigenvalues
snr = (0.0, 0.0)
whiten = True
eig_m_r, eig_M_r, eig_m_c, eig_M_c = 0.90, 0.95, 0.90, 0.95

print('(p,n,k+l,T) = ', (p,n,len(lag_range),T), '\n')

# I/O matter
mmap, chunksize = False, np.min((p,2000))
verbose=True

# create subpopulations
sub_pops = (np.arange(0,p), np.arange(0,p))
obs_pops = np.array([0,1])
obs_time = np.array([T//2,T])
obs_idx, idx_grp, co_obs, _, _, _, Om, _, _ = \
    get_subpop_stats(sub_pops=sub_pops, p=p, verbose=False)
obs_scheme = ObservationScheme(p=p, T=T, 
                               sub_pops=sub_pops, obs_pops=obs_pops, 
                               obs_time=obs_time, obs_idx=obs_time, 
                               idx_grp=idx_grp)

n_obs = np.ceil(p * 0.5)
mask = np.zeros((T,p))
for t in range(T):
    for i in range(len(obs_time)):
        if t < obs_time[i]:
            #mask[t, np.random.choice(p, n_obs, replace=False)] = 1
            mask[t,sub_pops[obs_pops[i]]] = 1
            break            
            
obs_scheme.mask = mask
plt.figure(figsize=(20,10))
plt.imshow(mask.T)
plt.show()

# draw system matrices / data
data_path = '/home/mackelab/Desktop/Projects/Stitching/code/le_stitch/python/fits/compare_vs_grouse/'
pars_true, x, y, Qs, idx_a, idx_b = gen_data(p,n,lag_range,T, nr,
                                             eig_m_r, eig_M_r, 
                                             eig_m_c, eig_M_c,
                                             mmap, chunksize,
                                             data_path,
                                             snr=snr, whiten=whiten)    
pars_true['X'] = np.vstack([np.linalg.matrix_power(pars_true['A'],k).dot(pars_true['Pi']) for k in lag_range])


# formula check

- numerically comparing implemented gradients with non-vectorised analytic formula 

In [None]:
t, m = np.random.choice(p, 1), 0
m_ = m

C, Xm, R = pars_true['C'], pars_true['X'][m*n:(m+1)*n, :], pars_true['R']
p,n = C.shape
grad_C = np.zeros((p,n))
grad_X = np.zeros(pars_true['X'].shape)
grad_R = np.zeros(p)
idx_ct = np.zeros((p,2),dtype=np.int32)

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

a,b = obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]
a,b = np.where(a)[1], np.where(b)[1]

anb = np.intersect1d(a,b)
a_ = np.setdiff1d(a,b)
b_ = np.setdiff1d(b,a)

yf = y[t+m_,a]
yp = y[t,b]

Om = np.outer(obs_scheme.mask[t+m,:], obs_scheme.mask[t,:]).astype(bool)
L = np.outer(y[t+m,:], y[t,:])
grad_C = np.zeros_like(C)
for k in range(p):
    for i in range(p):
        for j in range(p):        

            if Om[i,j]:
                #print(i,j)
                Ci, Cj = C[i,:], C[j,:]
                if k==i and k!=j:
                    #print('1')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                if k==j and k!=i:
                    #print('2')
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                if k==i and k==j:
                    #print('3')
                    grad_C[k,:] += Ci.dot(Xm.dot(np.outer(Cj,Cj)).dot(Xm.T)) - L[i,j]*Cj.dot(Xm.T)
                    grad_C[k,:] += Cj.dot(Xm.T.dot(np.outer(Ci,Ci)).dot(Xm)) - L[i,j]*Ci.dot(Xm)
                    if m ==0:
                        grad_C[k,:] += R[i] * (Cj.dot(Xm.T)+Ci.dot(Xm))
                    
print('non-vectorised', grad_C)
grad_C_blunt = grad_C.copy()
#g_C_l2_Hankel_vector_pair(grad_C, m_, C, Xm, R, a, b, ab, CC_a, CC_b, yp, yf)    

grad_C = np.zeros((p,n))

C___ = C.dot(Xm)   # mad-
C_tr = C.dot(Xm.T) # ness

grad_C[a,:] += C[a,:].dot( C_tr[b,:].T.dot(C_tr[b,:]) ) - np.outer(yf,yp.dot(C_tr[b,:]))
grad_C[b,:] += C[b,:].dot( C___[a,:].T.dot(C___[a,:]) ) - np.outer(yp,yf.dot(C___[a,:]))

# correction for variables not observed both at t+m_ and t  
#if a_.size > 0:
#    grad_C[a_,:] -= (np.sum(C[a_,:]*C_tr[a_,:],axis=1) - y[t+m,a_]*y[t,a_]).reshape(-1,1) * C_tr[a_,:]
#if b_.size > 0:
#    grad_C[b_,:] -= (np.sum(C[b_,:]*C___[b_,:],axis=1) - y[t+m,b_]*y[t,b_]).reshape(-1,1) * C___[b_,:]

if m_==0: 
    grad_C[anb,:] += R[anb].reshape(-1,1)*(C___[anb,:] + C_tr[anb,:])
print('vectorised', grad_C)

print('overlap', anb)

assert np.allclose(grad_C_blunt, grad_C)

plt.imshow(Om, interpolation='None')
plt.title('observation scheme (\Omega)')
plt.show()

# errors per time-lag

In [None]:
def f_l2_block(C,AmPi,Q,idx_grp,co_obs,idx_a,idx_b,W=None):
    "Hankel reconstruction error on an individual Hankel block"

    err = 0.
    for i in range(len(idx_grp)):
        err_ab = 0.
        a = np.intersect1d(idx_grp[i],idx_a)
        b = np.intersect1d(co_obs[i], idx_b)
        a_Q = np.in1d(idx_a, idx_grp[i])
        b_Q = np.in1d(idx_b, co_obs[i])

        v = (C[a,:].dot(AmPi).dot(C[b,:].T) - Q[np.ix_(a_Q,b_Q)])
        v = v.reshape(-1,) if  W is None else W.reshape(-1,) * v.reshape(-1,)

        err += v.dot(v)

    return err

def f_l2_inst(C,Pi,R,Q,idx_grp,co_obs,idx_a,idx_b,W=None):
    "reconstruction error on the instantaneous covariance"

    err = 0.
    if not Q is None:
        for i in range(len(idx_grp)):

            a = np.intersect1d(idx_grp[i],idx_a)
            b = np.intersect1d(co_obs[i], idx_b)
            a_Q = np.in1d(idx_a, idx_grp[i])
            b_Q = np.in1d(idx_b, co_obs[i])

            v = (C[a,:].dot(Pi).dot(C[b,:].T) - Q[np.ix_(a_Q,b_Q)])
            idx_R = np.where(np.in1d(b,a))[0]
            v[np.arange(len(idx_R)), idx_R] += R[a]
            v = v.reshape(-1,) if  W is None else W.reshape(-1,)*v.reshape(-1,)

            err += v.dot(v)

    return err

def f_l2_Hankel_nl(C,X,Pi,R,lag_range,Qs,idx_grp,co_obs,
                   idx_a=None,idx_b=None,W=None):
    "returns overall l2 Hankel reconstruction error"

    kl = len(lag_range)
    p,n = C.shape
    idx_a = np.arange(p) if idx_a is None else idx_a
    idx_b = idx_a if idx_b is None else idx_b
    assert (len(idx_a), len(idx_b)) == Qs[0].shape

    err = np.zeros(kl)
    err[0] = f_l2_inst(C,X[:n, :],R,Qs[0],idx_grp,co_obs,idx_a,idx_b)
    for m in range(1,kl):
        err[m]= f_l2_block(C,X[m*n:(m+1)*n, :],Qs[m],idx_grp,co_obs,idx_a,idx_b,W)
            
    return err
pars_est = pars_est.copy()
print('est. pars \n', f_l2_Hankel_nl(pars_est['C'], pars_est['X'], pars_est['X'][:n,:], pars_est['R'], lag_range, Qs, idx_grp, co_obs))

pars_est = pars_true.copy()
print('ground-truth pars \n', f_l2_Hankel_nl(pars_est['C'], np.vstack([np.cov(x[k_:-(kl_)+k_, :].T, x[:-(kl_), :].T)[:n,n:] for k_ in lag_range]), pars_est['X'][:n,:], pars_est['R'], lag_range, Qs, idx_grp, co_obs))


# old gradients

In [None]:
m_max = 4
n_t = 10
ts,ms = np.random.choice(T-m_max, n_t), np.random.choice(m_max, 1)[0] * np.ones(n_t,dtype=int)

a = [np.where(obs_scheme.mask[t+m,:])[0] for (t,m) in zip(ts, ms)]
b = [np.where(obs_scheme.mask[t,:])[0] for (t,m) in zip(ts, ms)]

a = [np.random.choice(p,p//2,replace=False) for (t,m) in zip(ts, ms)]
b = [a[i].copy() if ms[i] == 0 else np.random.choice(p,p//2,replace=False) for i in range(len(ts))]


anb = [np.intersect1d(a[i],b[i]) for i in range(len(ts)) ]
pars_est = pars_true.copy()

#Xm, R = pars_est['X'][m*n:(m+1)*n,:], pars_est['R']
X, R = np.random.normal(size=(m_max*n,n)), np.random.normal(size=(p))**2
C = np.random.normal(size=(p,n))

def fC(C):
    C = C.reshape(p,n)
    f = 0
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        f += ((C[a[i],:].dot(Xm).dot(C[b[i],:].T) + \
               (ms[i]==0)* np.diag(R)[np.ix_(a[i],b[i])] - \
               np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]]))**2).sum()
    
    return 0.5*f / len(ts)

def fC_rw(C):
    C = C.reshape(p,n)
    f = 0
    
    nC = [np.zeros((p,p), dtype=int) for m in range(m_max)]
    for i in range(len(ts)):        
        nC[ms[i]][np.ix_(a[i], b[i])] += 1
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        f += (((C[a[i],:].dot(Xm).dot(C[b[i],:].T) + \
                (ms[i]==0)* np.diag(R)[np.ix_(a[i],b[i])] - \
                np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]]))**2)/nC[ms[i]][np.ix_(a[i],b[i])]).sum()
    
    return 0.5*f

def fC_(C):
    C = C.reshape(p,n)
    S  = [np.zeros((p,p)) for m in range(m_max)]
    cS = [np.zeros((p,p), dtype=int) for m in range(m_max)]
    for i in range(len(ts)):        
        S[ ms[i]][np.ix_(a[i], b[i])] += np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]])
        cS[ms[i]][np.ix_(a[i], b[i])] += 1
    Om = [cS[m] > 0 for m in range(m_max)]
    cS = [np.maximum(cS[m], 1) for m in range(m_max)]
    #print(cS)
    return 0.5*np.sum([np.sum( (C.dot(X[m*n:(m+1)*n,:]).dot(C.T) + (m==0)*np.diag(R) - S[m]/cS[m])[Om[m]]**2) for m in range(m_max)])


def fX(X):
    X = X.reshape(-1,n)
    f = 0
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        
        f += ((C[a[i],:].dot(Xm).dot(C[b[i],:].T) - np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]]) + (ms[i]==0)* np.diag(R)[np.ix_(a[i],b[i])])**2).sum()
    
    return 0.5*f / len(ts)

def fR(R):
    S = np.zeros((p,p))
    f = 0
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        
        f += ((C[a[i],:].dot(Xm).dot(C[b[i],:].T) - np.outer(y[ts[i]+ms[i],a[i]], y[ts[i],b[i]]) + (ms[i]==0)* np.diag(R)[np.ix_(a[i],b[i])])**2).sum()
    
    return 0.5*f / len(ts)


def gC(C): 
    C = C.reshape(p,n)
    grad_C = np.zeros((p,n))

    nC = [np.zeros((p,p), dtype=int) for m in range(m_max)]
    for i in range(len(ts)):        
        nC[ms[i]][np.ix_(a[i], b[i])] += 1    
    
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        C___ = C.dot(Xm)   # mad-
        C_tr = C.dot(Xm.T) # ness        
        grad_C[a[i],:] += C[a[i],:].dot( C_tr[b[i],:].T.dot(C_tr[b[i],:]) ) - np.outer(y[ts[i]+ms[i],a[i]],y[ts[i],b[i]].dot(C_tr[b[i],:]))
        grad_C[b[i],:] += C[b[i],:].dot( C___[a[i],:].T.dot(C___[a[i],:]) ) - np.outer(y[ts[i],b[i]],y[ts[i]+ms[i],a[i]].dot(C___[a[i],:]))
        if ms[i] ==0:
            grad_C[anb[i],:] += R[anb[i]].reshape(-1,1)*(C___[anb[i],:] + C_tr[anb[i],:])  
    return grad_C.reshape(-1) / len(ts)

def gX(X):
    X = X.reshape(-1, n)
    grad_X = np.zeros_like(X)
    for i in range(len(ts)):
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        CC_a = C[a[i],:].T.dot(C[a[i],:])
        CC_b = C[b[i],:].T.dot(C[b[i],:])
        grad_X[ms[i]*n:(ms[i]+1)*n,:] += CC_a.dot(Xm).dot(CC_b) - np.outer(y[ts[i]+ms[i],a[i]].dot(C[a[i],:]), y[ts[i],b[i]].dot(C[b[i],:]))
        if ms[i] == 0:
            grad_X[:n,:] += C[anb[i],:].T.dot(R[anb[i]].reshape(-1,1) * C[anb[i],:])
    return grad_X.reshape(-1) / len(ts)

def gR(R): 
    grad_R = np.zeros(p)
    for i in range(len(ts)):
        if ms[i]==0:
            grad_R[b[i]] += R[b[i]] + np.sum(C[b[i],:] * C[b[i],:].dot(X[:n,:].T),axis=1) - y[ts[i],b[i]]**2
    return grad_R / len(ts)

print(fC(C), fX(X), fR(R))
#gC(pars_est['C'])

#def fC(C):
#    return 0.5*np.sum(C**2)
#def gC(C):
#    return C

print('a, b, a \cap b \n', a,b,anb)

print('ms \n', ms)

print('numerical gradient errors (sp.otimize.check_grad)')
print('grad C (actual)', sp.optimize.check_grad(fC, gC, C.reshape(-1)))
print('grad C (Hankel)', sp.optimize.check_grad(fC_, gC, C.reshape(-1)))
print('grad C (corr. )', sp.optimize.check_grad(fC_rw, gC, C.reshape(-1)))

print('grad X', sp.optimize.check_grad(fX, gX, X.reshape(-1)))
print('grad R', sp.optimize.check_grad(fR, gR, R.reshape(-1)))

print('function evaluation (actual f , Hankel f, corrected f )')
print(fC(C), fC_(C), fC_rw(C))

V = np.zeros((p,p))
for i in range(p):
    for j in range(p):
        V[i,j] = np.var(y[ts+ms,i] * y[ts,i])

fC(C), fC_(C) + 0.5 * V.sum()

# corrected gradients

In [None]:
m_max = 4
n_t = 10
ts,ms = np.random.choice(T-m_max, n_t), np.random.choice(m_max, 1)[0] * np.ones(n_t,dtype=int)

a = [np.where(obs_scheme.mask[t+m,:])[0] for (t,m) in zip(ts, ms)]
b = [np.where(obs_scheme.mask[t,:])[0] for (t,m) in zip(ts, ms)]

ais = [np.random.choice(p,p//2,replace=False) for (t,m) in zip(ts, ms)]
bis = [a[i].copy() if ms[i] == 0 else np.random.choice(p,p//2,replace=False) for i in range(len(ts))]

anbis = anb


def fC(C):
    C = C.reshape(p,n)
    f = 0
    for i in range(len(ts)):
        a, b = ais[i], bis[i]
        t, m = ts[i], ms[i] 
        Xm = X[m*n:(m+1)*n,:].copy()
        f += ((C[a,:].dot(Xm).dot(C[b,:].T) + \
               (ms[i]==0)* np.diag(R)[np.ix_(a,b)] - \
               np.outer(y[t+m,a], y[t,b]))**2).sum()
    
    return 0.5*f / len(ts)

def fC_rw(C):
    C = C.reshape(p,n)
    f = 0
    
    nC = [np.zeros((p,p), dtype=int) for m in range(m_max)]
    for i in range(len(ts)):  
        a, b = ais[i], bis[i]
        nC[ms[i]][np.ix_(a, b)] += 1
    for i in range(len(ts)):
        a, b = ais[i], bis[i]
        t, m = ts[i], ms[i] 
        Xm = X[m*n:(m+1)*n,:].copy()
        f += (((C[a,:].dot(Xm).dot(C[b,:].T) + \
                (ms[i]==0)* np.diag(R)[np.ix_(a,b)] - \
                np.outer(y[t+m,a], y[t,b]))**2)/nC[m][np.ix_(a,b)]).sum()
    
    return 0.5*f

def fC_(C):
    C = C.reshape(p,n)
    S  = [np.zeros((p,p)) for m in range(m_max)]
    cS = [np.zeros((p,p), dtype=int) for m in range(m_max)]
    for i in range(len(ts)):        
        a, b = ais[i], bis[i]
        t, m = ts[i], ms[i] 
        S[ m][np.ix_(a, b)] += np.outer(y[t+m,a], y[t,b])
        cS[m][np.ix_(a, b)] += 1
    Om = [cS[m] > 0 for m in range(m_max)]
    cS = [np.maximum(cS[m], 1) for m in range(m_max)]
    #print(cS)
    return 0.5*np.sum([np.sum( (C.dot(X[m*n:(m+1)*n,:]).dot(C.T) + (m==0)*np.diag(R) - S[m]/cS[m])[Om[m]]**2) for m in range(m_max)])



def gC_rw(C): 
    C = C.reshape(p,n)
    grad_C = np.zeros((p,n))

    nC = [np.zeros((p,p), dtype=int) for m in range(m_max)]
    for i in range(len(ts)):        
        
        a, b = ais[i], bis[i]
        anb  = anbis[i]
        
        nC[ms[i]][np.ix_(a, b)] += 1    
    
    for i in range(len(ts)):
        
        a, b = ais[i], bis[i]
        anb  = anbis[i]
        t, m = ts[i], ms[i] 
        
        Xm = X[ms[i]*n:(ms[i]+1)*n,:].copy()
        C___ = C.dot(Xm)   # mad-
        C_tr = C.dot(Xm.T) # ness        
        for k in a:
            #W = np.diag(1./np.maximum(nC[m][k,b],1))
            WC = C_tr[b,:] / np.maximum(nC[m][k,b],1).reshape(-1,1)
            grad_C[k,:] += C[k,:].dot( C_tr[b,:].T.dot(WC) ) 
            grad_C[k,:] -= y[t+m,k] * y[t,b].dot(WC)
        for k in b:
            WC = C___[a,:] / np.maximum(nC[m][a,k],1).reshape(-1,1)
            grad_C[k,:] += C[k,:].dot( C___[a,:].T.dot(WC) ) 
            grad_C[k,:] -= y[t,k] * y[t+m,a].dot(WC)
        if ms[i] ==0:
            for k in anb:
                grad_C[k,:] += R[k].reshape(-1,1)*(C___[k,:] + C_tr[k,:])
    return grad_C.reshape(-1) 


print('grad C (actual)', sp.optimize.check_grad(fC, gC_rw, C.reshape(-1)))
print('grad C (Hankel)', sp.optimize.check_grad(fC_, gC_rw, C.reshape(-1)))
print('grad C (corr. )', sp.optimize.check_grad(fC_rw, gC_rw, C.reshape(-1)))