## Baseball

This example comes from the [documentation](http://num.pyro.ai/en/stable/examples/baseball.html), although I have tried to simplify it a bit / make amenible to poking around. 

__Note:__
The example in the documentation is better from the point of view of software engineeering; this is more learning about how to predict on new samples and figuring out how things work

In [1]:
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import logsumexp

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import BASEBALL, load_dataset
from numpyro.infer import HMC, MCMC, NUTS, SA, Predictive, log_likelihood

## Data fetching

In [3]:
_, fetch_train = load_dataset(BASEBALL, split='train', shuffle=False)
train, player_names = fetch_train()
_, fetch_test = load_dataset(BASEBALL, split='test', shuffle=False)
test, _ = fetch_test()
at_bats, hits = train[:, 0], train[:, 1]
season_at_bats, season_hits = test[:, 0], test[:, 1]

Downloading - https://d2hg8soec8ck9v.cloudfront.net/datasets/EfronMorrisBB.txt.
Download complete.


In [7]:
train, train.shape

(array([[45, 18],
        [45, 17],
        [45, 16],
        [45, 15],
        [45, 14],
        [45, 14],
        [45, 13],
        [45, 12],
        [45, 11],
        [45, 11],
        [45, 10],
        [45, 10],
        [45, 10],
        [45, 10],
        [45, 10],
        [45,  9],
        [45,  8],
        [45,  7]]),
 (18, 2))

In [8]:
test, test.shape

(array([[412, 145],
        [471, 144],
        [566, 160],
        [320,  76],
        [463, 128],
        [511, 140],
        [631, 168],
        [183,  41],
        [555, 148],
        [245,  57],
        [583, 152],
        [231,  52],
        [480, 142],
        [322,  83],
        [636, 205],
        [603, 168],
        [453, 137],
        [115,  21]]),
 (18, 2))

## Define the model

In [9]:
def partially_pooled_with_logit(at_bats, hits=None):
    r"""
    Number of hits has a Binomial distribution with a logit link function.
    The logits $\alpha$ for each player is normally distributed with the
    mean and scale parameters sharing a common prior.

    :param (jnp.DeviceArray) at_bats: Number of at bats for each player.
    :param (jnp.DeviceArray) hits: Number of hits for the given at bats.
    :return: Number of hits predicted by the model.
    """
    loc = numpyro.sample("loc", dist.Normal(-1, 1))
    scale = numpyro.sample("scale", dist.HalfCauchy(1))
    num_players = at_bats.shape[0]
    with numpyro.plate("num_players", num_players):
        alpha = numpyro.sample("alpha", dist.Normal(loc, scale))
        return numpyro.sample("obs", dist.Binomial(at_bats, logits=alpha), obs=hits)

## Do the training

In [11]:
training_rnd_key = random.PRNGKey(3)
chain = MCMC(NUTS(partially_pooled_with_logit), num_warmup=500, num_samples=2500, num_chains=1)

In [12]:
chain.run(training_rnd_key, at_bats=train[:,0], hits=train[:, 1])

sample: 100%|██████████| 3000/3000 [00:06<00:00, 495.23it/s, 15 steps of size 3.12e-01. acc. prob=0.73]


In [23]:
chain.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
  alpha[0]     -0.88      0.20     -0.91     -1.19     -0.54    344.49      1.00
  alpha[1]     -0.90      0.19     -0.92     -1.19     -0.56    365.78      1.00
  alpha[2]     -0.93      0.18     -0.94     -1.20     -0.61    544.93      1.00
  alpha[3]     -0.95      0.18     -0.96     -1.26     -0.68    711.61      1.00
  alpha[4]     -0.97      0.18     -0.98     -1.25     -0.67    772.31      1.00
  alpha[5]     -0.98      0.17     -0.98     -1.26     -0.68    765.73      1.00
  alpha[6]     -1.00      0.17     -1.01     -1.29     -0.72    809.75      1.00
  alpha[7]     -1.03      0.17     -1.02     -1.29     -0.74    778.07      1.00
  alpha[8]     -1.05      0.18     -1.04     -1.35     -0.78    658.15      1.00
  alpha[9]     -1.05      0.18     -1.04     -1.35     -0.75    818.47      1.00
 alpha[10]     -1.08      0.18     -1.06     -1.37     -0.79    608.12      1.00
 alpha[11]     -1.07      0

Let's push through the _training_ data:

In [35]:
 predictions = Predictive(partially_pooled_with_logit, 
                          posterior_samples=chain.get_samples())(random.PRNGKey(2), train[:, 0])['obs']

In [36]:
predictions.shape

(2500, 18)

Okay, now thee _testing_ data:

In [37]:
predictions = Predictive(partially_pooled_with_logit, 
                          posterior_samples=chain.get_samples())(random.PRNGKey(2), test[:, 0])['obs']

In [38]:
predictions.shape

(2500, 18)

Wait a second .... the testing data just so happens to be the same size as the observational data! What happeens if we only try to preedict 10 of the batters?

In [39]:
predictions = Predictive(partially_pooled_with_logit, 
                          posterior_samples=chain.get_samples())(random.PRNGKey(2), test[:10, 0])['obs']

ValueError: Incompatible shapes for broadcasting: ((1,), (18,), (10,))

Everything breaks -- this vectorization _only_ works if we have the same number of elements in the test set as we had in the training set. This is pretty misleading in the documentation.

## The example leads us astray... let's try the documentation

The thing that we seem to need is the guide, as stated in the documentation for [Predictive](http://num.pyro.ai/en/stable/utilities.html?highlight=Predictive#predictive)

In [None]:
numpyro.plate()