# Passband Demo

A `Passband` object stores the information needed to transform the observed flux density over multiple wavelengths into a single band flux for a given filter. A `PassbandGroup` object implements a collection of `Passband` providing convenient helper functions for loading and processing multiple passbands.

In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np

from tdastro.astro_utils.passbands import Passband, PassbandGroup
from tdastro.utils.plotting import plot_flux_spectrogram, plot_lightcurves

## Set Up PassbandGroup

Both the `Passband` and `PassbandGroup` classes provide multiple mechanisms for loading in the passband information. Users can manually specify the passband values, load from given files, or load from a preset (which will download the files if needed).

### Loading present passbands

We start this notebook by loading the default passbands for LSST and printing basic information. 

In [None]:
passband_group = PassbandGroup(preset="LSST")
print(passband_group)

wavelengths = passband_group.waves
min_wave, max_wave = passband_group.wave_bounds()
print(f"Wavelengths range [{min_wave}, {max_wave}]")

We can access individual `Passband` objects with the [] notation and plot them using `Passband`'s plot functionality.

In [None]:
passband_group["LSST_g"].plot()

We can plot all of the passbands using `PassbandGroup`'s plot functionality.

In [None]:
passband_group.plot()

### Manually specified passbands

For testing, we might want to manually specify the passband information. We can do this by creating a 2-dimensional n umpy array where the first column is wavelength and the second column is transmission values.

In [None]:
values = np.array(
    [
        [1000, 0.5],
        [1005, 0.6],
        [1010, 0.7],
        [1015, 0.5],
        [1020, 0.7],
        [1025, 0.8],
        [1030, 0.2],
        [1035, 0.2],
    ]
)

toy_passband = Passband(
    "toy_survey",  # Survey name.
    "a",  # Filter name
    table_values=values,  # The matrix of transmission data
)
toy_passband.plot()

## Applying Passbands

In order to apply passbands, we first need a 2-dimensional matrix flux densities for different times and wavelengths. We can manually specify these or generate them with one of the physical models. 

In this example, we use simple model to compute flux densities using a predefined spline.

In [None]:
from tdastro.sources.spline_model import SplineModel

# Load a model
input_times = np.array([1001.0, 1002.0, 1003.0, 1004.0, 1005.0, 1006.0])
input_wavelengths = np.linspace(min_wave, max_wave, 5)
input_fluxes = np.array(
    [
        [1.0, 5.0, 2.0, 3.0, 1.0],
        [5.0, 10.0, 6.0, 7.0, 5.0],
        [2.0, 6.0, 3.0, 4.0, 2.0],
        [1.0, 5.0, 2.0, 3.0, 1.0],
        [1.0, 5.0, 2.0, 3.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 0.0],
    ]
)
spline_model = SplineModel(input_times, input_wavelengths, input_fluxes, time_degree=3, wave_degree=3)

# Query the model at different time steps and all the wavelengths covered
# by the current passband group.
times = np.linspace(1000.0, 1006.0, 40)
fluxes = spline_model.evaluate(times, wavelengths)

To visualize the flux densities, we plot the flux spectrogram.

In [None]:
plot_flux_spectrogram(fluxes, times, wavelengths, title="Flux Spectrogram")

### Plot Lightcurves

Compute the lightcurves in each band and plot them.

In [None]:
bandfluxes = passband_group.fluxes_to_bandfluxes(fluxes)
plot_lightcurves(bandfluxes, times, title="Passband-Normalized Lightcurve")

Or we can plot each band's light curve on its own.

In [None]:
num_cols = 3
num_rows = math.ceil(len(bandfluxes.keys()) / num_cols)

fig = plt.figure(figsize=(12, 4))
axes = fig.subplots(num_rows, num_cols, sharex=True, sharey=True)

for idx, band_name in enumerate(bandfluxes.keys()):
    row = int(idx / num_cols)
    col = idx % num_cols
    plot_lightcurves(bandfluxes[band_name], times, ax=axes[row][col], title=band_name)

## Modifying Passbands and PassbandGroups

In some cases we might want to modify the passband information to fit our use case. In this section we show how to perform several different modifications, including: filtering the passbands used, updating the wave grid, and trimming the passbands.

### Filtering Passbands

If our analysis is not using all of the bands defined in a passband group, we can create a subset of the passbands using `PassbandGroup`'s subset functionality. This will prune unused passbands, removing unneeded overhead of future computations.

For example we could drop the LSST_u and LSST_z filters.

In [None]:
bands_to_keep = ["LSST_g", "LSST_r", "LSST_i", "LSST_y"]
passband_group.subset(bands_to_keep)
passband_group.plot()

### Update Wave Grid

By increasing our `delta_wave` parameter, we increase the grid step of our transmission table, and the fluxes caluculated from `passband_group.waves`.

In [None]:
passband_group.process_transmission_tables(delta_wave=30.0)

times = np.linspace(1000.0, 1006.0, 40)
wavelengths = passband_group.waves
fluxes = spline_model.evaluate(times, wavelengths)

bandfluxes = passband_group.fluxes_to_bandfluxes(fluxes)
plot_lightcurves(bandfluxes, times, title="Passband-Normalized Lightcurve")

### Setting Trim Quantile

By setting our `trim_quantile` parameter to None, we disable the automatic trimming performed on transmission table to remove the upper and lower tails.

In [None]:
passband_group.process_transmission_tables(delta_wave=30.0, trim_quantile=None)

times = np.linspace(1000.0, 1006.0, 40)
wavelengths = passband_group.waves
fluxes = spline_model.evaluate(times, wavelengths)

bandfluxes = passband_group.fluxes_to_bandfluxes(fluxes)
plot_lightcurves(bandfluxes, times, title="Passband-Normalized Lightcurve")