In [1]:
import pyro
import pyro.distributions as dist
import torch
from pyro.infer import SVI, TraceEnum_ELBO, Predictive, config_enumerate
from pyro.optim import ClippedAdam
from pyro.infer.autoguide import AutoNormal, AutoDiagonalNormal, AutoDelta
from pprint import pprint

In [2]:
# The @config_enumerate is "required" (not strictly) because we have a discrete variable, A
# See https://pyro.ai/examples/enumeration.html
@config_enumerate
def BN_model(A_obs=None, B_obs=None, C_obs=None, N=None):
    if A_obs is not None:
        if B_obs is not None:
            assert len(A_obs) == len(B_obs)
        if C_obs is not None:
            assert len(A_obs) == len(C_obs)
        if N is not None:
            assert N == len(A_obs)
        else:
            N = len(A_obs)

        A_obs = A_obs.squeeze()

    if B_obs is not None:
        if N is not None:
            assert N == len(B_obs)
        else:
            N = len(B_obs)

        B_obs = B_obs.squeeze()

    if C_obs is not None:
        if N is not None:
            assert N == len(C_obs)
        else:
            N = len(C_obs)

        C_obs = C_obs.squeeze()

    if N is None:
        N = 1

    # prior distribution over weights of the categorical distribution from which A is drawn
    # pyro distinguishes between the "batch_shape" (=shape of samples drawn) and the "event_shape" (=shape of a single
    # RV drawn from this distribution) of a tensor. We need to tell it that this 3-D thing describes a single RV (and
    # similarly for the other priors below). See https://pyro.ai/examples/tensor_shapes.html for details.
    weights = pyro.sample('weights', dist.Dirichlet(torch.ones(3)).to_event())

    # prior distribution over parameters (> 0) of the beta distribution from which B is drawn
    beta_concentrations = pyro.sample('beta_concentrations', dist.Gamma(concentration=torch.tensor([2., 2.]),
                                                                        rate=torch.tensor([0.5, 0.5])).to_event())

    # prior distribution over weigths k in p_C = B*k(A)
    C_weights = pyro.sample('C_weights', dist.Beta(torch.tensor([1., 1., 1.]), torch.tensor([1., 1., 1.])).to_event())

    if N > 0:
        with pyro.plate('data', N):
            A_dist = dist.Categorical(weights)
            #ic(A_dist.batch_shape)
            #ic(A_dist.event_shape)
            #if A_obs is not None:
            #    ic(A_obs.shape)
            A = pyro.sample('A', A_dist, obs=A_obs, infer={"enumerate": "parallel"})
            B_dist = dist.Beta(beta_concentrations[0], beta_concentrations[1])
            #ic(B_dist.batch_shape)
            #ic(B_dist.event_shape)
            #if B_obs is not None:
            #    ic(B_obs.shape)
            B = pyro.sample('B', B_dist, obs=B_obs)
            C = pyro.sample('C', dist.Binomial(probs=B * C_weights[A]), obs=C_obs)

In [3]:
# Visualize the model (needs graphviz - don't have it, cannot test)
#pyro.render_model(lambda: BN_model(N=100), render_distributions=True)

In [4]:
# Helper function
def summarize_samples(samples):
    # Adapted from https://pyro.ai/examples/bayesian_regression.html#Model-Evaluation.
    param_stats = {}
    for k, v in samples.items():
        if torch.is_floating_point(v):
            param_stats[k] = {
                "mean": torch.mean(v, 0),
                "std": torch.std(v, 0),
                "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
                "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
            }
        else:
            print(f'Dropping variable {k} from summary statistics since it is not a float.')
    return param_stats

In [9]:
# Sample from the prior distribution, see here: https://forum.pyro.ai/t/samples-from-prior-distribution/1740/2
prior_samples = Predictive(BN_model, posterior_samples={}, num_samples=1000)()
print("SUMMARY: SAMPLES FROM PRIOR DISTRIBUTION\n")
pprint(summarize_samples(prior_samples))

SUMMARY: SAMPLES FROM PRIOR DISTRIBUTION

Dropping variable A from summary statistics since it is not a float.
{'B': {'5%': tensor([0.0547]),
       '95%': tensor([0.9386]),
       'mean': tensor([0.4934]),
       'std': tensor([0.2751])},
 'C': {'5%': tensor([0.]),
       '95%': tensor([1.]),
       'mean': tensor([0.2200]),
       'std': tensor([0.4145])},
 'C_weights': {'5%': tensor([[0.0509, 0.0539, 0.0510]]),
               '95%': tensor([[0.9461, 0.9481, 0.9499]]),
               'mean': tensor([[0.4949, 0.5013, 0.5111]]),
               'std': tensor([[0.2937, 0.2879, 0.2885]])},
 'beta_concentrations': {'5%': tensor([[0.6700, 0.7640]]),
                         '95%': tensor([[8.9570, 9.3770]]),
                         'mean': tensor([[3.8575, 3.8982]]),
                         'std': tensor([[2.7063, 2.7227]])},
 'weights': {'5%': tensor([[0.0238, 0.0261, 0.0258]]),
             '95%': tensor([[0.7493, 0.7722, 0.7515]]),
             'mean': tensor([[0.3356, 0.3325, 0.3319]]

In [10]:
# Specify some parameters and sample from the parametrized model
# we'll see below whether we can then estimate those params
weights = torch.tensor([0.2, 0.2, 0.6])
beta_concentrations = torch.tensor([0.5, 2.0])
C_weights = torch.tensor([0.5, 1.0, 0.2])
BN_model_conditioned = pyro.poutine.condition(BN_model, data={'weights': weights,
                                                              'beta_concentrations': beta_concentrations,
                                                              'C_weights': C_weights})
parametrized_samples = Predictive(BN_model_conditioned, posterior_samples={}, num_samples=5000)()
print("SUMMARY: SAMPLES FROM CONDITIONED DISTRIBUTION\n")
pprint(summarize_samples(parametrized_samples))

SUMMARY: SAMPLES FROM CONDITIONED DISTRIBUTION

Dropping variable A from summary statistics since it is not a float.
{'B': {'5%': tensor([0.0010]),
       '95%': tensor([0.6618]),
       'mean': tensor([0.2024]),
       'std': tensor([0.2170])},
 'C': {'5%': tensor([0.]),
       '95%': tensor([1.]),
       'mean': tensor([0.0792]),
       'std': tensor([0.2701])},
 'C_weights': {'5%': tensor([[0.5000, 1.0000, 0.2000]]),
               '95%': tensor([[0.5000, 1.0000, 0.2000]]),
               'mean': tensor([[0.5000, 1.0000, 0.2000]]),
               'std': tensor([[0., 0., 0.]])},
 'beta_concentrations': {'5%': tensor([[0.5000, 2.0000]]),
                         '95%': tensor([[0.5000, 2.0000]]),
                         'mean': tensor([[0.5000, 2.0000]]),
                         'std': tensor([[0., 0.]])},
 'weights': {'5%': tensor([[0.2000, 0.2000, 0.6000]]),
             '95%': tensor([[0.2000, 0.2000, 0.6000]]),
             'mean': tensor([[0.2000, 0.2000, 0.6000]]),
           

In [7]:
# Now let's try to estimate those parameters using SVI
pyro.clear_param_store()
guide = AutoNormal(pyro.poutine.block(BN_model, hide=["A", "B", "C"]))

svi = SVI(model=BN_model,
          guide=guide,
          optim=ClippedAdam({"lr": 0.01, 'clip_norm': 1.0}),
          loss=TraceEnum_ELBO(max_plate_nesting=1))  # if we didn't have a discrete variable, we'd use Trace_ELBO

for i in range(5000):
    loss = svi.step(parametrized_samples['A'], parametrized_samples['B'], parametrized_samples['C'])

In [11]:
# Did we estimate the parameters correctly?
posterior_predictive = Predictive(BN_model, guide=guide, num_samples=1000, return_sites=("weights", "beta_concentrations", "C_weights"))
posterior_samples = posterior_predictive()
print("SUMMARY: SAMPLES FROM POSTERIOR DISTRIBUTION\n")
pprint(summarize_samples(posterior_samples))

SUMMARY: SAMPLES FROM POSTERIOR DISTRIBUTION

{'C_weights': {'5%': tensor([[0.4659, 0.8215, 0.1502]]),
               '95%': tensor([[0.6467, 0.9938, 0.2309]]),
               'mean': tensor([[0.5590, 0.9421, 0.1885]]),
               'std': tensor([[0.0570, 0.0659, 0.0243]])},
 'beta_concentrations': {'5%': tensor([[0.4759, 1.8436]]),
                         '95%': tensor([[0.5078, 2.0155]]),
                         'mean': tensor([[0.4919, 1.9300]]),
                         'std': tensor([[0.0098, 0.0518]])},
 'weights': {'5%': tensor([[0.1916, 0.1859, 0.5772]]),
             '95%': tensor([[0.2224, 0.2128, 0.6109]]),
             'mean': tensor([[0.2070, 0.1990, 0.5939]]),
             'std': tensor([[0.0092, 0.0080, 0.0105]])}}
