In [None]:
from ipywidgets import interactive
from expecto import get_spectrum

import matplotlib.pyplot as plt
import numpy as np
import astropy.units as u

from jax import numpy as jnp
from fleck.jax import ActiveStar, bin_spectrum


times = np.linspace(-0.04, 0.04, 100)
wavelength = np.geomspace(0.5, 5, 101) * u.um

# Download and bin PHOENIX model spectra to compute contrast:
kwargs = dict(
    bins=wavelength, 
    min=wavelength.min(), 
    max=wavelength.max(), 
    log=False
)
phot, cool, hot = [
    bin_spectrum(
        get_spectrum(T_eff=T_eff, log_g=5.0, cache=True), **kwargs
    )
    for T_eff in [2600, 2300, 4000]
]

In [None]:
def lc_interact(
    lon1=-0.25, lon2=0, lon3=0.25, 
    rad1=0.15, rad2=0.3, rad3=0.005,
    lat1=1.3 * np.pi/2, lat2=1.3 * np.pi/2, lat3=1.3 * np.pi/2
):
    inclination = 0.982 * np.pi/2

    r = ActiveStar(
        times=times, 
        inclination=np.pi/2,
        T_eff=2600,
        wavelength=phot.wavelength.to_value(u.m),
        phot=phot.flux.value,
        u_ld=[0.1, 0.3]
    )

    r.add_spot(lon=lon1, lat=lat1, rad=rad1, spectrum=cool.flux.value, temperature=cool.meta['PHXTEFF'])
    r.add_spot(lon=lon2, lat=lat2, rad=rad2, spectrum=cool.flux.value, temperature=cool.meta['PHXTEFF'])
    r.add_spot(lon=lon3, lat=lat3, rad=rad3, spectrum=hot.flux.value, temperature=hot.meta['PHXTEFF'])
    
    t0 = 0
    a = 15
    ecc = 0
    rp = 0.06
    rotation, f_S = r.rotation_model(t0=t0)
    transit, contam, X, Y = r.transit_model(
        t0 = t0,
        ecc = ecc,
        rp = rp,
        period = 1.5,
        a = a,
        inclination = inclination,
        f_S = f_S
    )

    lc = np.squeeze((1 + np.sum(transit, axis=1)) * (1 + rotation))
    fig, ax = plt.subplots(1, 3, figsize=(12, 4), dpi=100)

    skip = 15
    
    cmap = lambda i: plt.cm.Spectral_r((r.wavelength[i] - r.wavelength.min()) / r.wavelength.ptp())
    for i, lc_i in enumerate(lc[:, ::skip].T):
        ax[0].plot(r.times, lc_i, color=cmap(skip * i))

    ax[0].set(
        xlabel='Time [d]',
        ylabel='Flux [relative to inactive]'
    )

    t_ind = np.argmin(np.abs(r.times - t0))
    contaminated_depth = 1e6 * contam[t_ind, 0, :]

    ax[1].plot(r.wavelength * 1e6, contaminated_depth, zorder=-3, lw=2.5, color='silver')

    ax[1].scatter(
        r.wavelength[::skip] * 1e6, contaminated_depth[::skip], 
        c=cmap(skip * np.arange(len(r.wavelength) // skip + 1)), 
        s=50, edgecolor='k', zorder=4
    )
    ax[1].set(
        xlabel='Wavelength [µm]',
        ylabel='Transit depth [ppm]',
        xlim=[1e6 * r.wavelength.min(), 1e6 * r.wavelength.max()],
        ylim=[contaminated_depth.min() * 0.999, contaminated_depth.max() * 1.001]
    )

    r.plot_star(ax=ax[2], rp=rp, a=a, ecc=ecc, inclination=inclination)
    
    for sp in ['right', 'top']:
        for axis in ax[:2]:
            axis.spines[sp].set_visible(False)
    
    fig.tight_layout()
    plt.show()

In [None]:
interactive(
    lc_interact, 
    lon1=(-np.pi, 3*np.pi, 0.1),
    lon2=(-np.pi, 3*np.pi, 0.1),
    lon3=(-np.pi, 3*np.pi, 0.1),
    
    lat1=(0, np.pi, 0.05),
    lat2=(0, np.pi, 0.05),
    lat3=(0, np.pi, 0.05),
    
    rad1=(0, 0.3, 0.01),
    rad2=(0, 0.3, 0.01),
    rad3=(0, 0.05, 0.001),
)

In [None]:
lon1=-0.25
lon2=0
lon3=0.25

rad1=0.15
rad2=0.3
rad3=0.005

# lat1=1.3 * np.pi/2
# lat2=1.3 * np.pi/2
# lat3=1.3 * np.pi/2


lat1=0
lat2=0
lat3=0

inclination = 0.982 * np.pi/2

r = ActiveStar(
    times=times, 
    inclination=np.pi/2,
    T_eff=2600,
    wavelength=phot.wavelength.to_value(u.m),
    phot=phot.flux.value,
    u_ld=[0.1, 0.3]
)

r.add_spot(lon=lon1, lat=lat1, rad=rad1, spectrum=cool.flux.value, temperature=cool.meta['PHXTEFF'])
r.add_spot(lon=lon2, lat=lat2, rad=rad2, spectrum=cool.flux.value, temperature=cool.meta['PHXTEFF'])
r.add_spot(lon=lon3, lat=lat3, rad=rad3, spectrum=hot.flux.value, temperature=hot.meta['PHXTEFF'])

t0 = 0
a = 15
ecc = 0
rp = 0.06

rotation, f_S = r.rotation_model(t0=t0)
transit, contam, X, Y = r.transit_model(
    t0 = t0,
    ecc = ecc,
    rp = rp,
    period = 1.5,
    a = a,
    inclination = inclination,
    f_S = f_S
)

lc = np.squeeze((1 + np.sum(transit, axis=1)) * (1 + rotation))