# Generate spectra


# MultiREx Autoencoder for Exoplanet Spectra Denoising (All Materials)

This script trains a single, robust deep learning model to remove noise and stellar
contamination from a comprehensive set of exoplanet transit spectra, including various
biosignatures (CH4, H2O, O3), CO2, and airless scenarios.

This version is fully optimized for memory efficiency on both CPU and GPU.


## 1. Initial Setup and Imports


In [20]:
import multirex as mrex
import matplotlib.pyplot as plt
# import seaborn as sns
import numpy as np
# import sys
import pandas as pd
import os
import re
import gc
import warnings
# import joblib
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score


## 2. Configuration and Helper Functions


In [21]:
def remove_warnings():
    """Suppresses specified warnings for cleaner output."""
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    warnings.filterwarnings(
        "ignore",
        category=UserWarning,
        message="Pandas doesn't allow columns to be created via a new attribute name*",
    )
    warnings.filterwarnings("ignore", category=pd.errors.PerformanceWarning)

def configure_gpu_memory_growth():
    """Prevents TensorFlow from allocating all GPU memory at once."""
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
      try:
        for gpu in gpus:
          tf.config.experimental.set_memory_growth(gpu, True)
        print(f"TensorFlow memory growth set for {len(gpus)} GPU(s).")
      except RuntimeError as e:
        print(f"GPU memory growth could not be set: {e}")

# Initial setup
remove_warnings()
configure_gpu_memory_growth()

# Load and prepare wavelength data
waves = np.loadtxt("waves.txt")
n_points = len(waves)
wn_grid = np.sort((10000 / waves))


def apply_contaminations_from_files(contamination_files, df, n_points):
    """
    Applies multiple stellar contaminations to the spectral data.
    This version is optimized to prevent DataFrame fragmentation.
    """
    df_list = []
    # Add non-contaminated case
    df_no_contam = df.copy().assign(f_spot=0.0, f_fac=0.0)
    cols = ["f_spot", "f_fac"] + [col for col in df.columns if col not in ["f_spot", "f_fac"]]
    df_list.append(df_no_contam[cols])
    
    pattern = r"fspot(?P<f_spot>[0-9.]+)_ffac(?P<f_fac>[0-9.]+)\.txt$"

    for file_path in contamination_files:
        if not os.path.isfile(file_path): raise FileNotFoundError(f"File not found: {file_path}")
        
        filename = os.path.basename(file_path)
        match = re.search(pattern, filename)
        if not match: raise ValueError(f"Filename '{filename}' does not match pattern.")

        f_spot = float(match.group("f_spot"))
        f_fac = float(match.group("f_fac"))

        try:
            contam_data = np.loadtxt(file_path, ndmin=2)
            contam_values = contam_data[:, 1] if contam_data.shape[1] >= 2 else contam_data.flatten()
            if len(contam_values) != n_points: raise ValueError(f"Contamination values in '{filename}' != n_points.")
        except Exception as e:
            raise ValueError(f"Error reading {file_path}: {e}")

        df_contam = df.copy()
        data_columns = df_contam.columns[-n_points:]
        df_contam[data_columns] *= contam_values[::-1]
        df_contam = df_contam.assign(f_spot=f_spot, f_fac=f_fac)
        df_list.append(df_contam[cols])

    df_final = pd.concat(df_list, ignore_index=True).copy()
    df_final.data = df_final.iloc[:, -n_points:]
    df_final.params = df_final.iloc[:, :-n_points]
    return df_final

def filter_rows(df):
    """Filters rows based on atmospheric composition values."""
    filter_cols = ["atm CH4", "atm O3", "atm H2O"]
    for chem in [col for col in filter_cols if col in df.columns]:
        df = df[df[chem] >= -8].copy()
    df.data = df.iloc[:, -n_points:]
    df.params = df.iloc[:, :-n_points]
    return df

def load_and_prep_data(filepath, n_points):
    """Reads a CSV and sets appropriate data types for memory efficiency."""
    df = pd.read_csv(filepath)
    df[df.columns[-n_points:]] = df[df.columns[-n_points:]].astype('float32')
    return df

def normalize_min_max_by_row(df):
    """Normalizes each row of a DataFrame to a [0, 1] range."""
    min_vals = df.min(axis=1)
    max_vals = df.max(axis=1)
    range_vals = max_vals - min_vals
    range_vals[range_vals == 0] = 1  # Avoid division by zero
    return (df.sub(min_vals, axis=0)).div(range_vals, axis=0)

def generate_df_with_noise_std(df, n_repeat, noise_std, seed=None):
    """Generates a new DataFrame by applying Gaussian noise to the spectra."""
    df_params = getattr(df, 'params', pd.DataFrame())
    df_spectra = getattr(df, 'data', df).astype('float32')
    
    if seed is not None: np.random.seed(seed)
    
    replicated_spectra_vals = np.repeat(df_spectra.values, n_repeat, axis=0)
    
    if isinstance(noise_std, (int, float)):
        noise = np.random.normal(0, noise_std, replicated_spectra_vals.shape).astype('float32')
    else:
        noise_array = np.array(noise_std, dtype='float32')
        noise_replicated = np.tile(np.repeat(noise_array[:, np.newaxis], n_repeat, axis=0), (1, df_spectra.shape[1]))
        noise = np.random.normal(0, noise_replicated, replicated_spectra_vals.shape).astype('float32')
        
    noisy_spectra = pd.DataFrame(replicated_spectra_vals + noise, columns=df_spectra.columns)

    replicated_params = pd.DataFrame(np.repeat(df_params.values, n_repeat, axis=0), columns=df_params.columns)
    
    new_cols_data = {"noise_std": np.repeat(noise_std, n_repeat * len(df_params)) if isinstance(noise_std, (list, np.ndarray)) else noise_std, "n_repeat": n_repeat}
    new_cols_df = pd.DataFrame(new_cols_data)

    df_final = pd.concat([new_cols_df, replicated_params.reset_index(drop=True), noisy_spectra.reset_index(drop=True)], axis=1)
    df_final.data = df_final.iloc[:, -n_points:]
    df_final.params = df_final.iloc[:, :-n_points]
    return df_final

TensorFlow memory growth set for 1 GPU(s).


## 3. Load and Process All Source Data


In [22]:
contamination_files = [
    "stellar_contamination/TRAPPIST-1_contam_fspot0.01_ffac0.08.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.01_ffac0.54.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.01_ffac0.70.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.08_ffac0.08.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.08_ffac0.54.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.08_ffac0.70.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.26_ffac0.08.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.26_ffac0.54.txt",
    "stellar_contamination/TRAPPIST-1_contam_fspot0.26_ffac0.70.txt",
]

data_sources = {}
try:
    print("Loading and preparing source datasets...")
    # Load all base datasets
    data_sources['airless'] = load_and_prep_data("spec_data/airless_data.csv", n_points)
    data_sources['CO2'] = load_and_prep_data("spec_data/CO2_data.csv", n_points)
    data_sources['CH4'] = filter_rows(load_and_prep_data("spec_data/CH4_data.csv", n_points))
    data_sources['O3'] = filter_rows(load_and_prep_data("spec_data/O3_data.csv", n_points))
    data_sources['H2O'] = filter_rows(load_and_prep_data("spec_data/H2O_data.csv", n_points))
    data_sources['CH4_O3'] = filter_rows(load_and_prep_data("spec_data/CH4_O3_data.csv", n_points))
    data_sources['CH4_H2O'] = filter_rows(load_and_prep_data("spec_data/CH4_H2O_data.csv", n_points))
    data_sources['O3_H2O'] = filter_rows(load_and_prep_data("spec_data/O3_H2O_data.csv", n_points))
    data_sources['CH4_O3_H2O'] = filter_rows(load_and_prep_data("spec_data/CH4_O3_H2O_data.csv", n_points))

    # Prepare two lists: one for clean data, one for contaminated data
    clean_data_list = list(data_sources.values())
    contaminated_data_list = [apply_contaminations_from_files(d, contamination_files, n_points) for d in clean_data_list]
    
    print("...Data loading complete.")
except Exception as e:
    print(f"Error during data loading: {e}")
    
# Define the number of repetitions for each dataset type per SNR level
# These values are taken from the original unoptimized notebook.
n_repeats_per_snr = {
    1: [2000, 2000, 200, 200, 200, 20, 20, 20, 4],
    3: [2000, 2000, 200, 200, 200, 20, 20, 20, 4],
    6: [1500, 1500, 150, 150, 150, 15, 15, 15, 3],
    10: [1000, 1000, 100, 100, 100, 10, 10, 10, 2],
    None: [4000, 3000, 300, 300, 300, 30, 30, 30, 5]
}


Loading and preparing source datasets...
Error during data loading: 'list' object has no attribute 'assign'


## 4. Generate Full Combined Dataset with Caching


In [24]:
output_dir = "processed_data"
os.makedirs(output_dir, exist_ok=True)
# Use general filenames for the combined dataset
noisy_path = os.path.join(output_dir, "ALL_X_noisy_full_dataset.npy")
clean_path = os.path.join(output_dir, "ALL_X_clean_full_dataset.npy")

if os.path.exists(noisy_path) and os.path.exists(clean_path):
    print("Found cached combined data. Loading from disk...")
    X_noisy = np.load(noisy_path)
    X_no_noisy = np.load(clean_path)
    print("...Loading complete.")
else:
    # --- FIX: Check if data loading was successful before proceeding ---
    if not contaminated_data_list or not clean_data_list:
        raise ValueError("Data loading failed in the previous cell. Please check file paths and try again.")
        
    print("Cached data not found. Generating full combined dataset...")
    
    list_of_noisy_arrays, list_of_clean_arrays = [], []
    snr_values = [1, 3, 6, 10, None] 

    for snr in snr_values:
        print(f"--- Processing SNR = {snr if snr is not None else 'inf'} ---")
        # Use the contaminated CO2 data (index 1) as a consistent reference for noise calculation
        noise = mrex.generate_df_SNR_noise(df=contaminated_data_list[1], n_repeat=1, SNR=snr)["noise"][0] if snr is not None else 0.0
        
        n_repeats = n_repeats_per_snr[snr]
        temp_noisy_dfs, temp_clean_dfs = [], []

        # Iterate through all data types and apply noise with correct repetitions
        for i, (contam_df, clean_df) in enumerate(zip(contaminated_data_list, clean_data_list)):
            temp_noisy_dfs.append(generate_df_with_noise_std(df=contam_df, n_repeat=n_repeats[i], noise_std=noise))
            temp_clean_dfs.append(generate_df_with_noise_std(df=clean_df, n_repeat=n_repeats[i], noise_std=0))

        # Concatenate all data for this SNR level
        noisy_df = pd.concat(temp_noisy_dfs, ignore_index=True)
        clean_df = pd.concat(temp_clean_dfs, ignore_index=True)
        
        # Append normalized numpy arrays to the main list
        list_of_noisy_arrays.append(normalize_min_max_by_row(noisy_df.iloc[:, -n_points:]).values.astype("float32"))
        list_of_clean_arrays.append(normalize_min_max_by_row(clean_df.iloc[:, -n_points:]).values.astype("float32"))
        
        # Clean up memory for this iteration
        del temp_noisy_dfs, temp_clean_dfs, noisy_df, clean_df
        gc.collect()

    # Final concatenation of numpy arrays from all SNR levels
    X_noisy = np.concatenate(list_of_noisy_arrays, axis=0)
    X_no_noisy = np.concatenate(list_of_clean_arrays, axis=0)
    del list_of_noisy_arrays, list_of_clean_arrays
    gc.collect()
    
    print("\n--- Saving combined data to disk for future runs ---")
    np.save(noisy_path, X_noisy)
    np.save(clean_path, X_no_noisy)

print(f"\nFinal noisy data shape: {X_noisy.shape}")
print(f"Final clean data shape: {X_no_noisy.shape}")
assert X_noisy.shape[0] == X_no_noisy.shape[0], "Sample count mismatch."

NameError: name 'contaminated_data_list' is not defined

## 5. Build and Train Autoencoder Model


### Create Optimized TensorFlow Datasets


In [None]:
BATCH_SIZE = 64
TEST_SIZE = 0.2
RANDOM_STATE = 42

# Create datasets on CPU to prevent VRAM overflow on initialization
with tf.device('/CPU:0'):
    X_train_noisy, X_test_noisy, X_train_clean, X_test_clean = train_test_split(
        X_noisy, X_no_noisy, test_size=TEST_SIZE, random_state=RANDOM_STATE
    )

    # Create efficient TensorFlow dataset pipelines
    train_dataset = tf.data.Dataset.from_tensor_slices((X_train_noisy, X_train_clean))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    validation_dataset = tf.data.Dataset.from_tensor_slices((X_test_noisy, X_test_clean))
    validation_dataset = validation_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# Clean up large NumPy arrays from RAM immediately after creating TF datasets
del X_noisy, X_no_noisy, X_train_noisy, X_train_clean
gc.collect()

print("TensorFlow datasets created and source NumPy arrays cleared from RAM.")

NameError: name 'X_noisy' is not defined

### Define and Compile the Autoencoder


In [None]:
# Infer input dimension directly from the dataset specification
input_dim = validation_dataset.element_spec[0].shape[1]

input_spectrum = keras.Input(shape=(input_dim,))

# Encoder
encoded = layers.Dense(512, activation="swish")(input_spectrum)
encoded = layers.Dropout(0.2)(encoded)
encoded = layers.Dense(512, activation="swish")(encoded)
encoded = layers.Dropout(0.2)(encoded)
encoded = layers.Dense(512, activation="swish")(encoded)
encoded = layers.Dropout(0.2)(encoded)
encoded = layers.Dense(300, activation="swish")(encoded)
encoded = layers.Dropout(0.2)(encoded)
encoded = layers.Dense(300, activation="swish")(encoded)

# Decoder
decoded = layers.Dense(300, activation="swish")(encoded)
decoded = layers.Dropout(0.2)(decoded)
decoded = layers.Dense(512, activation="swish")(decoded)
decoded = layers.Dropout(0.2)(decoded)
decoded = layers.Dense(512, activation="swish")(decoded)
decoded = layers.Dropout(0.2)(decoded)
decoded = layers.Dense(512, activation="swish")(decoded)
decoded = layers.Dropout(0.2)(decoded)
decoded = layers.Dense(input_dim, activation="linear")(decoded)

autoencoder = keras.Model(inputs=input_spectrum, outputs=decoded)
optimizer = Adam(learning_rate=0.00001)
autoencoder.compile(optimizer=optimizer, loss="mae")

autoencoder.summary()


### Train the Model


In [None]:
history = autoencoder.fit(
    train_dataset,
    epochs=100,
    validation_data=validation_dataset,
    callbacks=[
        keras.callbacks.EarlyStopping(
            monitor="val_loss", patience=5, restore_best_weights=True
        )
    ],
)

# Save the general model
autoencoder.save("AE_ALL_MATERIALS.keras")
print("Model training complete and saved to AE_ALL_MATERIALS.keras")


## 6. Evaluate and Visualize Results


In [None]:
print("Plotting training history...")
plt.figure(figsize=(10, 6))
plt.plot(history.history["loss"], label="Training MAE")
plt.plot(history.history["val_loss"], label="Validation MAE")
plt.title("Model MAE Progress During Training")
plt.xlabel("Epochs")
plt.ylabel("Mean Absolute Error (MAE)")
plt.legend()
plt.grid(True)
plt.show()

print("\nEvaluating model on test data...")
# Reload test data from cache for evaluation
print("Loading test data for evaluation...")
noisy_path = os.path.join(output_dir, "ALL_X_noisy_full_dataset.npy")
clean_path = os.path.join(output_dir, "ALL_X_clean_full_dataset.npy")

X_noisy_full = np.load(noisy_path)
X_clean_full = np.load(clean_path)
_, X_test_noisy_eval, _, X_test_clean_eval = train_test_split(
    X_noisy_full, X_clean_full, test_size=TEST_SIZE, random_state=RANDOM_STATE
)
del X_noisy_full, X_clean_full
gc.collect()


print("Predicting on test data...")
decoded_spectra = autoencoder.predict(X_test_noisy_eval, batch_size=BATCH_SIZE)

# Visualize a few reconstructions
print("Visualizing sample reconstructions...")
num_samples = 5
indices = np.random.choice(len(X_test_noisy_eval), num_samples, replace=False)

for idx in indices:
    plt.figure(figsize=(12, 5))
    plt.plot(waves, X_test_clean_eval[idx].flatten(), label="Original Clean Spectrum", color='blue')
    plt.plot(waves, X_test_noisy_eval[idx].flatten(), label="Noisy Input Spectrum", color='gray', alpha=0.6)
    plt.plot(waves, decoded_spectra[idx].flatten(), label="Denoised Spectrum", color='red', linestyle='--')
    plt.xlabel("Wavelength")
    plt.ylabel("Normalized Intensity")
    plt.title(f"Spectrum Reconstruction - Sample {idx}")
    plt.legend()
    plt.grid(alpha=0.3)
    plt.show()

### Final Performance Metrics


In [None]:
print("\nCalculating final performance metrics...")
mae = mean_absolute_error(X_test_clean_eval, decoded_spectra)
print(f"Final Mean Absolute Error (MAE): {mae:.6f}")

mse = mean_squared_error(X_test_clean_eval, decoded_spectra)
print(f"Final Mean Squared Error (MSE): {mse:.6f}")

# Flatten for R² score
r2 = r2_score(X_test_clean_eval.flatten(), decoded_spectra.flatten())
print(f"Final Coefficient of Determination (R²): {r2:.6f}")
