# AU Mic with `fleck`

Fit the TESS Sector 1 light curve using `fleck` with three cool active regions. Using the best-fit spot map, extrapolate to find the wavelength dependence of the rotational modulation at other wavelengths.

In [None]:
%matplotlib inline
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro import distributions as dist

# Set the number of cores on your machine for parallel computing:
cpu_cores = 4
numpyro.set_host_device_count(cpu_cores)

import jax
from jax import jit, numpy as jnp, config
from jax.random import PRNGKey, split

# we need float64 support:
config.update("jax_enable_x64", True)

import arviz
from corner import corner

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import numpy as np
import astropy.units as u
from astropy.time import Time
from expecto import get_spectrum
from lightkurve import search_lightcurve

from fleck.jax import ActiveStar, bin_spectrum

The commented code below finds the transmittance weighted mean wavelength of the TESS bandpass.

In [None]:
# import numpy as np
# import astropy.units as u
# from tynt import FilterGenerator

# f = FilterGenerator()
# tess_filt = f.reconstruct('TESS/TESS.Red')
# tess_mean_wavelength = np.average(tess_filt.wavelength, weights=tess_filt.transmittance).to_value(u.m)
# tess_mean_wavelength

# this is the answer you'd get if you ran the above code:
tess_mean_wavelength = 8.004867649770393e-07  # [m]

The sigma clipping below helps remove flares:

In [None]:
tess_lc = search_lightcurve(
    "AU Mic", mission="TESS", author="SPOC", sector=1
).download_all()[0].normalize().remove_nans().remove_outliers(sigma_upper=2.8)

Assume spots are *very* cold. They're not likely *this* cold.

In [None]:
T_phot = 3700  # Plavchan (2009)
T_spot1 = T_spot2 = T_spot3 = 2300

blackbody = lambda *args: ActiveStar()._blackbody(*args)

phot = jnp.array([float(blackbody(tess_mean_wavelength, T_phot))])
spectrum = jnp.array(
    [[blackbody(tess_mean_wavelength, T_spot1)],
     [blackbody(tess_mean_wavelength, T_spot2)],
     [blackbody(tess_mean_wavelength, T_spot3)]]
)

Below we construct a model that you can tweak by hand, to see how the parameters change:

In [None]:
u1 = 0.453
u2 = 0.207

r = ActiveStar(
    times=tess_lc.time.value, 
    inclination=np.pi/2,
    T_eff=T_phot,
    wavelength=jnp.array([tess_mean_wavelength]),
    P_rot=4.863,
    phot=phot,
    spectrum=spectrum
)

r.lon = jnp.array([2.2, -0.5, 2.2])
r.lat = jnp.array([np.pi/2, np.pi/3, 0.1])
r.rad = jnp.array([0.2, 0.2, 0.2])
r.temperature = jnp.array([T_spot1, T_spot2, T_spot3])

lc, contam = r.rotation_model(f0=1.05, u1=u1, u2=u2, t0_rot=tess_lc.time.value[0])
lc = np.squeeze(lc)

ax = plt.gca()
tess_lc.plot(ax=ax)
ax.plot(tess_lc.time.value, lc)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 4))

r.plot_star(0, 2, 10, np.pi/2, ax=ax[0])
ax[0].set(
    title='$\phi = 0$'
)

r.plot_star(0, 2, 10, np.pi/2, t0_rot=r.P_rot/2, ax=ax[1])
ax[1].set(
    title='$\phi = \pi$'
);

In [None]:
def numpyro_model(
    y=jnp.array(np.array(tess_lc.flux.value)), 
    y_err=jnp.array(np.array(tess_lc.flux_err.value)), 
    n_spots=3, save_model=False
):   

    theta = numpyro.sample(
        'theta', dist.Uniform(-np.pi, np.pi), 
        sample_shape=(n_spots,)
    )

    x1, x2 = jnp.sin(theta), jnp.cos(theta)
    lon = numpyro.deterministic(
        'lon', jnp.arctan2(x1, x2)
    ) + np.pi

    rad = numpyro.sample(
        'rad', dist.Uniform(low=0, high=0.3), sample_shape=(n_spots,)
    )

    f0 = numpyro.sample(
        'f0', dist.Uniform(low=0.9, high=1.1)
    )
    
    r = ActiveStar(
        times=tess_lc.time.value, 
        inclination=np.pi/2,
        T_eff=T_phot,
        phot=phot,
        spectrum=spectrum,
        wavelength=jnp.array([tess_mean_wavelength]),
        P_rot=4.863
    )
    
    r.lon = lon
    r.lat = jnp.ones(n_spots) * np.pi / 2
    r.rad = rad

    lc, contam = r.rotation_model(f0=f0, u1=u1, u2=u2, t0_rot=tess_lc.time.value[0])
    lc = jnp.squeeze(lc)

    log_beta = numpyro.sample('log_beta', dist.Uniform(-1, 2))
    
    if save_model:
        # this gets used to produce posterior predictive samples later on:
        numpyro.deterministic("_lc_model", lc)
        contaminated_depth = 1e6 * contam
        numpyro.deterministic("_depth_model", contaminated_depth)

    # Normally distributed likelihood
    numpyro.sample(
        "obs", dist.Normal(
            loc=lc, 
            scale=y_err * jnp.exp(log_beta)
        ), obs=y
    )

In [None]:
from numpyro.infer import (
    SVI, autoguide, Trace_ELBO, 
)
from numpyro import optim

guide = autoguide.AutoMultivariateNormal(numpyro_model)

# # some alternatives to this guide are:
# guide = autoguide.AutoDAIS(numpyro_model)
# guide = autoguide.AutoBNAFNormal(numpyro_model)

svi = SVI(
    model=numpyro_model, 
    guide=guide, 
    optim=optim.Adagrad(step_size=0.15),
    loss=Trace_ELBO()
)
svi_result = svi.run(
    rng_key=PRNGKey(1), 
    num_steps=500
)
plt.loglog(svi_result.losses + 1 - np.nanmin(svi_result.losses))
plt.show()

params = svi_result.params
posteriors = guide.sample_posterior(PRNGKey(1), params, sample_shape=(2_000,))
labels = [k for k, v in posteriors.items() if not k.startswith("_")]


samples = []
iter_labels = []
for key, label in zip(posteriors, labels):
    
    if label == 'x':
        continue
    
    if posteriors[key].ndim > 1:
        for i, col in enumerate(posteriors[key].T):
            samples.append(posteriors[key][:, i])
            iter_labels.append(f"{label}_{i}")
    else:        
        samples.append(posteriors[key][None, :])
        iter_labels.append(label)
        
samples = np.vstack(samples).T


corner(samples, labels=iter_labels)
fig = plt.gcf()
fig.suptitle('SVI')
plt.show()

Below is what you'd run for HMC, but this particular model is highly degenerate, and the fits won't converge:

In [None]:
posterior_predictive = Predictive(
    model=numpyro_model, 
    posterior_samples=posteriors,
    return_sites=['_lc_model', '_depth_model'],
)

pred = posterior_predictive(
    rng_key=PRNGKey(1), 
    save_model=True
)

y_pred = pred['_lc_model']  # contaminated transit model over wavelength
depth_pred = pred['_depth_model'] # contaminated transit depth at mid-transit time

In [None]:
low, mid, high = np.percentile(y_pred, [16, 50, 84], axis=0)

In [None]:
plt.scatter(tess_lc.time.value, tess_lc.flux.value, color='silver', s=2)

plt.plot(tess_lc.time.value, mid, color='DodgerBlue', lw=2)

plt.fill_between(tess_lc.time.value, low, high, color='DodgerBlue', alpha=0.4, lw=0)

plt.gca().set(
    xlabel='Time [d]',
    ylabel='Flux'
)

### extrapolate wavelength dependence from best fit in TESS

This time, use PHOENIX model spectra to predict the wavelength dependence of variability. The lowest temperature model available with this package is 2300 K.

In [None]:
from fleck.jax import bin_spectrum

wavelengths = jnp.linspace(0.5, 5, 100) * u.um
times = jnp.linspace(tess_lc.time.value.min(), tess_lc.time.value.max(), 300)

kwargs = dict(
    bins=wavelengths, 
    min=wavelengths.min(), 
    max=wavelengths.max(), 
    log=False
)


phot, spot = [
    bin_spectrum(
        get_spectrum(temp, 5.0, cache=True), **kwargs
    ) for temp in [T_phot, T_spot1]
]

In [None]:
spectrum = jnp.vstack([
    spot.flux.value, 
    spot.flux.value, 
    spot.flux.value
])

r = ActiveStar(
    times=times, 
    inclination=np.pi/2,
    T_eff=T_phot,
    wavelength=phot.wavelength.to_value(u.m),
    phot=phot.flux.value,
    spectrum=spectrum,
    P_rot=4.863,
    temperature=jnp.array([T_spot1, T_spot2, T_spot3])
)

r.lon = posteriors['lon'].mean(axis=0)
r.lat = jnp.array([np.pi / 2] * len(r.lon))
r.rad = posteriors['rad'].mean(axis=0)

lc, contam = r.rotation_model(
    f0=posteriors['f0'].mean(), u1=u1, u2=u2, 
    t0_rot=tess_lc.time.value[0]
)
lc = np.squeeze(lc)

Compute rotational modulation without transits:

In [None]:
# blue to red colormap over `wavelengths`
cmap = lambda x: plt.cm.Spectral_r((x - wavelengths.min()) / np.ptp(wavelengths))

for i in range(0, len(wavelengths), 20):
    plt.plot(
        times, lc[:, i], 
        color=cmap(wavelengths[i]), 
        label=f"{wavelengths[i].to_value(u.um):.1f} µm"
    )
    
plt.legend(loc='upper right', framealpha=1)
plt.gca().set(
    xlabel='Time [d]',
    ylabel='Flux'
);

Predict transmission contamination from this best-fit model for AU Mic b:

In [None]:
# Wittrock (2023)
period = 8.46308
rp = 0.0488
inclination = np.radians(89.57917)
t0 = 2458322.77 - 2457000
a = 18.79

# compute model near a transit:
r.times = jnp.linspace(t0 - 0.25, t0 + 0.25, 250)

transit_lc, contam = r.transit_model(
    t0=t0, 
    period=period, 
    rp=rp,
    a=a,
    inclination=inclination, 
    u1=u1, 
    u2=u2,
    t0_rot=tess_lc.time.value[0]
)[:2]

In [None]:
# blue to red colormap over `wavelengths`
fig, ax = plt.subplots(1, 3, figsize=(15, 4))

cmap = lambda x: plt.cm.Spectral_r((x - wavelengths.min()) / np.ptp(wavelengths))

for i in range(0, len(wavelengths), 20):
    ax[0].plot(
        r.times, transit_lc[:, i], 
        color=cmap(wavelengths[i]), 
        label=f"{wavelengths[i].to_value(u.um):.1f} µm"
    )

ax[0].legend(loc='lower right', framealpha=1)
ax[0].set(
    xlabel='Time [d]',
    ylabel='Flux'
)

r.plot_star(t0, rp, a, inclination, ax=ax[1], t0_rot=tess_lc.time.value[0])

ax[2].plot(wavelengths.to_value(u.um)[:-1], 1e6 * contam)
ax[2].set(
    xlabel='Wavelength [µm]',
    ylabel='Contaminated transit\ndepth [ppm]'
)
fig.tight_layout()