In [None]:
!pip install laplace-torch vbll > out.txt

# Laplace

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import plotly.graph_objects as go
from laplace import Laplace
from torch.utils.data import TensorDataset, DataLoader

# -------------------
# 1) Create a synthetic dataset
# -------------------
# We'll use a simple non-linear function: y = x^3 + noise

N_TRAIN = 1  # number of training points
N_TEST = 200   # number of test points

# Random seed for reproducibility
torch.manual_seed(42)

# Generate training data
x_train = torch.linspace(-3, 3, steps=N_TRAIN).unsqueeze(-1)
y_train = x_train**3 + 2.0 * torch.randn_like(x_train)  # add some noise

# Generate test data for evaluating and plotting
x_test = torch.linspace(-4, 4, steps=N_TEST).unsqueeze(-1)
y_true = x_test**3  # ground truth (no noise)

# Create DataLoader for training
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# -------------------
# 2) Define the MLP architecture
# -------------------
# 3 hidden layers, 128 units, tanh activation

class MLP(nn.Module):
    def __init__(
        self,
        input_dim=1,
        num_hidden_layer=3,
        hidden_dim=128,
        output_dim=1,
        activation="elu"
    ):
        super().__init__()
        self.activation = self._get_activation(activation)

        # Dynamically create the layers
        layers = []

        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(self.activation)

        # Hidden layers
        for _ in range(num_hidden_layer - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(self.activation)

        # Output layer
        layers.append(nn.Linear(hidden_dim, output_dim))

        # Combine all layers into a Sequential module
        self.net = nn.Sequential(*layers)

    def _get_activation(self, activation_name):
        if activation_name == 'relu':
            return nn.ReLU()
        elif activation_name == 'tanh':
            return nn.Tanh()
        elif activation_name == 'elu':
            return nn.ELU()
        elif activation_name == 'sigmoid':
            return nn.Sigmoid()
        else:
            raise ValueError(f"Unknown activation: {activation_name}")

    def forward(self, x):
        return self.net(x)

model = MLP()

# -------------------
# 3) Train the MLP
# -------------------
# We will use a standard MSE loss and Adam optimizer

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)
criterion = nn.MSELoss()

n_epochs = 1500
for epoch in range(n_epochs):
    model.train()
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        y_pred = model(x_batch)
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optimizer.step()

    if (epoch + 1) % 50 == 0:
        print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}")

# -------------------
# 4) Apply Last-Layer Laplace Approximation
# -------------------
la = Laplace(
    model,
    'regression',
    subset_of_weights='last_layer',
    hessian_structure='full',
    prior_precision=0.01,
    sigma_noise=1.0,
    temperature=1.0,
)
la.fit(train_loader)

# Optimize the prior precision
# Using the training set as a validation dataset (you can split data if needed)
# la.optimize_prior_precision(method='gridsearch', val_loader=DataLoader(TensorDataset(x_train, y_train), batch_size=32))

# Function to predict with Laplace
def predict_with_laplace(model_laplace, x):
    """Return predictive mean and variance given input x."""
    mean, var = model_laplace(x)
    mean = mean.detach().squeeze(-1)
    var = var.squeeze(1).squeeze(-1)
    return mean, var

model.eval()
with torch.no_grad():
    y_mean, y_var = predict_with_laplace(la, x_test)
y_std = torch.sqrt(y_var)

# -------------------
# 5) Visualization with Plotly
# -------------------
fig = go.Figure()

# True function
fig.add_trace(go.Scatter(
    x=x_test.squeeze().cpu().numpy(),
    y=y_true.squeeze().cpu().numpy(),
    mode='lines',
    name='True function (x^3)',
    line=dict(color='green', width=2)
))

# Predictive mean
fig.add_trace(go.Scatter(
    x=x_test.squeeze().cpu().numpy(),
    y=y_mean.cpu().numpy(),
    mode='lines',
    name='Predictive mean',
    line=dict(color='blue', width=2)
))

# Uncertainty band: mean +/- 2 std
fig.add_trace(go.Scatter(
    x=np.concatenate([x_test.squeeze().cpu().numpy(), x_test.squeeze().cpu().numpy()[::-1]]),
    y=np.concatenate([(y_mean - 2*y_std).cpu().numpy(), (y_mean + 2*y_std).cpu().numpy()[::-1]]),
    fill='toself',
    fillcolor='rgba(135,206,250,0.4)',  # light-blue shade
    line=dict(color='rgba(255,255,255,0)'),
    hoverinfo='skip',
    showlegend=True,
    name='Confidence (±2σ)'
))

# Training points
fig.add_trace(go.Scatter(
    x=x_train.squeeze().cpu().numpy(),
    y=y_train.squeeze().cpu().numpy(),
    mode='markers',
    name='Train data',
    marker=dict(color='red', size=4, opacity=0.8)
))

fig.update_layout(
    title='Last-layer Laplace MLP Regression',
    xaxis_title='x',
    yaxis_title='y',
    width=800,
    height=500,
    plot_bgcolor='white',
)

fig.show()

Epoch [50/1500], Loss: 2.7948
Epoch [100/1500], Loss: 0.0080
Epoch [150/1500], Loss: 0.0000
Epoch [200/1500], Loss: 0.0000
Epoch [250/1500], Loss: 0.0000
Epoch [300/1500], Loss: 0.0000
Epoch [350/1500], Loss: 0.0000
Epoch [400/1500], Loss: 0.0000
Epoch [450/1500], Loss: 0.0000
Epoch [500/1500], Loss: 0.0000
Epoch [550/1500], Loss: 0.0000
Epoch [600/1500], Loss: 0.0000
Epoch [650/1500], Loss: 0.0000
Epoch [700/1500], Loss: 0.0000
Epoch [750/1500], Loss: 0.0000
Epoch [800/1500], Loss: 0.0000
Epoch [850/1500], Loss: 0.0000
Epoch [900/1500], Loss: 0.0000
Epoch [950/1500], Loss: 0.0000
Epoch [1000/1500], Loss: 0.0000
Epoch [1050/1500], Loss: 0.0000
Epoch [1100/1500], Loss: 0.0000
Epoch [1150/1500], Loss: 0.0000
Epoch [1200/1500], Loss: 0.0000
Epoch [1250/1500], Loss: 0.0000
Epoch [1300/1500], Loss: 0.0000
Epoch [1350/1500], Loss: 0.0000
Epoch [1400/1500], Loss: 0.0000
Epoch [1450/1500], Loss: 0.0000
Epoch [1500/1500], Loss: 0.0000


In [12]:
mean, covariance = la(x_test, joint=True)
mean.shape, covariance.shape

(torch.Size([200]), torch.Size([200, 200]))

In [16]:
*mean.shape[:-1]

SyntaxError: can't use starred expression here (488112645.py, line 1)

# VBLL

In [None]:
import torch
import torch.nn as nn
import numpy as np
import plotly.graph_objects as go
from torch.utils.data import TensorDataset, DataLoader
import vbll

# -------------------
# 1) Create a synthetic dataset
# -------------------
# We'll use a simple non-linear function: y = x^3 + noise

N_TRAIN = 2  # number of training points
N_TEST = 200   # number of test points

# Random seed for reproducibility
torch.manual_seed(42)

# Generate training data
x_train = torch.linspace(-3, 3, steps=N_TRAIN).unsqueeze(-1)
y_train = x_train**3 + 2.0 * torch.randn_like(x_train)  # add some noise

# Generate test data for evaluating and plotting
x_test = torch.linspace(-4, 4, steps=N_TEST).unsqueeze(-1)
y_true = x_test**3  # ground truth (no noise)

# Create DataLoader for training
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# -------------------
# 2) Define VBLLMLP architecture
# -------------------
class VBLLMLP(nn.Module):
    def __init__(self, cfg):
        super(VBLLMLP, self).__init__()

        self.params = nn.ModuleDict({
            'in_layer': nn.Linear(cfg.IN_FEATURES, cfg.HIDDEN_FEATURES),
            'core': nn.ModuleList([
                nn.Linear(cfg.HIDDEN_FEATURES, cfg.HIDDEN_FEATURES) for _ in range(cfg.NUM_LAYERS)
            ]),
            'out_layer': vbll.Regression(
                cfg.HIDDEN_FEATURES,
                cfg.OUT_FEATURES,
                cfg.REG_WEIGHT,
                prior_scale=cfg.PRIOR_SCALE,
                wishart_scale=cfg.WISHART_SCALE
            )
        })

        self.activations = nn.ModuleList([nn.Tanh() for _ in range(cfg.NUM_LAYERS)])
        self.cfg = cfg

    def forward(self, x):
        x = self.params['in_layer'](x)
        for layer, activation in zip(self.params['core'], self.activations):
            x = activation(layer(x))
        return self.params['out_layer'](x)

# Configuration class
class Config:
    IN_FEATURES = 1
    HIDDEN_FEATURES = 128
    OUT_FEATURES = 1
    NUM_LAYERS = 3
    REG_WEIGHT = (1.0 / N_TRAIN)
    PARAM = 'dense'
    PRIOR_SCALE = 10
    WISHART_SCALE = 1

cfg = Config()
vbll_model = VBLLMLP(cfg)

# -------------------
# 3) Train the VBLL Model
# -------------------
class TrainConfig:
    NUM_EPOCHS = 2000
    BATCH_SIZE = 32
    LR = 1e-3
    WD = 0
    OPT = torch.optim.Adam
    CLIP_VAL = 1.0
    VAL_FREQ = 100

train_cfg = TrainConfig()

# Training function
def train_vbll(dataloader, model, train_cfg, verbose=True):
    param_list = [
        {'params': model.params['in_layer'].parameters(), 'weight_decay': train_cfg.WD},
        {'params': model.params['core'].parameters(), 'weight_decay': train_cfg.WD},
        {'params': model.params['out_layer'].parameters(), 'weight_decay': 0.0},
    ]

    optimizer = train_cfg.OPT(param_list, lr=train_cfg.LR)

    for epoch in range(train_cfg.NUM_EPOCHS + 1):
        model.train()
        running_loss = []

        for x_batch, y_batch in dataloader:
            optimizer.zero_grad()
            out = model(x_batch)
            loss = out.train_loss_fn(y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_cfg.CLIP_VAL)
            optimizer.step()
            running_loss.append(loss.item())

        if epoch % train_cfg.VAL_FREQ == 0 and verbose:
            print(f"Epoch: {epoch}, Loss: {np.mean(running_loss):.4f}")

train_vbll(train_loader, vbll_model, train_cfg)

# -------------------
# 4) Make Predictions
# -------------------
@torch.no_grad()
def predict_vbll(model, x):
    pred_dist = model(x).predictive
    mean = pred_dist.mean.squeeze(-1)
    variance = pred_dist.variance.squeeze(-1)
    return mean, variance

vbll_model.eval()
y_mean, y_var = predict_vbll(vbll_model, x_test)
y_std = torch.sqrt(y_var)

# -------------------
# 5) Visualization with Plotly
# -------------------
fig = go.Figure()

# True function
fig.add_trace(go.Scatter(
    x=x_test.squeeze().cpu().numpy(),
    y=y_true.squeeze().cpu().numpy(),
    mode='lines',
    name='True function (x^3)',
    line=dict(color='green', width=2)
))

# Predictive mean
fig.add_trace(go.Scatter(
    x=x_test.squeeze().cpu().numpy(),
    y=y_mean.cpu().numpy(),
    mode='lines',
    name='Predictive mean',
    line=dict(color='blue', width=2)
))

# Uncertainty band: mean +/- 2 std
fig.add_trace(go.Scatter(
    x=np.concatenate([x_test.squeeze().cpu().numpy(), x_test.squeeze().cpu().numpy()[::-1]]),
    y=np.concatenate([(y_mean - 2 * y_std).cpu().numpy(), (y_mean + 2 * y_std).cpu().numpy()[::-1]]),
    fill='toself',
    fillcolor='rgba(135,206,250,0.4)',  # light-blue shade
    line=dict(color='rgba(255,255,255,0)'),
    hoverinfo='skip',
    showlegend=True,
    name='Confidence (±2σ)'
))

# Training points
fig.add_trace(go.Scatter(
    x=x_train.squeeze().cpu().numpy(),
    y=y_train.squeeze().cpu().numpy(),
    mode='markers',
    name='Train data',
    marker=dict(color='red', size=4, opacity=0.8)
))

fig.update_layout(
    title='VBLL MLP Regression',
    xaxis_title='x',
    yaxis_title='y',
    width=800,
    height=500,
    plot_bgcolor='white',
)

fig.show()


Epoch: 0, Loss: 913.0225
Epoch: 100, Loss: 392.6157
Epoch: 200, Loss: 335.5594
Epoch: 300, Loss: 287.7934
Epoch: 400, Loss: 246.8021
Epoch: 500, Loss: 211.2416
Epoch: 600, Loss: 180.3139
Epoch: 700, Loss: 153.5719
Epoch: 800, Loss: 130.6942
Epoch: 900, Loss: 112.1152
Epoch: 1000, Loss: 97.2979
Epoch: 1100, Loss: 85.8449
Epoch: 1200, Loss: 77.1717
Epoch: 1300, Loss: 70.6484
Epoch: 1400, Loss: 65.7498
Epoch: 1500, Loss: 62.0622
Epoch: 1600, Loss: 59.2813
Epoch: 1700, Loss: 57.1864
Epoch: 1800, Loss: 55.5869
Epoch: 1900, Loss: 54.3563
Epoch: 2000, Loss: 53.3990


In [None]:
temp = vbll_model(x_test)
temp

VBLLReturn(predictive=Normal(loc: torch.Size([200, 1]), scale: torch.Size([200, 1])), train_loss_fn=<function Regression._get_train_loss_fn.<locals>.loss_fn at 0x7a49041c9480>, val_loss_fn=<function Regression._get_val_loss_fn.<locals>.loss_fn at 0x7a47d4df4280>, ood_scores=None)

In [None]:
type(temp.predictive)

In [None]:

class train_cfg:
  NUM_EPOCHS = 1000
  BATCH_SIZE = 32
  LR = 1e-3
  WD = 1e-4
  OPT = torch.optim.AdamW
  CLIP_VAL = 1
  VAL_FREQ = 100

class cfg:
    IN_FEATURES = 1
    HIDDEN_FEATURES = 64
    OUT_FEATURES = 1
    NUM_LAYERS = 4
    REG_WEIGHT = 1./dataset.__len__()
    PARAM = 'dense'
    PRIOR_SCALE = 1.
    WISHART_SCALE = .1

In [None]:
import torch
import torch.nn as nn

class RegNet(nn.Sequential):
    def __init__(self, dimensions, activation, input_dim=1, output_dim=1,
                 dtype=torch.float64, device="cpu"):
        super(RegNet, self).__init__()

        # Combine input, hidden, and output dimensions
        self.dimensions = [input_dim, *dimensions, output_dim]

        # Loop to create layers
        for i in range(len(self.dimensions) - 1):
            # Add linear layer
            self.add_module(f'linear{i}', nn.Linear(
                self.dimensions[i], self.dimensions[i + 1], dtype=dtype, device=device
            ))

            # Add activation function for hidden layers
            if i < len(self.dimensions) - 2:
                if activation == "tanh":
                    self.add_module(f'tanh{i}', nn.Tanh())
                elif activation == "relu":
                    self.add_module(f'relu{i}', nn.ReLU())
                elif activation == "elu":
                    self.add_module(f'elu{i}', nn.ELU())
                else:
                    raise NotImplementedError(f"Activation type '{activation}' is not supported.")

# Example usage
if __name__ == "__main__":
    # Create a RegNet with specific parameters
    model = RegNet(
        dimensions=[32, 64, 32],  # Hidden layers with 32, 64, and 32 units
        activation="relu",       # Use ReLU activation
        input_dim=10,            # Input has 10 features
        output_dim=1,            # Single output for regression
        dtype=torch.float32,     # Use 32-bit precision
        device="cpu"            # Run on CPU
    )

    # Print the model architecture
    print(model)


In [None]:
import torch
import torch.nn as nn
from typing import Union


class MLP(nn.Sequential):
    def __init__(
        self,
        dimensions: list[int],
        activation: str,
        input_dim: int = 1,
        output_dim: int = 1,
        dtype: torch.dtype = torch.float64,
        device: Union[str, torch.device] = "cpu",
    ) -> None:
        """
        A Multi-Layer Perceptron (MLP) model.

        Args:
            dimensions (list[int]): List of hidden layer dimensions.
            activation (str): Activation function ('tanh', 'relu', 'elu').
            input_dim (int): Dimension of the input.
            output_dim (int): Dimension of the output.
            dtype (torch.dtype): Data type for the model's parameters.
            device (Union[str, torch.device]): Device for the model.
        """
        super(MLP, self).__init__()

        # Combine input, hidden, and output dimensions
        self.dimensions: list[int] = [input_dim, *dimensions, output_dim]

        # Loop to create layers
        for i in range(len(self.dimensions) - 1):
            # Add linear layer
            self.add_module(
                f"linear{i}",
                nn.Linear(
                    self.dimensions[i],
                    self.dimensions[i + 1],
                    dtype=dtype,
                    device=device,
                ),
            )

            # Add activation function for hidden layers
            if i < len(self.dimensions) - 2:
                if activation == "tanh":
                    self.add_module(f"tanh{i}", nn.Tanh())
                elif activation == "relu":
                    self.add_module(f"relu{i}", nn.ReLU())
                elif activation == "elu":
                    self.add_module(f"elu{i}", nn.ELU())
                else:
                    raise NotImplementedError(
                        f"Activation type '{activation}' is not supported."
                    )

In [None]:
import vbll
from typing import Union


class VBLLMLP(MLP):
    def __init__(
        self,
        dimensions: list[int],
        activation: str,
        input_dim: int = 1,
        output_dim: int = 1,
        dtype: torch.dtype = torch.float64,
        device: Union[str, torch.device] = "cpu",
        reg_weight: float = 1.0,
        prior_scale: float = 1.0,
        wishart_scale: float = 0.1,
    ) -> None:
        """
        A Multi-Layer Perceptron (MLP) with a VBLL last layer.

        Args:
            dimensions (list[int]): List of hidden layer dimensions.
            activation (str): Activation function ('tanh', 'relu', 'elu').
            input_dim (int): Dimension of the input.
            output_dim (int): Dimension of the output.
            dtype (torch.dtype): Data type for the model's parameters.
            device (Union[str, torch.device]): Device for the model.
            reg_weight (float): Regularization weight for VBLL.
            prior_scale (float): Prior scale for VBLL.
            wishart_scale (float): Wishart scale for VBLL.
        """
        super().__init__(dimensions, activation, input_dim, output_dim, dtype, device)

        # Replace the last layer with VBLL regression
        last_layer_input_dim = self.dimensions[-2]
        self.out_layer = vbll.Regression(
            last_layer_input_dim,
            output_dim,
            reg_weight,
            prior_scale=prior_scale,
            wishart_scale=wishart_scale,
        )

    def forward(self, x: torch.Tensor) -> vbll.VBLLReturn:
        """
        Perform a forward pass through the network.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            vbll.VBLLReturn: Predictive distribution and loss functions from VBLL.
        """
        for name, module in self.named_children():
            # Pass through all layers except the last VBLL layer
            if name != "out_layer":
                x = module(x)

        # Pass through the VBLL regression layer
        return self.out_layer(x)


In [None]:
import os
from typing import Any, Callable, List, Optional

import numpy as np
import torch
from botorch.models.model import Model
from botorch.posteriors import Posterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch import distributions as gdists
from laplace import Laplace
from torch import Tensor

from .nn_utils import RegNet, get_best_hyperparameters
from .nn_utils import augmented_and_regularized_trimmed_loss
# from .nn_utils import EarlyStopping
from early_stopping_pytorch import EarlyStopping


class LaplacePosterior(Posterior):
    def __init__(self, posterior, output_dim):
        super().__init__()
        self.post = posterior
        self.output_dim = output_dim

    def rsample(
        self,
        sample_shape: Optional[torch.Size] = None,
    ) -> Tensor:
        samples = self.post.rsample(sample_shape).squeeze(-1)
        new_shape = samples.shape[:-1]
        return samples.reshape(*new_shape, -1, self.output_dim)

    @property
    def mean(self) -> Tensor:
        r"""The posterior mean."""
        post_mean = self.post.mean.squeeze(-1)
        shape = post_mean.shape
        return post_mean.reshape(*shape[:-1], -1, self.output_dim)

    @property
    def variance(self) -> Tensor:
        r"""The posterior variance."""
        post_var = self.post.variance.squeeze(-1)
        shape = post_var.shape
        return post_var.reshape(*shape[:-1], -1, self.output_dim)

    @property
    def device(self) -> torch.device:
        return self.post.device

    @property
    def dtype(self) -> torch.dtype:
        r"""The torch dtype of the distribution."""
        return self.post.dtype


class LaplaceBNN(Model):
    def __init__(
        self,
        args: dict,
        input_dim,
        output_dim,
        device: torch.device = torch.device("cpu"),
        dtype: torch.dtype = torch.float64
    ):
        super().__init__()
        self.likelihood = "regression"
        self.regnet_dims = args.get("regnet_dims", [128, 128, 128])
        self.regnet_activation = args.get("regnet_activation", "tanh")
        self.prior_var = args.get("prior_var", 10.0)
        self.noise_var = args.get("noise_var", 1.0)
        self.iterative = args.get("iterative", True)
        self.loss_params = args.get("loss_params", {})
        self.nn = RegNet(
            dimensions=self.regnet_dims,
            activation=self.regnet_activation,
            input_dim=input_dim,
            output_dim=output_dim,
            dtype=dtype,
            device=device
        )
        self.bnn = None
        self.output_dim = output_dim

    def posterior_predictive(self, X, bnn):
        if len(X.shape) < 3:
            B, D = X.shape
            Q = 1
        else:
            # Transform to `(batch_shape*q, d)`
            B, Q, D = X.shape
            X = X.reshape(B * Q, D)

        K = self.num_outputs
        # Posterior predictive distribution
        mean_y, cov_y = self._get_prediction(X, bnn)

        # Reshape mean
        mean_y = mean_y.reshape(B, Q * K)

        # Reshape covariance
        cov_y += 1e-4 * torch.eye(B * Q * K).to(X)
        cov_y = cov_y.reshape(B, Q, K, B, Q, K)
        cov_y = torch.einsum('bqkbrl->bqkrl', cov_y)  # (B, Q, K, Q, K)
        cov_y = cov_y.reshape(B, Q * K, Q * K)

        dist = gdists.MultivariateNormal(mean_y, covariance_matrix=cov_y)
        post_pred = GPyTorchPosterior(dist)

        # Return a custom LaplacePosterior if multiple outputs in a batched scenario
        if K > 1 and Q > 1:
            return LaplacePosterior(post_pred, self.output_dim)
        else:
            return post_pred

    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[List[int]] = None,
        observation_noise: bool = False,
        posterior_transform: Optional[Callable[[Posterior], Posterior]] = None,
        **kwargs: Any,
    ) -> Posterior:
        return self.posterior_predictive(X, self.bnn)

    @property
    def num_outputs(self) -> int:
        return self.output_dim

    def _get_prediction(self, test_x: torch.Tensor, bnn):
        """
        Batched Laplace prediction.

        Args:
            test_x: Tensor of size `(batch_shape, d)`.

        Returns:
            Tuple of (mean, cov) with shapes:
            - mean: `(batch_shape, k)`
            - cov: `(batch_shape*k, batch_shape*k)`
        """
        mean_y, cov_y = bnn(test_x, joint=True)
        return mean_y, cov_y

    def get_likelihood(self, train_x, train_y, prior_var, noise_var):
        # fit to 80% of the data, and evaluate on the rest
        n = len(train_x)
        n_train = int(0.8 * n)
        train_x, val_x = train_x[:n_train], train_x[n_train:]
        train_y, val_y = train_y[:n_train], train_y[n_train:]
        train_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(train_x, train_y),
            batch_size=min(32, len(train_x)),  # smaller batch size
            shuffle=True
        )

        model = self.fit_laplace(train_loader, prior_var, noise_var)
        posterior = self.posterior_predictive(val_x, model)

        predictions_mean = posterior.mean
        predictions_std = torch.sqrt(posterior.variance + self.noise_var)
        # get log likelihood
        likelihood = torch.distributions.Normal(predictions_mean, predictions_std).log_prob(val_y).sum()
        return likelihood

    def fit_laplace(self, train_loader, prior_var, noise_var):
        """
        Fit a Laplace approximation on the last layer of the neural network.
        """
        bnn = Laplace(
            self.nn,
            self.likelihood,
            sigma_noise=np.sqrt(noise_var),
            prior_precision=(1 / prior_var),
            subset_of_weights='last_layer',
            hessian_structure='full',
            enable_backprop=True
        )
        bnn.fit(train_loader)
        bnn.optimize_prior_precision(n_steps=50)

        return bnn

    def fit(self, train_x, original_train_y, model_param_path=None):
        """
        Train the neural network (MSE + optional ARTL) and optionally fit Laplace approximation.
        """
        # Use a smaller batch_size to avoid training with the entire dataset at once
        batch_size = min(32, len(train_x))
        train_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(train_x, original_train_y),
            batch_size=batch_size,
            shuffle=True
        )

        n_epochs = self.loss_params.get("n_epochs", 10000)
        lr = self.loss_params.get("lr", 1e-2)
        weight_decay = self.loss_params.get("weight_decay", 0)
        momentum = self.loss_params.get("momentum", 0)
        artl_weight = self.loss_params.get("artl_weight", 0)  # Weight for ARTL loss
        h = self.loss_params.get("h", int(0.9 * len(train_x)))
        lambd = self.loss_params.get("lambd", 1e-3)
        k = self.loss_params.get("k", (1, 2, 3))
        q = self.loss_params.get("q", 2)
        M = self.loss_params.get("M", 10)

        # Load model state if provided and the file exists
        if model_param_path and os.path.isfile(model_param_path):
            try:
                model_state = torch.load(model_param_path, weights_only=True)
                self.nn.load_state_dict(model_state)
            except EOFError:
                # File is empty or corrupted
                model_state = None

        # Use a simple SGD without scheduling
        optimizer = torch.optim.SGD(
            self.nn.parameters(),
            lr=lr,
            weight_decay=weight_decay,
            momentum=momentum
        )
        mse_loss_func = torch.nn.MSELoss()

        early_stopping = EarlyStopping(patience=1000, verbose=True, path=model_param_path)

        for epoch in range(n_epochs):
            for x, y in train_loader:
                optimizer.zero_grad()

                # Compute MSE loss
                mse_loss = mse_loss_func(self.nn(x), y)

                # Compute ARTL loss if needed
                if artl_weight != 0:
                    artl_loss_val = augmented_and_regularized_trimmed_loss(
                        model=self.nn,
                        X_tensor=x,
                        y_tensor=y,
                        h=h,
                        lambd=lambd,
                        k=k,
                        q=q,
                        M=M
                    )
                else:
                    artl_loss_val = 0

                # Combine both losses
                total_loss = mse_loss + artl_weight * artl_loss_val

                # Backpropagation
                total_loss.backward()
                optimizer.step()

            # Logging
            mse_loss_value = mse_loss.item()
            artl_loss_value = (
                artl_loss_val.item() if isinstance(artl_loss_val, torch.Tensor) else artl_loss_val
            )

            print(f"Epoch {epoch+1}/{n_epochs}: MSE Loss: {mse_loss_value}, ARTL Loss: {artl_loss_value}")

            # Use training loss as "validation" loss for early stopping
            # val_loss = total_loss.item()
            val_loss = mse_loss_value
            early_stopping(val_loss, self.nn)

            if early_stopping.early_stop:
                print("Early stopping triggered")
                break

        # Reload the best model weights
        self.nn.load_state_dict(torch.load(model_param_path, weights_only=True))
        self.nn.eval()

        # Fit Laplace approximation if iterative
        if self.iterative:
            llh_fn = self.get_likelihood
            self.prior_var, self.noise_var = get_best_hyperparameters(
                train_x, original_train_y, llh_fn
            )

        self.bnn = self.fit_laplace(train_loader, self.prior_var, self.noise_var)
