## Introduction

This notebook is an introduction to the NumPyro and JAX Numpy methods used to resolve the exercises in chapter 2 of the book Statistical Rethinking. There are many other usefull functions and method that can be learned at the [official numpyro documentation](https://num.pyro.ai/en/latest/index.html#introductory-tutorials).JAX Numpy has the same syntax as NumPy but is usually faster. NumPyro has some exciting modules to analyze models and distributions with friendly syntax. 

## Summary


- The tutorial covers
    - Array declaring
    - Simple algebric expressions
    - Filter array based on condition
    - Declaring distributions
    - Sample from distribuitions
    - Modeling
        - Apply Auto Laplace Approximation in a defined model
        - Apply Stochastic Variational Inference from a Auto Laplace Approximation
        - Use the result of Stochastic Variational Inference to posteriori sample 



## Module `jax.numpy`

`jax.numpy` is an API for executing commands efficiently with the jax library still with NumPy syntax. Below comparing NumPy and jax.numpy syntaxes:

In [11]:
import jax.numpy as jnp
import numpy as np

# Declare an array
arr = np.array([0.0, 3, 8, 9, 0])
arr_j = jnp.array([0.0, 3, 8, 9, 0])

# Sum elements of array
sum_arr = np.sum(arr)
sum_arrj = jnp.sum(arr)

# Exponentiate a number
a = np.exp(1)
ej = jnp.exp(1) 

# Exponentiate an array
exp_arr = np.exp(arr)
exp_arr_j = jnp.exp(arr_j)

# Conditionally change values in an array
and_arr = jnp.where(arr_j < 0.5, 0, 1)
cond_arrj = np.where(arr < 0.5, 0, 1)

# Create a vector with sequential values
seq = np.linspace(start=0, stop=1, num=100)
seq_j = jnp.linspace(start=0, stop=1, num=100)

# Create a vector with repeated values
rep = np.repeat(1,100)
rep_j = jnp.repeat(1,100)



It's also possible to do `jax.numpy` operations directly in a NumPy vector, for example, replacing values based on conditions.

In [13]:
and_arr = jnp.where(arr < 0.5, 0, 1)

`jax.numpy` is way faster than `numpy` for large operations. 

In [14]:
%timeit jnp.linspace(start=0, stop=1, num=1000000000)

420 ms ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%timeit np.linspace(start =0, stop=1, num=1000000000)

3.48 s ± 355 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Module `numpyro`

To use `numpyro`, start configuring which hardware will execute the code. 

In [16]:
import numpyro

numpyro.set_platform("cpu")

Declaring distributions in `numpyro` is simple: the distribution parameters are the arguments. Then, processed values of the distribution can be accessed with methods and functions.

For example, consider the Binomial distribution. 

$P(x) = (\frac{n!}{(n-x)! x!})p^{x}(1-p)^{n-x}$

To describe a binomial curve of the equation above, we need the total count of events `n`, the positive events `x`, and the probability `p` of getting success in one trial. These are the parameters to calculate a binomial distribution in numpyro. For example, if the binomial distribution was defined considering 9 events with 50% of probability, we can declare it as: 

In [17]:
import numpyro.distributions as dist

n=9
p=0.5

# Considering 9 events with a 50% probability
binomial_dist = dist.Binomial(total_count=n, probs=p)

To get the probability given some number of positive events we can call the method `.log_prob`. The method retrive the log of the probability of the declared distribution. For example, to get the log probability of 6 positive outcomes:

In [18]:
positive_events = 6
log_prob = binomial_dist.log_prob(positive_events)

But usually, what we need is the actual probability, not the log of it. So it's possible to exponentiate and get the log with `jax.numpy`

In [19]:
prob = jnp.exp(log_prob)

So, now that you have a distribution, how do you take samples from it? <br>Call the python object containing the distribution and the method `sample` with the first argument concerning the seed and the second argument as the total number of samples.</br> <br>For example, extract 1000 samples and store them in the vector `samples`:</br>

In [20]:
import numpyro 
from jax import random

samples = binomial_dist.sample(random.PRNGKey(0), (1000,))


Each value in the vector 'samples' corresponds to the total positive events in 9 occurrences given the binomial distribution previously defined. 

#### Models

<br>But when defining a model, the parameter of a distribution can be another distribution. To deal with it, we can declare the relationship between distributions in an algebraic form in a python function. </br>For example, suppose we want to define a Binomial distribution where the probability `p` of getting success in a trial follows a uniform distribution:
<br></br>
$
\newline
P(x) \backsim Binomial(n, p)
\newline
p \backsim Uniform(0,1)
$
<br></br>
in code:

In [21]:
def model(n:int, x:int):
    """Declare a binomial distribution with the probability parameter as a uniform distribution

    Args:
        n (int): total count of events
        x (int): total count of positive outcomes
    """    
    p = numpyro.sample(name="p", fn=dist.Uniform(0, 1))
    numpyro.sample(name="x", fn=dist.Binomial(n, p), obs=x)
 

Note that when declaring the related distributions, we didn't use the method `.sample ` from the distribution, but the method `numpyro.sample` where there is no argument defining the number of samples. Still, we can declare the observed values in the argument `obs`.

With the correct model declared in a function to calculate the posteriori, we need to

1. Define the posteriori approximation;
2. Define optimization;
3. Optimize;
4. Take samples of posteriori.


For example:

In [22]:
import numpyro.optim as optim
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation
from jax import random

# 1. Define the posteriori approximation
guide = AutoLaplaceApproximation(model)

# 2. Define optimization
svi = SVI(model, guide, optim.Adam(1), Trace_ELBO(), n=6, x=3)

# 3. Optimize
svi_result = svi.run(random.PRNGKey(0), 1000)

100%|██████████| 1000/1000 [00:00<00:00, 1689.82it/s, init loss: 2.5989, avg. loss [951-1000]: 2.5494]


In [26]:
# 4. Take samples of posteriori
params = svi_result.params
samples = guide.sample_posterior(random.PRNGKey(0), params, (1000,))


Also, with the samples we can analyse results with `numpyro.diagnostics`.

In [28]:
numpyro.diagnostics.print_summary(samples, prob=0.89, group_by_chain=False)


                mean       std    median      5.5%     94.5%     n_eff     r_hat
         p      0.50      0.15      0.50      0.28      0.77    954.19      1.00

