In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from scipy.spatial.distance import cdist
import plotly.graph_objects as go

# 1. Define Distance-Aware Priors
class DistanceAwarePrior:
    def __init__(self, sigma0):
        self.sigma0 = sigma0  # Base standard deviation

    def compute_distance(self, x_train, x_val):
        """Compute the nearest neighbor distance (identity map as feature extractor)."""
        # Compute Euclidean distances
        x_train_np = x_train.cpu().numpy()
        x_val_np = x_val.cpu().numpy()
        distances = np.min(cdist(x_val_np, x_train_np, 'euclidean'), axis=1)
        return torch.tensor(distances, dtype=torch.float32)

    def distance_aware_variance(self, distances, phi):
        """Compute the distance-aware variance."""
        return (self.sigma0 + torch.exp(phi) * distances).pow(2)

# 2. Importance Sampling for Posterior Predictive
class ImportanceSamplingPredictor:
    def __init__(self, posterior_samples, prior_sigma0):
        self.posterior_samples = posterior_samples  # Samples from posterior p(theta | x, y)
        self.prior_sigma0 = prior_sigma0

    def compute_importance_weights(self, distance_variances):
        """Compute importance weights to correct predictions."""
        # Gaussian weight ratios between distance-aware prior and posterior
        prior_variances = self.prior_sigma0**2
        weights = torch.exp(-0.5 * (1 / distance_variances - 1 / prior_variances))
        return weights

    def predictive_mean_and_variance(self, x_val, model, weights):
        """Estimate the predictive mean and variance using importance sampling."""
        preds = []
        for theta in self.posterior_samples:  # Iterate through posterior samples
            model.load_state_dict(theta)
            with torch.no_grad():
                preds.append(model(x_val).squeeze())  # Ensure output is 1D for each x_val
        preds = torch.stack(preds, dim=0)  # Shape: [num_samples, num_inputs]
        weighted_preds = preds * weights.unsqueeze(0)
        mean = weighted_preds.mean(dim=0)  # Weighted mean across samples
        variance = ((preds - mean.unsqueeze(0)) ** 2 * weights.unsqueeze(0)).mean(dim=0)
        return mean, variance


# 3. Calibration of Distance-Aware Priors
class DAPCalibration:
    def __init__(self, model, prior_sigma0, x_train, x_val):
        self.model = model
        self.prior = DistanceAwarePrior(prior_sigma0)
        self.x_train = x_train
        self.x_val = x_val
        self.phi = torch.tensor(0.0, requires_grad=True)  # Initialize phi

    def calibration_loss(self, distances, target_uncertainty=1.0):
        """Calibration loss to match target uncertainty."""
        distance_variances = self.prior.distance_aware_variance(distances, self.phi)
        return torch.mean((distance_variances.sqrt() - target_uncertainty) ** 2)

    def calibrate(self, target_uncertainty=1.0, lr=0.01, epochs=100):
        """Optimize phi using calibration loss."""
        optimizer = optim.Adam([self.phi], lr=lr)
        distances = self.prior.compute_distance(self.x_train, self.x_val)

        for epoch in range(epochs):
            optimizer.zero_grad()
            loss = self.calibration_loss(distances, target_uncertainty)
            loss.backward()
            optimizer.step()
            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Phi = {self.phi.item():.4f}")

        print(f"Calibration complete. Optimal Phi = {self.phi.item():.4f}")
        return self.phi.detach()

# 4. Example Pipeline
if __name__ == "__main__":
    # Toy Example Setup
    class SimpleBNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Sequential(nn.Linear(1, 10), nn.Tanh(), nn.Linear(10, 1))

        def forward(self, x):
            return self.fc(x)

    # Generate Toy 1D Data
    torch.manual_seed(42)
    x_train = torch.cat([
        torch.linspace(-3, -2, 5),  # Imbalanced regions
        torch.linspace(0, 1, 10),
        torch.linspace(2, 3, 5)
    ]).unsqueeze(1)
    y_train = torch.sin(x_train) + 0.1 * torch.randn_like(x_train)
    x_val = torch.linspace(-6, 6, 100).unsqueeze(1)  # Out-of-distribution points

    # Initialize BNN
    bnn = SimpleBNN()

    # Generate fake posterior samples (e.g., variational inference or MCMC)
    posterior_samples = [bnn.state_dict() for _ in range(10)]

    # Step 1: Calibrate Distance-Aware Priors
    dap = DAPCalibration(bnn, prior_sigma0=1.0, x_train=x_train, x_val=x_val)
    phi_opt = dap.calibrate(target_uncertainty=1.0)

    # Step 2: Apply Importance Sampling for Predictions
    predictor = ImportanceSamplingPredictor(posterior_samples, prior_sigma0=1.0)
    distances = dap.prior.compute_distance(x_train, x_val)
    distance_variances = dap.prior.distance_aware_variance(distances, phi_opt)
    weights = predictor.compute_importance_weights(distance_variances)
    predictive_mean, predictive_variance = predictor.predictive_mean_and_variance(x_val, bnn, weights)



    # Convert predictions and intervals to NumPy
    x_val_np = x_val.squeeze().numpy()
    x_train_np = x_train.squeeze().numpy()
    y_train_np = y_train.squeeze().numpy()
    pred_mean_np = predictive_mean.detach().numpy()
    pred_var_np = predictive_variance.detach().numpy()

    # Compute confidence intervals
    lower = pred_mean_np - 1.96 * np.sqrt(pred_var_np)
    upper = pred_mean_np + 1.96 * np.sqrt(pred_var_np)

    # Plot using Plotly
    fig = go.Figure()

    # Training data points
    fig.add_trace(
        go.Scatter(
            x=x_train_np,
            y=y_train_np,
            mode="markers",
            name="Training Data",
            marker=dict(color="blue"),
        )
    )

    # Predictive mean
    fig.add_trace(
        go.Scatter(
            x=x_val_np,
            y=pred_mean_np,
            mode="lines",
            name="Predictive Mean",
            line=dict(color="red"),
        )
    )

    # Confidence interval as a filled band
    fig.add_trace(
        go.Scatter(
            x=np.concatenate([x_val_np, x_val_np[::-1]]),
            y=np.concatenate([upper, lower[::-1]]),
            fill="toself",
            fillcolor="rgba(255, 0, 0, 0.2)",
            line=dict(width=0),
            name="95% Confidence Interval",
        )
    )

    # Layout settings
    fig.update_layout(
        title="Distance-Aware Prior Calibration",
        xaxis_title="Input x",
        yaxis_title="Output y",
        legend_title="Legend",
    )

    fig.show()


Epoch 0: Loss = 1.6392, Phi = -0.0100
Epoch 10: Loss = 1.3439, Phi = -0.1091
Epoch 20: Loss = 1.1100, Phi = -0.2042
Epoch 30: Loss = 0.9280, Phi = -0.2930
Epoch 40: Loss = 0.7870, Phi = -0.3747
Epoch 50: Loss = 0.6769, Phi = -0.4494
Epoch 60: Loss = 0.5898, Phi = -0.5177
Epoch 70: Loss = 0.5197, Phi = -0.5804
Epoch 80: Loss = 0.4624, Phi = -0.6383
Epoch 90: Loss = 0.4149, Phi = -0.6922
Calibration complete. Optimal Phi = -0.7376


In [23]:
if __name__ == "__main__":
    # Toy Example Setup
    class SimpleBNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Sequential(
                nn.Linear(1, 128),
                nn.Tanh(),
                nn.Linear(128, 128),
                nn.Tanh(),
                nn.Linear(128, 128),
                nn.Tanh(),
                nn.Linear(128, 1)
            )

        def forward(self, x):
            return self.fc(x)

    # Generate Toy 1D Data
    torch.manual_seed(42)
    x_train = torch.cat([
        torch.linspace(-3, -2, 2),  # Imbalanced regions
        torch.linspace(0, 1, 5),
        torch.linspace(2, 3, 3)
    ]).unsqueeze(1)
    y_train = torch.sin(x_train) + 0.1 * torch.randn_like(x_train)
    x_val = torch.linspace(-6, 6, 100).unsqueeze(1)  # Out-of-distribution points
    y_val = torch.sin(x_val)  # True function values for validation data

    # Initialize BNN
    bnn = SimpleBNN()

    # Step 1: Train the BNN with MSE Loss
    criterion = nn.MSELoss()
    optimizer = optim.Adam(bnn.parameters(), lr=0.01)
    epochs = 50

    for epoch in range(epochs):
        bnn.train()
        optimizer.zero_grad()
        y_pred = bnn(x_train)
        loss = criterion(y_pred, y_train)
        loss.backward()
        optimizer.step()

        if epoch % 50 == 0:
            print(f"Epoch {epoch}/{epochs}, Loss: {loss.item():.4f}")

    # Step 2: Generate posterior samples (e.g., variational inference or MCMC)
    posterior_samples = [bnn.state_dict() for _ in range(10)]

    # Step 3: Calibrate Distance-Aware Priors
    dap = DAPCalibration(bnn, prior_sigma0=1., x_train=x_train, x_val=x_val)
    phi_opt = dap.calibrate(target_uncertainty=y_train.mean())

    # Step 4: Apply Importance Sampling for Predictions
    predictor = ImportanceSamplingPredictor(posterior_samples, prior_sigma0=1.0)
    distances = dap.prior.compute_distance(x_train, x_val)
    distance_variances = dap.prior.distance_aware_variance(distances, phi_opt)
    weights = predictor.compute_importance_weights(distance_variances)
    predictive_mean, predictive_variance = predictor.predictive_mean_and_variance(x_val, bnn, weights)

    # Convert predictions and intervals to NumPy
    x_val_np = x_val.squeeze().numpy()
    x_train_np = x_train.squeeze().numpy()
    y_train_np = y_train.squeeze().numpy()
    y_val_np = y_val.squeeze().numpy()
    pred_mean_np = predictive_mean.detach().numpy()
    pred_var_np = predictive_variance.detach().numpy()

    # Compute confidence intervals
    lower = pred_mean_np - 1.96 * np.sqrt(pred_var_np)
    upper = pred_mean_np + 1.96 * np.sqrt(pred_var_np)

    # Compute MSE for validation
    mse = np.mean((y_val_np - pred_mean_np) ** 2)
    print(f"Validation Mean Squared Error (MSE): {mse:.4f}")

    # Plot using Plotly
    fig = go.Figure()

    # Training data points
    fig.add_trace(
        go.Scatter(
            x=x_train_np,
            y=y_train_np,
            mode="markers",
            name="Training Data",
            marker=dict(color="blue"),
        )
    )

    # Predictive mean
    fig.add_trace(
        go.Scatter(
            x=x_val_np,
            y=pred_mean_np,
            mode="lines",
            name="Predictive Mean",
            line=dict(color="red"),
        )
    )

    # Confidence interval as a filled band
    fig.add_trace(
        go.Scatter(
            x=np.concatenate([x_val_np, x_val_np[::-1]]),
            y=np.concatenate([upper, lower[::-1]]),
            fill="toself",
            fillcolor="rgba(255, 0, 0, 0.2)",
            line=dict(width=0),
            name="95% Confidence Interval",
        )
    )

    # Layout settings
    fig.update_layout(
        title="Distance-Aware Prior Calibration with MSE Training",
        xaxis_title="Input x",
        yaxis_title="Output y",
        legend_title="Legend",
    )

    fig.show()


Epoch 0/50, Loss: 0.2979
Epoch 0: Loss = 3.4007, Phi = -0.0100
Epoch 10: Loss = 2.9836, Phi = -0.1092
Epoch 20: Loss = 2.6416, Phi = -0.2051
Epoch 30: Loss = 2.3651, Phi = -0.2956
Epoch 40: Loss = 2.1422, Phi = -0.3801
Epoch 50: Loss = 1.9613, Phi = -0.4585
Epoch 60: Loss = 1.8130, Phi = -0.5313
Epoch 70: Loss = 1.6896, Phi = -0.5992
Epoch 80: Loss = 1.5856, Phi = -0.6628
Epoch 90: Loss = 1.4969, Phi = -0.7226
Calibration complete. Optimal Phi = -0.7736
Validation Mean Squared Error (MSE): 0.8273


In [66]:
if __name__ == "__main__":
    # Toy Example Setup
    class SimpleBNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Sequential(
                nn.Linear(1, 128),
                nn.Tanh(),
                nn.Linear(128, 128),
                nn.Tanh(),
                nn.Linear(128, 128),
                nn.Tanh(),
                nn.Linear(128, 1)
            )

        def forward(self, x):
            return self.fc(x)

    # Generate Toy 1D Data
    # torch.manual_seed(42)
    x_train = torch.cat([
        # torch.linspace(-3, -2, 2),  # Imbalanced regions
        torch.linspace(0, 1, 2),
        # torch.linspace(2, 3, 3)
    ]).unsqueeze(1)
    y_train = torch.sin(x_train) + 0.1 * torch.randn_like(x_train)
    x_val = torch.linspace(-6, 6, 100).unsqueeze(1)  # Out-of-distribution points
    y_val = torch.sin(x_val)  # True function values for validation data

    # Initialize BNN
    bnn = SimpleBNN()

    # Step 1: Train the BNN with MSE Loss
    criterion = nn.MSELoss()
    optimizer = optim.Adam(bnn.parameters(), lr=0.01)
    epochs = 500

    for epoch in range(epochs):
        bnn.train()
        optimizer.zero_grad()
        y_pred = bnn(x_train)
        loss = criterion(y_pred, y_train)
        loss.backward()
        optimizer.step()

        if epoch % 50 == 0:
            print(f"Epoch {epoch}/{epochs}, Loss: {loss.item():.4f}")

    # Step 2: Generate posterior samples (e.g., variational inference or MCMC)
    posterior_samples = [bnn.state_dict() for _ in range(10)]

    # Step 3: Calibrate Distance-Aware Priors
    dap = DAPCalibration(bnn, prior_sigma0=1., x_train=x_train, x_val=x_val)
    phi_opt = dap.calibrate(target_uncertainty=y_train.mean())

    # Step 4: Apply Importance Sampling for Predictions
    predictor = ImportanceSamplingPredictor(posterior_samples, prior_sigma0=1.0)
    distances = dap.prior.compute_distance(x_train, x_val)
    distance_variances = dap.prior.distance_aware_variance(distances, phi_opt)
    weights = predictor.compute_importance_weights(distance_variances)
    predictive_mean, predictive_variance = predictor.predictive_mean_and_variance(x_val, bnn, weights)

    # Convert predictions and intervals to NumPy
    x_val_np = x_val.squeeze().numpy()
    x_train_np = x_train.squeeze().numpy()
    y_train_np = y_train.squeeze().numpy()
    y_val_np = y_val.squeeze().numpy()
    pred_mean_np = predictive_mean.detach().numpy()
    pred_var_np = predictive_variance.detach().numpy()

    # Compute confidence intervals
    lower = pred_mean_np - 1.96 * np.sqrt(pred_var_np)
    upper = pred_mean_np + 1.96 * np.sqrt(pred_var_np)

    # Compute MSE for validation
    mse = np.mean((y_val_np - pred_mean_np) ** 2)
    print(f"Validation Mean Squared Error (MSE): {mse:.4f}")

    # Plot using Plotly
    fig = go.Figure()

    # Training data points
    fig.add_trace(
        go.Scatter(
            x=x_train_np,
            y=y_train_np,
            mode="markers",
            name="Training Data",
            marker=dict(color="blue"),
        )
    )

    # Predictive mean
    fig.add_trace(
        go.Scatter(
            x=x_val_np,
            y=pred_mean_np,
            mode="lines",
            name="Predictive Mean",
            line=dict(color="red"),
        )
    )

    # Confidence interval as a filled band
    fig.add_trace(
        go.Scatter(
            x=np.concatenate([x_val_np, x_val_np[::-1]]),
            y=np.concatenate([upper, lower[::-1]]),
            fill="toself",
            fillcolor="rgba(255, 0, 0, 0.2)",
            line=dict(width=0),
            name="95% Confidence Interval",
        )
    )

    # Layout settings
    fig.update_layout(
        title="Distance-Aware Prior Calibration with MSE Training",
        xaxis_title="Input x",
        yaxis_title="Output y",
        legend_title="Legend",
    )

    fig.show()


Epoch 0/500, Loss: 0.3256
Epoch 50/500, Loss: 0.0032
Epoch 100/500, Loss: 0.0000
Epoch 150/500, Loss: 0.0000
Epoch 200/500, Loss: 0.0000
Epoch 250/500, Loss: 0.0000
Epoch 300/500, Loss: 0.0000
Epoch 350/500, Loss: 0.0000
Epoch 400/500, Loss: 0.0000
Epoch 450/500, Loss: 0.0000
Epoch 0: Loss = 13.5293, Phi = -0.0100
Epoch 10: Loss = 11.4606, Phi = -0.1092
Epoch 20: Loss = 9.7904, Phi = -0.2047
Epoch 30: Loss = 8.4634, Phi = -0.2945
Epoch 40: Loss = 7.4124, Phi = -0.3777
Epoch 50: Loss = 6.5742, Phi = -0.4544
Epoch 60: Loss = 5.8971, Phi = -0.5251
Epoch 70: Loss = 5.3422, Phi = -0.5907
Epoch 80: Loss = 4.8806, Phi = -0.6518
Epoch 90: Loss = 4.4914, Phi = -0.7089
Calibration complete. Optimal Phi = -0.7574
Validation Mean Squared Error (MSE): 0.9211


In [None]:
from typing import Union, Optional, Callable, Any

import torch
from botorch.models.model import Model
from botorch.posteriors import Posterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch import distributions as gdists
from torch import Tensor
from torch.utils.data import DataLoader

from ._utils import EarlyStopping
from ._utils import MLP
from ._utils import hovr_loss_fn, trimmed_loss_fn


class DAPModel(Model):
    def __init__(
        self,
        dimensions: list[int],
        activation: str,
        input_dim: int,
        output_dim: int,
        dtype: torch.dtype = torch.float64,
        device: Union[str, torch.device] = "cpu",
    ) -> None:
        super().__init__()

        self.output_dim = output_dim

        self.nn = MLP(
            dimensions=dimensions,
            activation=activation,
            input_dim=input_dim,
            output_dim=output_dim,
            dtype=dtype,
            device=device,
        )

        self.bnn = None

    @property
    def num_outputs(self) -> int:
        return self.output_dim

    def _scaling_fn(self, x: Tensor) -> Tensor:
        """
        Function to compute the variance scaling factor.

        e.g.: exponential, scaled sigmoid, softplus etc.
        """
        return torch.exp(x)
    
    def _p(self, x: Tensor) -> Tensor:
        """
        Project the inputs to a lower dimensional space.
        However, we consider this as an identity map.
        """
        return x

    def _compute_distance(self, X1: Tensor, X2: Tensor) -> Tensor:
        """
        Compute the pairwise distance between two sets of points.
        """
        return torch.cdist(X1, X2)

    def forward(self, X: Tensor) -> Tensor:
        mean, covariance = self.bnn(X, joint=True)
        mean = mean.reshape(*mean.shape[:-1], -1, self.output_dim)
        return mean, covariance
    
    def posterior(
        self, 
        X: Tensor,
        output_indices: Optional[list[int]] = None,
        observation_noise: bool = False,
        posterior_transform: Optional[Callable[[Posterior], Posterior]] = None,
        **kwargs: Any,
    ) -> Posterior:
        mean, covariance = self.forward(X)
        dist = gdists.MultivariateNormal(mean, covariance)
        posterior = GPyTorchPosterior(dist)
        return posterior
    
    def loss_fn(
        self,
        x: Tensor,
        y: Tensor,
        mse_coeff: float,
        trim_coeff: float,
        hovr_coeff: float,
        h: int = None,
        k: tuple[int, ...] = (1, 2),
        q: int = 2,
        M: int = 10,
    ) -> Tensor:
        """
        Computes the combined loss based on MSE, trimmed loss, and HOVR loss.

        Args:
            x (Tensor): Input features.
            y (Tensor): Target labels.
            mse_coeff (float): Coefficient for the MSE loss.
            trim_coeff (float): Coefficient for the trimmed loss.
            hovr_coeff (float): Coefficient for the HOVR loss.
            h (int, optional): Parameter for trimmed loss.
            k (tuple[int, ...], optional): Parameter for HOVR loss.
            q (int, optional): Parameter for HOVR loss.
            M (int, optional): Parameter for HOVR loss.

        Returns:
            Tensor: Combined loss value.
        """
        y_pred = self.nn(x)

        # Define individual loss components
        def calculate_loss(coeff: float, loss_fn: callable, *args, **kwargs) -> float:
            return coeff * loss_fn(*args, **kwargs) if coeff > 0 else 0

        # Loss calculations
        mse_loss = calculate_loss(mse_coeff, torch.nn.MSELoss(), y_pred, y)
        trim_loss = calculate_loss(trim_coeff, trimmed_loss_fn, self.nn, x, y, h=h)
        hovr_loss = calculate_loss(hovr_coeff, hovr_loss_fn, self.nn, x, y, k=k, q=q, M=M)

        # Combine losses
        total_loss = mse_loss + trim_loss + hovr_loss
        
        return total_loss

    def fit(
        self,
        x: Tensor,
        y: Tensor,
        config: dict = {}
    ) -> None:
        """
        MAP Estimation (deterministic training) followed by Uncertainty Estimation via Laplace approximation.

        config = {
            # Data loading
            "batch_size": 32,          # Batch size for training
            "val_split": 0.2,          # Validation split ratio
            "min_train_size": 5,       # Minimum dataset size for validation split

            # Optimization
            "optimizer": torch.optim.Adam,  # Optimizer class
            "lr": 1e-3,                    # Learning rate
            "weight_decay": 0,             # Weight decay for non-VBLL layers
            "epochs": 1000,                 # Number of training epochs

            # Loss coefficients
            "loss_coeffs": {
                "mse": 1,                 # MSE loss coefficient
                "trim": 0,                # Trimmed loss coefficient
                "hovr": 0,                # HOVR loss coefficient
            }

            # Loss parameters
            loss_params = {
                "h": None,                # Number of points to keep after trimming
                "k": (1, 2),              # Tuple of derivative orders
                "q": 2,                   # Exponent in HOVR
                "M": 10,                  # Number of random points
            }

            # Early Stopping
            "patience": 20,           # Number of epochs to wait for improvement
            "verbose": True,         # Print early stopping messages
            "delta": 0,              # Minimum change to qualify as improvement

            # Laplace approximation
            "hessian_structure": "full",
            "prior_precision": 1e-2,  # Too small prior precision can lead to numerical instability
            "sigma_noise": 1e-1,
            "temperature": 1,
        }
        """
        train_loader = self.fit_map(x, y, config) # Share train loader with Laplace
        
        self.bnn = Laplace(
            model=self.nn,
            likelihood="regression",
            subset_of_weights='last_layer',
            hessian_structure=config.get("hessian_structure", "full"),
            prior_precision=config.get("prior_precision", 1e-2),
            sigma_noise=config.get("sigma_noise", 1e-1),
            temperature=config.get("temperature", 1),
            enable_backprop=True
        )

        self.bnn.fit(train_loader)

    def fit_map(
        self,
        x: Tensor,
        y: Tensor,
        config: dict
    ) -> None:
        """
        Reference fit method.
        """

        # Retrieve minimum training size for Early Stopping
        min_train_size = config.get("min_train_size", 5)
        
        dataset_size = x.size(0)
        
        if dataset_size >= min_train_size:
            # Split data into training and validation sets
            val_split = config.get("val_split", 0.2)
            indices = torch.randperm(dataset_size)
            split = int(val_split * dataset_size)
            train_indices, val_indices = indices[split:], indices[:split]
            train_dataset = torch.utils.data.TensorDataset(x[train_indices], y[train_indices])
            val_dataset = torch.utils.data.TensorDataset(x[val_indices], y[val_indices])
            train_loader = DataLoader(train_dataset, batch_size=config.get("batch_size", 32), shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=config.get("batch_size", 32), shuffle=False)

            # Initialize EarlyStopping
            early_stopping = EarlyStopping(
                patience=config.get("patience", 20),
                verbose=config.get("verbose", True),
                delta=config.get("delta", 0)
            )

        else:
            # Use the entire dataset for training without validation
            train_dataset = torch.utils.data.TensorDataset(x, y)
            train_loader = DataLoader(train_dataset, batch_size=config.get("batch_size", 32), shuffle=True)
            val_loader = None  # No validation
            early_stopping = None  # EarlyStopping not applied

        weight_decay = config.get("weight_decay", 0)

        loss_coeffs = config.get("loss_coeffs", {})
        mse_coeff = loss_coeffs.get("mse", 1)
        trim_coeff = loss_coeffs.get("trim", 0)
        hovr_coeff = loss_coeffs.get("hovr", 0)
        total_coeff = mse_coeff + trim_coeff + hovr_coeff + weight_decay
        mse_coeff /= total_coeff
        trim_coeff /= total_coeff
        hovr_coeff /= total_coeff
        weight_decay /= total_coeff

        loss_params = config.get("loss_params", {})
        h = loss_params.get("h", None)
        k = loss_params.get("k", (1, 2))
        q = loss_params.get("q", 2)
        M = loss_params.get("M", 10)

        non_out_layer_params = []
        out_layer_params = []
        for name, param in self.named_parameters():
            if name.startswith("nn.out_layer"):
                out_layer_params.append(param)
            else:
                non_out_layer_params.append(param)

        param_list = [
            {"params": non_out_layer_params, "weight_decay": weight_decay},
            {"params": out_layer_params, "weight_decay": 0},
        ]

        optimizer_class = config.get("optimizer", torch.optim.Adam)
        optimizer = optimizer_class(param_list, lr=config.get("lr", 1e-3))

        epochs = config.get("epochs", 1000)
        for _ in range(epochs):
            self.train()
            for x_batch, y_batch in train_loader:

                optimizer.zero_grad()
                batch_loss = self.loss_fn(
                    x_batch, 
                    y_batch, 
                    mse_coeff, 
                    trim_coeff, 
                    hovr_coeff,
                    h=h,
                    k=k,
                    q=q,
                    M=M
                )
                batch_loss.backward()
                optimizer.step()

            if early_stopping and val_loader is not None:
                # Validation
                self.eval()
                val_loss = 0
                with torch.no_grad():
                    for x_val, y_val in val_loader:
                        loss = self.loss_fn(
                            x_val, 
                            y_val,
                            mse_coeff,
                            trim_coeff,
                            hovr_coeff,
                            h=h,
                            k=k,
                            q=q,
                            M=M
                        )
                        val_loss += loss.item()
                val_loss /= len(val_loader)

                # Check early stopping
                early_stopping(val_loss, self)
                if early_stopping.early_stop:
                    break
        
        return train_loader
    

In [None]:
from typing import Union, Optional, Callable, Any

import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
from botorch.models.model import Model
from botorch.posteriors import Posterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch import distributions as gdists
import numpy as np

from _utils import MLP
from _utils import gaussian_filter_1d


def gaussian_filter_1d(input_tensor, sigma: float = 1.0):
    """
    1D Gaussian filter implementation in PyTorch
    
    Args:
        input_tensor (torch.Tensor): 1D input data to be smoothed (shape: [N])
        sigma (float, optional): Standard deviation of the Gaussian distribution
                                 (default: 1.0, controls the smoothing strength)
    
    Returns:
        torch.Tensor: Smoothed 1D data
    """
    dtype = input_tensor.dtype
    device = input_tensor.device

    # Determine kernel size based on the 3-sigma rule
    kernel_size = int(6 * sigma + 1)
    if kernel_size % 2 == 0:
        kernel_size += 1  # Adjust to ensure an odd kernel size

    # Create the Gaussian kernel
    x = torch.arange(kernel_size, dtype=dtype, device=device) - kernel_size // 2
    gaussian_kernel = torch.exp(-0.5 * (x / sigma) ** 2)
    gaussian_kernel /= gaussian_kernel.sum()  # Normalize the kernel

    # Reshape kernel to match convolution input requirements
    gaussian_kernel = gaussian_kernel.view(1, 1, -1)

    # Add batch and channel dimensions to the input tensor
    input_tensor = input_tensor.view(1, 1, -1)

    # Perform 1D convolution with Gaussian kernel
    smoothed_tensor = F.conv1d(input_tensor, gaussian_kernel, padding=kernel_size // 2)

    # Remove batch and channel dimensions to return the original shape
    return smoothed_tensor.view(-1)


class DAPModel(Model):
    def __init__(self,
                 dimensions: list[int],
                 activation: str,
                 input_dim: int,
                 output_dim: int,
                 dtype: torch.dtype = torch.float64,
                 device: Union[str, torch.device] = "cpu",
                 sigma_0: float = 1.0,
                 phi: float = 0.1) -> None:
        super().__init__()

        self.output_dim = output_dim
        self.sigma_0 = sigma_0
        self.phi = torch.tensor(phi, dtype=dtype, device=device, requires_grad=True)  # Convert phi to Tensor

        self.nn = MLP(
            dimensions=dimensions,
            activation=activation,
            input_dim=input_dim,
            output_dim=output_dim,
            dtype=dtype,
            device=device,
        )

        self.X_train = None
        self.Y_train = None

    @property
    def num_outputs(self) -> int:
        return self.output_dim


    def scaling_fn(self, x: Tensor) -> Tensor:
        # """Function to compute the variance scaling factor."""
        # return torch.nn.Softplus()(x)
        """Sigmoid-based scaling function."""
        # return 1 / (1 + torch.exp(x))
        return torch.exp(x)

    def forward(self, X: Tensor) -> Tensor:
        """Compute the prior-based mean and covariance."""
        mean = self.nn(X)
        distance = self._compute_distance(X, self.X_train)
        distance = gaussian_filter_1d(distance, sigma=1e-1)  # Apply Gaussian filter
        prior_variance = (self.sigma_0 + self.scaling_fn(self.phi) * distance).pow(2)  # self.phi is now a Tensor
        covariance = torch.diag_embed(prior_variance)
        return mean, covariance

    def _compute_distance(self, X_test: Tensor, X_train: Tensor) -> Tensor:
        """Compute the Euclidean distance between test and training inputs."""
        distances = torch.cdist(X_test, X_train)
        min_distances, _ = distances.min(dim=1)  # Minimum distance to training points
        return min_distances

    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[list[int]] = None,
        observation_noise: bool = False,
        posterior_transform: Optional[Callable[[Posterior], Posterior]] = None,
        **kwargs: Any,
    ) -> Posterior:
        mean, covariance = self.forward(X)
        dist = gdists.MultivariateNormal(mean, covariance)
        posterior = GPyTorchPosterior(dist)
        return posterior

    def fit(self, X: Tensor, Y: Tensor) -> None:
        """Fit the model to the training data."""
        self.X_train = X
        self.Y_train = Y

        optimizer = torch.optim.Adam(self.nn.parameters(), lr=1e-3)
        loss_fn = nn.MSELoss()

        for epoch in range(1000):
            self.nn.train()
            optimizer.zero_grad()
            Y_pred = self.nn(X)
            loss = loss_fn(Y_pred, Y)
            loss.backward()
            optimizer.step()

    def calibrate_phi(
        self,
        X_val: Tensor,
        Y_val: Tensor,
        lr: float = 1e-3,
        epochs: int = 100,
        eps: float = 1e-9,
    ) -> None:
        """Calibrate phi using validation data and the variance of training labels as target uncertainty level.

        Args:
            X_val (Tensor): Validation input data.
            Y_val (Tensor): Validation output data.
            lr (float): Learning rate for phi calibration.
            epochs (int): Number of epochs for calibration.
        """
        if self.Y_train.shape[0] == 1:
            gamma = eps
        else:   
            gamma = torch.var(self.Y_train) + eps  # Use training label variance as target uncertainty
        optimizer = torch.optim.Adam([self.phi], lr=lr)

        for epoch in range(epochs):
            self.nn.eval()

            # Perform forward pass (gradients enabled)
            mean, covariance = self.forward(X_val)

            # Compute the epistemic variance (diagonal of covariance matrix)
            uncertainty = torch.sqrt(torch.diagonal(covariance, dim1=-2, dim2=-1))
            
            # Calibration loss: MSE between uncertainty and target gamma
            calibration_loss = torch.mean((uncertainty - gamma) ** 2)

            optimizer.zero_grad()
            calibration_loss.backward()
            optimizer.step()

            if epoch % 10 == 0:
                print(f"Epoch {epoch}, Calibration Loss: {calibration_loss.item():.6f}")






# Example testing the DAPModel with 1D regression
if __name__ == "__main__":
    import plotly.graph_objects as go

    # Generate synthetic data
    np.random.seed(0)
    torch.manual_seed(0)

    # Create training data and normalize input and labels
    X = np.linspace(-5, 5, 20).reshape(-1, 1)
    y = np.sin(X) + 0.1 * np.random.normal(size=X.shape)

    X = (X - X.min()) / (X.max() - X.min())  # Normalize inputs to [0, 1]
    y_mean, y_std = y.mean(), y.std()
    y = (y - y_mean) / y_std  # Standardize labels

    X_train = torch.tensor(X, dtype=torch.float64)
    y_train = torch.tensor(y, dtype=torch.float64)

    # Define model configuration
    model = DAPModel(
        dimensions=[128, 128, 128],
        activation="tanh",
        input_dim=1,
        output_dim=1,
        sigma_0=1e-2 * 3
    )

    # Train model
    model.fit(X_train, y_train)

    # Calibrate phi using validation data
    X_val = torch.linspace(-5.5, 5.5, 1000).reshape(-1, 1).to(torch.float64)
    X_val = (X_val - X_val.min()) / (X_val.max() - X_val.min())  # Normalize validation inputs
    Y_val = torch.sin(X_val * (X.max() - X.min()) + X.min()) + 0.2 * torch.randn_like(X_val)  # Synthetic validation data
    model.calibrate_phi(X_val, Y_val)

    # Predict
    X_test = torch.linspace(-6, 6, 1000).reshape(-1, 1).to(torch.float64)
    X_test = (X_test - X_test.min()) / (X_test.max() - X_test.min())  # Normalize test inputs
    with torch.no_grad():
        y_pred, covariance = model.forward(X_test)

    # Convert to numpy for plotting
    X_test_np = X_test.numpy()
    y_pred_np = y_pred.numpy().squeeze() * y_std + y_mean  # Re-scale predictions to original label scale
    std_dev = torch.sqrt(torch.diagonal(covariance, dim1=-2, dim2=-1)).numpy().squeeze() * y_std  # Re-scale std dev

    # Plot results
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=X_train.numpy().squeeze(), y=(y_train.numpy().squeeze() * y_std + y_mean), mode="markers", name="Training Data"))
    fig.add_trace(go.Scatter(x=X_test_np.squeeze(), y=y_pred_np, mode="lines", name="Prediction"))

    # Add 2-sigma confidence intervals
    fig.add_trace(go.Scatter(
        x=X_test_np.squeeze(),
        y=(y_pred_np + 2 * std_dev),
        mode="lines",
        name="Upper 2-sigma",
        line=dict(dash="dash")
    ))
    fig.add_trace(go.Scatter(
        x=X_test_np.squeeze(),
        y=(y_pred_np - 2 * std_dev),
        mode="lines",
        name="Lower 2-sigma",
        line=dict(dash="dash")
    ))

    fig.update_layout(
        title="1D Function Regression with Distance-Aware Priors",
        xaxis_title="Input",
        yaxis_title="Output",
    )
    fig.show()



# # Example testing the DAPModel with 1D regression
# if __name__ == "__main__":
#     import plotly.graph_objects as go

#     # Generate synthetic data
#     np.random.seed(0)
#     torch.manual_seed(0)

#     # x1 = np.random.uniform(-5, -3, 10)
#     # x2 = np.random.uniform(0, 1, 3)
#     # x3 = np.random.uniform(2, 5, 30)
#     # X = torch.tensor(np.concatenate([x1, x2, x3]).reshape(-1, 1), dtype=torch.float64)
#     X = np.linspace(-5, 5, 3).reshape(-1, 1)
#     y = np.sin(X) + 0.1 * np.random.normal(size=X.shape)

#     X_train = torch.tensor(X, dtype=torch.float64)
#     y_train = torch.tensor(y, dtype=torch.float64)

#     # Define model configuration
#     model = DAPModel(
#         dimensions=[128, 128, 128],
#         activation="tanh",
#         input_dim=1,
#         output_dim=1,
#         sigma_0=1e-2
#     )

#     # Train model
#     model.fit(X_train, y_train)

#     # Calibrate phi using validation data
#     X_val = torch.linspace(-5.5, 5.5, 1000).reshape(-1, 1).to(torch.float64)
#     Y_val = torch.sin(X_val) + 0.2 * torch.randn_like(X_val)  # Synthetic validation data
#     model.calibrate_phi(X_val, Y_val)

#     # Predict
#     X_test = torch.linspace(-6, 6, 1000).reshape(-1, 1).to(torch.float64)
#     with torch.no_grad():
#         y_pred, covariance = model.forward(X_test)

#     # Convert to numpy for plotting
#     X_test_np = X_test.numpy()
#     y_pred_np = y_pred.numpy().squeeze()
#     std_dev = torch.sqrt(torch.diagonal(covariance, dim1=-2, dim2=-1)).numpy().squeeze()

#     # Plot results
#     fig = go.Figure()
#     fig.add_trace(go.Scatter(x=X.squeeze(), y=y.squeeze(), mode="markers", name="Training Data"))
#     fig.add_trace(go.Scatter(x=X_test_np.squeeze(), y=y_pred_np, mode="lines", name="Prediction"))

#     # Add 2-sigma confidence intervals
#     fig.add_trace(go.Scatter(
#         x=X_test_np.squeeze(),
#         y=(y_pred_np + 2 * std_dev),
#         mode="lines",
#         name="Upper 2-sigma",
#         line=dict(dash="dash")
#     ))
#     fig.add_trace(go.Scatter(
#         x=X_test_np.squeeze(),
#         y=(y_pred_np - 2 * std_dev),
#         mode="lines",
#         name="Lower 2-sigma",
#         line=dict(dash="dash")
#     ))

#     fig.update_layout(
#         title="1D Function Regression with Distance-Aware Priors",
#         xaxis_title="Input",
#         yaxis_title="Output",
#     )
#     fig.show()
