In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from scipy.spatial.distance import cdist
import plotly.graph_objects as go

# 1. Define Distance-Aware Priors
class DistanceAwarePrior:
    def __init__(self, sigma0):
        self.sigma0 = sigma0  # Base standard deviation

    def compute_distance(self, x_train, x_val):
        """Compute the nearest neighbor distance (identity map as feature extractor)."""
        # Compute Euclidean distances
        x_train_np = x_train.cpu().numpy()
        x_val_np = x_val.cpu().numpy()
        distances = np.min(cdist(x_val_np, x_train_np, 'euclidean'), axis=1)
        return torch.tensor(distances, dtype=torch.float32)

    def distance_aware_variance(self, distances, phi):
        """Compute the distance-aware variance."""
        return (self.sigma0 + torch.exp(phi) * distances).pow(2)

# 2. Importance Sampling for Posterior Predictive
class ImportanceSamplingPredictor:
    def __init__(self, posterior_samples, prior_sigma0):
        self.posterior_samples = posterior_samples  # Samples from posterior p(theta | x, y)
        self.prior_sigma0 = prior_sigma0

    def compute_importance_weights(self, distance_variances):
        """Compute importance weights to correct predictions."""
        # Gaussian weight ratios between distance-aware prior and posterior
        prior_variances = self.prior_sigma0**2
        weights = torch.exp(-0.5 * (1 / distance_variances - 1 / prior_variances))
        return weights

    def predictive_mean_and_variance(self, x_val, model, weights):
        """Estimate the predictive mean and variance using importance sampling."""
        preds = []
        for theta in self.posterior_samples:  # Iterate through posterior samples
            model.load_state_dict(theta)
            with torch.no_grad():
                preds.append(model(x_val).squeeze())  # Ensure output is 1D for each x_val
        preds = torch.stack(preds, dim=0)  # Shape: [num_samples, num_inputs]
        weighted_preds = preds * weights.unsqueeze(0)
        mean = weighted_preds.mean(dim=0)  # Weighted mean across samples
        variance = ((preds - mean.unsqueeze(0)) ** 2 * weights.unsqueeze(0)).mean(dim=0)
        return mean, variance


# 3. Calibration of Distance-Aware Priors
class DAPCalibration:
    def __init__(self, model, prior_sigma0, x_train, x_val):
        self.model = model
        self.prior = DistanceAwarePrior(prior_sigma0)
        self.x_train = x_train
        self.x_val = x_val
        self.phi = torch.tensor(0.0, requires_grad=True)  # Initialize phi

    def calibration_loss(self, distances, target_uncertainty=1.0):
        """Calibration loss to match target uncertainty."""
        distance_variances = self.prior.distance_aware_variance(distances, self.phi)
        return torch.mean((distance_variances.sqrt() - target_uncertainty) ** 2)

    def calibrate(self, target_uncertainty=1.0, lr=0.01, epochs=100):
        """Optimize phi using calibration loss."""
        optimizer = optim.Adam([self.phi], lr=lr)
        distances = self.prior.compute_distance(self.x_train, self.x_val)

        for epoch in range(epochs):
            optimizer.zero_grad()
            loss = self.calibration_loss(distances, target_uncertainty)
            loss.backward()
            optimizer.step()
            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Phi = {self.phi.item():.4f}")

        print(f"Calibration complete. Optimal Phi = {self.phi.item():.4f}")
        return self.phi.detach()

# 4. Example Pipeline
if __name__ == "__main__":
    # Toy Example Setup
    class SimpleBNN(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc = nn.Sequential(nn.Linear(1, 10), nn.Tanh(), nn.Linear(10, 1))

        def forward(self, x):
            return self.fc(x)

    # Generate Toy 1D Data
    torch.manual_seed(42)
    x_train = torch.cat([
        torch.linspace(-3, -2, 5),  # Imbalanced regions
        torch.linspace(0, 1, 10),
        torch.linspace(2, 3, 5)
    ]).unsqueeze(1)
    y_train = torch.sin(x_train) + 0.1 * torch.randn_like(x_train)
    x_val = torch.linspace(-6, 6, 100).unsqueeze(1)  # Out-of-distribution points

    # Initialize BNN
    bnn = SimpleBNN()

    # Generate fake posterior samples (e.g., variational inference or MCMC)
    posterior_samples = [bnn.state_dict() for _ in range(10)]

    # Step 1: Calibrate Distance-Aware Priors
    dap = DAPCalibration(bnn, prior_sigma0=1.0, x_train=x_train, x_val=x_val)
    phi_opt = dap.calibrate(target_uncertainty=1.0)

    # Step 2: Apply Importance Sampling for Predictions
    predictor = ImportanceSamplingPredictor(posterior_samples, prior_sigma0=1.0)
    distances = dap.prior.compute_distance(x_train, x_val)
    distance_variances = dap.prior.distance_aware_variance(distances, phi_opt)
    weights = predictor.compute_importance_weights(distance_variances)
    predictive_mean, predictive_variance = predictor.predictive_mean_and_variance(x_val, bnn, weights)



    # Convert predictions and intervals to NumPy
    x_val_np = x_val.squeeze().numpy()
    x_train_np = x_train.squeeze().numpy()
    y_train_np = y_train.squeeze().numpy()
    pred_mean_np = predictive_mean.detach().numpy()
    pred_var_np = predictive_variance.detach().numpy()

    # Compute confidence intervals
    lower = pred_mean_np - 1.96 * np.sqrt(pred_var_np)
    upper = pred_mean_np + 1.96 * np.sqrt(pred_var_np)

    # Plot using Plotly
    fig = go.Figure()

    # Training data points
    fig.add_trace(
        go.Scatter(
            x=x_train_np,
            y=y_train_np,
            mode="markers",
            name="Training Data",
            marker=dict(color="blue"),
        )
    )

    # Predictive mean
    fig.add_trace(
        go.Scatter(
            x=x_val_np,
            y=pred_mean_np,
            mode="lines",
            name="Predictive Mean",
            line=dict(color="red"),
        )
    )

    # Confidence interval as a filled band
    fig.add_trace(
        go.Scatter(
            x=np.concatenate([x_val_np, x_val_np[::-1]]),
            y=np.concatenate([upper, lower[::-1]]),
            fill="toself",
            fillcolor="rgba(255, 0, 0, 0.2)",
            line=dict(width=0),
            name="95% Confidence Interval",
        )
    )

    # Layout settings
    fig.update_layout(
        title="Distance-Aware Prior Calibration",
        xaxis_title="Input x",
        yaxis_title="Output y",
        legend_title="Legend",
    )

    fig.show()
