# Joint Probability (Discrete)

In [7]:
# Import some helper functions (please ignore this!)
from utils import * 

**Context:** So far, you've spent some time conducting a preliminary exploratory data analysis (EDA) of IHH's ER data. You noticed that considering variables separately can result in misleading information. As a result, you decided to use *conditional distributions* to model the *relationship between variables*. Using these conditional distributions, you were able to develop *predictive models* (e.g. predicting the probability of intoxication given the day of the week). These predictive models are useful for the IHH administration to make decisions. 

However, you've noticed that your modeling toolkit is still limited. The conditional distributions we introduced can model how the probability of one variable changes given a *set* of variables. What if we wanted to describe how the probability of a *set* of variables (i.e. more than one) changes given a *set* of variables? For example, we may want to answer questions like: "how does the probability that a patient is hospitalized for an allergic reaction change given the day of the week?" In this question, we're inquiring about two variables---that the condition is an allergic reaction, *and* that the patient was hospitalized---given the day of the week.

**Challenge:** We need to expand our modeling toolkit to include yet another tool---joint probabilities. 

**Outline:**
1. Introduce and practice the concepts, terminology, and notation behind discrete joint probability distributions (leaving continuous distributions to a later time).
2. Introduce a graphical representation to describe joint distributions.
3. Translate this graphical representation directly into code in a probabilistic programming language (using `NumPyro`) that we can then use to fit the data.

## Terminology and Notation

We, again introduce the statistical language---terminology and notation---to precisely specify to a computer how to model our data. We will then translate statements in this language directly into code in `NumPyro` that a computer can run.

**Concept.** The concept behind a joint probability is elegant; it allows us to build complicated distributions over many variables using simple conditional and non-conditional distributions (that we already covered). 

We can illustrate this using an example with just two variables. Suppose you have two RVs, $A$ and $B$. The probability that $A = a$ and $B = b$ are *both* satisfied is called their *joint probability*. It is denoted by $p_{A,B}(a, b)$. This joint distribution can be *factorized* to a product of conditional and non-conditional (or "marginal") distributions as follows:
\begin{align*}
p_{A, B}(a, b) &= p_{A | B}(a | b) \cdot p_B(b) \quad \text{(Option 1)} \\
\underbrace{\phantom{p_{A, B}(a, b)}}_{\text{joint}} &= \underbrace{p_{B | A}(b | a)}_{\text{conditional}} \cdot \underbrace{p_A(a)}_{\text{marginal}} \quad \text{(Option 2)}
\end{align*}
Notice that the joint is now described in terms of conditional and marginal distributions, which we already know how to work with! 

**Intuition.** So what's the intuition behind this formula? Let's depict events $A$ and $B$ as follows:

<img align="center" width="500px" src="figs/joint-probability-venn.png" />

In this diagram, different shaded areas represent the probability of a different events. We use it to pictorially represent the marginal, conditional, and joint distributions. The marginal $p_B(b)$ is the ratio of the blue circle relative to the whole space (the gray square):

<img align="center" width="300px" src="figs/joint-probability-eq-marginal.png" />

The conditional $p_{A | B}(a | b)$ is the ratio of the purple intersection relative to the blue circle. This is because the blue circle represents us conditioning on $B = b$, and the intersection of the circles represents the observations for which we *also* have $A = a$.

<img align="center" width="300px" src="figs/joint-probability-eq-conditional.png" />

Finally, the joint $p_{A, B}(a, b)$ is the ratio between the purple intersection and the whole space (the gray square). This is because the intersection is the place where both $A = a$ and $B = b$.

<img align="center" width="300px" src="figs/joint-probability-eq-joint.png" />

Now we can see that the joint is the product of the conditional and the marginal because the blue circles "cancel out":

<img align="center" width="500px" src="figs/joint-probability-eq-joint-expanded.png" />

**Choice of Factorization.** Lastly, notice that we have a *choice* to factorize the distribution in two ways. How do you know which one to use? Typically, we choose a factorization that is *intuitive to us* and what we can compute. 
> For example, suppose you want to model the joint distribution of the day of the week, $D$ and whether a patient arrive with intoxication, $I$. The joint distribution can be factorized in two ways:
> \begin{align*}
p_{D, I}(d, i) &= p_{I | D}(i | d) \cdot p_D(d) \quad \text{(Option 1)} \\
&= p_{D | I}(d | i) \cdot p_I(i) \quad \text{(Option 2)} \\
\end{align*}
> Which one makes more intuitive sense? Well, it's a little weird to try to predict the day of the week given whether a patient arrives with intoxication; we typically know what the day of the week is and we don't need to predict it. In contrast, given the day of the week, it makes a lot of sense to wonder about the probability of a patient arriving with intoxication. As such, Option 1 makes more sense here. 

**Generalizing to More than Two RVs.** So now we have the tools to work with joint distributions with two RVs. What do we do if we have three or more? The same ideas apply. The joint distribution for random variables $A$, $B$, and $C$ can be factorized in a number of ways. For example, we can condition on two variables at a time:
\begin{align*}
p_{A, B, C}(a, b, c) &= p_{A | B, C}(a | b, c) \cdot p_{B, C}(b, c) \quad \text{(Option 1)} \\
&= p_{B | A, C}(b | a, c) \cdot p_{A, C}(a, c) \quad \text{(Option 2)} \\
&= p_{C | A, B}(c | a, b) \cdot p_{A, B}(a, b) \quad \text{(Option 3)}
\end{align*}
wherein the above, we already know how to factorize $p_{B, C}(b, c)$, $p_{A, C}(a, c)$, and $p_{A, B}(a, b)$ (since they are joint distributions with two variables).

We can also condition on one variable at a time:
\begin{align*}
p_{A, B, C}(a, b, c) &= p_{A, B | C}(a, b | c) \cdot p_C(c) \quad \text{(Option 1)} \\
&= p_{A, C | B}(a, c | b) \cdot p_B(b) \quad \text{(Option 2)} \\
&= p_{B, C | A}(b, c | a) \cdot p_A(a) \quad \text{(Option 3)}
\end{align*}
And how do we further factorize distributions of the form $p_{A, B | C}(a, b | c)$? We apply the same factorization for a joint distribution with two variables, and simply add a "conditioned on $C$" to each one:
\begin{align*}
p_{A, B | C}(a, b | c) &= p_{A | B, C}(a | b, c) \cdot p_{B | C}(b | c) \quad \text{(Option 1)} \\
&= p_{B | A, C}(b | a, c) \cdot p_{A | C}(a | c) \quad \text{(Option 2)} \\
\end{align*}

## Directed Graphical Models (DGMs)

As you may have already noticed, the number of possible ways to factorize a joint distribution increases *very quickly* with the number of RVs. In fact, the more RVs we have, the more unwieldy it becomes for us as data analysts to specify each component in the factorization. What can we do to simplify our model? Often, we can use our *domain knowledge* (knowledge of the specifics of the problem) to simplify the joint distribution. Specifically, we'll use our knowledge of "conditional independence" to do this. Let's get started by first introducing the idea of *statistical independence*.

**Statistical Independence.** We say two variables $A$ and $B$ are statistically independent if their joint can be factorized to the product of their marginals:
\begin{align*}
p_{A, B}(a, b) &= p_B(b) \cdot p_A(a)
\end{align*}
This equation implies that to sample $A$ and $B$ jointly, we don't have to consider their relationship (since the conditional isn't used)---they are entirely independent. 

Another way to understand this equation is by thinking its implications on the conditionals. We do this by factorizing $p_{A, B}(a, b)$ into the product of the conditional and marginal:
\begin{align*}
p_{A, B}(a, b) &= \underbrace{p_{B | A}(b | a)}_{\text{must equal } p_B(b)} \cdot p_A(a)
\end{align*}
We then observe that $p_{B | A}(b | a)$ must equal $p_B(b)$ to satisfy our definition of statistical independence. And $p_{B | A}(b | a) = p_B(b)$ implies that having observed $A = a$ does not affect the probability of $B = b$.

**Graphical Representation of Statistical Dependencies.** Since reasoning about many variables jointly is difficult, we introduce a graphical representation to aid with it. This representation is called a *directed graphical model* (DGM), and it will help us convey which variables depend on one another in what way. 

A DGM is represented using a *graph* (or network) in which nodes represent RVs and arrows represent conditional dependencies. For example, consider the following DGM for some hypothetical joint distribution, $p_{A, B, C}(\cdot)$:

<img align="center" width="500px" src="figs/joint-probability-example-dgm.png" />

In this DGM, there are three nodes, corresponding to our three RVs. Our factorization then consists of one factor for each node:
* The factor corresponding to $B$ is $p_{B | A}(\cdot)$, since there's an arrow from $A$ to $B$, indicating a conditional dependence.
* The factor corresponding to $A$ is $p_{A}(\cdot)$, since there aren't any arrows pointing into $A$.
* The factor corresponding to $B$ is $p_{C}(\cdot)$, since there aren't any arrows pointing into $C$.

In total, the DGM represents the following factorization:
\begin{align*}
p_{A, B, C}(a, b, c) &= p_{B | A}(b | a) \cdot p_A(a) \cdot p_C(c)
\end{align*}

```{admonition} Exercise
**Part 1:** Write down the factorization for $p_{A, B, C, D, E, F, G, H}(\cdot)$ implied by the following DGM:

<img align="center" width="500px" src="figs/joint-probability-exercise-dgm.png" />

**Part 2:** Draw a DGM representing the following joint distribution:
\begin{align*}
p_{A, B, C, D, E}(a, b, c, d, e) &= p_{C | A, B, D}(c | a, b, d) \cdot p_{A | D}(a | d) \cdot p_{D | B, E}(d | b, e) \cdot p_{B}(b) \cdot p_{E}(e)
\end{align*}
```

TODO write your answer here

## Translating Math to Code with `NumPyro`

**What is `NumPyro`?** `NumPyro` is a "Probabilistic Programming Language" based in `Jax`. It provides an interface for (nearly) direct translation of the stats/math we wrote above into code that we can use to fit to data, make predictions, and more. This will allow us to focus on the conceptual ideas behind probabilistic ML. 

**Instantiating Distributions in `NumPyro`.** `NumPyro` comes with many distributions already implemented. For a complete list of all available discrete distributions, check out the [this part of the documentation](https://num.pyro.ai/en/stable/distributions.html#discrete-distributions). So why use `NumPyro` instead of implementing the distributions on our own? It's easy to write subtle bugs that are hard to catch when implementing mathematical formulas in code. Also, using `NumPyro`'s distributions will help us highlight the overall *logic* of the code, instead of getting bogged down by the mathematical details. 

Distributions in `NumPyro` have several notable properties and methods we will rely on. Let's explore them together. First, we import the necessary components of `NumPyro`:

In [8]:
import jax.numpy as jnp
import jax.random as jrandom
import numpyro
import numpyro.distributions as D

Now, let's instantiate the simplest discrete distribution we know: the Bernoulli distribution. 
\begin{align*}
p_X(x) &= \mathrm{Bern}(\rho) = \rho^x \cdot (1 - \rho)^{1 - x}
\end{align*}
Recall that a Bernoulli distribution takes in just one parameter, $\rho \in [0, 1]$, which determines the probability of sampling $X = 1$ vs. $X = 0$. Here let's instantiate the Bernoulli distribution with $\rho = 0.7$.

In [9]:
rho = jnp.array(0.7)
p_X = D.Bernoulli(rho)

That's it! 

**Evaluating the PMF of `NumPyro` Distributions.** Now, if we want to evaluate the PMF, we can use `log_prob` method as follows (note that this returns the *log* of the PMF, so we'll have to exponentiate the result):

In [10]:
log_p_x_eq_1 = p_X.log_prob(jnp.array(1.0))
print('Probability of sampling a 1:', jnp.exp(log_p_x_eq_1))

log_p_x_eq_0 = p_X.log_prob(jnp.array(0.0))
print('Probability of sampling a 0:', jnp.exp(log_p_x_eq_0))

Probability of sampling a 1: 0.7
Probability of sampling a 0: 0.3


**Sampling from `NumPyro` Distributions.** `NumPyro` distributions all have a `sample` method which can be used to draw samples. It takes in two arguments:
1. A random-generator "key," which controls the randomness of the sample.
2. A shape, describing the number of i.i.d samples you want to draw.

Let's give it a go:

In [11]:
shape = (15,) # Shape of i.i.d samples we wish to draw 

key1 = jrandom.PRNGKey(seed=0) # Create a random-generator key
print('First batch drawn with key1: ', p_X.sample(key1, shape))
print('Second batch drawn with key1:', p_X.sample(key1, shape))

key2 = jrandom.PRNGKey(seed=1) # Create a random-generator key
print('Third batch drawn with key2: ', p_X.sample(key2, shape))

First batch drawn with key1:  [1 0 1 0 1 1 0 1 1 1 0 0 1 1 1]
Second batch drawn with key1: [1 0 1 0 1 1 0 1 1 1 0 0 1 1 1]
Third batch drawn with key2:  [1 1 1 0 1 1 1 1 1 1 1 1 0 1 1]


Notice in the above code, when using the same key twice (or the same `seed`), we get the *exact same batch of samples*. This is both a blessing and a curse. It's a blessing because this allows us to precisely control the randomness of our ML code. This will prove crucial for debugging later on. However, it can also be a curse if we accidentally use the same key in a place where we need two different sources of randomness.

**Best Practice: How to Manage Your Keys.** We will follow to rules of thumb:
1. Make only ONE CALL to `jrandom.PRNGKey` in your entire code.
2. Never use the same key twice.

But if we're restricting ourselves to only creating one key with `jrandom.PRNGKey`, how can we possibly call `sample` multiple times with different keys? `Jax` allows us to take a random key and split it into multiple different keys, each of which can be used for different purposes. This means we can create ONE KEY to control the randomness of our entire code. We can then split this key into multiple keys as needed. Here's how we can do this:

In [12]:
# Create ONE KEY to be used by your ENTIRE CODE
key = jrandom.PRNGKey(seed=0)

# Whenever you need to use the key for multiple purposes, split it into parts:
key_first, key_second, key_third = jrandom.split(key, 3) 

# Use a different key for each need
print('First batch drawn with key_first:  ', p_X.sample(key_first, shape))
print('Second batch drawn with key_second:', p_X.sample(key_second, shape))
print('Third batch drawn with key_third:  ', p_X.sample(key_third, shape))

First batch drawn with key_first:   [1 1 1 1 1 1 1 1 0 1 1 0 1 1 1]
Second batch drawn with key_second: [1 1 0 0 1 1 0 1 1 1 0 1 1 0 1]
Third batch drawn with key_third:   [0 0 1 1 1 1 1 1 1 0 1 1 1 0 0]


**Sampling from Conditional Distributions:**

TODO

**Sampling from Joint Distributions:**

TODO

```{admonition} Exercise
**Context:** Your friend is an ML researcher at a nearby university. She heard all about the interesting data you have from the IHH ER and wants to help with the analysis. However, because this is sensitive medical data, she needs to obtain the right credentials, undergo a lengthy training on secure data management, and more, before obtaining access to the data. Realistically, this means she'll only be able to gain access to the IHH ER data in several months. 

**Idea:** To help her out, you have an idea: instead of sending her the data directly, you will develop a *generative model* of the IHH ER data and send that to her instead. This generative model will allow your friend to generate (or sample) realistic data with the same characteristics as the real data without violating any privacy constraints. But what is exactly a generative model? A generative model is a joint probability distribution over all variables in the data: $D$, $C$, $H$, $M$, and $A$. 

**Problem:** Use what you already know about the marginal and conditional distributions of the IHH ER data, in tandem with what you learned here about joint distributions to implement this generative model. Your model should take the form of a function that takes in a `key` and outputs a Python dictionary with a single sample from the joint distribution $p_{D, C, H, M, A}(\cdot)$.

**Note:** `NumPyro` discrete distributions only work with integers, not strings. For example, instead of using $d = \text{Monday}$, you should convert the days of the week into an integers from 0 to 7, and instead use $d = 0$. We provide some code below to help you do this.
```

In [13]:
# A list mapping numbers 0 through 7 to days of the week
# You may need to implement a mapping (e.g. using a dictionary) for the opposite way
IDX_TO_DAY_OF_WEEK = [
    'Monday',
    'Tuesday',
    'Wednesday',
    'Thursday',
    'Friday',
    'Saturday',
    'Sunday',
]

# A list mapping numbers 0 through 4 to conditions
# You may need to implement a mapping (e.g. using a dictionary) for the opposite way
IDX_TO_CONDITION = [
    'High Fever',    
    'Broken Limb',    
    'Entangled Antennas',
    'Allergic Reaction',
    'Intoxication',
]

# A list mapping numbers 0 and 1 to Yes and No
# You may need to implement a mapping (e.g. using a dictionary) for the opposite way
IDX_TO_BOOL = [
    'No',
    'Yes',
]

In [14]:
def sample_IHH_ER_generative_model(key):
    pass # TODO implement