In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import corner
import os
import sys
import glob
from   copy import deepcopy
import astropy.constants as apc
from   astropy.io import fits
from   scipy import stats
import warnings
import seaborn as sns

PROJECT_DIR = '/Users/research/projects/kepler-ecc-rp/'

sys.path.append(PROJECT_DIR)
from utils.stats import weighted_percentile
from utils.io import load_posteriors, extract_posteriors

sys.path.append('/Users/research/projects/alderaan/')
sys.path.append('/data/user/gjgilbert/projects/alderaan/')
from alderaan.Results import Results

pi = np.pi

RSAU = (apc.R_sun/apc.au).value                                 # solar radius [AU]
RSRE = (apc.R_sun/apc.R_earth).value                            # R_sun/R_earth
RHOSUN_GCM3 = (3*apc.M_sun/(4*pi*apc.R_sun**3)).value/1000      # solar density [g/cm^3]

## Load data

#### Input catalog

In [None]:
DATA_SOURCE  = 'ALDERAAN-INJECTION'
DATA_DIR     = '/Users/research/projects/alderaan/Results/2024-09-25-SIMULATED-singles-ecc-physical/'

CATALOG = os.path.join(DATA_DIR, '{0}.csv'.format(DATA_DIR[DATA_DIR.find('Results')+8:-1]))
catalog = pd.read_csv(CATALOG, index_col=0)


# make injection catalog look like real catalog
if DATA_SOURCE == 'ALDERAAN-INJECTION':
    n = len(catalog.koi_id)
    
    catalog.columns = map(str.lower, catalog.columns)
    
    # track ground-truth injected values
    catalog['true_rp'] = np.copy(catalog.ror * catalog.rstar * RSRE)
    catalog['true_rstar'] = np.copy(catalog.rstar)
    catalog['true_rhostar'] = np.copy(10**catalog.logrho)
    
    # assign planet name
    planet_name = np.array(catalog.koi_id.values, dtype='U9')
    for i, p in enumerate(planet_name):
        planet_name[i] = planet_name[i] + '.01'
    catalog['planet_name'] = planet_name
    
    
use = (catalog.true_rp > 1.0)*(catalog.true_rp < 4.0)
catalog = catalog.loc[use].reset_index(drop=True)
targets = np.array(catalog.planet_name)

#### Transit fit posterior chains

In [None]:
def infer_planet_koi_from_period(star_koi, P_samp, catalog):
    P_all = catalog.loc[catalog.koi_id==star_koi, 'period'].values
    P_cat = P_all[np.argmin(np.abs(P_all - P_samp))]
        
    return catalog.loc[catalog.period==P_cat, 'planet_name'].values[0]

In [None]:
def infer_index_from_planet_koi(planet_koi, results, catalog):
    periods = np.zeros(results.npl)

    for n in range(results.npl):
        periods[n] = np.median(results.samples(n).PERIOD)
    
    return np.argmin(np.abs(periods - catalog.loc[catalog.planet_name==planet_koi, 'period'].values))    

In [None]:
chains  = {}
failure = []

if DATA_SOURCE == 'ALDERAAN-INJECTION':
    files = np.sort(glob.glob(os.path.join(DATA_DIR, '*/*results.fits')))
    
    for i, t in enumerate(targets):
        try:
            
            results = Results('S'+t[1:-3], DATA_DIR)
            n = infer_index_from_planet_koi(t, results, catalog)

            chains[t] = results.samples(n).sample(n=8000, replace=True, weights=results.posteriors.weights(), ignore_index=True)
            chains[t] = chains[t].drop(columns='LN_WT')
            
        except FileNotFoundError:
            warnings.warn("{0} failed to load".format(t))
            failure.append(t)            
            
else:
    raise ValueError("Data source must be either 'DR25'or 'ALDERAAN' or 'ALDERAAN-INJECTION'")
    

# update targets and catalog
targets = list(np.array(targets)[~np.isin(targets,failure)])
catalog = catalog.loc[np.isin(catalog.planet_name, targets)].reset_index(drop=True)

print("{0} targets loaded".format(len(targets)))

#### H-Bayes posteriors

In [None]:
PATHS = [os.path.join(PROJECT_DIR, 'Results/20241004/injection-test-rp-slices-fwhm10-empirical-gap00/'),
         os.path.join(PROJECT_DIR, 'Results/20241004/injection-test-rp-slices-fwhm10-empirical-gap10/'),
         os.path.join(PROJECT_DIR, 'Results/20241004/injection-test-rp-slices-fwhm10-empirical-gap20/'),
         os.path.join(PROJECT_DIR, 'Results/20241004/injection-test-rp-slices-fwhm10-empirical-gap40/')]

## Make figures

In [None]:
period  = np.zeros(len(targets))
rp_true = np.zeros(len(targets))
rp_obs  = np.zeros(len(targets))

for i, t in enumerate(targets):
    period[i]  = catalog.loc[catalog.planet_name==t, 'period']
    rp_true[i] = catalog.loc[catalog.planet_name==t, 'ror']*catalog.loc[catalog.planet_name==t, 'rstar']*RSRE
    rp_obs[i]  = np.median(chains[t].ROR)*catalog.loc[catalog.planet_name==t, 'rstar']*RSRE
    
plt.figure(figsize=(4,3))
plt.plot(rp_true, rp_obs, 'k.')
plt.xlabel("Injected $R_p$", fontsize=16)
plt.ylabel("Recovered $R_p$", fontsize=16)
plt.xlim(0.8,4.2)
plt.ylim(0.8,4.2)
plt.show()

In [None]:
sns.set_context("paper", font_scale=1.5)


files = list(np.sort(glob.glob(os.path.join(PATHS[0], '*.fits'))))

samples, headers, bin_edges = load_posteriors(files)
rp0, ecc0, mult0, nobj0 = extract_posteriors(samples, headers)


for path in PATHS:
    # load H-Bayes data
    files = list(np.sort(glob.glob(os.path.join(path, '*.fits'))))

    samples, headers, bin_edges = load_posteriors(files)
    rp, ecc, mult, nobj = extract_posteriors(samples, headers)

    
    gap = float(path[-3:-1])/100

    if gap > 0:
        ecc = np.percentile(ecc-ecc0, [16,50,84], axis=1).T
    else:
        ecc = np.percentile(ecc, [16,50,84], axis=1).T
        
    x = np.mean(rp[mult==1], axis=1)
    y = ecc[mult==1][:,1]
    yerr = np.abs(ecc[mult==1][:,(0,2)].T - ecc[mult==1][:,1])
    
    # extract catalog data
    use = ((rp_true < 1.84/(1+gap/2))+(rp_true > 1.84*(1+gap/2))) * (rp_obs > 1.0) * (rp_obs < 4.0)
    per_ = period[use]
    rpt_ = rp_true[use]
    rpo_ = rp_obs[use]
    
    # make plot
    fig = plt.figure(figsize=(18,5))
    ax  = [None,None]
    
    ax[0] = plt.subplot2grid(shape=(1,3), loc=(0,0), rowspan=1, colspan=2)
    ax[0].errorbar(x, y, yerr=yerr, fmt='ko')
    if gap > 0:
        ax[0].axvspan(1.84/(1+gap/2), 1.84*(1+gap/2), color='C0', alpha=0.2)
        ax[0].plot(np.linspace(1,4,20), np.zeros(20), color='grey', ls='--', zorder=0)
        ax[0].set_ylabel(r"$\langle e \rangle - \langle e \rangle_{\rm no\ gap}$", fontsize=16)
    else:
        ax[0].plot(np.linspace(1,4,20), 0.023*np.ones(20), color='grey', ls='--', zorder=0)
        ax[0].set_ylabel(r"$\langle e \rangle$", fontsize=16)
        
    ax[0].set_xscale('log')
    ax[0].set_xticks([1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0], [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0])
    ax[0].set_ylim(-0.1,0.3)
    ax[0].set_xlabel(r"$R_p\ (R_\oplus)$", fontsize=16)
    ax[0].minorticks_off()

    
    ax[1] = plt.subplot2grid(shape=(1,3), loc=(0,2), rowspan=1, colspan=1)
    for i, _ in enumerate(per_):
        ax[1].plot(per_[i]*np.ones(2), np.linspace(rpt_[i], rpo_[i], 2), color='lightgrey')
    ax[1].plot(per_, rpo_, 'k.')
    if gap > 0:
        ax[1].axhspan(1.84/(1+gap/2), 1.84*(1+gap/2), color='C0', alpha=0.2)
    ax[1].set_xscale('log')
    ax[1].set_yscale('log')
    ax[1].set_xticks([1,3,10,30,100], [1,3,10,30,100])
    ax[1].set_yticks([1,2,4], [1,2,4])
    ax[1].minorticks_off()
    ax[1].set_xlabel(r"$P$ (days)", fontsize=16)
    ax[1].set_ylabel(r"$R_p\ (R_\oplus)$", fontsize=16)
    
    plt.savefig(os.path.join(PROJECT_DIR, 'Figures/injection-test-{0}.pdf'.format(path[-6:-1])), bbox_inches='tight')
    plt.show()