In [None]:
"""
TESS/ZTF Transient Classification Project

This notebook can be used to run the entire ML pipeline in one place.
"""

In [None]:
"""
Step 1: Preprocessing

Reads in raw light curve data and classification info and processes the the data.

Processing steps include:
    - numeric encodings of class
    - use of TESS/ZTF filter IDs
    - timestep creation for each filter ID occurance
    - cut light curves between specific time range
    - skip light curves with no data
    - data augmentation so all light curves have same # of timesteps
"""
from pre_process import read_raw_data, prepare_NN_data,plot_specific

# define file paths
lc_path='/Users/drewj/Documents//Urops/Muthukrishna/data/processed_curves/'
transient_path='/Users/drewj/Documents/Urops/Muthukrishna/data/all_transients.csv'

# get DataFrames from raw file data
light_curves, original_curves, all_transients=read_raw_data(lc_path,transient_path)

# plot specific light curve
plot_specific(original_curves,'2018fzi')

# get data prepared for neural network model
prepared_data=prepare_NN_data(light_curves)


In [None]:
"""
Step 2: Recurrent Variational Autoencoder
*Unsupervised*
    
Builds a variational autoencoder that takes in time-series
light curve data and produces lower-dimensional representations to be used
for classificiation.

Trains and tests the model, extracts the encoder.

Plots a 2D t-SNE representation of light curves in their latent space.
"""
from NN_model import RVAE

# make RVAE object
rvae= RVAE(prepared_data)

# split prepped data into training and testing sets
x_train, x_test, y_train, y_test = rvae.split_prep_data()

# build model and encoder 
model,encoder=rvae.build_connected_model()

# get trained model
trained_model=rvae.train_model(model,x_train, x_test, y_train, y_test)

# test the trained model
rvae.test_model(trained_model,x_test,y_test)

# t-SNE plot
rvae.t_SNE_plot(light_curves,encoder)

In [None]:
"""
Step 3: Balanced Random Forest Classifier
*Supervised*

Creates a Balanced Random Forest Classifier that takes in encoded light curves and classifies them.

Uses the trained encoder from the RVAE model to encode light curves.

Trains the classifier on labeled data, tests on both labeled and unlabeled.
"""
from classify import RandomForest

rf= RandomForest(light_curves,prepared_data,encoder)

# split data set for supervised training
x_train, x_test, y_train, y_test, x_unclassified= rf.create_test_train()

# encode input data
x_train_enc,x_test_enc,x_unclassified_enc=rf.make_encodings(x_train, x_test, x_unclassified)

# build and train the classifier
rf.build_classier(x_train_enc,x_test_enc,x_unclassified_enc,y_train,y_test)

# classify an example light curve
rf.classify(original_curves,filename='2018evo')