# Bayesian Imputation

#### Note:
I got this notebook from the numpyro documentation and edited to demonstrate a problem I'm having with imputing missing data to make predictions on new data.

Real-world datasets often contain many missing values. In those situations, we have to either remove those missing data (also known as "complete case") or replace them by some values. Though using complete case is pretty straightforward, it is only applicable when the number of missing entries is so small that throwing away those entries would not affect much the power of the analysis we are conducting on the data. The second strategy, also known as [imputation](https://en.wikipedia.org/wiki/Imputation_%28statistics%29), is more applicable and will be our focus in this tutorial.

Probably the most popular way to perform imputation is to fill a missing value with the mean, median, or mode of its corresponding feature. In that case, we implicitly assume that the feature containing missing values has no correlation with the remaining features of our dataset. This is a pretty strong assumption and might not be true in general. In addition, it does not encode any uncertainty that we might put on those values. Below, we will construct a *Bayesian* setting to resolve those issues. In particular, given a model on the dataset, we will

+ create a generative model for the feature with missing value
+ and consider missing values as unobserved latent variables.

In [1]:
#!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

In [2]:
# first, we need some imports
import os

from IPython.display import set_matplotlib_formats
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd

from jax import numpy as jnp
from jax import random
from jax.scipy.special import expit

import numpyro
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, Predictive

plt.style.use("seaborn")
if "NUMPYRO_SPHINXBUILD" in os.environ:
    set_matplotlib_formats("svg")

assert numpyro.__version__.startswith("0.9.1")

## Dataset

The data is taken from the competition [Titanic: Machine Learning from Disaster](https://www.kaggle.com/c/titanic) hosted on [kaggle](https://www.kaggle.com/). It contains information of passengers in the [Titanic accident](https://en.wikipedia.org/wiki/Sinking_of_the_RMS_Titanic) such as name, age, gender,... And our target is to predict if a person is more likely to survive.

In [3]:
train_df = pd.read_csv(
    "https://raw.githubusercontent.com/agconti/kaggle-titanic/master/data/train.csv"
)
train_df.info()
train_df.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB


Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


Look at the data info, we know that there are missing data at `Age`, `Cabin`, and `Embarked` columns. Although `Cabin` is an important feature (because the position of a cabin in the ship can affect the chance of people in that cabin to survive), we will skip it in this tutorial for simplicity. In the dataset, there are many categorical columns and two numerical columns `Age` and `Fare`. Let's first look at the distribution of those categorical columns:

In [4]:
for col in ["Survived", "Pclass", "Sex", "SibSp", "Parch", "Embarked"]:
    print(train_df[col].value_counts(), end="\n\n")

0    549
1    342
Name: Survived, dtype: int64

3    491
1    216
2    184
Name: Pclass, dtype: int64

male      577
female    314
Name: Sex, dtype: int64

0    608
1    209
2     28
4     18
3     16
8      7
5      5
Name: SibSp, dtype: int64

0    678
1    118
2     80
5      5
3      5
4      4
6      1
Name: Parch, dtype: int64

S    644
C    168
Q     77
Name: Embarked, dtype: int64



## Prepare data

First, we will merge rare groups in `SibSp` and `Parch` columns together. In addition, we'll fill 2 missing entries in `Embarked` by the mode `S`. Note that we can make a generative model for those missing entries in `Embarked` but let's skip doing so for simplicity.

In [5]:
train_df.SibSp.clip(0, 1, inplace=True)
train_df.Parch.clip(0, 2, inplace=True)
train_df.Embarked.fillna("S", inplace=True)

Looking closer at the data, we can observe that each name contains a title. We know that age is correlated with the title of the name: e.g. those with Mrs. would be older than those with `Miss.` (on average) so it might be good to create that feature. The distribution of titles is:

In [6]:
train_df.Name.str.split(", ").str.get(1).str.split(" ").str.get(0).value_counts()

Mr.          517
Miss.        182
Mrs.         125
Master.       40
Dr.            7
Rev.           6
Mlle.          2
Major.         2
Col.           2
the            1
Capt.          1
Ms.            1
Sir.           1
Lady.          1
Mme.           1
Don.           1
Jonkheer.      1
Name: Name, dtype: int64

We will make a new column `Title`, where rare titles are merged into one group `Misc.`.

In [7]:
train_df["Title"] = (
    train_df.Name.str.split(", ")
    .str.get(1)
    .str.split(" ")
    .str.get(0)
    .apply(lambda x: x if x in ["Mr.", "Miss.", "Mrs.", "Master."] else "Misc.")
)

Now, it is ready to turn the dataframe, which includes categorical values, into numpy arrays. We also perform standardization (a good practice for regression models) for `Age` column.

In [8]:
title_cat = pd.CategoricalDtype(
    categories=["Mr.", "Miss.", "Mrs.", "Master.", "Misc."], ordered=True
)
embarked_cat = pd.CategoricalDtype(categories=["S", "C", "Q"], ordered=True)
age_mean, age_std = train_df.Age.mean(), train_df.Age.std()
data = dict(
    age=train_df.Age.pipe(lambda x: (x - age_mean) / age_std).values,
    pclass=train_df.Pclass.values - 1,
    title=train_df.Title.astype(title_cat).cat.codes.values,
    sex=(train_df.Sex == "male").astype(int).values,
    sibsp=train_df.SibSp.values,
    parch=train_df.Parch.values,
    embarked=train_df.Embarked.astype(embarked_cat).cat.codes.values,
)
survived = train_df.Survived.values
# compute the age mean for each title
age_notnan = data["age"][jnp.isfinite(data["age"])]
title_notnan = data["title"][jnp.isfinite(data["age"])]
age_mean_by_title = jnp.stack([age_notnan[title_notnan == i].mean() for i in range(5)])

## Modelling

First, we want to note that in NumPyro, the following models
```python
def model1a():
    x = numpyro.sample("x", dist.Normal(0, 1).expand([10]))
```
and
```python
def model1b():
    x = numpyro.sample("x", dist.Normal(0, 1).expand([10]).mask(False))
    numpyro.sample("x_obs", dist.Normal(0, 1).expand([10]), obs=x)
```
are equivalent in the sense that both of them have

+ the same latent sites `x` drawn from `dist.Normal(0, 1)` prior,
+ and the same log densities `dist.Normal(0, 1).log_prob(x)`.

Now, assume that we observed the last 6 values of `x` (non-observed entries take value `NaN`), the typical model will be
```python
def model2a(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]))
    x_obs = numpyro.sample("x_obs", dist.Normal(0, 1).expand([6]), obs=x[4:])
    x_imputed = jnp.concatenate([x_impute, x_obs])
```
or with the usage of `mask`,
```python
def model2b(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))
    x_imputed = jnp.concatenate([x_impute, x[4:]])
    numpyro.sample("x", dist.Normal(0, 1).expand([10]), obs=x_imputed)
```

Both approaches to model the partial observed data `x` are equivalent. For the model below, we will use the latter method.

In [9]:
def model(
    age, pclass, title, sex, sibsp, parch, embarked, survived=None, bayesian_impute=True
):
    b_pclass = numpyro.sample("b_Pclass", dist.Normal(0, 1).expand([3]))
    b_title = numpyro.sample("b_Title", dist.Normal(0, 1).expand([5]))
    b_sex = numpyro.sample("b_Sex", dist.Normal(0, 1).expand([2]))
    b_sibsp = numpyro.sample("b_SibSp", dist.Normal(0, 1).expand([2]))
    b_parch = numpyro.sample("b_Parch", dist.Normal(0, 1).expand([3]))
    b_embarked = numpyro.sample("b_Embarked", dist.Normal(0, 1).expand([3]))

    # impute age by Title
    isnan = np.isnan(age)
    age_nanidx = np.nonzero(isnan)[0]
    print('index', age_nanidx)
    if bayesian_impute:
        age_mu = numpyro.sample("age_mu", dist.Normal(0, 1).expand([5]))
        age_mu = age_mu[title]
        age_sigma = numpyro.sample("age_sigma", dist.Normal(0, 1).expand([5]))
        age_sigma = age_sigma[title]
        age_impute = numpyro.sample(
            "age_impute",
            dist.Normal(age_mu[age_nanidx], age_sigma[age_nanidx]).mask(False),
        )
        print('age impute size', age_impute.shape)
        age = jnp.asarray(age).at[age_nanidx].set(age_impute)
        numpyro.sample("age", dist.Normal(age_mu, age_sigma), obs=age)
    else:
        # fill missing data by the mean of ages for each title
        age_impute = age_mean_by_title[title][age_nanidx]
        age = jnp.asarray(age).at[age_nanidx].set(age_impute)

    a = numpyro.sample("a", dist.Normal(0, 1))
    b_age = numpyro.sample("b_Age", dist.Normal(0, 1))
    logits = a + b_age * age
    logits = logits + b_title[title] + b_pclass[pclass] + b_sex[sex]
    logits = logits + b_sibsp[sibsp] + b_parch[parch] + b_embarked[embarked]
    numpyro.sample("survived", dist.Bernoulli(logits=logits), obs=survived)

Note that in the model, the prior for `age` is `dist.Normal(age_mu, age_sigma)`, where the values of `age_mu` and `age_sigma` depend on `title`. Because there are missing values in `age`, we will encode those missing values in the latent parameter `age_impute`. Then we can replace `NaN` entries in `age` with the vector `age_impute`.

## Sampling

We will use MCMC with NUTS kernel to sample both regression coefficients and imputed values.

In [10]:
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), **data, survived=survived)
mcmc.print_summary()

index [  5  17  19  26  28  29  31  32  36  42  45  46  47  48  55  64  65  76
  77  82  87  95 101 107 109 121 126 128 140 154 158 159 166 168 176 180
 181 185 186 196 198 201 214 223 229 235 240 241 250 256 260 264 270 274
 277 284 295 298 300 301 303 304 306 324 330 334 335 347 351 354 358 359
 364 367 368 375 384 388 409 410 411 413 415 420 425 428 431 444 451 454
 457 459 464 466 468 470 475 481 485 490 495 497 502 507 511 517 522 524
 527 531 533 538 547 552 557 560 563 564 568 573 578 584 589 593 596 598
 601 602 611 612 613 629 633 639 643 648 650 653 656 667 669 674 680 692
 697 709 711 718 727 732 738 739 740 760 766 768 773 776 778 783 790 792
 793 815 825 826 828 832 837 839 846 849 859 863 868 878 888]
age impute size (177,)
index [  5  17  19  26  28  29  31  32  36  42  45  46  47  48  55  64  65  76
  77  82  87  95 101 107 109 121 126 128 140 154 158 159 166 168 176 180
 181 185 186 196 198 201 214 223 229 235 240 241 250 256 260 264 270 274
 277 284 295 298 300 301 30

  0%|                                                  | 0/2000 [00:00<?, ?it/s]

index [  5  17  19  26  28  29  31  32  36  42  45  46  47  48  55  64  65  76
  77  82  87  95 101 107 109 121 126 128 140 154 158 159 166 168 176 180
 181 185 186 196 198 201 214 223 229 235 240 241 250 256 260 264 270 274
 277 284 295 298 300 301 303 304 306 324 330 334 335 347 351 354 358 359
 364 367 368 375 384 388 409 410 411 413 415 420 425 428 431 444 451 454
 457 459 464 466 468 470 475 481 485 490 495 497 502 507 511 517 522 524
 527 531 533 538 547 552 557 560 563 564 568 573 578 584 589 593 596 598
 601 602 611 612 613 629 633 639 643 648 650 653 656 667 669 674 680 692
 697 709 711 718 727 732 738 739 740 760 766 768 773 776 778 783 790 792
 793 815 825 826 828 832 837 839 846 849 859 863 868 878 888]
age impute size (177,)


sample: 100%|█| 2000/2000 [00:17<00:00, 113.11it/s, 63 steps of size 6.11e-02. a



                     mean       std    median      5.0%     95.0%     n_eff     r_hat
              a      0.12      0.83      0.12     -1.28      1.40   1077.21      1.00
  age_impute[0]      0.20      0.81      0.20     -1.11      1.52   2068.95      1.00
  age_impute[1]     -0.07      0.84     -0.09     -1.34      1.29   1524.15      1.00
  age_impute[2]      0.38      0.77      0.35     -0.77      1.73   2369.57      1.00
  age_impute[3]      0.25      0.84      0.24     -1.26      1.56   1209.11      1.00
  age_impute[4]     -0.64      0.91     -0.62     -2.16      0.72   1417.63      1.00
  age_impute[5]      0.21      0.88      0.20     -1.07      1.79   1612.38      1.00
  age_impute[6]      0.45      0.84      0.45     -0.81      1.94   1532.53      1.00
  age_impute[7]     -0.65      0.84     -0.65     -2.15      0.66   1795.94      1.00
  age_impute[8]     -0.14      0.90     -0.16     -1.49      1.49   1931.31      1.00
  age_impute[9]      0.24      0.85      0.23     -1.

To double check that the assumption "age is correlated with title" is reasonable, let's look at the infered age by title. Recall that we performed standarization on `age`, so here we need to scale back to original domain.

In [11]:
age_by_title = age_mean + age_std * mcmc.get_samples()["age_mu"].mean(axis=0)
dict(zip(title_cat.categories, age_by_title))

{'Mr.': DeviceArray(32.42303, dtype=float32),
 'Miss.': DeviceArray(21.783125, dtype=float32),
 'Mrs.': DeviceArray(35.86772, dtype=float32),
 'Master.': DeviceArray(4.6148586, dtype=float32),
 'Misc.': DeviceArray(42.061874, dtype=float32)}

The infered result confirms our assumption that `Age` is correlated with `Title`:

+ those with `Master.` title has pretty small age (in other words, they are children in the ship) comparing to the other groups,
+ those with `Mrs.` title have larger age than those with `Miss.` title (in average).

We can also see that the result is similar to the actual statistical mean of `Age` given `Title` in our training dataset:

In [12]:
train_df.groupby("Title")["Age"].mean()

Title
Master.     4.574167
Misc.      42.384615
Miss.      21.773973
Mr.        32.368090
Mrs.       35.898148
Name: Age, dtype: float64

So far so good, we have many information about the regression coefficients together with imputed values and their uncertainties. Let's inspect those results a bit:

+ The mean value `-0.44` of `b_Age` implies that those with smaller ages have better chance to survive.
+ The mean value `(1.11, -1.07)` of `b_Sex` implies that female passengers have higher chance to survive than male passengers.

## Prediction

In NumPyro, we can use [Predictive](http://num.pyro.ai/en/stable/utilities.html#numpyro.infer.util.Predictive) utility for making predictions from posterior samples. Let's check how well the model performs on the training dataset. For simplicity, we will get a `survived` prediction for each posterior sample and perform the majority rule on the predictions.

In [13]:
posterior = mcmc.get_samples()
survived_pred = Predictive(model, posterior)(random.PRNGKey(1), **data)["survived"]
survived_pred = (survived_pred.mean(axis=0) >= 0.5).astype(jnp.uint8)
print("Accuracy:", (survived_pred == survived).sum() / survived.shape[0])
confusion_matrix = pd.crosstab(
    pd.Series(survived, name="actual"), pd.Series(survived_pred, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)

index [  5  17  19  26  28  29  31  32  36  42  45  46  47  48  55  64  65  76
  77  82  87  95 101 107 109 121 126 128 140 154 158 159 166 168 176 180
 181 185 186 196 198 201 214 223 229 235 240 241 250 256 260 264 270 274
 277 284 295 298 300 301 303 304 306 324 330 334 335 347 351 354 358 359
 364 367 368 375 384 388 409 410 411 413 415 420 425 428 431 444 451 454
 457 459 464 466 468 470 475 481 485 490 495 497 502 507 511 517 522 524
 527 531 533 538 547 552 557 560 563 564 568 573 578 584 589 593 596 598
 601 602 611 612 613 629 633 639 643 648 650 653 656 667 669 674 680 692
 697 709 711 718 727 732 738 739 740 760 766 768 773 776 778 783 790 792
 793 815 825 826 828 832 837 839 846 849 859 863 868 878 888]
age impute size (177,)
Accuracy: 0.8249158


predict,0,1
actual,Unnamed: 1_level_1,Unnamed: 2_level_1
0,0.874317,0.201754
1,0.15847,0.745614


This is a pretty good result using a simple logistic regression model. Let's see how the model performs if we don't use Bayesian imputation here.

### *** add predictions on "new data" to demonstrate the problem with imputing for new missing data ***

In [14]:
np.random.seed(1)
train_df2 = train_df.sample(frac=1)

In [15]:
train_df2.Age

862    48.0
223     NaN
84     17.0
680     NaN
535     7.0
       ... 
715    19.0
767    30.5
72     21.0
235     NaN
37     21.0
Name: Age, Length: 891, dtype: float64

In [16]:
title_cat2 = pd.CategoricalDtype(
    categories=["Mr.", "Miss.", "Mrs.", "Master.", "Misc."], ordered=True
)
embarked_cat2 = pd.CategoricalDtype(categories=["S", "C", "Q"], ordered=True)
age_mean2, age_std2 = train_df2.Age.mean(), train_df2.Age.std()
# sprinkle in more missing for age
age = train_df2.Age
# add a new missing value:
age[0] = np.nan
data2 = dict(
    age=age.pipe(lambda x: (x - age_mean2) / age_std2).values,
    pclass=train_df2.Pclass.values - 1,
    title=train_df2.Title.astype(title_cat2).cat.codes.values,
    sex=(train_df2.Sex == "male").astype(int).values,
    sibsp=train_df2.SibSp.values,
    parch=train_df2.Parch.values,
    embarked=train_df2.Embarked.astype(embarked_cat2).cat.codes.values,
)
survived2 = train_df2.Survived.values
# compute the age mean for each title
age_notnan2 = data2["age"][jnp.isfinite(data2["age"])]
title_notnan2 = data2["title"][jnp.isfinite(data2["age"])]
age_mean_by_title2 = jnp.stack([age_notnan2[title_notnan2 == i].mean() for i in range(5)])

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  age[0] = np.nan


In [17]:
posterior = mcmc.get_samples()
survived_pred = Predictive(model, posterior)(random.PRNGKey(1), **data2)["survived"]
survived_pred = (survived_pred.mean(axis=0) >= 0.5).astype(jnp.uint8)
print("Accuracy:", (survived_pred == survived).sum() / survived.shape[0])
confusion_matrix = pd.crosstab(
    pd.Series(survived2, name="actual"), pd.Series(survived_pred, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)

index [  1   3   9  13  14  31  37  50  60  61  73  77  81  82  86  87  91  93
  94  97  99 100 102 122 128 129 143 144 147 154 167 175 176 179 186 188
 190 192 196 202 208 216 220 221 224 235 236 237 239 261 262 270 272 287
 294 297 299 307 310 311 313 327 328 336 338 339 352 356 361 363 366 379
 384 395 397 399 406 409 417 419 421 423 426 427 432 433 438 443 444 455
 456 460 461 464 474 487 492 497 500 504 514 523 527 531 541 542 544 547
 552 565 568 569 577 586 593 594 605 606 613 615 618 620 631 639 643 644
 651 661 663 664 665 667 669 671 675 681 683 684 687 690 691 695 699 702
 719 723 727 733 735 738 739 744 753 774 782 791 797 798 802 805 806 813
 821 822 824 825 827 832 835 836 840 845 846 848 850 870 872 889]
age impute size (177,)


ValueError: Incompatible shapes for broadcasting: (177,) and requested shape (178,)

back to original notebook

In [None]:
mcmc.run(random.PRNGKey(2), **data, survived=survived, bayesian_impute=False)
posterior_1 = mcmc.get_samples()
survived_pred_1 = Predictive(model, posterior_1)(random.PRNGKey(2), **data)["survived"]
survived_pred_1 = (survived_pred_1.mean(axis=0) >= 0.5).astype(jnp.uint8)
print("Accuracy:", (survived_pred_1 == survived).sum() / survived.shape[0])
confusion_matrix = pd.crosstab(
    pd.Series(survived, name="actual"), pd.Series(survived_pred_1, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)
confusion_matrix = pd.crosstab(
    pd.Series(survived, name="actual"), pd.Series(survived_pred_1, name="predict")
)
confusion_matrix / confusion_matrix.sum(axis=1)

We can see that Bayesian imputation does a little bit better here.

**Remark.** When using `posterior` samples to perform prediction on the new data, we need to marginalize out `age_impute` because those imputing values are specific to the training data:
```python
posterior.pop("age_impute")
survived_pred = Predictive(model, posterior)(random.PRNGKey(3), **new_data)
```

## References

1. McElreath, R. (2016). Statistical Rethinking: A Bayesian Course with Examples in R and Stan.
2. Kaggle competition: [Titanic: Machine Learning from Disaster](https://www.kaggle.com/c/titanic)