In [1]:
import os

chkpt_root = '/mnt/tess/astronet/checkpoints/fa1_38_run_1'
data_files = '/mnt/tess/astronet/tfrecords-38-train/*'
tces_file = '/mnt/tess/astronet/tces-v14-train.csv'

nruns = 10

def load_ensemble(chkpt_root, nruns):
    checkpts = []
    for i in range(nruns):
        parent = os.path.join(chkpt_root, str(i + 1))
        if not os.path.exists(parent):
            break
        all_dirs = os.listdir(parent)
        if not all_dirs:
            break
        d, = all_dirs
        checkpts.append(os.path.join(parent, d))
    return checkpts

paths = load_ensemble(chkpt_root, nruns)
paths

['/mnt/tess/astronet/checkpoints/fa1_38_run_1/1/AstroCNNModel_final_alpha_1_20220504_164445',
 '/mnt/tess/astronet/checkpoints/fa1_38_run_1/2/AstroCNNModel_final_alpha_1_20220504_172032',
 '/mnt/tess/astronet/checkpoints/fa1_38_run_1/3/AstroCNNModel_final_alpha_1_20220504_175419',
 '/mnt/tess/astronet/checkpoints/fa1_38_run_1/4/AstroCNNModel_final_alpha_1_20220504_182735',
 '/mnt/tess/astronet/checkpoints/fa1_38_run_1/5/AstroCNNModel_final_alpha_1_20220504_190055',
 '/mnt/tess/astronet/checkpoints/fa1_38_run_1/6/AstroCNNModel_final_alpha_1_20220504_193432',
 '/mnt/tess/astronet/checkpoints/fa1_38_run_1/7/AstroCNNModel_final_alpha_1_20220504_200824',
 '/mnt/tess/astronet/checkpoints/fa1_38_run_1/8/AstroCNNModel_final_alpha_1_20220504_204203',
 '/mnt/tess/astronet/checkpoints/fa1_38_run_1/9/AstroCNNModel_final_alpha_1_20220504_211537',
 '/mnt/tess/astronet/checkpoints/fa1_38_run_1/10/AstroCNNModel_final_alpha_1_20220504_214944']

In [2]:
import getpass
import os
from astronet import predict
import tensorflow as tf


def run_predictions(path):
    predict.FLAGS = predict.parser.parse_args([
      '--model_dir', path,
      '--data_files', data_files,
      '--output_file', '',
    ])

    return predict.predict()


paths = load_ensemble(chkpt_root, nruns)
ensemble_preds = []
config = None
for i, path in enumerate(paths):
    print(f'Running model {i + 1}')
    preds, config = run_predictions(path)
    ensemble_preds.append(preds.set_index('astro_id'))
    print()

Running model 1
19919 records
Running model 2
19919 records
Running model 3
19919 records
Running model 4
19919 records
Running model 5
19919 records
Running model 6
19919 records
Running model 7
19919 records
Running model 8
19919 records
Running model 9
19919 records
Running model 10
19919 records


In [3]:
labels = ['disp_e', 'disp_n', 'disp_j', 'disp_s', 'disp_b']

col_e = labels.index('disp_e')
thresh = 0.2

In [4]:
import numpy as np
import pandas as pd

agg_preds = {}

for preds in ensemble_preds:
    for ex_id in preds.index:
        if ex_id not in agg_preds:
            agg_preds[ex_id] = []

        row = preds[preds.index == ex_id]
        pred_v = row.values[0]
        if len(row.values) > 1:
            print(f'Warning: duplicate predictions for {ex_id}')
        if pred_v[col_e] >= thresh:
            agg_preds[ex_id].append('disp_e')
        else:
            masked_v = [v if i != col_e else 0 for i, v in enumerate(pred_v)]
            agg_preds[ex_id].append(preds.columns[np.argmax(masked_v)])

In [5]:
final_preds = []
for ex_id in list(agg_preds.keys()):
    counts = {l: 0 for l in labels}
    for e in agg_preds[ex_id]:
        counts[e] += 1
    maxcount = max(counts.values())
    counts.update({
        'astro_id': ex_id,
        'maxcount': maxcount,
    })
    final_preds.append(counts)
    
final_preds = pd.DataFrame(final_preds).set_index('astro_id')

In [6]:
tce_table = pd.read_csv(tces_file, header=0, low_memory=False)
tce_table['astro_id'] = tce_table['Astro ID']
tce_table = tce_table.set_index('astro_id')
for l in labels:
    tce_table[l] = tce_table[l[:-1] + l[-1].upper()]
tce_labels = tce_table[labels + ['TIC ID']]

pl = final_preds.join(tce_labels, on='astro_id', how='left', lsuffix='_p')

pl.head()
pd.set_option('display.max_columns', None)

In [7]:
ppos = (pl['disp_e_p'] > 0)
pos = (pl['disp_e'] > 0)

pneg = (pl['disp_e_p'] == 0)
neg = (pl['disp_e'] == 0)

print('Recall:', len(pl[ppos & pos]) / len(pl[pos]))
print('Precision:', len(pl[ppos & pos]) / len(pl[ppos]))

Recall: 0.9952426260704091
Precision: 0.8


In [8]:
for i in pl[pos & pneg]['TIC ID']:
    print(i)

49799681
50905927
761960972
379464439
369860950
376936788
381366555
176582931
383716793
387829772


In [9]:
for i in pl[neg & ppos]['TIC ID']:
    print(i)

30313096
231792014
469269031
404851658
50269985
339959047
31415158
89307726
279614421
384077498
149928367
260700439
186302615
424277160
457139941
173008159
404355884
252809234
1884267302
235562906
154006988
1718006824
105378745
252733538
240341734
446012335
308958576
332867524
170932338
372186057
98937256
85517073
388906923
262169297
154699047
445417446
83408987
356822872
99382119
380289423
43667308
343463316
190622748
405460905
322399290
277298580
404931649
306890368
257772292
410228822
425083264
284196025
86144938
88353444
349572054
201563776
396967912
31391041
340217291
300557619
279615956
276793626
373843857
301955257
409639832
261136311
40515061
90161731
30641264
363414722
312028996
360961597
363979429
271597452
327595687
101821550
86516107
233896030
148392230
198536950
514546468
284701173
60385535
405864709
32000625
254158487
351472293
30267088
384195094
295372927
396957480
389363745
270365850
323170106
349635294
420174832
164464417
239551545
92328347
417057112
1715180951
4683718

In [10]:
def compare(ensemble_preds, filter):
    result = ensemble_preds[0][filter]
    for preds in ensemble_preds[1:]:
        result = result.append(preds[filter])
    return result

compare(ensemble_preds, preds.index == pl[pl['TIC ID'] == 118412801].index.values[0])

IndexError: index 0 is out of bounds for axis 0 with size 0

In [None]:
pl[pl.index == pl[pl['TIC ID'] == 1254504863].index.values[0]]

### PR curve

In [None]:
ids = set(ensemble_preds[0].index.values)

index = {v: i for i, v in enumerate(ids)}

pred_es = np.zeros([len(ensemble_preds), len(index)])
for i, preds in enumerate(ensemble_preds):
    for row in preds.iterrows():
        ex_id, pred_e = row[0], row[1][col_e]
        pred_es[i][index[ex_id]] = pred_e

lbl_es = np.zeros([len(index)], dtype=np.bool)
for row in tce_labels.iterrows():
    ex_id, lbl_e = row[0], row[1]['disp_e']
    lbl_es[index[ex_id]] = (lbl_e > 0)

In [None]:
num_cond_pos = int(np.sum(lbl_es))

def pr_at_th(th):
    pred_pos = np.any(pred_es >= th, axis=0)
    true_pos = pred_pos & lbl_es
    num_pred_pos = int(np.sum(pred_pos))
    num_true_pos = int(np.sum(true_pos))
    if num_pred_pos == 0:
        return 1.0, 0.0
    return float(num_true_pos) / float(num_pred_pos), float(num_true_pos) / float(num_cond_pos)

In [None]:
from matplotlib import pyplot as plt

ps, rs, ths = ([], [], [])
th = np.max(pred_es)
while th >= 0.0:
    p, r = pr_at_th(th)
    ps.append(p)
    rs.append(r)
    ths.append(th)
    th -= 0.0005
    
from sklearn import metrics

print(f'AUC: {metrics.auc(rs, ps)}, max R: {max(rs)}, max P: {max(ps)}')
    
i = len(rs) - 1
while rs[i] == 1.0:
    i -= 1
i += 1
print(f'100% recall at: {int(ps[i] * 100)}%, threshold: {ths[i]}')

fig, ax = plt.subplots(figsize=(6, 3.7), dpi=200)

ax.spines['top'].set_color('#808080')
ax.spines['right'].set_color('#808080')
ax.spines['left'].set_color('#808080')
ax.spines['bottom'].set_color('#808080')
ax.tick_params(direction='in', color='#808080')

plt.grid(color='#c0c0c0', linestyle='--', linewidth=0.5)

plt.ylabel('Precision', fontweight='bold')
plt.xlabel('Recall', fontweight='bold')

plt.xlim((0.0, 1.0))
plt.ylim((0.0, 1.0))

_ = plt.plot(rs, ps)