In [None]:
# packages needed
import numpy as np

# np.seterr(over = 'ignore', invalid = 'ignore')  # to remove error warning
import matplotlib.pyplot as plt

# import h5py                                     # to input data
# from gatspy import datasets, periodic           # find period
from numpy.core.fromnumeric import std  # used to trim data
import ellc  # light curve generation
import emcee  # for monte carlo
from scipy.optimize import minimize
import corner  # to plot for emcee
import lmfit as lm  # for initial fit
from lmfit import Parameters

%matplotlib inline                              # to change the plot to the notebook itself

In [None]:
def extract_data(
    stat_list,
    file="/home/tungtran/summer_2021_project/sdb.hdf5",
    star_ID="[81842830750193450]",
):
    file = h5py.File(file, "r")
    LC = file["Objects"][star_ID]

    stat_dict = {}
    for stat in stat_list:
        if stat == "t":
            stat_dict[stat] = LC[:, 0]
        elif stat == "y":
            stat_dict[stat] = LC[:, 1]
        elif stat == "dy":
            stat_dict[stat] = LC[:, 2]
        else:
            stat_dict[stat] = LC.attrs[stat]

    return stat_dict


# plotting light curve
def find_period(t_obs, y, dy, range=(0.015, 100), double_period=False):
    model_scargle = periodic.LombScargleFast(fit_period=True)
    model_scargle.optimizer.period_range = (0.015, 100)
    model_scargle.fit(t_obs, y, dy)

    period = (
        2 * model_scargle.best_period if double_period else model_scargle.best_period
    )
    phases = [(day % period) / period for day in t_obs]

    return phases


# Using lmfit to identify the region where this is likely to work
def lightcurve_model(
    t_obs,
    A,
    t_zero,
    radius_1,
    radius_2,
    incl,
    sbratio,
    q,
    heat_1,
    heat_2,
    inverse_sbratio=False,
    inverse_q=False,
):
    sbratio = 1 / sbratio if inverse_sbratio else sbratio
    q = 1 / q if inverse_q else q

    try:
        y_model = A * ellc.lc(
            t_obs=t_obs,
            t_zero=t_zero,
            radius_1=radius_1,
            radius_2=radius_2,
            incl=incl,
            sbratio=sbratio,
            q=q,
            heat_1=heat_1,
            heat_2=heat_2,
            verbose=0,
        )
        if min(y_model) == -np.inf or max(y_model) == np.inf:
            return t_obs * 10**9
    except:
        return t_obs * 10**9
    return y_model


def fit_light_curve(model=lightcurve_model):
    lm_model = lm.Model(lightcurve_model)
    params = Parameters()

    for var in priors.keys():
        value, min_bound, max_bound = priors[var]
        vary = True if var in lmfit_params else False
        params.add(var, value=value, min=min_bound, max=max_bound, vary=vary)

    result = lm_model.fit(data=data["y"], params=params, t_obs=phases)

    plt.figure()
    result.plot(xlabel="Time", ylabel="Amplitude")
    print(result.fit_report())

    return result


def find_log_prior(theta):
    in_bound = True
    for i in range(theta.shape[0]):
        value, min_bound, max_bound = priors[emcee_params[i]]
        if theta[i] < min_bound or theta[i] > max_bound:
            in_bound = False
            break

    return 0.0 if in_bound else -np.inf


def find_log_probability(theta, phases, y, dy):

    log_prior = find_log_prior(theta)
    if not np.isfinite(log_prior):
        return -np.inf

    def pick(var):
        return theta[emcee_params.index(var)] if var in emcee_params else priors[var][0]

    # likelihood function
    y_model = lightcurve_model(
        t_obs=phases,
        A=pick("A"),
        t_zero=pick("t_zero"),
        radius_1=pick("radius_1"),
        radius_2=pick("radius_2"),
        incl=pick("incl"),
        sbratio=pick("sbratio"),
        q=pick("q"),
        heat_1=pick("heat_1"),
        heat_2=pick("heat_2"),
    )

    log_likelihood = -0.5 * np.sum(((y - y_model) / dy) ** 2)

    return log_prior + log_likelihood


def fit_emcee(n_walker, n_run):
    n_dim = len(emcee_params)
    pos = np.array(
        [
            result.best_values[key] if key in lmfit_params else priors[key][0]
            for key in emcee_params
        ]
    ).reshape(-1, n_dim)
    pos = pos * (1 + 0.01 * np.random.randn(n_walker, n_dim))

    from multiprocessing import Pool

    with Pool() as pool:
        sampler = emcee.EnsembleSampler(
            n_walker,
            n_dim,
            find_log_probability,
            args=(phases, data["y"], data["dy"]),
            pool=pool,
        )
        sampler.run_mcmc(pos, n_run, progress=True)
        # semicolon to suppress output

    try:
        tau = sampler.get_autocorr_time()
        cut = 2 * np.mean(tau) if np.all(np.isfinite(tau)) else n_run // 10
    except:
        cut = n_run // 10

    flat_samples = sampler.get_chain(discard=cut, thin=5, flat=True)

    figure = corner.corner(
        flat_samples, labels=[var for var in emcee_params], show_titles=True
    );

In [None]:
star_ID = "[81842830750193450]"
stat_list = [
    "t",
    "y",
    "dy",
    "RA",
    "Dec",
    "ref_r_flux",
    "parallax",
    "parallax_error",
    "bp_rp",
    "G",
    "pm",
]
data = extract_data(stat_list, star_ID=star_ID)

phases = find_period(t_obs=data["t"], y=data["y"], dy=data["dy"], double_period=False)

priors = {}
priors["A"] = (73026, 50000, 90000)
priors["t_zero"] = (0.3, -1, 1)
priors["radius_1"] = (0.18, 0.1, 0.3)
priors["radius_2"] = (0.17, 0.1, 0.3)
priors["incl"] = (82, 60, 90)
priors["sbratio"] = (0.2, 0, 10)
priors["q"] = (2, 0.1, 10)
priors["heat_1"] = (1, 0, 2)
priors["heat_2"] = (2.6, 1, 3)

lmfit_params = [
    "A",
    "t_zero",
    "radius_1",
    "radius_2",
    "incl",
    "sbratio",
    "q",
    "heat_2",
]  # any order works
result = fit_light_curve()

emcee_params = ["A", "t_zero", "radius_1", "radius_2", "incl"]  # any combination works
fit_emcee(n_walker=64, n_run=2000)