In [None]:
import numpy as np
import pickle
import scipy.stats as stats
from scipy.signal import find_peaks
from altaipony.flarelc import FlareLightCurve

In [None]:
# Load the processed light curve data from pickle file
pickle_path = "processed_lightcurves.pkl"  # Replace with actual path
with open(pickle_path, "rb") as f:
    processed_lightcurves = pickle.load(f)

In [None]:


def generate_synthetic_flare(time, flux, num_flares_range=(1, 10), amplitude_range=(0.01, 0.1), decay_time_range=(0.1, 0.5)):
    """
    Inject synthetic flares into a light curve while preserving pre-flare properties.
    """
    synthetic_flux = flux.copy()
    flare_labels = np.zeros_like(flux)
    num_flares = np.random.randint(*num_flares_range)  # Vary the number of flares per star
    
    for _ in range(num_flares):
        # Randomly select a flare start time ensuring pre-flare state is preserved
        flare_start_idx = np.random.randint(0, len(time) - 1)
        amplitude = np.random.uniform(*amplitude_range) * np.median(flux)
        decay_time = np.random.uniform(*decay_time_range) * (time[-1] - time[0]) / len(time)
        
        # Create an exponential decay flare while preserving pre-flare flux properties
        for i in range(flare_start_idx, len(time)):
            time_diff = time[i] - time[flare_start_idx]
            if time_diff < 0:
                continue
            decay_flux = amplitude * np.exp(-time_diff / decay_time)
            synthetic_flux[i] += decay_flux
            if decay_flux > 0.01 * amplitude:
                flare_labels[i] = 1
    
    return synthetic_flux, flare_labels

# Apply synthetic flare injection to each star type
synthetic_lightcurves = {}
for tic_id, data in processed_lightcurves.items():
    time = data["time"]
    flux = data["detrended_flux"]
    
    # Use original flare detection method before injection
    flc = FlareLightCurve(time=time, flux=flux)
    flcd = flc.detrend("savgol")
    flcd = flcd.find_flares()
    
    # Inject synthetic flares while keeping pre-flare properties
    synthetic_flux, labels = generate_synthetic_flare(time, flcd.detrended_flux)
    
    # Store new dataset
    synthetic_lightcurves[tic_id] = {
        "time": time,
        "flux": synthetic_flux,
        "original_flux": flux,
        "detrended_flux": flcd.detrended_flux.copy(),
        "flares": flcd.flares.copy() if flcd.flares is not None else None,
        "labels": labels,
        "pre_flare_flux": flux.copy()  # Preserve pre-flare state for ML predictions
    }



In [None]:
# Save the synthetic dataset
synthetic_pickle_path = "synthetic_lightcurves.pkl"
with open(synthetic_pickle_path, "wb") as f:
    pickle.dump(synthetic_lightcurves, f)

print("Synthetic dataset created with variable flares, preserved pre-flare properties, and consistent flare detection.")
