In [None]:
from contextlib import AbstractContextManager
import rich
import torch
import numpy as np
import pyro
from pyro import distributions as dist
from pyro import poutine
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, SVI
from pyro.infer import config_enumerate, infer_discrete
from pyro.distributions import constraints
from pyro.ops.indexing import Vindex
from pyro.infer.autoguide import AutoNormal, AutoDelta, AutoGuide
from matplotlib import pyplot as plt
from IPython import display
%matplotlib inline


class catch(AbstractContextManager):
    def __init__(self, info = None):
        self.info = info

    def __enter__(self):
        if self.info is not None:
            print(f"=== {self.info} ===")

    def __exit__(self, exctype, excinst, exctb):
        if exctype is not None:
            print(f"Error: {exctype}")
        return True


def render_model_plus(model, model_args=None):
    display.display(
        pyro.render_model(model, model_args, render_params=True, render_distributions=True)
    )


## Fit a Gaussian Distribution

In this section, we use `pyro` to fit a Gaussian Distribution. This is a very toy example to get understanding of some primitives: `sample`, `param`, `model` and `guide`.

### Difference between `sample` and `param`

- connection with `model` and `guide`:
  - `sample` (with the same name) must appear both in `model` and `guide`
  - `param` can be `model`-only or `guide`-only
- trainbale? if the arg of `sample` is:
  - constant distribution: never change during training
  - `param`eterized distribution: changed during training

### What is trained

See the case *model_3*, the `model` views `loc` and `scale` as random variables, and specifies the **prior**. Even if we train the `model` (actually, in this case only the `guide` gets trained, why in *model_4* the `model` gets trained too) to get a good estimate of posterior, the `model` itself is still a `prior` model! (look at the generated data!)

In [None]:
data = dist.Normal(2.0, 1.0).sample((100, ))

In [None]:
def test(model):
    pyro.clear_param_store()

    cond_model = pyro.condition(model, data={"x": data})
    guide = AutoDelta(cond_model)
    elbo = Trace_ELBO()
    optim = pyro.optim.AdamW({"lr": 0.02})
    svi = SVI(cond_model, guide, optim, elbo)

    # training
    losses = []
    for i in range(500):
        loss = svi.step()
        losses.append(loss)
    print("loss:")
    plt.plot(losses)
    plt.show()

    # show result
    rich.print(dict(pyro.get_param_store()))
    print("model:")
    render_model_plus(cond_model)
    print("guide:")
    render_model_plus(guide)

    print("generated data:")
    plt.figure(figsize=(9, 3))
    for i in range(3):
        plt.subplot(1, 3, i + 1)
        generated_data = model().detach().numpy()
        plt.hist(generated_data, bins=20, density=True)
        plt.hist(data.numpy(), bins=20, density=True)
        plt.title(f"mean={round(float(np.mean(generated_data)), 3)}\n"  #
                  + f"var={round(float(np.var(generated_data)), 3)}")
    plt.show()

In [None]:
"""
    For a `model` without random variables (aka. without `pyro.sample`),
    and without `pyro.param`:
    
    - the `model` cannot be trained
    - the `guide` contains nothing
    - the generated data will always be N(0.0, 1.0)
"""
def model_0():
    with pyro.plate("num", size=len(data)):
        x = pyro.sample("x", dist.Normal(0.0, 1.0))
    return x
    
test(model_0)

In [None]:
"""
    For a `model` containing only params
    
    - the `param` will be trained
    - the `guide` is `None`
    - the generated data will be deterministic: N(trained_loc, trained_scale)
"""
def model_1():
    loc = pyro.param("loc", torch.tensor(0.0))
    scale = pyro.param("scale", torch.tensor(1.0))
    with pyro.plate("num", size=len(data)):
        x = pyro.sample("x", dist.Normal(loc, scale))
    return x

test(model_1)

In [None]:
"""
    For a `model` containing random variables
    
    - the `model` will `NOT` be trained, since no params exist
    - the `guide` contains variables (and corresponding params)
    - the generated data will be stochastic, although the guide gives a 
      good estimation of `loc` and `scale`. Note that no params exist in
      the model!

"""
def model_2():
    loc = pyro.sample("loc", dist.Normal(0.0, 1.0))
    scale = pyro.sample("scale", dist.LogNormal(0.0, 3.0))
    with pyro.plate("num", size=len(data)):
        x = pyro.sample("x", dist.Normal(loc, scale))
    return x

test(model_2)

In [None]:
"""
    For a `model` containing both `pyro.sample` and `pyro.param`:
    
    - the `model` will be trained (but only the `scale` is trained)
    - the `guide` contains variables about `loc`, but not about `scale`
    - the generated data will be stochastic about `loc` but deterministic
      about `scale`. 
"""
def model_3():
    loc = pyro.sample("loc", dist.Normal(0.0, 5.0))
    scale = pyro.param("scale", torch.tensor(10.0))
    with pyro.plate("num", size=len(data)):
        x = pyro.sample("x", dist.Normal(loc, scale))
    return x

test(model_3)