## Bayesian Linear Regression: Updating Beliefs about Model Parameters (Plotly Interactive)

In this notebook, we will explore Bayesian Linear Regression. This is a fundamental probabilistic model where we place a probability distribution over the possible values of our model parameters and update this distribution as we observe data. We'll see how a Gaussian prior distribution, combined with a linear model and Gaussian noise, leads to a Gaussian posterior distribution that can be computed analytically.

We will use JAX for numerical computation, `ipywidgets` for interactive controls, and **Plotly** for dynamic and interactive plotting.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Import necessary libraries
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np  # Useful for converting JAX arrays for plotting

# No matplotlib.pyplot needed in the main update function
import scipy.io  # To load the .mat data file
import ipywidgets as widgets  # For interactive controls
from IPython.display import display  # To display widgets and output
from jax.scipy.linalg import cholesky  # For plotting Gaussian contours
from gaussians import Gaussian
from gaussians_utils import (
    phi,
    _select_data,
    create_prior_distribution,
    compute_posterior,
    generate_parameter_space_data,
    generate_function_space_data,
)

# Import Plotly
import plotly.graph_objects as go
import plotly.express as px  # Can be useful for simpler plots
import plotly.io as pio  # For saving figures

pio.templates.default = "plotly_white"  # Set default template for Plotly

# Optional: No direct color import from tueplots needed for Plotly, define colors here
PLOTLY_COLORS = {
    "dark": "rgba(0,0,0,1.0)",  # Black
    "gray": "rgba(128,128,128,1.0)",  # Gray
    "blue": "rgba(0,0,255,1.0)",  # Blue
    "red": "rgba(255,0,0,1.0)",  # Red
    "dark_alpha": "rgba(0,0,0,0.2)",  # Black with transparency for bands/samples
    "blue_alpha": "rgba(0,0,255,0.2)",  # Blue with transparency
    "red_alpha": "rgba(255,0,0,0.5)",  # Red with transparency
}


# Optional: Configure JAX for 64-bit precision for potential numerical stability
jax.config.update("jax_enable_x64", True)

# Set a random seed for reproducibility
initial_key = jrandom.PRNGKey(0)

In [3]:
## Loading Data

# We'll use the data provided in the `lindata.mat` file.
# This file contains input features $X$, corresponding output values $Y$, and the known standard deviation of the observation noise $\\sigma$.

data = scipy.io.loadmat("lindata.mat")
X_all = data["X"]  # inputs (N, 1)
Y_all = data["Y"][:, 0]  # outputs (N,)
sigma_noise = data["sigma"][0].flatten()[0]  # Noise standard deviation (scalar)
N_total = X_all.shape[0]  # Total number of data points

print(f"Data loaded: {N_total} points")
print(f"Input shape (X): {X_all.shape}")
print(f"Output shape (Y): {Y_all.shape}")
print(f"Noise standard deviation (sigma): {sigma_noise}")

Data loaded: 20 points
Input shape (X): (20, 1)
Output shape (Y): (20,)
Noise standard deviation (sigma): 1.5


### The Linear Model

We are modeling the relationship between the input $x$ and output $y$ using a simple linear model:
$$ y = w_0 + w_1 x + \epsilon $$
where $w_0$ is the intercept, $w_1$ is the slope, and $\epsilon$ is observation noise. We combine the parameters into a vector $w = \begin{bmatrix} w_0 \ w_1 \end{bmatrix}$.

To write this in a more general linear regression form, we use a feature function $\phi(x)$ that transforms the input $x$ into a feature vector. For this simple linear model, $\phi(x) = \begin{bmatrix} 1 \ x \end{bmatrix}$. Then the model becomes:
$$ y = \phi(x)^T w + \epsilon $$

In our probabilistic setting, we assume the noise $\epsilon$ is Gaussian, $\epsilon \sim \mathcal{N}(0, \sigma^2)$, where $\sigma^2$ is the noise variance (square of sigma_noise). This means the likelihood of observing $y$ given $x$ and the parameters $w$ is a Gaussian centered at $\phi(x)^T w$ with variance $\sigma^2$:
$$ p(y | x, w) = \mathcal{N}(y; \phi(x)^T w, \sigma^2) $$

In the case of multiple data points $(X, Y)$, assuming they are independent given $w$, the joint likelihood $p(Y_{select} | X_{select}, w)$ is a multivariate Gaussian:
$$ p(Y_{select} | X_{select}, w) = \mathcal{N}(Y_{select}; \Phi_{select} w, \sigma^2 I) $$
where $\Phi_{select}$ is the matrix where each row is $\phi(x_i)^T$ for the selected inputs $x_i$, and $I$ is the identity matrix.

In [4]:
# --- Define the feature function ---
def phi(x):
    """
    Feature function for simple linear regression: [1, x]
    Accepts a single scalar x or a JAX array of shape (N, 1).
    Returns a JAX array of shape (N, num_features) or (num_features,).
    """
    if jnp.ndim(x) == 0:  # Handle scalar input
        return jnp.array([1.0, x])
    else:  # Handle array input (N, 1)
        return jnp.hstack([jnp.ones_like(x), x])


# Example usage:
x_example = 2.0
phi_example = phi(x_example)
print(f"phi({x_example}) = {phi_example}")

X_subset_example = X_all[:3] if X_all is not None else jnp.array([[0.0], [1.0], [2.0]])
Phi_subset_example = phi(X_subset_example)
print(f"phi(subset of X):\n{Phi_subset_example}")

phi(2.0) = [1. 2.]
phi(subset of X):
[[ 1.         -5.        ]
 [ 1.         -4.47368421]
 [ 1.         -3.94736842]]


### The Gaussian Prior

In the Bayesian approach, we start with a prior distribution over the parameters $w$. This prior represents our beliefs about the parameters before observing any data. For mathematical convenience, and because it is often a reasonable choice when we have some idea about the range and relationship of parameters, we choose a Gaussian prior for $w$:
$$ p(w) = \mathcal{N}(w; \mu_{prior}, \Sigma_{prior}) $$
where $\mu_{prior} \in \mathbb{R}^2$ is the prior mean vector and $\Sigma_{prior} \in \mathbb{R}^{2 \times 2}$ is the prior covariance matrix.

*   $\mu_{prior}$: Our initial best guess for the values of $w_0$ and $w_1$.
*   $\Sigma_{prior}$: Our initial uncertainty about $w$. The diagonal elements represent the variance of our belief about $w_0$ and $w_1$ independently. The off-diagonal elements represent how our belief about $w_0$ is correlated with our belief about $w_1$. A large diagonal value means high uncertainty.

In the interactive plot, you can adjust the parameters of this prior distribution. We define the $2 \times 2$ prior covariance matrix using the variance of $w_0$ ($\Sigma_{11}$), the variance of $w_1$ ($\Sigma_{22}$), and the correlation coefficient ($\rho$) between them:
$$
\Sigma_{prior} =
\begin{bmatrix}
\Sigma_{11} & \rho \sqrt{\Sigma_{11}\Sigma_{22}} \\
\rho \sqrt{\Sigma_{11}\Sigma_{22}} & \Sigma_{22}
\end{bmatrix}
$$

### Bayesian Inference: Computing the Posterior

The goal of Bayesian inference is to update our prior beliefs about $w$ using the observed data $(X_{select}, Y_{select})$. The updated belief is represented by the posterior distribution:
$$ p(w | X_{select}, Y_{select}) = \frac{p(Y_{select} | X_{select}, w) p(w)}{p(Y_{select} | X_{select})} $$
where $p(Y_{select} | X_{select})$ is the marginal likelihood, a normalization constant.

Since we chose a Gaussian prior $p(w)$ and the likelihood $p(Y_{select} | X_{select}, w)$ is a Gaussian (as explained above), and because the Gaussian distribution is its own conjugate prior for a Gaussian likelihood, the resulting posterior distribution $p(w | X_{select}, Y_{select})$ is also Gaussian:
$$ p(w | X_{select}, Y_{select}) = \mathcal{N}(w; \mu_{posterior}, \Sigma_{posterior}) $$

The parameters of the posterior distribution, $\mu_{posterior}$ and $\Sigma_{posterior}$, are updated from the prior parameters ($\mu_{prior}, \Sigma_{prior}$) and the selected data $(X_{select}, Y_{select})$ using specific analytical formulas derived from Gaussian conditioning. These are the same formulas we discussed in Lecture 06 for conditioning a Gaussian variable on another linearly related Gaussian variable. If $w \sim \mathcal{N}(\mu_{prior}, \Sigma_{prior})$ and $Y_{select} \sim \mathcal{N}(\Phi_{select} w, \sigma^2 I)$, then $w | Y_{select} \sim \mathcal{N}(\mu_{posterior}, \Sigma_{posterior})$.

The original script uses a Gaussian class with a .condition() method. This method encapsulates these analytical formulas for computing the posterior mean and covariance given the prior Gaussian, the feature matrix $\Phi_{select}$, the observed data $Y_{select}$, and the noise covariance $\sigma^2 I$.

### The Role of Selected Data

The power of interactive exploration here comes from selecting which data points we condition our posterior on. Initially, with no data selected, the posterior is the same as the prior. As you select data points, the posterior distribution (and the corresponding function space) will update to reflect the information gained from those specific observations.

### Interactive Bayesian Linear Regression with Plotly

We will now set up the interactive controls using ipywidgets and link them to a function that performs the Bayesian update and generates the Plotly plots.

#### Plotting utils

In [25]:
# --- Function to create interactive widgets ---
def create_regression_widgets(X_all, Y_all, N_total):
    """Creates and returns ipywidgets for Bayesian Linear Regression parameters and data selection."""

    # Sliders for the prior mean (2D)
    mu0_prior_slider = widgets.FloatSlider(
        min=-5.0,
        max=5.0,
        value=0.0,
        step=0.1,
        description="Prior Mu_0:",
        style={"description_width": "initial"},
    )
    mu1_prior_slider = widgets.FloatSlider(
        min=-5.0,
        max=5.0,
        value=0.0,
        step=0.1,
        description="Prior Mu_1:",
        style={"description_width": "initial"},
    )

    # Sliders for the prior covariance matrix parameters
    s11_prior_slider = widgets.FloatSlider(
        min=0.1,
        max=5.0,
        value=1.0,
        step=0.1,
        description="Prior Sigma_11:",
        style={"description_width": "initial"},
    )
    s22_prior_slider = widgets.FloatSlider(
        min=0.1,
        max=5.0,
        value=1.0,
        step=0.1,
        description="Prior Sigma_22:",
        style={"description_width": "initial"},
    )
    rho_prior_slider = widgets.FloatSlider(
        min=-0.99,
        max=0.99,
        value=0.0,
        step=0.01,
        description="Prior rho:",
        style={"description_width": "initial"},
    )

    # Widget to select data points
    data_selector_options = (
        [
            (f"Point {i + 1} (x={X_all[i, 0]:.2f}, y={Y_all[i]:.2f})", i)
            for i in range(N_total)
        ]
        if X_all is not None
        else []
    )
    data_selector = widgets.SelectMultiple(
        options=data_selector_options,
        description="Select Data Points:",
        disabled=(X_all is None),
        layout={"width": "500px"},
        style={"description_width": "initial"},
    )

    # Return a dictionary of widgets
    return {
        "mu0_prior": mu0_prior_slider,
        "mu1_prior": mu1_prior_slider,
        "s11_prior": s11_prior_slider,
        "s22_prior": s22_prior_slider,
        "rho_prior": rho_prior_slider,
        "selected_indices": data_selector,
    }

In [17]:
def _initialize_regression_figure():
    """Initializes a Plotly figure with two subplots for parameter and function space."""
    # Create two separate figures: one for parameter space, one for function space
    fig_param = go.Figure(
        layout=go.Layout(
            title="Bayesian Linear Regression: Parameter Space",
            xaxis=dict(
                title='<span class="math-inline">w0</span> (Intercept)',
                range=[-3, 3],
                scaleanchor="y",
                scaleratio=1,
            ),
            yaxis=dict(
                title='<span class="math-inline">w1</span> (Slope)',
                range=[-3, 3],
            ),
            showlegend=True,
            width=600,
            height=500,
        )
    )

    fig_func = go.Figure(
        layout=go.Layout(
            title="Bayesian Linear Regression: Function Space",
            xaxis=dict(
                title='<span class="math-inline">x</span>',
                range=[-5, 5],
            ),
            yaxis=dict(
                title='<span class="math-inline">y</span>',
                range=[-10, 10],
            ),
            showlegend=True,
            width=600,
            height=500,
        )
    )

    return fig_param, fig_func

In [18]:
def _add_parameter_space_traces(fig, data, PLOTLY_COLORS):
    """Adds traces for the parameter space plot to the figure."""
    # Plot prior contours
    for i, pts in enumerate(data["prior_contour_pts"]):
        fig.add_trace(
            go.Scattergl(
                x=np.array(pts[:, 0]),
                y=np.array(pts[:, 1]),
                mode="lines",
                line=dict(
                    color=PLOTLY_COLORS["gray"], dash="dash", width=2.5 / (i + 1)
                ),
                name=f"Prior {i + 1}σ Contour",
                showlegend=True,  # Show legend only once
                xaxis="x1",
                yaxis="y1",
            )
        )
    # Plot prior mean
    fig.add_trace(
        go.Scattergl(
            x=[data["prior_dist"].mu[0]],
            y=[data["prior_dist"].mu[1]],
            mode="markers",
            marker=dict(color=PLOTLY_COLORS["gray"], size=8),
            name="Prior Mean",
            showlegend=True,
            xaxis="x1",
            yaxis="y1",
        )
    )
    # Plot prior samples
    if data["prior_samples"] is not None:
        fig.add_trace(
            go.Scattergl(
                x=np.array(data["prior_samples"][:, 0]),
                y=np.array(data["prior_samples"][:, 1]),
                mode="markers",
                name="Prior Samples",
                marker=dict(size=4, opacity=0.8, color=PLOTLY_COLORS["dark"]),
                showlegend=True,
                xaxis="x1",
                yaxis="y1",
            )
        )

    # Plot Likelihood (MLE point) if data is selected
    if data["w_mle"] is not None:
        fig.add_trace(
            go.Scattergl(
                x=[data["w_mle"][0]],
                y=[data["w_mle"][1]],
                mode="markers",
                marker=dict(color=PLOTLY_COLORS["blue"], size=8),
                name="Likelihood (MLE)",
                showlegend=True,
                xaxis="x1",
                yaxis="y1",
            )
        )

    # Plot posterior contours
    for i, pts in enumerate(data["posterior_contour_pts"]):
        fig.add_trace(
            go.Scattergl(
                x=np.array(pts[:, 0]),
                y=np.array(pts[:, 1]),
                mode="lines",
                line=dict(color=PLOTLY_COLORS["red"], dash="dash", width=2.5 / (i + 1)),
                name=f"Posterior {i + 1}σ Contour",
                showlegend=True,  # Show legend only once
                xaxis="x1",
                yaxis="y1",
            )
        )
    # Plot posterior mean
    fig.add_trace(
        go.Scattergl(
            x=[data["posterior_dist"].mu[0]],
            y=[data["posterior_dist"].mu[1]],
            mode="markers",
            marker=dict(color=PLOTLY_COLORS["red"], size=8),
            name="Posterior Mean",
            showlegend=True,
            xaxis="x1",
            yaxis="y1",
        )
    )
    # Plot posterior samples
    if data["posterior_samples"] is not None:
        fig.add_trace(
            go.Scattergl(
                x=np.array(data["posterior_samples"][:, 0]),
                y=np.array(data["posterior_samples"][:, 1]),
                mode="markers",
                name="Posterior Samples",
                marker=dict(size=4, opacity=0.8, color=PLOTLY_COLORS["red"]),
                showlegend=True,
                xaxis="x1",
                yaxis="y1",
            )
        )
    return fig

In [19]:
def _add_function_space_traces(fig, data, PLOTLY_COLORS):
    """Adds traces for the function space plot to the figure."""
    # Plot all data points
    fig.add_trace(
        go.Scatter(
            x=np.array(data["all_data_x"]),
            y=np.array(data["all_data_y"]),
            mode="markers",
            name="All Data",
            marker=dict(color=PLOTLY_COLORS["dark"], size=5),
            error_y=dict(
                type="data",
                array=np.array(data["sigma_noise"] * jnp.ones_like(data["all_data_y"])),
            ),
            showlegend=True,
            xaxis="x2",
            yaxis="y2",
        )
    )
    # Highlight selected data points
    if data["selected_data_x"].shape[0] > 0:  # Check if there are selected points
        fig.add_trace(
            go.Scatter(
                x=np.array(data["selected_data_x"]),
                y=np.array(data["selected_data_y"]),
                mode="markers",
                name="Selected Data",
                marker=dict(
                    color=PLOTLY_COLORS["red"],
                    size=7,
                    line=dict(width=1, color="DarkRed"),
                ),
                error_y=dict(
                    type="data",
                    array=np.array(
                        data["sigma_noise"] * jnp.ones_like(data["selected_data_y"])
                    ),
                ),
                showlegend=True,
                xaxis="x2",
                yaxis="y2",
            )
        )

    # Plot prior mean function
    fig.add_trace(
        go.Scatter(
            x=np.array(data["x_plot"][:, 0]),
            y=np.array(data["prior_mean_f"]),
            mode="lines",
            name="Prior Mean",
            line=dict(color=PLOTLY_COLORS["dark"], width=2),
            showlegend=True,
            xaxis="x2",
            yaxis="y2",
        )
    )
    # Plot prior uncertainty band (+/- 2 std dev)
    fig.add_trace(
        go.Scatter(
            x=np.array(data["x_plot"][:, 0]),
            y=np.array(data["prior_upper_f"]),
            mode="lines",
            line=dict(width=0),
            name="Prior +2σ",
            showlegend=False,
            xaxis="x2",
            yaxis="y2",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=np.array(data["x_plot"][:, 0]),
            y=np.array(data["prior_lower_f"]),
            mode="lines",
            line=dict(width=0),
            name="Prior 2σ Uncertainty Band",
            showlegend=True,
            xaxis="x2",
            yaxis="y2",
            fill="tonexty",
            fillcolor=PLOTLY_COLORS["dark_alpha"],
        )
    )  # Fill between upper and lower

    # Plot prior function samples
    if data["prior_func_samples"] is not None:
        for i in range(data["prior_func_samples"].shape[1]):
            fig.add_trace(
                go.Scatter(
                    x=np.array(data["x_plot"][:, 0]),
                    y=np.array(data["prior_func_samples"][:, i]),
                    mode="lines",
                    line=dict(color=PLOTLY_COLORS["dark"], width=1),
                    opacity=0.3,
                    name="Prior Sample Functions",
                    showlegend=False,  # Don't show legend for each sample
                    xaxis="x2",
                    yaxis="y2",
                )
            )

    # Plot posterior mean function
    fig.add_trace(
        go.Scatter(
            x=np.array(data["x_plot"][:, 0]),
            y=np.array(data["posterior_mean_f"]),
            mode="lines",
            name="Posterior Mean",
            line=dict(color=PLOTLY_COLORS["red"], width=2),
            showlegend=True,
            xaxis="x2",
            yaxis="y2",
        )
    )
    # Plot posterior uncertainty band (+/- 2 std dev)
    fig.add_trace(
        go.Scatter(
            x=np.array(data["x_plot"][:, 0]),
            y=np.array(data["posterior_upper_f"]),
            mode="lines",
            line=dict(width=0),
            name="Posterior +2σ",
            showlegend=False,
            xaxis="x2",
            yaxis="y2",
        )
    )
    fig.add_trace(
        go.Scatter(
            x=np.array(data["x_plot"][:, 0]),
            y=np.array(data["posterior_lower_f"]),
            mode="lines",
            line=dict(width=0),
            name="Posterior 2σ Uncertainty Band",
            showlegend=True,
            xaxis="x2",
            yaxis="y2",
            fill="tonexty",
            fillcolor=PLOTLY_COLORS["red_alpha"],
        )
    )  # Fill between upper and lower

    # Plot posterior function samples
    if data["posterior_func_samples"] is not None:
        for i in range(data["posterior_func_samples"].shape[1]):
            fig.add_trace(
                go.Scatter(
                    x=np.array(data["x_plot"][:, 0]),
                    y=np.array(data["posterior_func_samples"][:, i]),
                    mode="lines",
                    line=dict(color=PLOTLY_COLORS["red"], width=1),
                    opacity=0.5,
                    name="Posterior Sample Functions",
                    showlegend=False,  # Don't show legend for each sample
                    xaxis="x2",
                    yaxis="y2",
                )
            )

    # Plot Likelihood (MLE Function) if calculated
    if data["mle_func"] is not None:
        fig.add_trace(
            go.Scatter(
                x=np.array(data["x_plot"][:, 0]),
                y=np.array(data["mle_func"]),
                mode="lines",
                line=dict(color=PLOTLY_COLORS["blue"], dash="dash", width=2),
                name="Likelihood (MLE Function)",
                showlegend=True,
                xaxis="x2",
                yaxis="y2",
            )
        )
    return fig

In [20]:
def _finalize_and_show_figure(fig_param, fig_func):
    """Applies final layout adjustments and displays the figure."""
    fig_param.update_layout(
        legend=dict(orientation="h", x=0.5, y=-0.15, xanchor="center", yanchor="top"),
        title_x=0.5,  # Center the main title
        title_y=0.95,  # Position the main title slightly lower
        title="Bayesian Linear Regression: Parameter Space",  # Set the main title here
        width=1200,
        height=600,
    )

    fig_param.show()

    fig_func.update_layout(
        legend=dict(orientation="h", x=0.5, y=-0.15, xanchor="center", yanchor="top"),
        title_x=0.5,  # Center the main title
        title_y=0.95,  # Position the main title slightly lower
        title="Bayesian Linear Regression: Function Space",  # Set the main title here
        width=1200,
        height=600,
    )

    fig_func.show()


In [None]:
# --- Function to set up and display the interactive interface ---
def setup_interactive_regression(X_all, Y_all, sigma_noise, N_total, update_plot_fn):
    """
    Creates widgets, links them to the update_plot_fn, and displays the interactive interface.
    """
    if X_all is None or Y_all is None or sigma_noise is None:
        print("Data not loaded. Cannot setup interactive interface.")
        return

    # Create the widgets
    regression_widgets = create_regression_widgets(X_all, Y_all, N_total)

    # Link widgets to the update function
    interactive_plot_output = widgets.interactive_output(
        update_plot_fn,
        regression_widgets,  # Pass the dictionary of widgets
    )

    # Arrange widgets (Assuming a horizontal box layout for controls)
    prior_controls_box = widgets.VBox(
        [
            widgets.Label("Prior Parameters:"),
            regression_widgets["mu0_prior"],
            regression_widgets["mu1_prior"],
            regression_widgets["s11_prior"],
            regression_widgets["s22_prior"],
            regression_widgets["rho_prior"],
        ]
    )

    data_selection_control_box = widgets.VBox(
        [
            widgets.Label("Data Selection:"),
            regression_widgets["selected_indices"],
        ]
    )

    controls_box = widgets.HBox([prior_controls_box, data_selection_control_box])

    # Display the controls and the plot output
    print(
        "Adjust the sliders and select data points to explore Bayesian Linear Regression:"
    )
    display(controls_box, interactive_plot_output)

#### Main

In [21]:
# --- Main Update Function (Orchestrator) ---
def update_regression_plot_plotly(
    mu0_prior, mu1_prior, s11_prior, s22_prior, rho_prior, selected_indices
):
    """
    Updates the Plotly plots for Bayesian Linear Regression based on the selected parameters and data.
    This function orchestrates calls to smaller helper functions.
    """
    # Assume X_all, Y_all, sigma_noise are available globally or passed in higher scope
    if X_all is None or Y_all is None or sigma_noise is None:
        # Display an empty plot or message in the output area if data is not loaded
        fig = go.Figure()
        fig.update_layout(title="Data not loaded.")
        fig.show()
        return

    # 1. Data Handling & Prior Setup
    prior_dist = create_prior_distribution(
        mu0_prior, mu1_prior, s11_prior, s22_prior, rho_prior
    )
    X_select, Y_select, Lambda_select_sq = _select_data(
        X_all, Y_all, sigma_noise, selected_indices
    )

    # 2. Posterior Calculation
    posterior_dist = compute_posterior(prior_dist, X_select, Y_select, Lambda_select_sq)

    # 3. Plotting Data Preparation
    # Regenerate key for samples each update
    global initial_key  # Assuming initial_key is a global JAX PRNG key
    initial_key, subkey = jrandom.split(initial_key)

    param_space_data = generate_parameter_space_data(
        prior_dist, posterior_dist, X_select, Y_select, subkey
    )
    # Update key based on usage in param_space_data generation
    subkey = param_space_data["key"]

    func_space_data = generate_function_space_data(
        prior_dist,
        posterior_dist,
        X_all,
        Y_all,
        sigma_noise,
        X_select,
        Y_select,
        param_space_data["w_mle"],
        subkey,
    )
    # Update key based on usage in func_space_data generation
    initial_key = func_space_data["key"]  # Update the global key for the next call

    # Add prior and posterior distributions to the data dictionaries for easier access in plotting
    param_space_data["prior_dist"] = prior_dist
    param_space_data["posterior_dist"] = posterior_dist

    # 4. Plotly Figure Creation & Population
    fig_param, fig_function = _initialize_regression_figure()
    fig_param = _add_parameter_space_traces(fig_param, param_space_data, PLOTLY_COLORS)
    fig_function = _add_function_space_traces(
        fig_function, func_space_data, PLOTLY_COLORS
    )

    # 5. Finalize and Show
    _finalize_and_show_figure(fig_param, fig_function)

#### Interactive

In [23]:
# --- Run the setup function ---
if X_all is not None:  # Check if data loaded before setting up
    setup_interactive_regression(
        X_all, Y_all, sigma_noise, N_total, update_regression_plot_plotly
    )

Adjust the sliders and select data points to explore Bayesian Linear Regression:


HBox(children=(VBox(children=(Label(value='Prior Parameters:'), FloatSlider(value=0.0, description='Prior Mu_0…

Output()

### Explanation of Bayesian Linear Regression

At its core, Bayesian linear regression is about finding a probability distribution over the possible straight lines (or hyperplanes in higher dimensions) that could have generated the data. Instead of finding a single "best" line, we maintain a belief about what the parameters $w_0$ (intercept) and $w_1$ (slope) could be, represented by a joint probability distribution $p(w_0, w_1)$.

1. The Model: We assume the data is generated by a linear function corrupted by Gaussian noise: $y = w_0 + w_1 x + \epsilon$, where $\epsilon \sim \mathcal{N}(0, \sigma^2)$.
2. The Prior: We start with a prior belief about the parameters $w$. A common and mathematically convenient choice is a Gaussian prior $p(w) = \mathcal{N}(w; \mu_{prior}, \Sigma_{prior})$. This distribution reflects our initial uncertainty about the intercept and slope before seeing any data.
    * In the parameter space plot (left), the contours of the prior Gaussian show regions of higher probability for the $[w_0, w_1]$ pair. The center of the ellipse is the prior mean $\mu_{prior}$. The shape and orientation of the ellipse are determined by the prior covariance $\Sigma_{prior}$.
    * In the function space plot (right), the prior mean function $f(x) = \phi(x)^T \mu_{prior}$ is shown as a line. The uncertainty band around it shows the range of function values (e.g., $\pm 2$ standard deviations) predicted by the prior distribution over parameters. Each sample from the prior in parameter space corresponds to a specific line in function space, illustrating the variety of functions considered plausible under the prior.
3. The Likelihood: The likelihood $p(y | x, w)$ tells us the probability of observing a data point $(x, y)$ given specific parameter values $w$. Due to the Gaussian noise assumption, this likelihood is Gaussian: the observed $y$ is likely to be close to $\phi(x)^T w$. For multiple independent data points, the joint likelihood $p(Y_{select} | X_{select}, w)$ is also Gaussian.
    * In the parameter space plot, the "Likelihood (MLE)" point represents the parameter values that maximize the likelihood for the selected data – effectively, the line that best fits only the selected data according to the least squares criterion. (Note: The likelihood itself is a function of $w$ for fixed data, not a distribution over $w$ that you can sample from directly).
    * In the function space plot, the "Likelihood (MLE Function)" shows the line corresponding to the MLE parameters. The selected data points are also highlighted.
4. The Posterior: When we observe data, we update our prior belief to get the posterior distribution $p(w | X_{select}, Y_{select})$. Thanks to the conjugate property of Gaussian priors with Gaussian likelihoods, the posterior is also Gaussian, but with updated mean $\mu_{posterior}$ and covariance $\Sigma_{posterior}$.
    * The posterior mean $\mu_{posterior}$ is a weighted average of the prior mean and the information from the data (specifically, the MLE). As you add more data, especially informative data, the posterior mean will move towards the MLE.
    * The posterior covariance $\Sigma_{posterior}$ is smaller than the prior covariance, reflecting a reduction in uncertainty about the parameters after observing data. As you add more data, the posterior ellipse in parameter space will shrink.
    * In the parameter space plot (left subplot), the posterior contours and samples show the updated belief about $w$.
    * In the function space plot (right subplot), the posterior mean function shows the line that best fits the selected data, considering the prior. The posterior uncertainty band is typically narrower than the prior band, reflecting increased certainty about the function after seeing data. Samples from the posterior in function space show lines that are plausible given the data and the prior.

By selecting data points using the interactive widget, you can observe how the posterior distribution shifts and shrinks, and how this translates into a more certain belief about the linear relationship between $x$ and $y$ in the function space. The posterior mean function becomes a better fit to the selected data, and the uncertainty band narrows, particularly in regions where data has been observed.

This interactive notebook allows you to visualize this core process of Bayesian learning: starting with a belief (prior), observing evidence (data), and updating that belief (posterior).