# Lab 7: Probabilistic Programming

So up until now, we've been working on neural networks, which I think of as hitting a problem with a hammer hard enough until it gets fixed... However, another school of though in machine learning is with probabilistic graphical models. In PGMs, we propose a generative process that our data arose from, then learn the parameters for the process that best fit the data. A generative process can be thought of as a composition of probability distributions (rather than a composition of linear and non-linear transformations in a neural network). 

There are many ways to do define probabilistic programming. A lot of models can be actually be coded manually. (if you take Foundations of Graphical Models with David Blei, you'll have to do that!) Fortunately for us, we can take advantage of probabilistic programming languages, which greatly simplify the process of defining a model and performing inference on said model. 

In Computational Methods, our probabilistic programming language of choice will be pyro. Pyro sits on top of pytorch, and since we're all experts in pytorch by now, this seems like a logical next step. If you haven't already install pyro into your `Computational_methods` environment. A note on installation, you can only install pyro by doing `pip install pyro`. If you use the version of pip that's installed in your environment, it will still install packages in the same conda environment.

In [1]:
import torch
import pyro

## Part 1: Sampling

A simle form of inference when working with graphical models involves sampling from the conditional distributions that make up the larger model. One of the main reasons for using a probabilistic programming language is for simplifying this process. Consider the following graphical model, which represents a mixture of multivariate gaussians:

<img src="new_model.png" alt="Drawing" style="width: 600px;"/>

\begin{align*}
\theta & \sim \text{Dirichlet}(\alpha) \\
z & \sim \text{Categorical}(\theta) \\
\beta & \sim \text{MultivariateNormal}(\mu, \Sigma) \\
x & \sim \text{MultivatiateNormal}_{z_k}(\beta_k, I)
\end{align*}

In this model, there are $K$ different Multivate Normal distributions, which each have a mean vector $\beta_k$ and the identity covariance matrix, to keep things simple. If you've never seen the multivariate normal before, check out the [Wikipedia article](https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Notation_and_parametrization). Additionally, every data point $x$ which we get to observe has an associated mixture component $z$, which is drawn from a categorical distribution. You can think of a categorical as a spinner, like in Twister. It randomly assigns a category, based on some distribution. That distribution is defined by a dirichlet $\theta$ which controls how big each section of the spinner is, and how much variation this distribution can have. 

<img src="https://lh3.googleusercontent.com/wDEhZwDSbSWzALuGrRrGnh9PaBoFUyiUX7HwT0MCODiIykPmAjca_YL-GS5O1T-Ti2Qj" alt="Drawing" style="width: 400px;"/>


At this point, I could probably take off with the amount of hand waving I'm doing, but let's try to implement this model in Pyro. Because Pyro is build on pytroch, all of the distributions you need come from the `torch.distributions` module. 

In [None]:
def alice():
    