In [1]:
import numpy as np
import scipy.stats as st

import pymc as pm
import arviz as az

import aesara
import aesara.tensor as at

In [2]:
SIMPLE_EXAMPLE = False

In [3]:
def gen_model(SIMPLE_EXAMPLE, x_obs=None):
    if SIMPLE_EXAMPLE:
        with pm.Model() as funnel:
            θ = pm.Normal("θ", 0, 3)
            z = pm.Normal("z", 0, at.exp(θ / 2), size=512)
            x = pm.Normal("x", z, 1, observed=x_obs)
    else:
        with pm.Model() as funnel:
            θ0 = pm.Normal("θ0", 0, 3)
            z0 = pm.Normal("z0", 0, at.exp(θ0 / 2), size=256)
            θ1 = pm.HalfNormal("θ1", 5)
            z1 = pm.Normal("z1", 0, θ1, size=(128, 2))
            σ = pm.HalfNormal("σ", 1)
            x = pm.Normal("x", at.stack([z0, at.flatten(z1)]), σ, observed=x_obs)
    return funnel

m = gen_model(SIMPLE_EXAMPLE)

Generate data for `x` condition on some true $\theta$. There are a few ways to do it as explained in https://github.com/pymc-devs/pymc/discussions/5280

In [4]:
if SIMPLE_EXAMPLE:
    sample_x: callable = aesara.function([m.θ], [m.x])
    x_obs, = sample_x(1.)
else:
    sample_x: callable = aesara.function([m.θ0, m.θ1, m.σ], [m.x])
    x_obs, = sample_x(0., 1., 2.)

In [5]:
x_obs.shape

(2, 256)

## Define a regular PyMC model that conditioned on some observation

In [6]:
funnel = gen_model(SIMPLE_EXAMPLE, x_obs)

## Forward sampling function (for generating `x`)

In [7]:
# sample_x_z = aesara.function([θ], [z, x])
model_graph = pm.model_graph.ModelGraph(funnel)
# theta are priors with no parent
theta = [var for var in funnel.basic_RVs if model_graph.get_parents(var) == set()]
# The remaining free variables are z
latent_field = [var for var in funnel.free_RVs if var not in theta]

In [8]:
funnel.basic_RVs

[θ0, z0, θ1, z1, σ, x]

In [9]:
theta, latent_field

([θ0, θ1, σ], [z0, z1])

In [10]:
# z_x = list(set(funnel.value_vars) - set(theta))  # Not doing this as the order is unpredictable
z_x: list = funnel.observed_RVs + latent_field

sample_x_z: callable = aesara.function(theta, z_x)
theta_val = [v.eval() for v in theta]
output_test = sample_x_z(*theta_val)
[v.shape for v in output_test]

[(2, 256), (256,), (128, 2)]

Alternative `sample_x` that only output the simulation of the observed (`x`)

In [11]:
sample_x: callable = aesara.function(theta, funnel.observed_RVs)
theta_val = [v.eval() for v in theta]
output_test = sample_x(*theta_val)
[v.shape for v in output_test]

[(2, 256)]

## Likelihood function `logP(x,z|θ)` and Prior function `logP(θ)`

In [12]:
from pymc.distributions import logpt as joint_logpt

# Copy and small modification of self.logp_elemwiset in a pm.Model
def generate_logpt_allnodes(model, vars: list, jacobian: bool = True):
    """Elemwise log-probability of the input variables."""
    if model.potentials:
        raise Exception("Does not work with model that contains potentials")

    rv_values = {}
    for var in vars:
        if var in model.observed_RVs:
            value_var = var.type()
            value_var.name = var.name
        else:
            value_var = model.rvs_to_values[var]
        if value_var is not None:
            rv_values[var] = value_var
        else:
            raise ValueError(
                f"Requested variable {var} not found among the model variables"
            )

    rv_logps = joint_logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
    logpt_nodes = {}
    for k, logp in zip(rv_values.keys(), rv_logps):
        node_logp = logp.sum()
        node_logp.name = k.name + "_logpt"
        logpt_nodes[k] = node_logp

    return logpt_nodes, rv_values

In [13]:
# keep the order of z, x, θ
ordered_input_var: list = latent_field + funnel.observed_RVs + theta

logpt_nodes, rv_values = generate_logpt_allnodes(funnel, ordered_input_var)
input_var = [rv_values[v] for v in ordered_input_var]

In [14]:
logpt_nodes

{z0: z0_logpt,
 z1: z1_logpt,
 x: x_logpt,
 θ0: θ0_logpt,
 θ1: θ1_logpt,
 σ: σ_logpt}

In [15]:
rv_values

{z0: z0, z1: z1, x: x, θ0: θ0, θ1: θ1_log__, σ: σ_log__}

In [16]:
# Some testing
compile_logp_fn_per_node = aesara.function(
    input_var, [logpt_nodes[v] for v in ordered_input_var])
test_point = funnel.recompute_initial_point()

z_val = [test_point[rv_values[v].name] for v in latent_field]
θ_val = [test_point[rv_values[v].name] for v in theta]

x_val = [funnel.rvs_to_values[v].data for v in funnel.observed_RVs]

print(compile_logp_fn_per_node(*z_val, *x_val, *θ_val))
print(ordered_input_var)
print("\nFrom PyMC model itself")
print(funnel.point_logps(test_point))

[array(-235.2482645), array(-647.26437008), array(-1710.18395397), array(-2.01755082), array(-0.72579135), array(-0.72579135)]
[z0, z1, x, θ0, θ1, σ]

From PyMC model itself
θ0      -2.02
z0    -235.25
θ1      -0.73
z1    -647.26
σ       -0.73
x    -1710.18
Name: Point log-probability, dtype: float64


In [17]:
logpt_z_x = []
logpt_theta = []
input_theta = []
for var in ordered_input_var:
    if var not in theta:
        logpt_z_x.append(logpt_nodes[var])
    else:
        logpt_theta.append(logpt_nodes[var])
        input_theta.append(rv_values[var])

condition_logpt = at.sum(logpt_z_x)
compile_logp_fn = aesara.function(input_var, [condition_logpt])
theta_logpt = at.sum(logpt_theta)
compile_logp_fn_theta = aesara.function(input_theta, [theta_logpt])

# compare to full posterior as check
np.testing.assert_almost_equal(sum(compile_logp_fn(*z_val, *x_val, *θ_val) + compile_logp_fn_theta(*θ_val)),
                               funnel.logp(test_point))

## ∇θ_logLike and ∇θ_logPrior(θ)

∇θ_logLike is gradient of θ -> logP(x,z|θ)

In [18]:
grad_theta_logpt = aesara.grad(
    condition_logpt, 
    wrt=input_theta, 
    consider_constant=list(set(input_var) - set(input_theta))
)
compile_grad_theta_fn = aesara.function(input_var, grad_theta_logpt)
compile_grad_theta_fn(*z_val, *x_val, *θ_val)

[array(-128.), array(-256.), array(1967.37484993)]

∇θ_logPrior(θ) is gradient of θ -> logPrior(θ)

In [19]:
grad_theta_prior = aesara.grad(
    theta_logpt, 
    wrt=input_theta
)
compile_grad_theta_fn = aesara.function(input_theta, grad_theta_prior)
compile_grad_theta_fn(*θ_val)

[array(-0.), array(1.11022302e-16), array(0.)]

## zMAP that maximizes the function z -> logP(x,z|θ)

We can use [scipy.optimize.minimize](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html#scipy.optimize.minimize) to find the MAP.
While we have logP(x,z|θ) and gradient of z -> logP(x,z|θ) above, to make using scipy minimize easier we to some additional formatting so that:
- `cost_fun` to return a tuple (f, g) containing the objective function and the gradient
- Flatten `z` to an 1-D array with shape (n,)

Similar logic and API useage see [find_MAP in PyMC](https://github.com/pymc-devs/pymc/blob/1d7130d8cf6e419120e192f8308cf154f4c44074/pymc/tuning/starting.py#L148-L150) and [ValueGradFunction](https://github.com/pymc-devs/pymc/blob/5626a04a1e064ad615e1765e37bcb7ea52887ab7/pymc/model.py#L358)

In [20]:
# ∇z_logLike is gradient of z -> logP(x,z|θ) that will be used in optimization
input_z = []
input_x_theta = []
for var in ordered_input_var:
    if var in latent_field:
        input_z.append(rv_values[var])
    else:
        input_x_theta.append(rv_values[var])

# Flatten and replace value (similar to ValueGradFunction in pm.Model)
flatten_z = at.vector(name='flatten_z')
split_point = np.concatenate([np.asarray([0]), np.cumsum([v.size for v in z_val])], axis=-1)
z_replace = []
for i, np_val in enumerate(z_val):
    z_replace.append(at.reshape(flatten_z[split_point[i]:split_point[i+1]], np_val.shape))

mapping_fn = aesara.function([flatten_z], z_replace)
flatten_z_val = np.concatenate([v.ravel() for v in z_val], axis=-1)
# mapping_fn(flatten_z_val)

# We minimize the negative loglikelihood
condition_logpt_clone = -1.0 * aesara.clone_replace(condition_logpt, dict(zip(input_z, z_replace)))
grad_z_clone_tensor = aesara.grad(
    condition_logpt_clone,
    wrt=flatten_z, 
    consider_constant=input_x_theta
)
cost_fun_with_grad = aesara.function(
    [flatten_z] + input_x_theta, 
    [condition_logpt_clone, grad_z_clone_tensor])
value_test, grad_test = cost_fun_with_grad(flatten_z_val, *x_val, *θ_val)

In [21]:
# Some testing
grad_z_tensor = aesara.grad(
    -condition_logpt,
    wrt=input_z, 
    consider_constant=input_x_theta
)

output_tensors = [-condition_logpt] + grad_z_tensor
cost_fun_with_grad_ = aesara.function(input_var, output_tensors)
value_test2, *grad_test2 = cost_fun_with_grad_(*z_val, *x_val, *θ_val)

assert value_test == value_test2
_ = [np.testing.assert_almost_equal(v1, v2) for v1, v2 in zip(mapping_fn(grad_test), grad_test2)]

### Using PyMC idioms

But either have the exact input -> output signature we want

In [22]:
test_point['x'] = x_val[0]
# shared = pm.make_shared_replacements(test_point, input_z, funnel)
shared = {
    var: aesara.shared(test_point[var.name],
                       var.name + "_shared",
                       broadcastable=var.broadcastable)
    for var in input_x_theta
}
out_list, inarray0 = pm.join_nonshared_inputs(
    test_point, [-condition_logpt] + grad_z_tensor, input_var, shared)
cost_fun_with_grad_ = pm.aesaraf.compile_pymc([inarray0], out_list)
cost_fun_with_grad_.trust_input = True
value_test3, *grad_test3 = cost_fun_with_grad_(flatten_z_val)

assert value_test2 == value_test3
_ = [np.testing.assert_almost_equal(v1, v2) for v1, v2 in zip(grad_test2, grad_test3)]

In [23]:
extra_vars_and_values = {
    var: test_point[var.name]
    for var in input_x_theta
}
pm_val_grad_fn = pm.model.ValueGradFunction([-condition_logpt], input_z, extra_vars_and_values)
pm_val_grad_fn.set_extra_values(test_point)
value_test4, grad_test4 = pm_val_grad_fn(z_val)

assert value_test == value_test3
np.testing.assert_almost_equal(grad_test, grad_test4)

Once we have the logp fn that also output gradient, optimizing it is straightforward.

In [24]:
from scipy.optimize import minimize
def zmap_optimization(
    cost_fun_with_grad: callable,
    initial_z: list,
    x: list,
    theta: list,
    random_init=True,
    method='L-BFGS-B',
    **kwargs):
    x0 = np.concatenate([v.ravel() for v in initial_z], axis=-1)
    if random_init:
        x0 = np.random.randn(*x0.shape)
    return minimize(
        cost_fun_with_grad, x0, args=(*x, *theta), method=method, jac=True, **kwargs
    )

output = zmap_optimization(cost_fun_with_grad, z_val, x_val, θ_val, random_init=False)

In [25]:
zmap_val = mapping_fn(output.x)

In [26]:
compile_grad_theta_fn = aesara.function(input_var, grad_theta_logpt)
compile_grad_theta_fn(*zmap_val, *x_val, *θ_val)

[array(31.17454171), array(-211.40023662), array(-191.86660666)]

In [27]:
for i in range(10):
    x_sim = sample_x(*θ_val)
    output_sim = zmap_optimization(cost_fun_with_grad, z_val, x_sim, θ_val, random_init=False)
    zmap_sim = mapping_fn(output_sim.x)
    print(compile_grad_theta_fn(*zmap_sim, *x_sim, *θ_val))

[array(-95.47144803), array(-231.80552749), array(-445.97514265)]
[array(-96.52885436), array(-230.06259566), array(-448.02024561)]
[array(-95.29392648), array(-230.65255955), array(-445.57398436)]
[array(-92.6283119), array(-232.24828555), array(-440.3065717)]
[array(-98.54675517), array(-230.43545728), array(-452.07096384)]
[array(-98.24847932), array(-232.33774037), array(-451.55049804)]
[array(-97.99397109), array(-232.32691742), array(-451.04104818)]
[array(-95.82500258), array(-233.07840573), array(-446.7331629)]
[array(-94.24946974), array(-228.88255645), array(-443.41427403)]
[array(-94.35220437), array(-232.47478704), array(-443.76342034)]


In [28]:
# x_sim[0][:4], x_val[0][:4]

In [29]:
zmap_sim[0][:4], zmap_val[0][:4]

(array([ 0.07251638,  0.30974791, -0.45845313,  0.29817301]),
 array([ 0.63471893,  0.48443256,  0.06441641, -0.21877776]))

## Putting everything together

```python
θ = # initial guess for θ
H = # some guess for Hessian of θ -> logP(θ|x)

while norm(θ - θlast) < θtol:
    # a bunch of simulated x's generated from P(x,z|θ)
    x_sims = [sample_prior_predictive(θ).x for i in 1:nsims] 

    # zMAP maximizes the function z -> logP(x,z|θ)
    zMAP_data = zMAP(x, θ)
    zMAP_sims = [zMAP(x_sim, θ) for x_sim in x_sims]

    # ∇θ_logLike is gradient of θ -> logP(x,z|θ)
    g_data = ∇θ_logLike(zMAP_data, x, θ)
    g_sims = [∇θ_logLike(zMAP_sim, x_sim, θ) for (zMAP_sim, x_sim) in zip(zMAP_sims,x_sims)]

    # gradient of θ -> logP(θ)
    g_prior = ∇θ_logPrior(θ) 

    θlast = θ
    θ -= H \ (g_data - mean(g_sims) + g_prior)
```

In [30]:
def create_flatten_replace_var(values: list, name=''):
    flatten_var = at.vector(name='flatten_' + name)
    split_point = np.concatenate([np.asarray([0]), 
                                  np.cumsum([v.size for v in values])],
                                 axis=-1)
    replace_var = []
    for i, np_val in enumerate(values):
        replace_var.append(
            at.reshape(flatten_var[split_point[i]:split_point[i+1]], np_val.shape))

    flatten_var_value = np.concatenate([v.ravel() for v in values], axis=-1)
    return replace_var, flatten_var, flatten_var_value

In [31]:
import warnings
from scipy.optimize import OptimizeWarning

def MUSE(model: pm.Model, tol=1e-2, nsims=100, 
         max_iter=50,
         minimize_method='L-BFGS-B', 
         **minimize_kwargs):
    if model.potentials:
        raise Exception("MUSE does not work with model that contains potentials")
        
    # Catigorize variables into θ, z, x
    model_graph = pm.model_graph.ModelGraph(model)
    # θ are priors with no parent
    theta = [var for var in model.basic_RVs if model_graph.get_parents(var) == set()]
    # The remaining free variables are z
    z = [var for var in model.free_RVs if var not in theta]
    x: list = model.observed_RVs

    # Compute logp and gradient
    # For each node in the model, get the correspondent variable for computing logp
    rv_values = {}
    for var in z + x + theta:
        if var in x:
            value_var = var.type()
            value_var.name = var.name
        else:
            value_var = model.rvs_to_values[var]
        rv_values[var] = value_var

    rv_logps = joint_logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=True)
    logpt_nodes = {}
    for k, logp in zip(rv_values.keys(), rv_logps):
        node_logp = logp.sum()
        node_logp.name = k.name + "_logpt"
        logpt_nodes[k] = node_logp

    # A dict containing array so we can use to infer shape 
    test_point = model.recompute_initial_point()
    # θ, z, x in unbounded space (variable that actually used for computing logp)
    # and their test value
    input_theta = [rv_values[v] for v in theta]
    theta_val = [test_point[rv_values[v].name] for v in theta]
    input_z = [rv_values[v] for v in z]
    z_val = [test_point[rv_values[v].name] for v in z]
    input_x = [rv_values[v] for v in x]
    x_val = [model.rvs_to_values[v].data for v in x]
    
    # Flatten and concat θ into 1D tensors
    (replace_theta,
     flatten_theta,
     init_theta,
     ) = create_flatten_replace_var(theta_val, name='theta')
    # Replace theta in original space
    replace_theta_org = {}
    for org_var, input_var, replace_var in zip(theta, input_theta, replace_theta):
        if hasattr(input_var.tag, "transform"):
            replace_theta_org[org_var] = input_var.tag.transform.backward(replace_var)
        else:
            replace_theta_org[org_var] = replace_var
    # Function to sample x conditioned on θ
    x_clone = aesara.clone_replace(x, replace_theta_org)
    sample_x: callable = aesara.function([flatten_theta], x_clone)
    mapping_theta_fn = aesara.function([flatten_theta], list(replace_theta_org.values()))
    # _ = sample_x(init_theta)

    # Flatten z into 1D tensor
    (replace_z,
     flatten_z,
     init_z,
     ) = create_flatten_replace_var(z_val, name='z')

    # Prepare function for MAP estimate of z
    # logP(x,z|θ)
    condition_logpt = at.sum([logpt_nodes[var] for var in x + z])
    condition_logpt_clone = aesara.clone_replace(
        condition_logpt, dict(zip(input_z + input_theta,
                                  replace_z + replace_theta)))
    
    # We minimize the negative loglikelihood
    neg_condition_logpt = -1.0 * condition_logpt_clone
    grad_z_clone_tensor = aesara.grad(
        neg_condition_logpt,
        wrt=flatten_z, 
        consider_constant=input_x + input_theta
    )
    cost_fun_with_grad = aesara.function(
        [flatten_z] + input_x + [flatten_theta], 
        [neg_condition_logpt, grad_z_clone_tensor])

    # gradient of θ -> logP(x,z|θ)
    grad_theta_clone_tensor = aesara.grad(
        condition_logpt_clone,
        wrt=flatten_theta, 
        consider_constant=input_x + input_z
    )
    compile_grad_theta_fn = aesara.function(
        [flatten_z] + input_x + [flatten_theta],
        grad_theta_clone_tensor)
    
    # # testing
    # grad_test = aesara.grad(condition_logpt, wrt=input_theta, consider_constant=input_x + input_z)
    # test_fn = aesara.function(input_z + input_x + input_theta, grad_test)
    # map_fn_z = aesara.function([flatten_z], replace_z)
    # map_fn_θ = aesara.function([flatten_theta], replace_theta)
    # test_z = np.random.randn(*init_z.shape)
    # test_θ = np.random.randn(*init_theta.shape)
    # return compile_grad_theta_fn(test_z, *x_val, test_θ), test_fn(*map_fn_z(test_z), *x_val, *map_fn_θ(test_θ))
    
    # gradient of θ -> logP(θ)
    theta_logpt = at.sum([logpt_nodes[var] for var in theta])
    theta_logpt_clone = aesara.clone_replace(
        theta_logpt, dict(zip(input_theta, replace_theta)))
    grad_theta_prior = aesara.grad(
        theta_logpt_clone, 
        wrt=flatten_theta
    )
    compile_grad_theta_prior_fn = aesara.function([flatten_theta], grad_theta_prior)
    
    # testing
    # grad_test = aesara.grad(theta_logpt, wrt=input_theta)
    # test_fn = aesara.function(input_theta, grad_test)
    # return compile_grad_theta_prior_fn(init_theta), test_fn(*theta_val)

    # MUSE algorithm
    theta_est = np.random.randn(*init_theta.shape)
    last_theta = np.random.randn(*init_theta.shape)
    H = - np.eye(len(init_theta)) * 500.
    i = 0
    while (np.linalg.norm(theta_est - last_theta) > tol) & (i < max_iter):
        i += 1
        output = minimize(
            cost_fun_with_grad, init_z, args=(*x_val, theta_est), 
            method=minimize_method, jac=True, **minimize_kwargs
        )
        if not output.success:
            warnings.warn("zMAP did not converge.", OptimizeWarning)
        zMAP_data = output.x
        g_data = compile_grad_theta_fn(zMAP_data, *x_val, theta_est)
        g_prior = compile_grad_theta_prior_fn(theta_est)
        
        g_sims = np.zeros([nsims, *g_data.shape], dtype=g_data.dtype)
        for j in range(nsims):
            x_sim = sample_x(theta_est)
            output_sim = minimize(
                cost_fun_with_grad, init_z, args=(*x_sim, theta_est), 
                method=minimize_method, jac=True, **minimize_kwargs
            )
            zMAP_sim = output_sim.x
            g_sim = compile_grad_theta_fn(zMAP_sim, *x_sim, theta_est)
            g_sims[j] = g_sim

        expect_g_sim = np.mean(g_sims, axis=0)
        last_theta = theta_est
        theta_est = theta_est - np.linalg.solve(H, (g_data - expect_g_sim + g_prior))
        print(mapping_theta_fn(theta_est))

    return i, {
        org_var.name: est 
        for org_var, est in zip(
            replace_theta_org.keys(),
            mapping_theta_fn(theta_est))
    }

MUSE(funnel, nsims=100, tol=1e-5, max_iter=50)

[array(-1.10469299), array(0.26289206), array(7.74322965)]
[array(-1.1057464), array(0.26327243), array(2.67192363)]
[array(-1.10895806), array(0.26335883), array(1.91502821)]
[array(-1.10352368), array(0.26458809), array(2.46085678)]
[array(-1.10628583), array(0.26475283), array(1.93849005)]
[array(-1.10164844), array(0.26587059), array(2.39913608)]
[array(-1.10383368), array(0.26607752), array(1.97857484)]
[array(-1.1003633), array(0.26707612), array(2.33860552)]
[array(-1.1021535), array(0.26734759), array(2.01116737)]
[array(-1.09922355), array(0.26818032), array(2.27311521)]
[array(-1.10055724), array(0.26854053), array(2.04832752)]
[array(-1.09831612), array(0.26930316), array(2.24781228)]
[array(-1.09939085), array(0.26970165), array(2.06671234)]
[array(-1.09754161), array(0.27044131), array(2.23512953)]
[array(-1.09836194), array(0.27085112), array(2.07851889)]
[array(-1.09683527), array(0.27154901), array(2.20860284)]
[array(-1.09720508), array(0.27197305), array(2.09171319)]


(50,
 {'θ0': array(-1.08309009), 'θ1': array(0.29117404), 'σ': array(2.19508577)})