# Pyro

https://www.youtube.com/watch?v=aLFJ5ERxt2c

In [5]:
import pyro
from pyro import distributions

In [7]:
x = pyro.sample("x", distributions.Bernoulli(0.5))
x

tensor(1.)

### Pyro Primitives

Here we have an `obs` in the pyro.sample statement. We fit the learnable params in this distribution to the observed data.

`pyro.sample("data", distributions.Normal(0, 1), obs = data)`

We also have the `pyro.param` statement, which is how we specify that a parameter is learnable. We register it with a name and initial value, and optionally constrain the parameter to run constrained optimizations.

`pyro.param("theta", torch.ones(100), constraint = positive)`

#### Plate

Using `pyro.plate` we can specify conditional independence / exchangeability. Below we can draw multiple samples in parallel by sampling batches by specifying batch dimensions.

```
with pyro.plate("data", len(data), batch_size) as ids :
    pyro.sample("data", func, obs = data[ids])
```

Sometimes we want to do different things with different batches:

```
for i in pyro.plate("data", len(data), batch_size) :
    pyro.sample("data_%i"%i, func(x), obs = data[i])
```



### Pyro Models

Pyro models are Python functions.

- Line 1: We declare something as trainable $p$
- Line 2: Sample statement that samples a Categorical based on our trainable parameter, $p$
- Line 3: Control flow. Without this, our model is just a graphical model.
- Line 4: We have an observe statement that conditions our model on our data.

```
def model(data) :
    p = pyro.param("p", torch.ones(10)/10, constraint = simplex)
    c = pyro.sample("c", Categorical(p)
    if c > 0 :
        pyro.sample("obs", Normal(helper(c - 1), 1.), obs = data)
        
def helper(c) :
    x = pyro.sample("x", Normal(0, 10))
    return x[c]
```

### Pyro for Semi-Supervised Learning

- Line 1: Goes through the NN and declares pyro.param on all the parameters. Like a recursive param statement
- Line 2: I want my embedding vectors to be Gaussian 
- Line 3: Sample a random digit
- Line 4: Conditioned on the digit (decode those together) we sample a binarized image. If the image is not provided, this is a generative model, however if we do provide an image, this is interpreted as conditioning the model to an image. WE can then ttry to fig out the conditional distribution given the image. 

```
decoder = ... # neural network

def model (image = None):
    pyro.module("decoder", decoder) # fancy param statement. This declares pyro.param on all the neural net parameters
    style = pyro.sample("style", Normal(zeros(20), ones(20))
    digit = pyro.sample("digit", Categorical(ones(10) * 0.1)
    image = pyro.sample("image", Bernoulli(style, digit), obs = image)
    return image
    
def guide (image) :
    pyro.module("encoder_digit", encoder_digit)
    pyro.module("encoder_style", encoder_style)
    digit = pyro.sample("digit", Categorical(encoder_digit(image))
    loc, scale = encoder_style(image, digit)
    style = pyro.sample("style", Normal(loc, scale)
    return digit, style
```