In [1]:
import pyro
import torch
import numpy as np
from pyro.infer.importance import Importance

# Whatsapp puzzle

Probability of a day being Sunday is givn by Bernoulli distribution
$$P(A) = p_{A}(x) = \mathrm{Bern}\left(x;1/7\right)$$

If the day is Sunday, number of messages $B$ per hour is given by the Poisson distribution
$$P(B|A=1) = p_{B}(x|A=1) = \mathrm{Poi}(x;3)$$
otherwise
$$P(B|A=0) = p_{B}(x|A=0) = \mathrm{Poi}(x;10)$$

If receives $4$ messages on a day, what is the probability that the day is Sunday
$$P(A=1|B=4)$$

# Calculate

In [2]:
def model():
    is_sunday = pyro.sample("is_sunday", pyro.distributions.Bernoulli(1/7.0))
    rate = torch.tensor(3.) if is_sunday == 1 else torch.tensor(10.)
    numbers = pyro.sample("numbers", pyro.distributions.Poisson(rate))
    pyro.sample("obs", pyro.distributions.Poisson(rate), obs=torch.tensor(4))
    return is_sunday.item()

In [3]:
importance = Importance(model, guide=None, num_samples=20000)
importance.run()

<pyro.infer.importance.Importance at 0x7f59e1f12e20>

In [4]:
log_weights = np.array([w.item() for w in importance.log_weights])
values = np.array([t.nodes['is_sunday']['value'].item() for t in importance.exec_traces])
weights = np.exp(log_weights)
np.sum(weights * values) / np.sum(weights)

0.5970305704350286

# Sample from posterior via systematic resampling

In [5]:
import math
def systematic_resampling(log_weights, values):
    import torch

    mx = max(log_weights)
    weight_sum = sum(math.exp(log_weight - mx) for log_weight in log_weights)
    u_n = torch.distributions.Uniform(0, 1).sample().item()
    sum_acc = 0.0
    resamples = []
    for (log_weight, value) in zip(log_weights, values):
        weight = math.exp(log_weight - mx) * len(values) / weight_sum
        sum_acc += weight
        while u_n < sum_acc:
            u_n += 1
            resamples.append(value)
    return resamples


In [6]:
resamples = systematic_resampling(log_weights, values)                                                               

In [4]:
from pyro.infer.importance import Importance
import pyro.infer.mcmc as pyromcmc
from pathlib import Path
import pyro
import torch
import numpy as np

def walk_model():
    import pyro
    import torch

    start = pyro.sample("start", pyro.distributions.Uniform(0, 3))
    t = 0
    position = start
    distance = torch.tensor(0.0)
    while position > 0 and position < 10:
        step = pyro.sample(f"step_{t:03d}", pyro.distributions.Uniform(-1, 1))
        distance = distance + torch.abs(step)
        position = position + step
        t = t + 1
    pyro.sample("obs", pyro.distributions.Normal(1.1, 0.1), obs=distance)
    return start.item()
kernel = pyromcmc.HMC(
    walk_model,
    step_size=0.1,
    num_steps=50,
    adapt_step_size=False,
)
count = 100
mcmc = pyromcmc.MCMC(kernel, num_samples=count, warmup_steps=count // 10)
mcmc.run()
samples = mcmc.get_samples()
for key, value in samples.items():
    print(key, value.shape, value)
mcmc.summary() 

Sample: 100%|██████████| 110/110 [01:06,  1.66it/s, step size=1.00e-01, acc. prob=0.000]


start torch.Size([100]) tensor([1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559, 1.2559,
        1.2559])
step_000 torch.Size([100]) tensor([0.7614, 0.7614, 0.7614, 0.7614, 0.7614, 0.76