Notes:
 * Run 17 on the TOI targets, comparing BLS with TEV

In [1]:
import os

chkpt_root = '/mnt/tess/astronet/checkpoints/extended_25_run_17'
data_files = '/mnt/tess/astronet/tfrecords-toi-bls-vs-tev/*'
tces_file = '/mnt/tess/astronet/tces-toi-bls-vs-tev.csv'

nruns = 10

def load_ensemble(chkpt_root, nruns):
    checkpts = []
    for i in range(nruns):
        parent = os.path.join(chkpt_root, str(i + 1))
        if not os.path.exists(parent):
            break
        all_dirs = os.listdir(parent)
        if not all_dirs:
            break
        d, = all_dirs
        checkpts.append(os.path.join(parent, d))
    return checkpts

paths = load_ensemble(chkpt_root, nruns)
paths

['/mnt/tess/astronet/checkpoints/extended_25_run_17/1/AstroCNNModel_extended_20210321_154527',
 '/mnt/tess/astronet/checkpoints/extended_25_run_17/2/AstroCNNModel_extended_20210321_161606',
 '/mnt/tess/astronet/checkpoints/extended_25_run_17/3/AstroCNNModel_extended_20210321_164646',
 '/mnt/tess/astronet/checkpoints/extended_25_run_17/4/AstroCNNModel_extended_20210321_171723',
 '/mnt/tess/astronet/checkpoints/extended_25_run_17/5/AstroCNNModel_extended_20210321_174729',
 '/mnt/tess/astronet/checkpoints/extended_25_run_17/6/AstroCNNModel_extended_20210321_181710',
 '/mnt/tess/astronet/checkpoints/extended_25_run_17/7/AstroCNNModel_extended_20210321_184652',
 '/mnt/tess/astronet/checkpoints/extended_25_run_17/8/AstroCNNModel_extended_20210321_191655',
 '/mnt/tess/astronet/checkpoints/extended_25_run_17/9/AstroCNNModel_extended_20210321_194703',
 '/mnt/tess/astronet/checkpoints/extended_25_run_17/10/AstroCNNModel_extended_20210321_201726']

In [2]:
import getpass
import os
from astronet import predict
import tensorflow as tf


def run_predictions(path):
    predict.FLAGS = predict.parser.parse_args([
      '--model_dir', path,
      '--data_files', data_files,
      '--output_file', '',
    ])

    return predict.predict()


paths = load_ensemble(chkpt_root, nruns)
ensemble_preds = []
config = None
for i, path in enumerate(paths):
    print(f'Running model {i + 1}')
    preds, config = run_predictions(path)
    ensemble_preds.append(preds)
    print()

Running model 1
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records
Running model 2
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records
Running model 3
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records
Running model 4
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records
Running model 5
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records
Running model 6
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records
Running model 7
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records
Running model 8
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records
Running model 9
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records
Running model 10
Binary prediction threshold: 0.2152499407880693 (orientative)
1515 records


In [3]:
labels = ['disp_E', 'disp_N', 'disp_J', 'disp_S', 'disp_B']

col_e = labels.index('disp_E')
# thresh = config.hparams.prediction_threshold
# thresh = 0.030485098838860747  # From the validation numbers - maximum thrershold for 100% recall
thresh = 0.31245827674871207  # Relaxed to match Liang's precision value

In [4]:
import numpy as np
import pandas as pd

agg_preds = {}
tic_ids = {}

for preds in ensemble_preds:
    for row_id in preds.index:
        if row_id not in agg_preds:
            agg_preds[row_id] = []
            tic_ids[row_id] = preds['tic_id'][row_id]

        row = preds[preds.index == row_id]
        pred_v = row.values[0]
        if len(row.values) > 1:
            print(f'Warning: duplicate predictions for {row_id}')
        if pred_v[col_e] >= thresh:
            agg_preds[row_id].append('disp_E')
        else:
            masked_v = [v if i != col_e else 0 for i, v in enumerate(pred_v)]
            agg_preds[row_id].append(preds.columns[np.argmax(masked_v)])

In [5]:
final_preds = []
for row_id in list(agg_preds.keys()):
    counts = {l: 0 for l in labels}
    for e in agg_preds[row_id]:
        counts[e] += 1
    maxcount = max(counts.values())
    counts.update({
        'row_id': row_id,
        'tic_id': tic_ids[row_id],
        'maxcount': maxcount,
    })
    final_preds.append(counts)

final_preds = pd.DataFrame(final_preds).set_index('tic_id')

In [6]:
pd.set_option('display.max_row', None)

def compare(ensemble_preds, filter):
    result = ensemble_preds[0][filter]
    for preds in ensemble_preds[1:]:
        result = result.append(preds[filter])
    return result

compare(ensemble_preds, preds.tic_id == 161477033).sort_index()

Unnamed: 0,tic_id,disp_E,disp_N,disp_J,disp_S,disp_B
296,161477033,0.000106,0.016336,0.999656,0.00041,0.003393
296,161477033,2.2e-05,0.015427,0.999886,0.000118,0.000959
296,161477033,0.000504,0.046386,0.998169,0.001177,0.002472
296,161477033,9e-06,0.009137,0.999971,2.1e-05,0.000209
296,161477033,8.6e-05,0.023264,0.999718,0.000168,0.001246
296,161477033,0.000258,0.022865,0.999367,0.000636,0.000821
296,161477033,0.000102,0.02321,0.999795,7.2e-05,0.000807
296,161477033,8.8e-05,0.020062,0.999612,0.000441,0.001083
296,161477033,0.000246,0.028025,0.999518,0.000431,0.001365
296,161477033,3.7e-05,0.020676,0.999821,0.000123,0.001264
