### Run upon export from spreadsheet

In [None]:
import os

from astroquery.mast import Catalogs
import numpy as np
import pandas as pd


tces_file = '/mnt/tess/labels/tces-triage-v12.csv'
labels_file = '/mnt/tess/labels/labels-triage-v12.csv'

tce_table = pd.read_csv(tces_file, header=0, low_memory=False).set_index('tic_id')
tce_table = tce_table[tce_table.Exclude != 'Y']

joined_table = tce_table
joined_table = joined_table.reset_index()[[
  'tic_id', 'Tmag', 'Epoc', 'Period', 'Duration',
  'Transit_Depth', 'star_rad', 'star_mass',
  'filename', 'Split'
]]
joined_table = joined_table.set_index('tic_id')


labels_table = pd.read_csv(labels_file, header=0, low_memory=False)
labels_table['tic_id'] = labels_table['TIC ID']

disps = ['E', 'J', 'N', 'S', 'B']
users = ['av', 'md', 'ch', 'as', 'mk', 'et', 'dm', 'td']
for d in disps:
  labels_table[f'disp_{d}'] = 0

def set_labels(row):
  a = ~row.isna()
  if a['Final'] and row["Final"] != 'U':
    has_label = True
    row[f'disp_{row["Final"]}'] = 1
  else:
    has_label = False
    for user in users:
      if a[user] and row[user] and row[user] != 'U':
        has_label = True
        row[f'disp_{row[user]}'] += 1
  if not has_label:
    row['Exclude'] = 'Y'
  return row
labels_table = labels_table.apply(set_labels, axis=1)

labels_table = labels_table[labels_table.Exclude != 'Y']
labels_table = labels_table[['tic_id'] + [f'disp_{d}' for d in disps]]
labels_table = labels_table.set_index('tic_id')

joined_table = joined_table.join(labels_table, on='tic_id', how='inner')
print(f'Total entries: {len(joined_table)}')
joined_table = joined_table[
    sum(joined_table[f'disp_{d}'] for d in disps) > 0
]
print(f'Total labeled entries: {len(joined_table)}')


t_train = joined_table[joined_table['Split'] == 'train'].drop(columns=['Split'])
t_val = joined_table[joined_table['Split'] == 'val'].drop(columns=['Split'])
t_test = joined_table[joined_table['Split'] == 'test'].drop(columns=['Split'])
all_table = joined_table.drop(columns=['Split'])


print(f'Split sizes. Train: {len(t_train)}; Valid: {len(t_val)}; Test: {len(t_test)}')
print(f'Duplicate TICs: {len(all_table.index.values) - len(set(all_table.index.values))}')

t_train.to_csv('/mnt/tess/astronet/tces-v12-train.csv')
t_val.to_csv('/mnt/tess/astronet/tces-v12-val.csv')
t_test.to_csv('/mnt/tess/astronet/tces-v12-test.csv')
all_table.to_csv('/mnt/tess/astronet/tces-v12-all.csv')

In [None]:
pd.set_option('display.max_columns', None)
t_train.sample(5)

In [None]:
t_val.sample(5)

In [None]:
t_test.sample(5)