# AstroPrism Tutorial: Forward Model

The forward model combines all components into a single differentiable model:

1. **GP Model** (MixtureGP): $s = \text{activation}(A \cdot u + c)$ — sky signal from latent GPs
2. **Response Model** (InstrumentResponse): $\mu_i = P_i \ast R_i(s)$ — noiseless prediction  
3. **Noise Model** (NoiseModel): $\sigma_i^2 = b_i^2 + g_i \cdot |\mu_i|$ — heteroscedastic noise

The full generative model is: $d_i \sim \mathcal{N}(\mu_i, \sigma_i^2)$

The forward model can then be used to construct a likelihood for Bayesian inference (see next tutorial). 


## Imports

In [1]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import warnings
from astropy.utils.exceptions import AstropyWarning
import jax

warnings.filterwarnings('ignore', category=AstropyWarning)

## Demonstration

In [None]:
# Load Dataset
from astroprism.io import load_dataset, SingleInstrumentDataset
path = "../data/tutorial/jwst_miri_cutout/"
dataset = load_dataset(path=path, instrument="JWST_MIRI", extension="fits")
print(dataset.summary())

# GP Sky Model
from astroprism.models.gp import SpatialGP, MixtureGP
spatial_gp = SpatialGP(
    n_channels=len(dataset),
    shape=dataset.shapes[0],
    distances=dataset.pixel_scales[0],
)
mixture = MixtureGP(spatial_gps=spatial_gp)

# Instrument Response Model
from astroprism.models.response import InstrumentResponse
response = InstrumentResponse(
    dataset=dataset,
    signal_wcs=dataset.wcs[0],
    signal_shape=dataset.shapes[0],
)

# Noise Model
from astroprism.models.noise import NoiseModel
noise = NoiseModel(n_channels=len(dataset))


SingleInstrumentDataset Summary:
--------------------------------
Number of channels: 4
Channel keys: ['jwst_miri_ngc1566_f1000w', 'jwst_miri_ngc1566_f1130w', 'jwst_miri_ngc1566_f2100w', 'jwst_miri_ngc1566_f770w']
Channel shapes: [(2125, 1814), (2125, 1814), (2125, 1814), (2125, 1814)]
Pixel scales: [(0.11091449975820492, 0.1109144997630491), (0.11091449949791077, 0.11091449946799463), (0.11091449907273217, 0.11091449905375952), (0.11091450012231652, 0.11091450010349585)]



### Assemble Forward Model

The `ForwardModel` wraps all three components and combines their parameter domains into a single $\theta$:

$$\theta = \{\theta_{\text{GP}}, \theta_{\text{response}}, \theta_{\text{noise}}\}$$

When called, it executes the full pipeline and returns $(\mu(\theta), \sigma^{-1}(\theta))$ for the likelihood.

In [None]:
# Imports
from astroprism.models.forward import ForwardModel

# ForwardModel: Combines all models
model = ForwardModel(mixture, response, noise)

# Print the domain of the ForwardModel (parameters)
print("ForwardModel domain keys:", list(model.domain.keys()))

ForwardModel domain keys: ['zeromode', 'fluctuations', 'loglogavgslope', 'spectrum', 'flexibility', 'xi', 'mixture_matrix', 'mixing_offset', 'psf_sigma', 'psf_rotation', 'background_std', 'poisson_scale']
