# From Seismic Waves to Neural Inference: A Hands-On Introduction to Simulation-Based Inference (SBI)

## Installation and packages

Create a new conda environment:

```zsh
conda create -n compearth-workshop python=3.11 -y
conda activate compearth-workshop
```

Install jupyter:

```zsh
pip install jupyter
```

Link your kernel to the conda environment:

```zsh
python -m ipykernel install --user --name compearth-workshop --display-name "CompEarth Workshop"
```

In [None]:
!pip install torch sbi matplotlib numpy pandas pathlib
!pip install git+https://github.com/nschaetti/CompEarth_Workshop.git --upgrade

In [None]:
import os
import random
import numpy as np
import torch

seed = 42
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
rng = np.random.default_rng(seed)

In [None]:
!nvidia-smi

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## Introduction — The Inverse Problem in Earth Sciences

### Understanding the inverse problem

### The Bayesian view

### Why Simulation-Based Inference (SBI)?

### What you will learn today

## 2 The Forward Model — Simulating Surface Wave Dispersion

### 2.1 Physical background

### 2.2 Using the simulator `surfdisp2k25`

#### Playing with the simulator

HuggingFace Space: https://huggingface.co/spaces/MIGRATE/surfdisp2k25

The `surfdisp2k25` simulator computes **Rayleigh-wave dispersion curves**, showing how the **group velocity** of surface waves varies with **period** for a given layered Earth model.

Each model is described by a set of physical and numerical parameters.

**Physical parameters**

* `n_layers` – number of layers, including the half-space at the bottom.
* `vpvs` – ratio ($V_p/V_s$), typically between 1.7 and 1.9 for crustal materials.
* `thicknesses` – list of layer thicknesses in kilometers. The last layer has a thickness of 0 to represent an infinite half-space.
* `vs_layers` – list of shear-wave velocities ($V_s$) in kilometers per second, usually increasing with depth.
* `ρ` – density of each layer, estimated here from an empirical relation ($ρ = 0.32 + 0.77 \times V_p$).

All these parameters are concatenated into a single input tensor:

$$
\theta = [,n,, V_P/V_S,, h_1, h_2, …, h_{N_{\max}},, V_{S,1}, V_{S,2}, …, V_{S,N_{\max}},]
$$

For a single model, `theta` has shape `(1, 2 + 2*Nmax)`.
When simulating several models at once, the batch dimension corresponds to the number of models.

**Numerical parameters**

* `p_min`, `p_max` – minimum and maximum periods in seconds defining the frequency range.
* `kmax` – number of discrete period samples between `p_min` and `p_max`.
* `iflsph` – Earth geometry flag: 0 for flat Earth, 1 for spherical correction.
* `iwave` – wave type: 1 for Love waves, 2 for Rayleigh waves.
* `mode` – mode number: 1 for the fundamental mode.
* `igr` – 1 to compute group velocity (dispersion curve), 0 for phase velocity only.

The simulator returns a tensor of shape `(B, kmax)` where `B` is the number of models.
Each row corresponds to a dispersion curve ($c(T)$) representing group velocity (in km/s) as a function of period (in seconds).

We now test the simulator on a simple two-layer model, representing a soft sedimentary layer over a stiffer half-space.

| Parameter        | Description                  | Value              |
| ---------------- | ---------------------------- | ------------------ |
| `n_layers`       | Number of layers             | 2                  |
| `vpvs`           | (V_P/V_S) ratio              | 1.75               |
| `thicknesses`    | Layer thicknesses (km)       | [2.0, 0.0]         |
| `vs_layers`      | Shear-wave velocities (km/s) | [2.5, 3.5]         |
| `p_min`, `p_max` | Period range (s)             | [0.5, 20.0]        |
| `kmax`           | Number of periods            | 60                 |
| `iflsph`         | Earth geometry               | 0 (flat Earth)     |
| `iwave`          | Wave type                    | 2 (Rayleigh)       |
| `mode`           | Mode number                  | 1 (fundamental)    |
| `igr`            | Velocity type                | 1 (group velocity) |


#### 2.2.1. Testing the simulator with a simple two-layer model

Let’s start by defining a model with **two layers** over a half-space.

In [None]:
from compearth.extensions.surfdisp2k25 import dispsurf2k25_simulator

# Physical model parameters
n_layers = 4                # Number of layers (including the half-space)
vpvs = 1.75                 # Vp/Vs ratio (typical for crustal rocks)
thicknesses = [0.9, 1.2, 0.34, 0.0]    # Layer thicknesses in km (0 for half-space)
vs_layers = [1.5, 1.2, 2.5, 3.5]      # Shear-wave velocities in km/s (increase with depth)

# Combine all parameters into θ = [n, vpvs, h..., vs...]
theta = torch.tensor([[n_layers, vpvs] + thicknesses + vs_layers], dtype=torch.float32)

# Numerical simulation parameters
p_min, p_max = 1.0, 5.0    # Period range in seconds
kmax = 108                   # Number of discrete period samples
iflsph = 0                  # Flat Earth approximation (0 = flat, 1 = spherical)
iwave = 2                   # Wave type (2 = Rayleigh, 1 = Love)
mode = 1                    # Fundamental mode
igr = 1                     # Compute group velocity (1 = group, 0 = phase)

print(f"Theta shape: {theta.shape}")

In [None]:
# Run the simulator
disp_curve = dispsurf2k25_simulator(
    theta=theta,
    p_min=p_min,
    p_max=p_max,
    kmax=kmax,
    iflsph=iflsph,
    iwave=iwave,
    mode=mode,
    igr=igr,
)

print("Dispersion curve shape:", disp_curve.shape)

In [None]:
disp_curve

In [None]:
np.linspace(p_min, p_max, disp_curve.shape[1])

In [None]:
theta

#### 2.2.2 Visualizing the model and its generated dispersion curve

To better understand the simulator output, we use a helper function `plot_velocity_and_dispersion` that displays both the **velocity structure** (shear-wave velocity versus depth) and the resulting **dispersion curve** (group velocity versus period).

The function takes as input:

* `thicknesses` – list of layer thicknesses in kilometers. The last one can be set to 0 to indicate the half-space of infinite thickness.
* `vs_layers` – list of shear-wave velocities (V_S) in kilometers per second, one per layer.
* `disp_curve` – the dispersion curve returned by `dispsurf2k25_simulator`, either a `torch.Tensor` or a NumPy array.
* `p_min`, `p_max` – minimum and maximum periods defining the frequency range in seconds.
* `kmax` – number of discrete periods used to compute the curve.
* `fig_size` – scaling factor for the figure size (default = 2).
* `dpi` – rendering resolution for the figure (default = 300).

The plot on the **left** shows the 1-D velocity model as a step function of depth.
Each horizontal segment corresponds to a homogeneous layer.
The plot on the **right** shows the simulated Rayleigh-wave group velocity as a function of period.
Short periods sample the shallow layers, while long periods are sensitive to deeper structures.

We now display both the model and the corresponding dispersion curve for our two-layer example.

In [None]:
from compearth.utils import plot_velocity_and_dispersion

# Plot the model and the curve
plot_velocity_and_dispersion(
    theta=theta,
    disp_curve=disp_curve,
    p_min=p_min,
    p_max=p_max,
    kmax=kmax,
    dpi=300
)


#### 2.2.3. Generating multiple random models from a prior

To explore how different subsurface structures affect the dispersion curves, we now generate **random velocity models** using a simple **prior distribution** over the model parameters.

The goal is to produce a small set of plausible Earth models that will later be used to train or test an inference network.

Each model is defined by the same structure as before:

* `n` – number of layers, drawn randomly between 2 and a maximum value (`max_layers`).
* `vpvs` – fixed P-to-S velocity ratio (here 1.75).
* `h` – random layer thicknesses, uniformly sampled between 0.5 km and 5.0 km.
* `vs` – random shear-wave velocities ($V_s$), uniformly sampled between 1.5 km/s and 4.5 km/s.

For each model, these values are concatenated into a parameter vector

$$
\theta = [n, vpvs, h_1, …, h_{max}, V_{S,1}, …, V_{S,max}]
$$

We draw several such vectors to obtain a small **ensemble of synthetic models**.
Note that the last layer represents the **half-space** (infinite depth), so its thickness must remain zero and its velocity nonzero to represent the solid substrate.

In [None]:
import numpy as np
from compearth.utils import sample_models

# Sample from the prior
theta_prior, z_vnoi = sample_models(
    n_samples=8,
    layers_min=2,
    layers_max=10,
    z_min=0.0,
    z_max=5.0,
    vs_min=1.5,
    vs_max=4.5,
    thick_min=0.1,
    sort_vs=True
)
theta_prior.shape

In [None]:
theta_prior

In [None]:
z_vnoi

#### 2.2.4. Running the simulator on random models

We now run the simulator on the randomly sampled models.
Each parameter vector in `theta_prior` is passed to `dispsurf2k25_simulator`, which computes the corresponding dispersion curve.
The result `disp_curves` is a tensor of shape `(n_samples, kmax)`, where each row represents the group velocity curve ($c(T)$) for one model.
We also define the array of sampled periods to use for plotting.

In [None]:
# Run the simulator on the sampled models
disp_curves = dispsurf2k25_simulator(
    theta=theta_prior,
    p_min=p_min,
    p_max=p_max,
    kmax=kmax,
    iflsph=iflsph,
    iwave=iwave,
    mode=mode,
    igr=igr
)
periods = np.linspace(p_min, p_max, kmax)
disp_curves.shape

In [None]:
theta_prior[0]

In [None]:
z_vnoi[0]

In [None]:
# Plot the model and the curve
plot_velocity_and_dispersion(
    theta=theta_prior,
    disp_curve=disp_curves,
    z_vnoi=z_vnoi,
    p_min=p_min,
    p_max=p_max,
    kmax=kmax,
    dpi=300,
    fig_size=1.5
)

*Observation:*

Each dispersion curve represents a distinct velocity model.
You can already see the **non-uniqueness** of the inverse problem: different subsurface structures can yield similar dispersion curves.

#### 2.2.5. Discussion — Why this matters

This simulator encapsulates our **forward physical model**, ($f(\theta) \rightarrow x$).
The next step will be to **invert** this mapping: given an observed curve ($x$), infer plausible models ($\theta$).
That’s precisely what *Simulation-Based Inference (SBI)* aims to achieve.

## Generating a small dataset

### Defining the problem

In this step, we select one random model from the previously sampled set.
Each model is represented as a parameter vector ($\theta$), which encodes:

1. the **number of layers** ($n$),
2. the **($V_p/V_s$)** ratio,
3. the **layer thicknesses** $(h_1, h_2, \dots )$,
4. and the **shear-wave velocities** $(V_{s,1}, V_{s,2}, \dots )$.

This vector ($\theta$) fully defines the layered Earth model.
We’ll inspect it to see how the model parameters are structured before turning it into a continuous velocity profile.

In [None]:
from compearth.utils import theta_to_velocity_profile

# Take one of the random models
theta_example = theta_prior[0]
theta_example

### Conversion to continuous velocity profile

We now convert the parametric model ($\theta$) into a **continuous velocity profile** ($V_s(z)$),
that is, a regularly sampled velocity map as a function of depth.

The function `theta_to_velocity_profile` performs the following steps:

* it takes a vector ($\theta$) as input,
* reconstructs the layer boundaries from their thicknesses,
* assigns each depth interval its corresponding shear velocity ($V_s$),
* and returns two arrays:

  * `depth`: the depth samples (in km),
  * `vs_profile`: the corresponding ($V_s$) values.

Here, we sample 60 points uniformly between 0 and 60 km depth.
The resulting table shows how shear-wave velocity varies with depth for this particular model.

In [None]:
# Convert to continuous velocity profile
depth, vs_profile = theta_to_velocity_profile(
    theta_example, 
    depth_max=5.0, 
    n_points=60
)
print(f"Depth shape: {depth.shape}")
print(f"Vs profile shape: {vs_profile.shape}")

In [None]:
# Display depth and Vs in a small table
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 4), dpi=300)
plt.scatter(depth, vs_profile, color='crimson', s=25, marker='o', zorder=3, label='Sample points')
plt.ylabel("Vs [km/s]")
plt.xlabel("Depth [km]")
plt.title("Continuous velocity profile from θ")
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# Generate dataset
# n_samples = 10_000
n_samples = 100
max_layers = 10
z_min = 0.0
z_max = 5.0
vs_min = 0.5
vs_max = 4.0
thick_min = 0.1
n_points_depth = 60
kmax = 108

# Sampling range
p_min, p_max = 1.0, 15.0

In [None]:
theta_models, z_vnoi = sample_models(
    n_samples=n_samples,
    layers_min=2,
    layers_max=max_layers,
    z_min=0.0,
    z_max=z_max,
    vs_min=vs_min,
    vs_max=vs_max,
    thick_min=thick_min,
    sort_vs=True
)
print(f"Theta models: {theta_models.shape}")

In [None]:
# Run simulator to get dispersion curves
disp_curves = dispsurf2k25_simulator(
    theta=theta_models,
    p_min=p_min,
    p_max=p_max,
    kmax=kmax,
    iflsph=iflsph,
    iwave=iwave,
    mode=mode,
    igr=igr,
    progress=True
)

In [None]:
# Build velocity profiles
vel_maps = []
for i in range(n_samples):
    _, vs_profile = theta_to_velocity_profile(
        theta_models[i],
        depth_max=z_max,
        n_points=n_points_depth
    )
    vel_maps.append(vs_profile)
# end for

In [None]:
vel_maps = np.stack(vel_maps, axis=0)   # shape: (B, 60)
z = np.linspace(0, z_max, n_points_depth)
periods = np.linspace(p_min, p_max, kmax)

# ---- 4. Convert to torch tensors ----
theta = torch.tensor(vel_maps, dtype=torch.float32)   # (B, 60)
x = disp_curves.to(torch.float32)                     # (B, 108)
z = torch.tensor(z, dtype=torch.float32)              # (60,)
periods = torch.tensor(periods, dtype=torch.float32)  # (108,)                                     # shape (N, 22)

# ---- 5. Display summary ----
print(f"theta (velocity maps): {tuple(theta.shape)}")
print(f"x (dispersion curves): {tuple(x.shape)}")
print(f"z (depth samples):     {tuple(z.shape)}")
print(f"periods:               {tuple(periods.shape)}")

In [None]:
print(f"z: {z}")
print(f"periods: {periods}")

## Learning the Inverse Model — Neural Posterior Estimation (NPE)

### Conceptual overview

### Setting up the SBI pipeline

#### Defining the prior

In [None]:
from sbi.utils import BoxUniform

# Define the prior used for generation
prior = BoxUniform(
    low=torch.full((60,), vs_min, device=device),
    high=torch.full((60,), vs_max, device=device)
)

#### Defining the model and adding data

In [None]:
from sbi.inference import SNPE

inference = SNPE(
    prior=prior,
    density_estimator="maf",
    device=device
)

In [None]:
theta = theta.to(device)
x = x.to(device)

In [None]:
inference = inference.append_simulations(theta, x)
inference

#### Training

In [None]:
batch_size = 1024
learning_rate = 1e-4
validation_fraction = 0.1
stop_after_epochs = 200
max_num_epochs = 1000
show_train_summary = True

density_estimator = inference.train(
    training_batch_size=batch_size,
    learning_rate=learning_rate,
    validation_fraction=validation_fraction,
    stop_after_epochs=stop_after_epochs,
    max_num_epochs=max_num_epochs,
    show_train_summary=show_train_summary,
)

In [None]:
# inference._summary["training_log_probs"]
# inference._summary["validation_log_probs"]
print(f"Epochs trained: {inference._summary['epochs_trained'][0]}")
print(f"Best validation loss: {inference._summary['best_validation_loss'][0]}")

#### Building the posterior

In [None]:
# Build posterior
posterior = inference.build_posterior(density_estimator)
print(posterior)

#### Visualising training and validation curves

In [None]:
from compearth.utils import plot_training_summary
plot_training_summary(inference)

### Training with a bigger dataset

#### Loading the data

In [None]:
from huggingface_hub import hf_hub_download
import pickle

repo_id = "MIGRATE/Dispsurf96-Roccastrada-10k"

path = hf_hub_download(repo_id=repo_id, filename="data.pkl", repo_type="dataset")
with open(path, "rb") as f:
    data = pickle.load(f)
# end with

path = hf_hub_download(repo_id=repo_id, filename="val_data.pkl", repo_type="dataset")
with open(path, "rb") as f:
    val_data = pickle.load(f)
# end with

In [None]:
theta = data['theta']
x = data['x']
z = data['z']
periods = data['periods']
prior = data['prior']

vs_min = prior['vs'][0]
vs_max = prior['vs'][1]
z_min = prior['z'][0]
z_max = prior['z'][1]
layer_min = prior['layers'][0]
layer_max = prior['layers'][1]

In [None]:
print("Training:")
print(f"theta: {theta.shape}")
print(f"x: {x.shape}")
print(f"z: {z.shape}")
print(f"periods: {periods.shape}")
print(f"prior: {prior}")

print("Validation:")
print(f"theta: {val_data['theta'].shape}")
print(f"x: {val_data['x'].shape}")

#### Training the model

In [None]:
from sbi.inference import SNPE

# Define the prior used for generation
prior = BoxUniform(
    low=torch.full((60,), vs_min, device=device),
    high=torch.full((60,), vs_max, device=device)
)

inference = SNPE(
    prior=prior,
    density_estimator="maf",
    device=device
)

theta = theta.to(device)
x = x.to(device)

inference = inference.append_simulations(theta, x)

batch_size = 1024
learning_rate = 1e-4
validation_fraction = 0.2
stop_after_epochs = 300
max_num_epochs = 3000
show_train_summary = True

density_estimator = inference.train(
    training_batch_size=batch_size,
    learning_rate=learning_rate,
    validation_fraction=validation_fraction,
    stop_after_epochs=stop_after_epochs,
    max_num_epochs=max_num_epochs,
    show_train_summary=show_train_summary
)

# Build posterior
posterior = inference.build_posterior(density_estimator)

In [None]:
# inference._summary["training_log_probs"]
# inference._summary["validation_log_probs"]
print(f"Epochs trained: {inference._summary['epochs_trained'][0]}")
print(f"Best validation loss: {inference._summary['best_validation_loss'][0]}")

In [None]:
plot_training_summary(inference)

### Visualizing the learned posterior

In [None]:
theta_val = data['theta']
x_val = data['x']

theta_val = theta_val.to(device)
x_val = x_val.to(device)

In [None]:
obs_idx = 6
x_obs = x_val[obs_idx:obs_idx+1, :]
print(f"x_obs: {x_obs.shape}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 3), dpi=300)

axes[1].plot(periods, x_obs[0].cpu(), "blue", linewidth=2, alpha=1)
axes[1].set_xlim(p_min, p_max)
axes[1].set_xlabel("Period [T]")
axes[1].set_ylabel("Vg [km/s]")
axes[1].set_title(f"Dispersion curve (Vg) for observation {obs_idx}")
axes[1].grid(True, linestyle="--", alpha=0.5)

axes[0].plot(z, theta_val[obs_idx].cpu(), 'red', linewidth=2)
axes[0].set_xlim(0, z_max)
axes[0].set_ylim(vs_min, vs_max)
axes[0].set_xlabel("Depth [km]")
axes[0].set_ylabel("Vs [km/s]")
axes[0].set_title(f"Velocity map (Vs) for model {obs_idx}")
axes[0].grid(True, linestyle="--", alpha=0.5)

plt.show()

In [None]:
samples = posterior.sample((100,), x=x_obs)
print(f"Sample shape: {samples.shape}")

In [None]:
plt.figure(figsize=(8, 4), dpi=300)
for i in range(100):
    plt.plot(z, samples[i].cpu(), linewidth=1, alpha=0.5)
# end for
plt.plot(z, theta_val[obs_idx].cpu(), 'red', linewidth=2)
plt.xlim(0, z_max)
plt.ylim(vs_min, vs_max)
plt.xlabel("Depth [km]")
plt.ylabel("Vs [km/s]")
plt.title(f"Posterior samples and ground truth for model {obs_idx}")
plt.grid(True, linestyle="--", alpha=0.5)
plt.show()

In [None]:
from compearth.utils import plot_posterior_grid
plot_posterior_grid(
    posterior=posterior,
    x=x_val,
    theta=theta_val,
    z=z,
    n_row=5,
    n_col=3,
    n_samples=100,
    vs_min=vs_min,
    vs_max=vs_max,
    figsize=(9, 12),
    device=device
)

### Infering flat models

In [None]:
# === 6. Visualisation ===
plt.figure(figsize=(10, 5))

# Posterior samples (profils Vs)
plt.plot(z, samples[0].cpu().numpy(), alpha=1, color="blue", label="Posterior sample")

# Vrai profil en rouge épais
plt.plot(z, theta_val[obs_idx].cpu().numpy(), color="red", linewidth=2, label="Ground truth")
plt.ylim(vs_min - 0.1, vs_max + 0.1)
plt.xlim(z.min(), z.max())
plt.xlabel("Depth (z)")
plt.ylabel("Vs")
plt.title("A single posterior sample vs ground truth")
plt.legend()
plt.show()

In [None]:
import ruptures as rpt

sample = samples[0].cpu().numpy()

algo = rpt.Pelt(model="l2").fit(sample)
bkps = algo.predict(pen=5)
vs_flat = np.zeros_like(sample)
start = 0
for end in bkps:
    print(f"Region found with range {start}-{end}")
    vs_flat[start:end] = np.mean(sample[start:end])
    start = end
# end for

# Posterior samples (profils Vs)
plt.figure(figsize=(8, 6), dpi=300)
plt.plot(z, theta_val[obs_idx].cpu().numpy(), color="red", linewidth=3, label="Ground truth", alpha=0.2)
plt.plot(z, samples[0].cpu().numpy(), alpha=0.5, color="blue", label="Posterior sample", linewidth=1)
plt.plot(z, vs_flat, alpha=1, color="orange", label="Flatten posterior sample", linewidth=2)
plt.ylim(vs_min - 0.1, vs_max + 0.1)
plt.xlim(z.min(), z.max())
plt.xlabel("Depth (z)")
plt.ylabel("Vs")
plt.title("Posterior (sampling Vs) vs ground truth")
plt.legend()
plt.show()

In [None]:
print(theta_val.shape)
print(z.shape)

In [None]:
from compearth.utils import plot_flatten_grid

# Show the effect of different values 
# for the penality parameter
plot_flatten_grid(
    sample=samples[0].cpu().numpy(),
    theta=theta_val[obs_idx].cpu(),
    z=z.cpu(),
    n_row=5,
    n_col=3,
    penalty_min=0,
    penalty_max=15,
    vs_min=vs_min,
    vs_max=vs_max,
    figsize=(8, 11)
)

In [None]:
from typing import Union

def plot_flatten_models(
    samples: Union[np.ndarray, torch.Tensor],
    theta: Union[np.ndarray, torch.Tensor],
    z: Union[np.ndarray, torch.Tensor],
    penalty: float = 1.0,
    vs_min: float = 1.5,
    vs_max: float = 4.5,
    figsize: tuple = (8, 5),
    dpi: int = 200,
):
    """
    Flatten all posterior samples using PELT with a fixed penalty
    and display them together with the ground truth profile.

    Parameters
    ----------
    samples : np.ndarray or torch.Tensor
        Posterior samples (N, D_z)
    theta : np.ndarray or torch.Tensor
        Ground truth velocity profile (D_z,)
    z : np.ndarray or torch.Tensor
        Depth coordinates (D_z,)
    penalty : float
        Penalty value for PELT algorithm.
    vs_min, vs_max : float
        Limits for Vs axis.
    figsize : tuple
        Figure size.
    dpi : int
        Plot resolution.
    """
    # --- Convert tensors to numpy ---
    if hasattr(samples, "detach"):
        samples = samples.detach().cpu().numpy()
    if hasattr(theta, "detach"):
        theta = theta.detach().cpu().numpy()
    if hasattr(z, "detach"):
        z = z.detach().cpu().numpy()

    n_samples = samples.shape[0]

    # --- Prepare figure ---
    plt.figure(figsize=figsize, dpi=dpi)

    # Plot flattened posterior samples
    for i in range(n_samples):
        s = samples[i]
        algo = rpt.Pelt(model="l2").fit(s)
        bkps = algo.predict(pen=penalty)

        vs_flat = np.zeros_like(s)
        start = 0
        for end in bkps:
            vs_flat[start:end] = np.mean(s[start:end])
            start = end

        plt.plot(z, vs_flat, alpha=0.3, linewidth=1)
    # end for

    # Plot ground truth
    plt.plot(z, theta, color="red", linewidth=2, label="Ground truth", alpha=0.7)

    # Style
    plt.xlabel("Depth [km]")
    plt.ylabel("Vs [km/s]")
    plt.title(f"Flattened posterior samples (penalty={penalty:.2f})")
    plt.ylim(vs_min - 0.1, vs_max + 0.1)
    plt.xlim(z.min(), z.max())
    plt.grid(True, linestyle="--", alpha=0.4)
    plt.legend()
    plt.show()
# end def plot_flatten_models

In [None]:
plot_flatten_models(
    samples=samples[:20],   # par exemple 50 échantillons
    theta=theta_val[obs_idx],
    z=z,
    penalty=0.0,
    vs_min=vs_min,
    vs_max=vs_max,
)

In [None]:
def plot_random_flatten_models(
    samples: Union[np.ndarray, torch.Tensor],
    theta: Union[np.ndarray, torch.Tensor],
    z: Union[np.ndarray, torch.Tensor],
    penalty_min: float = 0.1,
    penalty_max: float = 5.0,
    vs_min: float = 1.5,
    vs_max: float = 4.5,
    figsize: tuple = (8, 5),
    dpi: int = 200,
    seed: int = 42,
):
    """
    Flatten each posterior sample with a random penalty drawn uniformly
    between `penalty_min` and `penalty_max`, and display all results.

    Parameters
    ----------
    samples : np.ndarray or torch.Tensor
        Posterior samples (N, D_z)
    theta : np.ndarray or torch.Tensor
        Ground truth velocity profile (D_z,)
    z : np.ndarray or torch.Tensor
        Depth coordinates (D_z,)
    penalty_min, penalty_max : float
        Range for randomly sampled penalties.
    vs_min, vs_max : float
        Velocity axis limits.
    figsize : tuple
        Figure size.
    dpi : int
        Plot resolution.
    seed : int
        Random seed for reproducibility.
    """
    # --- Convert tensors to numpy ---
    if hasattr(samples, "detach"):
        samples = samples.detach().cpu().numpy()
    if hasattr(theta, "detach"):
        theta = theta.detach().cpu().numpy()
    if hasattr(z, "detach"):
        z = z.detach().cpu().numpy()

    rng = np.random.default_rng(seed)
    n_samples = samples.shape[0]

    plt.figure(figsize=figsize, dpi=dpi)

    # --- Process each sample ---
    for i in range(n_samples):
        s = samples[i]
        penalty = rng.uniform(penalty_min, penalty_max)

        algo = rpt.Pelt(model="l2").fit(s)
        bkps = algo.predict(pen=penalty)

        vs_flat = np.zeros_like(s)
        start = 0
        for end in bkps:
            vs_flat[start:end] = np.mean(s[start:end])
            start = end

        plt.plot(z, vs_flat, alpha=0.3, linewidth=1)
    # end for

    # Plot ground truth
    plt.plot(z, theta, color="red", linewidth=2, alpha=0.7, label="Ground truth")

    plt.xlabel("Depth [km]")
    plt.ylabel("Vs [km/s]")
    plt.title(f"Random flattening (penalty ∈ [{penalty_min}, {penalty_max}])")
    plt.ylim(vs_min - 0.1, vs_max + 0.1)
    plt.xlim(z.min(), z.max())
    plt.grid(True, linestyle="--", alpha=0.4)
    plt.legend()
    plt.show()
# end def plot_random_flatten_models

In [None]:
plot_random_flatten_models(
    samples=samples[:50],
    theta=theta_val[obs_idx],
    z=z,
    penalty_min=0.0,
    penalty_max=10.0,
    vs_min=vs_min,
    vs_max=vs_max,
)

In [None]:
def flatten_models(
    samples: Union[np.ndarray, torch.Tensor],
    penalty: float = 1.0,
    model: str = "l2",
) -> np.ndarray:
    """
    Flatten posterior samples using PELT segmentation with a fixed penalty.

    Parameters
    ----------
    samples : np.ndarray or torch.Tensor
        Posterior samples of shape (N, D_z)
    penalty : float
        Penalty value for PELT algorithm.
    model : str
        Cost model for ruptures (default: "l2").

    Returns
    -------
    vs_flat_all : np.ndarray
        Flattened velocity profiles of shape (N, D_z)
    """
    # --- Convert to numpy ---
    if hasattr(samples, "detach"):
        samples = samples.detach().cpu().numpy()

    n_samples, depth_points = samples.shape
    vs_flat_all = np.zeros_like(samples)

    # --- Apply PELT to each posterior sample ---
    for i in range(n_samples):
        s = samples[i]
        algo = rpt.Pelt(model=model).fit(s)
        bkps = algo.predict(pen=penalty)

        vs_flat = np.zeros_like(s)
        start = 0
        for end in bkps:
            vs_flat[start:end] = np.mean(s[start:end])
            start = end

        vs_flat_all[i] = vs_flat
    # end for

    return vs_flat_all
# end def flatten_models

In [None]:
vs_flat_all = flatten_models(samples=samples, penalty=1.0)
print(vs_flat_all.shape)  # (100, D_z)

In [None]:
def random_flatten_models(
    samples: Union[np.ndarray, torch.Tensor],
    penalty_min: float = 0.1,
    penalty_max: float = 5.0,
    model: str = "l2",
    seed: int = 42,
) -> np.ndarray:
    """
    Flatten posterior samples using PELT segmentation, with a random penalty
    drawn uniformly between `penalty_min` and `penalty_max` for each sample.

    Parameters
    ----------
    samples : np.ndarray or torch.Tensor
        Posterior samples of shape (N, D_z)
    penalty_min, penalty_max : float
        Range of random penalties for the PELT algorithm.
    model : str
        Cost model for ruptures (default: "l2").
    seed : int
        Random seed for reproducibility.

    Returns
    -------
    vs_flat_all : np.ndarray
        Flattened velocity profiles of shape (N, D_z)
    """
    # --- Convert to numpy ---
    if hasattr(samples, "detach"):
        samples = samples.detach().cpu().numpy()

    n_samples, depth_points = samples.shape
    vs_flat_all = np.zeros_like(samples)

    rng = np.random.default_rng(seed)

    # --- Flatten each sample with a random penalty ---
    for i in range(n_samples):
        s = samples[i]
        penalty = rng.uniform(penalty_min, penalty_max)

        algo = rpt.Pelt(model=model).fit(s)
        bkps = algo.predict(pen=penalty)

        vs_flat = np.zeros_like(s)
        start = 0
        for end in bkps:
            vs_flat[start:end] = np.mean(s[start:end])
            start = end

        vs_flat_all[i] = vs_flat
    # end for

    return vs_flat_all
# end def random_flatten_models

In [None]:
vs_flat_rand = random_flatten_models(
    samples=samples,
    penalty_min=0.0,
    penalty_max=10.0,
    seed=123,
)
print(vs_flat_rand.shape)  # (100, D_z)

### Saving the model

In [None]:
# torch.save({
#     "model_type": type(inference).__name__,
#     "neural_net_state_dict": inference._neural_net.state_dict(),
#     "prior": inference._prior,
# }, "test.pt")

In [None]:
# checkpoint = torch.load("test.pt", map_location="cpu", weights_only=False)

# inference = SNPE(checkpoint["prior"])
# inference._neural_net.load_state_dict(checkpoint["neural_net_state_dict"])

# posterior = inference.build_posterior()

## Working with Pre-Trained Models

### Motivation

### Loading a pre-trained model

### Sampling and visualization

## Model Validation and Diagnostics

### Log-probability

Once the posterior distribution has been trained, there are **two complementary ways to use it**:

1. **Sampling** from the posterior: given an observed dispersion curve $x_obs$, we can draw samples from $p(\theta \mid x_{obs})$. Each sample corresponds to a possible **Earth model** — a velocity profile consistent with the observed data. Sampling allows us to *visualize uncertainty*: it shows how many different models could explain the same observation.

2. **Evaluating the posterior density**: conversely, we can compute the **log-probability** of a specific model $\theta$ given the observation $x_{obs}$:

   $$
   \large
   \log p(\theta \mid x_{obs})
   $$
   
   This value tells us **how plausible** that model is, according to the learned posterior. The higher (less negative) the log-probability, the more confident the model is that $\theta$ could have generated the data $x_{obs}$.

#### Intuition

Think of the posterior as a *landscape of plausibility*:

* **Sampling** means picking random points from this landscape, exploring regions of high probability.
* **Evaluating the log-probability** means checking how "high" or "low" a given point lies in this landscape.

Mathematically, the log-probability corresponds to the logarithm of the posterior density at a specific location $\theta$.
It is often preferred over the probability itself because densities can be extremely small, and the log-scale gives a more stable numerical measure.

#### How to interpret log-probability values

* **High (less negative)** → The posterior assigns *high plausibility* to this model: it fits the data $x$ well.
* **Low (very negative)** → The posterior considers this model *unlikely* under the given observation.
* **Average log-probability** over many samples → A useful quantitative measure of posterior quality.
  If the average log-probability is too low, the posterior might be **overconfident** (too narrow) or **miscalibrated**.

#### How the log-probability is computed

Voici la version Markdown prête à copier dans ton notebook Jupyter, avec les formules correctement formatées :

---

### ⚙️ How the log-probability is computed

Under the hood, the posterior is represented by a **normalizing flow** — an *invertible neural network* that transforms a simple base distribution (usually a standard Gaussian) into a complex, structured one that matches the true posterior $p(\theta \mid x)$.

Because this transformation ( f ) is **invertible** and **differentiable**, the probability density of a sample $\theta$ can be computed exactly using the **change of variables** formula:

$$
p(\theta \mid x) = p_z(f^{-1}(\theta)) \Big| \det J_{f^{-1}}(\theta) \Big|
$$

where:

* $p_z$ is the base density (often $\mathcal{N}(0, I)$),
* $f^{-1}(\theta)$ maps the parameter sample back into the latent space,
* and $J_{f^{-1}}(\theta)$ is the Jacobian of the inverse transformation.

Taking the logarithm gives the **log-probability**:

$$
\log p(\theta \mid x) = \log p_z(f^{-1}(\theta)) + \log \Big| \det J_{f^{-1}}(\theta) \Big|
$$

This property is what makes normalizing flows so powerful:
they allow both **sampling** (by applying $f$) and **density evaluation** (via $f^{-1}$) in a mathematically consistent way.


In [None]:
from tqdm import tqdm

# Let's compute the average log-prob for 20 samples
n_samples = 20

total_log_probs = 0.0
for i in tqdm(range(n_samples), desc="Computing log-prob"):
    log_p = posterior.log_prob(theta_val[i:i+1], x=x_val[i:i+1])
    total_log_probs += log_p.cpu().item()
# end for 

# Log. prob
print(f"Log-prob: {total_log_probs/n_samples}")

### Posterior Predictive Check (PPC)

In [None]:
from typing import Union, Tuple

def posterior_to_theta(
    z: Union[np.ndarray, torch.Tensor],
    vs_batch: Union[np.ndarray, torch.Tensor],
    vpvs: float = 1.75,
    penalty_min: float = 0.1,
    penalty_max: float = 5.0,
    model: str = "l2",
    max_layers: int = 20,
    seed: int = 42,
) -> Tuple[torch.Tensor, list, np.ndarray]:
    """
    Convert multiple posterior velocity samples into layered Earth models (θ),
    using PELT segmentation with a random penalty drawn uniformly for each sample.
    Breakpoints are returned in depth (km) rather than indices.

    Parameters
    ----------
    z : np.ndarray or torch.Tensor
        Depth coordinates (D_z,), in km.
    vs_batch : np.ndarray or torch.Tensor
        Velocity profiles (N, D_z)
    vpvs : float
        Fixed Vp/Vs ratio for all models
    penalty_min, penalty_max : float
        Range for random penalties
    model : str
        Cost model for ruptures (default: 'l2')
    max_layers : int
        Maximum number of layers (for padding in θ)
    seed : int
        Random seed for reproducibility

    Returns
    -------
    theta_all : torch.Tensor
        Model parameters of shape (N, 2 + 2 * max_layers)
        [n_layers, vpvs, h_1...h_max, vs_1...vs_max]
    bkps_all_km : list[list[float]]
        List of breakpoint depths (km) for each sample
    penalties : np.ndarray
        Penalties drawn for each sample
    """
    # --- Convert to numpy ---
    if hasattr(vs_batch, "detach"):
        vs_batch = vs_batch.detach().cpu().numpy()
    # end if
    
    if hasattr(z, "detach"):
        z = z.detach().cpu().numpy()
    # end if

    n_samples, n_depths = vs_batch.shape
    theta_all = []
    bkps_all_km = []
    penalties = []

    rng = np.random.default_rng(seed)

    for i in range(n_samples):
        vs = vs_batch[i]
        penalty = rng.uniform(penalty_min, penalty_max)
        penalties.append(penalty)

        algo = rpt.Pelt(model=model).fit(vs)
        bkps = algo.predict(pen=penalty)

        # --- Convert breakpoints (indices) -> depths (km)
        bkps_depth = [z[min(end - 1, len(z) - 1)] for end in bkps if end <= len(z)]
        bkps_all_km.append(bkps_depth)

        # --- Build layers ---
        vs_layers = []
        h_layers = []
        start = 0

        for end in bkps:
            segment_vs = vs[start:end]
            segment_z = z[start:end]
            mean_vs = np.mean(segment_vs)
            vs_layers.append(mean_vs)
            if len(segment_z) > 0:
                h_layers.append(segment_z[-1] - segment_z[0])
            else:
                h_layers.append(0.0)
            # end if
            start = end
        # end for

        # Half-space (last layer)
        if h_layers:
            h_layers[-1] = 0.0
        # end if

        # Padding
        n_layers = len(vs_layers)
        h_padded = np.zeros(max_layers)
        vs_padded = np.zeros(max_layers)
        h_padded[:n_layers] = h_layers[:max_layers]
        vs_padded[:n_layers] = vs_layers[:max_layers]

        # Assemble θ vector
        theta = [n_layers, vpvs] + h_padded.tolist() + vs_padded.tolist()
        theta_all.append(theta)
    # end for

    theta_all = torch.tensor(theta_all, dtype=torch.float32)
    penalties = np.array(penalties)

    return theta_all, bkps_all_km, penalties
# end def posterior_to_theta

In [None]:
samples = posterior.sample((2000,), x=x_obs)
print(f"Sample shape: {samples.shape}")

In [None]:
theta_models, bkps, _ = posterior_to_theta(
    z=z,
    vs_batch=samples,
    vpvs=1.75,
    penalty_min=0.0,
    penalty_max=15.0,
    max_layers=layer_max
)

print(theta_models.shape)  # (10, 2 + 2*15)
print(bkps[1])  # breakpoints du premier échantillon

In [None]:
theta_models[1]

In [None]:
bkps[1]

In [None]:
# Plot the model and the curve
plot_velocity_and_dispersion(
    theta=theta_models[1],
    disp_curve=x_obs[0],
    p_min=p_min,
    p_max=p_max,
    kmax=kmax,
    dpi=300
)


In [None]:
# Run the simulator on the sampled models
disp_curves = dispsurf2k25_simulator(
    theta=theta_models,
    p_min=p_min,
    p_max=p_max,
    kmax=kmax,
    iflsph=iflsph,
    iwave=iwave,
    mode=mode,
    igr=igr,
    progress=True
)

In [None]:
disp_curves.shape

In [None]:
results = []
for i in range(disp_curves.shape[0]):
    # === erreur RMS ===
    rms = np.sqrt(torch.mean((disp_curves[i].cpu() - x_obs[0].cpu())**2))
    results.append({
        "disp_curve": disp_curves[i],
        "rms": rms
    })
# end for

# Sort my RMS
results_sorted = sorted(results, key=lambda r: r["rms"])
best = results_sorted[:20]

In [None]:
plt.figure(figsize=(10,5), dpi=300)

# Plot best curves
for r in best:
    plt.plot(periods, r["disp_curve"], alpha=0.3)
# end for

# Plot the true dispersion curve
plt.plot(periods, x_obs[0].cpu().numpy(), color='red', linewidth=3, label="Observation (x_obs)")

# Average misfit of the reconstructed dispersion curves
mean_rms = np.mean([r["rms"] for r in best])

plt.xlim(p_min, p_max)
plt.xlabel("Period (s)")
plt.ylabel("Phase velocity (km/s)")
plt.title(f"Posterior Predictive Check - Dispersion curves\nMean misfit (RMS) = {mean_rms:.3f}")
plt.legend()
plt.show()

### Posterior Predictive Distribution

In [None]:
# quantiles
q05 = np.percentile(disp_curves, 5, axis=0)
q50 = np.percentile(disp_curves, 50, axis=0)
q95 = np.percentile(disp_curves, 95, axis=0)

In [None]:
plt.figure(figsize=(10,5), dpi=300)
plt.fill_between(periods, q05, q95, color="lightblue", alpha=0.5, label="90% credible interval")
plt.plot(periods, q50, color="blue", linewidth=2, label="Median prediction")
plt.plot(periods, x_obs[0].cpu().numpy(), color="red", linewidth=2, label="Observation")
plt.xlabel("Period (s)")
plt.ylabel("Phase velocity (km/s)")
plt.title("Posterior Predictive Distribution (PPD) of dispersion curves")
plt.legend()
plt.grid(True)
plt.show()

---

### Expected Coverage Probability and Posterior Calibration

This cell evaluates the **Expected Coverage Probability (ECP)** of the posterior distribution — a standard diagnostic for **posterior calibration** in simulation-based inference (SBI).


#### Practical insights

Expected coverage provides a **simple and interpretable** way to diagnose issues in the posterior.   Compared to other diagnostics such as **L-C2ST**, it requires relatively few additional simulations (~200) and does **not** rely on extra hyperparameters (as **TARP** does) or additional neural network training.

It allows us to evaluate whether the posterior is, *on average across many prior-predictive observations*,  
**over-confident** or **under-confident**.

![Texte alternatif](images/sbc_rank_plot.png)

The plot can interpreted as follows:

* The blue line is below the diagonal => then the posterior is (on average) over-confident.
* The line is above the gray region => the posterior is, on average, under-confident.
* The line is within the gray region => we cannot reject the null hypothesis that the posterior is well-calibrated.

#### Method

For each test observation $x_i$ with its corresponding ground-truth parameter $\theta_i^\ast$,  
we draw multiple posterior samples $\{ \theta_{i,j} \}_{j=1}^{N} \sim p_\phi(\theta|x_i)$.  
For a given credibility level $\alpha$ (e.g. 0.9), we compute the **central credible interval** that covers $\alpha\%$ of the posterior mass:

$$
\mathcal{I}_{\alpha}(x_i) = \Big[ q_{(1-\alpha)/2}, \; q_{1 - (1-\alpha)/2} \Big]
$$

Then we check whether the true parameter lies inside this interval:

$$
\mathbf{1}\big( \theta_i^\ast \in \mathcal{I}_{\alpha}(x_i) \big)
$$

Repeating this for all test observations gives the **empirical coverage**:

$$
\hat{C}(\alpha) = \frac{1}{N_{\text{obs}}} \sum_i \mathbf{1}\big( \theta_i^\ast \in \mathcal{I}_{\alpha}(x_i) \big)
$$

#### Interpretation

- If the posterior is **well calibrated**, the observed coverage matches the expected credibility level:
  $$
  \hat{C}(\alpha) \approx \alpha
  $$
- If $\hat{C}(\alpha) < \alpha$ → **overconfident posterior** (too narrow intervals).  
- If $\hat{C}(\alpha) > \alpha$ → **underconfident posterior** (too wide intervals).

Plotting the **observed coverage** (y-axis) versus the **expected credibility levels** (x-axis) yields the *calibration curve*.  
A perfectly calibrated model lies on the diagonal \( y = x \).

#### Notes

- The code computes this for several credibility levels (e.g. `[0.1, 0.3, 0.5, 0.7, 0.9]`).
- For each level, it checks if the true parameter is entirely within the central credible interval (using `torch.quantile`).

#### Reference

Hermans, J., Delaunoy, A., Rozet, F., Wehenkel, A., & Louppe, G. (2022). [A trust crisis in simulation-based inference? Your posterior approximations can be unfaithful](https://arxiv.org/pdf/2110.06581). Machine Learning Research.

In [None]:
levels = [0.1, 0.3, 0.5, 0.7, 0.9]
n_obs = min(100, x_val.shape[0])
observed_counts = torch.zeros(len(levels), device=device)

for i in range(n_obs):
    # Get the obervation and the theta
    x = x_val[i].unsqueeze(0).to(device)
    theta_true = theta_val[i].to(device)
    
    # Draw 1000 samples from the posterior
    samples = posterior.sample((1000,), x=x, show_progress_bars=False).to(device)
    
    # For each level
    for k, alpha in enumerate(levels):
        low = torch.quantile(samples, (1 - alpha) / 2, dim=0)
        high = torch.quantile(samples, 1 - (1 - alpha) / 2, dim=0)
        in_interval = ((theta_true >= low) & (theta_true <= high)).all().float()
        observed_counts[k] += in_interval
    # end for
# end for

# Compute ratio
observed = observed_counts / n_obs
expected = torch.tensor(levels, device=device)

In [None]:
# --- PLOT ---
fig = plt.figure(figsize=(6, 6))
plt.plot(expected.cpu(), observed.cpu(), "o-", label="Model", linewidth=2)
plt.plot([0, 1], [0, 1], "k--", label="Perfect calibration")
plt.xlabel("Expected coverage")
plt.ylabel("Observed coverage")
plt.title("Posterior Calibration Curve (Expected vs Observed)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### Visual summary

## Advanced Discussion — Sensitivity and Generalization

### Effect of the prior

### Sequential methods

### Out-of-distribution and generalization