# UBER Numpyro on Windows
#### Celso Axelrud
#### Revision 1.0 - 4/24/2021

This document describes the efforts to execute UBER Numpyro (https://github.com/pyro-ppl/numpyro) on Windows OS.

Numpyro depends on JAX.
JAX is available for Linux including on Google Collab environment but not officially for Windows.

I have been collaborating by compiling and testing JAX for Windows.

Currently, I am able to use Numpyro correctly for the CPU and GPU.

Together with this notebook, several original project testing notebooks are available.

Also, the notebooks for the book "Statistical Rethinking- Richars McEreath" are available.

NumPyro is a small probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for automatic differentiation and JIT compilation to GPU / CPU.
Pyro is a universal probabilistic programming language (PPL) written in Python. Pyro enables flexible and expressive deep probabilistic modeling, unifying the best of modern deep learning and Bayesian modeling. 

In [1]:
#===================================================
# Numpyro Eight Schools example
# https://github.com/pyro-ppl/numpyro

import numpyro

import numpyro.distributions as dist

import numpy as np

#numpyro.set_platform("cpu")

In [2]:
J = 8

y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])


def eight_schools(J, sigma, y=None):

     mu = numpyro.sample('mu', dist.Normal(0, 5))

     tau = numpyro.sample('tau', dist.HalfCauchy(5))

     with numpyro.plate('J', J):

         theta = numpyro.sample('theta', dist.Normal(mu, tau))

         numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

In [3]:
from jax import random
from numpyro.infer import MCMC, NUTS

nuts_kernel = NUTS(eight_schools)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))

mcmc.print_summary()

sample: 100%|█| 1500/1500 [00:05<00:00, 254.76it/s, 15 steps of size 1.94e-01. 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
        mu      4.09      3.19      4.01     -0.72      9.46    155.10      1.03
       tau      4.66      3.70      3.87      0.76      8.84    114.24      1.02
  theta[0]      6.46      6.23      5.73     -3.35     15.98    300.05      1.00
  theta[1]      4.83      5.17      4.62     -3.23     14.04    307.49      1.00
  theta[2]      3.74      5.87      3.78     -5.59     13.06    375.82      1.01
  theta[3]      4.63      5.09      4.39     -1.97     14.31    392.76      1.00
  theta[4]      3.15      4.56      3.26     -3.85     10.41    194.56      1.03
  theta[5]      3.63      5.29      3.70     -3.84     12.01    483.57      1.00
  theta[6]      6.65      5.67      6.06     -2.12     15.33    184.48      1.00
  theta[7]      4.58      5.79      4.41     -3.72     13.76    452.71      1.00

Number of divergences: 13


In [4]:
pe = mcmc.get_extra_fields()['potential_energy']
print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))
#Expected log joint density: -56.14

Expected log joint density: -56.14


In [5]:
from numpyro.infer.reparam import TransformReparam

# Eight Schools example - Non-centered Reparametrization
def eight_schools_noncentered(J, sigma, y=None):
     mu = numpyro.sample('mu', dist.Normal(0, 5))
     tau = numpyro.sample('tau', dist.HalfCauchy(5))
     with numpyro.plate('J', J):
         with numpyro.handlers.reparam(config={'theta': TransformReparam()}):
             theta = numpyro.sample(
                 'theta',
                 dist.TransformedDistribution(dist.Normal(0., 1.),
                                              dist.transforms.AffineTransform(mu, tau)))
         numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

nuts_kernel = NUTS(eight_schools_noncentered)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
#sample: 100%|██████████| 1500/1500 [00:06<00:00, 229.91it/s, 7 steps of size 3.79e-01. acc. prob=0.91]

sample: 100%|█| 1500/1500 [00:05<00:00, 253.22it/s, 7 steps of size 3.79e-01. a


In [6]:
mcmc.print_summary(exclude_deterministic=False)


                   mean       std    median      5.0%     95.0%     n_eff     r_hat
           mu      4.07      3.49      3.95     -1.29     10.09    750.88      1.00
          tau      3.88      3.31      2.93      0.03      8.24    551.26      1.00
     theta[0]      6.18      5.36      5.69     -2.37     14.50    935.24      1.00
     theta[1]      4.74      5.03      4.81     -4.20     12.21   1155.03      1.00
     theta[2]      3.78      5.65      4.01     -5.87     11.28    959.08      1.00
     theta[3]      4.51      4.92      4.31     -2.53     12.78   1104.54      1.00
     theta[4]      3.40      4.74      3.55     -4.00     11.36   1025.76      1.00
     theta[5]      3.61      4.76      3.76     -3.69     11.35    787.76      1.00
     theta[6]      6.20      5.15      5.74     -1.59     14.26    911.82      1.00
     theta[7]      5.03      5.36      4.69     -4.39     12.22   1129.99      1.00
theta_base[0]      0.38      0.98      0.41     -1.30      1.83    775.55  

In [7]:
pe = mcmc.get_extra_fields()['potential_energy']
# Compare with the earlier value
print('Expected log joint density: {:.2f}'.format(np.mean(-pe)))  
#Expected log joint density: -46.09

Expected log joint density: -46.15


In [8]:
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
    theta = numpyro.sample('theta', dist.Normal(mu, tau))

NameError: name 'LocScaleReparam' is not defined

In [9]:
from numpyro.infer import Predictive

# New School
def new_school():
     mu = numpyro.sample('mu', dist.Normal(0, 5))
     tau = numpyro.sample('tau', dist.HalfCauchy(5))
     return numpyro.sample('obs', dist.Normal(mu, tau))

predictive = Predictive(new_school, mcmc.get_samples())
samples_predictive = predictive(random.PRNGKey(1))
print(np.mean(samples_predictive['obs']))  # doctest: +SKIP
4.09555

4.09555


4.09555