# Passband Demo

In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np

from tdastro.astro_utils.passbands import PassbandGroup
from tdastro.utils.plotting import plot_flux_spectrogram, plot_bandflux_lightcurves

### Set Up PassbandGroup

Load the default passbands for LSST and get its wavelengths.

In [None]:
passband_group = PassbandGroup(preset="LSST")
print(passband_group)

wavelengths = passband_group.waves
min_wave, max_wave = passband_group.wave_bounds()
print(f"Wavelengths range [{min_wave}, {max_wave}]")

### Set Up Spline Model

Create a simple model to compute flux densities using a predefined spline.

In [None]:
from tdastro.sources.spline_model import SplineModel

# Load a model
input_times = np.array([1001.0, 1002.0, 1003.0, 1004.0, 1005.0, 1006.0])
input_wavelengths = np.linspace(min_wave, max_wave, 5)
input_fluxes = np.array(
    [
        [1.0, 5.0, 2.0, 3.0, 1.0],
        [5.0, 10.0, 6.0, 7.0, 5.0],
        [2.0, 6.0, 3.0, 4.0, 2.0],
        [1.0, 5.0, 2.0, 3.0, 1.0],
        [1.0, 5.0, 2.0, 3.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 0.0],
    ]
)
spline_model = SplineModel(input_times, input_wavelengths, input_fluxes, time_degree=3, wave_degree=3)

### Query and plot the model

Generate the flux densities at a few points and the passband group's wavelengths. Plot the flux spectrogram.

In [None]:
times = np.linspace(1000.0, 1006.0, 40)
fluxes = spline_model.evaluate(times, wavelengths)

plot_flux_spectrogram(fluxes, times, wavelengths, title="Flux Spectrogram")

### Plot Lightcurves

Compute the lightcurves in each band and plot them.

In [None]:
bandfluxes = passband_group.fluxes_to_bandfluxes(fluxes)
plot_bandflux_lightcurves(bandfluxes, times, title="Passband-Normalized Lightcurve")

Or we can plot each band's light curve on its own.

In [None]:
num_cols = 3
num_rows = math.ceil(len(bandfluxes.keys()) / num_cols)

fig = plt.figure(figsize=(12, 4))
axes = fig.subplots(num_rows, num_cols, sharex=True, sharey=True)

for idx, band_name in enumerate(bandfluxes.keys()):
    row = int(idx / num_cols)
    col = idx % num_cols
    plot_bandflux_lightcurves(bandfluxes[band_name], times, ax=axes[row][col], title=band_name)

### Update Wave Grid

By increasing our `delta_wave` parameter, we increase the grid step of our transmission table, and the fluxes caluculated from `passband_group.waves`.

In [None]:
passband_group.process_transmission_tables(delta_wave=30.0)

times = np.linspace(1000.0, 1006.0, 40)
wavelengths = passband_group.waves
fluxes = spline_model.evaluate(times, wavelengths)

bandfluxes = passband_group.fluxes_to_bandfluxes(fluxes)
plot_bandflux_lightcurves(bandfluxes, times, title="Passband-Normalized Lightcurve")

By setting our `trim_quantile` parameter to None, we disable the automatic trimming performed on transmission table to remove the upper and lower tails.

In [None]:
passband_group.process_transmission_tables(delta_wave=30.0, trim_quantile=None)

times = np.linspace(1000.0, 1006.0, 40)
wavelengths = passband_group.waves
fluxes = spline_model.evaluate(times, wavelengths)

bandfluxes = passband_group.fluxes_to_bandfluxes(fluxes)
plot_bandflux_lightcurves(bandfluxes, times, title="Passband-Normalized Lightcurve")