Now that you've seen how pyro can be used to solve an inference problem with VI,
try it for yourself!

For this exercise we consider a model from behavioural economics for _Constant
Elasticity of Substitution_, or CES for short. The parameters of the model are
$\rho, \alpha, u$ and the input is $\xi = {x, x^\prime}$ where $x$ and $x^\prime$
are two baskets of goods such that $x, x^\prime \in [0,100]^3$. The priors over
the parameters are

\begin{aligned}
&\rho \sim Beta(1,1)\\
&\alpha \sim Dirichlet([1,1,1])\\
\log &(u) \sim \mathcal{N}(1,3)
\end{aligned}

and the predictive model is

\begin{aligned}
&U(x) = (\sum_{i} x_i^\rho\alpha_i)^{\frac{1}{\rho}}\\
&\mu_\eta = u \cdot (U(x) - U(x^\prime))\\
&\sigma_\eta = 0.005 \cdot u \cdot (1 + ||x - x^\prime||_2^2) \\
&\eta \sim \mathcal{N}(\mu_\eta, \sigma^2_\eta)\\
&y = sigmoid(\eta)
\end{aligned}

Use what you've learned from the `source_location` notebook to build a model
in pyro, generate data, and infer the parameters of the data generating
distribution.

In [None]:
import pyro
import pyro.distributions as dist
import torch

from contextlib import ExitStack
from pyro.contrib.util import iter_plates_to_shape, rexpand, rmv

def make_ces_model(rho_concentration, alpha_concentration, slope_mu, slope_sigma, observation_sd,
                   observation_label="y"):
    def ces_model(design):
        batch_shape = design.shape[:-2]
        with ExitStack() as stack:
            for plate in iter_plates_to_shape(batch_shape):
                stack.enter_context(plate)
            rho_shape = batch_shape + (rho_concentration.shape[-1],)
            rho = 0.01 + 0.99 * pyro.sample("rho", dist.Dirichlet(rho_concentration.expand(rho_shape))).select(-1, 0)
            alpha_shape = batch_shape + (alpha_concentration.shape[-1],)
            alpha = pyro.sample("alpha", dist.Dirichlet(alpha_concentration.expand(alpha_shape)))
            slope = pyro.sample("slope", dist.LogNormal(slope_mu.expand(batch_shape), slope_sigma.expand(batch_shape)))
            rho, slope = rexpand(rho, design.shape[-2]), rexpand(slope, design.shape[-2])
            d1, d2 = design[..., 0:3], design[..., 3:6]
            U1rho = (rmv(d1.pow(rho.unsqueeze(-1)), alpha)).pow(1. / rho)
            U2rho = (rmv(d2.pow(rho.unsqueeze(-1)), alpha)).pow(1. / rho)
            mean = slope * (U1rho - U2rho)
            sd = slope * observation_sd * (1 + torch.norm(d1 - d2, dim=-1, p=2))


            emission_dist = dist.Normal(mean, sd).to_event(1)
            y = pyro.sample(observation_label, emission_dist)
            return torch.sigmoid(y)

    return ces_model

def elboguide(design, dim=10):
    rho_concentration = pyro.param("rho_concentration", torch.ones(dim, 1, 2),
                                   constraint=torch.distributions.constraints.positive)
    alpha_concentration = pyro.param("alpha_concentration", torch.ones(dim, 1, 3),
                                     constraint=torch.distributions.constraints.positive)
    slope_mu = pyro.param("slope_mu", torch.ones(dim, 1))
    slope_sigma = pyro.param("slope_sigma", 3. * torch.ones(dim, 1),
                             constraint=torch.distributions.constraints.positive)
    batch_shape = design.shape[:-2]
    with ExitStack() as stack:
        for plate in iter_plates_to_shape(batch_shape):
            stack.enter_context(plate)
        rho_shape = batch_shape + (rho_concentration.shape[-1],)
        pyro.sample("rho", dist.Dirichlet(rho_concentration.expand(rho_shape)))
        alpha_shape = batch_shape + (alpha_concentration.shape[-1],)
        pyro.sample("alpha", dist.Dirichlet(alpha_concentration.expand(alpha_shape)))
        pyro.sample("slope", dist.LogNormal(slope_mu.expand(batch_shape),
                                            slope_sigma.expand(batch_shape)))