In [None]:
import itertools
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
import astropy.units as u
import scienceplots
plt.style.use('science')

from exod.utils.path import data, data_results, savepaths_combined
from exod.post_processing.crossmatch import crossmatch_dr14_slim
from exod.post_processing.crossmatch_runs import split_subsets, crossmatch_simulation_subsets
from exod.post_processing.cluster_regions import ClusterRegions
from exod.post_processing.main import calc_df_lc_feat_filter_flags
from sklearn.preprocessing import scale

In [None]:
df_lc_indexs = pd.read_csv(savepaths_combined['lc_idx'], index_col='Unnamed: 0')
df_cmatch    = pd.read_csv(savepaths_combined['cmatch_simbad'])

In [None]:
df_regions        = pd.read_csv(savepaths_combined['regions'])
cr = ClusterRegions(df_regions)
df_regions_unique = cr.df_regions_unique
tab_cmatch_xmm    = crossmatch_dr14_slim(df_regions_unique)
tab_cmatch_xmm['idx'] = range(len(tab_cmatch_xmm))

In [None]:
df_lc_feat = pd.read_csv(savepaths_combined['lc_features'])
df_lc_feat = calc_df_lc_feat_filter_flags(df_lc_feat)

In [None]:
data_cols = ['ratio_bccd', 'ratio_bti', 'ks_stat', 'ks_pval', 'n_min', 'n_max', 'n_mean',
             'n_std', 'n_sum', 'n_skew', 'n_kurt', 'mu_min', 'mu_max', 'mu_mean', 'mu_std', 'mu_skew',
             'mu_kurt', 'B_peak_log_max', 'B_eclipse_log_max', 'num_B_peak_above_6_4', 'num_B_eclipse_above_5_5']

In [None]:
df_numeric = df_lc_feat[data_cols]

In [None]:
df_c = df_lc_feat.copy()
df_c = df_c[~((df_c['n_max'] < 5) & (df_c['runid'].str.contains('_5_')))]
df_c = df_c[~((df_c['n_max'] < 8) & (df_c['runid'].str.contains('_5_')))]
df_c = df_c[~((df_c['n_max'] < 10) & (df_c['runid'].str.contains('_5_')))]

df_c = df_c[df_c['n_max'] < 100]
df_c = df_c[df_c['n_min'] < 50]
df_c = df_c[df_c['n_std'] < 50]
df_c = df_c[df_c['B_peak_log_max'] < np.inf]
df_c = df_c[df_c['B_eclipse_log_max'] < np.inf]
df_c = df_c[df_c['num_B_peak_above_6_4'] < 50]
df_c = df_c[df_c['num_B_eclipse_above_5_5'] < 50]

print(len(df_c))

In [None]:

for col1, col2 in itertools.combinations(data_cols, r=2):
    plt.figure(figsize=(5,5))
    for i, tbin in enumerate(tbins):
        sub = df_c[df_c['runid'].str.contains(tbin)]
        plt.scatter(sub[col1], sub[col2], s=1.0, marker='.', label=labs[i])
        plt.xlabel(col1)
        plt.ylabel(col2)
    plt.legend()
    plt.show()
#plt.scatter(df_lc_feat

In [None]:
def clean(data):
    d = data
    #d = np.log10(d+1)
    d = d[d > -np.inf]
    d = d[d < np.inf]
    return d
    

for col in data_cols:
    plt.figure()
    plt.title(col)
    for i, tbin in enumerate(tbins):
        sub = df_lc_feat[df_lc_feat['runid'].str.contains(tbin)]
        plt.hist(clean(sub[col]), bins=100, histtype='step', label=labs[i])
    plt.hist(clean(df_lc_feat[col]), bins=100, histtype='step', color='black', label='All')
    plt.yscale('log')    
    plt.ylabel('Count')
    plt.legend()
    plt.show()

In [None]:
# Function to find the row index containing the specified element
def find_row_index(df, element):
    for i, row in df.iterrows():
        if element in row['idxs']:
            return i
    return None  # If the element is not found in any row


#sub_many_peak_and_eclipse = df_lc_feat[(df_lc_feat['num_B_peak_above_6_4'] > 10) & (df_lc_feat['num_B_eclipse_above_5_5'] > 10)]
#sub_hard_and_fast = df_lc_feat[df_lc_feat['runid'].str.contains('5_2.0_12.0')]
#sub_strong_peaks    = df_lc_feat[df_lc_feat['B_peak_log_max'] > 20]
sub_hard_strong_eclipses = df_lc_feat[(df_lc_feat['B_eclipse_log_max'] > 20) & (df_lc_feat['runid'].str.contains('2.0_12.0'))]
sub_hard_strong_peaks = df_lc_feat[(df_lc_feat['B_peak_log_max'] > 20) & (df_lc_feat['runid'].str.contains('2.0_12.0'))]


for idx, row in sub_hard_strong_eclipses.iterrows():
    key         = row['key']
    try:
        start, stop = df_lc_indexs.loc[key]
    except:
        print(key, 'failed!')
    df_lc       = pd.read_hdf('../data/results_combined/merged_with_dr14/df_lc.h5', start=start, stop=stop)
    if df_lc['n'].max() < 30:
        continue
    row_index = find_row_index(df_regions_unique, idx)
    match = df_cmatch.loc[row_index]
    if match['SEP_ARCSEC'] < 15:
        continue
    
    print(f'{key} {row_index}')
    
    plt.figure(figsize=(15,3))
    plt.plot(df_lc['time'], df_lc['n'], color='black', label=key)
    plt.plot(df_lc['time'], df_lc['mu'], color='red')
    plt.legend()
    plt.show()