In [None]:
import os
os.environ.update(
    OMP_NUM_THREADS="1", OPENBLAS_NUM_THREADS="1", NUMEXPR_NUM_THREADS="1", MKL_NUM_THREADS="1"
)

In [None]:
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import exoplanet as xo
import pymc3 as pm
import pymc3_ext as pmx
import theano.tensor as tt
import arviz as az
import corner

from util import phasefold, get_stats, plot_binned

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# Numbers from the SPOC S1-39 multisector https://tev.mit.edu/data/delivered-signal/i226894/
period = 1.74469
epoch = 1326.9843
ror = 0.0123
duration = 1.1574 / 24 # in days

In [None]:
tic_info = pd.read_csv("data/tic_82.csv", index_col=0).T

In [None]:
lc = pd.read_csv("data/lc.csv", index_col=0)
lc = lc[lc["quality"] == 0]
phase = phasefold(lc["time"], period, epoch)
intran = np.abs(phase) < 2 * duration / period
lc = lc[intran]
print(lc.shape)

In [None]:
plot_binned(lc["time"], lc["flux"], period, epoch, bins=int(period/(1/60/24)), color="0.8")
plot_binned(lc["time"], lc["flux"], period, epoch, bins=int(period/(15/60/24)))

In [None]:
with pm.Model() as model:
    t0 = pm.Uniform("t0", lower=epoch - .1 * period, upper = epoch + .1 * period)
    per = pm.Uniform("per", lower=0.9*period, upper=1.1*period)

    ror = pm.Uniform("ror", lower=0, upper=2, testval=ror)
    b = pm.Uniform("b", lower=-(ror + 1), upper=(ror + 1))
    
    # From Table 1
    m_star = pm.Normal("m_star", mu=0.421, sd=.021)
    r_star = pm.Normal("r_star", mu=0.427, sd=0.021)
    
    orbit = xo.orbits.KeplerianOrbit(
        r_star=r_star,
        m_star=m_star,
        period=per,
        t0=t0,
        b=b,
    )
    u = xo.distributions.QuadLimbDark("u")
    lcs = pm.Deterministic(
        "lcs",
        xo.LimbDarkLightCurve(u).get_light_curve(
            orbit=orbit,
            r=ror*r_star,
            t=lc["time"],
            texp=0.00138889, # 2 min
            oversample=3,
        )
    )
    mean = pm.Normal("mean", mu=1.0, sd=0.5)
    full_lc = tt.sum(lcs, axis=-1) + mean
    pm.Deterministic("full_lc", full_lc)
    
    # Observation model
    pm.Normal("obs", mu=full_lc, sd=lc["flux_err"], observed=lc["flux"], shape=len(lc))

In [None]:
# Find MAP parameter solution
# This helps us make sure our model is set up reasonably and
# will give MCMC a starting point
with model:
    map_soln = pmx.optimize(start=model.test_point, vars=[mean])
    map_soln = pmx.optimize(start=map_soln)

In [None]:
plot_binned(lc["time"], lc["flux"], map_soln["per"], map_soln["t0"], bins=int(period/(1/60/24)), color="0.8")
plot_binned(lc["time"], lc["flux"], map_soln["per"], map_soln["t0"], bins=int(period/(15/60/24)), s=100)
plot_binned(lc["time"], map_soln["full_lc"], map_soln["per"], map_soln["t0"], s=1)

In [None]:
# Run MCMC!
np.random.seed(42)
with model:
    trace = pm.sample(
        tune=1000,
        draws=1000,
        start=map_soln,
        cores=5,
        chains=5,
        target_accept=0.90,
        return_inferencedata=True,
        init="adapt_full",
    )

In [None]:
trace.to_netcdf("data/230309-oversample7.nc")

In [None]:
sampled_params = ["t0", "per", "ror", "b", "u", "r_star", "m_star"]
az.summary(trace, var_names=sampled_params)
# # Check that r_hat is close to 1
# stats

In [None]:
_ = corner.corner(trace, var_names=sampled_params)

In [None]:
from util import (
    get_a,
    get_aor,
    get_inclination,
    get_radius,
    get_teq,
    get_insolation,
    get_duration,
    get_rho,
    get_transit_shape,
    print_stats,
)

In [None]:
samples = trace.posterior.stack(sample=("chain", "draw"))

In [None]:
ignore_params = ["full_lc", "lcs", "ecs", "mean", "u0", "u1"]
sampled_params = ["t0", "per", "ror", "b", "u", "r_star", "m_star"]

# Values from Table 1. Averaged stds.
teff_data = samples["per"].copy()
teff_data.data = np.random.normal(3485, 138.5, len(samples["t0"]))
samples["teff"] = teff_data

samples["a"] = get_a(samples["per"], samples["m_star"])
samples["depth"] = samples["ror"]**2
samples["aor"] = get_aor(samples["per"], samples["r_star"], samples["m_star"])
samples["inc"] = get_inclination(samples["per"], samples["b"], samples["r_star"], samples["m_star"])
samples["r_p"] = get_radius(samples["ror"], samples["r_star"])
samples["teq"] = get_teq(samples["per"], samples["r_star"], samples["teff"], samples["m_star"])
samples["irr"] = get_insolation(samples["per"], samples["r_star"], samples["teff"], samples["m_star"])
samples["dur"] = get_duration(samples["per"], samples["ror"], samples["b"], samples["r_star"], samples["m_star"])
samples["t_shape"] = get_transit_shape(samples["ror"], samples["b"])
samples["qin"] = (1 - samples["t_shape"]) / 2
samples["rho"] = get_rho(samples["per"], samples["dur"], samples["depth"], samples["t_shape"])

derived_params = list(samples.keys() - set(sampled_params) - set(ignore_params))
print("Sampled:", sampled_params)
print("Derived:", derived_params)

In [None]:
print("Sampled:")
for param_name in sampled_params:
    match param_name:
        case "b": print_stats("|b|", np.abs(samples[param_name]))
        case "u":
            print_stats("u0", samples["u"][0])
            print_stats("u1", samples["u"][1])
        case _: print_stats(param_name, samples[param_name], sigma=1)
print()

print("Derived:")
for param_name in derived_params:
    print_stats(param_name, samples[param_name], sigma=1)

In [None]:
# Old stuff below

In [None]:
median_lc = np.nanmedian(samples["full_lc"], axis=1)
median_per = np.nanmedian(samples["per"])
median_t0 = np.nanmedian(samples["t0"])

In [None]:
plt.figure(figsize=(5, 5))
plt.subplot(2, 1, 1)
plot_binned(lc["time"], lc["flux"], median_per, median_t0, bins=int(period/(1/60/24)), color="0.8")
plot_binned(lc["time"], lc["flux"], median_per, median_t0, bins=int(period/(15/60/24)), s=100)
plot_binned(lc["time"], median_lc, median_per, median_t0, s=1)

plt.subplot(2, 1, 2)
plot_binned(lc["time"], lc["flux"] - median_lc, median_per, median_t0, bins=int(period/(1/60/24)))
plot_binned(lc["time"], lc["flux"] - median_lc, median_per, median_t0, bins=int(period/(15/60/24)), s=100)
plt.axhline(0, color="k")