In [2]:
import torch
import matplotlib.pyplot as plt
from sbi.inference import SNPE
from sbi.utils import BoxUniform

# Problem setup
dim = 10
prior = BoxUniform(low=-torch.ones(dim), high=torch.ones(dim))

# Simulator: x = theta + eps
def simulator(theta):
    eps = 1e-4 * torch.randn_like(theta)
    return theta + eps

# Generate training data
num_simulations = 50_000
theta = prior.sample((num_simulations,))
x = simulator(theta)

# --- Option 1: Regular NPE (-log q(theta|x)) ---
inference = SNPE(prior=prior, density_estimator="maf")
inference = inference.append_simulations(theta, x)
density_estimator = inference.train()
posterior_npe = inference.build_posterior(density_estimator)

# --- Option 2: SNPE-C (atomic loss) ---
inference_atomic = SNPE(prior=prior, density_estimator="maf", loss="atomic")
inference_atomic = inference_atomic.append_simulations(theta, x)
density_estimator_atomic = inference_atomic.train()
posterior_atomic = inference_atomic.build_posterior(density_estimator_atomic)

# Observation to infer from
theta_true = torch.zeros(dim)  # true parameter
x_obs = simulator(theta_true.unsqueeze(0))[0]

# Sample posteriors
samples_npe = posterior_npe.sample((10_000,), x=x_obs)
samples_atomic = posterior_atomic.sample((10_000,), x=x_obs)

# --- Plot marginals for a few dims ---
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
dims_to_plot = [0, 1, 2, 3, 4, 5]
for ax, d in zip(axes.flat, dims_to_plot):
    ax.hist(samples_npe[:, d].numpy(), bins=50, density=True, alpha=0.6, label="NPE")
    ax.hist(samples_atomic[:, d].numpy(), bins=50, density=True, alpha=0.6, label="SNPE-C")
    ax.axvline(theta_true[d].item(), color="k", linestyle="--")
    ax.set_title(f"θ[{d}] marginal")
ax.legend()
plt.tight_layout()
plt.show()

 Training neural network. Epochs trained: 1

KeyboardInterrupt: 

In [None]:
import torch
import matplotlib.pyplot as plt
from sbi.inference import SNPE
from sbi.utils import BoxUniform

# Select device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Problem setup
dim = 10
prior = BoxUniform(low=-torch.ones(dim, device=device),
                   high=torch.ones(dim, device=device))

# Simulator: x = theta + eps
def simulator(theta):
    eps = 1e-4 * torch.randn_like(theta, device=device)
    return theta + eps

# Generate training data
num_simulations = 50_000
theta = prior.sample((num_simulations,))
x = simulator(theta)

# --- Option 1: Regular NPE (-log q(theta|x)) ---
inference = SNPE(prior=prior, density_estimator="maf", device=device)
inference = inference.append_simulations(theta, x)
density_estimator = inference.train()
posterior_npe = inference.build_posterior(density_estimator)

# --- Option 2: SNPE-C (atomic loss) ---
inference_atomic = SNPE(prior=prior, density_estimator="maf", loss="atomic", device=device)
inference_atomic = inference_atomic.append_simulations(theta, x)
density_estimator_atomic = inference_atomic.train()
posterior_atomic = inference_atomic.build_posterior(density_estimator_atomic)

# Observation to infer from
theta_true = torch.zeros(dim, device=device)
x_obs = simulator(theta_true.unsqueeze(0))[0]

# Sample posteriors
samples_npe = posterior_npe.sample((10_000,), x=x_obs).cpu()
samples_atomic = posterior_atomic.sample((10_000,), x=x_obs).cpu()

# --- Plot marginals for a few dims ---
fig, axes = plt.subplots(2, 3, figsize=(12, 6))
dims_to_plot = [0, 1, 2, 3, 4, 5]
for ax, d in zip(axes.flat, dims_to_plot):
    ax.hist(samples_npe[:, d].numpy(), bins=50, density=True, alpha=0.6, label="NPE")
    ax.hist(samples_atomic[:, d].numpy(), bins=50, density=True, alpha=0.6, label="SNPE-C")
    ax.axvline(theta_true[d].item(), color="k", linestyle="--")
    ax.set_title(f"θ[{d}] marginal")
ax.legend()
plt.tight_layout()
plt.show()

Using device: cpu


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x14d7855acfa0>>
Traceback (most recent call last):
  File "/home/weniger/.conda/envs/emri_few/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
