In [85]:
from IPython.display import Image
from IPython.core.display import HTML 
import pyro
import torch
from torch.distributions.constraints import unit_interval
from pyro.distributions import Bernoulli
import pyro.distributions as dist
import numpy as np
from torch.distributions import constraints
from pyro import poutine
import pandas as pd
from pyro.infer.mcmc import MCMC
from pyro.infer.mcmc.nuts import HMC
from pyro.infer import EmpiricalMarginal
import matplotlib.pyplot as plt
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.optim import Adam

In [2]:
pyro.enable_validation()
pyro.set_rng_seed(0)

In [3]:
%matplotlib inline

In [4]:
Image(url= "https://i.stack.imgur.com/t99mv.png")

In [30]:
def model():
    p_rain = pyro.param("p_rain", torch.tensor(0.2), constraint=unit_interval)
    p_sprinkler = pyro.param("p_sprinkler", torch.tensor([0.4, 0.01]),
                             constraint=unit_interval)
    p_wet = pyro.param("p_wet", torch.tensor([[0.0, 0.9], [0.8, 0.99]]),
                       constraint=unit_interval)

    rain = pyro.sample("rain", Bernoulli(p_rain))
    sprinkler = pyro.sample("sprinkler",
                            Bernoulli(p_sprinkler[rain.long()]))
    wet = pyro.sample("wet", Bernoulli(p_wet[rain.long(), sprinkler.long()]))
    return wet

## Inference: Going backwards

Condition the model on an observation of Grass being wet

In [32]:
conditioned_model = pyro.condition(model, data={'wet': torch.tensor(1.)})

In [34]:
posterior = pyro.infer.Importance(conditioned_model, num_samples=5000).run()

In [35]:
marginal = pyro.infer.EmpiricalMarginal(posterior, "rain")

In [36]:
rain_samples = np.array([marginal().item() for _ in range(10000)])

In [37]:
rain_samples.mean()

0.3491

Re-writing the params by using categorical variables

Following code samples from (https://github.com/robertness/causalML/blob/7c196d1bd21a9ac168a198aba149a7bb6fd4b69b/tutorials/introduction/intro_to_generative_ML_with_Pyro.ipynb)

In [21]:
Image(url= "https://i.stack.imgur.com/t99mv.png")

In [155]:
def model():
    rain = pyro.sample("rain", dist.Categorical(torch.tensor([0.8, 0.2])))
    
    p_sprinkler = torch.tensor([[0.6, 0.4], [0.99,0.01]])
    p_wet = torch.tensor([[[1, 0],[0.2,0.8]],[[0.1, 0.9],[0.01, 0.99]]])
    
    sprinkler = pyro.sample("sprinkler",
                            dist.Categorical(p_sprinkler[rain]))

    wet = pyro.sample("wet", 
                      dist.Categorical(p_wet[sprinkler][rain]))
                            
    return rain.item(), sprinkler.item(), wet.item()

In [115]:
samples = [model() for _ in range (10000)]

In [116]:
samples_df = pd.DataFrame(samples, columns=['rain', 'sprinkler', 'wet'])

Rain = True 

In [117]:
samples_df.query("rain == 1").shape[0] / samples_df.shape[0]

0.2039

Sprinkler = True & Rain = True

In [118]:
samples_df.query("rain == 1 & sprinkler == 1").shape[0] / samples_df.query("rain==1").shape[0]

0.010299166257969592

Sprinkler = True & Rain = False

In [119]:
samples_df.query("rain == 0 & sprinkler ==1").shape[0] / samples_df.query("rain==0").shape[0]

0.3995729179751287

Sprinkler = True & Rain = False & Wet = False

In [120]:
samples_df.query("rain == 0 & sprinkler ==1 & wet == 0").shape[0] / samples_df.query("rain==0 & sprinkler==1").shape[0]

0.095881798176674

Confirmed correct

In [156]:
conditioned_model = pyro.condition(model, data={'wet': torch.tensor(1)})

In [157]:
posterior = pyro.infer.Importance(conditioned_model, num_samples=5000).run()

In [158]:
marginal = pyro.infer.EmpiricalMarginal(posterior, sites=["rain"])

In [159]:
rain_samples = np.array([marginal().item() for _ in range(10000)])

In [160]:
rain_samples.mean()

0.3554

In [152]:
marginal = pyro.infer.EmpiricalMarginal(posterior, sites=["sprinkler"])

In [153]:
np.array([marginal().item() for _ in range(10000)]).mean()

0.6502

---

### Using SVI instead of Importance Sampling

In [141]:
def model(rain_prob, sprinkler_prob):
    rain = pyro.sample("rain", dist.Bernoulli(rain_prob))
    p_wet = torch.tensor([[[1, 0],[0.2,0.8]],[[0.1, 0.9],[0.01, 0.99]]])

    sprinkler = pyro.sample("sprinkler", dist.Bernoulli(sprinkler_prob))
    wet = pyro.sample("wet", 
                      dist.Categorical(p_wet[sprinkler.long()][rain.long()]))
                            
    return rain.item(), sprinkler.item(), wet.item()

In [142]:
def model_guide(rain_prob, sprinkler_prob):
    rain_prob = pyro.param('rain_prob', rain_prob, constraint=constraints.unit_interval)
    sprinkler_prob = pyro.param('sprinkler_prob', sprinkler_prob, constraint=constraints.unit_interval)
    try:
        rain = pyro.sample('rain', dist.Bernoulli(rain_prob))
    except RuntimeError as e:
        print("rain_prob: {}".format(rain_prob))
        raise e
    sprinkler = pyro.sample('sprinkler', dist.Bernoulli(sprinkler_prob))

In [147]:
def svi_test():
	rain_prob_prior = torch.tensor(.2)
	sprinkler_prob_prior = torch.tensor(.4)
	conditioned_lawn = pyro.condition(model, data={"wet":torch.tensor([1.])})
	# guide = AutoGuide(lawn)
	# set up the optimizer
	adam_params = {"lr": 0.05, "betas": (0.90, 0.999)}
	optimizer = Adam(adam_params)

	# setup the inference algorithm
	svi = SVI(conditioned_lawn, model_guide, optimizer, loss=pyro.infer.Trace_ELBO())

	n_steps = 5000
	# do gradient steps
	for step in range(n_steps):
		svi.step(rain_prob_prior, sprinkler_prob_prior)
		if step % 1000 == 0:
			print("step: ", step)
			for p in ['rain_prob', 'sprinkler_prob']:
				print(p, ": ", pyro.param(p).item())
	for p in ['rain_prob', 'sprinkler_prob']:
		print(p, ": ", pyro.param(p).item())

In [154]:
pyro.clear_param_store()
svi_test()

step:  0
rain_prob :  0.20812010765075684
sprinkler_prob :  0.4120577871799469
step:  1000
rain_prob :  0.26252007484436035
sprinkler_prob :  0.9929171800613403
step:  2000
rain_prob :  0.24675895273685455
sprinkler_prob :  0.9966957569122314
step:  3000
rain_prob :  0.20035420358181
sprinkler_prob :  0.9978237152099609
step:  4000
rain_prob :  0.16417889297008514
sprinkler_prob :  0.9994780421257019
rain_prob :  0.17275157570838928
sprinkler_prob :  0.9994427561759949
