In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
import ipywidgets as widgets
from IPython.display import display, clear_output

# Define the parameters for the multivariate normal
mean = np.array([0, 0, 0])  # Mean for 3 variables
# Covariance matrix with correlation
cov = np.array([
    [1.0, 0.7, 0.4],
    [0.7, 1.0, 0.6],
    [0.4, 0.6, 1.0]
])

# Store clicked points globally
clicked_points = []

def conditional_mvn(x3_value):
    """Compute conditional distribution of (x1, x2) given x3"""
    # Extract relevant parts of mean and covariance for conditioning
    mu_a = mean[:2]  # Mean of x1, x2
    mu_b = mean[2]   # Mean of x3
    
    sigma_aa = cov[:2, :2]  # Covariance of x1, x2
    sigma_ab = cov[:2, 2:3]  # Covariance between (x1, x2) and x3
    sigma_bb = cov[2:3, 2:3]  # Variance of x3
    
    # Compute conditional mean and covariance
    sigma_bb_inv = 1.0 / sigma_bb[0, 0]
    conditional_mean = mu_a + sigma_ab.flatten() * sigma_bb_inv * (x3_value - mu_b)
    conditional_cov = sigma_aa - sigma_ab @ sigma_ab.T * sigma_bb_inv
    
    return conditional_mean, conditional_cov

def plot_mvn_and_samples(x3_value=0.0):
    """Create plots for the given x3 value and current samples"""
    global clicked_points
    
    # Close any existing figures to prevent duplicates
    plt.close('all')
    
    # Create figure with two subplots
    fig, (ax_contour, ax_line) = plt.subplots(1, 2, figsize=(15, 6), dpi=500)
    
    # --- Left plot: Contour of conditional distribution ---
    cond_mean, cond_cov = conditional_mvn(x3_value)
    
    # Set up the grid for the contour plot
    x1_range = np.linspace(-3, 3, 100)
    x2_range = np.linspace(-3, 3, 100)
    X1, X2 = np.meshgrid(x1_range, x2_range)
    
    # Compute the PDF over the grid
    pos = np.dstack((X1, X2))
    rv = multivariate_normal(cond_mean, cond_cov)
    Z = rv.pdf(pos)
    
    # Create the contour plot
    contour = ax_contour.contourf(X1, X2, Z, 20, cmap='viridis')
    fig.colorbar(contour, ax=ax_contour)
    
    # Add axis labels and title
    ax_contour.set_title(f'Multivariate Normal (x1, x2 | x3={x3_value:.2f})')
    ax_contour.set_xlabel('x1')
    ax_contour.set_ylabel('x2')
    ax_contour.grid(alpha=0.3)
    
    # Plot any existing clicked points
    if clicked_points:
        for i, (x1, x2, x3) in enumerate(clicked_points):
            if x3 == x3_value:  # Only show points for current x3 value
                ax_contour.plot(x1, x2, 'o', markersize=8, color=f'C{i%10}', alpha=0.7)
    
    # --- Right plot: Line representation of samples ---
    ax_line.set_title('Sampled Function Values (GP Perspective)')
    ax_line.set_xlabel('Variable Index')
    ax_line.set_ylabel('Function Value')
    ax_line.set_xlim(-0.5, 2.5)
    ax_line.set_ylim(-3, 3)
    
    # Set x-ticks for the three variables
    ax_line.set_xticks([0, 1, 2])
    ax_line.set_xticklabels(['x1', 'x2', 'x3'])
    ax_line.grid(alpha=0.3)
    
    # Plot lines for each sample
    if clicked_points:
        for i, (x1, x2, x3) in enumerate(clicked_points):
            if x3 == x3_value:  # Only show points for current x3 value
                ax_line.plot([0, 1, 2], [x1, x2, x3], 'o-', markersize=8, color=f'C{i%10}', alpha=0.7)
    
    plt.tight_layout()
    plt.suptitle('From Multivariate Normal to Gaussian Process', fontsize=16)
    plt.subplots_adjust(top=0.9)
    
    return fig

def on_click(event):
    """Handle click events on the contour plot"""
    global clicked_points
    
    if event.inaxes == plt.gcf().axes[0]:  # Check if click is in the first (contour) subplot
        x1, x2 = event.xdata, event.ydata
        
        # Get current x3 value from the slider
        x3_value = x3_slider.value
        
        # Check if the click is within bounds
        if (x1 is not None and x2 is not None and 
            -3 <= x1 <= 3 and -3 <= x2 <= 3):
            
            # Add the point to our global list
            clicked_points.append((x1, x2, x3_value))
            
            # Redraw the plot
            with output:
                clear_output(wait=True)
                fig = plot_mvn_and_samples(x3_value)
                plt.show()
                
                print(f"Added point: x1={x1:.2f}, x2={x2:.2f}, x3={x3_value:.2f}")
                print(f"Total points: {len(clicked_points)}")

def update_plot(x3):
    """Update the plot when x3 slider changes"""
    with output:
        clear_output(wait=True)
        fig = plot_mvn_and_samples(x3)
        fig.canvas.mpl_connect('button_press_event', on_click)
        plt.show()

def add_random_sample(b):
    """Add a random sample from the current conditional distribution"""
    global clicked_points
    
    # Get current x3 value
    x3_value = x3_slider.value
    
    # Generate random sample from conditional distribution
    cond_mean, cond_cov = conditional_mvn(x3_value)
    x1, x2 = np.random.multivariate_normal(cond_mean, cond_cov)
    
    # Add the point if it's within bounds
    if -3 <= x1 <= 3 and -3 <= x2 <= 3:
        clicked_points.append((x1, x2, x3_value))
        
        # Update the plot
        with output:
            clear_output(wait=True)
            fig = plot_mvn_and_samples(x3_value)
            plt.show()
            
            print(f"Added random sample: x1={x1:.2f}, x2={x2:.2f}, x3={x3_value:.2f}")
            print(f"Total points: {len(clicked_points)}")

def reset_visualization(b):
    """Clear all stored points and redraw"""
    global clicked_points
    clicked_points = []
    
    with output:
        clear_output(wait=True)
        fig = plot_mvn_and_samples(x3_slider.value)
        plt.show()
        print("Visualization reset. Click on the contour plot to add points.")

# Create widgets
x3_slider = widgets.FloatSlider(
    value=0.0,
    min=-2.0,
    max=2.0,
    step=0.1,
    description='x3 value:',
    continuous_update=False
)

random_button = widgets.Button(
    description='Add Random Sample',
    button_style='',
    tooltip='Add a random sample from the distribution'
)
random_button.on_click(add_random_sample)

reset_button = widgets.Button(
    description='Reset',
    button_style='',
    tooltip='Clear all points'
)
reset_button.on_click(reset_visualization)

# Create output area for the plot
output = widgets.Output()

# Set up the interactive component
widgets.interactive_output(update_plot, {'x3': x3_slider})

# Display everything
controls = widgets.HBox([x3_slider, random_button, reset_button])
print("Exercise: Below you see a multivariate normal distribition. Use the button to add random samples and see them visualised on the right.")
print("We have a conditional distribution (x3 is fixed). Can you think of scenarios in regression in which this might be useful?")
display(controls)
display(output)

print("Click on the button above to add points. You can also change the fixed value of x3.")

Exercise: Below you see a multivariate normal distribition. Use the button to add random samples and see them visualised on the right.
We have a conditional distribution (x3 is fixed). Can you think of scenarios in regression in which this might be useful?


HBox(children=(FloatSlider(value=0.0, continuous_update=False, description='x3 value:', max=2.0, min=-2.0), Bu…

Output()

Click on the button above to add points. You can also change the fixed value of x3.


In [4]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ipywidgets as widgets
from IPython.display import display, clear_output

# Function to compute kernel matrices
def rbf_kernel(x1, x2, length_scale=1.0, sigma=1.0):
    """Radial Basis Function (RBF) kernel"""
    dist_sq = np.sum((x1[:, np.newaxis, :] - x2[np.newaxis, :, :]) ** 2, axis=2)
    return sigma**2 * np.exp(-0.5 * dist_sq / length_scale**2)

def periodic_kernel(x1, x2, length_scale=1.0, sigma=1.0, period=1.0):
    """Periodic kernel"""
    dist_sq = np.sum((x1[:, np.newaxis, :] - x2[np.newaxis, :, :]) ** 2, axis=2)
    sin_component = np.sin(np.pi * np.sqrt(dist_sq) / period)
    return sigma**2 * np.exp(-2.0 * (sin_component**2) / length_scale**2)

def linear_kernel(x1, x2, sigma=1.0, c=0.0):
    """Linear kernel with optional constant term"""
    return sigma**2 * (x1 @ x2.T + c)

# Function to generate samples from a Gaussian process
def sample_from_gp(x_points, kernel_matrix, num_samples=5):
    """Generate samples from a Gaussian Process with the specified kernel matrix"""
    # Add small jitter to ensure numerical stability
    K = kernel_matrix + 1e-10 * np.eye(kernel_matrix.shape[0])
    
    # Generate samples from the multivariate normal
    mean = np.zeros(len(x_points))
    samples = np.random.multivariate_normal(mean, K, size=num_samples)
    
    return samples

# Function to create the kernel visualization
def create_visualization(rbf_length=1.0, rbf_sigma=1.0, 
                         periodic_length=1.0, periodic_sigma=1.0, periodic_period=1.0,
                         linear_sigma=1.0, linear_c=0.0, num_points=50, num_samples=5,
                         fixed_colorbar=True):
    """Create a new figure with kernel visualizations based on the parameters"""
    
    # Generate 1D points for kernel evaluation
    x = np.linspace(-3, 3, num_points)
    X = x.reshape(-1, 1)  # Reshape for kernel computation
    
    # Compute kernel matrices
    K_rbf = rbf_kernel(X, X, rbf_length, rbf_sigma)
    K_periodic = periodic_kernel(X, X, periodic_length, periodic_sigma, periodic_period)
    K_linear = linear_kernel(X, X, linear_sigma, linear_c)
    
    # Generate samples
    samples_rbf = sample_from_gp(X, K_rbf, num_samples)
    samples_periodic = sample_from_gp(X, K_periodic, num_samples)
    samples_linear = sample_from_gp(X, K_linear, num_samples)
    
    # Middle index for kernel slice
    middle_idx = num_points // 2
    
    # Create a new figure
    fig = plt.figure(figsize=(18, 12), dpi=100)
    gs = GridSpec(3, 3, figure=fig)
    
    # Set up the figure title
    fig.suptitle('Gaussian Process Kernel Comparison', fontsize=16, y=0.98)
    
    # Set fixed color scale for kernel matrices if requested
    vmin, vmax = 0, 9 if fixed_colorbar else None
    
    # Create RBF kernel plots
    ax_rbf_kernel = fig.add_subplot(gs[0, 0])
    im_rbf = ax_rbf_kernel.imshow(K_rbf, cmap='viridis', aspect='auto', 
                                  extent=[x.min(), x.max(), x.max(), x.min()],
                                  vmin=vmin, vmax=vmax)
    ax_rbf_kernel.set_title(f'RBF Kernel (ℓ = {rbf_length:.2f}, σ² = {rbf_sigma:.2f})')
    ax_rbf_kernel.set_xlabel('x')
    ax_rbf_kernel.set_ylabel('x′')
    cbar_rbf = plt.colorbar(im_rbf, ax=ax_rbf_kernel)
    cbar_rbf.set_label('Covariance Strength')
    
    ax_rbf_slice = fig.add_subplot(gs[0, 1])
    ax_rbf_slice.plot(x, K_rbf[middle_idx], 'b-', linewidth=2)
    ax_rbf_slice.axvline(x=0, color='r', linestyle='--', alpha=0.5)
    ax_rbf_slice.set_title('RBF Kernel from Origin')
    ax_rbf_slice.set_xlabel('x')
    ax_rbf_slice.set_ylabel('k(0, x)')
    ax_rbf_slice.grid(True, alpha=0.3)
    
    # Add fixed y-axis limits to the slice plots if using fixed colorbar
    if fixed_colorbar:
        ax_rbf_slice.set_ylim(0, vmax)
    
    ax_rbf_samples = fig.add_subplot(gs[0, 2])
    for i in range(num_samples):
        if i < samples_rbf.shape[0]:
            ax_rbf_samples.plot(x, samples_rbf[i], alpha=0.7)
    ax_rbf_samples.set_title(f'RBF GP Samples (n = {num_samples})')
    ax_rbf_samples.set_xlabel('x')
    ax_rbf_samples.set_ylabel('f(x)')
    ax_rbf_samples.grid(True, alpha=0.3)
    
    # Create Periodic kernel plots
    ax_periodic_kernel = fig.add_subplot(gs[1, 0])
    im_periodic = ax_periodic_kernel.imshow(K_periodic, cmap='viridis', aspect='auto',
                                           extent=[x.min(), x.max(), x.max(), x.min()],
                                           vmin=vmin, vmax=vmax)
    ax_periodic_kernel.set_title(f'Periodic Kernel (ℓ = {periodic_length:.2f}, σ² = {periodic_sigma:.2f}, p = {periodic_period:.2f})')
    ax_periodic_kernel.set_xlabel('x')
    ax_periodic_kernel.set_ylabel('x′')
    cbar_periodic = plt.colorbar(im_periodic, ax=ax_periodic_kernel)
    cbar_periodic.set_label('Covariance Strength')
    
    ax_periodic_slice = fig.add_subplot(gs[1, 1])
    ax_periodic_slice.plot(x, K_periodic[middle_idx], 'g-', linewidth=2)
    ax_periodic_slice.axvline(x=0, color='r', linestyle='--', alpha=0.5)
    ax_periodic_slice.set_title('Periodic Kernel from Origin')
    ax_periodic_slice.set_xlabel('x')
    ax_periodic_slice.set_ylabel('k(0, x)')
    ax_periodic_slice.grid(True, alpha=0.3)
    
    # Add fixed y-axis limits to the slice plots if using fixed colorbar
    if fixed_colorbar:
        ax_periodic_slice.set_ylim(0, vmax)
    
    ax_periodic_samples = fig.add_subplot(gs[1, 2])
    for i in range(num_samples):
        if i < samples_periodic.shape[0]:
            ax_periodic_samples.plot(x, samples_periodic[i], alpha=0.7)
    ax_periodic_samples.set_title(f'Periodic GP Samples (n = {num_samples})')
    ax_periodic_samples.set_xlabel('x')
    ax_periodic_samples.set_ylabel('f(x)')
    ax_periodic_samples.grid(True, alpha=0.3)
    
    # Create Linear kernel plots
    ax_linear_kernel = fig.add_subplot(gs[2, 0])
    im_linear = ax_linear_kernel.imshow(K_linear, cmap='viridis', aspect='auto',
                                       extent=[x.min(), x.max(), x.max(), x.min()],
                                       vmin=vmin, vmax=vmax)
    ax_linear_kernel.set_title(f'Linear Kernel (σ² = {linear_sigma:.2f}, c = {linear_c:.2f})')
    ax_linear_kernel.set_xlabel('x')
    ax_linear_kernel.set_ylabel('x′')
    cbar_linear = plt.colorbar(im_linear, ax=ax_linear_kernel)
    cbar_linear.set_label('Covariance Strength')
    
    ax_linear_slice = fig.add_subplot(gs[2, 1])
    ax_linear_slice.plot(x, K_linear[middle_idx], 'r-', linewidth=2)
    ax_linear_slice.axvline(x=0, color='r', linestyle='--', alpha=0.5)
    ax_linear_slice.set_title('Linear Kernel from Origin')
    ax_linear_slice.set_xlabel('x')
    ax_linear_slice.set_ylabel('k(0, x)')
    ax_linear_slice.grid(True, alpha=0.3)
    
    # Add fixed y-axis limits to the slice plots if using fixed colorbar
    if fixed_colorbar:
        ax_linear_slice.set_ylim(0, vmax)
    
    ax_linear_samples = fig.add_subplot(gs[2, 2])
    for i in range(num_samples):
        if i < samples_linear.shape[0]:
            ax_linear_samples.plot(x, samples_linear[i], alpha=0.7)
    ax_linear_samples.set_title(f'Linear GP Samples (n = {num_samples})')
    ax_linear_samples.set_xlabel('x')
    ax_linear_samples.set_ylabel('f(x)')
    ax_linear_samples.grid(True, alpha=0.3)
    
    # Adjust layout
    fig.tight_layout()
    fig.subplots_adjust(top=0.93)
    
    return fig

# Interactive widgets with button update
def interactive_kernel_visualization():
    # Create sliders for RBF kernel parameters
    rbf_length_slider = widgets.FloatSlider(
        value=1.0, min=0.1, max=3.0, step=0.1,
        description='RBF ℓ:',
        style={'description_width': 'initial'}
    )
    rbf_sigma_slider = widgets.FloatSlider(
        value=1.0, min=0.1, max=3.0, step=0.1,
        description='RBF σ²:',
        style={'description_width': 'initial'}
    )
    
    # Create sliders for Periodic kernel parameters
    periodic_length_slider = widgets.FloatSlider(
        value=1.0, min=0.1, max=3.0, step=0.1,
        description='Periodic ℓ:',
        style={'description_width': 'initial'}
    )
    periodic_sigma_slider = widgets.FloatSlider(
        value=1.0, min=0.1, max=3.0, step=0.1,
        description='Periodic σ²:',
        style={'description_width': 'initial'}
    )
    periodic_period_slider = widgets.FloatSlider(
        value=1.0, min=0.1, max=3.0, step=0.1,
        description='Period p:',
        style={'description_width': 'initial'}
    )
    
    # Create sliders for Linear kernel parameters
    linear_sigma_slider = widgets.FloatSlider(
        value=1.0, min=0.1, max=3.0, step=0.1,
        description='Linear σ²:',
        style={'description_width': 'initial'}
    )
    linear_c_slider = widgets.FloatSlider(
        value=0.0, min=0.0, max=2.0, step=0.1,
        description='Constant c:',
        style={'description_width': 'initial'}
    )
    
    # Number of samples slider
    samples_slider = widgets.IntSlider(
        value=5, min=1, max=10, step=1,
        description='Samples:',
        style={'description_width': 'initial'}
    )
    
    # Checkbox for fixed color scale
    fixed_scale_checkbox = widgets.Checkbox(
        value=True,
        description='Use fixed color scale (0-9)',
        style={'description_width': 'initial'}
    )
    
    # Create tabs for different kernel parameter groups
    rbf_params = widgets.VBox([rbf_length_slider, rbf_sigma_slider])
    periodic_params = widgets.VBox([periodic_length_slider, periodic_sigma_slider, periodic_period_slider])
    linear_params = widgets.VBox([linear_sigma_slider, linear_c_slider])
    
    tab = widgets.Tab()
    tab.children = [rbf_params, periodic_params, linear_params]
    tab.set_title(0, 'RBF')
    tab.set_title(1, 'Periodic')
    tab.set_title(2, 'Linear')
    
    # Samples control and fixed scale checkbox
    display_options = widgets.VBox([
        samples_slider,
        fixed_scale_checkbox
    ])
    
    # Create a button to trigger the update
    update_button = widgets.Button(
        description='Update Visualization',
        button_style='primary',
        tooltip='Click to update the visualization with current parameter values'
    )
    
    # Random seed button for new samples
    random_seed_button = widgets.Button(
        description='New Random Samples',
        button_style='info',
        tooltip='Click to generate new random samples'
    )
    
    # Combine all controls
    controls = widgets.VBox([
        tab, 
        display_options,
        widgets.HBox([update_button, random_seed_button])
    ])
    
    # Output area for the plots
    output = widgets.Output()
    
    # Define the update function
    def update_visualization(button):
        with output:
            # Clear previous output
            clear_output(wait=True)
            
            # Create and display a fresh visualization
            fig = create_visualization(
                rbf_length=rbf_length_slider.value,
                rbf_sigma=rbf_sigma_slider.value,
                periodic_length=periodic_length_slider.value,
                periodic_sigma=periodic_sigma_slider.value,
                periodic_period=periodic_period_slider.value,
                linear_sigma=linear_sigma_slider.value,
                linear_c=linear_c_slider.value,
                num_samples=samples_slider.value,
                fixed_colorbar=fixed_scale_checkbox.value
            )
            plt.show()
    
    # Define function to set a new random seed
    def new_random_samples(button):
        # Set a new random seed
        np.random.seed(np.random.randint(1000))
        # Update visualization
        update_visualization(button)
    
    # Connect buttons to their functions
    update_button.on_click(update_visualization)
    random_seed_button.on_click(new_random_samples)
    
    # Perform initial update to show the visualization
    with output:
        # Set initial random seed for consistency
        np.random.seed(42)
        fig = create_visualization(
            rbf_length=rbf_length_slider.value,
            rbf_sigma=rbf_sigma_slider.value,
            periodic_length=periodic_length_slider.value,
            periodic_sigma=periodic_sigma_slider.value,
            periodic_period=periodic_period_slider.value,
            linear_sigma=linear_sigma_slider.value,
            linear_c=linear_c_slider.value,
            num_samples=samples_slider.value,
            fixed_colorbar=fixed_scale_checkbox.value
        )
        plt.show()
    
    # Return the complete UI
    return widgets.VBox([controls, output])

# Investigating different kernel functions
### Radial Basis Function (RBF) kernel

$k(x, x') = \sigma^2 \exp\left(-\frac{1}{2} \frac{\|x - x'\|^2}{\ell^2}\right)$

with $\sigma^2$ as output variance, $\ell$ as length scale and $\|x - x'\|^2$ as the squared Euclidean distance between the vectors.

### Periodic kernel
$k(x, x') = \sigma^2 \exp\left(-\frac{2\sin^2(\pi\ \|x - x'\|\ p)}{\ell^2}\right)$

with $\sigma^2$ as output variance, $\ell$ as length scale, $p$ as the period and $\|x - x'\|$ as the Euclidean distance between the vectors.

### Linear kernel
$k(x, x') = \sigma^2 (x^T x' + c)$

with $\sigma^2$ as output variance and $c$ as the constant offset.

## Exercise

Let's have a closer look at these three kernels. Play around with the different parameters and see how they influence the output.

In [None]:
# Create and display the interactive visualization
visualization = interactive_kernel_visualization()
display(visualization)
print("Can you think about different situation in which these kernels might be useful?")

VBox(children=(VBox(children=(Tab(children=(VBox(children=(FloatSlider(value=1.0, description='RBF ℓ:', max=3.…

Can you think about different situation in which these kernels might be useful?
