In [1]:
import numpy as np

import pymc as pm
import aesara
import aesara.tensor as at

import arviz as az



In [2]:
print(pm.__version__)
print(aesara.__version__)

4.1.4
2.7.9


In [4]:
c, sigma = 10., 1.5
X = np.random.normal(size=(100, 2))
beta = np.asarray([-3., 4.])
y = X @ beta + c + np.random.normal(0, sigma, size=100)

In [5]:
with pm.Model() as m:
    i = pm.Normal('c', 0, 100)
    β = pm.Normal('β', 0, 100, size=2)
    σ = pm.HalfNormal('σ', 10)
    obs = pm.Normal('y', X @ β + i, σ, observed=y)
    # trace = pm.sample(1000)
# az.plot_trace(trace);

In [6]:
init_point = m.initial_point()

val_and_grad = m.logp_dlogp_function()
val_and_grad.set_extra_values({})
q = pm.blocking.DictToArrayBijection.map({v.name: init_point[v.name] for v in m.value_vars})
val_and_grad(q)  # val_and_grad(q.data)

(array(-397.28626688),
 array([ 9.24825597, -3.12307039,  3.3779111 , 15.6715735 ]))

In [7]:
# same output as:
m.compile_logp()(init_point), m.compile_dlogp()(init_point)

(array(-397.28626688),
 array([ 9.24825597, -3.12307039,  3.3779111 , 15.6715735 ]))

In [33]:
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
value_var = at.concatenate([at.flatten(v) for v in vars], axis=0)
# value_var = [m.rvs_to_values.get(var) for var in m.free_RVs]
hvp = hessian @ value_var
hvp_fn = m.compile_fn(hvp)
hvp_fn(init_point)

array([ 42.58979265, -14.38227066,  15.5558555 , 537.29245186])

In [34]:
# same output as:
m.compile_d2logp()(init_point) @ q.data

array([ 42.58979265, -14.38227066,  15.5558555 , 537.29245186])

In [41]:
# Flatten and replace value (similar to ValueGradFunction in pm.Model)
flatten_var = at.vector(name='flatten_var')
split_point = np.concatenate([
    np.asarray([0]), 
    np.cumsum([
        np.prod(v) 
        for _, v, _ in q.point_map_info
    ])
], axis=-1).astype(int)

vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
    vars_replace.append(at.reshape(flatten_var[split_point[i]:split_point[i+1]], v))

hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))
hvp_fn2 = aesara.function([flatten_var], [hvp_clone])

In [42]:
hvp_fn2(q.data)

[array([ 42.58979265, -14.38227066,  15.5558555 , 537.29245186])]