In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider


def multivariate_gaussian_pdf(x, y, mean, cov):
    """
    Compute the value of the PDF of a 2D Gaussian at points (x, y).
    
    :param x: Meshgrid array for x-coordinates
    :param y: Meshgrid array for y-coordinates
    :param mean: [mu_x, mu_y]
    :param cov: 2x2 covariance matrix
    :return: Values of the PDF at each (x, y)
    """
    # Flatten x, y for vectorized operations
    xy = np.column_stack([x.ravel(), y.ravel()])
    
    # Inverse and determinant of covariance
    cov_inv = np.linalg.inv(cov)
    cov_det = np.linalg.det(cov)
    
    # Centered coordinates
    xy_centered = xy - mean
    
    # Mahalanobis distance
    mahal = np.sum((xy_centered @ cov_inv) * xy_centered, axis=1)
    
    # PDF formula for the multivariate Gaussian
    pdf = np.exp(-0.5 * mahal) / (2.0 * np.pi * np.sqrt(cov_det))
    
    return pdf.reshape(x.shape)

def plot_gaussian(mu_x=0.0, mu_y=0.0, sigma_x=1.0, sigma_y=1.0, rho=0.0):
    """
    Plot the 2D Multivariate Gaussian PDF using a contour plot.
    
    :param mu_x: Mean in x-dimension
    :param mu_y: Mean in y-dimension
    :param sigma_x: Std dev in x-dimension
    :param sigma_y: Std dev in y-dimension
    :param rho: Correlation coefficient (-1 < rho < 1)
    """
    # Build the mean vector and covariance matrix
    mean = np.array([mu_x, mu_y])
    cov = np.array([
        [sigma_x**2,         rho * sigma_x * sigma_y],
        [rho * sigma_x * sigma_y, sigma_y**2        ]
    ])
    
    # Create a meshgrid for plotting
    grid_size = 100
    x_vals = np.linspace(-5, 5, grid_size)
    y_vals = np.linspace(-5, 5, grid_size)
    X, Y = np.meshgrid(x_vals, y_vals)
    
    # Evaluate the PDF on the grid
    Z = multivariate_gaussian_pdf(X, Y, mean, cov)
    
    # Plot
    plt.figure(figsize=(6, 5))
    contour = plt.contourf(X, Y, Z, levels=30, cmap='viridis')
    plt.colorbar(contour, label='PDF Value')
    plt.title("2D Multivariate Gaussian")
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.xlim([-5, 5])
    plt.ylim([-5, 5])
    plt.gca().set_aspect('equal', 'box')
    plt.show()

# Create an interactive widget with sliders
interact(
    plot_gaussian,
    mu_x=FloatSlider(value=0.0, min=-3.0, max=3.0, step=0.1, description='Mean X'),
    mu_y=FloatSlider(value=0.0, min=-3.0, max=3.0, step=0.1, description='Mean Y'),
    sigma_x=FloatSlider(value=1.0, min=0.1, max=3.0, step=0.1, description='Sigma X'),
    sigma_y=FloatSlider(value=1.0, min=0.1, max=3.0, step=0.1, description='Sigma Y'),
    rho=FloatSlider(value=0.0, min=-0.9, max=0.9, step=0.1, description='Correlation')
);


interactive(children=(FloatSlider(value=0.0, description='Mean X', max=3.0, min=-3.0), FloatSlider(value=0.0, …