# Advanced Examples

In [1]:
%config InlineBackend.print_figure_kwargs = {'bbox_inches': 'tight', 'dpi': 110}
%load_ext autoreload
%autoreload 2
import logging, warnings
logging.getLogger("pymc").setLevel(logging.FATAL)
warnings.filterwarnings("ignore")

## PyMC

The [Example](example.html) page introduces how to use *muse-inference* for a problem defined with PyMC. Here we consider a more complex problem to highlight additional features. In particular:

* We can estimate any number of parameters with any shapes. Here we have a 2-dimensional array $\mu$ and a scalar $\theta$. Note that by default, *muse-inference* considers any variables which do not depend on others as "parameters" (i.e. the "leaves" of the probabilistic graph). However, the algorithm is not limited to such parameters, and any choice can be selected by providing a list of `params` to the `PyMCMuseProblem` constructor.

* We can work with distributions with limited domain support. For example, below we use the $\rm Beta$ distribution with support on $(0,1)$ and the $\rm LogNormal$ distribution with support on $(0,\infty)$. All necessary transformations are handled internally.

* The data and latent space can include any number of variables, with any shapes. Below we demonstrate an $x$ and $z$ which are 2-dimensional arrays. 

First, load the relevant packages:

In [2]:
%pylab inline
import pymc as pm
from muse_inference.pymc import PyMCMuseProblem

Populating the interactive namespace from numpy and matplotlib


Then define the problem,

In [3]:
def gen_funnel(x=None, σ=None, μ=None, rng=None):
    with pm.Model() as model:
        μ = pm.Beta("μ", 2, 5, size=2) if μ is None else μ
        σ = pm.Normal("σ", 0, 3) if σ is None else σ
        z = pm.LogNormal("z", μ, np.exp(σ/2), size=(100, 2))
        x = pm.Normal("x", z, 1, observed=x)
    return model

generate the model and some data, given some chosen true values of parameters,

In [4]:
θ_true = dict(μ=[0.3, 0.7], σ=1)
with gen_funnel(rng=RandomState(0), **θ_true):
    x_obs = pm.sample_prior_predictive(1, random_seed=0).prior.x[0,0]
model = gen_funnel(x=x_obs)
prob = PyMCMuseProblem(model)

and finally, run MUSE:

In [5]:
θ_start = dict(μ=[0.5, 0.5], σ=0)
result = prob.solve(θ_start=θ_start, progress=True)

MUSE:   0%|          | 0/5050 [00:00<?, ?it/s]

MUSE:   0%|          | 9/5050 [00:00<01:00, 83.12it/s]

MUSE:   0%|          | 18/5050 [00:00<00:57, 86.77it/s]

MUSE:   1%|          | 29/5050 [00:00<00:53, 94.58it/s]

MUSE:   1%|          | 39/5050 [00:00<00:56, 88.84it/s]

MUSE:   1%|          | 49/5050 [00:00<00:55, 90.79it/s]

MUSE:   1%|          | 59/5050 [00:00<00:56, 87.98it/s]

MUSE:   1%|▏         | 70/5050 [00:00<00:53, 93.07it/s]

MUSE:   2%|▏         | 80/5050 [00:00<00:54, 91.90it/s]

MUSE:   2%|▏         | 92/5050 [00:00<00:50, 97.43it/s]

MUSE:   2%|▏         | 102/5050 [00:01<00:59, 83.43it/s]

MUSE:   2%|▏         | 111/5050 [00:01<01:07, 72.87it/s]

MUSE:   2%|▏         | 119/5050 [00:01<01:07, 73.23it/s]

MUSE:   3%|▎         | 127/5050 [00:01<01:05, 74.70it/s]

MUSE:   3%|▎         | 135/5050 [00:01<01:06, 73.49it/s]

MUSE:   3%|▎         | 143/5050 [00:01<01:06, 73.69it/s]

MUSE:   3%|▎         | 151/5050 [00:01<01:11, 68.48it/s]

MUSE:   3%|▎         | 158/5050 [00:01<01:13, 66.72it/s]

MUSE:   3%|▎         | 165/5050 [00:02<01:17, 63.21it/s]

MUSE:   3%|▎         | 172/5050 [00:02<01:16, 63.63it/s]

MUSE:   4%|▎         | 179/5050 [00:02<01:16, 63.99it/s]

MUSE:   4%|▎         | 187/5050 [00:02<01:14, 64.86it/s]

MUSE:   4%|▍         | 194/5050 [00:02<01:21, 59.79it/s]

MUSE:   4%|▍         | 201/5050 [00:02<01:22, 58.93it/s]

MUSE:   4%|▍         | 207/5050 [00:02<01:28, 54.72it/s]

MUSE:   4%|▍         | 215/5050 [00:02<01:19, 61.01it/s]

MUSE:   4%|▍         | 223/5050 [00:03<01:13, 65.25it/s]

MUSE:   5%|▍         | 232/5050 [00:03<01:09, 68.84it/s]

MUSE:   5%|▍         | 239/5050 [00:03<01:10, 68.42it/s]

MUSE:   5%|▍         | 247/5050 [00:03<01:08, 69.93it/s]

MUSE:   5%|▌         | 256/5050 [00:03<01:06, 72.27it/s]

MUSE:   5%|▌         | 265/5050 [00:03<01:03, 74.77it/s]

MUSE:   5%|▌         | 276/5050 [00:03<00:58, 81.75it/s]

MUSE:   6%|▌         | 285/5050 [00:03<00:57, 82.17it/s]

MUSE:   6%|▌         | 294/5050 [00:03<00:56, 83.74it/s]

MUSE:   6%|▌         | 303/5050 [00:04<01:03, 75.20it/s]

MUSE:   6%|▌         | 311/5050 [00:04<01:03, 74.42it/s]

MUSE:   6%|▋         | 323/5050 [00:04<00:54, 86.26it/s]

MUSE:   7%|▋         | 339/5050 [00:04<00:44, 106.39it/s]

MUSE:   7%|▋         | 353/5050 [00:04<00:41, 113.54it/s]

MUSE:   7%|▋         | 365/5050 [00:04<00:40, 114.37it/s]

MUSE:   7%|▋         | 377/5050 [00:04<00:40, 114.13it/s]

MUSE:   8%|▊         | 391/5050 [00:04<00:38, 121.37it/s]

MUSE:   8%|▊         | 405/5050 [00:04<00:40, 113.80it/s]

MUSE:   9%|▉         | 449/5050 [00:05<00:22, 201.29it/s]

MUSE:  10%|▉         | 483/5050 [00:05<00:19, 238.78it/s]

MUSE:  10%|█         | 508/5050 [00:05<00:19, 233.75it/s]

MUSE:  11%|█         | 535/5050 [00:05<00:18, 243.72it/s]

MUSE:  11%|█         | 565/5050 [00:05<00:17, 257.56it/s]

MUSE:  12%|█▏        | 596/5050 [00:05<00:16, 271.21it/s]

MUSE:  12%|█▏        | 624/5050 [00:05<00:17, 249.50it/s]

MUSE:  13%|█▎        | 650/5050 [00:05<00:17, 246.66it/s]

MUSE:  13%|█▎        | 676/5050 [00:05<00:17, 243.59it/s]

MUSE:  14%|█▍        | 701/5050 [00:06<00:17, 242.12it/s]

MUSE:  14%|█▍        | 726/5050 [00:06<00:19, 217.14it/s]

MUSE:  15%|█▌        | 759/5050 [00:06<00:17, 246.40it/s]

MUSE:  16%|█▌        | 800/5050 [00:06<00:14, 289.67it/s]

MUSE: 100%|██████████| 5050/5050 [00:06<00:00, 783.30it/s]




get_H:   0%|          | 0/70 [00:00<?, ?it/s]

get_H:  11%|█▏        | 8/70 [00:00<00:00, 77.15it/s]

get_H:  30%|███       | 21/70 [00:00<00:00, 105.92it/s]

get_H:  47%|████▋     | 33/70 [00:00<00:00, 109.98it/s]

get_H:  71%|███████▏  | 50/70 [00:00<00:00, 125.91it/s]

get_H:  90%|█████████ | 63/70 [00:00<00:00, 119.74it/s]

get_H: 100%|██████████| 70/70 [00:00<00:00, 115.91it/s]




When there are multiple parameters, the starting guess should be specified as as a dictionary, as above.

The parameter estimate is returned as a dictionary,

In [6]:
result.θ

{'μ': array([0.40806398, 0.41726603]), 'σ': array(0.91835154)}

 and the covariance as matrix, with parameters concatenated in the order they appear in the model (or in the order specified in `params`, if that was used):

In [7]:
result.Σ

array([[ 0.02243357,  0.00089992, -0.00288293],
       [ 0.00089992,  0.03016997, -0.00925313],
       [-0.00288293, -0.00925313,  0.02195891]])

The `result.ravel` and `result.unravel` functions can be used to convert between dictionary and vector representations of the parameters. For example, to compute the standard deviation for each parameter (the square root of the diagonal of the covariance):

In [8]:
result.unravel(np.sqrt(np.diag(result.Σ)))

{'μ': array([0.14977839, 0.17369504]), 'σ': array(0.14818539)}

or to convert the mean parameters to a vector:

In [9]:
result.ravel(result.θ)

array([0.40806398, 0.41726603, 0.91835154])

## Jax

We can also use [Jax](https://jax.readthedocs.io/) to define the problem. In this case we will write out function to generate forward samples and to compute the posterior, and Jax will provide necessary gradients for free. To use Jax, load the necessary packages:

In [10]:
from functools import partial
import jax
import jax.numpy as jnp
from muse_inference.jax import JittableJaxMuseProblem, JaxMuseProblem

Let's implement the noisy funnel problem from the [Example](example.html) page. To do so, extend either `JaxMuseProblem`, or, if your code is able to be JIT compiled by Jax, extend `JittableJaxMuseProblem` and decorate the functions with `jax.jit`:

In [11]:
class JaxFunnelMuseProblem(JittableJaxMuseProblem):

    def __init__(self, N):
        super().__init__()
        self.N = N

    @partial(jax.jit, static_argnums=0)
    def sample_x_z(self, key, θ):
        keys = jax.random.split(key, 2)
        z = jax.random.normal(keys[0], (self.N,)) * jnp.exp(θ/2)
        x = z + jax.random.normal(keys[1], (self.N,))
        return (x, z)

    @partial(jax.jit, static_argnums=0)
    def logLike(self, x, z, θ):
        return -(jnp.sum((x - z)**2) + jnp.sum(z**2) / jnp.exp(θ) + 512*θ) / 2

    @partial(jax.jit, static_argnums=0)
    def logPrior(self, θ):
        return -θ**2 / (2*3**2)

Now generate some simulated data, which we set into `prob.x`. Note also the use of `PRNGKey` (rather than `RandomState` for PyMC/Numpy) for random number generation. 

In [12]:
prob = JaxFunnelMuseProblem(10000)
key = jax.random.PRNGKey(0)
(x, z) = prob.sample_x_z(key, 0)
prob.x = x



And finally, run MUSE:

In [13]:
prob.solve(θ_start=0., rng=jax.random.PRNGKey(1)) # warmup

<muse_inference.muse_inference.MuseResult at 0x7f5d04454d60>

In [14]:
result = prob.solve(θ_start=0., rng=jax.random.PRNGKey(1), progress=True)

MUSE:   0%|          | 0/5050 [00:00<?, ?it/s]

MUSE:   2%|▏         | 102/5050 [00:00<00:05, 930.35it/s]

MUSE:   4%|▍         | 203/5050 [00:00<00:06, 750.96it/s]

MUSE:   6%|▌         | 304/5050 [00:00<00:05, 794.77it/s]

MUSE:   8%|▊         | 404/5050 [00:00<00:05, 863.65it/s]

MUSE: 100%|██████████| 5050/5050 [00:00<00:00, 8883.46it/s]




get_H:   0%|          | 0/30 [00:00<?, ?it/s]

get_H: 100%|██████████| 30/30 [00:00<00:00, 398.59it/s]




Note that the solution here is obtained around 10X faster that the PyMC version of this in the [Example](example.html) page (the cloud machines which build these docs don't always achieve the 10X, but you see this if you run these examples locally). The Jax interface has much lower overhead, which will be noticeable for very fast posteriors like the one above. 

One convenient aspect of using Jax is that the parameters, `θ`, and latent space, `z`, can be any [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), ie tuples, dictionaries, nested combinations of them, etc... (there is no requirement on the data format of the `x` variable). To demonstrate, consider a problem which is just two copies of the noisy funnel problem:

In [15]:
class JaxPyTreeFunnelMuseProblem(JittableJaxMuseProblem):

    def __init__(self, N):
        super().__init__()
        self.N = N

    @partial(jax.jit, static_argnums=0)
    def sample_x_z(self, key, θ):
        (θ1, θ2) = (θ["θ1"], θ["θ2"])
        keys = jax.random.split(key, 4)
        z1 = jax.random.normal(keys[0], (self.N,)) * jnp.exp(θ1/2)
        z2 = jax.random.normal(keys[1], (self.N,)) * jnp.exp(θ2/2)        
        x1 = z1 + jax.random.normal(keys[2], (self.N,))
        x2 = z2 + jax.random.normal(keys[3], (self.N,))        
        return ({"x1":x1, "x2":x2}, {"z1":z1, "z2":z2})

    @partial(jax.jit, static_argnums=0)
    def logLike(self, x, z, θ):
        return (
            -(jnp.sum((x["x1"] - z["z1"])**2) + jnp.sum(z["z1"]**2) / jnp.exp(θ["θ1"]) + 512*θ["θ1"]) / 2
            -(jnp.sum((x["x2"] - z["z2"])**2) + jnp.sum(z["z2"]**2) / jnp.exp(θ["θ2"]) + 512*θ["θ2"]) / 2
        )

    @partial(jax.jit, static_argnums=0)
    def logPrior(self, θ):
        return - θ["θ1"]**2 / (2*3**2) - θ["θ2"]**2 / (2*3**2)

Here, `x`, `θ`, and `z` are all dictionaries. We generate the problem as usual, passing in parameters as dictionaries,

In [16]:
θ_true = dict(θ1=-1., θ2=2.)
θ_start = dict(θ1=0., θ2=0.)

In [17]:
prob = JaxPyTreeFunnelMuseProblem(10000)
key = jax.random.PRNGKey(0)
(x, z) = prob.sample_x_z(key, θ_true)
prob.x = x

and run MUSE:

In [18]:
prob.solve(θ_start=θ_start, rng=jax.random.PRNGKey(0)) # warmup

<muse_inference.muse_inference.MuseResult at 0x7f5d065253a0>

In [19]:
result = prob.solve(θ_start=θ_start, rng=jax.random.PRNGKey(0), progress=True)

MUSE:   0%|          | 0/5050 [00:00<?, ?it/s]

MUSE:   2%|▏         | 102/5050 [00:00<00:10, 465.35it/s]

MUSE:   3%|▎         | 153/5050 [00:00<00:10, 482.92it/s]

MUSE:   4%|▍         | 203/5050 [00:00<00:14, 328.58it/s]

MUSE:   5%|▍         | 240/5050 [00:00<00:14, 328.64it/s]

MUSE:   5%|▌         | 276/5050 [00:00<00:14, 328.69it/s]

MUSE:   6%|▌         | 311/5050 [00:01<00:19, 239.67it/s]

MUSE:   7%|▋         | 346/5050 [00:01<00:17, 262.68it/s]

MUSE:   8%|▊         | 381/5050 [00:01<00:16, 282.45it/s]

MUSE:   8%|▊         | 413/5050 [00:01<00:21, 215.83it/s]

MUSE:   9%|▉         | 445/5050 [00:01<00:19, 236.89it/s]

MUSE:  10%|▉         | 482/5050 [00:01<00:17, 266.86it/s]

MUSE:  10%|█         | 513/5050 [00:01<00:21, 211.04it/s]

MUSE:  11%|█         | 551/5050 [00:01<00:18, 246.16it/s]

MUSE:  12%|█▏        | 589/5050 [00:02<00:16, 276.55it/s]

MUSE:  12%|█▏        | 621/5050 [00:02<00:20, 218.50it/s]

MUSE:  13%|█▎        | 659/5050 [00:02<00:17, 252.05it/s]

MUSE:  14%|█▍        | 697/5050 [00:02<00:15, 281.21it/s]

MUSE:  14%|█▍        | 730/5050 [00:02<00:19, 223.37it/s]

MUSE:  15%|█▌        | 768/5050 [00:02<00:16, 256.62it/s]

MUSE:  16%|█▌        | 806/5050 [00:02<00:14, 284.78it/s]

MUSE:  17%|█▋        | 839/5050 [00:03<00:19, 221.19it/s]

MUSE:  17%|█▋        | 874/5050 [00:03<00:16, 247.37it/s]

MUSE:  18%|█▊        | 907/5050 [00:03<00:15, 265.01it/s]

MUSE:  19%|█▊        | 938/5050 [00:03<00:19, 211.08it/s]

MUSE:  19%|█▉        | 970/5050 [00:03<00:17, 233.23it/s]

MUSE:  20%|█▉        | 1006/5050 [00:03<00:15, 262.26it/s]

MUSE:  21%|██        | 1036/5050 [00:04<00:18, 216.82it/s]

MUSE:  22%|██▏       | 1088/5050 [00:04<00:13, 283.13it/s]

MUSE:  22%|██▏       | 1122/5050 [00:04<00:16, 240.27it/s]

MUSE:  24%|██▎       | 1187/5050 [00:04<00:11, 328.95it/s]

MUSE:  24%|██▍       | 1227/5050 [00:04<00:13, 280.82it/s]

MUSE:  25%|██▌       | 1282/5050 [00:04<00:11, 337.14it/s]

MUSE:  26%|██▌       | 1322/5050 [00:04<00:13, 286.38it/s]

MUSE:  27%|██▋       | 1378/5050 [00:05<00:10, 344.83it/s]

MUSE:  28%|██▊       | 1419/5050 [00:05<00:12, 285.89it/s]

MUSE:  29%|██▉       | 1476/5050 [00:05<00:10, 345.31it/s]

MUSE:  30%|███       | 1517/5050 [00:05<00:12, 282.93it/s]

MUSE:  31%|███       | 1573/5050 [00:05<00:10, 339.78it/s]

MUSE:  32%|███▏      | 1617/5050 [00:05<00:11, 290.87it/s]

MUSE:  33%|███▎      | 1676/5050 [00:05<00:09, 351.60it/s]

MUSE:  34%|███▍      | 1718/5050 [00:06<00:11, 292.17it/s]

MUSE:  35%|███▌      | 1776/5050 [00:06<00:09, 351.14it/s]

MUSE:  36%|███▌      | 1819/5050 [00:06<00:10, 296.39it/s]

MUSE:  37%|███▋      | 1863/5050 [00:06<00:09, 325.85it/s]

MUSE:  38%|███▊      | 1907/5050 [00:06<00:08, 350.96it/s]

MUSE:  39%|███▊      | 1947/5050 [00:06<00:10, 284.05it/s]

MUSE:  39%|███▉      | 1992/5050 [00:06<00:09, 319.52it/s]

MUSE:  40%|████      | 2029/5050 [00:07<00:11, 266.72it/s]

MUSE:  42%|████▏     | 2098/5050 [00:07<00:08, 356.50it/s]

MUSE:  42%|████▏     | 2141/5050 [00:07<00:09, 305.26it/s]

MUSE:  44%|████▍     | 2212/5050 [00:07<00:07, 391.27it/s]

MUSE:  45%|████▍     | 2259/5050 [00:07<00:08, 326.78it/s]

MUSE:  46%|████▌     | 2324/5050 [00:08<00:08, 306.05it/s]

MUSE:  47%|████▋     | 2369/5050 [00:08<00:08, 332.93it/s]

MUSE:  48%|████▊     | 2413/5050 [00:08<00:07, 354.94it/s]

MUSE:  49%|████▊     | 2454/5050 [00:08<00:09, 287.37it/s]

MUSE:  50%|████▉     | 2500/5050 [00:08<00:07, 322.21it/s]

MUSE:  50%|█████     | 2538/5050 [00:08<00:09, 269.07it/s]

MUSE:  51%|█████     | 2583/5050 [00:08<00:08, 306.34it/s]

MUSE: 100%|██████████| 5050/5050 [00:09<00:00, 4487.09it/s]

MUSE: 100%|██████████| 5050/5050 [00:09<00:00, 554.27it/s] 




get_H:   0%|          | 0/50 [00:00<?, ?it/s]

get_H:  40%|████      | 20/50 [00:00<00:00, 193.35it/s]

get_H:  82%|████████▏ | 41/50 [00:00<00:00, 199.35it/s]

get_H: 100%|██████████| 50/50 [00:00<00:00, 195.24it/s]




The result is returned as a pytree:

In [20]:
result.θ

{'θ1': DeviceArray(-1.0020632, dtype=float32),
 'θ2': DeviceArray(2.027134, dtype=float32)}

and the covariance as a matrix:

In [21]:
result.Σ

array([[ 3.3823610e-03, -4.6946143e-06],
       [-4.6946157e-06,  2.4767642e-04]], dtype=float32)

The `result.ravel` and `result.unravel` functions can be used to convert between pytree and vector representations of the parameters. For example, to compute the standard deviation for each parameter (the square root of the diagonal of the covariance):

In [22]:
result.unravel(np.sqrt(np.diag(result.Σ)))

{'θ1': DeviceArray(0.05815807, dtype=float32),
 'θ2': DeviceArray(0.01573774, dtype=float32)}

or to convert the mean parameters to a vector:

In [23]:
result.ravel(result.θ)

DeviceArray([-1.0020632,  2.027134 ], dtype=float32)