# Passband Interface Refactor

In [None]:
import logging
import os

import numpy as np
import scipy.integrate
from tdastro.sources.spline_model import SplineModel

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


class PassbandGroup:
    """A group of passbands.

    Attributes
    ----------
    passbands : dict
        A dictionary of Passband objects.
    """

    def __init__(self, preset=None, passbands=None):
        """Initialize a PassbandGroup object.

        Parameters
        ----------
        preset : str, optional
            A pre-defined set of passbands to load.
        passbands : list, optional
            A list of Passband objects assigned to the group.
        """
        if preset is not None:
            self.load_preset(preset)

        if passbands is not None:
            for passband in passbands:
                self.passbands[passband.label] = passband  # Overriding any preset passbands

    def load_preset(self, preset):
        """Load a pre-defined set of passbands."""
        if preset == "LSST":
            self.passbands = {
                "u": Passband("LSST", "u"),
                "g": Passband("LSST", "g"),
                "r": Passband("LSST", "r"),
                "i": Passband("LSST", "i"),
                "z": Passband("LSST", "z"),
                "y": Passband("LSST", "y"),
            }
        else:
            raise ValueError(f"Unknown passband preset: {preset}")


class Passband:
    """A passband contains information about its transmission curve and calculates its normalization."""

    def __init__(self, survey, label, table_path=None, table_url=""):
        self.label = label
        self.survey = survey
        self.full_name = f"{survey}_{label}"
        self.transmission_table = self._load_transmission_table(table_path, table_url)
        self.normalized_transmission_table = self._normalize_transmission_table()

    def _load_transmission_table(self, table_path, table_url) -> np.ndarray:
        """Load a transmission table from a file or URL."""
        if table_path is None:
            table_path = os.path.join(os.path.dirname(__file__), f"passbands/{self.survey}/{self.label}.dat")
        if not os.path.exists(table_path):
            self._download_transmission_table(table_url, table_path)
        return np.loadtxt(table_path)

    def _download_transmission_table(self, table_url, table_path) -> None:
        raise NotImplementedError("Downloading passband tables is not yet implemented.")

    def _normalize_transmission_table(self) -> np.ndarray:
        """Calculate the value of phi_b for all wavelengths in a transmission table.

        This is eq. 8 from "On the Choice of LSST Flux Units" (Ivezić et al.):

        φ_b(λ) = S_b(λ)λ⁻¹ / ∫ S_b(λ)λ⁻¹ dλ

        where S_b(λ) is the system response of the passband.

        Notes
        -----
        - We use transmission table here to represent S_b(λ).
        - There is currently no interpolation implemented (but coming very soon).
        """
        wavelengths_angstrom = self.transmission_table[:, 0]
        transmissions = self.transmission_table[:, 1]
        # Calculate the numerators and denominator
        numerators = transmissions / wavelengths_angstrom
        denominator = scipy.integrate.trapezoid(numerators, x=wavelengths_angstrom)
        # Calculate phi_b for each wavelength
        return numerators / denominator