# Normal Normal

In [None]:
import torch
import torch.distributions as dist
import pandas as pd
import seaborn as sns

sns.set_style("darkgrid")

import beanmachine.ppl as bm

class NormalNormal:
    @bm.random_variable
    def mu(self):
        return dist.Normal(0, 1)
    
    @bm.random_variable
    def x(self, i):
        return dist.Normal(self.mu(), 1)

In [None]:
from beanmachine.ppl.experimental.vi.VariationalInfer import MeanFieldVariationalInference

model = NormalNormal()

vi = MeanFieldVariationalInference()
queries = [model.mu()]
obs = {
    model.x(0): torch.tensor(10.0),
}
vi_dicts = vi.infer(queries, obs, num_iter=500, lr=1e-2, num_flows=1)

In [None]:
x = torch.linspace(0, 12).unsqueeze(1)
p = torch.exp(vi_dicts[model.mu()].log_prob(x))
sns.relplot(data=pd.DataFrame({
    "x": x.flatten().detach().numpy(),
    "p": p.flatten().detach().numpy(),
}), x="x", y="p", kind="line")

In [None]:
sns.displot(vi_dicts[model.mu()].sample((100,1)).detach().numpy())

In [None]:
import beanmachine.ppl as bm

mcmc = bm.SingleSiteHamiltonianMonteCarlo(path_length=10)
samples = mcmc.infer(queries, obs, num_samples=100, num_chains=1)
sns.displot(samples[model.mu()].flatten().numpy())