In [None]:
import astropy.units as u
import h5py
import lightkurve as lk
import matplotlib.pyplot as plt
import numpy as np
import os
from os.path import exists
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import sys
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

sys.path.append('..')
from file_editing import *
from period_finding import *

In [None]:
# Global variables
cadence = 120

In [None]:
# Commaize if not done already
csv_filename = 'cnn_data.csv' # m
if not exists(csv_filename):
    commaize('raw_cnn_data.csv', csv_filename)

In [None]:
# Create training dataset
df = pd.read_csv(csv_filename)
df = df[['iau_name', 'i', 'porb', 'porbe']] 

In [None]:
'''
    name and ingo blah
'''
def save_data_hdf5(file_name, lightcurve, periodogram, best_period, literature_period, star_name, star_imag, target_length):
    # Determine if the period is probable
    is_real, cutoff = is_real_period(periodogram, best_period)

    # Define folded and binned lightcurve
    phase_lightcurve = lightcurve.fold(period=best_period)
    bin_value = find_bin_value(phase_lightcurve, 50)
    binned_lightcurve = phase_lightcurve.bin(bin_value*u.min) 

    # Lightcurve data
    time = lightcurve.time.value
    flux = lightcurve.flux.value

    # Pad or truncate the time and flux arrays
    if len(time) < target_length:
        time = np.pad(time, (0, target_length - len(time)), mode='constant')
        flux = np.pad(flux, (0, target_length - len(flux)), mode='constant')
    elif len(time) > target_length:
        time = time[:target_length]
        flux = flux[:target_length]
    else:
        time = time
        flux = flux

    # Get the power at the best period
    interp_func = interp1d(periodogram.period.value, periodogram.power.value, kind='linear', bounds_error=False, fill_value=np.nan)
    power_at_best_period = interp_func(best_period)

    # Make an lmfit object and fit it
    model = lmfit.Model(sine_wave)
    params = model.make_params(amplitude=power_at_best_period, frequency=1/best_period, phase=0.0)
    result = model.fit(flux, params, x=time)

    # Save data to HDF5
    with h5py.File(file_name, 'a') as f:
        grp = f.create_group(star_name)
        grp.create_dataset('periodogram_period', data = periodogram.period.value)
        grp.create_dataset('periodogram_power', data = periodogram.power.value)
        grp.create_dataset('time', data = time)
        grp.create_dataset('flux', data = flux)
        grp.create_dataset('binned_phase', data = binned_lightcurve.phase.value)
        grp.create_dataset('binned_flux', data = binned_lightcurve.flux.value)
        grp.create_dataset('fitted_sine_wave', data = result.best_fit)
        grp.create_dataset('residuals', data = flux - result.best_fit)
        grp.attrs['star_imag'] = star_imag
        grp.attrs['real_label'] = is_real

In [None]:
# Check if data has already been added
if exists('training_data.h5'):
    print('File training_data.h5 already exists.')
else:
    for _, row in tqdm(df.head(2000).iterrows(), desc="Processing lightcurves", total = 2000):
        # Pull data for that star
        try:
            result = lk.search_lightcurve(row['iau_name'], mission = 'TESS')
            result_exposures = result.exptime
        except Exception as e:
            # print(f"Error for {row['iau_name']}: {e} \n")
            continue

        lightcurve = append_lightcurves(result, result_exposures, cadence)
        if not lightcurve: continue # check if there was a result with the cadence needed

        # Star data
        star_name = 'TIC ' + str(lightcurve.meta['TICID'])
        star_imag = row['i']
        literature_period = (row['porb']*u.hour).to(u.day).value
        
        # Get periodogram
        periodogram = lightcurve.to_periodogram(oversample_factor = 10, 
                                                minimum_period = (2*cadence*u.second).to(u.day).value, 
                                                maximum_period = 14)
        
        # Determine if the period is probable
        best_period = periodogram.period_at_max_power.value 

        # Save the data
        save_data_hdf5('training_data.h5', lightcurve, periodogram, best_period, literature_period, star_name, star_imag, 10000)

In [None]:
import h5py
import numpy as np

def load_data(hdf5_filename):
    with h5py.File(hdf5_filename, 'r') as f:
        star_names = list(f.keys())
        periodogram_periods = []
        periodogram_powers = []
        times = []
        fluxes = []
        binned_phases = []
        binned_fluxes = []
        fitted_sine_waves = []
        residuals = []
        labels = []
        
        for star_name in star_names:
            group = f[star_name]
            periodogram_periods.append(group['periodogram_period'][:])
            periodogram_powers.append(group['periodogram_power'][:])
            times.append(group['time'][:])
            fluxes.append(group['flux'][:])
            binned_phases.append(group['binned_phase'][:])
            binned_fluxes.append(group['binned_flux'][:])
            fitted_sine_waves.append(group['fitted_sine_wave'][:])
            residuals.append(group['residuals'][:])
            labels.append(group.attrs['real_label'])
    
    return (periodogram_periods, periodogram_powers, times, fluxes, binned_phases, binned_fluxes, fitted_sine_waves, residuals, labels)

In [None]:
# Load the data
periodogram_periods, periodogram_powers, times, fluxes, binned_phases, binned_fluxes, fitted_sine_waves, residuals, labels = load_data('training_data.h5')

In [None]:
def preprocess_data(periodogram_periods, periodogram_powers, times, fluxes, binned_phases, binned_fluxes, fitted_sine_waves, residuals, labels):
    # Ensure all input arrays are numpy arrays
    periodogram_periods = np.array(periodogram_periods)
    periodogram_powers = np.array(periodogram_powers)
    times = np.array(times)
    fluxes = np.array(fluxes)
    binned_phases = np.array(binned_phases)
    binned_fluxes = np.array(binned_fluxes)
    fitted_sine_waves = np.array(fitted_sine_waves)
    residuals = np.array(residuals)
    labels = np.array(labels)

    # Ensure all arrays have the same number of samples
    n_samples = labels.shape[0]
    assert periodogram_periods.shape[0] == n_samples
    assert periodogram_powers.shape[0] == n_samples
    assert times.shape[0] == n_samples
    assert fluxes.shape[0] == n_samples
    assert binned_phases.shape[0] == n_samples
    assert binned_fluxes.shape[0] == n_samples
    assert fitted_sine_waves.shape[0] == n_samples
    assert residuals.shape[0] == n_samples

    # Normalize the data
    scaler = StandardScaler()
    periodogram_periods = scaler.fit_transform(periodogram_periods)
    periodogram_powers = scaler.fit_transform(periodogram_powers)
    times = scaler.fit_transform(times)
    fluxes = scaler.fit_transform(fluxes)
    binned_phases = scaler.fit_transform(binned_phases)
    binned_fluxes = scaler.fit_transform(binned_fluxes)
    fitted_sine_waves = scaler.fit_transform(fitted_sine_waves)
    residuals = scaler.fit_transform(residuals)
    
    # Stack the different features into a single tensor
    X = np.stack([periodogram_periods, periodogram_powers, times, fluxes, 
                  binned_phases, binned_fluxes, fitted_sine_waves, residuals], axis=-1)
    
    # Split data into training and validation sets
    X_train, X_val, y_train, y_val = train_test_split(X, labels, test_size=0.2, random_state=42)
    
    return X_train, X_val, y_train, y_val

In [None]:
# Process the data
X_train, X_val, y_train, y_val = preprocess_data(periodogram_periods, periodogram_powers, times, fluxes, binned_phases, binned_fluxes, fitted_sine_waves, residuals, labels)

In [None]:

# def create_cnn_model(input_shape):
#     model = Sequential([
#         Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
#         MaxPooling2D((2, 2)),
#         Dropout(0.25),
#         Conv2D(64, (3, 3), activation='relu'),
#         MaxPooling2D((2, 2)),
#         Dropout(0.25),
#         Flatten(),
#         Dense(128, activation='relu'),
#         Dropout(0.5),
#         Dense(1, activation='sigmoid')
#     ])
#     return model