In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# --- Simulator and True Likelihood Definitions ------------------------------

def f_1d(theta):
    """
    Cubic nonlinearity, f(theta) = ((1.5*theta + 0.5)**3) / 200.
    theta can be a numpy array.
    """
    return ((1.5 * theta + 0.5) ** 3) / 200

def simulator_1d(theta, sigma=0.31622777):
    """Generate a simulated observation x ~ N(f(theta), sigma^2)."""
    mean = f_1d(theta)
    return np.random.normal(mean, sigma)

def true_log_likelihood_1d(theta, x0=2, sigma=0.31622777):
    """
    Compute the true log-likelihood: log p(x0|theta) = log N(x0; f(theta), sigma^2)
    """
    mean = f_1d(theta)
    return -0.5*np.log(2*np.pi*sigma**2) - 0.5*((x0 - mean)**2)/(sigma**2)

def true_posterior_1d(theta_grid, x0=2, sigma=0.31622777):
    """Compute the (normalized) true posterior on a grid (prior is uniform)."""
    ll = np.exp(true_log_likelihood_1d(theta_grid, x0, sigma))
    post = ll  # uniform prior constant => p(theta|x0) ∝ likelihood
    norm = np.trapz(post, theta_grid) + 1e-10
    return post / norm

# --- Neural Network Ensemble Definition ---------------------------------------

class EmulatorNet(nn.Module):
    """
    A simple feedforward network that predicts the parameters of a Gaussian distribution.
    Input: theta (1D)
    Output: [mu, log_sigma] where sigma = exp(log_sigma)
    Architecture: one hidden layer with 10 tanh units.
    """
    def __init__(self):
        super(EmulatorNet, self).__init__()
        self.fc1 = nn.Linear(1, 10)
        self.act = nn.Tanh()
        self.fc2 = nn.Linear(10, 2)  # outputs: mu and log_sigma
    
    def forward(self, x):
        # x is expected to be of shape (batch, 1)
        h = self.act(self.fc1(x))
        out = self.fc2(h)
        return out  # shape: (batch, 2)

def nll_loss(output, target):
    """
    Negative log likelihood loss for a Gaussian likelihood.
    output: tensor of shape (batch, 2) -> [mu, log_sigma]
    target: tensor of shape (batch, 1)
    """
    mu = output[:, 0].unsqueeze(1)
    log_sigma = output[:, 1].unsqueeze(1)
    sigma = torch.exp(log_sigma) + 1e-6  # ensure positivity
    # Using the closed-form Gaussian NLL loss:
    loss = 0.5 * torch.log(2*np.pi*sigma**2) + 0.5 * ((target - mu)**2)/(sigma**2)
    return torch.mean(loss)

# --- Training Functions for the Ensemble ------------------------------------

def train_ensemble(models, optimizers, theta_tensor, x_tensor, epochs=200):
    """
    Train each ensemble member on the current dataset.
    models: list of EmulatorNet instances.
    optimizers: list of optimizers (one per model).
    theta_tensor: tensor of shape (N, 1)
    x_tensor: tensor of shape (N, 1)
    """
    for epoch in range(epochs):
        for model, optimizer in zip(models, optimizers):
            model.train()
            optimizer.zero_grad()
            outputs = model(theta_tensor)
            loss = nll_loss(outputs, x_tensor)
            loss.backward()
            optimizer.step()

def ensemble_predict(models, theta_candidates, x0=2, sigma_min=1e-6):
    """
    For a set of candidate theta values, compute the synthetic likelihood q(x0|theta)
    for each ensemble member and return an array of shape (n_candidates, n_models).
    The likelihood for a given network is computed as:
        L = N(x0; mu(theta), sigma(theta)^2)
    """
    models_eval = [model.eval() for model in models]
    theta_tensor = torch.tensor(theta_candidates.reshape(-1, 1), dtype=torch.float32)
    with torch.no_grad():
        all_preds = []
        for model in models:
            outputs = model(theta_tensor)  # (n_candidates, 2)
            mu = outputs[:, 0].numpy()
            log_sigma = outputs[:, 1].numpy()
            sigma = np.exp(log_sigma) + sigma_min
            # Compute the likelihood density using the Gaussian PDF
            L = (1.0/np.sqrt(2*np.pi*sigma**2)) * np.exp(-0.5*((x0 - mu)**2)/(sigma**2))
            all_preds.append(L)
        # Shape: (n_models, n_candidates) --> transpose to (n_candidates, n_models)
        all_preds = np.array(all_preds).T
    return all_preds

# --- Acquisition Rule (MaxVar) -----------------------------------------------
def acquisition_maxvar_ensemble(models, grid, x0=2):
    """
    Evaluate the ensemble synthetic likelihood on a grid of theta values,
    and return the theta corresponding to the maximum sample variance across ensemble predictions.
    """
    preds = ensemble_predict(models, grid, x0)
    # For each candidate theta, compute the sample variance across ensemble members
    var_preds = np.var(preds, axis=1)
    idx = np.argmax(var_preds)
    return grid[idx]

def acquisition_uniform_ensemble():
    """Uniform acquisition: return a random theta from the prior."""
    return np.random.uniform(-8, 8)

# --- Main Simulation Loop using the Neural Network Ensemble -----------------
def run_simulation_ensemble(acquisition_rule='maxvar', N_initial=10, N_acq=100, x0=2, sigma=0.31622777,
                            ensemble_size=10, epochs_per_round=200):
    """
    Run one simulation run using a neural network ensemble.
      - acquisition_rule: either 'uniform' or 'maxvar'
      - N_initial: initial sample count
      - N_acq: number of acquisition rounds
      - ensemble_size: number of ensemble members
      - epochs_per_round: training epochs for each acquisition round
    Returns:
      tv_errors: list of total variation distances (one per acquisition round)
      data: (theta, x) arrays for inspection if needed.
    """
    # --- Initialize the training data -------------------------------
    thetas = np.random.uniform(-8, 8, size=N_initial)
    x_samples = np.array([simulator_1d(t, sigma) for t in thetas])
    # Convert to torch tensors (shape: (N, 1))
    theta_tensor = torch.tensor(thetas.reshape(-1, 1), dtype=torch.float32)
    x_tensor = torch.tensor(x_samples.reshape(-1, 1), dtype=torch.float32)
    
    # Initialize ensemble
    models = [EmulatorNet() for _ in range(ensemble_size)]
    optimizers = [optim.Adam(model.parameters(), lr=0.01) for model in models]
    
    # Set up evaluation grid over theta for posterior reconstruction
    grid = np.linspace(-8, 8, 400)
    true_post = true_posterior_1d(grid, x0, sigma)
    
    tv_errors = []
    
    for acq in trange(N_acq):
        # Train the ensemble on the current dataset
        train_ensemble(models, optimizers, theta_tensor, x_tensor, epochs=epochs_per_round)
        
        # Predict synthetic likelihood on the grid
        preds = ensemble_predict(models, grid, x0)  # shape (len(grid), ensemble_size)
        # Average across the ensemble to get synthetic likelihood estimate
        L_est = np.mean(preds, axis=1)
        # Obtain synthetic posterior (proportional to L_est, uniform prior)
        post_est = L_est / (np.trapz(L_est, grid) + 1e-10)
        # Compute TV distance between synthetic posterior and true posterior
        dtheta = grid[1] - grid[0]
        tv = 0.5 * np.sum(np.abs(post_est - true_post)) * dtheta
        tv_errors.append(tv)
        
        # Acquisition: choose new theta according to rule
        if acquisition_rule == 'uniform':
            new_theta = acquisition_uniform_ensemble()
        elif acquisition_rule == 'maxvar':
            new_theta = acquisition_maxvar_ensemble(models, grid, x0)
        else:
            raise ValueError("Unknown acquisition rule: choose 'uniform' or 'maxvar'")
        
        # Simulate a new observation at the new theta
        new_x = simulator_1d(new_theta, sigma)
        # Append new data to the training set
        thetas = np.append(thetas, new_theta)
        x_samples = np.append(x_samples, new_x)
        theta_tensor = torch.tensor(thetas.reshape(-1, 1), dtype=torch.float32)
        x_tensor = torch.tensor(x_samples.reshape(-1, 1), dtype=torch.float32)
        
    return tv_errors, (thetas, x_samples)

# --- Evaluate Ensemble Inference over Multiple Runs --------------------------
def evaluate_ensemble(N_runs=5, N_initial=10, N_acq=100, rule='maxvar', ensemble_size=10, epochs_per_round=200):
    all_tv = []
    for run in range(N_runs):
        print(f'run: {run}')
        tv, _ = run_simulation_ensemble(acquisition_rule=rule, N_initial=N_initial, N_acq=N_acq,
                                        ensemble_size=ensemble_size, epochs_per_round=epochs_per_round)
        all_tv.append(tv)
    return np.array(all_tv)

# Run ensemble simulations with maxvar and uniform acquisitions
print('---- max var ----')
tv_maxvar_ensemble = evaluate_ensemble(rule='maxvar')
print('---- ensemble ----')
tv_uniform_ensemble = evaluate_ensemble(rule='uniform')

acq_steps = np.arange(1, 101)
mean_maxvar = tv_maxvar_ensemble.mean(axis=0)
mean_uniform = tv_uniform_ensemble.mean(axis=0)
sem_maxvar = tv_maxvar_ensemble.std(axis=0) / np.sqrt(tv_maxvar_ensemble.shape[0])
sem_uniform = tv_uniform_ensemble.std(axis=0) / np.sqrt(tv_uniform_ensemble.shape[0])

# --- Plotting the Results -----------------------------------------------------
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
# (a) Plot the simulator function f(theta) and the observed x0
theta_grid = np.linspace(-8, 8, 400)
plt.plot(theta_grid, f_1d(theta_grid), label='f(theta)')
plt.axhline(2, color='r', linestyle='--', label='x0 = 2')
plt.xlabel('theta')
plt.ylabel('f(theta)')
plt.title('Simulator Function')
plt.legend()

plt.subplot(1, 2, 2)
plt.errorbar(acq_steps, mean_uniform, yerr=sem_uniform, label='Uniform Acq', capsize=3)
plt.errorbar(acq_steps, mean_maxvar, yerr=sem_maxvar, label='MaxVar Acq', capsize=3)
plt.xlabel('Acquisitions')
plt.ylabel('TV Distance')
plt.title('Ensemble-Based Inference (1D)')
plt.legend()
plt.tight_layout()
plt.show()


run: 0


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:32<00:00,  3.04it/s]


run: 1


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00,  3.03it/s]


run: 2


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:32<00:00,  3.03it/s]


run: 3


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00,  2.99it/s]


run: 4


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00,  2.99it/s]


run: 0


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:32<00:00,  3.03it/s]


run: 1


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00,  3.02it/s]


run: 2


 38%|██████████████████████████████████████████████████████████████████▉                                                                                                             | 38/100 [00:12<00:20,  3.05it/s]