# User API

In [1]:
%config InlineBackend.print_figure_kwargs = {'bbox_inches': 'tight', 'dpi': 110}
%load_ext autoreload
%autoreload 2
import logging, warnings
logging.getLogger("pymc").setLevel(logging.FATAL)
warnings.filterwarnings("ignore")

## PyMC

The [Example](example.html) page introduces how to use *muse-inference* for a problem defined with PyMC. Here we consider a more complex problem to highlight additional features. In particular:

* We can estimate any number of parameters with any shapes. Here we have a 2-dimensional array $\mu$ and a scalar $\theta$. Note that by default, *muse-inference* considers any variables which do not depend on others as "parameters" (i.e. the "leaves" of the probabilistic graph). However, the algorithm is not limited to such parameters, and any choice can be selected by providing a list of `params` to the `PyMCMuseProblem` constructor.

* We can work with distributions with limited domain support. For example, below we use the $\rm Beta$ distribution with support on $(0,1)$ and the $\rm LogNormal$ distribution with support on $(0,\infty)$. All necessary transformations are handled internally.

* The data and latent space can include any number of variables, with any shapes. Below we demonstrate an $x$ and $z$ which are 2-dimensional arrays. 

First, load the relevant packages:

In [2]:
%pylab inline
import pymc as pm
from muse_inference.pymc import PyMCMuseProblem

Populating the interactive namespace from numpy and matplotlib


Then define the problem,

In [3]:
def gen_funnel(x=None, θ=None, μ=None, rng=None):
    with pm.Model(rng_seeder=rng) as model:
        μ = pm.Beta("μ", 2, 5, size=2) if μ is None else μ
        θ = pm.Normal("θ", 0, 3) if θ is None else θ
        z = pm.LogNormal("z", μ, np.exp(θ/2), size=(100, 2))
        x = pm.Normal("x", z, 1, observed=x)
    return model

generate the model and some data, given some chosen true values of parameters,

In [4]:
params_true = dict(μ=[0.3, 0.7], θ=1)
x_obs = pm.sample_prior_predictive(1, gen_funnel(rng=RandomState(0), **params_true)).prior.x[0,0]
model = gen_funnel(x=x_obs)
prob = PyMCMuseProblem(model)

and finally, run MUSE:

In [5]:
params_start = dict(μ=[0.5, 0.5], θ=0)
result = prob.solve(params_start, progress=True)

MUSE:   0%|          | 0/5050 [00:00<?, ?it/s]

MUSE:   0%|          | 9/5050 [00:00<00:57, 87.38it/s]

MUSE:   0%|          | 18/5050 [00:00<00:58, 86.47it/s]

MUSE:   1%|          | 27/5050 [00:00<00:59, 84.71it/s]

MUSE:   1%|          | 36/5050 [00:00<01:00, 83.26it/s]

MUSE:   1%|          | 45/5050 [00:00<00:58, 85.16it/s]

MUSE:   1%|          | 54/5050 [00:00<01:01, 81.57it/s]

MUSE:   1%|▏         | 64/5050 [00:00<00:58, 85.09it/s]

MUSE:   1%|▏         | 73/5050 [00:00<00:59, 84.23it/s]

MUSE:   2%|▏         | 83/5050 [00:00<00:56, 87.30it/s]

MUSE:   2%|▏         | 92/5050 [00:01<00:56, 88.03it/s]

MUSE:   2%|▏         | 102/5050 [00:01<01:24, 58.36it/s]

MUSE:   2%|▏         | 111/5050 [00:01<01:16, 64.94it/s]

MUSE:   2%|▏         | 119/5050 [00:01<01:13, 66.75it/s]

MUSE:   3%|▎         | 127/5050 [00:01<01:11, 69.20it/s]

MUSE:   3%|▎         | 136/5050 [00:01<01:07, 73.25it/s]

MUSE:   3%|▎         | 144/5050 [00:01<01:06, 74.17it/s]

MUSE:   3%|▎         | 153/5050 [00:02<01:04, 76.35it/s]

MUSE:   3%|▎         | 162/5050 [00:02<01:03, 77.19it/s]

MUSE:   3%|▎         | 170/5050 [00:02<01:03, 77.34it/s]

MUSE:   4%|▎         | 179/5050 [00:02<01:00, 80.31it/s]

MUSE:   4%|▎         | 188/5050 [00:02<01:00, 80.97it/s]

MUSE:   4%|▍         | 197/5050 [00:02<00:58, 82.46it/s]

MUSE:   4%|▍         | 206/5050 [00:02<01:24, 57.19it/s]

MUSE:   4%|▍         | 218/5050 [00:02<01:08, 70.11it/s]

MUSE:   5%|▍         | 229/5050 [00:03<01:00, 79.08it/s]

MUSE:   5%|▍         | 242/5050 [00:03<00:52, 91.80it/s]

MUSE:   5%|▌         | 254/5050 [00:03<00:49, 97.18it/s]

MUSE:   5%|▌         | 266/5050 [00:03<00:47, 101.06it/s]

MUSE:   5%|▌         | 277/5050 [00:03<00:46, 102.99it/s]

MUSE:   6%|▌         | 288/5050 [00:03<00:45, 104.70it/s]

MUSE:   6%|▌         | 300/5050 [00:03<00:44, 106.54it/s]

MUSE:   6%|▌         | 311/5050 [00:03<01:04, 73.12it/s] 

MUSE:   6%|▋         | 327/5050 [00:04<00:51, 91.67it/s]

MUSE:   7%|▋         | 342/5050 [00:04<00:44, 105.26it/s]

MUSE:   7%|▋         | 358/5050 [00:04<00:40, 116.79it/s]

MUSE:   7%|▋         | 376/5050 [00:04<00:35, 133.23it/s]

MUSE:   8%|▊         | 391/5050 [00:04<00:35, 132.11it/s]

MUSE:   8%|▊         | 406/5050 [00:04<00:49, 93.47it/s] 

MUSE:   8%|▊         | 422/5050 [00:04<00:43, 106.48it/s]

MUSE:   9%|▊         | 441/5050 [00:04<00:36, 125.28it/s]

MUSE:   9%|▉         | 461/5050 [00:05<00:32, 141.23it/s]

MUSE:   9%|▉         | 479/5050 [00:05<00:30, 147.73it/s]

MUSE:  10%|▉         | 496/5050 [00:05<00:29, 153.06it/s]

MUSE:  10%|█         | 513/5050 [00:05<00:42, 107.99it/s]

MUSE:  11%|█         | 533/5050 [00:05<00:36, 125.08it/s]

MUSE:  11%|█         | 555/5050 [00:05<00:30, 145.52it/s]

MUSE:  11%|█▏        | 575/5050 [00:05<00:28, 158.42it/s]

MUSE:  12%|█▏        | 598/5050 [00:05<00:25, 172.63it/s]

MUSE:  12%|█▏        | 617/5050 [00:06<00:38, 116.38it/s]

MUSE:  13%|█▎        | 644/5050 [00:06<00:29, 146.93it/s]

MUSE:  13%|█▎        | 667/5050 [00:06<00:26, 165.25it/s]

MUSE:  14%|█▎        | 691/5050 [00:06<00:23, 181.74it/s]

MUSE:  14%|█▍        | 712/5050 [00:06<00:32, 134.54it/s]

MUSE:  15%|█▍        | 741/5050 [00:06<00:26, 163.24it/s]

MUSE:  15%|█▌        | 761/5050 [00:06<00:25, 170.90it/s]

MUSE:  16%|█▌        | 787/5050 [00:07<00:22, 192.06it/s]

MUSE:  16%|█▌        | 809/5050 [00:07<00:32, 129.71it/s]

MUSE:  17%|█▋        | 839/5050 [00:07<00:26, 161.77it/s]

MUSE:  17%|█▋        | 867/5050 [00:07<00:22, 186.34it/s]

MUSE:  18%|█▊        | 896/5050 [00:07<00:19, 209.20it/s]

MUSE:  18%|█▊        | 921/5050 [00:07<00:26, 153.73it/s]

MUSE:  19%|█▉        | 955/5050 [00:08<00:21, 190.83it/s]

MUSE:  20%|█▉        | 989/5050 [00:08<00:18, 223.09it/s]

MUSE:  20%|██        | 1016/5050 [00:08<00:23, 173.20it/s]

MUSE:  22%|██▏       | 1096/5050 [00:08<00:13, 299.89it/s]

MUSE: 100%|██████████| 5050/5050 [00:08<00:00, 7167.64it/s]

MUSE: 100%|██████████| 5050/5050 [00:08<00:00, 577.54it/s] 




get_H:   0%|          | 0/70 [00:00<?, ?it/s]

get_H:   7%|▋         | 5/70 [00:00<00:01, 41.75it/s]

get_H:  14%|█▍        | 10/70 [00:00<00:01, 43.25it/s]

get_H:  21%|██▏       | 15/70 [00:00<00:01, 42.77it/s]

get_H:  29%|██▊       | 20/70 [00:00<00:01, 40.68it/s]

get_H:  36%|███▌      | 25/70 [00:00<00:01, 43.27it/s]

get_H:  46%|████▌     | 32/70 [00:00<00:00, 49.98it/s]

get_H:  54%|█████▍    | 38/70 [00:00<00:00, 51.88it/s]

get_H:  63%|██████▎   | 44/70 [00:00<00:00, 48.77it/s]

get_H:  70%|███████   | 49/70 [00:01<00:00, 47.05it/s]

get_H:  77%|███████▋  | 54/70 [00:01<00:00, 43.61it/s]

get_H:  84%|████████▍ | 59/70 [00:01<00:00, 43.51it/s]

get_H:  93%|█████████▎| 65/70 [00:01<00:00, 45.10it/s]

get_H: 100%|██████████| 70/70 [00:01<00:00, 43.13it/s]

get_H: 100%|██████████| 70/70 [00:01<00:00, 44.75it/s]




When there are multiple parameters, the starting guess should be specified as as a dictionary, as above.

The solution is returned as a 1-dimensional vector of all parameters concatenated in the order they appear in the model:

In [6]:
result.θ, result.Σ

(array([0.32164014, 0.52036869, 0.81806896]),
 array([[ 0.02828987,  0.00559278, -0.0216605 ],
        [ 0.00559278,  0.01515474, -0.00759072],
        [-0.0216605 , -0.00759072,  0.02698495]]))

## Jax

We can also use [Jax](https://jax.readthedocs.io/) to define the problem. In this case we will write out function to generate forward samples and to compute the posterior, and Jax will provide necessary gradients for free. To use Jax, load the necessary packages:

In [7]:
from functools import partial
import jax
import jax.numpy as jnp
from muse_inference.jax import JittableJaxMuseProblem, JaxMuseProblem
from muse_inference import XZSample

Let's implement the noisy funnel problem from the [Example](example.html) page. To do so, extend either `JaxMuseProblem`, or, if your code is able to be JIT compiled by Jax, extend `JittableJaxMuseProblem` and decorate the functions with `jax.jit`:

In [8]:
class JaxFunnelMuseProblem(JittableJaxMuseProblem):

    def __init__(self, N):
        super().__init__()
        self.N = N

    @partial(jax.jit, static_argnums=0)
    def sample_x_z(self, key, θ):
        keys = jax.random.split(key, 2)
        z = jax.random.normal(keys[0], (self.N,)) * jnp.exp(θ/2)
        x = z + jax.random.normal(keys[1], (self.N,))
        return XZSample(x, z)

    @partial(jax.jit, static_argnums=0)
    def logLike(self, x, z, θ):
        return -(jnp.sum((x - z)**2) + jnp.sum(z**2) / jnp.exp(θ) + 512*θ) / 2

    @partial(jax.jit, static_argnums=0)
    def logPrior(self, θ):
        return -θ**2 / (2*3**2)

Now generate some simulated data, which we set into `prob.x`. Note also the use of `PRNGKey` (rather than `RandomState` for PyMC/Numpy) for random number generation. 

In [9]:
prob = JaxFunnelMuseProblem(10000)
key = jax.random.PRNGKey(0)
(x, z) = prob.sample_x_z(key, jnp.array([1.]))
prob.x = x



And finally, run MUSE:

In [10]:
result = prob.solve(θ_start=0., rng=jax.random.PRNGKey(1), progress=True)

MUSE:   0%|          | 0/5050 [00:00<?, ?it/s]

MUSE:   0%|          | 1/5050 [00:02<2:55:13,  2.08s/it]

MUSE:   2%|▏         | 102/5050 [00:05<03:58, 20.78it/s]

MUSE:   4%|▍         | 203/5050 [00:05<01:48, 44.47it/s]

MUSE:   6%|▌         | 304/5050 [00:06<01:01, 77.18it/s]

MUSE:   8%|▊         | 405/5050 [00:06<00:39, 117.95it/s]

MUSE:  10%|█         | 506/5050 [00:06<00:27, 167.15it/s]

MUSE:  12%|█▏        | 607/5050 [00:06<00:19, 222.58it/s]

MUSE:  14%|█▍        | 708/5050 [00:06<00:15, 282.87it/s]

MUSE:  16%|█▌        | 809/5050 [00:06<00:12, 344.93it/s]

MUSE:  18%|█▊        | 910/5050 [00:07<00:10, 398.28it/s]

MUSE: 100%|██████████| 5050/5050 [00:07<00:00, 6010.15it/s]

MUSE: 100%|██████████| 5050/5050 [00:07<00:00, 697.00it/s] 




get_H:   0%|          | 0/30 [00:00<?, ?it/s]

get_H:   3%|▎         | 1/30 [00:02<01:07,  2.33s/it]

get_H:  70%|███████   | 21/30 [00:02<00:00, 11.94it/s]

get_H: 100%|██████████| 30/30 [00:02<00:00, 11.91it/s]




Note that the solution here is obtained around 10X faster that the PyMC version of this in the [Example](example.html) page. The Jax interface has much lower overhead, which will be noticeable for very fast posteriors like the one above. 

One powerful aspect of using Jax is that the parameters, `θ`, and latent space, `z`, can be any [pytree](https://jax.readthedocs.io/en/latest/pytrees.html), ie tuples, dictionaries, nested combinations of them, etc... (there is no requirement on the data format of the `x` variable). To demonstrate, consider a problem which is just two copies of the noisy funnel problem:

In [11]:
class JaxPyTreeFunnelMuseProblem(JittableJaxMuseProblem):

    def __init__(self, N):
        super().__init__()
        self.N = N

    @partial(jax.jit, static_argnums=0)
    def sample_x_z(self, key, θ):
        (θ1, θ2) = (θ["θ1"], θ["θ2"])
        keys = jax.random.split(key, 4)
        z1 = jax.random.normal(keys[0], (self.N,)) * jnp.exp(θ1/2)
        z2 = jax.random.normal(keys[1], (self.N,)) * jnp.exp(θ2/2)        
        x1 = z1 + jax.random.normal(keys[2], (self.N,))
        x2 = z2 + jax.random.normal(keys[3], (self.N,))        
        return XZSample(x={"x1":x1, "x2":x2}, z={"z1":z1, "z2":z2})

    @partial(jax.jit, static_argnums=0)
    def logLike(self, x, z, θ):
        return (
            -(jnp.sum((x["x1"] - z["z1"])**2) + jnp.sum(z["z1"]**2) / jnp.exp(θ["θ1"]) + 512*θ["θ1"]) / 2
            -(jnp.sum((x["x2"] - z["z2"])**2) + jnp.sum(z["z2"]**2) / jnp.exp(θ["θ2"]) + 512*θ["θ2"]) / 2
        )

    @partial(jax.jit, static_argnums=0)
    def logPrior(self, θ):
        return - θ["θ1"]**2 / (2*3**2) - θ["θ2"]**2 / (2*3**2)

Here, `x`, `θ`, and `z` are all dictionaries. We generate the problem as usual, passing in parameters as dictionaries,

In [12]:
θ_true = dict(θ1=-1., θ2=2.)
θ_start = dict(θ1=0., θ2=0.)

In [13]:
prob = JaxPyTreeFunnelMuseProblem(10000)
key = jax.random.PRNGKey(0)
(x, z) = prob.sample_x_z(key, θ_true)
prob.x = x

and run MUSE:

In [14]:
result = prob.solve(θ_start=θ_start, rng=jax.random.PRNGKey(0), progress=True)

MUSE:   0%|          | 0/5050 [00:00<?, ?it/s]

MUSE:   0%|          | 1/5050 [00:02<3:44:45,  2.67s/it]

MUSE:   2%|▏         | 102/5050 [00:07<05:05, 16.19it/s]

MUSE:   3%|▎         | 151/5050 [00:07<03:03, 26.72it/s]

MUSE:   4%|▍         | 199/5050 [00:07<01:59, 40.66it/s]

MUSE:   5%|▍         | 232/5050 [00:07<01:54, 42.09it/s]

MUSE:   5%|▌         | 260/5050 [00:08<01:30, 52.89it/s]

MUSE:   6%|▌         | 289/5050 [00:08<01:10, 67.30it/s]

MUSE:   6%|▌         | 315/5050 [00:08<01:03, 74.74it/s]

MUSE:   7%|▋         | 347/5050 [00:08<00:48, 97.52it/s]

MUSE:   7%|▋         | 378/5050 [00:08<00:38, 122.31it/s]

MUSE:   8%|▊         | 405/5050 [00:08<00:38, 122.17it/s]

MUSE:   9%|▊         | 433/5050 [00:08<00:31, 145.33it/s]

MUSE:   9%|▉         | 463/5050 [00:09<00:26, 171.32it/s]

MUSE:  10%|▉         | 493/5050 [00:09<00:23, 196.51it/s]

MUSE:  10%|█         | 520/5050 [00:09<00:27, 165.70it/s]

MUSE:  11%|█         | 550/5050 [00:09<00:23, 191.72it/s]

MUSE:  11%|█▏        | 579/5050 [00:09<00:20, 213.17it/s]

MUSE:  12%|█▏        | 607/5050 [00:09<00:25, 172.85it/s]

MUSE:  13%|█▎        | 637/5050 [00:09<00:22, 198.57it/s]

MUSE:  13%|█▎        | 667/5050 [00:10<00:19, 221.46it/s]

MUSE:  14%|█▍        | 698/5050 [00:10<00:18, 241.57it/s]

MUSE:  14%|█▍        | 726/5050 [00:10<00:22, 190.20it/s]

MUSE:  15%|█▌        | 761/5050 [00:10<00:19, 224.15it/s]

MUSE:  16%|█▌        | 797/5050 [00:10<00:16, 254.66it/s]

MUSE:  16%|█▋        | 826/5050 [00:10<00:21, 197.26it/s]

MUSE:  17%|█▋        | 858/5050 [00:10<00:18, 222.66it/s]

MUSE:  18%|█▊        | 890/5050 [00:10<00:17, 244.66it/s]

MUSE:  18%|█▊        | 918/5050 [00:11<00:21, 191.78it/s]

MUSE:  19%|█▊        | 944/5050 [00:11<00:19, 205.39it/s]

MUSE:  19%|█▉        | 971/5050 [00:11<00:18, 219.98it/s]

MUSE:  20%|█▉        | 1000/5050 [00:11<00:17, 235.63it/s]

MUSE:  20%|██        | 1026/5050 [00:11<00:21, 189.00it/s]

MUSE:  21%|██▏       | 1075/5050 [00:11<00:15, 256.00it/s]

MUSE:  22%|██▏       | 1112/5050 [00:12<00:17, 223.00it/s]

MUSE:  23%|██▎       | 1169/5050 [00:12<00:13, 296.72it/s]

MUSE:  24%|██▍       | 1213/5050 [00:12<00:14, 263.20it/s]

MUSE:  25%|██▌       | 1265/5050 [00:12<00:11, 317.12it/s]

MUSE:  26%|██▌       | 1313/5050 [00:12<00:10, 352.96it/s]

MUSE:  27%|██▋       | 1354/5050 [00:12<00:12, 290.51it/s]

MUSE:  28%|██▊       | 1404/5050 [00:12<00:10, 336.01it/s]

MUSE:  29%|██▊       | 1443/5050 [00:13<00:13, 264.74it/s]

MUSE:  29%|██▉       | 1482/5050 [00:13<00:12, 290.01it/s]

MUSE: 100%|██████████| 5050/5050 [00:13<00:00, 6557.03it/s]

MUSE: 100%|██████████| 5050/5050 [00:13<00:00, 376.10it/s] 




get_H:   0%|          | 0/50 [00:00<?, ?it/s]

get_H:   2%|▏         | 1/50 [00:02<02:09,  2.63s/it]

get_H:  32%|███▏      | 16/50 [00:02<00:04,  8.07it/s]

get_H:  78%|███████▊  | 39/50 [00:02<00:00, 23.22it/s]

get_H: 100%|██████████| 50/50 [00:02<00:00, 16.97it/s]




The result is returned as a dictionary:

In [15]:
result.θ

{'θ1': DeviceArray(-1.0030082, dtype=float32),
 'θ2': DeviceArray(2.0271263, dtype=float32)}

and the covariance as a matrix:

In [16]:
result.Σ

array([[ 3.3970999e-03, -3.8145434e-05],
       [-3.8145430e-05,  2.4442433e-04]], dtype=float32)