In [None]:
import sys, os
from os.path import join, abspath, exists, pardir
import tomlkit
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from astropy.table import Table, join as tab_join, vstack
import copy
from numpy import array as a
from sklearn.neighbors import KernelDensity
import pycatch22
from tempfile import TemporaryDirectory
import subprocess
from astropy.stats import sigma_clipped_stats, sigma_clip
from supersmoother import SuperSmoother
import scipy.ndimage

import astropy.units as u
from astropy.timeseries import TimeSeries, aggregate_downsample
from astropy.time import TimeDelta, Time

from tqdm import tqdm

from lightcurve import ASASSN_Lightcurve
from utils import read_config, n_hist_bins, hist
from data_preparation import prepare_lc, plot_prepared
from metrics import fill_lc_gaps, calc_metrics, short_metric_names

In [None]:
import warnings
warnings.filterwarnings('ignore', message='.*dubious year')

In [None]:
cfg = read_config("config.toml")
data_dir = cfg["data_dir"]
data_link = cfg["data_link"]

In [None]:
!cd $(data_dir) && curl $(data_link)

In [None]:
cleaned_dir = join(abspath(join(data_dir,os.pardir)),"cleaned")

csv_path = join(abspath(join(data_dir,os.pardir)),"asassn_rounded.csv")

outdir = "out"
os.makedirs(outdir,exist_ok=True)

def out(fname): return join(outdir,fname)
def savefig(fname): plt.savefig(out(fname),dpi=300,bbox_inches="tight")
def load_old_lc(fname): return ASASSN_Lightcurve.from_dat_file(join(data_dir,fname))
def load_cleaned_lc(fname): return ASASSN_Lightcurve.from_cleaned_file(join(cleaned_dir,fname))
rng = np.random.default_rng()
metadata = Table.read(csv_path)
names = [ASASSN_Lightcurve.filename_from_id(i) for i in metadata["ID"]]

In [None]:
def process_and_calculate_metrics(row,lc_is_cleaned=False):
    fname = ASASSN_Lightcurve.filename_from_id(row["ID"])
    if lc_is_cleaned:
        lc = load_cleaned_lc(fname)
    else:
        lc = load_old_lc(fname)

    d_mag, d_times = prepare_lc(lc,row,do_preprocess=not lc_is_cleaned)
    return calc_metrics(d_mag, d_times, lc.cadence)

In [None]:
colnames = short_metric_names
data = []
classes = []
ids = []
for row in metadata[:10]:
    metrics = process_and_calculate_metrics(row)
    classification = row["ML_classification"]
    data.append(metrics)
    ids.append(row['ID'])
    classes.append(classification)
    

tab = Table(data=a(data),names=colnames)
tab["class"] = classes
tab["ID"] = ids
tab

In [None]:
savedir = config["metric_savedir"]
os.makedirs(savedir,exist_ok=True)

In [None]:
colnames = short_metric_names
data = []
classes = []
ids = []

checkpoint_interval = 5000

start_at = 350001
current_iter = start_at
for i, row in enumerate(tqdm(metadata[start_at:])):
    try:
        metrics = process_and_calculate_metrics(row)
        classification = row["ML_classification"]
        data.append(metrics)
        ids.append(row['ID'])
        classes.append(classification)
    except Exception as e:
        with open(join(savedir,"errors.txt"),"a+") as f:
            f.write(f"Couldn't preprocess {row['ID']}: {e}\n")
    current_iter = i + start_at
    if current_iter and current_iter%checkpoint_interval == 0:
        tab = Table(data=a(data),names=colnames)
        tab["class"] = classes
        tab["ID"] = ids
        tab.write(join(savedir,f"{current_iter-checkpoint_interval}_{current_iter}.csv"),overwrite=True)
        data = []
        classes = []
        ids = []
        
tab = Table(data=a(data),names=colnames)
tab["class"] = classes
tab["ID"] = ids
tab.write(join(savedir,"final.csv"),overwrite=True)

In [None]:
tables = []
dir_list = os.listdir(savedir)
for f in dir_list:
    print(f)
    tables.append(Table.read(os.path.join(savedir, f)))

t = vstack(tables)
unique_classes = np.unique(t["class"])
n=[len(np.where(t["class"] == cls)[0]) for cls in unique_classes]

In [None]:
cutoff = 5000  # classes with fewer than this many samples will be dropped (for now)

In [None]:
allowed_classes_idx = np.where(np.array(n) > cutoff)[0]
allowed_classes = np.array(unique_classes[allowed_classes_idx])
t = t[np.isin(t["class"], allowed_classes)]
rng = np.random.default_rng()
i = np.arange(len(t))
rng.shuffle(i)
t = t[i]

In [None]:
TRAIN = 0.7
VAL = 0.2
TEST = 0.1

In [None]:
t_train = t[:int(TRAIN*len(t))]
t_valid = t[int(TRAIN*len(t)) : int(TRAIN*len(t)) + int(VAL*len(t))]
t_test = t[int(TRAIN*len(t)) + int(VAL*len(t)):]
len(t_train)/len(t), len(t_valid)/len(t), len(t_test)/len(t)

In [None]:
TABDIR = config["table_dir"]

In [None]:
t_train.write(join(TABDIR,"big_train.csv")
t_valid.write(join(TABDIR,"valid.csv")
t_test.write(join(TABDIR,"test.csv")

In [None]:
# balance the classes

n = cutoff
rand_list = []
for c in allowed_classes:
    indices = np.where(t_train["class"] == c)[0]
    if len(indices) >= n:
        chosen_indices = rng.choice(indices, n, replace=False)
    else:
        chosen_indices = indices
    rand_list.extend(chosen_indices)

split_table = t_train[rand_list]

In [None]:
split_table.write(join(TABDIR,"train.csv"),overwrite=True)