In [None]:
from models.laplace import LaplaceBNN

: 

In [None]:
import numpy as np
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Import your LaplaceBNN class
from laplace_bnn import LaplaceBNN

# Generate synthetic data
def generate_data(n_train=20, n_test=100, noise_std=0.2):
    x_train = np.linspace(-3, 3, n_train)
    y_train = np.sin(x_train) + noise_std * np.random.randn(n_train)
    x_test = np.linspace(-3.5, 3.5, n_test)
    y_test = np.sin(x_test)
    return (
        torch.tensor(x_train, dtype=torch.float64).unsqueeze(-1),
        torch.tensor(y_train, dtype=torch.float64).unsqueeze(-1),
        torch.tensor(x_test, dtype=torch.float64).unsqueeze(-1),
        torch.tensor(y_test, dtype=torch.float64).unsqueeze(-1),
    )

# Visualize the training data and true function
def plot_results(x_train, y_train, x_test, y_test, mean, lower, upper):
    fig = make_subplots(rows=1, cols=1)

    # True function
    fig.add_trace(
        go.Scatter(
            x=x_test.squeeze().numpy(),
            y=y_test.squeeze().numpy(),
            mode="lines",
            name="True Function",
            line=dict(dash="dash"),
        )
    )

    # Training data
    fig.add_trace(
        go.Scatter(
            x=x_train.squeeze().numpy(),
            y=y_train.squeeze().numpy(),
            mode="markers",
            name="Training Data",
        )
    )

    # Predictive mean
    fig.add_trace(
        go.Scatter(
            x=x_test.squeeze().numpy(),
            y=mean.squeeze().numpy(),
            mode="lines",
            name="Predictive Mean",
        )
    )

    # Predictive credible interval
    fig.add_trace(
        go.Scatter(
            x=np.concatenate([x_test.squeeze().numpy(), x_test.squeeze().numpy()[::-1]]),
            y=np.concatenate([upper.squeeze().numpy(), lower.squeeze().numpy()[::-1]]),
            fill="toself",
            fillcolor="rgba(0,100,250,0.2)",
            line=dict(color="rgba(255,255,255,0)"),
            hoverinfo="skip",
            name="Credible Interval",
        )
    )

    fig.update_layout(
        title="Laplace-BNN Regression",
        xaxis_title="X",
        yaxis_title="Y",
        legend=dict(x=0, y=1.1, orientation="h"),
    )
    fig.show()


# Define the main script
def main():
    # Generate data
    x_train, y_train, x_test, y_test = generate_data()

    # Define LaplaceBNN arguments
    args = {
        "regnet_dims": [32, 32],
        "regnet_activation": torch.nn.ReLU,
        "prior_var": 1.0,
        "noise_var": 0.1,
        "iterative": False,
    }

    # Instantiate the model
    device = torch.device("cpu")
    model = LaplaceBNN(args, input_dim=1, output_dim=1, device=device)

    # Train the model
    model.fit_and_save(x_train, y_train, save_dir="./")

    # Predictive posterior
    posterior = model.posterior(x_test)
    mean = posterior.mean
    std_dev = torch.sqrt(posterior.variance)

    # Credible intervals (mean ± 2*std_dev)
    lower = mean - 2 * std_dev
    upper = mean + 2 * std_dev

    # Visualize the results
    plot_results(x_train, y_train, x_test, y_test, mean, lower, upper)


if __name__ == "__main__":
    main()
