# Passband Demo

In [None]:
from tdastro.astro_utils.passbands import PassbandGroup

### Set Up PassbandGroup

In [None]:
passband_group = PassbandGroup(preset="LSST")
print(passband_group)

### Set Up Spline Model

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tdastro.sources.spline_model import SplineModel

# Load a model
input_times = np.array([1001.0, 1002.0, 1003.0, 1004.0, 1005.0])
input_wavelengths = np.linspace(4000.0, 8000.0, 5)
input_fluxes = np.array(
    [
        [1.0, 5.0, 2.0, 3.0, 1.0],
        [5.0, 10.0, 6.0, 7.0, 5.0],
        [2.0, 6.0, 3.0, 4.0, 2.0],
        [1.0, 5.0, 2.0, 3.0, 1.0],
        [1.0, 5.0, 2.0, 3.0, 1.0],
    ]
)

spline_model = SplineModel(input_times, input_wavelengths, input_fluxes, time_degree=3, wave_degree=3)

times = np.linspace(1000.0, 1006.0, 40)
wavelengths = passband_group.waves
fluxes = spline_model.evaluate(times, wavelengths)

In [None]:
def plot_flux_spectrogram():
    """Plot a spectrogram to visualize the fluxes."""
    plt.figure(figsize=(12, 5))
    plt.imshow(fluxes.T, cmap="plasma", interpolation="nearest", aspect="auto")

    # Add title, axis labels, and correct ticks
    plt.title("Flux Spectrogram")
    plt.xlabel("Time (days)")
    plt.ylabel("Wavelength (Angstrom)")
    plt.xticks(np.arange(len(times))[::4], [f"{round(time)}" for time in times][::4])
    plt.yticks(np.arange(len(wavelengths))[::50], [f"{round(wave)}" for wave in wavelengths][::50])

    # Add flux labels
    for (j, i), label in np.ndenumerate(fluxes.T):
        if i % 2 == 1 and j % 40 == 20:
            plt.text(i, j, round(label, 1), ha="center", va="center", size=8)

    plt.show()


plot_flux_spectrogram()

### Plot Lightcurve

In [None]:
bandfluxes = passband_group.fluxes_to_bandfluxes(fluxes)

In [None]:
def plot_lightcurve(bandfluxes):
    """Plot the passband-normalized lightcurve."""
    fig, ax = plt.subplots()
    ax.set_title("Passband-Normalized Lightcurve")
    ax.set_xlabel("Time (days)")
    ax.set_ylabel("Flux")

    for _, label in enumerate(bandfluxes.keys()):
        ax.plot(times, bandfluxes[label], marker="o", label=label)

    ax.legend()
    plt.show()


plot_lightcurve(bandfluxes)

### Update Wave Grid

In [None]:
passband_group.process_transmission_tables(delta_wavelength=20.0, trim_percentile=None)

times = np.linspace(1000.0, 1006.0, 40)
wavelengths = passband_group.waves
fluxes = spline_model.evaluate(times, wavelengths)

bandfluxes = passband_group.fluxes_to_bandfluxes(fluxes)

plot_lightcurve(bandfluxes)