# LightcurveSource Demo

The `LightcurveSource` model is designed to replicate given lightcurves in specific bands. It is specified as a separate lightcurve for each passband. The underlying model produces estimated SEDs at each time, so that all of TDAstroâ€™s effects can be applied.

In this notebook we provide an introductory demo to setting up and using the `LightcurveSource` model.

## Setup

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from tdastro.astro_utils.passbands import PassbandGroup
from tdastro.consts import lsst_filter_plot_colors
from tdastro.sources.lightcurve_source import LightcurveSource

We start be loading the passbands that we will use to define the model. In this case we use the passbands from the LSST preset.

In [None]:
passband_group = PassbandGroup(preset="LSST")
filters = passband_group.passbands.keys()
print(passband_group)

wavelengths = passband_group.waves
min_wave, max_wave = passband_group.wave_bounds()
print(f"Wavelengths range [{min_wave}, {max_wave}]")

passband_group.plot()

## Creating the model

We create the model from lightcurves for each passbands of interest. This is what we want the model to reproduce then we call get_band_fluxes(). While these lightcurves do no need to be the same as in the `PassbandGroup` every lightcurve must have a corresponding entry in the `PassbandGroup`.

For simplicity of the demo, we create each curve as a randomly parameterized sin wave.  Note that the times for all the lightcurves do not need to be the same.

In [None]:
num_times = 100
times = np.linspace(0, 20, num_times)

lightcurves = {}
for filter in filters:
    amp = 5.0 * np.random.random() + 1.0
    flux_offset = np.random.random() * 25 + 10
    time_offset = np.random.random() * 10
    filter_flux = amp * np.sin(times + time_offset) + flux_offset
    print(f"Filter {filter}: {amp:.2f} * sin(t + {time_offset:.2f}) + {flux_offset:.2f}")

    lightcurves[filter] = np.array([times + time_offset, filter_flux]).T

In [None]:
# Plot the lightcurves
figure = plt.figure()
ax = figure.add_axes([0, 0, 1, 1])
for filter, lightcurve in lightcurves.items():
    color = lsst_filter_plot_colors.get(filter, "black")
    ax.plot(lightcurve[:, 0], lightcurve[:, 1], color=color, label=filter)

We then create the model from the dictionary of lightcurve information 

In [None]:
model = LightcurveSource(lightcurves, passband_group, t0=0.0)

If we plot the underlying lightcurves we can see they matched the provided ones.

In [None]:
model.plot_lightcurves()

We can also plot the SEDs for each filter to see how the underlying model is computing the total SED. Each passband's SED basis is multiplied by the corresponding lightcurve's values over time to give its contributions to the overall SED.  Note that, in order to avoid overcounting the contributions of some wavelengths, the SED basis functions only contain wavelengths where the filters do not overlap.

In [None]:
model.plot_sed_basis()

## Generating Flux Densities

We evaluate the `LightcurveSource` model the same way we evaltuate any `PhysicalModel` with functions such as `evaluate()`, `sample_parameters()`, and `get_band_fluxes()`. This model was specifically designed for the `get_band_fluxes()` function, so we explore that below.

Let's model a sequence of observations in the g and r filters only.

In [None]:
# The state is not used since we do not have any random parameters, but we need it for get_band_fluxes().
state = model.sample_parameters(num_samples=1)

# Create query times over the middle of the range. Make 3/4 of them r and the rest g.
query_times = np.linspace(5, 15, 200)
query_filters = np.array(["g" if int(t) % 4 == 0 else "r" for t in query_times])

# Get the fluxes for the query times and filters
flux = model.get_band_fluxes(passband_group, query_times, query_filters, state)

In [None]:
# Plot the fluxes
figure = plt.figure()
ax = figure.add_axes([0, 0, 1, 1])
for filter in ["g", "r"]:
    mask = query_filters == filter
    color = lsst_filter_plot_colors.get(filter, "black")
    label = f"Observation in {filter}"
    ax.plot(query_times[mask], flux[mask], color=color, label=label, linewidth=0, marker=".")

ax.set_xlabel("Time")
ax.set_ylabel("Flux")
ax.legend()
plt.show()