In [1]:
import numpy as np
import scipy.stats as st

import pymc as pm
import arviz as az

import aesara
import aesara.tensor as aet

import aeppl

In [2]:
with pm.Model() as funnel:
    θ = pm.Normal("θ", 0, 3)
    z = pm.Normal("z", 0, aet.exp(θ / 2), size=512)
    x = pm.Normal("x", z, 1)

Generate data for `x` condition on some true $\theta$. There are a few ways to do it as explained in https://github.com/pymc-devs/pymc/discussions/5280

In [3]:
sample_x_z: callable = aesara.function([θ], [z, x])
_, x_obs = sample_x_z(0)

## Define a regular PyMC model that conditioned on some observation

In [4]:
with pm.Model() as funnel:
    θ = pm.Normal("θ", 0, 3)
    z = pm.Normal("z", 0, aet.exp(θ / 2), size=512)
    x = pm.Normal("x", z, 1, observed=x_obs)

## Forward sampling function (for generating `x`)

In [5]:
# sample_x_z = aesara.function([θ], [z, x])
model_graph = pm.model_graph.ModelGraph(funnel)
theta = [var for var in funnel.basic_RVs if model_graph.get_parents(var) == set()]
# z_x = list(set(funnel.value_vars) - set(theta))  # Not doing this as the order is unpredictable
latent_field = [var for var in funnel.free_RVs if var not in theta]
z_x: list = funnel.observed_RVs + latent_field

sample_x_z: callable = aesara.function(theta, z_x)

In [6]:
funnel.basic_RVs

[θ, z, x]

In [7]:
theta, z_x

([θ], [x, z])

In [8]:
output_test = sample_x_z(1.)
[v.shape for v in output_test]

[(512,), (512,)]

Alternative `sample_x` that only output the simulation of the observed (`x`)

In [9]:
sample_x: callable = aesara.function(theta, funnel.observed_RVs)
output_test = sample_x(1.)
[v.shape for v in output_test]

[(512,)]

## Likelihood function `logP(x,z|θ)`

In [10]:
ordered_input_var = latent_field + funnel.observed_RVs + theta  # keep the order of z, x, θ
new_input_var = []
for var in ordered_input_var:
    if var in funnel.observed_RVs:
        new_input_var.append(x.type())
    else:
        new_input_var.append(funnel.rvs_to_values[var])

rv_pairs = dict(zip(ordered_input_var, new_input_var))
new_theta = [funnel.rvs_to_values[var] for var in theta]

logprob_tensor: dict = aeppl.factorized_joint_logprob(rv_pairs)

In [11]:
logpt = aet.sum([logprob_tensor[var] for var in new_input_var if var not in new_theta])
compile_logp_fn = aesara.function(new_input_var, [logpt])

In [12]:
z_val = np.zeros(512)
x_val = np.ones(512)

print(compile_logp_fn(z_val, x_val, 0.))

print(
   st.norm.logpdf(z_val, 0., np.exp(0 / 2)).sum() 
 + st.norm.logpdf(x_val, z_val, 1.).sum()
)

[array(-1196.993058)]
-1196.9930580015848


### As comparison, the full posterior (and its gradient)

In [13]:
logpt = aeppl.joint_logprob(rv_pairs)
compile_logp_fn = aesara.function(new_input_var, [logpt])
print(compile_logp_fn(z_val, x_val, 0.))

print(
   st.norm.logpdf(0., 0., 3.) 
 + st.norm.logpdf(z_val, 0., np.exp(0 / 2)).sum() 
 + st.norm.logpdf(x_val, z_val, 1.).sum()
     )

[array(-1199.01060882)]
-1199.0106088234575


## ∇θ_logLike is gradient of θ -> logP(x,z|θ)

In [14]:
# grad_tensor = aesara.grad(logpt, new_input_var)
grad_theta_tensor = aesara.grad(
    logpt, 
    wrt=[rv_pairs[var] for var in theta], 
    consider_constant=[rv_pairs[var] for var in latent_field + funnel.observed_RVs]
)
compile_grad_theta_fn = aesara.function(new_input_var, grad_theta_tensor)
compile_grad_theta_fn(np.zeros(512), np.ones(512), 0.)

[array(-256.)]

## Prior function `logP(θ)`

In [15]:
# theta_prior_logpt = aeppl.joint_logprob({var: funnel.rvs_to_values[var] for var in theta})

theta_prior_logpt = aet.sum([logprob_tensor[var] for var in new_theta])
grad_theta_tensor = aesara.grad(
    theta_prior_logpt, 
    wrt=new_theta
)
compile_grad_theta_prior_fn: callable = aesara.function(new_theta, grad_theta_tensor)
compile_grad_theta_prior_fn(1.)

[array(-0.11111111)]