# Generate spectra


# MultiREx Autoencoder for Exoplanet Spectra Denoising (Optimized Version)

This script trains a deep learning model to remove noise and stellar contamination
from exoplanet transit spectra, with significant memory and performance optimizations.


## 1. Initial Setup and Imports


In [None]:
import multirex as mrex
import matplotlib.pyplot as plt
# import seaborn as sns
import numpy as np
# import sys
import pandas as pd
import os
import re
import gc
import warnings
# import joblib
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score


## 2. Configuration and Helper Functions


In [None]:

def remove_warnings():
    """Suppresses specified warnings for cleaner output."""
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    warnings.filterwarnings(
        "ignore",
        category=UserWarning,
        message="Pandas doesn't allow columns to be created via a new attribute name*",
    )
    warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)

def configure_gpu_memory_growth():
    """Prevents TensorFlow from allocating all GPU memory at once."""
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
      try:
        for gpu in gpus:
          tf.config.experimental.set_memory_growth(gpu, True)
        print(f"TensorFlow memory growth set for {len(gpus)} GPU(s).")
      except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(f"GPU memory growth could not be set: {e}")

# Initial setup
remove_warnings()
configure_gpu_memory_growth()

# Load and prepare wavelength data
waves = np.loadtxt("waves.txt")
n_points = len(waves)
# Note: The original notebook had redundant lines for 'indices' and 'puntos_seleccionados'
# which essentially re-assigned 'waves' to itself. Keeping it minimal here.
wn_grid = np.sort((10000 / waves)) # This line remains as it's a calculation based on waves

## 3. Data Loading and Preprocessing Functions


In [None]:
def apply_contaminations_from_files(contamination_files, df, n_points):
    """
    Applies multiple stellar contaminations to the spectral data from a list of files.
    Returns a DataFrame with all combinations, including the non-contaminated case.
    This version is optimized to prevent DataFrame fragmentation.
    """
    df_list = []
    # Add non-contaminated case
    # Using .assign() for efficient column creation and .copy() to ensure de-fragmentation
    df_no_contam = df.copy().assign(f_spot=0.0, f_fac=0.0)
    # Reorder columns to place 'f_spot' and 'f_fac' at the beginning
    cols = ["f_spot", "f_fac"] + [col for col in df.columns if col not in ["f_spot", "f_fac"]]
    df_list.append(df_no_contam[cols])
    
    # Regex pattern to extract f_spot and f_fac from the filename
    pattern = r"fspot(?P<f_spot>[0-9.]+)_ffac(?P<f_fac>[0-9.]+)\.txt$"

    for file_path in contamination_files:
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")
        
        filename = os.path.basename(file_path)
        match = re.search(pattern, filename)
        if not match:
            raise ValueError(f"Filename '{filename}' does not match expected pattern.")

        f_spot = float(match.group("f_spot"))
        f_fac = float(match.group("f_fac"))

        try:
            contam_data = np.loadtxt(file_path, ndmin=2)
            # Take the second column if available, otherwise flatten (assuming single column data)
            contam_values = contam_data[:, 1] if contam_data.shape[1] >= 2 else contam_data.flatten()
            if len(contam_values) != n_points:
                raise ValueError(f"Contamination values in '{filename}' ({len(contam_values)}) != n_points ({n_points}).")
        except Exception as e:
            raise ValueError(f"Error reading {file_path}: {e}")

        # Reverse contamination values to match wavelength order if necessary
        contam_values = contam_values[::-1]
        
        df_contam = df.copy()
        data_columns = df_contam.columns[-n_points:]
        # Perform element-wise multiplication for contamination
        df_contam[data_columns] = df_contam[data_columns].multiply(contam_values, axis=1)
        
        # Add contamination parameters as new columns efficiently
        df_contam = df_contam.assign(f_spot=f_spot, f_fac=f_fac)
        df_list.append(df_contam[cols]) # Append with reordered columns

    # Concatenate all DataFrames at once to create a de-fragmented final DataFrame
    df_final = pd.concat(df_list, ignore_index=True).copy()
    # Assign .data and .params attributes for easy access
    df_final.data = df_final.iloc[:, -n_points:]
    df_final.params = df_final.iloc[:, :-n_points]
    return df_final

def filter_rows(df):
    """Filters rows based on atmospheric composition values (e.g., atm CH4 >= -8)."""
    filter_columns = ["atm CH4", "atm O3", "atm H2O"]
    # Filter only if the column exists in the DataFrame
    for chem in [col for col in filter_columns if col in df.columns]:
        df = df[df[chem] >= -8].copy() # Use .copy() to avoid SettingWithCopyWarning and fragmentation
        
    # Re-assign .data and .params attributes after filtering
    df.data = df.iloc[:, -n_points:]
    df.params = df.iloc[:, :-n_points]
    return df

def load_and_prep_data(filepath, n_points):
    """Reads a CSV and converts only the spectral data columns to float32 for memory efficiency."""
    df = pd.read_csv(filepath)
    # Identify spectral data columns (last n_points)
    data_cols = df.columns[-n_points:]
    # Convert only spectral data to float32, leaving other columns (params) as their original type
    df[data_cols] = df[data_cols].astype('float32')
    return df

def normalize_min_max_by_row(df):
    """Normalizes each row of a DataFrame to a [0, 1] range."""
    min_vals = df.min(axis=1)
    max_vals = df.max(axis=1)
    range_vals = max_vals - min_vals
    # Avoid division by zero: if range is zero (all values in row are same), set it to 1 to prevent NaN
    range_vals[range_vals == 0] = 1
    # Perform normalization using vectorized operations
    normalized = (df.sub(min_vals, axis=0)).div(range_vals, axis=0)
    return normalized

def generate_df_with_noise_std(df, n_repeat, noise_std, seed=None):
    """
    Generates a new DataFrame by applying Gaussian noise with the specified standard
    deviation to the spectra in a vectorized manner.
    """
    # Safely get 'params' and 'data' attributes, defaulting to empty DataFrame/original DataFrame if not present
    df_params = getattr(df, 'params', pd.DataFrame())
    df_spectra = getattr(df, 'data', df).astype('float32') # Ensure spectral data is float32
    
    if seed is not None:
        np.random.seed(seed)
    
    # Replicate spectral data values
    replicated_spectra_vals = np.repeat(df_spectra.values, n_repeat, axis=0)
    
    # Prepare noise array based on whether noise_std is a single value or an array
    if isinstance(noise_std, (int, float)):
        noise = np.random.normal(0, noise_std, replicated_spectra_vals.shape).astype('float32')
    else:
        # For array-like noise_std, replicate and tile to match the shape of spectral data
        noise_array = np.array(noise_std, dtype='float32')
        noise_replicated = np.tile(np.repeat(noise_array[:, np.newaxis], n_repeat, axis=0), (1, df_spectra.shape[1]))
        noise = np.random.normal(0, noise_replicated, replicated_spectra_vals.shape).astype('float32')
        
    # Apply noise to replicated spectra
    noisy_spectra = pd.DataFrame(replicated_spectra_vals + noise, columns=df_spectra.columns)

    # Replicate parameters DataFrame
    replicated_params = pd.DataFrame(np.repeat(df_params.values, n_repeat, axis=0), columns=df_params.columns)
    
    # Create new columns for 'noise_std' and 'n_repeat'
    new_cols_data = {
        "noise_std": np.repeat(noise_std, n_repeat * len(df_params)) if isinstance(noise_std, (list, np.ndarray)) else noise_std,
        "n_repeat": n_repeat
    }
    new_cols_df = pd.DataFrame(new_cols_data)

    # Concatenate all parts (new columns, replicated parameters, noisy spectra) at once
    df_final = pd.concat([new_cols_df, replicated_params.reset_index(drop=True), noisy_spectra.reset_index(drop=True)], axis=1)
    
    # Re-assign .data and .params attributes
    df_final.data = df_final.iloc[:, -n_points:]
    df_final.params = df_final.iloc[:, :-n_points]
    return df_final



## 4. Load and Augment Source Data


In [None]:
# List of stellar contamination files
contamination_files = [
    "stellar_contamination/TRAPPIST-1_contam_fspot0.01_ffac0.08.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.01_ffac0.54.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.01_ffac0.70.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.08_ffac0.08.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.08_ffac0.54.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.08_ffac0.70.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.26_ffac0.08.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.26_ffac0.54.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.26_ffac0.70.txt",
]

# Load and process initial datasets, applying contaminations and filtering where necessary
try:
    airless_data = load_and_prep_data("spec_data/airless_data.csv", n_points)
    airless_data = apply_contaminations_from_files(contamination_files, airless_data, n_points)
    
    CO2_data = load_and_prep_data("spec_data/CO2_data.csv", n_points)
    CO2_data = apply_contaminations_from_files(contamination_files, CO2_data, n_points)
    
    # Apply filter_rows only if "atm CH4" etc. columns are relevant for this dataset
    CH4_data = load_and_prep_data("spec_data/CH4_data.csv", n_points)
    CH4_data = filter_rows(CH4_data) # This applies the filter for CH4, O3, H2O
    CH4_data = apply_contaminations_from_files(contamination_files, CH4_data, n_points)
except Exception as e:
    print(f"Error during initial data loading and processing: {e}")

# Create "clean" versions for training targets (without additional contamination for denoising target)
try:
    airless_data_clean = load_and_prep_data("spec_data/airless_data.csv", n_points)
    CO2_data_clean = load_and_prep_data("spec_data/CO2_data.csv", n_points)
    
    CH4_data_clean = load_and_prep_data("spec_data/CH4_data.csv", n_points)
    CH4_data_clean = filter_rows(CH4_data_clean) # Apply the same filtering as for noisy CH4 data
except Exception as e:
    print(f"Error processing clean data for targets: {e}")
    


## 5. OPTIMIZED: Generate, Process, and Cache Full Dataset

This cell implements the caching logic. It first checks if the final NumPy arrays
(X_noisy and X_no_noisy) exist on disk. If they do, it loads them directly.
If not, it generates them in a memory-efficient loop and saves them for future runs.


In [None]:

    
output_dir = "processed_data"
os.makedirs(output_dir, exist_ok=True) # Ensure the directory exists
noisy_path = os.path.join(output_dir, "X_noisy_full_dataset.npy")
clean_path = os.path.join(output_dir, "X_clean_full_dataset.npy")

# Caching Logic: Check if files exist
if os.path.exists(noisy_path) and os.path.exists(clean_path):
    print("Found cached data. Loading from disk...")
    X_noisy = np.load(noisy_path)
    X_no_noisy = np.load(clean_path)
    print("...Loading complete.")
else:
    print("Cached data not found. Starting data generation process (this may take a while)...")
    
    list_of_noisy_arrays = []
    list_of_clean_arrays = []
    # Using None for the 'Nan' (no additional noise) case
    snr_values = [1, 3, 6, 10, None] 

    for snr in snr_values:
        print(f"--- Processing SNR = {snr if snr is not None else 'inf (no additional noise)'} ---")
        gc.collect() # Explicitly collect garbage before processing each SNR level

        # Determine noise_std for the current SNR level
        # mrex.generate_df_SNR_noise returns a DataFrame, we need the noise array from it.
        # The [0] index is used to get the actual noise array from the DataFrame's 'noise' column.
        noise = mrex.generate_df_SNR_noise(df=CO2_data, n_repeat=1, SNR=snr)["noise"][0] if snr is not None else 0.0

        # Generate noisy data for each planet type and concatenate
        temp_CO2_noisy = generate_df_with_noise_std(df=CO2_data, n_repeat=5000, noise_std=noise)
        temp_CH4_noisy = generate_df_with_noise_std(df=CH4_data, n_repeat=500, noise_std=noise)
        temp_airless_noisy = generate_df_with_noise_std(df=airless_data, n_repeat=5000, noise_std=noise)
        current_noisy_df = pd.concat([temp_CO2_noisy, temp_CH4_noisy, temp_airless_noisy], ignore_index=True)
        
        # Normalize and append to the list of noisy arrays
        normalized_noisy_data = normalize_min_max_by_row(current_noisy_df.iloc[:, -n_points:])
        list_of_noisy_arrays.append(normalized_noisy_data.values.astype("float32"))
        
        # Clean up intermediate DataFrames to free memory
        del temp_CO2_noisy, temp_CH4_noisy, temp_airless_noisy, current_noisy_df, normalized_noisy_data
        gc.collect()

        # Generate corresponding clean data (n_repeat should match noisy data, noise_std=0)
        temp_CO2_clean = generate_df_with_noise_std(df=CO2_data_clean, n_repeat=5000, noise_std=0)
        temp_CH4_clean = generate_df_with_noise_std(df=CH4_data_clean, n_repeat=500, noise_std=0)
        temp_airless_clean = generate_df_with_noise_std(df=airless_data_clean, n_repeat=5000, noise_std=0)
        current_clean_df = pd.concat([temp_CO2_clean, temp_CH4_clean, temp_airless_clean], ignore_index=True)
        
        # Normalize and append to the list of clean arrays
        normalized_clean_data = normalize_min_max_by_row(current_clean_df.iloc[:, -n_points:])
        list_of_clean_arrays.append(normalized_clean_data.values.astype("float32"))

        # Clean up intermediate DataFrames
        del temp_CO2_clean, temp_CH4_clean, temp_airless_clean, current_clean_df, normalized_clean_data
        gc.collect()

    print("\n--- Final Data Concatenation ---")
    # Concatenate all collected NumPy arrays into final noisy and clean datasets
    X_noisy = np.concatenate(list_of_noisy_arrays, axis=0)
    del list_of_noisy_arrays # Free memory from the list of arrays

    X_no_noisy = np.concatenate(list_of_clean_arrays, axis=0)
    del list_of_clean_arrays # Free memory from the list of arrays

    gc.collect() # Final garbage collection after concatenation
    
    print("\n--- Saving generated data to disk for future runs ---")
    np.save(noisy_path, X_noisy)
    np.save(clean_path, X_no_noisy)
    print(f"Data saved to directory: '{output_dir}'")

# Print final shapes and assert consistency
print(f"\nFinal noisy data shape: {X_noisy.shape}")
print(f"Final clean data shape: {X_no_noisy.shape}")
assert X_noisy.shape[0] == X_no_noisy.shape[0], "The number of samples does not match between noisy and clean datasets."


## 6. Build and Train Autoencoder Model


### Create Optimized TensorFlow Datasets

Using `tf.data.Dataset` is the most memory-efficient way to handle large datasets for training.
It streams data to the GPU in batches, preventing VRAM overflow during validation.


In [None]:
# Define constants for splitting and batching
BATCH_SIZE = 64
TEST_SIZE = 0.2
RANDOM_STATE = 42

# Split data into training and testing sets
# Crucially, perform this split on CPU to avoid immediate VRAM overflow if X_noisy/X_no_noisy are very large
with tf.device('/CPU:0'):
    X_train_noisy, X_test_noisy, X_train_clean, X_test_clean = train_test_split(
        X_noisy, X_no_noisy, test_size=TEST_SIZE, random_state=RANDOM_STATE
    )

# Clean up original large NumPy arrays from RAM immediately after creating splits
del X_noisy, X_no_noisy
gc.collect()

# Create efficient TensorFlow dataset pipelines
# .shuffle() for randomness during training
# .batch() for processing data in chunks
# .prefetch(tf.data.AUTOTUNE) overlaps data preprocessing and model execution for performance
train_dataset = tf.data.Dataset.from_tensor_slices((X_train_noisy, X_train_clean))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

validation_dataset = tf.data.Dataset.from_tensor_slices((X_test_noisy, X_test_clean))
validation_dataset = validation_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# Clean up the split NumPy arrays from RAM as they are now in TensorFlow Datasets
del X_train_noisy, X_test_noisy, X_train_clean, X_test_clean
gc.collect()

print("TensorFlow datasets created and source NumPy arrays cleared from RAM.")
# Note: We can't print shapes of X_train_noisy/X_test_noisy here as they are deleted.
# The shapes are implicitly handled by the dataset's element_spec.


### Define and Compile the Dense Autoencoder


### FInal data


In [None]:
# Infer input dimension directly from the dataset specification, as original arrays are deleted
input_dim = train_dataset.element_spec[0].shape[1]

input_spectrum = keras.Input(shape=(input_dim,))

# Encoder layers
encoded = layers.Dense(512, activation="swish")(input_spectrum)
encoded = layers.Dropout(0.2)(encoded)
encoded = layers.Dense(512, activation="swish")(encoded)
encoded = layers.Dropout(0.2)(encoded)
encoded = layers.Dense(512, activation="swish")(encoded)
encoded = layers.Dropout(0.2)(encoded)
encoded = layers.Dense(300, activation="swish")(encoded)
encoded = layers.Dropout(0.2)(encoded)
encoded = layers.Dense(300, activation="swish")(encoded)
encoded = layers.Dropout(0.2)(encoded)

# Decoder layers
decoded = layers.Dense(300, activation="swish")(encoded)
decoded = layers.Dropout(0.2)(decoded)
decoded = layers.Dense(300, activation="swish")(decoded)
decoded = layers.Dropout(0.2)(decoded)
decoded = layers.Dense(512, activation="swish")(decoded)
decoded = layers.Dropout(0.2)(decoded)
decoded = layers.Dense(512, activation="swish")(decoded)
decoded = layers.Dropout(0.2)(decoded)
decoded = layers.Dense(512, activation="swish")(decoded)
decoded = layers.Dropout(0.2)(decoded)
decoded = layers.Dense(input_dim, activation="linear")(decoded) # Output layer matches input dimension

autoencoder = keras.Model(inputs=input_spectrum, outputs=decoded)
optimizer = Adam(learning_rate=0.00001) # Adam optimizer with a small learning rate
autoencoder.compile(optimizer=optimizer, loss="mae") # Compile with Mean Absolute Error as loss

autoencoder.summary() # Print model summary


### 7. Train the Model


In [None]:
# Train the autoencoder using the efficient tf.data.Dataset pipelines
history = autoencoder.fit(
    train_dataset, # Training data pipeline
    epochs=100, # Maximum number of epochs
    validation_data=validation_dataset, # Validation data pipeline
    callbacks=[
        keras.callbacks.EarlyStopping(
            monitor="val_loss", # Monitor validation loss
            patience=5, # Stop if val_loss doesn't improve for 5 epochs
            restore_best_weights=True # Restore model weights from the epoch with the best val_loss
        )
    ],
)

# Save the trained model for future use
autoencoder.save("AE_CH4.keras") # Model saved in Keras native format
print("Model training complete and saved to AE_CH4.keras")


## 8. Evaluate and Visualize Results


In [None]:
# Plot training and validation MAE
print("Plotting training history...")
plt.figure(figsize=(10, 6))
plt.plot(history.history["loss"], label="Training MAE")
plt.plot(history.history["val_loss"], label="Validation MAE")
plt.title("Model MAE Progress During Training")
plt.xlabel("Epochs")
plt.ylabel("Mean Absolute Error (MAE)")
plt.legend()
plt.grid(True)
plt.show()

# Predict reconstructed spectra on the test data
# We need to re-load the test data from the cached files because X_test_noisy/X_test_clean were deleted
# to save RAM after creating the tf.data.Dataset.
print("\nLoading test data for final evaluation and prediction...")
# Ensure noisy_path and clean_path are defined (they should be from the caching cell)
X_noisy_full = np.load(noisy_path)
X_clean_full = np.load(clean_path)

# Perform the exact same train-test split to get the identical test set
_, X_test_noisy_eval, _, X_test_clean_eval = train_test_split(
    X_noisy_full, X_clean_full, test_size=TEST_SIZE, random_state=RANDOM_STATE
)

# Clean up the full arrays again as we only need the test set for evaluation
del X_noisy_full, X_clean_full
gc.collect()
print("Test data re-loaded successfully.")

print("\nPredicting on test data...")
# Predict using the re-loaded test data
decoded_spectra = autoencoder.predict(X_test_noisy_eval, batch_size=BATCH_SIZE)

# Visualize a few reconstructions
print("Visualizing sample reconstructions...")
num_samples = 5 # Number of samples to visualize
indices = np.random.choice(len(X_test_noisy_eval), num_samples, replace=False)

for idx in indices:
    plt.figure(figsize=(12, 5))
    plt.plot(waves, X_test_clean_eval[idx].flatten(), label="Original Clean Spectrum", color='blue')
    plt.plot(waves, X_test_noisy_eval[idx].flatten(), label="Noisy Input Spectrum", color='gray', alpha=0.6)
    plt.plot(
        waves,
        decoded_spectra[idx].flatten(),
        label="Denoised (Reconstructed) Spectrum",
        linestyle="--",
        color='red'
    )
    plt.xlabel("Wavelength")
    plt.ylabel("Normalized Intensity")
    plt.title(f"Spectrum Reconstruction - Sample {idx}")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()

print("Evaluation complete.")

### Final Performance Metrics


In [None]:
print("\nCalculating final performance metrics...")
# Calculate Mean Absolute Error (MAE)
mae = mean_absolute_error(X_test_clean_eval, decoded_spectra)
print(f"Final Mean Absolute Error (MAE): {mae:.6f}")

# Calculate Mean Squared Error (MSE)
mse = mean_squared_error(X_test_clean_eval, decoded_spectra)
print(f"Final Mean Squared Error (MSE): {mse:.6f}")

# Calculate Coefficient of Determination (R²)
# Flatten arrays for R² score calculation if they are multi-dimensional
r2 = r2_score(X_test_clean_eval.flatten(), decoded_spectra.flatten())
print(f"Final Coefficient of Determination (R²): {r2:.6f}")
