In [None]:
from ipywidgets import interactive, HBox, VBox

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import numpy as np
import astropy.units as u
from expecto import get_spectrum

from jax import numpy as jnp
from fleck.jax import ActiveStar, bin_spectrum


times = np.linspace(-0.04, 0.04, 250)
wavelength = np.geomspace(0.5, 5, 101) * u.um

# Download and bin PHOENIX model spectra to compute contrast:
kwargs = dict(
    bins=wavelength, 
    min=wavelength.min(), 
    max=wavelength.max(), 
    log=False
)

phot, cool, hot = [
    bin_spectrum(
        get_spectrum(T_eff=T_eff, log_g=5.0, cache=True), **kwargs
    )
    for T_eff in [2600, 2400, 4400]
]

In [None]:
def plot_transit_contamination(
    r, lc, contam, X, Y, spectrum_at_transit, 
    norm_oot_per_wavelength, norm_stellar_spectrum
):
    fig = plt.figure(figsize=(8, 4), dpi=100)
    gs = GridSpec(2, 2, figure=fig)
    
    ax = [
        fig.add_subplot(gs[0, 0]), 
        fig.add_subplot(gs[1, 0]),
        fig.add_subplot(gs[:, 1:3]),
    ]
    
    skip = 15
    
    cmap = lambda i: plt.cm.Spectral_r((r.wavelength[i] - r.wavelength.min()) / r.wavelength.ptp())
    
    if norm_stellar_spectrum:    
        scale_relative_to_flux_at_wavelength = 1
    else:
        scale_relative_to_flux_at_wavelength = (spectrum_at_transit / spectrum_at_transit.mean())[::skip]

    for i, lc_i in enumerate((lc * scale_relative_to_flux_at_wavelength)[:, ::skip].T):
        
        if norm_oot_per_wavelength:
            lc_i /= lc_i.mean()
        
        ax[0].plot(r.times, lc_i, color=cmap(skip * i))


    ax[0].set(
        xlabel='Time [d]',
        ylabel='$\\left(F(t)/\\bar{F}\\right)_{\\lambda}$',
    )

    t_ind = np.argmin(np.abs(r.times - t0))
    contaminated_depth = 1e6 * contam#[t_ind, :]

    ax[1].plot(r.wavelength * 1e6, contaminated_depth, zorder=-3, lw=2.5, color='silver')
    ax[1].scatter(
        r.wavelength[::skip] * 1e6, contaminated_depth[::skip].T, 
        c=cmap(skip * np.arange(len(r.wavelength) // skip + 1)), 
        s=50, edgecolor='k', zorder=4
    )
    ax[1].set(
        xlabel='Wavelength [µm]',
        ylabel='Transit depth [ppm]',
        xscale='log',
        xlim=[1e6 * 0.9 * r.wavelength.min(), 1e6 * 1.1 * r.wavelength.max()],
    )

    r.plot_star(t0=t0, rp=rp, a=a, ecc=ecc, inclination=inclination, ax=ax[2])
    
    for sp in ['right', 'top']:
        for axis in ax:
            axis.spines[sp].set_visible(False)

    fig.tight_layout()
    plt.show()

In [None]:
# stellar parameters:
r = ActiveStar(
    times=times, 
    inclination=np.pi/2,
    T_eff=phot.meta['PHXTEFF'],
    wavelength=phot.wavelength.to_value(u.m),
    phot=phot.flux.value,
)

# spot parameters:
# radians [0, 2pi]
lon1 = -0.45
lon2 = 0
lon3 = 0.34

# radians [0, pi]
lat1 = np.pi/2
lat2 = np.pi/2
lat3 = np.pi/2


# rspot/rstar
rad1 = 0.03
rad2 = 0.08
rad3 = 0.005

r.add_spot(lon=lon1, lat=lat1, rad=rad1, spectrum=cool.flux.value, temperature=cool.meta['PHXTEFF'])
r.add_spot(lon=lon2, lat=lat2, rad=rad2, spectrum=cool.flux.value, temperature=cool.meta['PHXTEFF'])
r.add_spot(lon=lon3, lat=lat3, rad=rad3, spectrum=hot.flux.value, temperature=hot.meta['PHXTEFF'])

# planet parameters for TRAPPIST-1 c from Agol 2021
inclination = np.radians(89.778)
a = 28.549
rp = 0.08440 
period = 2.421937 
t0 = 0
ecc = 0

def lc_interact(
    lon1=lon1,
    lon2=lon2,
    lon3=lon3,
    rad1=rad1,
    rad2=rad2,
    rad3=rad3,
    lat1=lat1,
    lat2=lat2,
    lat3=lat3,
    norm_oot_per_wavelength=True,
    norm_stellar_spectrum=True
):

    r.lon = jnp.array([lon1, lon2, lon3])
    r.lat = jnp.array([lat1, lat2, lat3])
    r.rad = jnp.array([rad1, rad2, rad3])

    lc, contam, X, Y, spectrum_at_transit = r.transit_model(
        t0 = t0,
        period = period,
        rp = rp,
        a = a,
        inclination = inclination,
        ecc = ecc,
        u1=0.1, 
        u2=0.3
    )
    
    plot_transit_contamination(
        r, lc, contam, X, Y, spectrum_at_transit, 
        norm_oot_per_wavelength, norm_stellar_spectrum
    )


In [None]:
widget = interactive(
    lc_interact, 
    lon1=(-np.pi, 3*np.pi, 0.1),
    lon2=(-np.pi, 3*np.pi, 0.1),
    lon3=(-np.pi, 3*np.pi, 0.1),
    
    lat1=(0, np.pi, 0.05),
    lat2=(0, np.pi, 0.05),
    lat3=(0, np.pi, 0.05),
    
    rad1=(0, 0.3, 0.01),
    rad2=(0, 0.3, 0.01),
    rad3=(0, 0.05, 0.001),
)


controls = VBox(widget.children[:-1])
output = widget.children[-1]
display(HBox([output, controls]))