In [None]:
import numpy as np
import pandas as pd
import os
import sys
from astropy import units as u

import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
from matplotlib import rc
plt.style.use('classic')
rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex=True)
rc('figure', facecolor='w')
rc('xtick', labelsize=20)
rc('ytick', labelsize=20)

from dtaidistance import dtw, clustering
from dtaidistance import dtw_visualisation as dtwvis

from sklearn.preprocessing import MinMaxScaler
from scipy.cluster.hierarchy import dendrogram, linkage

sys.path.append('/astro/users/jbirky/projects/tess_binaries')
os.environ['TESS_DATA'] = '/data/epyc/projects2/tess'

import tess_binaries as tb

In [None]:
tess_xm = pd.read_csv(tb.cat_dir + '/asassn_tess_xm.csv.gz')
psamp = tess_xm[~np.isnan(tess_xm['period'])]
ref = psamp[psamp['period'] < 28]

In [None]:
tsteps = 100
tarr = np.arange(0,tsteps,1)
pharr = np.linspace(0,1,tsteps)

farr = []
ids = []
periods = []
types = []

demo = [5,16,23, 20,25,28, 4,6,9, 14,21,29, 62,117,123]
# demo = np.arange(30,70)
# demo = [5,16,20,25,4,62]

for i in demo:
    tic_id = list(ref['tic_id'])[i]
    sec = list(ref['sector'])[i]
    typ = list(ref['Type'])[i]
    per = list(ref['period'])[i]
#     print(i, typ)
    
    data_full = tb.readSourceFiles(tic_id, sector=sec)[0]  
    data      = data_full.fold(period=per*u.day) 
    bin_flux  = tb.binData(data, tsteps)
    
    flux = np.roll(np.array(bin_flux), tsteps-np.argmin(bin_flux))
    
    data = np.vstack([tarr, flux]).T
    scaler = MinMaxScaler()
    scaler.fit(data)
    flux = scaler.transform(data).T[1]

    farr.append(flux)
    ids.append(tic_id)
    periods.append(per)
    types.append(typ)

In [None]:
model = clustering.HierarchicalTree(dists_fun=dtw.distance_matrix_fast, dists_options={})
cluster_idx = model.fit(farr)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 10))
model.plot("myplot.png", axes=ax, show_ts_label=types,
           show_tr_label=True, ts_label_margin=-10,
           ts_left_margin=10, ts_sample_length=1)

In [None]:
model3 = clustering.LinkageTree(dtw.distance_matrix_fast, {})
cluster_idx = model3.fit(farr)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 20))
model.plot("myplot2.png", axes=ax, show_ts_label=types,
           show_tr_label=True, ts_label_margin=0,
           ts_left_margin=0, ts_sample_length=1)

In [None]:
d = dtw.distance_fast(farr[0], farr[1])

In [None]:
d