In [None]:
from chromatic_fitting import *
import pymc3_ext as pmx
starry.config.lazy = True
starry.config.quiet = True
from chromatic import bintoR, bintogrid
from chromatic_fitting.jwst import planets
from chromatic_fitting.jwst.jwst_utils import get_spot_contrast

In [None]:
# !pip install git+https://github.com/kevin218/Astraeus.git

### Load in Kepler-51d data from Jessica

In [None]:
k51 = Rainbow("kepler51_stage3/S3_kepler51_ap3_bg7_SpecData.h5")
k51

In [None]:
k51.dt.to_value('s')

Bin in time (to speed up everything later on)

In [None]:
k51 = k51.bin(dt=2*u.minute)

In [None]:
k51

In [None]:
k51.normalize().bin(R=100).imshow()

In [None]:
k51.normalize().bin(R=5).plot_lightcurves()
# plt.savefig("wlc.png")

Load in planetary parameters

In [None]:
k51_params, k51d_params = planets.kepler51()
k51_params, k51d_params

### Create transit + spot + polynomial model

In [None]:
sp = TransitSpotModel(ydeg=20, nspots=1)
sp.setup_parameters(
    # A = Fitted(Normal, mu=1.0, sd=0.01),
    # stellar parameters
    rs=0.8636679, #Fitted(TruncatedNormal,lower=0, mu=k51_params["r"].to_value('R_sun'),testval=k51_params["r"].to_value('R_sun'), sigma=0.011), # stellar radius in Sun radius
    ms=0.9893804, #Fitted(Normal, mu=k51_params["m"].to_value('M_sun'), sigma=0.012), # stellar mass in Sun masses
    prot=k51_params["prot"], # stellar rotation period
    u=WavelikeFitted(TruncatedNormal, mu=k51_params['u'], sigma=0.1, lower=[0,-1], upper=[2,1], testval=k51_params['u'], shape=2),#Fitted(Normal, mu=[0.3,0.3], sigma=0.1, shape=2),#hatp18_params["u"], # limb-darkening coeffs
    stellar_inc=80, #Fitted(Uniform, lower=75, upper=85, testval=80),
    
    # spot 1 parameters
    spot_contrast = WavelikeFitted(Uniform, lower=0.0, upper=1.0, testval=0.3),
    spot_1_radius = 6.81, #Fitted(Uniform, lower=5.0, upper=30.0, testval=10), 
    spot_1_latitude = 2.627367, #Fitted(Uniform,lower=-90, upper=90, testval=0.1),
    spot_1_longitude = 6.432652, #Fitted(Uniform, lower=-180, upper=180, testval=0.1),
    
    # planet parameters
    mp=k51d_params['mp'].to_value('M_earth'), #Fitted(Normal, mu=k51d_params['mp'].to_value('M_earth'), sigma=1.12), #Fitted(Normal, mu=toi3884b_params['mp'], sigma=0.1), # planet mass in Earth masses
    rp=WavelikeFitted(TruncatedNormal, lower=0.0, mu=k51d_params['rp'].to_value('R_earth'), sigma=0.5,
                     testval=k51d_params['rp'].to_value('R_earth')), # planet radius in Earth radii
    inc=89.87501, #Fitted(Normal, mu=k51d_params["inc"],sigma=0.1),
    period=k51d_params['porb'], 
    # omega=Fitted(Normal, mu=k51d_params['omega'], sigma=68), 
    ecc=k51d_params['ecc'], 
    t0=2460121.8473359, #Fitted(Normal, mu=2460121.847, sigma=0.01),
    )

p = PolynomialModel(degree=2)
p.setup_parameters(
    p_0 = WavelikeFitted(Normal, mu=1.0, sigma=1e-2),
    p_1 = WavelikeFitted(Normal,mu=0.0,sigma=1e-2),
    p_2 = WavelikeFitted(Normal,mu=0.0,sigma=1e-3),
    )

s = sp * p

In [None]:
s.attach_data(k51.normalize())

In [None]:
s.choose_optimization_method("separate")

In [None]:
# s.plot_lightcurves()

In [None]:
s.data

Include a parameter for inflating uncertainties per wavelength

In [None]:
nsig = WavelikeFitted(TruncatedNormal, mu=1.0, sigma=0.005, lower=1.0, upper=3.0, testval=1.01)
s.setup_likelihood(inflate_uncertainties=False, inflate_uncertainties_prior=nsig)

MAP-optimize the models to get initial values for sampling

In [None]:
opt = s.optimize(plot=False)

In [None]:
# s._pymc3_model

MCMC (NUTS) Sample. I have an issue with running multiple cores on my machine - this might not be an issue for everyone... You can also try removing the `mp_ctx` kw and trying again 

In [None]:
s.sample(start=opt, sampling_method=pmx.sample, draws=500, tune=300, chains=2, cores=1, mp_ctx="spawn",)

### Plot and Save Results

In [None]:
s.plot_with_model_and_residuals(model_plotkw={'zorder':0, 'color':'orange'})
# plt.savefig("modelfit.png")

In [None]:
pickle.dump(s._pymc3_model, open("pymc3model.pkl", 'wb'))
pickle.dump(s.summary, open("summary.pkl", 'wb'))
pickle.dump(s.trace, open("trace.pkl", 'wb'))

In [None]:
s.imshow_with_models(vlimits_data=[0.99, 1.02], vspan_residuals=0.0005)
plt.savefig("imshow.png")

In [None]:
# !pip install ffmpeg
# !conda install -c conda-forge ffmpeg

In [None]:
# s._chromatic_models['transitspot'].keplerian_system.show(t=s.data.time)#t=2460121.83)

In [None]:
s._chromatic_models['transitspot'].plot_spectrum()
# plt.savefig("R10_transspec.png")
transspec = s.make_transmission_spectrum_table()
# transspec.to_csv("R10_transspec.csv")
# plt.plot(
# transspec

In [None]:
transspec = s.make_transmission_spectrum_table()
(max(transspec['transitspot_radius_ratio'].values) * 0.8648831*u.R_sun).to_value('R_earth')

In [None]:
contrasts = []
import math
sr = math.sin((9.339763 * u.degree).to_value('radian'))
print(sr)

### Extract spot contrast from model

In [None]:
spot_contrast = s._chromatic_models['transitspot'].plot_spectrum(param="spot_contrast",name_of_spectrum='Spot Contrast') 

for spot_teff in np.linspace(4000, 5000, 6):
    contrast = get_spot_contrast(wavelengths=s.data.wavelength, star_teff=5670, spot_teff=spot_teff, spot_radii=[sr], logg=4.7, metallicity=0.0, visualize=False)
    plt.plot(s.data.wavelength, contrast, label=f"{round(spot_teff)}K", zorder=0, alpha=0.6)

plt.legend()
# spot_contrast = s._chromatic_models['transitspot'].make_spectrum_table(param="spot_contrast")
# plt.savefig("R10_spot_contrast.png")
# pd.DataFrame(spot_contrast).to_csv("R10_spot_contrast.csv")
# spot_contrast

### Plotting hack to get individual imshow components

In [None]:
transitmod = []
data = s.get_data()
for i in range(data.nwave):
    params = s._chromatic_models['transitspot'].extract_from_posteriors(s.summary, i=i)
    params['transitspot_spot_contrast'] = 0
    flux_model, sys = s._chromatic_models['transitspot'].setup_star_and_planet("transitspot_", 
                                                                               s._chromatic_models['transitspot'].method, 
                                                                               params, 
                                                                               s.data.time.to_value('d'), 
                                                                               [])
    transitmod.append(list(eval_in_model(flux_model, model=s._pymc3_model[i])[0]))
transitmod = np.array(transitmod)
new_model_rainbow = data._create_copy()
new_model_rainbow = new_model_rainbow.attach_model(model=s.data_with_model.model/(np.transpose([s.data_with_model.model[:,0]]*data.ntime)), 
                               planet_model=transitmod/np.transpose([transitmod[:,0]]*data.ntime),
                              spot_model=(s.data_with_model.planet_model/transitmod) / np.transpose([s.data_with_model.planet_model[:,0]/transitmod[:,0]]*data.ntime),
                              systematics_model=s.data_with_model.systematics_model/np.transpose([s.data_with_model.systematics_model[:,0]]*data.ntime)
                                                  )

In [None]:
new_model_rainbow.flux = new_model_rainbow.flux/np.mean(new_model_rainbow.flux[:,:50])

In [None]:
new_model_rainbow.imshow_with_models(models=['model','planet_model','spot_model', 'systematics_model'], 
                                     figsize=(12,4),
                                     vlimits_data=[0.988,1.002],
                                    vspan_residuals=0.005)
plt.savefig("R10_imshow_all_mods.png")

In [None]:
spot_contrast = s._chromatic_models['transitspot'].make_spectrum_table(param="spot_contrast")
max(spot_contrast['transitspot_spot_contrast'].values)

### Corner Plots

In [None]:
# s.corner_plot()
import corner
for i, mod in enumerate(s._pymc3_model):
    try:
        with mod:
            # print(i,s.trace[i][0], mod)
            corner.corner(s.trace[i])
        plt.savefig(f"R10_{i}_corner.png")
    except Exception as e:
        print(e)
# plt.savefig("wlc_corner_with_poly.png")

In [None]:
# s.data_with_model.animate_with_models()

In [None]:
def get_spot_contrast(wavelengths, star_teff=3180, spot_teff=2900, spot_radii=[0.29, 0.16, 0.09],
                      logg=4.97, metallicity=0.0, visualize=False):
    # phoenix = np.load("phoenix_photons_metallicity=0.0_R=100.npy", allow_pickle=True)
    pl = PHOENIXLibrary()
    S_star = pl.get_spectrum(temperature=star_teff, logg=logg, metallicity=metallicity,
                             wavelength=wavelengths, visualize=visualize)
    S_spot = pl.get_spectrum(temperature=spot_teff, logg=logg, metallicity=metallicity,
                             wavelength=wavelengths, visualize=visualize)

    f = 0
    for sr in spot_radii:
        f += sr ** 2
    print(f"Spot covering fraction = {f}")

    S_total = (f * S_spot[1]) + ((1 - f) * S_star[1])

    if visualize:
        plt.plot(S_star[0], S_star[1], label=f"Star (T={star_teff}K)")
        plt.plot(S_spot[0], S_spot[1], label=f"Spot (T={spot_teff}K)")
        plt.plot(wavelengths, S_total, 'k', alpha=0.3, label=f"Mixed Spectrum (f={f})")
        plt.plot(wavelengths, S_total, 'k.', label="Mixed Spectrum")

        plt.ylabel("Surface Flux [Photons / (s * m**2 * nm)]")
        plt.xlabel("Wavelength [micron]")
        plt.legend()

    contrast = (S_star[1] - S_spot[1]) / S_star[1]

    if visualize:
        plt.figure()
        plt.plot(wavelengths, contrast)
        plt.ylabel("Spot Contrast")
        plt.xlabel("Wavelength [micron]")

    return contrast