## [Example: Toy Mixture Model With Discrete Enumeration](http://pyro.ai/examples/toy_mixture_model_discrete_enumeration.html#example-toy-mixture-model-with-discrete-enumeration)

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.distributions import constraints
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.beta import Beta
from tqdm import tqdm

In [2]:
import pyro
import pyro.distributions as dist
import pyro.infer
import pyro.optim
from pyro.ops.indexing import Vindex

In [12]:
def generate_data(num_obs):
    prior = {
        'A': torch.tensor([1.0,10.0]),
        'B': torch.tensor([[10.0,1.0], [1.0, 10.0]]),
        'C': torch.tensor([[10.0,1.0],[1.0,10.6]])
    }

    CPDs = {
        'p_A': Beta(prior['A'][0], prior['A'][1]).sample(),
        'p_B': Beta(prior['B'][:, 0], prior['B'][:, 1]).sample(),
        'p_C': Beta(prior['C'][:,0], prior['C'][:, 1]).sample()
    }

    data = {'A': Bernoulli(torch.ones(num_obs)*CPDs['p_A']).sample()}
    
    data['B'] = Bernoulli(
        torch.gather(CPDs['p_B'], 0, data['A'].type(torch.long))
    ).sample()
    
    data['C'] = Bernoulli(
        torch.gather(CPDs['p_C'], 0, data['B'].type(torch.long))
    ).sample()
    return prior, CPDs, data


In [28]:
t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
torch.gather(t, dim=1, index=torch.tensor([[0,1],[1,0]]))

tensor([[1, 2],
        [5, 4]])

In [13]:
generate_data(5)

({'A': tensor([ 1., 10.]),
  'B': tensor([[10.,  1.],
          [ 1., 10.]]),
  'C': tensor([[10.0000,  1.0000],
          [ 1.0000, 10.6000]])},
 {'p_A': tensor(0.0444),
  'p_B': tensor([0.6791, 0.0967]),
  'p_C': tensor([0.8047, 0.0657])},
 {'A': tensor([0., 0., 0., 0., 0.]),
  'B': tensor([1., 1., 1., 0., 1.]),
  'C': tensor([1., 0., 0., 1., 1.])})

In [16]:
@pyro.infer.config_enumerate
def model(prior, obs, num_obs):
    p_A = pyro.sample('p_A', dist.Beta(1,1))
    p_B = pyro.sample('p_B', dist.Beta(tirch.ones(2), torch.ones(2)).to_event(1))
    p_B = pyro.sample('p_C', dist.Beta(tirch.ones(2), torch.ones(2)).to_event(1))

    with pyro.plate('data_plate', num_obs):
        A = pyro.sample('A', dist.Bernoulli(p_A.expand(num_obs)), obs=obs['A'])
        B = pyro.sample(
            'B',
            dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]),
            infer = {'enumerate': 'parallel'}
        )
        pyro.sample('C', dist.Bernoulli(Vindex(p_C)[B.type(torch.long)]), obs=obs['C'])