Notes:
 * SOTA (run 14) on the sample TICs, cam2

In [7]:
import os

chkpt_root = '/mnt/tess/astronet/checkpoints/extended_23_run_14'
data_files = '/mnt/tess/astronet/tfrecords-s33-cam2-ccd14-sample/*'
tces_file = '/mnt/tess/astronet/tces-s33_cam2ccd14_sample.csv'

nruns = 10

def load_ensemble(chkpt_root, nruns):
    checkpts = []
    for i in range(nruns):
        parent = os.path.join(chkpt_root, str(i + 1))
        if not os.path.exists(parent):
            break
        all_dirs = os.listdir(parent)
        if not all_dirs:
            break
        d, = all_dirs
        checkpts.append(os.path.join(parent, d))
    return checkpts

paths = load_ensemble(chkpt_root, nruns)
paths

['/mnt/tess/astronet/checkpoints/extended_23_run_14/1/AstroCNNModel_extended_20210131_212427',
 '/mnt/tess/astronet/checkpoints/extended_23_run_14/2/AstroCNNModel_extended_20210131_215717',
 '/mnt/tess/astronet/checkpoints/extended_23_run_14/3/AstroCNNModel_extended_20210131_222957',
 '/mnt/tess/astronet/checkpoints/extended_23_run_14/4/AstroCNNModel_extended_20210131_230255',
 '/mnt/tess/astronet/checkpoints/extended_23_run_14/5/AstroCNNModel_extended_20210131_233546',
 '/mnt/tess/astronet/checkpoints/extended_23_run_14/6/AstroCNNModel_extended_20210201_000826',
 '/mnt/tess/astronet/checkpoints/extended_23_run_14/7/AstroCNNModel_extended_20210201_004107',
 '/mnt/tess/astronet/checkpoints/extended_23_run_14/8/AstroCNNModel_extended_20210201_011350',
 '/mnt/tess/astronet/checkpoints/extended_23_run_14/9/AstroCNNModel_extended_20210201_014655',
 '/mnt/tess/astronet/checkpoints/extended_23_run_14/10/AstroCNNModel_extended_20210201_021945']

In [8]:
import getpass
import os
from astronet import predict
import tensorflow as tf


def run_predictions(path):
    predict.FLAGS = predict.parser.parse_args([
      '--model_dir', path,
      '--data_files', data_files,
      '--output_file', '',
    ])

    return predict.predict()


paths = load_ensemble(chkpt_root, nruns)
ensemble_preds = []
config = None
for i, path in enumerate(paths):
    print(f'Running model {i + 1}')
    preds, config = run_predictions(path)
    ensemble_preds.append(preds.set_index('tic_id'))
    print()

Running model 1
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records
Running model 2
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records
Running model 3
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records
Running model 4
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records
Running model 5
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records
Running model 6
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records
Running model 7
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records
Running model 8
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records
Running model 9
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records
Running model 10
Binary prediction threshold: 0.2152499407880693 (orientative)
577 records


In [20]:
labels = ['disp_E', 'disp_N', 'disp_J', 'disp_S', 'disp_B']

col_e = labels.index('disp_E')
# thresh = config.hparams.prediction_threshold
# thresh = 0.030485098838860747  # From the validation numbers - maximum thrershold for 100% recall
thresh = 0.31245827674871207  # Relaxed to match Liang's precision value

In [21]:
import numpy as np
import pandas as pd

agg_preds = {}

for preds in ensemble_preds:
    for tic_id in preds.index:
        if tic_id not in agg_preds:
            agg_preds[tic_id] = []

        row = preds[preds.index == tic_id]
        pred_v = row.values[0]
        if len(row.values) > 1:
            print(f'Warning: duplicate predictions for {tic_id}')
        if pred_v[col_e] >= thresh:
            agg_preds[tic_id].append('disp_E')
        else:
            masked_v = [v if i != col_e else 0 for i, v in enumerate(pred_v)]
            agg_preds[tic_id].append(preds.columns[np.argmax(masked_v)])

In [22]:
final_preds = []
for tic_id in list(agg_preds.keys()):
    counts = {l: 0 for l in labels}
    for e in agg_preds[tic_id]:
        counts[e] += 1
    maxcount = max(counts.values())
    counts.update({
        'tic_id': tic_id,
        'maxcount': maxcount,
    })
    final_preds.append(counts)

final_preds = pd.DataFrame(final_preds).set_index('tic_id')

In [23]:
final_preds.head()

Unnamed: 0_level_0,disp_E,disp_N,disp_J,disp_S,disp_B,maxcount
tic_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
123724341,10,0,0,0,0,10
407685853,0,0,10,0,0,10
124228912,0,0,0,0,10,10
59953460,0,0,0,0,10,10
93282952,0,0,10,0,0,10


In [24]:
def compare(ensemble_preds, filter):
    result = ensemble_preds[0][filter]
    for preds in ensemble_preds[1:]:
        result = result.append(preds[filter])
    return result

compare(ensemble_preds, preds.index == 263337671)

Unnamed: 0_level_0,disp_E,disp_N,disp_J,disp_S,disp_B
tic_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1


In [25]:
def final_pred(row):
    if (row['Distinct'] > 1
        and (
            not isinstance(row['Decision'], str)
            and (
                (row['av'] in ('E', 'S'))
                or (row['md'] in ('E', 'S'))
                or (row['ch'] in ('E', 'S'))
                or (row['as'] in ('E', 'S'))
                or (row['mk'] in ('E', 'S'))
                or (row['et'] in ('E', 'S'))
            )
        )
       ):
        return '?'
    
    if (row['disp_E'] > 0):
        return 'E'
    else:
        maxpred = 'disp_E'
        for c in ['disp_N', 'disp_J', 'disp_S', 'disp_B']:
            if row[c] > row[maxpred]:
                maxpred = c
        return maxpred[5]

agg_preds = pd.read_csv('~/Astronet-Triage/Labels - extended mission test.csv', header=0, low_memory=False)
agg_preds = agg_preds.set_index('TIC ID')
agg_preds = final_preds.join(agg_preds)
agg_preds['final'] = agg_preds.apply(final_pred, axis=1)
agg_preds = agg_preds[['final']]

In [26]:
agg_preds[agg_preds['final'] == '?']

Unnamed: 0_level_0,final
tic_id,Unnamed: 1_level_1


In [27]:
agg_preds.to_csv('~/Astronet-Triage/tces-s33_cam2ccd14_sample.csv')