In [None]:
import h5py
import numpy as np
import pandas as pd
import os, sys
from astropy import units as u
import random

import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
from matplotlib import rc
plt.style.use('classic')
rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex=True)
rc('figure', facecolor='w')
rc('xtick', labelsize=20)
rc('ytick', labelsize=20)

from dtaidistance import dtw, clustering
from dtaidistance import dtw_visualisation as dtwvis

from sklearn.preprocessing import MinMaxScaler
from scipy.cluster.hierarchy import dendrogram, linkage

sys.path.append('/astro/users/jbirky/projects/tess_binaries')
os.environ['TESS_DATA'] = '/data/epyc/projects2/tess'

import tess_binaries as tb

In [None]:
tess_xm = pd.read_csv(tb.cat_dir + '/asassn_tess_xm.csv.gz')
psamp = tess_xm[~np.isnan(tess_xm['period'])]
ref = psamp[psamp['period'] < 28]

In [None]:
bad_ids = np.array([368, 212, 276, 275, 483, 1103, 381, 1105, 957, 127, 1116, 13, 961, 1054, 840, 544, 136, 596, 17, 1122, 19, 231, 752, 967, 969, 1127, 547, 1142, 549, 554, 603, \
           1154, 759, 662, 605, 413, 32, 856, 860, 861, 977, 766, 317, 1192, 984, 563, 869, 609, 680, 162, 47, 872, 990, 171, 684, 173, 11209, 1213, 1215, 512, 175, 514, \
           1071, 997, 262, 441, 1002, 777, 78, 1003, 79, 520, 180, 778, 701, 444, 1011, 1015, 780, 709, 712, 190, 191, 714, 577, 784, 345, 902, 905, 529, 908, 1269, 1080, \
           618, 619, 1018, 348, 1270, 1272, 622, 1274, 1277, 716, 792, 349, 578, 1028, 1090, 794, 626, 450, 1091, 923, 534, 628, 799, 802, 805, 806, 807, 808, 110, 934, 937, \
           1299, 581, 942, 943, 539, 814, 949, 952, 1047, 1050, 737, 479, 1051, 0, 285, 8, 845, 749, 141, 655, 756, 501, 604, 857, 665, 562, 65, 775, 440, 1242, 613, 81, \
           1267, 616, 353, 922, 533, 451, 1296, 936, 1040, 939, 118, 582, 473, 561, 657, 999, 1073, 1244, 516, 178, 1007, 1265, 1016, 1075, 1019, 198, 1021, 1023, 1037, 1039, \
           580, 119, 1044, 1045, 480, 747, 966, 652, 1132, 560, 868, 1234 , 572, 614, 87, 267, 1087, 1026, 1032, 200, 816, 821])
good_ids = np.array(list(set(np.arange(0,len(ref))) - set(bad_ids)))

In [None]:
df = {'tic_id':[], 'type':[], 'period':[], 'sector':[]}
for i in good_ids:
    df['tic_id'].append(list(ref['tic_id'])[i])
    df['type'].append(list(ref['Type'])[i])
    df['period'].append(list(ref['period'])[i])
    df['sector'].append(list(ref['sector'])[i])
train = pd.DataFrame(data=df)

In [None]:
tsteps = 100
tarr = np.arange(0,tsteps,1)
pharr = np.linspace(0,1,tsteps)

table = train
# demo = random.choices(np.arange(0,len(train)), k=N)
sample = {'tic_id':[], 'type':[], 'period':[], 'flux':[]}

for i in range(len(table)):
    try:
        data_full = tb.readSourceFiles(table['tic_id'][i], sector=table['sector'][i])[0]     
        data      = data_full.fold(period=table['period'][i]*u.day) 
        bin_flux  = tb.binData(data, tsteps)

        flux = np.roll(np.array(bin_flux), tsteps-np.argmin(bin_flux))

        data = np.vstack([tarr, flux]).T
        scaler = MinMaxScaler()
        scaler.fit(data)
        flux = scaler.transform(data).T[1]

        sample['flux'].append(flux)
        sample['tic_id'].append(table['tic_id'][i])
        sample['period'].append(table['period'][i])
        sample['type'].append(table['type'][i])
    except:
        print('BAD ID', i, table['tic_id'][i])

In [None]:
df = pd.DataFrame(data=sample)
fname = f'{tb.cat_dir}/asassn_tess_inspected.hdf5'

In [None]:
with h5py.File(fname, 'w') as f:
    f.create_dataset('tic_id', data=df['tic_id'])
    f.create_dataset('period', data=df['period'])
    f.create_dataset('type', data=df['type'], dtype=h5py.special_dtype(vlen=str))
    f.create_dataset('flux', data=df['flux'], dtype=h5py.special_dtype(vlen=np.dtype('float64')))
f.close()

In [None]:
ff = h5py.File(f'{tb.cat_dir}/asassn_tess_inspected.hdf5', mode="r")

dd = {}
for key in list(ff):
    dd[key] = ff[key].value
    
ff.close()