In [None]:
import matplotlib.pyplot as plt
%matplotlib inline


from funcs.model import aflare, calculate_specific_flare_flux
from funcs.flarefit import convert_posterior_units


import copy
import emcee

In [None]:
from funcs.model import FlareModulator

import numpy as np
import pandas as pd

import astropy.units as u
from astropy.constants import R_sun

In [None]:
def log_probability(params, FM, g):
    """Posterior probability to pass to MCMC sampler.
    
    Parameters:
    -----------
    params : list
        list of fitted model parameters
    FM : FlareModulator object
    
    g : astropy CompoundModel
        empirical prior on inclination in rad
        
    Return:
    -------
    negative posterior log probability 
    """
    lp = FM.log_prior(params, g=g)

    if not np.isfinite(lp):
        print("noted: no prior")
        return -np.inf

    try:
        ll = FM.log_likelihood(params)

    except:
        print("noted: no loglike")
        return -np.inf

    if np.isnan(ll):
        print("noted: nan loglike")
        return -np.inf

    return lp + ll

In [None]:
import os
import pickle
CWD = "/".join(os.getcwd().split("/")[:-2])
inits = pd.read_csv(f"{CWD}/data/summary/inits_decoupled_GP.csv")

tstamp =  "07_12_2020_19_42"#,"07_12_2020_18_46"
nflares = 2
iscoupled = False
ID = "44984200"

cond1 = (inits.tstamp == tstamp)

if nflares == 2:
    cond2 = (inits.ID.str[:-1] == ID)
    
if nflares == 1:
    cond2 = (inits.ID == ID)

init = inits[cond1 & cond2]

params = []

row1 = init.iloc[0]
params.append(row1.theta_a)
params.append(row1.phi0)
params.append( row1.i_mu)

for i, row in init.iterrows():

    params.append(row.a)
    params.append(row.phi_a)
    params.append(row.fwhm1)
    if iscoupled == False:
        params.append(row.fwhm2)

In [None]:
init

In [None]:
params

In [None]:
# get g
inclination_path =  f"{CWD}/data/inclinations/{ID}_post_compound.p"
gincl = pickle.load(open(inclination_path, "rb" ) )

# get lc
try:
    lc = pd.read_csv(f"{CWD}/data/lcs/{tstamp}_{ID}.csv")
except:
    lc = pd.read_csv(f"{CWD}/data/lcs/{tstamp}_{ID[:-1]}.csv")
    target.ID = ID[:-1]
median = lc.median_[0]

# set up FlareModulator

FM = FlareModulator(lc.phi.values, lc.flux.values, lc.flux_err.values, 
                    row1.qlum_erg_s*u.erg/u.s, (row1.R_Rsun*R_sun).to("cm"),
                    lc.median_[0], nflares, iscoupled,)


In [None]:
# setup MCMC 
nwalkers = 32
ndim = len(params)
Nsteps = 100
wiggle = 1e-4

# set starting points
pos = params * (1. + wiggle * np.random.randn(nwalkers, ndim))


# setup emcee backend
backend = emcee.backends.Backend()
backend.reset(nwalkers, ndim)


# construct sampler
sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability,
                                args=(FM, gincl,), backend=backend)


# run MCMC chain
sampler.run_mcmc(pos, Nsteps, progress=True, store=True)

# get MCMC chain
samples = sampler.get_chain(discard=Nsteps//10, flat=True)

In [None]:
# plot the result to check for convergence
multi_samples = sampler.get_chain(discard=Nsteps//10)
fig, axes = plt.subplots(ndim, figsize=(10, 20), sharex=True)

labels = ["latitude_deg", "phase_deg","i_deg","a","t0_d","fwhm1","fwhm2"]
labels += ["a2","t0_d2","fwhm12","fwhm22"]

for j in range(multi_samples.shape[2]):
    ax = axes[j]
    ax.plot(multi_samples[:, :, j], "k", alpha=0.3)
    ax.set_xlim(0, len(multi_samples))
    ax.set_ylabel(labels[j], fontsize=16)
    ax.yaxis.set_label_coords(-0.1, 0.5)

axes[-1].set_xlabel("step number");
plt.tight_layout()

In [None]:
# ------------------------------------------------------
# POST-ANALYSIS

# define column names depending on how many flares you enter
# first: how many parameters did you fit in total
suffdict = {7:[""], 11:["_a", "_b"], 6:[""], 9:["_a", "_b"]}

# second: did we use coupled or decoupled flare model?
coupleddict = {True: ["a","phase_peak","fwhm"],
               False: ["a","phase_peak","fwhmi","fwhmg"]}

# define column names for the flare parameters
fl = [[x + suff for x in coupleddict[iscoupled]] for suff in suffdict[ndim]]

# flatten list of lists
fl = [item for sublist in fl for item in sublist]

# add flaring region and stellar parameters
columns = ["latitude_rad", "phase_0","i_rad"] + fl


In [None]:
# Get Prot_d
props = pd.read_csv(f"{CWD}/data/summary/inclination_input.csv")

prot = props[props["id"] == int(ID)].iloc[0].prot_d
print(prot)

# define DataFrame from chain
r = pd.DataFrame(data=samples, columns=columns)


# get raw and converted posterior tables for each flare in the light curve
for s in suffdict[ndim]:
    cols = columns[:3] + list(np.array(columns)[[s in x for x in columns]])
    print(cols)
    rs = r[cols]
    rs = rs.rename(index=str, columns = dict(zip(cols, [x.replace(s,"") for x in cols] )))
    
#     rs.to_csv(f"{CWD}/analysis/results/mcmc/"
#               f"{tstamp}_{target.ID}{s}"
#               f"_raw_mcmc_sample.csv",
#               index=False)
    
    rsconv = convert_posterior_units(rs, prot, lc.phi, lc.t)
    print(rsconv.head())
    
#     rsconv.to_csv(f"{CWD}/analysis/results/mcmc/"
#                   f"{tstamp}_{target.ID}{s}"
#                   f"_converted_mcmc_sample.csv",
#                   index=False)