In [2]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split


## Read table
fname = 'vetting-v02'
fpath = f'../mnt/tess/labels/{fname}.csv'
all_table = pd.read_csv(fpath, header=0, low_memory=False).set_index('Astro ID')

## Rename or drop columns
all_table.drop(columns=['Split'])
all_table = all_table.rename(columns={'filename': 'File', 
                                      'Period': 'Per', 
                                      'Duration': 'Dur',
                                      'Transit_Depth': 'Depth',
                                      'star_rad': 'SRad',
                                      'star_rad_est': 'SRadEst',
                                      'star_mass': 'SMass',})

## Make label columns
disps = ['e', 'p', 'n', 'b', 't', 'u', 'j']
users = ['mk', 'ch', 'et', 'md', 'as', 'dm', 'Tansu', 'Shishir']
for d in disps:
    all_table[f'disp_{d}'] = 0

## Set labels
def set_labels(row):
    a = ~row.isna()
    if row['Final'] == 'i':
        # skip objects labeled as "inside the star"
        return row
    if a['Final']:
        row[f'disp_{row["Final"][0]}'] = 1
        row[f'disp_{row["Final"][1]}'] = 1
    else:
        for user in users:
            if a[user] and row[user]:
                row[f'disp_{row[user][0]}'] += 1
                row[f'disp_{row[user][1]}'] += 1

    return row

all_table = all_table.apply(set_labels, axis=1)

## Only use labelled rows 
print(f'Total entries: {len(all_table)}')
all_table = all_table[sum(all_table[f'disp_{d}'] for d in disps) > 0]
print(f'Total labeled entries: {len(all_table)}')
## And skip bad rows that have NaN in File', 'Per', 'Dur', 'Depth', or 'Tmag'
all_table = all_table[~all_table[['File', 'Per', 'Dur', 'Depth', 'Tmag']].isna().any(axis=1)]
print(f'Total after removing rows with missing File, Per, Dur, Depth, or Tmag: {len(all_table)}')


## Train-test split
t_train, t_test = train_test_split(all_table, test_size=0.1, random_state=42)
t_train, t_val = train_test_split(t_train, test_size=1./9, random_state=42)

## Print sizes of arrays and print duplicate counts
print(f'Split sizes. Train: {len(t_train)}; Valid: {len(t_val)}; Test: {len(t_test)}')
print(f'Duplicate TICs: {len(all_table.index.values) - len(set(all_table.index.values))}')
print('Splits')
print('  train:', len(t_train))
print('  val:', len(t_val))
print('  test:', len(t_test))

## Check label arrays
assert not any((t_train['disp_e'] + t_train['disp_p']+ t_train['disp_n'] + t_train['disp_b'] + t_train['disp_t'] + t_train['disp_u'] + t_train['disp_j']) == 0)
assert not any((t_val['disp_e'] + t_val['disp_p']+ t_val['disp_n'] + t_val['disp_b'] + t_val['disp_t'] + t_val['disp_u']+ t_val['disp_j']) == 0)
assert not any((t_test['disp_e'] + t_test['disp_p']+ t_test['disp_n'] + t_test['disp_b'] + t_test['disp_t'] + t_test['disp_u'] + t_test['disp_j']) == 0)

## Save train, test, and validation csv iles
t_train.to_csv(f'../mnt/tess/astronet/tces-{fname}-train.csv')
t_val.to_csv(f'../mnt/tess/astronet/tces-{fname}-val.csv')
t_test.to_csv(f'../mnt/tess/astronet/tces-{fname}-test.csv')
all_table.to_csv(f'../mnt/tess/astronet/tces-{fname}-all.csv')




Total entries: 9344
Total labeled entries: 4071
Total after removing rows with missing File, Per, Dur, Depth, or Tmag: 4016
Split sizes. Train: 3212; Valid: 402; Test: 402
Duplicate TICs: 0
Splits
  train: 3212
  val: 402
  test: 402


In [6]:
# all_table[['File', 'Per', 'Dur', 'Depth', 'Tmag', 'SRad', 'SRadEst', 'SMass']].isna().any(axis=1)
all_table

Unnamed: 0_level_0,TIC ID,Final,Decision,Distinct,mk,ch,et,md,as,dm,...,SRadEst,File,comment,disp_e,disp_p,disp_n,disp_b,disp_t,disp_u,disp_j
Astro ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,101179364,et,,1,et,et,et,et,et,,...,1.098840,mk_hlsp_qlp_tess_ffi-s0013-0000000101179364_te...,E labels from triage,1,0,0,0,1,0,0
2,101255974,eb,eb,2,eb,eb,eb,eb,et,,...,4.870666,mk_hlsp_qlp_tess_ffi-s0018-0000000101255974_te...,E labels from triage,1,0,0,1,0,0,0
3,101404344,eb,eb,2,eb,eb,eb,eb,pb,,...,1.728549,mk_hlsp_qlp_tess_ffi-s0027-0000000101404344_te...,E labels from triage,1,0,0,1,0,0,0
4,101427335,et,et,2,et,et,et,eb,et,,...,1.679462,mk_hlsp_qlp_tess_ffi-s0027-0000000101427335_te...,E labels from triage,1,0,0,0,1,0,0
5,10150705,eb,eb,2,eb,et,eb,eb,eb,,...,0.794777,mk_hlsp_qlp_tess_ffi-s0022-0000000010150705_te...,E labels from triage,1,0,0,1,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4088,22221375,pt,,1,pt,,,,,,...,,astronet_hlsp_qlp_tess_ffi-s0064-0000000022221...,TOI-652.01 (CP) detected by UNKNOWN pipeline,0,1,0,0,1,0,0
4089,22233480,pt,,1,pt,,,,,,...,,astronet_hlsp_qlp_tess_ffi-s0064-0000000022233...,TOI-4438.01 (PC) detected by QLP pipeline,0,1,0,0,1,0,0
4090,22384839,pt,,1,pt,,,,,,...,,astronet_hlsp_qlp_tess_ffi-s0064-0000000022384...,TOI-3120.01 (PC) detected by QLP pipeline,0,1,0,0,1,0,0
4091,22529346,pt,,1,pt,,,,,,...,,astronet_hlsp_qlp_tess_ffi-s0064-0000000022529...,TOI-495.01 (KP) detected by QLP pipeline,0,1,0,0,1,0,0


In [5]:
t_train

Unnamed: 0_level_0,TIC ID,Final,Decision,Distinct,mk,ch,et,md,as,dm,...,SRadEst,File,comment,disp_e,disp_p,disp_n,disp_b,disp_t,disp_u,disp_j
Astro ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3358,393233743,et,,1,et,et,,,,,...,2.450910,mk_hlsp_qlp_tess_ffi-s0020-0000000393233743_te...,EBs from TOI group vetting,1,0,0,0,1,0,0
1784,43770392,et,,1,et,,et,,et,et,...,3.295305,mk_hlsp_qlp_tess_ffi-s0014-0000000043770392_te...,EBs from TOI group vetting,1,0,0,0,1,0,0
524,167651110,et,,1,et,et,et,et,et,,...,2.050283,mk_hlsp_qlp_tess_ffi-s0034-0000000167651110_te...,E labels from triage,1,0,0,0,1,0,0
3251,377763672,jj,jj,2,jj,eb,,,,,...,2.775174,mk_hlsp_qlp_tess_ffi-s0011-0000000377763672_te...,EBs from TOI group vetting,0,0,0,0,0,0,1
126,136040527,et,,1,et,et,et,et,et,,...,3.087170,mk_hlsp_qlp_tess_ffi-s0015-0000000136040527_te...,E labels from triage,1,0,0,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3918,7020254,pt,,1,pt,pt,,,,,...,,astronet_hlsp_qlp_tess_ffi-s0064-0000000007020...,TOI-4316.01 (KP) detected by QLP pipeline,0,1,0,0,1,0,0
3488,425538187,eb,,1,eb,eb,,,,,...,9.192106,mk_hlsp_qlp_tess_ffi-s0011-0000000425538187_te...,EBs from TOI group vetting,1,0,0,1,0,0,0
1069,314865962,pt,pt,2,pt,pu,pt,pt,pt,,...,0.970018,mk_hlsp_qlp_tess_ffi-s0033-0000000314865962_te...,E labels from triage,0,1,0,0,1,0,0
2333,217783951,eb,,1,eb,eb,eb,,,eb,...,7.127803,mk_hlsp_qlp_tess_ffi-s0012-0000000217783951_te...,EBs from TOI group vetting,1,0,0,1,0,0,0
