[**AN INTRODUCTION TO MODELS IN PYRO**](http://pyro.ai/examples/intro_part_i.html#An-Introduction-to-Models-in-Pyro)

In [1]:
import torch
import pyro
pyro.set_rng_seed(3)

# Stochastic functions - models

## Primitive stochastic functions

* `primitive stochastic functions` are basically `{pyro, torch}.distributions` objects

* **probabilistic programs** are built by composing **primitive stochastic functions** and **deterministic computation**.

### ex1. Unit normal

In [3]:
# Set mean and std
loc = 0.
scale = 1.

# Create distribution object - the primitive stochastic function
normal = torch.distributions.Normal(0,1)

In [13]:
# Draw a sample
x = normal.rsample()
print(f'Sample {x}')
print(f"Log prob of getting sample: {normal.log_prob(x)}")

Sample 0.9774317741394043
Log prob of getting sample: -1.3966249227523804


**NOTE**:
* `pyro.distributions` objects are wrappers around `torch.distributions`

### ex2. Simple weather model

In [29]:
def weather1():
    
    # Sample how the sky is
    sky = torch.distributions.Bernoulli(.3).sample()
    sky = 'cloudy' if sky.item() == 1. else 'sunny'
    return sky

In [35]:
[weather1() for i in range(5)]

['sunny', 'sunny', 'sunny', 'cloudy', 'cloudy']

In [28]:
def weather2():
    
    # Sample how the sky is
    sky = torch.distributions.Bernoulli(.3).sample()
    sky = 'cloudy' if sky.item() == 1. else 'sunny'
    
    # Sample temperature based on sky
    sky2temp = {
        'cloudy': (55., 10.),
        'sunny': (75., 15.)
    }
    
    temp = torch.distributions.Normal(*sky2temp[sky]).rsample()
    
    return sky, temp.item()

In [37]:
[weather2() for i in range(5)]

[('cloudy', 66.80128479003906),
 ('cloudy', 50.844390869140625),
 ('sunny', 67.84062957763672),
 ('sunny', 66.93984985351562),
 ('cloudy', 59.198150634765625)]

## `pyro.sample` Primitive

* to turn general stochastic functions into Pyro programs, we need to replace:
    * `torch.distributions` with `pyro.distributions`
    * `*.sample()` and `*.rsample()` with `pyro.sample(<name>, *)`

### ex2. pyro-weather

In [39]:
def pyro_weather():
    
    # Sample how the sky is
    sky = pyro.sample('sky', pyro.distributions.Bernoulli(.3))
    sky = 'cloudy' if sky.item() == 1. else 'sunny'
    
    # Sample temperature based on sky
    sky2temp = {
        'cloudy': (55., 10.),
        'sunny': (75., 15.)
    }
    
    temp = pyro.sample('temp', pyro.distributions.Normal(*sky2temp[sky]))
    
    return sky, temp.item()

In [40]:
[pyro_weather() for i in range(5)]

[('sunny', 64.62797546386719),
 ('sunny', 57.524993896484375),
 ('sunny', 64.90321350097656),
 ('sunny', 74.11502075195312),
 ('cloudy', 68.44393920898438)]

**NOTE**:
* now, `pyro_weather()` specifies a joint probability distribution of `sky` and `temp` - say, $P(sky, temp) = P(temp|sky)P(sky)$

* we can now ask: if I observe a temperature of 70 degrees, how likely is it to be cloudy? - say, $P(sky=\text{cloudy}|temp=70)$

* in theory, $P(sky=\text{cloudy}|temp=70) = \frac{P(sky=\text{cloudy},temp=70)}{P(temp=70)} = 
\frac{P(sky=\text{cloudy},temp=70)}{P(sky=\text{cloudy},temp=70)P(sky=\text{sunny},temp=70)}$

* pyro helps solve this numerically - with variational inference!

## Stochastic recursion

* we can now compose these things

### ex3. Ice-cream sales

In [44]:
def ice_cream_sales():
    # Call stochastic function
    sky, temp = pyro_weather()
    
    # Expected sales based on weather
    expected_sales = 200. if sky == 'sunny' and temp > 80. else 50.
    
    # Stochastic ice cream sales
    sales = pyro.sample('sales', pyro.distributions.Normal(expected_sales, 10.))
    return sales.item()
    
    

In [45]:
[ice_cream_sales() for i in range(10)]

[51.20252227783203,
 30.857507705688477,
 44.13951110839844,
 40.52595138549805,
 39.3331413269043,
 45.06064224243164,
 53.64537811279297,
 185.06890869140625,
 51.9947509765625,
 65.69964599609375]

## Random control flow

* and use any control flow statements

### ex4. random control flow

In [46]:
def geometric(p, t=None):
    t = t or 0
    
    x = pyro.sample(f'x_{t}', pyro.distributions.Bernoulli(p))
    
    if x.item() == 1:
        return 0
    else:
        return 1 + geometric(p, t+1)

In [83]:
sum(geometric(.3) for i in range(200)) / 200, 1/.3*.7

(2.325, 2.3333333333333335)

## Hierachical functions

* or functions that return functions, whatever