In [None]:
import os
import tensorflow as tf

# tces_file = '/mnt/tess/astronet/tces-vetting-v4-toi-train.csv'
# file_pattern = '/mnt/tess/astronet/tfrecords-vetting-5-toi-train/*'
# model_name = 'AstroCNNModelVetting'
# config_name = 'vrevised'
# labels = ['p', 'e', 'n']
tces_file = '/mnt/tess/astronet/tces-v12-y2-train.csv'
file_pattern = '/mnt/tess/astronet/tfrecords-35-y2-train/*'
model_name = 'AstroCNNModel'
config_name = 'revised_tuned'
labels = ['E', 'N', 'J', 'S', 'B']

filenames = tf.io.gfile.glob(file_pattern)
    
filenames

In [None]:
import pandas as pd

tce_table = pd.read_csv(tces_file, header=0)
print(len(tce_table))
tce_table.head(3)

In [None]:
series = {}

for filename in filenames:
  tfr = tf.data.TFRecordDataset(filename)
  num_records = 0
  for record in tfr:
    num_records += 1
    ex = tf.train.Example.FromString(record.numpy())
    for k in ex.features.feature.keys():
      f = ex.features.feature[k]
      if f.int64_list.value:
        v = f.int64_list.value[0]
      elif f.float_list.value:
        v = f.float_list.value[0]
      elif f.bytes_list.value:
        v = f.bytes_list.value[0].decode()
      else:
        continue

      if k not in series:
        series[k] = []
      series[k].append(v)
  print(filename, num_records)

In [None]:
import pandas as pd

examples_table = pd.DataFrame.from_dict(series)

pd.set_option('display.max_columns', None)
examples_table.describe()

In [None]:
from matplotlib import pyplot as plt

counts = [sum(examples_table['disp_{}'.format(l)] > 0) for l in labels]
ax = plt.bar(labels, counts)
for i in range(len(labels)):
    b = ax[i]
    height = b.get_height()
    x, y = b.get_xy()
    plt.annotate(
        '{} - {:.0%}'.format(counts[i], counts[i] / sum(counts)),
        (x + 0.1, y + height + 11))

In [None]:
examples_table.head(3)

In [None]:
tce_table[tce_table.index == 91152385]

In [None]:
import numpy as np

print('TICs with labels mismatched between TCE and tfrecords:')
np.array(set(tce_table[tce_table[f'disp_{labels[0]}'] > 0]['tic_id'].values)
    - set(examples_table[examples_table[f'disp_{labels[0]}'] > 0]['tic_id'].values))

In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

from astronet.preprocess import preprocess


tess_data_dir = '/mnt/tess/lc'

def find_tce(tic_id):
  with tf.device('cpu'):
    for filename in filenames:
      tfr = tf.data.TFRecordDataset(filename)
      for record in tfr:
        ex = tf.train.Example.FromString(record.numpy())
        if (ex.features.feature["tic_id"].int64_list.value[0] == tic_id):
          print('TIC ID:', tic_id)
          for l in labels:
              print(f'{l}:', ex.features.feature[f"disp_{l}"].int64_list.value[0])
          print('Duration:', ex.features.feature["Duration"].float_list.value[0])
          return ex

    raise ValueError("{} not found in files: {}".format(tic_id, filenames))

In [None]:
import tensorflow as tf

from astronet import models
from astronet.astro_cnn_model import input_ds

config = models.get_model_config(model_name, config_name)

ds = input_ds.build_dataset(
      file_pattern=file_pattern,
      input_config=config.inputs,
      batch_size=1,
      include_labels=False,
      shuffle_filenames=False,
      repeat=1,
      include_identifiers=True)
labels_ds = input_ds.build_dataset(
      file_pattern=file_pattern,
      input_config=config.inputs,
      batch_size=1,
      include_labels=True,
      shuffle_filenames=False,
      repeat=1,
      include_identifiers=True)
labels_iter = iter(labels_ds)

label_index = {k.lower(): i for i, k in enumerate(config.inputs.label_columns)}
cols = ["disp_E", "disp_N", "disp_J", "disp_S", "disp_B"]

all_tics = []
bad_labels = []
for d in ds:
  lab = next(labels_iter)
  
  def lam(e):
    if e.dtype == tf.int64:
        return e
    if tf.reduce_any(tf.math.is_nan(e)):
        tf.print(e, summarize=-1)
        raise ValueError('data has NaNs.')
    return e
  tic_id = d[1].numpy().item()
  all_tics.append(tic_id)
  
  assert lab[0]['duration'] == d[0]['duration']
  rec = tce_table[tce_table.tic_id == tic_id]
  for c in cols:
    if (lab[1][0][label_index[c.lower()]].numpy() == 0) != (rec[c].values[0] == 0):
      bad_labels.append(tic_id)
      print('bad tic: ', tic_id)
      print(rec)
      print(cols)
      print(lab[1][0])
      break
  if bad_labels:
    break
  
  try:
    tf.nest.map_structure(lam, d)
  except ValueError as e:
    print(e)
    print(d[1])
    break
else:
  print('No NaNs or mismtached labels found.')

if len(all_tics) == len(set(all_tics)):
  print('No duplicates found.')
else:
  print('Found duplicates!', len(all_tics) - len(set(all_tics)))
  print([t for t in set(all_tics) if all_tics.count(t) > 1])

In [None]:
ds = ds.cache()
def plot_ds_tce(ds, tic_id):
    for d in ds:
        if d[1] == tic_id:
            for k, v in d[0].items():
                if k.startswith('local_'):
                    continue
                if k.startswith('global_'):
                    continue
                if k.startswith('secondary_'):
                    continue
                if k.startswith('sample_'):
                    continue
                print(f'{k:25}: {v.numpy()}')
            global_view = np.array(d[0]['global_view'][0].numpy())
            local_view = np.array(d[0]['local_view'][0].numpy())
            secondary_view = np.array(d[0]['secondary_view'][0].numpy())
            fig, axes = plt.subplots(2, 3, figsize=(20, 12))
            axes[0, 0].plot(global_view, '.-')
            axes[0, 1].plot(local_view, '.-')
            axes[0, 2].plot(secondary_view, '.-')
            axes[1, 0].plot(d[0]['global_mask'][0].numpy(), '.-')
            axes[1, 1].plot(d[0]['global_view_0.3'][0].numpy(), '.-')
            axes[1, 2].plot(d[0]['global_view_5.0'][0].numpy(), '.-')
            plt.show()
            plt.close('all')
            return

In [None]:
tic_id = 348962922
plot_ds_tce(ds, tic_id)
tce_table[tce_table['tic_id'] == tic_id]
examples_table[examples_table['tic_id'] == tic_id]

In [None]:
tce_table[tce_table['tic_id'] == tic_id]

In [None]:
examples_table[examples_table['tic_id'] == tic_id]

In [None]:
!ls /mnt/tess/lc-v | grep 237320326

In [None]:
tce = find_tce(368435330)

list(tce.features.feature.keys())