# Microlensing Example

In this notebook we look at how we can apply a simple microlensing effect that wraps the [VBMicrolensing package](https://github.com/valboz/VBMicrolensing).

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path

from tdastro.astro_utils.passbands import PassbandGroup
from tdastro.effects.microlensing import Microlensing
from tdastro.sources.basic_sources import StaticSource, SinWaveSource
from tdastro.sources.lightcurve_source import LightcurveSource
from tdastro.utils.plotting import plot_lightcurves


# Usually we would not hardcode the path to the passband files, but for this demo we will use a relative path
# to the test data directory so that we do not have to download the files.
data_dir = Path("../../tests/tdastro/data")

## Basic Application

In the most basic form, the microlensing effect can be added to any source. Here we start with a static source with no microlensing.

In [None]:
source = StaticSource(brightness=100.0)

t_start = 60676.0
times = np.arange(100.0) + t_start
wavelengths = np.array([7000.0])
fluxes = source.evaluate(times, wavelengths)

plt.plot(times, fluxes)
plt.xlabel("Time")
plt.ylabel("Flux")
plt.show()

Now we add a microlensing effect.  As you can see, the microlensing introduces a magnification 20.0 days after the start of the light curve.

In [None]:
ml_effect = Microlensing(microlensing_t0=t_start + 20.0, u_0=0.1, t_E=10.0)
source.add_effect(ml_effect)

fluxes = source.evaluate(times, wavelengths)

plt.plot(times, fluxes)
plt.xlabel("Time")
plt.ylabel("Flux")
plt.show()

The model underneath the microlense does not need to be constant. We could simulate a small amount of variability by using a sin wave based model.

In [None]:
source2 = SinWaveSource(brightness=100.0, frequency=0.05, t0=t_start)
source2.add_effect(ml_effect)
fluxes = source2.evaluate(times, wavelengths)

plt.plot(times, fluxes)
plt.xlabel("Time")
plt.ylabel("Flux")
plt.show()

## More Complex Models

We can extend the microlensing effect with any one of our SED or bandflux based models. For example consider a `LightcurveModel` which takes sample data points in each bandflux and returns the interpolated values.

We start by loading passbands, which are needed from the lightcurve model. Here we use (potentially older) data from the test directory to avoid needing to do a download. Users will generally want to download the most recent passbands. See the passbands notebook for more details.

In [None]:
# Load the passband data for the griz filters only.
filters = ["g", "r", "i", "z"]
passband_group = PassbandGroup.from_preset(
    preset="LSST",
    filters=filters,
    units="nm",
    trim_quantile=0.001,
    delta_wave=1,
    table_dir=data_dir / "passbands",
)

Next we define a series of lightcurves to use as our background source.  These will be defined for the griz filters.

In [None]:
dts = np.arange(100.0)
times = dts + t_start

lightcurves = {
    "g": np.array([times, 3.0 * np.ones_like(times)]).T,  # Constant at 3.0
    "r": np.array([times, 0.02 * dts + 1.0]).T,  # Slight linear increase
    "i": np.array([times, np.sin(dts / 10.0) + 1.5]).T,  # Sin wave
    "z": np.array([times, 2.0 * np.ones_like(times)]).T,  # Constant at 2.0
}

lc_source = LightcurveSource(lightcurves, passband_group, t0=0)
graph_state = lc_source.sample_parameters(num_samples=1)

query_filters = np.array([filters[i % 4] for i in range(len(times))])
fluxes = lc_source.get_band_fluxes(passband_group, times, query_filters, graph_state)

plot_lightcurves(fluxes, times, fluxerrs=None, filters=query_filters)

We can add microlensing to this source model as we would any other model.

In [None]:
lc_source.add_effect(ml_effect)

# We need to resample to include the effect’s parameters.
graph_state = lc_source.sample_parameters(num_samples=1)
fluxes = lc_source.get_band_fluxes(passband_group, times, query_filters, graph_state)

plot_lightcurves(fluxes, times, fluxerrs=None, filters=query_filters)

## Create the model

We want to create models based on an existing LCLIB file. Again we will need to download the file of interest to the data directory.  In this example we use the LCLIB_RRL-LSST.TEXT.gz data from https://zenodo.org/records/6672739.  We load this into a `MultiLightcurveSource` object, which represents a set of light curves from which we can sample.

The `MultiLightcurveSource` stores a series of multi-band light curves, each corresponding to the observer frame bandfluxes for a single real or simulated objects. New observations are created by randomly choosing one of the light curves and interpolating it at new times. The starting time of the activity is controlled by the `t0` parameter in the model. So we can generate a simulation where the object's activity starts halfway through our observations.

Note that currently only periodic and non-reoccurring non-periodic light curves are supported. We treat reoccurring non-periodic light curves as non-reoccurring non-periodic (they will only occur once in the simulated output).

Since `MultiLightcurveSource` is a `PhysicalModel`, we can specify other parameters such as the RA and dec.  In this examples, we generate this position information by sampling from the OpSim fields (using an `OpSimRADECSampler` node). We sample the starting time of the light curve uniformly from the time covered by the OpSim.

In [None]:
lc_file = _TDASTRO_BASE_DATA_DIR / "models" / "LCLIB_RRL-LSST.TEXT.gz"

# Use an OpSim based sampler for position.
ra_dec_sampler = OpSimRADECSampler(
    opsim_db,
    radius=3.0,  # degrees
    node_label="ra_dec_sampler",
)

# Use a uniform sampler for the starting time (t0) of activity.
time_sampler = NumpyRandomFunc("uniform", low=t_min, high=t_max, node_label="time_sampler")

# Load the light curves from the LCLIB file. Only load the filters that are present in the OpSim data.
source = MultiLightcurveSource.from_lclib_file(
    lc_file,
    passband_group,
    ra=ra_dec_sampler.ra,
    dec=ra_dec_sampler.dec,
    t0=time_sampler,
    filters=filters,
    node_label="source",
)

print(f"Loaded {len(source)} light curves from {lc_file}")

## Generate the simulations

We can now generate random simulations with all the information defined above. The `simulate_lightcurves` function takes three parameters: the source from which we want to sample (here the collection of lightcurves), the number of results to simulate (1,000), and the passband information.  The passband information is not used in this example, because we can extrapolate the bandfluxes directly from the underlying light curves.

In [None]:
lightcurves = simulate_lightcurves(source, 1_000, opsim_db, passband_group)

The results are written in the [nested-pandas](https://github.com/lincc-frameworks/nested-pandas) format for easy analysis. Each row corresponds to a single simulated object, with a unique id, ra, dec, etc. The column `params` include all internal state, including hyperparameter settings, that was used to generate this object.

We can print the first row:

In [None]:
print(lightcurves.loc[0])

The nested `lightcurve` column contains the times, filters, and fluxes for each observation of that object.  We can treat it as a table:

In [None]:
print(lightcurves.loc[0]["lightcurve"])

Now let's plot some random light curves. Note that all of the light curves in the "LCLIB_RRL-LSST.TEXT.gz" file are periodic, so we expect to see observations throughout the time range of the survey.

In [None]:
random_ids = np.random.choice(len(lightcurves), 5)

for random_id in random_ids:
    # Extract the row for this object.
    lc = lightcurves.loc[random_id]

    if lc["nobs"] > 0:
        # Unpack the nested columns (filters, mjd, flux, and flux error).
        lc_filters = np.asarray(lc["lightcurve"]["filter"], dtype=str)
        lc_mjd = np.asarray(lc["lightcurve"]["mjd"], dtype=float)
        lc_flux = np.asarray(lc["lightcurve"]["flux"], dtype=float)
        lc_fluxerr = np.asarray(lc["lightcurve"]["fluxerr"], dtype=float)

        # Look up which lightcurve was used.
        graph_state = lc["params"]
        lc_id = graph_state["source.selected_lightcurve"]
        ra = graph_state["source.ra"]
        dec = graph_state["source.dec"]

        plot_lightcurves(
            fluxes=lc_flux,
            times=lc_mjd,
            fluxerrs=lc_fluxerr,
            filters=lc_filters,
            title=f"Sample {random_id} from Lightcurve {lc_id} at ({ra:.2f}, {dec:.2f})",
        )