# AstroPrism Tutorial: Likelihood & Variational Inference

The **likelihood** connects our forward model to observed data using a Variable Covariance Gaussian:

$$\mathcal{L}(d | \theta) = \prod_i \mathcal{N}(d_i \,|\, \mu_i(\theta), \sigma_i^2(\theta))$$

Where:
- $d_i$ = observed pixel value
- $\mu_i(\theta)$ = model prediction (from instrument response)
- $\sigma_i(\theta)$ = noise std (from noise model)
- $\theta$ = all model parameters

Given the likelihood and prior $p(\theta)$, we want the posterior:

$$p(\theta | d) \propto \mathcal{L}(d | \theta) \, p(\theta)$$

Since this is intractable, we use **Variational Inference (VI)** to approximate it with a simpler distribution $q(\theta)$ by minimizing the KL divergence:

$$\text{KL}(q \| p) = \mathbb{E}_q \left[ \log q(\theta) - \log p(\theta | d) \right]$$

AstroPrism uses NIFTy8's **Geometric VI (GeoVI)**, which leverages the geometry of the parameter space for efficient optimization.

## Imports

In [1]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import warnings
from astropy.utils.exceptions import AstropyWarning
import jax

warnings.filterwarnings('ignore', category=AstropyWarning)

## Demonstration

In [2]:
# Load Dataset
from astroprism.io import load_dataset, SingleInstrumentDataset
path = "../data/tutorial/jwst_miri_cutout/"
dataset = load_dataset(path=path, instrument="JWST_MIRI", extension="fits")
print(dataset.summary())

# GP Sky Model
from astroprism.models.gp import SpatialGP, MixtureGP
spatial_gp = SpatialGP(
    n_channels=len(dataset),
    shape=dataset.shapes[0],
    distances=dataset.pixel_scales[0],
)
mixture = MixtureGP(spatial_gps=spatial_gp)

# Instrument Response Model
from astroprism.models.response import InstrumentResponse
response = InstrumentResponse(
    dataset=dataset,
    signal_wcs=dataset.wcs[0],
    signal_shape=dataset.shapes[0],
)

# Noise Model
from astroprism.models.noise import NoiseModel
noise = NoiseModel(n_channels=len(dataset))

# Forward Model
from astroprism.models.forward import ForwardModel
model = ForwardModel(mixture, response, noise)
print("ForwardModel domain keys:", list(model.domain.keys()))


SingleInstrumentDataset Summary:
--------------------------------
Number of channels: 4
Channel keys: ['F1000W_full', 'F1130W_full', 'F2100W_full', 'F770W_full']
Channel shapes: [(600, 600), (600, 600), (600, 600), (600, 600)]
Pixel scales: [(0.11091449975820492, 0.1109144997630491), (0.11091449949791077, 0.11091449946799463), (0.11091449907273217, 0.11091449905375952), (0.11091450012231652, 0.11091450010349585)]

ForwardModel domain keys: ['zeromode', 'fluctuations', 'loglogavgslope', 'spectrum', 'flexibility', 'xi', 'mixture_matrix', 'mixing_offset', 'psf_sigma', 'psf_rotation', 'background_std', 'poisson_scale']


### Build Likelihood

The `build_likelihood` function constructs the likelihood from the dataset, forward model, and an optional mask to exclude bad pixels (e.g., detector edges, cosmic rays).

In [3]:
# Imports
from astroprism.models.likelihood import build_likelihood

# Build likelihood: Combines data, model, and mask (readout)
likelihood = build_likelihood(dataset, model, mask=dataset.readout)

### Run Variational Inference

The `run_inference` function runs GeoVI optimization, returning posterior samples and the optimizer state.

In [None]:
# Imports
from astroprism.inference.vi import run_inference

# Run inference: Optimizes the likelihood 
samples, state = run_inference(
    likelihood, 
    n_iterations=10,   # number of iterations
    n_samples=2,      # number of samples
    seed=42,
    output_directory="tutorial_results"
)

  st = os.stat(path)
OPTIMIZE_KL: Starting 0001


Starting VI optimization: 2 iterations, 2 samples.


linear_solver: Iteration 0 â›°:+1.6211e+10 Î”â›°:inf âž½:1.0000e-04
linear_solver: Iteration 1 â›°:+9.4390e+08 Î”â›°:1.5267e+10 âž½:1.0000e-04
linear_solver: Iteration 2 â›°:+4.6147e+08 Î”â›°:4.8243e+08 âž½:1.0000e-04
linear_solver: Iteration 3 â›°:+1.4064e+08 Î”â›°:3.2083e+08 âž½:1.0000e-04
linear_solver: Iteration 4 â›°:+5.6275e+07 Î”â›°:8.4363e+07 âž½:1.0000e-04
linear_solver: Iteration 5 â›°:+5.1143e+07 Î”â›°:5.1326e+06 âž½:1.0000e-04
linear_solver: Iteration 6 â›°:+3.6643e+07 Î”â›°:1.4499e+07 âž½:1.0000e-04
linear_solver: Iteration 7 â›°:+3.4348e+07 Î”â›°:2.2959e+06 âž½:1.0000e-04
linear_solver: Iteration 8 â›°:+2.4230e+07 Î”â›°:1.0117e+07 âž½:1.0000e-04
linear_solver: Iteration 9 â›°:+1.1515e+07 Î”â›°:1.2715e+07 âž½:1.0000e-04
linear_solver: Iteration 10 â›°:+1.0501e+07 Î”â›°:1.0141e+06 âž½:1.0000e-04
linear_solver: Iteration 11 â›°:+1.0494e+07 Î”â›°:7.0885e+03 âž½:1.0000e-04
linear_solver: Iteration 12 â›°:+7.3519e+06 Î”â›°:3.1421e+06 âž½:1.0000e-04
linear_solver: Iteration 13 â