In [None]:
import os
import tensorflow as tf

tces_file = '/mnt/tess/astronet/tces-v3-train.csv'
file_pattern = '/mnt/tess/astronet/tfrecords-18-train/*'
filenames = tf.io.gfile.glob(file_pattern)
    
filenames

In [None]:
import pandas as pd

tce_table = pd.read_csv(tces_file, header=0)
print(len(tce_table))
tce_table.head(3)

In [None]:
series = {}

for filename in filenames:
  tfr = tf.data.TFRecordDataset(filename)
  num_records = 0
  for record in tfr:
    num_records += 1
    ex = tf.train.Example.FromString(record.numpy())
    for k in ex.features.feature.keys():
      f = ex.features.feature[k]
      if f.int64_list.value:
        v = f.int64_list.value[0]
      elif f.float_list.value:
        v = f.float_list.value[0]
      elif f.bytes_list.value:
        v = f.bytes_list.value[0].decode()
      else:
        continue

      if k not in series:
        series[k] = []
      series[k].append(v)
  print(filename, num_records)

In [None]:
import pandas as pd

examples_table = pd.DataFrame.from_dict(series)

pd.set_option('display.max_columns', None)
examples_table.describe()

In [None]:
from matplotlib import pyplot as plt

labels = ['E', 'J', 'N', 'S', 'B']
counts = [sum(examples_table['disp_{}'.format(l)] > 0) for l in labels]
ax = plt.bar(labels, counts)
for i in range(len(labels)):
    b = ax[i]
    height = b.get_height()
    x, y = b.get_xy()
    plt.annotate(
        '{:.0%}'.format(counts[i] / sum(counts)),
        (x + 0.1, y + height + 100))

In [None]:
examples_table.describe()

In [None]:
examples_table.head(3)

In [None]:
import numpy as np

np.array(set(tce_table[tce_table['disp_E'] > 0]['tic_id'].values)
    - set(examples_table[examples_table['disp_E'] > 0]['tic_id'].values))

In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

from astronet.preprocess import preprocess


tess_data_dir = '/mnt/tess/lc'

def find_tce(tic_id):
  with tf.device('cpu'):
    for filename in filenames:
      tfr = tf.data.TFRecordDataset(filename)
      for record in tfr:
        ex = tf.train.Example.FromString(record.numpy())
        if (ex.features.feature["tic_id"].int64_list.value[0] == tic_id):
          print('TIC ID:', tic_id)
          print('E:', ex.features.feature["disp_E"].int64_list.value[0])
          print('N:', ex.features.feature["disp_N"].int64_list.value[0])
          print('J:', ex.features.feature["disp_J"].int64_list.value[0])
          print('S:', ex.features.feature["disp_S"].int64_list.value[0])
          print('Duration:', ex.features.feature["Duration"].float_list.value[0])
          return ex

    raise ValueError("{} not found in files: {}".format(tic_id, filenames))

In [None]:
import tensorflow as tf

from astronet import models
from astronet.astro_cnn_model import input_ds

config = models.get_model_config('AstroCNNModel', 'extended')

ds = input_ds.build_dataset(
      file_pattern=file_pattern,
      input_config=config.inputs,
      batch_size=1,
      include_labels=False,
      shuffle_filenames=False,
      repeat=1,
      include_identifiers=True)

all_tics = []
for d in ds:
  def lam(e):
    if e.dtype == tf.int64:
        return e
    if tf.reduce_any(tf.math.is_nan(e)):
        tf.print(e, summarize=-1)
        raise ValueError('data has NaNs.')
    return e
  tic_id = d[1].numpy().item()
  all_tics.append(tic_id)
  try:
    tf.nest.map_structure(lam, d)
  except ValueError as e:
    print(e)
    print(d[1])
    break
else:
  print('No NaNs found.')

if len(all_tics) == len(set(all_tics)):
  print('No duplicates found.')
else:
  print('Found duplicates!', len(all_tics) - len(set(all_tics)))

In [None]:
ds = ds.cache()
def plot_ds_tce(ds, tic_id):
    for d in ds:
        if d[1] == tic_id:
            for k, v in d[0].items():
                if k.startswith('local_'):
                    continue
                if k.startswith('global_'):
                    continue
                if k.startswith('secondary_'):
                    continue
                print(f'{k:25}: {v.numpy().item()}')
            global_view = np.array(d[0]['global_view'][0].numpy())
            local_view = np.array(d[0]['local_view'][0].numpy())
            secondary_view = np.array(d[0]['secondary_view'][0].numpy())
            fig, axes = plt.subplots(2, 3, figsize=(20, 12))
            axes[0, 0].plot(global_view, '.-')
            axes[0, 1].plot(local_view, '.-')
            axes[0, 2].plot(secondary_view, '.-')
            axes[1, 0].plot(d[0]['global_mask'][0].numpy(), '.-')
            axes[1, 1].plot(d[0]['global_view_0.3'][0].numpy(), '.-')
            axes[1, 2].plot(d[0]['global_view_5.0'][0].numpy(), '.-')
            plt.show()
            plt.close('all')
            return

In [None]:
tic_id = 188825647
plot_ds_tce(ds, tic_id)
tce_table[tce_table['tic_id'] == tic_id]

In [None]:
tce = find_tce(188825647)

list(tce.features.feature.keys())