In [1]:
import numpy as np

import pymc as pm
import aesara
import aesara.tensor as at

import arviz as az



In [2]:
print(pm.__version__)
print(aesara.__version__)

4.1.4
2.7.9


In [3]:
c, sigma = 10., 1.5
X = np.random.normal(size=(100, 2))
beta = np.asarray([-3., 4.])
y = X @ beta + c + np.random.normal(0, sigma, size=100)

In [4]:
with pm.Model() as m:
    i = pm.Normal('c', 0, 100)
    β = pm.Normal('β', 0, 100, size=2)
    σ = pm.HalfNormal('σ', 10)
    obs = pm.Normal('y', X @ β + i, σ, observed=y)
    # trace = pm.sample(1000)
# az.plot_trace(trace);

In [5]:
init_point = m.initial_point()

val_and_grad = m.logp_dlogp_function()
val_and_grad.set_extra_values({})
q = pm.blocking.DictToArrayBijection.map({v.name: init_point[v.name] for v in m.value_vars})
val_and_grad(q)  # val_and_grad(q.data)

(array(-392.46791302),
 array([ 9.25599452, -1.72046969,  2.09698287,  6.03486578]))

In [6]:
# same output as:
m.compile_logp()(init_point), m.compile_dlogp()(init_point)

(array(-392.46791302),
 array([ 9.25599452, -1.72046969,  2.09698287,  6.03486578]))

In [7]:
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
value_var = at.concatenate([at.flatten(v) for v in vars], axis=0)
# value_var = [m.rvs_to_values.get(var) for var in m.free_RVs]
hvp = hessian @ value_var
hvp_fn = m.compile_fn(hvp)
hvp_fn(init_point)

array([ 42.62543002,  -7.92305572,   9.65696298, 492.91377277])

In [8]:
# same output as:
m.compile_d2logp()(init_point) @ q.data

array([ 42.62543002,  -7.92305572,   9.65696298, 492.91377277])

In [9]:
# Flatten and replace value (similar to ValueGradFunction in pm.Model)
flatten_var = at.vector(name='flatten_var')
split_point = np.concatenate([
    np.asarray([0]), 
    np.cumsum([
        np.prod(v) 
        for _, v, _ in q.point_map_info
    ])
], axis=-1).astype(int)

vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
    vars_replace.append(at.reshape(flatten_var[split_point[i]:split_point[i+1]], v))

hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))
hvp_fn2 = aesara.function([flatten_var], [hvp_clone])

In [10]:
hvp_fn2(q.data)

[array([ 42.62543002,  -7.92305572,   9.65696298, 492.91377277])]

In [11]:
b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b

hvp_fn = pm.aesaraf.compile_pymc(vars + [b], [hvp])
hvp_fn(*init_point.values(), q.data)

[array([ 42.62543002,  -7.92305572,   9.65696298, 492.91377277])]

In [12]:
b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b

# Flatten and replace value (similar to ValueGradFunction in pm.Model)
theta = at.vector(name='theta')
split_point = np.concatenate([
    np.asarray([0]), 
    np.cumsum([
        np.prod(v) 
        for _, v, _ in q.point_map_info
    ])
], axis=-1).astype(int)
vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
    vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))

hvp_fn = pm.aesaraf.compile_pymc([theta, b], [hvp_clone])
hvp_fn(q.data, q.data)

[array([ 42.62543002,  -7.92305572,   9.65696298, 492.91377277])]

In [13]:
import pandas as pd

data = pd.read_csv(pm.get_data('radon.csv'))
data['log_radon'] = data['log_radon'].astype(aesara.config.floatX)
county_names = data.county.unique()
county_idx = data.county_code.values.astype('int32')

n_counties = len(data.county.unique())

with pm.Model() as m:
    # Hyperpriors for group nodes
    mu_a = pm.Normal('mu_a', mu=0., sigma=100.)
    sigma_a = pm.HalfNormal('sigma_a', 5.)
    mu_b = pm.Normal('mu_b', mu=0., sigma=100.)
    sigma_b = pm.HalfNormal('sigma_b', 5.)

    # Intercept for each county, distributed around group mean mu_a
    # Above we just set mu and sd to a fixed value while here we
    # plug in a common group distribution for all a and b (which are
    # vectors of length n_counties).
    a = pm.Normal('a', mu=mu_a, sigma=sigma_a, shape=n_counties)
    # Intercept for each county, distributed around group mean mu_a
    b = pm.Normal('b', mu=mu_b, sigma=sigma_b, shape=n_counties)

    # Model error
    eps = pm.HalfCauchy('eps', 5.)

    radon_est = a[county_idx] + b[county_idx] * data.floor.values

    # Data likelihood
    radon_like = pm.Normal('radon_like', mu=radon_est,
                           sigma=eps, observed=data.log_radon)

In [14]:
init_point = m.initial_point()
q = pm.blocking.DictToArrayBijection.map({v.name: init_point[v.name] for v in m.value_vars})

b = at.vector(name='b')
hessian = m.d2logp()
vars = pm.aesaraf.cont_inputs(hessian)
hvp = hessian @ b

hvp_fn = pm.aesaraf.compile_pymc(vars + [b], [hvp])
hvp_fn(*init_point.values(), q.data)

[array([ 0.00000000e+00,  3.21887582e+00,  0.00000000e+00,  3.21887582e+00,
         3.68206010e-01,  5.96446425e+00,  4.21032191e-01,  1.07550628e+00,
         6.60378322e-01,  5.93607912e-01,  3.47250108e+00,  8.50531124e-01,
         1.25762188e+00,  9.43563427e-01,  9.19367227e-01,  9.04300167e-01,
         8.37980965e-01,  3.26088098e+00,  5.27221106e-01,  1.84774355e-01,
         3.86311649e-01,  1.52605298e+00,  1.07742936e+01,  7.01776778e-01,
         1.93979617e+00,  5.12649478e-01,  2.76750434e-01,  2.27094911e+00,
         3.36422278e+00,  1.78534244e+01,  1.20299390e+00,  5.55181070e-01,
         4.21369393e-01,  1.37365814e+00,  1.31104855e+00,  6.53499228e-01,
         1.06909385e+00,  4.47033155e-01,  4.28367023e-01,  6.71126935e-01,
         4.75198372e-01,  7.88802257e-01,  1.04657729e+00,  1.10710247e+00,
         1.94976862e+00,  1.78492376e-01,  1.45519489e+00,  9.08820254e-01,
         1.85387254e+00,  8.02403590e-01,  1.60561684e-01,  1.27950801e+00,
         2.7

In [15]:
# Flatten and replace value (similar to ValueGradFunction in pm.Model)

# theta = at.vector(name='theta')
# b = at.vector(name='b')
# hessian = m.d2logp()
# vars = pm.aesaraf.cont_inputs(hessian)
# hvp = hessian @ b

# split_point = np.concatenate([
#     np.asarray([0]), 
#     np.cumsum([
#         np.prod(v) 
#         for _, v, _ in q.point_map_info
#     ])
# ], axis=-1).astype(int)
# vars_replace = []
# for i, (_, v, _) in enumerate(q.point_map_info):
#     vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
# hvp_clone = aesara.clone_replace(hvp, dict(zip(vars, vars_replace)))

# hvp_fn = pm.aesaraf.compile_pymc([theta, b], [hvp_clone])
# hvp_fn(q.data, q.data)

In [46]:
# Flatten and replace value (similar to ValueGradFunction in pm.Model)
theta = at.vector(name='theta')
split_point = np.concatenate([
    np.asarray([0]), 
    np.cumsum([
        np.prod(v) 
        for _, v, _ in q.point_map_info
    ])
], axis=-1).astype(int)
vars_replace = []
for i, (_, v, _) in enumerate(q.point_map_info):
    vars_replace.append(at.reshape(theta[split_point[i]:split_point[i+1]], v))
logp_clone = aesara.clone_replace(m.logp(), dict(zip(vars, vars_replace)))

# test_fn = pm.aesaraf.compile_pymc([theta], [logp_clone])
# test_fn(q.data)

gradient = aesara.grad(logp_clone, theta)
hessian = -aesara.gradient.jacobian(gradient, theta)
# hessian = aesara.gradient.hessian(logp_clone, theta)

test_fn = aesara.function([theta], [hessian])
test_fn(q.data)

# b = at.vector(name='b')
# vars = m.value_vars

# # hessian = pm.aesaraf.hessian(m.logp(), None)
# def hessian_fn(f, vars=None):
#     if vars is None:
#         vars = pm.aesaraf.cont_inputs(f)

#     gradient = at.concatenate([pm.aesaraf.gradient1(f, v) for v in vars], axis=0)
#     jacobian = at.concatenate([pm.aesaraf.jacobian1(gradient, v) for v in vars], axis=1)
#     return -jacobian

# hessian = hessian_fn(m.logp(), vars=m.value_vars)
# hessian = m.logp()


# # hvp = hessian_clone @ b

# hvp_fn = pm.aesaraf.compile_pymc([theta], [hessian_clone])
# hvp_fn(q.data)

ERROR (aesara.graph.opt): Optimization failure due to: local_IncSubtensor_serialize
ERROR (aesara.graph.opt): node: Elemwise{add,no_inplace}(Elemwise{add,no_inplace}.0, AdvancedIncSubtensor{inplace=False,  set_instead_of_inc=False}.0)
ERROR (aesara.graph.opt): TRACEBACK:
ERROR (aesara.graph.opt): Traceback (most recent call last):
  File "/Users/junpenglao/Documents/OSS/aesara/aesara/graph/opt.py", line 1861, in process_node
    replacements = lopt.transform(fgraph, node)
  File "/Users/junpenglao/Documents/OSS/aesara/aesara/graph/opt.py", line 1066, in transform
    return self.fn(fgraph, node)
  File "/Users/junpenglao/Documents/OSS/aesara/aesara/tensor/subtensor_opt.py", line 1203, in local_IncSubtensor_serialize
    assert mi.owner.inputs[0].type.is_super(tip.type)
AssertionError



AssertionError: 
Apply node that caused the error: for{cpu,scan_fn}(Shape_i{0}.0, Subtensor{int64:int64:int8}.0, Shape_i{0}.0, (d__logp/dtheta), MakeVector{dtype='int64'}.0, eps_log___log, InplaceDimShuffle{x}.0, Elemwise{Composite{(i0 - (i1 + (i2 * i3)))}}[(0, 1)].0, InplaceDimShuffle{x}.0, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{sub,no_inplace}.0, Elemwise{sqr,no_inplace}.0, Elemwise{neg,no_inplace}.0, Elemwise{Mul}[(0, 0)].0, Elemwise{true_div,no_inplace}.0, Elemwise{mul,no_inplace}.0, Elemwise{switch,no_inplace}.0, Elemwise{Composite{Switch(i0, (i1 / i2), i3)}}.0, Elemwise{Composite{(Switch(i0, ((i1 * i2) / i3), i4) + (i5 / i2) + (i6 / i2))}}[(0, 5)].0, Alloc.0, Elemwise{sqr,no_inplace}.0, Elemwise{Sqr}[(0, 0)].0, Elemwise{mul,no_inplace}.0, Elemwise{neg,no_inplace}.0, Elemwise{neg,no_inplace}.0, MakeVector{dtype='int64'}.0, sigma_b_log___log, Elemwise{sqr,no_inplace}.0, Elemwise{mul,no_inplace}.0, Elemwise{true_div,no_inplace}.0, Elemwise{switch,no_inplace}.0, Elemwise{Composite{(Switch(i0, (i1 * i2), i3) + (i4 / i2) + (i5 / i2))}}[(0, 4)].0, Elemwise{sqr,no_inplace}.0, Elemwise{mul,no_inplace}.0, Elemwise{TrueDiv}[(0, 0)].0, MakeVector{dtype='int64'}.0, MakeVector{dtype='int64'}.0, sigma_a_log___log, Elemwise{sqr,no_inplace}.0, Elemwise{mul,no_inplace}.0, Elemwise{true_div,no_inplace}.0, Elemwise{switch,no_inplace}.0, Elemwise{Composite{(Switch(i0, (i1 * i2), i3) + (i4 / i2) + (i5 / i2))}}[(0, 4)].0, Elemwise{sqr,no_inplace}.0, Elemwise{mul,no_inplace}.0, Elemwise{TrueDiv}[(0, 0)].0, MakeVector{dtype='int64'}.0)
Toposort index: 112
Inputs types: [TensorType(int64, ()), TensorType(int64, (None,)), TensorType(int64, ()), TensorType(float64, (None,)), TensorType(int64, (1,)), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(int64, (1,)), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(int64, (1,)), TensorType(int64, (1,)), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(float64, (None,)), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, ()), TensorType(float64, (1,)), TensorType(float64, (None,)), TensorType(int64, (1,))]

HINT: Use a linker other than the C linker to print the inputs' shapes and strides.
HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.