[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/fonnesbeck/bayes_course_june_2024/blob/master/notebooks/Section2_1-MCMC.ipynb)

# Lesson: Bayesian Linear Regression

Now that we have covered the basics of Bayesian inference and Markov Chain Monte Carlo (MCMC) methods, we can apply these concepts to a specific class of statistical model, linear regression. Regression models are everywhere in data science, and Bayesian linear regression offers a powerful framework for incorporating prior knowledge and uncertainty into the modeling process.

Bayesian linear regression extends the traditional linear regression framework by incorporating prior beliefs about the parameters and updating these beliefs with data to return a posterior distribution of model's latent parameters. These posterior distributions can be used to make predictions, estimate uncertainty, and evaluate hypotheses.

The model assumes that the response variable $y$ is generated from a normal distribution with a mean that is a linear function of the predictors and a constant variance. Mathematically, this can be expressed as:

$$
y = X\beta + \epsilon, \quad \epsilon \sim N(0, \sigma^2I_n)
$$

where $y$ is the vector of response variables, $X$ is the design matrix of predictors, $\beta$ is the vector of regression coefficients, and $\epsilon$ is the error term with a normal distribution. 

The likelihood of the data given the parameters is then:

$$
p(y|X, \beta, \sigma^2) = (2\pi\sigma^2)^{-n/2} \exp \left( -\frac{1}{2\sigma^2} (y - X\beta)'(y - X\beta) \right)
$$

To perform Bayesian inference, we specify prior distributions for the parameters, which include the regression coefficients $\beta$ and the error variance $\sigma^2$. 

$$
\beta \sim p(\beta), \quad \sigma \sim p(\sigma)
$$

The posterior distribution, which combines the prior information with the likelihood of the observed data, is then derived using Bayes' theorem. The joint posterior distribution of $\beta$ and $\sigma$ is:

$$
p(\beta, \sigma | y, X) \propto p(y | X, \beta, \sigma^2) p(\beta) p(\sigma)
$$

Note that the priors are generally assumed to be independent of one another.


In [None]:
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns

az.style.use("arviz-darkgrid")
RANDOM_SEED = 8296
np.random.seed(RANDOM_SEED)

warnings.filterwarnings(action="ignore", category=UserWarning, module=r"seaborn")

## Example: Fish Weight Prediction

In this lesson, we'll imagine we are working in the data science team of an e-commerce company. In particular, we sell really good and fresh fish to our clients (mainly fancy restaurants). 

When we ship our products, there is a very important piece of information we need: the weight of the fish. This is important for two reasons: 

1. Because we _bill_ our clients according to weight. 

2. Because the company that delivers the fish to our clients has different price tiers for weights, and those tiers can get _really_ expensive. So we want to know the probability of an item being above that line. In other words, estimating uncertainty is important here!

![](images/weighingfish.jpg)


The problem we face is that we purchase our fish in bulk. This means we only know the total weight of our entire order, but we don't have the weights of the individual fish. You might think the obvious solution is simply to weigh each fish one by one.

However, this approach has significant drawbacks. Manually weighing each fish is costly, requires a lot of time, and demands substantial labor. This process is inefficient and impractical for our needs.

Given these challenges, we need to explore alternative solutions. 

### A solution

While researching the problem, we discovered that our wholesale supplier has detailed information on the size of each individual fish, including their length, height, and width. Since it is infeasible to weigh individual fish, the supplier uses a **camera** to record the size of each fish. 

However, the company used to try to weigh each fish manually until costs became prohibitive. As a result, we have a valuable **training dataset** consisting of different types of fish with their accurately -measured weights.

![](images/fishvideo.png)

### Exploratory data analysis

Let's import the data and take a look at it.

In [None]:
fish_market = pd.read_csv("../data/fish-market.csv")
fish_market.info()

We have collected 159 measurements, and all columns in our dataset have the appropriate data types.

For each observation, the dataset includes the following information: the species of the fish, its weight, height, and width, as well as three distinct length measurements. You might be wondering why we have three different measurements for the fish's length. Let's delve into some summary statistics to better understand the data and its significance.

In [None]:
fish_market.isnull().sum()

No missing values, which is nice.

Next let's peek at some summary statistics:

In [None]:
fish_market.describe().round(2)

Things to note:

- Though there are no missing data, there are some zero-weight fish! -- either the fish was below the minimum weight for the scake, or there was a mistake during data collection. 
- The standard deviation of the columns are very high, especially for weights.
- There are three columns for length, which is interesting. We will explore this further.

In [None]:
sns.heatmap(
    fish_market.drop(columns="Species").corr(),
    vmin=-1,
    vmax=1,
    center=0,
    annot=True,
    linewidths=4,
);

The three length measurements are highly correlated with each other. This means they essentially carry the same information. Without additional details to distinguish among them, we should arbitrarily choose one measurement and discard the other two. Keeping all three would be redundant and unnecessary since they do not provide unique information.

There is nothing inherently Bayesian about this step. The concept of *multicollinearity* is a fundamental concern in both Bayesian and frequentist statistics. In essence, if you include multiple variables that convey similar information in your regression model, you will end up with very unstable parameter estimates. This redundancy does not improve your model's predictive power and can, in fact, lead to misleading results. Thus, it is crucial to identify and address multicollinearity to maintain the robustness and reliability of your model.

In [None]:
fish_market = fish_market.drop(["Length2", "Length3"], axis="columns")

## Visual data exploration

Its always a good idea to plot your data! Seaborn's `pairplot` function is a great way to visualize the relationships between variables in your dataset. This function creates a matrix of scatterplots, with each variable plotted against every other variable. The diagonal of the matrix shows a histogram of each variable.

In [None]:
sns.pairplot(data=fish_market);

All variables exhibit linear relationships with each other, with one notable exception: weight. Weight appears to increase exponentially in relation to the other variables. However, this exponential growth is not limitless; it plateaus due to a natural upper limit on weight.

Additionally, we observe several trends within the data that may indicate differences in how these variables interact across various species. These trends suggest that the relationships between the variables are not uniform across all species, potentially due to unique biological or ecological factors influencing each species.

So, let's break down the data by species and see if we can identify any patterns.

In [None]:
sns.pairplot(data=fish_market, hue="Species");

Thus, it is clear that any model we build must account for the differences in the relationships between variables across species. This is where Bayesian linear regression comes in handy. By incorporating **domain knowledge** about the relationships between variables and the differences across species, we can build a more robust and reliable model.

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 9))
for ax, var in zip(axes.ravel(), ["Length1", "Height", "Width", "Weight"]):
    sns.violinplot(x="Species", y=var, data=fish_market, ax=ax)

The most diverse species are Bream, Whitefish, Perch, and Pike. This diversity likely makes them more versatile for sale and cooking because they come in a wide range of sizes, including different weights, widths, and heights. This variety allows for more options in preparation methods and recipes, catering to various culinary needs.

On the other hand, the Smelt is a very small fish that is typically used in specialized recipes. A brief internet search reveals that Smelt is most commonly fried and served as an appetizer, particularly in European cuisine. Its smaller size and specific preparation methods make it less versatile than the more diverse species like Bream, Whitefish, Perch, and Pike.The most diverse species are Bream, Whitefish, Perch and Pike, which means they are probably easier to sell and cook, as they exist in a variety of weight, width and height. Conversely, the Smelt is a very small fish that is probably used only in specialized recipes -- a quick internet search will show you that they are usually fried and served as appetizers, at least in Europe.

## A non-Bayesian linear regression

Now that we have a clearer understanding of the data we're working with, let's move on to developing a predictive model. Our specific task is to **predict the weight of a fish based on its width, height, and length**. While we've chosen these particular variables for our analysis, it's important to note that different combinations of independent and dependent variables could also be used, depending on the specific requirements of the study.

The most promising approach for our task is to develop a **physical model**. This involves leveraging the inherent relationships between height, width, and weight, which are governed by physical proportions that impose natural lower and upper bounds on these variables. In a professional context, such a model would likely yield the most accurate and reliable predictions due to its basis in the physical characteristics of fish.

However, creating a detailed physical model can be quite complex. Therefore, for our initial attempt, we can use a simple **ordinary least squares (OLS)** regression to establish a relationship between the dependent variable (weight) and the independent variables (width, height, and length).

From our data exploration, we observed that weight is not linearly related to the other variables. This non-linear relationship suggests that a direct application of linear regression may not be effective. To address this issue, we often need to apply some form of data transformation to better fit the model to the data.

In this scenario, a **logarithmic transformation** of the data appears to be a suitable choice. This transformation can help counteract the exponential increase in weight as the fish's width, height, and length increase. By applying a log-transform, we can linearize the relationship between these variables, making it more appropriate for linear regression analysis.

### Taking the log of all covariates

In [None]:
fish_market = fish_market.assign(log_width=np.log(fish_market.Width),
                                 log_height=np.log(fish_market.Height),
                                 log_length=np.log(fish_market.Length1),
                                 log_weight=np.log(fish_market.Weight))

Note that during the transformation, Pandas generated a warning because the logarithm of zero is negative infinity. This warning indicates a problem we need to address.

The simplest solution is to remove these observations from our dataset, a process known as **complete case analysis**.

This approach makes an implicit assumption that the missing values are not systematically different from the non-missing values. This assumption is often incorrect. For instance, consider that certain characteristics of the missing fish might make them particularly difficult to find in fish markets. If these characteristics contradict the trends observed in the available data, our analysis could become arbitrarily biased.

Or, imagine that a specific size of fish makes them ideal prey for predators. Consequently, these fish experience higher stress levels, resulting in them being thinner and lighter than their counterparts. However, their body shape makes them more challenging to weigh accurately, leading to frequent missing measurements.

By simply discarding these particular measurements, we might be ignoring a subset of fish that could significantly impact our trendline. This situation suggests that we might need a more sophisticated model to account for the dip in weight at certain size ranges, rather than merely excluding the data.

### Simple OLS regression

An easy way to perform OLS regression is via the `seaborn` graphics library. The `lmplot` function creates a scatterplot of the data and fits a regression line to the data.

In [None]:
fish_complete = fish_market[fish_market["Weight"] != 0].copy()

sns.lmplot(
    data=fish_complete,
    x="log_height",
    y="log_weight",
    hue="Species",
    col="Species",
    height=3,
    col_wrap=4,
);

The output here is purely visual, but in log space, our input variables seem linearly related to weight, so there is good reason to believe that a linear model is appropriate here.

Let's go ahead and fit a linear model to the data using PyMC.

### Baseline Model

Let's start with a very simple "null" model: just a global mean with no predictors. 
$$
\begin{aligned}
\log(\text{weight}) &\sim \mathrm{Normal}(\mu, \sigma)\\
\mu &\sim \mathrm{Normal}(0, 1)\\
\sigma &\sim \mathrm{HalfNormal}(1)\\
\end{aligned}
$$

This corresponds to `log(weight) ~ 1` in [Wilkinson notation](https://uk.mathworks.com/help/stats/wilkinson-notation.html).

In [None]:
with pm.Model() as fish_simple:

    # Prior
    mu = pm.Normal("mu")
    sigma = pm.HalfNormal("sigma", 1.0)

    # Likelihood
    pm.Normal(
        "log_weight",
        mu=mu,
        sigma=sigma,
        observed=fish_complete["log_weight"].to_numpy(),
    )

pm.model_to_graphviz(fish_simple)

Now to fit the model:

In [None]:
with fish_simple:
    trace_simple = pm.sample()

In [None]:
az.summary(trace_simple, round_to=2)

We will dig into model checking and diagnostics in a later section (which will explain most of the values in the `summary` table), but for now we can plot the posterior distribution of the model parameters and do some informal, visual checks.

In [None]:
az.plot_trace(trace_simple);

Traceplots are useful for evaluating the performance of our MCMC sampling. In these plots, we aim to see a "fuzzy caterpillar" pattern on the right side, which indicates that the chains are **mixing well** and exploring the parameter space effectively. This is evidence to suggest the chains have converged (to something!) and are providing a reasonable representation of the posterior distribution.

In addition to traceplots, **rank plots** serve as another diagnostic tool for assessing the quality of your MCMC samples. Rank plots display the ranks of sampled values for each parameter, and we look for the histograms in these plots to be approximately uniform. If one chain samples some values significantly more than the other chains, then the ranks of its samples will be markedly higher or lower than other chains and the histograms won't be uniform. Uniform histograms suggest that the sampler has performed well, meaning the samples are not exhibiting degeneracies. 

In [None]:
az.plot_rank(trace_simple);

## Interpreting parameters

This model is very simple, so the mean coefficient estimates are the mean and standard deviation, respectively, of the sample. 

If we go back to our trace plot, our posterior uncertainty doesn't seem big. But remember that we're on log scale, so it would be best to work on the nominal scale. Fortunately, `plot_trace` even has a `transform` argument we can use.

In [None]:
az.plot_trace(trace_simple, transform=np.exp, var_names="mu");

So there is a reasonable amount of uncertainty in our estimate of the mean weight of a fish, which is perhaps surprising given we have pooled all the data. 

Now let's look at **model fit**. We will explore model checking in detail later in the course, but for now, we can use a simple technique: posterior predictive checks.

Posterior predictive checks (PPCs) are a great way to validate a model. The idea is to generate data from the model using parameters from the posterior distribution and compare these samples to the observed data.

Let's generate these simulated datasets now.

In [None]:
with fish_simple:
    ppc = pm.sample_posterior_predictive(trace_simple, extend_inferencedata=True)

ax = az.plot_ppc(trace_simple)
ax.set_xlabel("log_weight");

The data are clearly heterogeneous, as evidenced by the multiple peaks in the log-weight, but our model fails to capture them accurately. This discrepancy suggests that the model is struggling to fit the data properly. Consequently, the model resorts to increasing the posterior uncertainty and observational noise (`sigma`). Essentially, the model compensates for its inability to accurately represent the data by broadening its predictions, sacrificing precision for coverage.

### Adding predictors to our model

It is time to introduce predictors to our model, and see how much they improve prediction.


$$
\begin{aligned}
\text{priors}\\
\mu[s] &\sim \mathrm{Normal}(0, 1)\\
\beta[s, k] &\sim \mathrm{Normal}(0, 0.5)\\
\sigma &\sim \mathrm{HalfNormal}(1)\\
\text{linear model}\\
\mu_i &= \mu[s_i]\\
        & \quad + \beta[s_i, 0] \times \log(\text{width}_i)\\
        & \quad + \beta[s_i, 1] \times \log(\text{height}_i)\\
        & \quad + \beta[s_i, 2] \times \log(\text{length}_i)\\
\text{likelihood}\\
\log(\text{weight}_i) &\sim \mathrm{Normal}(\mu_i, \sigma)\\
\end{aligned}
$$


where $s_i$ is the species index corresponding to fish _i_:


$$
s_i \in \{ 0, 1, \ldots, {S-1} \}.
$$


In Wilkinson notation, the model can be written as:


`log(weight) ~ 0 + species + log(width):species + log(height):species + log(length):species`. 


The `0 + species` component means that we just have $S$ intercept terms, one for each species, with no global intercept. 

The remaining terms (e.g. `log(width):species`) represent an interaction between the predictor and the `species` category. So there will be one coefficient for the $\log(width)$ slope (in this case) for each species.

So, each species has its own intercept and slopes for width, height, and length. This is an **unpooled model** because we are essentially fitting a separate regression for each species!

In order to make this work, we need to encode the species as a categorical variable. We can do this using the `factorize` function in `pandas`.

In [None]:
fish_complete.Species.factorize(sort=True)

`factorize` encodes the species names into integer values that we will use as indices for the model parametes. In addition, we will use the category classes as dimension labels.

### Define dimensions & coordinates

In [None]:
species_idx, species = fish_complete.Species.factorize(sort=True)
coords = {
    "slopes": ["width_effect", "height_effect", "length_effect"],
    "species": species,
}

We will also make use of `Data` containers to include the data explicitly in the model. This will be useful later, when we want to predict out-of-sample.

In [None]:
with pm.Model(coords=coords) as fish_unpooled:
    # data
    log_width = pm.Data("log_width", fish_complete.log_width.to_numpy())
    log_height = pm.Data("log_height", fish_complete.log_height.to_numpy())
    log_length = pm.Data("log_length", fish_complete.log_length.to_numpy())
    log_weight = pm.Data("log_weight", fish_complete.log_weight.to_numpy())
    s = pm.Data("species_idx", species_idx)

    # priors
    mu = pm.Normal("mu", sigma=1.0, dims="species")

    # each species gets a slope for each predictor thx to `dims`:
    beta = pm.Normal("beta", sigma=0.5, dims=("species", "slopes"))

    # linear regression
    expected_weight = (
        mu[s]
        + beta[s, 0] * log_width
        + beta[s, 1] * log_height
        + beta[s, 2] * log_length
    )
    # observational noise
    sigma = pm.HalfNormal("sigma", 1.0)

    # likelihood
    pm.Normal(
        "log_obs",
        mu=expected_weight,
        sigma=sigma,
        observed=log_weight,
    )

It's always helpful to plot the model before fitting it. This can help you catch errors in the model specification, and also give you a sense of what the model is doing.

In [None]:
pm.model_to_graphviz(fish_unpooled)

In [None]:
with fish_unpooled:
    trace_unpooled = pm.sample()
    # Posterior predictive
    pm.sample_posterior_predictive(trace_unpooled, extend_inferencedata=True)

Inspecting the posterior paramter estimates, notably the intercepts to begin with:

In [None]:
az.plot_trace(trace_unpooled, var_names='mu', transform=np.exp);

The intercepts look small (even on the nominal scale) which seems odd. But recall how intercepts are interpreted: they are the expected value of the outcome when all predictors are zero. In this case, that means when the log of the width, height, and length are zero. This is an awkward from an interpretive standpoint. 

How could we improve this?

Give it a try, and re-run the improved model.

In [None]:
# Write your answer here

Now we have meaningful intercepts -- the expected weight of a fish with average width, height, and length for each species.

In [None]:
az.plot_trace(trace_unpooled, var_names='mu', transform=np.exp);

When we have vector-valued paramters a forest plot is convenient for visualizing them.

In [None]:
az.plot_forest(trace_unpooled, var_names="beta", transform=np.exp);

In [None]:
az.plot_trace(trace_unpooled, var_names="sigma", transform=np.exp);

There is a good sign here: the posterior uncertainty around `sigma` is much lower than before, i.e we picked up much more information on the fish weights. But did this improve our posterior predictions?

In [None]:
with fish_unpooled:
    pm.sample_posterior_predictive(trace_unpooled, extend_inferencedata=True)
ax = az.plot_ppc(trace_unpooled)
ax.set_xlabel("log_obs");

## Predicting out-of-sample

In statistical workflows, a common task is to make predictions using new, unseen data, often referred to as "out-of-sample" data. In PyMC, the most straightforward approach to achieve this is by utilizing the `Data` container. This container allows PyMC and ArviZ to specify the data used for training the model, and then allow you to modify it later on.

#### Splitting Data into Training and Test Sets

To illustrate this functionality, let's randomly select 90% of our data as the training dataset for the model, while reserving the remaining 10% as the test data. This test data will be unseen by the model during the training process, allowing us to evaluate its performance on new, previously unseen data when making predictions.

By following this approach, you can effectively train your model on a subset of the available data and then assess its predictive capabilities on the held-out test data, mimicking real-world scenarios where predictions need to be made on new, unobserved data points.

In [None]:
fish_test = (
    fish_complete.sample(frac=0.1, random_state=1)
    .sort_index()
    .reset_index(drop=True)
)
test_idx = fish_test.index
fish_train = fish_complete.loc[fish_complete.index.difference(test_idx)].reset_index(
    drop=True
)

Since the dataset changed compared to the previous model, we also have to redefine our coordinates:

In [None]:
species_idx, species = fish_train.Species.factorize(sort=True)
coords["species"] = species

In [None]:
with pm.Model(
    coords=coords, coords_mutable={"obs_idx": fish_train.index}
) as fish_unpooled_oos:
    # data
    log_width = pm.Data(
        "log_width", fish_train.log_width.to_numpy() - fish_train.log_width.mean(), dims="obs_idx"
    )
    log_height = pm.Data(
        "log_height", fish_train.log_height.to_numpy() - fish_train.log_height.mean(), dims="obs_idx"
    )
    log_length = pm.Data(
        "log_length", fish_train.log_length.to_numpy() - fish_train.log_length.mean(), dims="obs_idx"
    )
    log_weight = pm.Data(
        "log_weight", fish_train.log_weight.to_numpy(), dims="obs_idx"
    )
    species_idx_ = pm.Data("species_idx", species_idx, dims="obs_idx")

    # priors
    mu = pm.Normal("mu", sigma=1.0, dims="species")
    beta = pm.Normal("beta", sigma=0.5, dims=("slopes", "species"))

    # linear regression
    expected_weight = (
        mu[species_idx_]
        + beta[0, species_idx_] * log_width
        + beta[1, species_idx_] * log_height
        + beta[2, species_idx_] * log_length
    )
    # observational noise
    sigma = pm.HalfNormal("sigma", 1.0)

    # likelihood
    log_obs = pm.Normal(
        "log_obs", mu=expected_weight, sigma=sigma, observed=log_weight, dims="obs_idx"
    )

    # sampling
    trace_unpooled_oos = pm.sample()
    pm.sample_posterior_predictive(trace_unpooled_oos, extend_inferencedata=True)

In [None]:
pm.model_to_graphviz(fish_unpooled_oos)

Checking the traceplots:

In [None]:
az.plot_trace(trace_unpooled_oos, transform=np.exp);

Now we want to see how this model would work in production: given some fish morphometrics, can we accurately predict the weight of the fish?

To do this, we just have to use `set_data` to change the inputs from the training set to the test set.

In [None]:
# Encode the species
species_idx_test = pd.Categorical(fish_test.Species, categories=species).codes.astype(
    np.int64
)

species_idx_test

Note that we are shifting the input variables using the training set mean and standard deviation. You always want to use the same transformation on the test set as you did on the training set!

In [None]:
with fish_unpooled_oos:
    pm.set_data(
        coords={"obs_idx": fish_test.index},
        new_data={
            "log_height": fish_test.log_height.to_numpy() - fish_train.log_height.mean(),
            "log_length": fish_test.log_length.to_numpy() - fish_train.log_length.mean(),
            "log_width": fish_test.log_width.to_numpy() - fish_train.log_width.mean(),
            "log_weight": np.zeros_like(fish_test.index),
            "species_idx": species_idx_test,
        },
    )

We now call `sample_posterior_predictive` once again, but this time we specify `predictions=True` since these are not posterior predictive checks, and they will be store in a different attribute on the trace.

### Use updated values to predict outcomes

In [None]:
with fish_unpooled_oos:
    pm.sample_posterior_predictive(
        trace_unpooled_oos,
        predictions=True,
        extend_inferencedata=True,
    )

In [None]:
trace_unpooled_oos

How good are these imputations? Glad you asked. Remember that our data are not _really_ out-of-sample; we just cut them out from our original dataset, so we can compare our predictions to the true weights. This is a simple line of code in ArviZ (we just exponentiate the predicted log weights to compare them to the true weights):

In [None]:
az.plot_posterior(
    trace_unpooled_oos.predictions,
    ref_val=fish_test["Weight"].tolist(),
    transform=np.exp,
);

So the predicted values all fell within the predictive distributions -- not all within the 95% interval, but there were no extreme predictions.

## Exercise: Refitting the model

Given the success of the model, you go back and try to fit it to data collected by another vendor, only to find that the predictions aren't nearly as good!

Frustrated, you go back to the drawing board... they deal with the same type of fish, but what's wrong with their data?

One of their colleagues mentions something about not having use the same equipment to weight the fish, because the "old manager always tried to cut costs".
They used a much cheaper scale ...

With this information in hand, make the appropriate modifications to the model to accomodate the new data. 

Here is the data:

In [None]:
new_fish = pd.read_csv("../data/new_fish.csv", index_col=0)

Try to diagnose the issue and propose a new model (a slight variation) that may help in dealing with the properties of this new dataset better!

In [None]:
# Write your answer here

## From predictions to business insights

Recall from the introduction that there are different price tiers for weights, and those tiers can get _really_ expensive, so we want to know the probability of an item being above any theshold.

- $> 250$
- $> 500$
- $> 750$
- $> 1000$

Since we have calculated posterior distributions, we have the ability to compute these probabilities for any new fish we observe.


In [None]:
# Extract projections to numpy array
predictions = (
    np.exp(
        az.extract(trace_unpooled_oos.predictions)
        .to_array()
        .to_numpy()
        .squeeze()
    )
)

Now we can see what proportion are above $250$ grams.

In [None]:
threshold = 250
(predictions >= threshold).mean(axis=1).round(2)

If we take something like a 0.5 probability as being "above", we can make a decision about each:

In [None]:
(predictions >= threshold).mean(axis=1).round(2) > 0.5


But remember that there are four thresholds $(250, 500, 750, 1000)$, so let's generalize this approach to the other three thresholds. We'll also plot these probabilities of being above thresholds.

In [None]:
predictions = np.exp(trace_unpooled_oos.predictions)

axes = az.plot_posterior(predictions, color="k")

for k, threshold in enumerate([250, 500, 750, 1000]):
    probs_above_threshold = (predictions >= threshold).mean(dim=("chain", "draw"))

    for i, ax in enumerate(axes.ravel()):
        ax.axvline(threshold, color=f"C{k}")
        _, pdf = az.kde(
            predictions["log_obs"].sel(obs_idx=i).stack(sample=("chain", "draw")).data
        )
        ax.text(
            x=threshold - 35,
            y=pdf.max() / 2,
            s=f">={threshold}",
            color=f"C{k}",
            fontsize="16",
            fontweight="bold",
        )
        ax.text(
            x=threshold - 20,
            y=pdf.max() / 2.3,
            s=f"{probs_above_threshold.sel(obs_idx=i)['log_obs'].data * 100:.0f}%",
            color=f"C{k}",
            fontsize="16",
            fontweight="bold",
        )
        ax.set_title(f"New fish\n{i}", fontsize=16)
        ax.set(xlabel="Weight\n", ylabel="Plausible values")
plt.suptitle(
    "Probability of weighing more than thresholds", fontsize=26, fontweight="bold"
);

In [None]:
%load_ext watermark
%watermark -n -u -v -iv -w