# AstroPrism Tutorial: Noise Model

The noise model describes the uncertainty in observations. AstroPrism uses a heteroscedastic noise model with two components:

$$\sigma_i^2 = b_i^2 + g_i \cdot |\mu_i|$$

Where:
- $\sigma_i^2$ = noise variance for channel $i$
- $b_i$ = background noise standard deviation (readout, sky background)
- $g_i$ = Poisson-like gain (signal-dependent noise scaling)
- $\mu_i$ = noiseless model prediction (from instrument response)

## Imports

In [1]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import warnings
from astropy.utils.exceptions import AstropyWarning
import jax

warnings.filterwarnings('ignore', category=AstropyWarning)

## Demonstration

In [1]:
# Load Dataset
from astroprism.io import load_dataset, SingleInstrumentDataset
path = "../data/tutorial/jwst_miri_cutout/"
dataset = load_dataset(path=path, instrument="JWST_MIRI", extension="fits")
print(dataset.summary())

SingleInstrumentDataset Summary:
--------------------------------
Number of channels: 4
Channel keys: ['F1000W_full', 'F1130W_full', 'F2100W_full', 'F770W_full']
Channel shapes: [(180, 180), (180, 180), (180, 180), (180, 180)]
Pixel scales: [(3600.0, 3600.0), (3600.0, 3600.0), (3600.0, 3600.0), (3600.0, 3600.0)]



### NoiseModel

The `NoiseModel` learns two parameters per channel:
- `background_std` ($b_i$): flux-independent noise (LogNormal prior)
- `poisson_scale` ($g_i$): flux-dependent noise scaling (LogNormal prior)

Given the instrument response output $\mu$, it computes the noise standard deviation $\sigma$ for each pixel.

In [None]:
# Imports
from astroprism.models.noise import NoiseModel

# NoiseModel
noise = NoiseModel(n_channels=len(dataset))

# Print the domain of the NoiseModel (parameters)
print("NoiseModel domain keys:", list(noise.domain.keys()))

NoiseModel domain keys: ['background_std', 'poisson_scale']
