# Bayesian Linear Regression: Updating Beliefs about Model Parameters (Plotly Interactive)

In this notebook, we will explore Bayesian Linear Regression. This is a fundamental probabilistic model where we place a probability distribution over the possible values of our model parameters and update this distribution as we observe data. We'll see how a Gaussian prior distribution, combined with a linear model and Gaussian noise, leads to a Gaussian posterior distribution that can be computed analytically.

We will use JAX for numerical computation, `ipywidgets` for interactive controls, and **Plotly** for dynamic and interactive plotting.

In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
# Import necessary libraries
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np  # Useful for converting JAX arrays for plotting

# No matplotlib.pyplot needed in the main update function
import scipy.io  # To load the .mat data file
import ipywidgets as widgets  # For interactive controls
from IPython.display import display  # To display widgets and output
from jax.scipy.linalg import cholesky  # For plotting Gaussian contours
from gaussians import Gaussian

# Import Plotly
import plotly.graph_objects as go
import plotly.express as px  # Can be useful for simpler plots

# Optional: No direct color import from tueplots needed for Plotly, define colors here
PLOTLY_COLORS = {
    "dark": "rgba(0,0,0,1.0)",  # Black
    "gray": "rgba(128,128,128,1.0)",  # Gray
    "blue": "rgba(0,0,255,1.0)",  # Blue
    "red": "rgba(255,0,0,1.0)",  # Red
    "dark_alpha": "rgba(0,0,0,0.2)",  # Black with transparency for bands/samples
    "blue_alpha": "rgba(0,0,255,0.2)",  # Blue with transparency
    "red_alpha": "rgba(255,0,0,0.5)",  # Red with transparency
}


# Optional: Configure JAX for 64-bit precision for potential numerical stability
jax.config.update("jax_enable_x64", True)

# Set a random seed for reproducibility
initial_key = jrandom.PRNGKey(0)

## Loading Data

# We'll use the data provided in the `lindata.mat` file.
# This file contains input features $X$, corresponding output values $Y$, and the known standard deviation of the observation noise $\\sigma$.

data = scipy.io.loadmat("lindata.mat")
X_all = data["X"]  # inputs (N, 1)
Y_all = data["Y"][:, 0]  # outputs (N,)
sigma_noise = data["sigma"][0].flatten()[0]  # Noise standard deviation (scalar)
N_total = X_all.shape[0]  # Total number of data points

print(f"Data loaded: {N_total} points")
print(f"Input shape (X): {X_all.shape}")
print(f"Output shape (Y): {Y_all.shape}")
print(f"Noise standard deviation (sigma): {sigma_noise}")

Data loaded: 20 points
Input shape (X): (20, 1)
Output shape (Y): (20,)
Noise standard deviation (sigma): 1.5


## The Linear Model

We are modeling the relationship between the input $x$ and output $y$ using a simple linear model:
$$ y = w_0 + w_1 x + \epsilon $$
where $w\_0$ is the intercept, $w\_1$ is the slope, and $\\epsilon$ is observation noise. We combine the parameters into a vector $w = \\begin{bmatrix} w\_0 \\ w\_1 \\end{bmatrix}$.

To write this in a more general linear regression form, we use a **feature function** $\\phi(x)$ that transforms the input $x$ into a feature vector. For this simple linear model, $\\phi(x) = \\begin{bmatrix} 1 \\ x \\end{bmatrix}$. Then the model becomes:
$$ y = \phi(x)^T w + \epsilon $$

In our probabilistic setting, we assume the noise $\\epsilon$ is Gaussian, $\\epsilon \\sim \\mathcal{N}(0, \\sigma^2)$, where $\\sigma^2$ is the noise variance (square of `sigma_noise`). This means the likelihood of observing $y$ given $x$ and the parameters $w$ is a Gaussian centered at $\\phi(x)^T w$ with variance $\\sigma^2$:
$$ p(y | x, w) = \mathcal{N}(y; \phi(x)^T w, \sigma^2) $$

In the case of multiple data points $(X, Y)$, assuming they are independent given $w$, the joint likelihood $p(Y\_{select} | X\_{select}, w)$ is a multivariate Gaussian:
$$ p(Y_{select} | X_{select}, w) = \mathcal{N}(Y_{select}; \Phi_{select} w, \sigma^2 I) $$
where $\\Phi\_{select}$ is the matrix where each row is $\\phi(x\_i)^T$ for the selected inputs $x\_i$, and $I$ is the identity matrix.

In [None]:
# --- Define the feature function ---
def phi(x):
  """
  Feature function for simple linear regression: [1, x]
  Accepts a single scalar x or a JAX array of shape (N, 1).
  Returns a JAX array of shape (N, num_features) or (num_features,).
  """
  if jnp.ndim(x) == 0: # Handle scalar input
      return jnp.array([1.0, x])
  else: # Handle array input (N, 1)
      return jnp.hstack([jnp.ones_like(x), x])

# Example usage:
x_example = 2.0
phi_example = phi(x_example)
print(f"phi({x_example}) = {phi_example}")

X_subset_example = X_all[:3] if X_all is not None else jnp.array([[0.],[1.],[2.]])
Phi_subset_example = phi(X_subset_example)
print(f"phi(subset of X):\n{Phi_subset_example}")

## The Gaussian Prior

In the Bayesian approach, we start with a **prior distribution** over the parameters $w$. This prior represents our beliefs about the parameters before observing any data. For mathematical convenience, and because it is often a reasonable choice when we have some idea about the range and relationship of parameters, we choose a **Gaussian prior** for $w$:
$$ p(w) = \mathcal{N}(w; \mu_{prior}, \Sigma_{prior}) $$
where $\\mu\_{prior} \\in \\mathbb{R}^2$ is the prior mean vector and $\\Sigma\_{prior} \\in \\mathbb{R}^{2 \\times 2}$ is the prior covariance matrix.

  * $\\mu\_{prior}$: Our initial best guess for the values of $w\_0$ and $w\_1$.
  * $\\Sigma\_{prior}$: Our initial uncertainty about $w$. The diagonal elements represent the variance of our belief about $w\_0$ and $w\_1$ independently. The off-diagonal elements represent how our belief about $w\_0$ is correlated with our belief about $w\_1$. A large diagonal value means high uncertainty.

In the interactive plot, you can adjust the parameters of this prior distribution. We define the 2x2 prior covariance matrix using the variance of $w\_0$ ($\\Sigma\_{11}$), the variance of $w\_1$ ($\\Sigma\_{22}$), and the correlation coefficient ($\\rho$) between them:
$$ \Sigma_{prior} = \begin{bmatrix} \Sigma_{11} & \rho \sqrt{\Sigma_{11}\Sigma_{22}} \\ \rho \sqrt{\Sigma_{11}\Sigma_{22}} & \Sigma_{22} \end{bmatrix} $$

## Bayesian Inference: Computing the Posterior

The goal of Bayesian inference is to update our prior beliefs about $w$ using the observed data $(X\_{select}, Y\_{select})$. The updated belief is represented by the **posterior distribution**:
$$ p(w | X_{select}, Y_{select}) = \frac{p(Y_{select} | X_{select}, w) p(w)}{p(Y_{select} | X_{select})} $$
where $p(Y\_{select} | X\_{select})$ is the marginal likelihood, a normalization constant.

Since we chose a Gaussian prior $p(w)$ and the likelihood $p(Y\_{select} | X\_{select}, w)$ is a Gaussian (as explained above), and because the Gaussian distribution is its own conjugate prior for a Gaussian likelihood, the resulting posterior distribution $p(w | X\_{select}, Y\_{select})$ is also **Gaussian**:
$$ p(w | X_{select}, Y_{select}) = \mathcal{N}(w; \mu_{posterior}, \Sigma_{posterior}) $$

The parameters of the posterior distribution, $\\mu\_{posterior}$ and $\\Sigma\_{posterior}$, are updated from the prior parameters ($\\mu\_{prior}, \\Sigma\_{prior}$) and the selected data $(X\_{select}, Y\_{select})$ using specific analytical formulas derived from Gaussian conditioning. These are the same formulas we discussed in Lecture 06 for conditioning a Gaussian variable on another linearly related Gaussian variable. If $w \\sim \\mathcal{N}(\\mu\_{prior}, \\Sigma\_{prior})$ and $Y\_{select} \\sim \\mathcal{N}(\\Phi\_{select} w, \\sigma^2 I)$, then $w | Y\_{select} \\sim \\mathcal{N}(\\mu\_{posterior}, \\Sigma\_{posterior})$.

The original script uses a `Gaussian` class with a `.condition()` method. This method encapsulates these analytical formulas for computing the posterior mean and covariance given the prior Gaussian, the feature matrix $\\Phi\_{select}$, the observed data $Y\_{select}$, and the noise covariance $\\sigma^2 I$.

## The Role of Selected Data

The power of interactive exploration here comes from selecting *which* data points we condition our posterior on. Initially, with no data selected, the posterior is the same as the prior. As you select data points, the posterior distribution (and the corresponding function space) will update to reflect the information gained from those specific observations.

## Interactive Bayesian Linear Regression with Plotly

We will now set up the interactive controls using `ipywidgets` and link them to a function that performs the Bayesian update and generates the Plotly plots.

In [None]:
# --- Define interactive widgets ---

# Sliders for the prior mean (2D)
mu0_prior_slider = widgets.FloatSlider(min=-5., max=5., value=0., step=.1, description='Prior Mu_0:')
mu1_prior_slider = widgets.FloatSlider(min=-5., max=5., value=0., step=.1, description='Prior Mu_1:')

# Sliders for the prior covariance matrix parameters
# Sigma_prior = [[S11, rho*sqrt(S11*S22)], [rho*sqrt(S11*S22), S22]]
s11_prior_slider = widgets.FloatSlider(min=0.1, max=5., value=1., step=.1, description='Prior Sigma_11:')
s22_prior_slider = widgets.FloatSlider(min=0.1, max=5., value=1., step=.1, description='Prior Sigma_22:')
rho_prior_slider = widgets.FloatSlider(min=-0.99, max=0.99, value=0., step=.01, description='Prior rho:')

# Widget to select data points
# Display point indices to select
data_selector = widgets.SelectMultiple(
    options=[(f'Point {i+1} (x={X_all[i,0]:.2f}, y={Y_all[i]:.2f})', i) for i in range(N_total)] if X_all is not None else [],
    description='Select Data Points:',
    disabled=(X_all is None),
    layout={'width': '400px'}
)


# --- Function to update plots based on widget values ---

def update_regression_plot_plotly(mu0_prior, mu1_prior, s11_prior, s22_prior, rho_prior, selected_indices):
    if X_all is None or Y_all is None or sigma_noise is None:
        print("Data not loaded. Cannot update plot.")
        # Display an empty plot or message in the output area
        fig = go.Figure()
        fig.update_layout(title="Data not loaded.")
        fig.show()
        return

    # Construct the prior mean and covariance matrix
    mu_prior = jnp.asarray([mu0_prior, mu1_prior])
    # Ensure s11 and s22 are positive for sqrt
    s11_safe = jnp.maximum(s11_prior, 1e-6)
    s22_safe = jnp.maximum(s22_prior, 1e-6)
    # Ensure rho is within valid range [-1, 1]
    rho_safe = jnp.clip(rho_prior, -0.999, 0.999)

    S12_prior = rho_safe * jnp.sqrt(s11_safe * s22_safe)
    Sigma_prior = jnp.asarray([[s11_safe, S12_prior], [S12_prior, s22_safe]])

    # Create the prior Gaussian distribution object
    prior_dist = Gaussian(mu=mu_prior, Sigma=Sigma_prior)

    # Select the data points based on indices
    X_select = X_all[list(selected_indices)]
    Y_select = Y_all[list(selected_indices)]
    # Noise covariance matrix for selected data (assuming iid noise)
    Lambda_select_sq = sigma_noise**2 * jnp.eye(len(selected_indices)) if len(selected_indices) > 0 else None


    # Compute the posterior distribution
    if len(selected_indices) > 0:
        # The .condition() method of the Gaussian class computes the posterior
        posterior_dist = prior_dist.condition(phi(X_select), Y_select, Lambda_select_sq)
    else:
        # If no data is selected, the posterior is the same as the prior
        posterior_dist = prior_dist

    # Regenerate key for samples each update
    global initial_key
    initial_key, subkey = jrandom.split(initial_key)

    # --- Prepare data for Plotly plotting ---

    # Parameter space plot data (Contours and Samples)
    n_contour_levels = 3 # Plot contours at 1, 2, 3 standard deviations
    theta = jnp.linspace(0, 2 * jnp.pi, 100)
    circle_pts = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1) # Unit circle points (100, 2)

    # Prior contour points
    prior_contour_pts = []
    if prior_dist.L is not None:
        for i in range(1, n_contour_levels + 1):
             pts = prior_dist.mu + i * jnp.dot(circle_pts, prior_dist.L.T)
             prior_contour_pts.append(pts)

    # Posterior contour points
    posterior_contour_pts = []
    if posterior_dist.L is not None:
         for i in range(1, n_contour_levels + 1):
             pts = posterior_dist.mu + i * jnp.dot(circle_pts, posterior_dist.L.T)
             posterior_contour_pts.append(pts)

    # Sample from prior and posterior parameter space
    num_samples_param_space = 10 # Plot a few samples
    prior_samples_param_space = prior_dist.sample(subkey, num_samples_param_space) if prior_dist.L is not None else None
    key, subkey = jrandom.split(subkey) # Use new key for posterior samples
    posterior_samples_param_space = posterior_dist.sample(subkey, num_samples_param_space) if posterior_dist.L is not None else None


    # Function space plot data (Data, Mean Functions, Uncertainty Bands, Sample Functions)
    x_plot = jnp.linspace(-5, 5, 100)[:, None] # X values for plotting functions
    phi_plot = phi(x_plot) # Feature matrix for plotting

    # Prior mean function and uncertainty band (+/- 2 std dev)
    prior_mean_f = jnp.dot(phi_plot, prior_dist.mu)
    prior_var_f = jnp.sum(phi_plot * jnp.dot(phi_plot, prior_dist.Sigma), axis=1)
    prior_std_f = jnp.sqrt(prior_var_f)
    prior_upper_f = prior_mean_f + 2 * prior_std_f
    prior_lower_f = prior_mean_f - 2 * prior_std_f

    # Posterior mean function and uncertainty band (+/- 2 std dev)
    posterior_mean_f = jnp.dot(phi_plot, posterior_dist.mu)
    posterior_var_f = jnp.sum(phi_plot * jnp.dot(phi_plot, posterior_dist.Sigma), axis=1)
    posterior_std_f = jnp.sqrt(posterior_var_f)
    posterior_upper_f = posterior_mean_f + 2 * posterior_std_f
    posterior_lower_f = posterior_mean_f - 2 * posterior_std_f

    # Sample functions from prior and posterior
    num_samples_func_space = 5 # Plot a few sample functions
    key, subkey = jrandom.split(subkey) # Use new key for function samples
    prior_samples_param_space_func = prior_dist.sample(subkey, num_samples_func_space) if prior_dist.L is not None else None
    prior_func_samples = jnp.dot(phi_plot, prior_samples_param_space_func.T) if prior_samples_param_space_func is not None else None

    key, subkey = jrandom.split(subkey) # Use new key for posterior function samples
    posterior_samples_param_space_func = posterior_dist.sample(subkey, num_samples_func_space) if posterior_dist.L is not None else None
    posterior_func_samples = jnp.dot(phi_plot, posterior_samples_param_space_func.T) if posterior_samples_param_space_func is not None else None


    # --- Create Plotly Figures ---

    # Create figure with two subplots
    fig = go.Figure(
        layout=go.Layout(
            title="Bayesian Linear Regression",
            grid=dict(
                rows=1,
                columns=2,
                pattern="independent",
            ),
            showlegend=True,
            width=1000, # Adjust figure width as needed
            height=500, # Adjust figure height as needed
        )
    )

    # --- Plot 1: Parameter Space ---

    # Plot prior contours
    for i, pts in enumerate(prior_contour_pts):
        fig.add_trace(go.Scattergl(x=np.array(pts[:, 0]), y=np.array(pts[:, 1]), mode='lines',
                                  line=dict(color=PLOTLY_COLORS["gray"], dash='dash', width=2.5/(i+1)),
                                  name=f'Prior {i+1}σ Contour', showlegend=(i==0), # Show legend only once
                                  xaxis='x1', yaxis='y1'))
    # Plot prior mean
    fig.add_trace(go.Scattergl(x=[prior_dist.mu[0]], y=[prior_dist.mu[1]], mode='markers',
                               marker=dict(color=PLOTLY_COLORS["gray"], size=8),
                               name='Prior Mean', showlegend=True,
                               xaxis='x1', yaxis='y1'))
    # Plot prior samples
    if prior_samples_param_space is not None:
         fig.add_trace(go.Scattergl(x=np.array(prior_samples_param_space[:, 0]), y=np.array(prior_samples_param_space[:, 1]),
                                   mode='markers', name='Prior Samples',
                                   marker=dict(size=4, opacity=0.8, color=PLOTLY_COLORS["dark"]),
                                   showlegend=True, xaxis='x1', yaxis='y1'))


    # Plot Likelihood (MLE point) if data is selected
    w_mle = None
    if len(selected_indices) > 0:
         try:
             Phi_select = phi(X_select)
             # Check if Phi_select.T @ Phi_select is invertible
             if jnp.linalg.det(jnp.dot(Phi_select.T, Phi_select)) > 1e-6:
                 w_mle = jnp.linalg.solve(jnp.dot(Phi_select.T, Phi_select), jnp.dot(Phi_select.T, Y_select))
                 fig.add_trace(go.Scattergl(x=[w_mle[0]], y=[w_mle[1]], mode='markers',
                                          marker=dict(color=PLOTLY_COLORS["blue"], size=8),
                                          name='Likelihood (MLE)', showlegend=True,
                                          xaxis='x1', yaxis='y1'))
             # Else: not enough distinct points, MLE is not unique, don't plot point
         except Exception as e:
             print(f"Error calculating MLE point: {e}")


    # Plot posterior contours
    for i, pts in enumerate(posterior_contour_pts):
        fig.add_trace(go.Scattergl(x=np.array(pts[:, 0]), y=np.array(pts[:, 1]), mode='lines',
                                  line=dict(color=PLOTLY_COLORS["red"], dash='dash', width=2.5/(i+1)),
                                  name=f'Posterior {i+1}σ Contour', showlegend=(i==0), # Show legend only once
                                  xaxis='x1', yaxis='y1'))
    # Plot posterior mean
    fig.add_trace(go.Scattergl(x=[posterior_dist.mu[0]], y=[posterior_dist.mu[1]], mode='markers',
                               marker=dict(color=PLOTLY_COLORS["red"], size=8),
                               name='Posterior Mean', showlegend=True,
                               xaxis='x1', yaxis='y1'))
    # Plot posterior samples
    if posterior_samples_param_space is not None:
        fig.add_trace(go.Scattergl(x=np.array(posterior_samples_param_space[:, 0]), y=np.array(posterior_samples_param_space[:, 1]),
                                   mode='markers', name='Posterior Samples',
                                   marker=dict(size=4, opacity=0.8, color=PLOTLY_COLORS["red"]),
                                   showlegend=True, xaxis='x1', yaxis='y1'))


    # Update layout for parameter space plot
    fig.update_layout(
        xaxis1=dict(
            title='$w_0$ (Intercept)',
            range=[-3, 3], # Fixed limits
            scaleanchor="y1", scaleratio=1, # Equal aspect ratio
            domain=[0, 0.48] # Position in the subplot grid
        ),
        yaxis1=dict(
            title='$w_1$ (Slope)',
            range=[-3, 3], # Fixed limits
            domain=[0, 1] # Position in the subplot grid
        ),
        title_x=0.5, # Center the main title
        title_y=0.95 # Position the main title slightly lower
    )


    # --- Plot 2: Function Space ---

    # Plot all data points
    fig.add_trace(go.Scattergl(x=np.array(X_all[:, 0]), y=np.array(Y_all),
                               mode='markers', name='All Data',
                               marker=dict(color=PLOTLY_COLORS["dark"], size=5),
                               error_y=dict(type='data', array=np.array(sigma_noise * jnp.ones_like(Y_all))),
                               showlegend=True, xaxis='x2', yaxis='y2'))
    # Highlight selected data points
    if len(selected_indices) > 0:
         fig.add_trace(go.Scattergl(x=np.array(X_select[:, 0]), y=np.array(Y_select),
                                    mode='markers', name='Selected Data',
                                    marker=dict(color=PLOTLY_COLORS["red"], size=7, line=dict(width=1, color='DarkRed')),
                                    error_y=dict(type='data', array=np.array(sigma_noise * jnp.ones_like(Y_select))),
                                    showlegend=True, xaxis='x2', yaxis='y2'))


    # Plot prior mean function
    fig.add_trace(go.Scattergl(x=np.array(x_plot[:, 0]), y=np.array(prior_mean_f),
                               mode='lines', name='Prior Mean',
                               line=dict(color=PLOTLY_COLORS["dark"], width=2),
                               showlegend=True, xaxis='x2', yaxis='y2'))
    # Plot prior uncertainty band (+/- 2 std dev)
    fig.add_trace(go.Scattergl(x=np.array(x_plot[:, 0]), y=np.array(prior_upper_f), mode='lines',
                               line=dict(width=0), name='Prior +2σ', showlegend=False, xaxis='x2', yaxis='y2'))
    fig.add_trace(go.Scattergl(x=np.array(x_plot[:, 0]), y=np.array(prior_lower_f), mode='lines',
                               line=dict(width=0), name='Prior -2σ', showlegend=False, xaxis='x2', yaxis='y2'),
                               fill='tonexty', fillcolor=PLOTLY_COLORS["dark_alpha"]) # Fill between upper and lower


    # Plot prior function samples
    if prior_func_samples is not None:
         for i in range(prior_func_samples.shape[1]):
              fig.add_trace(go.Scattergl(x=np.array(x_plot[:, 0]), y=np.array(prior_func_samples[:, i]), mode='lines',
                                       line=dict(color=PLOTLY_COLORS["dark"], width=1, opacity=0.3),
                                       name='Prior Sample Functions', showlegend=False, # Don't show legend for each sample
                                       xaxis='x2', yaxis='y2'))


    # Plot posterior mean function
    fig.add_trace(go.Scattergl(x=np.array(x_plot[:, 0]), y=np.array(posterior_mean_f),
                               mode='lines', name='Posterior Mean',
                               line=dict(color=PLOTLY_COLORS["red"], width=2),
                               showlegend=True, xaxis='x2', yaxis='y2'))
    # Plot posterior uncertainty band (+/- 2 std dev)
    fig.add_trace(go.Scattergl(x=np.array(x_plot[:, 0]), y=np.array(posterior_upper_f), mode='lines',
                               line=dict(width=0), name='Posterior +2σ', showlegend=False, xaxis='x2', yaxis='y2'))
    fig.add_trace(go.Scattergl(x=np.array(x_plot[:, 0]), y=np.array(posterior_lower_f), mode='lines',
                               line=dict(width=0), name='Posterior -2σ', showlegend=False, xaxis='x2', yaxis='y2'),
                               fill='tonexty', fillcolor=PLOTLY_COLORS["red_alpha"]) # Fill between upper and lower


    # Plot posterior function samples
    if posterior_func_samples is not None:
         for i in range(posterior_func_samples.shape[1]):
             fig.add_trace(go.Scattergl(x=np.array(x_plot[:, 0]), y=np.array(posterior_func_samples[:, i]), mode='lines',
                                       line=dict(color=PLOTLY_COLORS["red"], width=1, opacity=0.5),
                                       name='Posterior Sample Functions', showlegend=False, # Don't show legend for each sample
                                       xaxis='x2', yaxis='y2'))

    # Plot Likelihood (MLE Function) if calculated
    if w_mle is not None: # Check if MLE was calculated and stored
         mle_func = jnp.dot(phi_plot, w_mle)
         fig.add_trace(go.Scattergl(x=np.array(x_plot[:, 0]), y=np.array(mle_func), mode='lines',
                                    line=dict(color=PLOTLY_COLORS["blue"], dash='dash', width=2),
                                    name='Likelihood (MLE Function)', showlegend=True,
                                    xaxis='x2', yaxis='y2'))


    # Update layout for function space plot
    fig.update_layout(
        xaxis2=dict(
            title='$x$',
             range=[-5, 5], # Fixed limits
            domain=[0.52, 1] # Position in the subplot grid
        ),
        yaxis2=dict(
            title='$y$',
            range=[-10, 10], # Fixed limits
            domain=[0, 1] # Position in the subplot grid
        )
    )

    # Final layout adjustments
    fig.update_layout(
         hovermode='closest',
         legend=dict(x=0.01, y=0.99), # Position the legend (adjust as needed)
         margin=dict(l=20, r=20, t=40, b=20), # Adjust margins
         title='Bayesian Linear Regression: Parameter and Function Space'
    )


    # Show the plot
    fig.show()


# --- Display widgets and link to update function ---

if X_all is not None: # Only display widgets if data was loaded
    # Arrange widgets
    prior_controls = widgets.VBox([
        widgets.Label("Prior Parameters:"),
        mu0_prior_slider,
        mu1_prior_slider,
        s11_prior_slider,
        s22_prior_slider,
        rho_prior_slider,
    ])

    data_selection_control = widgets.VBox([
        widgets.Label("Data Selection:"),
        data_selector,
    ])

    # Using HBox to place controls side-by-side
    controls = widgets.HBox([prior_controls, data_selection_control])


    # Link widgets to the update function
    # Ensure widget names match function argument names
    interactive_plot = widgets.interactive_output(
        update_regression_plot_plotly,
        {
            'mu0_prior': mu0_prior_slider,
            'mu1_prior': mu1_prior_slider,
            's11_prior': s11_prior_slider,
            's22_prior': s22_prior_slider,
            'rho_prior': rho_prior_slider,
            'selected_indices': data_selector,
        }
    )


    # Display controls and the plot output
    display(controls, interactive_plot)

else:
    print("Cannot display interactive widgets because data was not loaded.")

### Explanation of Bayesian Linear Regression

At its core, Bayesian linear regression is about finding a probability distribution over the possible straight lines (or hyperplanes in higher dimensions) that could have generated the data. Instead of finding a single "best" line, we maintain a belief about what the parameters $w\_0$ (intercept) and $w\_1$ (slope) could be, represented by a joint probability distribution $p(w\_0, w\_1)$.

1.  **The Model:** We assume the data is generated by a linear function corrupted by Gaussian noise: $y = w\_0 + w\_1 x + \\epsilon$, where $\\epsilon \\sim \\mathcal{N}(0, \\sigma^2)$.
2.  **The Prior:** We start with a prior belief about the parameters $w$. A common and mathematically convenient choice is a Gaussian prior $p(w) = \\mathcal{N}(w; \\mu\_{prior}, \\Sigma\_{prior})$. This distribution reflects our initial uncertainty about the intercept and slope before seeing any data.
      - In the **parameter space plot** (left), the contours of the prior Gaussian show regions of higher probability for the $[w\_0, w\_1]$ pair. The center of the ellipse is the prior mean $\\mu\_{prior}$. The shape and orientation of the ellipse are determined by the prior covariance $\\Sigma\_{prior}$.
      - In the **function space plot** (right), the prior mean function $f(x) = \\phi(x)^T \\mu\_{prior}$ is shown as a line. The uncertainty band around it shows the range of function values (e.g., $\\pm 2$ standard deviations) predicted by the prior distribution over parameters. Each sample from the prior in parameter space corresponds to a specific line in function space, illustrating the variety of functions considered plausible under the prior.
3.  **The Likelihood:** The likelihood $p(y | x, w)$ tells us the probability of observing a data point $(x, y)$ given specific parameter values $w$. Due to the Gaussian noise assumption, this likelihood is Gaussian: the observed $y$ is likely to be close to $\\phi(x)^T w$. For multiple independent data points, the joint likelihood $p(Y\_{select} | X\_{select}, w)$ is also Gaussian.
      - In the **parameter space plot**, the "Likelihood (MLE)" point represents the parameter values that maximize the likelihood for the selected data – effectively, the line that best fits *only* the selected data according to the least squares criterion. (Note: The likelihood itself is a function of $w$ for fixed data, not a distribution over $w$ that you can sample from directly).
      - In the **function space plot**, the "Likelihood (MLE Function)" shows the line corresponding to the MLE parameters. The selected data points are also highlighted.
4.  **The Posterior:** When we observe data, we update our prior belief to get the posterior distribution $p(w | X\_{select}, Y\_{select})$. Thanks to the conjugate property of Gaussian priors with Gaussian likelihoods, the posterior is also Gaussian, but with updated mean $\\mu\_{posterior}$ and covariance $\\Sigma\_{posterior}$.
      - The **posterior mean** $\\mu\_{posterior}$ is a weighted average of the prior mean and the information from the data (specifically, the MLE). As you add more data, especially informative data, the posterior mean will move towards the MLE.
      - The **posterior covariance** $\\Sigma\_{posterior}$ is smaller than the prior covariance, reflecting a reduction in uncertainty about the parameters after observing data. As you add more data, the posterior ellipse in parameter space will shrink.
      - In the **parameter space plot** (left subplot), the posterior contours and samples show the updated belief about $w$.
      - In the **function space plot** (right subplot), the posterior mean function shows the line that best fits the selected data, considering the prior. The posterior uncertainty band is typically narrower than the prior band, reflecting increased certainty about the function after seeing data. Samples from the posterior in function space show lines that are plausible given the data and the prior.

By selecting data points using the interactive widget, you can observe how the posterior distribution shifts and shrinks, and how this translates into a more certain belief about the linear relationship between $x$ and $y$ in the function space. The posterior mean function becomes a better fit to the selected data, and the uncertainty band narrows, particularly in regions where data has been observed.

This interactive notebook allows you to visualize this core process of Bayesian learning: starting with a belief (prior), observing evidence (data), and updating that belief (posterior).

```

This raw Markdown content includes the text and code blocks for the Bayesian Linear Regression notebook with `ipywidgets` and Plotly, formatted with `$` and `$$` for KaTeX compatibility. You can copy and paste this directly into a Markdown cell in your Jupyter notebook.

**Key changes for Plotly:**

* Replaced `matplotlib.pyplot` imports and calls with `plotly.graph_objects` (`go`).
* Created a single `go.Figure` with a grid layout for the two subplots.
* Used `fig.add_trace` to add all plot elements (scatter points, lines, filled areas) to the figure, specifying which subplot axis (`xaxis='x1', yaxis='y1'` for parameter space, `xaxis='x2', yaxis='y2'` for function space).
* Replicated contour plotting by drawing lines based on points sampled around ellipses.
* Replicated `fill_between` for uncertainty bands by plotting the upper bound line and filling down to the lower bound line using `fill='tonexty'`.
* Set axis titles, ranges, and aspect ratios using `fig.update_layout` and nested axis dictionaries (`xaxis1`, `yaxis1`, `xaxis2`, `yaxis2`).
* Used `go.Scattergl` for scatter plots, which is often better for performance with many points.
* Added basic color definitions (`PLOTLY_COLORS`) to replace `tueplots.constants.color`.

Remember, you will need `ipywidgets`, `plotly`, `jax`, `jaxlib`, and `scipy` installed, and the `lindata.mat` file accessible, for this notebook to run. The simplified `Gaussian` class is included directly for convenience.
```