In [5]:
from matplotlib import pyplot as plt
from plotting import plot_regression_with_uncertainty
import pickle
from jax import numpy as jnp
import numpy as np

In [4]:
X_pred, pred = pickle.load(open("../prob_predictive.pkl", "rb"))

In [None]:
def unpack_preds(pred):
    y_pred=pred["pred_mean"][:, 0]
    y_std=jnp.sqrt(pred["pred_var"][:, 0])
    y_samples=pred["samples"]
    return y_pred, y_std, y_samples

def plot_regression_with_uncertainty(
    data, 
    X_pred,
    pred_rbf,
    pred_matern,
    pred_periodic,
):
    X_train=data["input"],
    y_train=data["target"],
    X_pred=X_pred,
    y_pred_rbf, y_std_rbf, y_samples_rbf = unpack_preds(pred_rbf)
    y_pred_matern, y_std_matern, y_samples_matern = unpack_preds(pred_matern)
    y_pred_periodic, y_std_periodic, y_samples_periodic = unpack_preds(pred_periodic)
    plt.figure(figsize=(10, 6))
    # Convert to numpy arrays if they are JAX arrays
    if hasattr(X_train, "device_buffer"):
        X_train = np.array(X_train)
        y_train = np.array(y_train)
    if X_test is not None:
        X_test = np.array(X_test)
        y_test = np.array(y_test)
    if X_pred is not None and y_pred is not None:
        X_pred = np.array(X_pred)
        y_pred = np.array(y_pred)
    if y_std is not None:
        y_std = np.array(y_std)
    if y_samples is not None:
        y_samples = np.array(y_samples)

    # Plot training data
    plt.scatter(X_train, y_train, color="blue", alpha=0.6, label="Training data")

    # Plot test data if provided
    if X_test is not None and y_test is not None:
        plt.scatter(X_test, y_test, color="green", alpha=0.6, label="Test data")

    # Plot prediction with uncertainty if provided
    if y_pred is not None:
        # If X_pred is not provided but X_test is, use X_test for predictions
        X_plot = X_pred if X_pred is not None else X_test
        # Only proceed if we have points to plot predictions for
        if X_plot is not None:
            # Sort X for proper line plotting
            sort_idx = np.argsort(X_plot.flatten())
            X_plot_sorted = X_plot[sort_idx]
            y_pred_sorted = y_pred[sort_idx]
            plt.plot(X_plot_sorted, y_pred_sorted, color="red", label="Prediction")

            # Plot uncertainty if provided
            if y_std is not None:
                y_std_sorted = y_std[sort_idx]
                plt.fill_between(
                    X_plot_sorted.flatten(),
                    (y_pred_sorted - 2 * y_std_sorted).flatten(),
                    (y_pred_sorted + 2 * y_std_sorted).flatten(),
                    color="red",
                    alpha=0.2,
                    label="95% confidence interval",
                )

            # Plot posterior samples if provided
            if y_samples is not None:
                # Assuming y_samples shape is (n_samples, n_posterior_samples)
                y_samples_sorted = y_samples[sort_idx]
                # Plot each posterior sample with low opacity
                for i in range(y_samples_sorted.shape[1]):
                    plt.plot(
                        X_plot_sorted,
                        y_samples_sorted[:, i],
                        color="purple",
                        alpha=0.1,
                        linewidth=1,
                        # Only label the first sample to avoid cluttering the legend
                        label="Posterior samples" if i == 0 else None,
                    )

    # Plot true function
    x_true = np.linspace(0, 8, 1000).reshape(-1, 1)
    y_true = np.sin(x_true)
    plt.plot(x_true, y_true, color="black", linestyle="--", label="True function")

    # Add labels and title
    plt.xlabel("x")
    plt.ylabel("y")
    if title:
        plt.title(title)
    else:
        plt.title("Regression with Uncertainty")

    # Set y-axis limit if y_max is provided
    if y_max is not None:
        plt.ylim(bottom=-y_max, top=y_max)

    plt.legend()
    plt.grid(True, alpha=0.3)
    return plt.gcf()
