In [1]:
# 3rd party
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import lightkurve as lk
from collections import defaultdict
import textwrap
import logging

In [2]:
logging.basicConfig(level=logging.INFO,
                    format='%(levelname)s - %(message)s')

logger = logging.getLogger(__name__) 

In [3]:
logger.info("test")

INFO - test


In [None]:
plot_dir = Path(f"/Users/jochoa4/Desktop/studies/study_model_preds_07-16-2025")
plot_dir.mkdir(parents=True, exist_ok=True)

tce_tbl_fp = Path(
    "/Users/jochoa4/Projects/exoplanet_transit_classification/ephemeris_tables/preprocessing_tce_tables/tess_2min_tces_dv_s1-s68_all_msectors_11-29-2023_2157_newlabels_nebs_npcs_bds_ebsntps_to_unks.csv"
)

scores_tbls_dir = Path(
    "/Users/jochoa4/Desktop/pfe_transfers/predict_model_TESS_exoplanet_dataset_07-11-2025_no_ntp_no_detrend_split_norm_filtered"
)

show_plots = False
# -------------------- Load Data --------------------
tce_tbl = pd.read_csv(tce_tbl_fp)
tce_tbl = tce_tbl.rename(columns={"uid": "tce_uid", "label": "disposition"})

if 'target_id' not in tce_tbl.columns:
    logger.debug("target_id not in tce_tbl, adding it")
    tce_tbl['target_id'] = (
        tce_tbl['tce_uid'].astype(str)
               .str.split('-')
               .str[0]
               .astype(int)
    )
# -------------------- Helper Functions --------------------
def _get_confusion_label(row):
    label, pred_label = row['label'], row['pred_label']
    if label == 0 and pred_label == 0:
        return 'TN'
    if label == 0 and pred_label == 1:
        return 'FP'
    if label == 1 and pred_label == 0:
        return 'FN'
    if label == 1 and pred_label == 1:
        return 'TP'
    return 'UNKNOWN'

# -------------------- Mixed-Target Flag --------------------
mixed_targets = {
    tid for tid, grp in tce_tbl.groupby('target_id')
    if grp['disposition'].isin(['EB','CP','KP']).any()
    and grp['disposition'].isin(['NTP','NEB','NPC']).any()
}


tce_tbl_cols = list(tce_tbl.columns)
log_features = ['tce_depth', 'tce_max_mult_ev', 'tce_maxmes', 'tce_period', 'tce_duration', 'tce_num_transits', 'tce_dikco_msky_original', 'tce_dikco_msky_err_original', 'tce_dicco_msky_original','tce_dicco_msky_err_original']
linear_features = []
feature_cols = log_features + linear_features
feature_cols = [f for f in feature_cols if f in tce_tbl_cols]

summary_rows = []

all_preds_tbl = []
for split_name in ['train', 'val', 'test']:
    preds_tbl = preds_tbl_fp = scores_tbls_dir / f"preds_{split_name}.csv"
    preds_tbl = pd.read_csv(preds_tbl_fp)
    
     # Extract identifiers and flags
    preds_tbl['tce_uid'] = preds_tbl['uid'].str.split('_').str[0]
    preds_tbl['target_id'] = preds_tbl['tce_uid'].str.split('-').str[0].astype(int)
    preds_tbl['confusion'] = preds_tbl.apply(_get_confusion_label, axis=1)
    preds_tbl['mixed_target_flag'] = preds_tbl['target_id'].apply(lambda t: 1 if t in mixed_targets else 0)

    merge_cols = ['tce_uid'] + feature_cols
    preds_tbl = preds_tbl.merge(
        tce_tbl[merge_cols],
        on="tce_uid",
        how="left",
        validate="many_to_one",
    )

    preds_tbl['split'] = split_name
    all_preds_tbl.append(preds_tbl)
all_preds_tbl = pd.concat(all_preds_tbl)

for split_name in ['train', 'val', 'test']:
    split_plot_dir = plot_dir / split_name
    split_plot_dir.mkdir(parents=True, exist_ok=True)

    # Load predictions
    preds_tbl = all_preds_tbl[all_preds_tbl['split'] == split_name].copy()

    logger.info(f"\n=== Split: {split_name} ===")
    split_dir = plot_dir / split_name
    split_dir.mkdir(exist_ok=True)

    # Summaries
    logger.info(f"Total examples: {len(preds_tbl)}")
    logger.info(f"Confusion counts:\n{preds_tbl['confusion'].value_counts()}" )

    # Build CF df
    for cf, df_cf in preds_tbl.groupby("confusion"):
        summary_rows.append({
            "split":      split_name,
            "disposition": "ANY",
            "confusion":  cf,
            "count":      len(df_cf)
        })

    # per‐disposition confusion counts
    for disp, df_disp in preds_tbl.groupby("disposition"):
        for cf, df_cf in df_disp.groupby("confusion"):
            summary_rows.append({
                "split":       split_name,
                "disposition": disp,
                "confusion":   cf,
                "count":       len(df_cf)
            })

    for disp in preds_tbl['disposition'].unique():
        df_disp = preds_tbl[preds_tbl['disposition'] == disp]
        logger.info(f"\n{disp} ({len(df_disp)} examples)")
        logger.info(df_disp['confusion'].value_counts().to_string())
        logger.info(df_disp[feature_cols].describe().to_string())

        for disp in preds_tbl['disposition'].unique():
            logger.info(f"{disp} Summary")
            disp_preds_tbl = preds_tbl[preds_tbl['disposition'] == disp].copy()
            print(f"{split_name} set consists of {len(disp_preds_tbl)} {disp} examples.")
            print(disp_preds_tbl['confusion'].value_counts())
            
        worst_examples = {
            cf : defaultdict(list) for cf in preds_tbl["confusion"].unique()
        }

        for cf in ['FP', 'FP']:
            worst_examples[cf] = preds_tbl[preds_tbl["confusion"] == cf]

            worst_targets_set = set([])
            worst_targets_list = []
            worst_targets_map = defaultdict(list)
            for _, cf_data in worst_examples[cf].iterrows():
                worst_targets_set.add(cf_data['target_id'])
                worst_targets_list.append(cf_data['target_id'])
                worst_targets_map[cf_data['target_id']].append(cf_data['uid'])

            for t, exs in worst_targets_map.items():
                logger.info(f"{t}: {len(exs)} {'mixed' if int(t) in mixed_targets else ''}")

summary_df = pd.DataFrame(summary_rows)
logger.info(summary_df.to_string(index=False))


  tce_tbl = pd.read_csv(tce_tbl_fp)
INFO - 
=== Split: train ===
INFO - Total examples: 207557
INFO - Confusion counts:
confusion
TP    173003
TN     33964
FN       575
FP        15
Name: count, dtype: int64
INFO - 
CP (14148 examples)
INFO - confusion
TN    13239
TP      733
FN      165
FP       11
INFO -            tce_depth  tce_max_mult_ev    tce_maxmes    tce_period  tce_duration  tce_num_transits  tce_dikco_msky_original  tce_dikco_msky_err_original  tce_dicco_msky_original  tce_dicco_msky_err_original
count   14148.000000     14148.000000  14148.000000  14148.000000  14148.000000      14148.000000             14148.000000                 14148.000000             14148.000000                 14148.000000
mean    55759.697478        35.717691      2.720668      5.544775      1.856815        264.988196                 3.555450                     3.755043                 8.195008                     4.323769
std    147208.866463        47.028759      0.938297      9.169517      1.4

train set consists of 14148 CP examples.
confusion
TN    13239
TP      733
FN      165
FP       11
Name: count, dtype: int64
train set consists of 183443 EB examples.
confusion
TP    172202
TN     10845
FN       392
FP         4
Name: count, dtype: int64
train set consists of 9966 KP examples.
confusion
TN    9880
TP      68
FN      18
Name: count, dtype: int64
train set consists of 14148 CP examples.
confusion
TN    13239
TP      733
FN      165
FP       11
Name: count, dtype: int64


INFO - KP Summary
INFO - 55525572: 1 
INFO - 267574918: 10 mixed
INFO - 48507019: 1 
INFO - 213047427: 1 mixed
INFO - 152223725: 1 
INFO - 93000166: 1 
INFO - 55525572: 1 
INFO - 267574918: 10 mixed
INFO - 48507019: 1 
INFO - 213047427: 1 mixed
INFO - 152223725: 1 
INFO - 93000166: 1 
INFO - 
KP (9966 examples)
INFO - confusion
TN    9880
TP      68
FN      18
INFO -            tce_depth  tce_max_mult_ev   tce_maxmes   tce_period  tce_duration  tce_num_transits  tce_dikco_msky_original  tce_dikco_msky_err_original  tce_dicco_msky_original  tce_dicco_msky_err_original
count    9966.000000      9966.000000  9966.000000  9966.000000   9966.000000       9966.000000              9966.000000                  9966.000000              9966.000000                  9966.000000
mean    13735.472611       120.236833     2.864413     3.907444      2.805767        139.638671                 2.619244                     3.360797                 7.083460                     3.853381
std     17575.1393

train set consists of 183443 EB examples.
confusion
TP    172202
TN     10845
FN       392
FP         4
Name: count, dtype: int64
train set consists of 9966 KP examples.
confusion
TN    9880
TP      68
FN      18
Name: count, dtype: int64
train set consists of 14148 CP examples.
confusion
TN    13239
TP      733
FN      165
FP       11
Name: count, dtype: int64
train set consists of 183443 EB examples.
confusion
TP    172202
TN     10845
FN       392
FP         4
Name: count, dtype: int64
train set consists of 9966 KP examples.
confusion
TN    9880
TP      68
FN      18
Name: count, dtype: int64
val set consists of 11135 EB examples.
confusion
TP    9828
TN    1284
FN      21
FP       2
Name: count, dtype: int64
val set consists of 1571 CP examples.
confusion
TN    1571
Name: count, dtype: int64
val set consists of 1162 KP examples.
confusion
TN    1111
FN      47
TP       4
Name: count, dtype: int64
val set consists of 11135 EB examples.
confusion
TP    9828
TN    1284
FN      21
FP  

INFO - KP Summary
INFO - 262412046: 2 
INFO - 262412046: 2 
INFO - 
KP (1162 examples)
INFO - confusion
TN    1111
FN      47
TP       4
INFO -            tce_depth  tce_max_mult_ev   tce_maxmes   tce_period  tce_duration  tce_num_transits  tce_dikco_msky_original  tce_dikco_msky_err_original  tce_dicco_msky_original  tce_dicco_msky_err_original
count    1162.000000      1162.000000  1162.000000  1162.000000   1162.000000       1162.000000              1162.000000                  1162.000000              1162.000000                  1162.000000
mean    23775.219253        59.802284     2.454043     3.016013      2.482307        109.193632                 1.072503                     2.702674                12.187570                     3.459172
std     32962.576153        65.335093     0.598753     1.750502      0.798585        133.254833                 0.848834                     0.284350                17.189056                     2.886731
min      1293.281419         7.172248   

val set consists of 1571 CP examples.
confusion
TN    1571
Name: count, dtype: int64
val set consists of 1162 KP examples.
confusion
TN    1111
FN      47
TP       4
Name: count, dtype: int64
val set consists of 11135 EB examples.
confusion
TP    9828
TN    1284
FN      21
FP       2
Name: count, dtype: int64
val set consists of 1571 CP examples.
confusion
TN    1571
Name: count, dtype: int64
val set consists of 1162 KP examples.
confusion
TN    1111
FN      47
TP       4
Name: count, dtype: int64
test set consists of 20250 EB examples.
confusion
TP    18877
TN     1360
FN       12
FP        1
Name: count, dtype: int64
test set consists of 1088 CP examples.
confusion
TN    1088
Name: count, dtype: int64
test set consists of 1143 KP examples.
confusion
TN    1143
Name: count, dtype: int64
test set consists of 20250 EB examples.
confusion
TP    18877
TN     1360
FN       12
FP        1
Name: count, dtype: int64
test set consists of 1088 CP examples.
confusion
TN    1088
Name: count, dtyp

In [None]:
train_preds_tbl = all_preds_tbl[all_preds_tbl['split'] == 'train']
train_preds_tbl[(train_preds_tbl['tce_dikco_msky'] > 20.0)]['tce_uid']

In [None]:
for col in tce_tbl.columns:
    if 'FP' in list(tce_tbl[col]):
        print(f'col: {col}')


In [None]:
for col in tce_tbl.columns:
    if 'dikco' in col:
        print(col)

In [None]:
abs(all_preds_tbl['pred_prob'] - all_preds_tbl['label'])

In [None]:
all_preds_tbl['tce_dikco_msky_err_original'].describe()

In [None]:
all_preds_tbl['pred_err'] = all_preds_tbl.apply(lambda r: abs(r['pred_prob'] - r['label']), axis=1)
all_preds_tbl['ratio'] = all_preds_tbl.apply(lambda r: r['tce_dikco_msky_err_original'] / (r['tce_dikco_msky_original'] + 1e-8), axis=1)
all_preds_tbl['err_cap'] = all_preds_tbl.apply(lambda r: 0.33 * (r['tce_dikco_msky_original'] ), axis=1)
all_preds_tbl['tic_offset_estimate'] = all_preds_tbl.apply(lambda r: r['tce_dikco_msky_err_original'] + (r['tce_dikco_msky_original']), axis=1)
all_preds_tbl['uncer_err'] = all_preds_tbl.apply(lambda r: r['tce_dikco_msky_err_original'] / (abs(r['tce_dikco_msky_err_original'] + r['tce_dikco_msky_original'])), axis=1)
eb_preds = all_preds_tbl[all_preds_tbl['disposition'] == 'EB']
eb_preds['pred_err'].describe()
cp_preds = all_preds_tbl[all_preds_tbl['disposition'] == 'CP']
cp_preds['pred_err'].describe()

In [None]:
ratio_preds = eb_preds[(eb_preds['ratio'] <  0.25) & (eb_preds['split'] == 'train')]
ratio_preds['confusion'].value_counts()

In [None]:
filt_eb_preds = eb_preds[(eb_preds['tce_dikco_msky_original'] < 4.2) & (eb_preds['tce_dikco_msky_err_original'] > 0)]
print(filt_eb_preds['confusion'].value_counts())

In [None]:
filt_cp_preds = cp_preds[((cp_preds['tce_dikco_msky_original'] + cp_preds['tce_dikco_msky_err_original']) < 20) & (cp_preds['tce_dikco_msky_err_original'] >= 0) ]
print(filt_cp_preds['confusion'].value_counts())

In [None]:
cp_preds['confusion'].value_counts()

In [None]:
filt_eb_preds = eb_preds[((eb_preds['tce_dikco_msky_original'] + eb_preds['tce_dikco_msky_err_original']) < 5.6) & (eb_preds['tce_dikco_msky_err_original'] >= 0) ]
print(filt_eb_preds['confusion'].value_counts())

In [None]:
eb_preds[((eb_preds['tce_dikco_msky_original'] < 4.2) & ( (eb_preds['tce_dikco_msky_original'] + eb_preds['tce_dikco_msky_err_original']) < 5.6) & (eb_preds['tce_dikco_msky_err_original'] >= 0)) ]['confusion'].value_counts()


In [None]:
bad_eb_examples = eb_preds[(( (eb_preds['tce_dikco_msky_original'] + eb_preds['tce_dikco_msky_err_original']) >= 5.6) & (eb_preds['tce_dikco_msky_err_original'] >= 0)) ]
bad_eb_examples['confusion'].value_counts()


In [None]:
bad_eb_examples[bad_eb_examples['confusion'] == 'FN']['target_id'].value_counts()

In [None]:
offset_cols = ['tce_uid', 'disposition', 'tce_time0bk','tce_period', 'tce_duration', 'tce_dikco_msky_original', 'tce_dikco_msky_err_original', 'tce_dicco_msky_original', 'tce_dicco_msky_err_original']
tce_tbl[tce_tbl['target_id'] == 420114776][offset_cols]

In [None]:
tce_tbl[tce_tbl['disposition'] == 'EB']['tce_dikco_msky_err_original'].describe()

In [None]:
filt_eb_preds[(filt_eb_preds['tce_dikco_msky_original'] < 5.6)]['confusion'].value_counts()

In [None]:
filt_eb_preds[((filt_eb_preds['tce_dikco_msky_original'] + filt_eb_preds['tce_dikco_msky_err_original']) < 4.2)]['confusion'].value_counts()

In [None]:
filt_eb_preds[((filt_eb_preds['tce_dikco_msky_original'] + filt_eb_preds['tce_dikco_msky_err_original']) < 4.2)]['confusion'].value_counts()

In [None]:
eb_preds['confusion'].value_counts()

In [None]:
test = filt_eb_preds[(filt_eb_preds['tce_dikco_msky_original'] + filt_eb_preds['tce_dikco_msky_err_original'] < (4.2 ))]
test['confusion'].value_counts()

In [None]:
filt_eb_preds[filt_eb_preds['confusion'] == 'FN']['target_id'].value_counts()

In [None]:
test[test['confusion'] == 'FN']['target_id'].value_counts()

In [None]:
test[test['target_id'] == 193831684]

In [None]:
eb_preds[(eb_preds['tce_dikco_msky_original'] / eb_preds['tce_dikco_msky_err_original'] > 0)]['confusion'].value_counts()

In [None]:
filt_eb_preds = eb_preds[
    (eb_preds['tce_dikco_msky_original'] < 4.2) &
    (eb_preds['tce_dikco_msky_err_original'] + eb_preds['tce_dikco_msky_original'] < (5.6))
]

print(filt_eb_preds['confusion'].value_counts())
print((filt_eb_preds['tce_dikco_msky_err_original'] > 0).sum())
print((filt_eb_preds['tce_dikco_msky_err_original'] <= 0).sum())

In [None]:
(eb_preds['tce_dikco_msky_original'] < 0).sum()

In [None]:
eb_preds[eb_preds['tce_dikco_msky_err_original'] > eb_preds['tce_dikco_msky_original']]['confusion'].value_counts()

In [None]:
cp_preds = all_preds_tbl[all_preds_tbl['disposition'] == 'CP']
cp_preds['confusion'].value_counts()

In [None]:
(cp_preds['ratio'] < 1/3).sum()

In [None]:
thr_cp_preds = cp_preds[cp_preds['tce_dikco_msky_original'] < 4.2]
thr_cp_preds['confusion'].value_counts()

In [None]:
thr_eb_preds = eb_preds[eb_preds['tce_dikco_msky_original'] < 4.2]
thr_eb_preds['confusion'].value_counts()

In [None]:
thr_eb_preds[thr_eb_preds['tce_dikco_msky_err_original'] <= -1.0]['confusion'].value_counts()

In [None]:
print(((eb_preds['label'] == 1)
       &
       (eb_preds['tce_dikco_msky_err_original'] >= 0)).sum())

In [None]:
print((eb_preds['tce_dikco_msky_original'] < 0).sum())
print((eb_preds['tce_dikco_msky_err_original'] < 0).sum())

In [None]:
print((eb_preds['tce_dikco_msky_original'] >= 0).sum())
print((eb_preds['tce_dikco_msky_err_original'] >= 0).sum())

In [None]:
ratio_preds[ratio_preds['tce_dikco_msky_original'] < 4.2]['confusion'].value_counts()

In [None]:
uncer_preds = eb_preds[(eb_preds['ratio'] < 0.3333) & eb_preds['']]
uncer_preds['confusion'].value_counts()

In [None]:
20944 + 207

In [None]:
uncer_preds[uncer_preds['split'] == 'train']['confusion'].value_counts()

In [None]:
20759 + 196

In [None]:
20759 + 196

In [None]:
20944 + 207

In [None]:
20935 + 207

In [None]:
ratio_preds[ratio_preds['split'] == 'train']['confusion'].value_counts()

In [None]:
20361 + 207

In [None]:
20935 + 207

In [None]:
filt_eb_preds = eb_preds[(eb_preds['tce_dikco_msky_original'] < 4.2) & (eb_preds['tce_dikco_msky_err_original'] < eb_preds['tce_dikco_msky_err_original'])]
print(filt_eb_preds['confusion'].value_counts())

In [None]:
filt_all_preds = all_preds_tbl[
    (all_preds_tbl['tce_dikco_msky_original'] < 4.2) &
    (all_preds_tbl['tce_dikco_msky_err_original'] < (1/3) * all_preds_tbl['tce_dikco_msky_original'])
]
print(filt_all_preds['confusion'].value_counts())

In [None]:
filt_eb_preds = eb_preds[
    (eb_preds['tce_dikco_msky_original'] < 4.2) &
    ((eb_preds['tce_dikco_msky_err_original'] / eb_preds['tce_dikco_msky_original']) < (1/3) )
]
print(filt_eb_preds['confusion'].value_counts())

In [None]:
filt_eb_preds = eb_preds[(eb_preds['tce_dikco_msky_original'] < 4.2) & (eb_preds['tce_dikco_msky_err_original'] > ( 0.33 * eb_preds['tce_dikco_msky_original']))]
print(filt_eb_preds['confusion'].value_counts())

In [None]:
tce_tbl[tce_tbl['disposition'] == 'EB']['tce_dikco_msky_err_original'].describe()

In [None]:
eb_preds[(eb_preds['tce_dicco_msky_original'] <  4.2) & (eb_preds['tce_dicco_msky_err_original'] >  0.33)]['confusion'].value_counts()

In [None]:
eb_preds['confusion'].value_counts()

In [None]:
all_preds_tbl['confusion'].value_counts()

In [None]:
all_preds_tbl[(all_preds_tbl['ratio'] < 0.25)]['confusion'].value_counts()

In [None]:
eb_preds['tce_dicco_msky_err_original'].describe()

In [None]:
eb_preds[eb_preds['tce_dicco_msky_err_original'] == -1.0]['confusion'].value_counts()

In [None]:
eb_preds[( ((eb_preds['tce_dikco_msky_original'] <  4.2) & (eb_preds['tce_dikco_msky_err_original'] <  (0.33 * eb_preds['tce_dikco_msky_original']))))]['confusion'].value_counts()


In [None]:
eb_preds[( ((eb_preds['tce_dicco_msky_original'] <  4.2) & (eb_preds['tce_dicco_msky_err_original'] <  0.33 * eb_preds['offset_estimate']) ))]['confusion'].value_counts()


In [None]:
all_preds_tbl[( ((all_preds_tbl['tce_dicco_msky_original'] >  4.2) & (all_preds_tbl['tce_dicco_msky_err_original'] >  0.33) & (all_preds_tbl['tce_dicco_msky_err_original'] != -1.0)))]['confusion'].value_counts()


In [None]:
for split in [ 'ANY']:
    print(f"\n{split} preds:")
    if split == 'ANY':
        split_preds_tbl = all_preds_tbl
    else:
        split_preds_tbl = all_preds_tbl[all_preds_tbl['split'] == split]

    for disp in ['EB']:
        print(f"\n{disp} preds:")
        disp_preds = split_preds_tbl[split_preds_tbl['disposition'] == disp]
        print(disp_preds[(disp_preds['tce_dicco_msky_original'] <  4.2) & ((disp_preds['tce_dicco_msky_err_original'] <  0.33)) ]['confusion'].value_counts())


In [None]:
eb_preds[( (eb_preds['tce_dicco_msky_original'] >  4.2) & ((eb_preds['tce_dicco_msky_err_original'] / eb_preds['tce_dicco_msky_original']) < 0.33))]['confusion'].value_counts()


In [None]:
eb_preds[eb_preds['tce_dicco_msky_err_original'] <  0.25]['confusion'].value_counts()

In [None]:
eb_preds[eb_preds['ratio'] < 0.2]['confusion'].value_counts()

In [None]:
est = eb_preds.apply(lambda r: r['tce_dikco_msky_err_original'] / abs(r['tce_dikco_msky_original'] + 1e-8), axis=1)

In [None]:
est[est > 0.2]

In [None]:
eb_preds[eb_preds['tce_dikco_msky_original'] > 0]['pred_err'].describe()

In [None]:
tce_subset_df['tce_dikco_msky']

In [None]:
for 

In [None]:
tce_subset_df = tce_tbl[tce_tbl['tce_uid'] == '420114776-1-S24']
for col in tce_subset_df.columns:
    if pd.api.types.is_numeric_dtype(tce_subset_df[col]):
        if (tce_subset_df[col] > 20).any() and (tce_subset_df[col] < 21).any():
            print(col)

In [None]:
tce_tbl['TFOPWG Disposition'].unique()

In [None]:
for split, df in summary_df[summary_df['disposition'] != 'ANY'].groupby("split"):
    pivot_split = (
        df.pivot(index="disposition", columns="confusion", values="count")
          .fillna(0)
          .loc[:, ["FN","FP"]]
    )
    fig, ax = plt.subplots(figsize=(5,4))
    pivot_split.plot(kind="bar", stacked=False, ax=ax)
    ax.set_title(f"{split.capitalize()} Set")
    ax.set_xlabel("Disposition")
    ax.set_ylabel("Count")
    ax.legend(title="Outcome", bbox_to_anchor=(1.02,1), loc="upper left")
    plt.xticks(rotation=0)
    plt.tight_layout()
    plt.show()


In [None]:
train_preds = all_preds_tbl[all_preds_tbl['split'] == 'train']
train_preds[train_preds['disposition'] == 'KP']['label'].sum()

In [None]:
for split_name, preds_tbl in all_preds_tbl.groupby('split'):
    logger.info(f"{split_name} set consists of {len(preds_tbl)} examples.")
    logger.info(preds_tbl['label'].value_counts())
    for disp in preds_tbl['disposition'].unique():
        logger.info(f"{disp} Summary")
        disp_preds_tbl = preds_tbl[preds_tbl['disposition'] == disp].copy()
        logger.info(f"Train set consists of {len(disp_preds_tbl)} {disp} examples.")
        logger.info(disp_preds_tbl['label'].value_counts())
    false_examples = {
        cf : defaultdict(list) for cf in ['FN', 'FP']
    }
    false_examples["FP"] = preds_tbl[preds_tbl["confusion"] == "FP"]
    false_examples["FN"] = preds_tbl[preds_tbl["confusion"] == "FN"]

    logger.info(f"{split_name} set has {len(false_examples['FP']) + len(false_examples['FN'])} incorrect predictions.")
    
    for cf, cf_df in false_examples.items():
        logger.info(f"{split_name} has {len(cf_df)} {cf} examples coming from {len(cf_df['tce_uid'].unique())} unique TCEs and {len(cf_df['target_id'].unique())} unique targets.")
        logger.info(f"Target list: {cf_df['target_id'].unique()}")
        

In [None]:
for split_name, preds_tbl in all_preds_tbl.groupby('split'):
    logger.info(f"{split_name} set consists of {len(preds_tbl)} examples.")
    logger.info(preds_tbl['confusion'].value_counts())
    for disp in preds_tbl['disposition'].unique():
        logger.info(f"{disp} Summary")
        disp_preds_tbl = preds_tbl[preds_tbl['disposition'] == disp].copy()
        logger.info(f"Train set consists of {len(disp_preds_tbl)} {disp} examples.")
        logger.info(disp_preds_tbl['confusion'].value_counts())
    false_examples = {
        cf : defaultdict(list) for cf in ['FN', 'FP']
    }
    false_examples["FP"] = preds_tbl[preds_tbl["confusion"] == "FP"]
    false_examples["FN"] = preds_tbl[preds_tbl["confusion"] == "FN"]

    logger.info(f"{split_name} set has {len(false_examples['FP']) + len(false_examples['FN'])} incorrect predictions.")
    
    for cf, cf_df in false_examples.items():
        logger.info(f"{split_name} has {len(cf_df)} {cf} examples coming from {len(cf_df['tce_uid'].unique())} unique TCEs and {len(cf_df['target_id'].unique())} unique targets.")
        logger.info(f"Target list: {cf_df['target_id'].unique()}")
    

# Analyzing Val

In [None]:
example_map = {}
for split_name in ['val', 'test', 'train']:
    logger.info(f"Analyzing split: {split_name}")
    df_split = all_preds_tbl[all_preds_tbl['split'] == split_name]

    # only need FN & FP
    false_examples = {
        'FP': df_split[df_split['confusion'] == 'FP'],
        'FN': df_split[df_split['confusion'] == 'FN']
    }

    example_map[split_name] = {}
    for cf, df_cf in false_examples.items():
        example_map[split_name][cf] = {}
        for t in df_cf['target_id'].unique():
            df_t = df_cf[df_cf['target_id'] == t]
            example_map[split_name][cf][t] = {}
            for tce_uid in df_t['tce_uid'].unique():
                # grab all rows for this target+TCE
                uids = df_t[df_t['tce_uid'] == tce_uid]['uid'].tolist()
                example_map[split_name][cf][t][tce_uid] = uids

        logger.info(f"{split_name} | {cf} Summary:")
        logger.info(df_cf[['pred_label','pred_prob']].describe().to_string())
        logger.info(f"{split_name} | {cf} Unique TCEs: {df_cf['tce_uid'].unique()}")
        logger.info(f"{split_name} | {cf} Unique Targets: {df_cf['target_id'].unique()}")


In [None]:
for info in example_map['val'].items():
    print(info)

In [None]:
nl = '   '
for split in ['train', 'test', 'val']:
    print(f"{nl * 0} {split}")
    for cf in ['FP', 'FN']:
        print(f"{nl * 1} {cf}")
        for target in example_map[split][cf]:
            print(f"{nl * 2} {target}")
            for tce_uid, examples in example_map[split][cf][target].items():
                print(f"{nl * 3} {tce_uid} : {[str(round(float(e.split('_t_')[-1]), 2)) for e in examples]}")
            

print(example_map['train']['FP'])

In [None]:

splits = ['val', 'test', 'train']
for split in splits:
    df_split = all_preds_tbl[all_preds_tbl['split'] == split]
    for cf in ['FP','FN']:
        df_cf = df_split[df_split['confusion'] == cf]

        # 1) Compute per-(target, tce_uid) mean confidence
        tce_means = (
            df_cf
            .groupby(['target_id','tce_uid'])['pred_prob']
            .mean()
            .rename('tce_avg')
            .reset_index()
        )

        # 2) Compute per-target mean of those tce_uids
        target_means = (
            tce_means
            .groupby('target_id')['tce_avg']
            .mean()
            .rename('target_avg')
            .sort_values(ascending=False)
        )

        print(f"\n=== {split} | {cf} ===")
        for target_id, target_avg in target_means.items():
            print(f"Target {target_id:>10}  ➜  avg prob = {target_avg:.3f}")

            # pull out all the tce's for this target and sort them
            tces = (
                tce_means[tce_means['target_id'] == target_id]
                .set_index('tce_uid')['tce_avg']
                .sort_values(ascending=False)
            )

            for tce_uid, tce_avg in tces.items():
                # list all the underlying uids for context
                uids = (
                    df_cf
                    .loc[
                        (df_cf['target_id']==target_id)&
                        (df_cf['tce_uid']==tce_uid),
                        'uid'
                    ]
                    .tolist()
                )
                print(f"    TCE {tce_uid:>8}  ➜  avg prob = {tce_avg:.3f} from {len(uids)} examples ")

In [None]:
def _summarize_tce_examples(tce_uid: str):
    interest_cols = ['tce_uid', 'disposition', 'tce_time0bk','tce_period', 'tce_duration', 'tce_maxmes', 'tce_maxmesd']
    t_tce_tbl_filt = tce_tbl[(tce_tbl['target_id'] == int(tce_uid.split('-')[0]))].copy()
    logger.info(t_tce_tbl_filt[interest_cols])
    
    tce_tbl_filt = tce_tbl[(tce_tbl['target_id'] == int(tce_uid.split('-')[0])) & (tce_tbl['sector_run'] == tce_uid.split('S')[-1].split('_')[0])]
    logger.info(tce_tbl_filt[interest_cols])
    tce_preds_tbl = all_preds_tbl[all_preds_tbl["tce_uid"] == tce_uid].copy()
    logger.info(f"Dataset has {len(tce_preds_tbl)} examples corresponding to {tce_preds_tbl['disposition'].unique()[0]}, {tce_uid} ")
    logger.info(f"Confusion counts: \n {textwrap.indent(tce_preds_tbl['confusion'].value_counts().to_string(), ' ' * 4)}")
    logger.info(f"Transit Window Counts: \n {textwrap.indent(tce_preds_tbl['tw_flag'].value_counts().to_string(), ' ' * 4)}")
    logger.info(f"Stats by Confusion Label: ")
    for cf in tce_preds_tbl['confusion'].unique():
        logger.info(f"{' ' * 4}{cf} Stats:")
        desc = tce_preds_tbl[tce_preds_tbl['confusion'] == cf]['pred_prob'].describe()[1:]
        logger.info(f"{textwrap.indent(desc.to_string(), ' ' * 8)}")
    
_summarize_tce_examples("30450412-1-S12")

# Looking at average correctness by target and TCE

In [None]:
all_preds_tbl.columns

In [None]:
splits = ['val','test','train']
for split in splits:
    # 1) select split
    df = all_preds_tbl[all_preds_tbl['split']==split].copy()

    # 2) define correctness: if true_label==1, use pred_prob; else use 1–pred_prob
    df['correctness'] = (
        df['pred_prob'] * df['label'] +
        (1 - df['pred_prob']) * (1 - df['label'])
    )

    # 3) average up to the TCE level
    tce_corr = (
        df
        .groupby(['target_id','tce_uid'])['correctness']
        .mean()
        .rename('tce_corr')
        .reset_index()
    )

    # 4) average those TCE‐means to the target level
    target_corr = (
        tce_corr
        .groupby('target_id')['tce_corr']
        .mean()
        .rename('target_corr')
        .reset_index()
        .sort_values('target_corr', ascending=False)
    )

    # 5) display
    print(f"\n=== Split: {split} — target‐level correctness ===")
    for _, row in target_corr.iterrows():
        print(f"Target {int(row['target_id'])}  ➜  avg correctness = {row['target_corr']:.3f}")


In [None]:
splits = ['val', 'test', 'train']
for split in splits:
    df = all_preds_tbl[all_preds_tbl['split'] == split].copy()

    for cf in ['FP', 'FN']:
        df_wrong = df[df['confusion'] == cf]

        # Count wrong examples per TCE (target_id + tce_uid)
        tce_counts = (
            df_wrong
            .groupby(['target_id', 'tce_uid'])
            .size()
            .rename('num_wrong_examples')
            .reset_index()
        )

        # Count total wrong examples per target
        target_counts = (
            tce_counts
            .groupby('target_id')['num_wrong_examples']
            .sum()
            .rename('total_wrong_examples')
            .reset_index()
            .sort_values('total_wrong_examples', ascending=False)
        )

        # === DISPLAY per-target ===
        print(f"\n=== Split: {split} | {cf} — total wrong examples per target ===")
        for row in target_counts.itertuples(index=False):
            print(f"Target {int(row.target_id)}  ➜  total wrong = {row.total_wrong_examples}")

        # === DISPLAY per-TCE within target ===
        print(f"\n=== Split: {split} | {cf} — wrong examples per TCE ===")
        tce_counts_sorted = tce_counts.sort_values(['target_id', 'num_wrong_examples'], ascending=[True, False])
        for row in tce_counts_sorted.itertuples(index=False):
            print(f"Target {int(row.target_id)} | TCE {row.tce_uid}  ➜  wrong examples = {row.num_wrong_examples}")


In [None]:
interest_cols = ['tce_uid', 'disposition', 'tce_time0bk','tce_period', 'tce_duration', 'tce_maxmes', 'tce_maxmesd']
tce_tbl[tce_tbl['target_id'] == 356473034][interest_cols]
# tce_tbl[(tce_tbl['target_id'] == 410418820) & (tce_tbl['sector_run'] == '1-36')][interest_cols]

In [None]:
for split_name, preds_tbl in all_preds_tbl.groupby('split'):
    logger.info(f"{split_name} set consists of {len(preds_tbl)} examples.")
    logger.info(preds_tbl['confusion'].value_counts())
    for disp in preds_tbl['disposition'].unique():
        logger.info(f"{disp} Summary")
        disp_preds_tbl = preds_tbl[preds_tbl['disposition'] == disp].copy()
        logger.info(f"Train set consists of {len(disp_preds_tbl)} {disp} examples.")
        logger.info(disp_preds_tbl['confusion'].value_counts())
    false_examples = worst_examples = {
        cf : defaultdict(list) for cf in ['FN', 'FP']
    }
    false_examples["FP"] = preds_tbl[preds_tbl["confusion"] == "FP"]
    false_examples["FN"] = preds_tbl[preds_tbl["confusion"] == "FN"]

    logger.info(f"{split_name} set has {len(false_examples['FP']) + len(false_examples['FN'])} incorrect predictions.")
    
    for cf, cf_df in false_examples.items():
        logger.info(f"{split_name} has {len(cf_df)} {cf} examples coming from {len(cf_df['tce_uid'].unique())} unique TCEs and {len(cf_df['target_id'].unique())} unique targets.")
        logger.info(f"Target list: {cf_df['target_id'].unique()}")

In [None]:
logger.info(f"Train set consists of {len(preds_tbl)} examples.")
logger.info(preds_tbl['confusion'].value_counts())

for disp in preds_tbl['disposition'].unique():
    logger.info(f"{disp} Summary")
    disp_preds_tbl = preds_tbl[preds_tbl['disposition'] == disp].copy()
    logger.info(f"Train set consists of {len(disp_preds_tbl)} {disp} examples.")
    logger.info(disp_preds_tbl['confusion'].value_counts())
worst_examples = {
    cf : defaultdict(list) for cf in preds_tbl["confusion"].unique()
}

num_examples = 200
# Get most confident FP/TN predictions
fp_df = preds_tbl[preds_tbl["confusion"] == "FP"]
worst_examples["FP"] = fp_df.nlargest(num_examples, columns=["pred_prob"])

tn_df = preds_tbl[preds_tbl["confusion"] == "TN"]
worst_examples["TN"] = tn_df.nlargest(num_examples, columns=["pred_prob"])

# Get least confident TP/FN predictions
tp_df = preds_tbl[preds_tbl["confusion"] == "TP"]
worst_examples["TP"] = tp_df.nsmallest(num_examples, columns=["pred_prob"])

fn_df = preds_tbl[preds_tbl["confusion"] == "FN"]
worst_examples["FN"] = fn_df.nsmallest(num_examples, columns=["pred_prob"])


In [None]:


len(fn_df["target_id"].unique())
logger.info(f"{len(fn_df[fn_df['tw_flag'] < 0])}")
fn_df.describe()
fn_df["mixed"] = fn_df.apply(lambda r: 1 if int(r['target_id']) in mixed_targets else 0, axis=1)

logger.info(f"MIXED FNs: {fn_df['mixed'].sum()}/{len(fn_df)}")

fn_df.head()

In [None]:

ex_at_mixed_target = {}
for cf, exs_df in worst_examples.items():
    exs_df["mixed_target_flag"] = exs_df.apply(lambda r: 1 if int(r['target_id']) in mixed_targets else 0, axis=1)
    # logger.info(f"{cf}: {exs_df['mixed_target_flag'].sum()} / {} are mixed")

In [None]:
len(worst_examples["FN"]["target_id"].unique())

In [None]:
# worst_examples["FP"]["target_id"].unique()
worst_examples["FN"][20:40]#["target_id"].unique()

In [None]:

worst_targets_set = set([])
worst_targets_list = []
worst_targets_map = defaultdict(list)
for _, fn in worst_examples["FN"].iterrows():
    logger.info(f"{fn['target']}, {fn['uid']}, {fn['pred_prob']}")
    worst_targets_set.add(fn['target'])
    worst_targets_list.append(fn['target'])
    worst_targets_map[fn['target']].append(fn['uid'])


In [None]:
len(worst_targets_set)


In [None]:
for t, exs in worst_targets_map.items():
    logger.info(f"{t}: {len(exs)} {'mixed' if int(t) in mixed_targets else ''}")

In [None]:
worst_targets_map["358232450"]

In [None]:
interest_cols = ['tce_uid', 'disposition', 'tce_time0bk','tce_period', 'tce_duration', 'tce_maxmes', 'tce_maxmesd']
tce_tbl[tce_tbl['target_id'] == 358232450][interest_cols]
# tce_tbl[(tce_tbl['target_id'] == 410418820) & (tce_tbl['sector_run'] == '1-36')][interest_cols]

In [None]:
def _summarize_tce_examples(tce_uid: str):
    interest_cols = ['tce_uid', 'disposition', 'tce_time0bk','tce_period', 'tce_duration', 'tce_maxmes', 'tce_maxmesd']
    t_tce_tbl_filt = tce_tbl[(tce_tbl['target_id'] == int(tce_uid.split('-')[0]))].copy()
    logger.info(t_tce_tbl_filt[interest_cols])
    
    tce_tbl_filt = tce_tbl[(tce_tbl['target_id'] == int(tce_uid.split('-')[0])) & (tce_tbl['sector_run'] == tce_uid.split('S')[-1].split('_')[0])]
    logger.info(tce_tbl_filt[interest_cols])
    tce_preds_tbl = preds_tbl[preds_tbl["tce_uid"] == tce_uid].copy()
    logger.info(f"Dataset has {len(tce_preds_tbl)} examples corresponding to {tce_preds_tbl['disposition'].unique()[0]}, {tce_uid} ")
    logger.info(f"Confusion counts: \n {textwrap.indent(tce_preds_tbl['confusion'].value_counts().to_string(), ' ' * 4)}")
    logger.info(f"Transit Window Counts: \n {textwrap.indent(tce_preds_tbl['tw_flag'].value_counts().to_string(), ' ' * 4)}")
    logger.info(f"Stats by Confusion Label: ")
    for cf in tce_preds_tbl['confusion'].unique():
        logger.info(f"{' ' * 4}{cf} Stats:")
        desc = tce_preds_tbl[tce_preds_tbl['confusion'] == cf]['pred_prob'].describe()[1:]
        logger.info(f"{textwrap.indent(desc.to_string(), ' ' * 8)}")
    
_summarize_tce_examples("189476500-1-S1-36")

In [None]:
def _summarize_target_examples(target_id: str):    
    target_preds_tbl = preds_tbl[preds_tbl["target_id"] == int(target_id)].copy()
    logger.info(f"Dataset has {len(target_preds_tbl)} examples corresponding to {target_id}\n")
    for tce_uid in target_preds_tbl["tce_uid"].unique():
        tce_preds_tbl = target_preds_tbl[target_preds_tbl["tce_uid"] == tce_uid].copy()
        logger.info(f"Dataset has {len(tce_preds_tbl)} examples corresponding to {tce_preds_tbl['disposition'].unique()[0]}, {tce_uid} ")
        logger.info(f"Confusion counts: \n {textwrap.indent(tce_preds_tbl['confusion'].value_counts().to_string(), ' ' * 4)}")
        logger.info(f"Transit Window Counts: \n {textwrap.indent(tce_preds_tbl['tw_flag'].value_counts().to_string(), ' ' * 4)}")
        logger.info(f"Stats by Confusion Label: ")
        for cf in tce_preds_tbl['confusion'].unique():
            logger.info(f"{' ' * 4}{cf} Stats:")
            desc = tce_preds_tbl[tce_preds_tbl['confusion'] == cf]['pred_prob'].describe()[1:]
            logger.info(f"{textwrap.indent(desc.to_string(), ' ' * 8)}")
        

_summarize_target_examples('4164713')

In [None]:
temp_df = preds_tbl[preds_tbl["target_id"] == 189476500]
temp_df[temp_df['confusion'] == 'FP'][['uid', 'confusion', 'pred_prob']]

In [None]:
def _logger.info_tce_preds(tce_uid):
    # interest_cols = ['tce_uid', 'disposition', 'tce_time0bk','tce_period', 'tce_duration', 'tce_maxmes', 'tce_maxmesd']
    preds_tbl_filt = preds_tbl[preds_tbl['tce_uid'] == tce_uid].copy()
    for _, ex in preds_tbl_filt.iterrows():
        if float(ex['time']) > 2899 and float(ex['time']) < 2900:
            logger.info(ex[['uid', 'confusion', 'disposition', 'label', 'pred_prob']])

_logger.info_tce_preds('388431711-4-S14-60')
_logger.info_tce_preds('388431711-1-S14-60')


In [None]:
# Comparing Examples

tce_tbl[(tce_tbl['tce_uid'] == '148158540-3-S11') | (tce_tbl['tce_uid'] == '148158540-1-S11')][['tce_uid', 'disposition', 'tce_time0bk','tce_period', 'tce_duration', 'tce_maxmes', 'tce_maxmesd']]

In [None]:
tce_tbl[(tce_tbl['tce_uid'] == '26489741-1-S40')][['tce_uid', 'disposition', 'tce_time0bk','tce_period', 'tce_duration', 'tce_maxmes', 'tce_maxmesd', 'tec_fluxtriage_comment']]


In [None]:
def mean_maxmes_by_target(df):
    return df.groupby('target_id')['maxmes'].mean().dropna()


In [None]:
eb = tce_tbl[tce_tbl['disposition'] == 'EB']
cp = tce_tbl[tce_tbl['disposition'] == 'CP']
kp = tce_tbl[tce_tbl['disposition'] == 'KP']
combined = tce_tbl[tce_tbl['disposition'].isin(['EB', 'CP', 'KP'])]

for name, df in [('EB', eb), ('CP', cp), ('KP', kp), ('ALL', combined)]:
    plt.figure()
    data = df['tce_maxmes'].dropna()
    plt.hist(data, bins=100)
    plt.title(f'Distribution of maxmes for {name} TCEs')
    plt.xlabel('maxmes')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.xlim((0,data.quantile(0.99) ))
    plt.show()
    

# OLD

In [None]:

tce_uid = "188768068-1-S14-26"  # "198408416-1-S14-60"  # '425064757-1-S1-65'   # '198408416-1-S14-60'

examples_tce = preds_tbl.loc[preds_tbl["tce_uid"] == tce_uid]
disp_tce = examples_tce["disposition"].values[0]
mes_tce = examples_tce["tce_max_mult_ev"].values[0]

transit_window_examples = examples_tce.loc[examples_tce["label"] == 1]
not_transit_window_examples = examples_tce.loc[examples_tce["label"] == 0]

f, ax = plt.subplots(figsize=(10, 5))
ax.scatter(
    transit_window_examples["time"],
    transit_window_examples["pred_prob"],
    s=8,
    alpha=0.3,
    edgecolors="k",
    label="Transit Window Examples",
)
ax.scatter(
    not_transit_window_examples["time"],
    not_transit_window_examples["pred_prob"],
    s=8,
    alpha=0.3,
    edgecolors="k",
    label="Not-Transit Window Examples",
)
ax.set_ylabel("Model Score")
ax.set_xlabel("Timestamp [BTJD]")
ax.set_ylim(bins_scores[[0, -1]])
ax.legend()
ax.set_title(
    f"TCE {tce_uid}\nDisposition {disp_tce}\nNumber of examples"
    f" {len(examples_tce)} | TCE Max MES {mes_tce:.3f}"
)
f.tight_layout()
f.savefig(
    plot_dir / f"scatter_transit_nottransit_examples_scores_{tce_uid}_{disp_tce}.png"
)
plt.show()


bins_scores = np.linspace(0, 1, 11)

f, ax = plt.subplots()
ax.hist(
    transit_window_examples["pred_prob"],
    bins_scores,
    histtype="step",
    label="Transit Window Examples",
)
ax.hist(
    not_transit_window_examples["pred_prob"],
    bins_scores,
    histtype="step",
    label="Not-Transit Window Examples",
)
ax.set_xlabel("Model Score")
ax.set_ylabel("Example Count")
ax.set_xlim(bins_scores[[0, -1]])
ax.legend()
ax.set_title(
    f"TCE {tce_uid}\nDisposition {disp_tce}\nNumber of examples"
    f" {len(examples_tce)} | TCE Max MES {mes_tce:.3f}"
)
f.tight_layout()
f.savefig(
    plot_dir / f"hist_transit_nottransit_examples_scores_{tce_uid}_{disp_tce}.png"
)
plt.show()


In [None]:

tce_uid =  "188768068-1-S14-26" #"352954787-1-S14-26"  # "198408416-1-S14-60"  # '425064757-1-S1-65'   # '198408416-1-S14-60'
lc_dir = Path("/Users/jochoa4/Downloads/")
sector_arr = list(range(14, 27))


tce = tce_tbl.loc[tce_tbl["tce_uid"] == tce_uid]

# find light curve data for target
search_lc_res = lk.search_lightcurve(
    target=f"tic{tce['target_id'].values[0]}",
    mission="TESS",
    author=("TESS-SPOC", "SPOC"),
    exptime=120,
    cadence="long",
    sector=sector_arr,
)

lcf = search_lc_res.download_all(
    download_dir=str(lc_dir), quality_bitmask="default", flux_column="pdcsap_flux"
)


def lcf_masked_quantity_corrector(lcf: lk.LightCurve) -> lk.LightCurve:
    lcf = lk.LightCurve({"time": lcf.time.value, "flux": np.array(lcf.flux.value)})
    return lcf.normalize()


lcf = lcf.stitch(corrector_func=lcf_masked_quantity_corrector)


In [None]:
lcf

In [None]:

t0, win_label = 1966.07, 0  # 2425.47, 0  # 2424.57, 1

dur_f = 5
win_len = tce["tce_duration"].values[0] / 24 * dur_f

t_start, t_end = t0 - win_len / 2, t0 + win_len / 2

f, ax = plt.subplots()
lcf.plot(ax=ax)
ax.set_xlim([t_start, t_end])
ax.set_title(f"TCE {tce_uid} | Disposition {disp_tce}\nt0={t0} | Label {win_label}")
f.savefig(plot_dir / f"plot_{tce_uid}_{disp_tce}_timestamp{t0}_label{win_label}.png")


In [None]:
tfrec_fp = Path("/Users/jochoa4/Desktop/study_transfers/study_model_preds_05-22-2025/tfrecords/norm_train_shard_2990-8611.tfrecord")


t0, win_label = 1966.07, 0  # 2425.47, 0  # 2424.57, 1

dur_f = 5
win_len = tce["tce_duration"].values[0] / 24 * dur_f

t_start, t_end = t0 - win_len / 2, t0 + win_len / 2

f, ax = plt.subplots()
lcf.plot(ax=ax)
ax.set_xlim([t_start, t_end])
ax.set_title(f"TCE {tce_uid} | Disposition {disp_tce}\nt0={t0} | Label {win_label}")
f.savefig(plot_dir / f"plot_{tce_uid}_{disp_tce}_timestamp{t0}_label{win_label}.png")


In [None]:

tce_uid = "161687311-2-S24"

examples_tce = preds_tbl.loc[preds_tbl["tce_uid"] == tce_uid]
disp_tce = examples_tce["disposition"].values[0]
mes_tce = examples_tce["tce_max_mult_ev"].values[0]

transit_window_examples = examples_tce.loc[examples_tce["label"] == 1]
not_transit_window_examples = examples_tce.loc[examples_tce["label"] == 0]

f, ax = plt.subplots(figsize=(10, 5))
ax.scatter(
    transit_window_examples["time"],
    transit_window_examples["raw_pred"],
    s=8,
    alpha=0.3,
    edgecolors="k",
    label="Transit Window Examples",
)
ax.scatter(
    not_transit_window_examples["time"],
    not_transit_window_examples["raw_pred"],
    s=8,
    alpha=0.3,
    edgecolors="k",
    label="Not-Transit Window Examples",
)
ax.set_ylabel("Model Score")
ax.set_xlabel("Timestamp [BTJD]")
ax.set_ylim(bins_scores[[0, -1]])
ax.legend()
ax.set_title(
    f"TCE {tce_uid}\nDisposition {disp_tce}\nNumber of examples"
    f" {len(examples_tce)} | TCE Max MES {mes_tce:.3f}"
)
f.tight_layout()
f.savefig(
    plot_dir / f"scatter_transit_nottransit_examples_scores_{tce_uid}_{disp_tce}.png"
)
plt.show()


bins_scores = np.linspace(0, 1, 11)

f, ax = plt.subplots()
ax.hist(
    transit_window_examples["raw_pred"],
    bins_scores,
    histtype="step",
    label="Transit Window Examples",
)
ax.hist(
    not_transit_window_examples["raw_pred"],
    bins_scores,
    histtype="step",
    label="Not-Transit Window Examples",
)
ax.set_xlabel("Model Score")
ax.set_ylabel("Example Count")
ax.set_xlim(bins_scores[[0, -1]])
ax.legend()
ax.set_title(
    f"TCE {tce_uid}\nDisposition {disp_tce}\nNumber of examples"
    f" {len(examples_tce)} | TCE Max MES {mes_tce:.3f}"
)
f.tight_layout()
f.savefig(
    plot_dir / f"hist_transit_nottransit_examples_scores_{tce_uid}_{disp_tce}.png"
)
plt.show()


In [None]:

tce_uid =  "161687311-2-S24" #"352954787-1-S14-26"  # "198408416-1-S14-60"  # '425064757-1-S1-65'   # '198408416-1-S14-60'
lc_dir = Path("/Users/jochoa4/Downloads/")
sector_arr = list(range(24, 25))


tce = tce_tbl.loc[tce_tbl["tce_uid"] == tce_uid]

# find light curve data for target
search_lc_res = lk.search_lightcurve(
    target=f"tic{tce['target_id'].values[0]}",
    mission="TESS",
    author=("TESS-SPOC", "SPOC"),
    exptime=120,
    cadence="long",
    sector=sector_arr,
)

lcf = search_lc_res.download_all(
    download_dir=str(lc_dir), quality_bitmask="default", flux_column="pdcsap_flux"
)


def lcf_masked_quantity_corrector(lcf: lk.LightCurve) -> lk.LightCurve:
    lcf = lk.LightCurve({"time": lcf.time.value, "flux": np.array(lcf.flux.value)})
    return lcf.normalize()


lcf = lcf.stitch(corrector_func=lcf_masked_quantity_corrector)


In [None]:

t0, win_label = 1970.34, 1 #1696.34, 0  # 2425.47, 0  # 2424.57, 1

dur_f = 5
win_len = tce["tce_duration"].values[0] / 24 * dur_f

t_start, t_end = t0 - win_len / 2, t0 + win_len / 2

f, ax = plt.subplots()
lcf.plot(ax=ax)
ax.set_xlim([t_start, t_end])
ax.set_title(f"TCE {tce_uid} | Disposition {disp_tce}\nt0={t0} | Label {win_label}")
f.savefig(plot_dir / f"plot_{tce_uid}_{disp_tce}_timestamp{t0}_label{win_label}.png")
