In [None]:
import numpy as np
import pickle
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler


In [None]:

# Load the synthetic light curve dataset
synthetic_pickle_path = "synthetic_lightcurves.pkl"
with open(synthetic_pickle_path, "rb") as f:
    synthetic_lightcurves = pickle.load(f)

# Prepare time-series dataset for LSTM
sequence_length = 100  # Number of past timesteps to use for prediction
X, y = [], []

for tic_id, data in synthetic_lightcurves.items():
    flux = np.array(data["flux"])
    labels = np.array(data["labels"])
    
    # Normalize flux values
    scaler = MinMaxScaler()
    flux_scaled = scaler.fit_transform(flux.reshape(-1, 1)).flatten()
    
    # Create time-series sequences
    for i in range(len(flux) - sequence_length):
        X.append(flux_scaled[i : i + sequence_length])  # Use past "sequence_length" values
        y.append(labels[i + sequence_length])  # Predict if flare occurs at next timestep

X, y = np.array(X), np.array(y)

# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], 1))
X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], 1))

# Define LSTM model
model = Sequential([
    LSTM(64, return_sequences=True, input_shape=(sequence_length, 1)),
    Dropout(0.2),
    LSTM(32, return_sequences=False),
    Dropout(0.2),
    Dense(16, activation="relu"),
    Dense(1, activation="sigmoid")  # Binary classification (flare/no flare)
])

# Compile model
model.compile(loss="binary_crossentropy", optimizer=Adam(learning_rate=0.001), metrics=["accuracy"])

# Train model
history = model.fit(X_train, y_train, epochs=20, batch_size=32, validation_data=(X_test, y_test), verbose=1)

# Evaluate model
loss, accuracy = model.evaluate(X_test, y_test)
print(f"Test Accuracy: {accuracy:.4f}")
