In [3]:
import numpy as np
import matplotlib.pyplot as plt
import thztools as thz

from numpy import pi
from scipy.optimize import approx_fprime

In [4]:
m = 2
n = 16
dt = 1.0 / n
t = np.arange(n) * dt
mu = np.cos(2 * pi * t)
x = np.tile(mu, [m, 1])
logv = [0, -np.inf, -np.inf]
a = np.ones(m)
eta = np.zeros(m)
ts = dt

In [63]:
def grad_nll_approx(x_, mu_, logv_, a_, eta_, ts_):
    kwargs = {'fix_logv': True, 'fix_mu': True, 'fix_a': True, 'fix_eta': True}
    grad_logv = approx_fprime(
        logv_,
        lambda logv_var: thz.tdnll(x_, mu_, logv_var, a_, eta_, ts_, **kwargs)[
            0
        ],
    )
    grad_mu = approx_fprime(
        mu_,
        lambda mu_var: thz.tdnll(x_, mu_var, logv_, a_, eta_, ts_, **kwargs)[
            0
        ],
    )
    grad_a = approx_fprime(
        a_[1:],
        lambda a_var: thz.tdnll(x_, mu_, logv_, 
                                np.insert(a_var, 0, a_[0]), 
                                eta_, ts_, **kwargs)[0],
    )
    grad_eta = approx_fprime(
        eta_[1:],
        lambda eta_var: thz.tdnll(x_, mu_, logv_, a_, 
                                  np.insert(eta_var, 0, eta_[0]), 
                                  ts_, **kwargs)[
            0
        ],
        epsilon=3e-9 # Derivative is inaccurate at default epsilon
    )
    return np.concatenate((grad_logv, grad_mu, grad_a, grad_eta))

In [64]:
grad_nll_approx_val = grad_nll_approx(x, mu, logv, a, eta, ts)

  dx = ((x0 + h) - x0)
  dx = x[i] - x0[i]  # Recompute dx as exactly representable number.


In [65]:
_, grad_nll = thz.tdnll(
    x,
    mu,
    logv,
    a,
    eta,
    ts,
    fix_logv=False,
    fix_mu=False,
    fix_a=False,
    fix_eta=False,
)
grad_nll

array([ 1.60000000e+01,  0.00000000e+00,  0.00000000e+00, -2.22044605e-16,
       -2.22044605e-16,  2.22044605e-16,  1.23259516e-32, -3.08148791e-32,
        2.22044605e-16,  0.00000000e+00, -1.18479872e-32,  2.22044605e-16,
       -1.23259516e-32,  0.00000000e+00,  1.11022302e-16,  5.54667824e-32,
       -1.11022302e-16,  2.22044605e-16, -4.77964436e-34, -2.52579384e-16,
        3.77524166e-16])

In [66]:
grad_nll[-1]

3.7752416578900986e-16

In [67]:
grad_nll_approx_val[-1]

0.0