Based on https://bambinos.github.io/bambi/notebooks/splines_cherry_blossoms.html

In [None]:
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
az.style.use("arviz-darkgrid")
SEED = 7355608

In [None]:
data = bmb.load_data("cherry_blossoms")
data

In [None]:
data = data.dropna(subset=["doy"]).reset_index(drop=True)
data.shape

In [None]:
# We create a function because this plot is going to be used again later
def plot_scatter(data, figsize=(10, 6)):
    _, ax = plt.subplots(figsize=figsize)
    ax.scatter(data["year"], data["doy"], alpha=0.4, s=30)
    ax.set_title("Day of the first bloom per year")
    ax.set_xlabel("Year")
    ax.set_ylabel("Days of the first bloom")
    return ax

In [None]:
plot_scatter(data);

In [None]:
num_knots = 15
knots = np.quantile(data["year"], np.linspace(0, 1, num_knots))

In [None]:
def plot_knots(knots, ax):
    for knot in knots:
        ax.axvline(knot, color="0.1", alpha=0.4)
    return ax

In [None]:
ax = plot_scatter(data)
plot_knots(knots, ax);

In [None]:
# We only pass the internal knots to the `bs()` function.
iknots = knots[1:-1]

# Define dictionary of priors
priors = {
    "Intercept": bmb.Prior("Normal", mu=100, sigma=10),
    "common": bmb.Prior("Normal", mu=0, sigma=10),
    "sigma": bmb.Prior("Exponential", lam=1),
}

# Define model
# The intercept=True means the basis also spans the intercept, as originally done in the book example.
model = bmb.Model("doy ~ bs(year, knots=iknots, intercept=True)", data, priors=priors)
model

In [None]:
def plot_spline_basis(basis, year, figsize=(10, 6)):
    df = (
        pd.DataFrame(basis)
        .assign(year=year)
        .melt("year", var_name="basis_idx", value_name="value")
    )

    _, ax = plt.subplots(figsize=figsize)

    for idx in df.basis_idx.unique():
        d = df[df.basis_idx == idx]
        ax.plot(d["year"], d["value"])

    return ax

In [None]:
B = model.components["mu"].design.common["bs(year, knots=iknots, intercept=True)"]
ax = plot_spline_basis(B, data["year"].values)
plot_knots(knots, ax);

In [None]:
# The seed is to make results reproducible
idata = model.fit(random_seed=SEED, idata_kwargs={"log_likelihood": True})

In [None]:
az.summary(idata)

In [None]:
az.plot_trace(idata);

In [None]:
posterior_stacked = az.extract(idata)
wp = posterior_stacked["bs(year, knots=iknots, intercept=True)"].mean("sample").values

ax = plot_spline_basis(B * wp.T, data["year"].values)
ax.plot(data.year.values, np.dot(B, wp.T), color="black", lw=3)
plot_knots(knots, ax);

In [None]:
def plot_predictions(data, idata, model):
    # Create a test dataset with observations spanning the whole range of year
    new_data = pd.DataFrame(
        {"year": np.linspace(data.year.min(), data.year.max(), num=500)}
    )

    # Predict the day of first blossom
    model.predict(idata, data=new_data)

    posterior_stacked = az.extract_dataset(idata)
    # Extract these predictions
    y_hat = posterior_stacked["mu"]

    # Compute the mean of the predictions, plotted as a single line.
    y_hat_mean = y_hat.mean("sample")

    # Compute 94% credible intervals for the predictions, plotted as bands
    hdi_data = np.quantile(y_hat, [0.03, 0.97], axis=1)

    # Plot obserevd data
    ax = plot_scatter(data)

    # Plot predicted line
    ax.plot(new_data["year"], y_hat_mean, color="firebrick")

    # Plot credibility bands
    ax.fill_between(
        new_data["year"], hdi_data[0], hdi_data[1], alpha=0.4, color="firebrick"
    )

    # Add knots
    plot_knots(knots, ax)

    return ax

In [None]:
plot_predictions(data, idata, model);

In [None]:
np.round(model.components["mu"].design.common.design_matrix, 3)

In [None]:
model.components["mu"].design.common.design_matrix.shape

In [None]:
np.linalg.matrix_rank(model.components["mu"].design.common.design_matrix)

In [None]:
# Note we use the same priors
model_new = bmb.Model("doy ~ bs(year, knots=iknots)", data, priors=priors)
idata_new = model_new.fit(random_seed=SEED, idata_kwargs={"log_likelihood": True})

In [None]:
az.summary(idata_new)

In [None]:
idata.posterior.sampling_time

In [None]:
idata_new.posterior.sampling_time

In [None]:
plot_predictions(data, idata_new, model_new);

In [None]:
models_dict = {"Original": idata, "New": idata_new}
df_compare = az.compare(models_dict)
df_compare

In [None]:
az.plot_compare(df_compare, insample_dev=False);

In [None]:
# Compute pointwise LOO
loo_1 = az.loo(idata, pointwise=True)
loo_2 = az.loo(idata_new, pointwise=True)

In [None]:
# plot kappa values
az.plot_khat(loo_1.pareto_k);

In [None]:
az.plot_khat(loo_2.pareto_k);