# Goals
- Access and filter DESI EDR galaxy spectra data from a database using SPARCL.
- Process and normalize the spectra data to prepare it for model training.
- Develop a CAE with skip connections to perform dimensionality reduction and reconstruction of the spectra.
- Train the autoencoder model using a weighted mean squared error (MSE) loss function to emphasize critical spectral features.
- Identify and visualize anomalies in the galaxy spectra based on high reconstruction errors.
- Provide visual representations of detected anomalies and evaluate the model's performance through training loss metrics.


# Summary
This project leverages the Dark Energy Spectroscopic Instrument (DESI) Early Data Release (EDR) dataset to train a convolutional autoencoder (CAE) for anomaly detection in galaxy spectra. The code retrieves galaxy spectra data from a database, processes and normalizes it, and then applies an autoencoder with skip connections to reconstruct the spectra. The reconstruction errors are used to identify anomalous spectra, which may indicate unusual features or observational issues in the data.

## Commands to run in terminal to upload changes to GitHub 
##### Note, to save and exit the commit comment section: Esc, ":wq", Enter


cd "/Users/elicox/Desktop/Mac/Work/Yr4 Work/Project/CNN-auto"

git add SpectralCNNAutoencoder.ipynb

output_dir="SpectralCNNAutoencoder_output"

latest_mean_reconstruction_error=$(ls -t output_images/mean_reconstruction_error_*.png | head -n 1)

latest_anomalous_spectra=$(ls -t output_images/anomalous_spectra_*.png | head -n 1)

latest_sampling_info=$(ls -t anomalous_regions/sampling_info_*.json | head -n 1)

if [ -f "$latest_mean_reconstruction_error" ]; then
    mv "$latest_mean_reconstruction_error" "$output_dir/"
    git add "$output_dir/$(basename "$latest_mean_reconstruction_error")"
fi

if [ -f "$latest_anomalous_spectra" ]; then
    mv "$latest_anomalous_spectra" "$output_dir/"
    git add "$output_dir/$(basename "$latest_anomalous_spectra")"
fi

if [ -f "$latest_sampling_info" ]; then
    mv "$latest_sampling_info" "$output_dir/"
    git add "$output_dir/$(basename "$latest_sampling_info")"
fi

git commit 

git push origin main


# Change log
<i> Place to log changes before they are recorded in a github update:

A better fitness function has been developed.

Integration with the BlueBEAR supercomputer has been improved.

# Imports

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from tqdm import tqdm
from sparcl.client import SparclClient
from dl import queryClient as qc, authClient as ac
from getpass import getpass
import os
import re
import csv 
import torch.nn.functional as F
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import AutoMinorLocator
from torchviz import make_dot
import time
import random
from concurrent.futures import ThreadPoolExecutor, as_completed 
import logging
import json
import seaborn as sns
import requests
from io import BytesIO
from PIL import Image
import matplotlib.patches as patches
import requests
from PIL import Image, ImageDraw
from deap import base, creator, tools, algorithms

In [2]:
#token = ac.login(input("Enter user name: (+ENTER) "),getpass("Enter password: (+ENTER) "))
ac.whoAmI()

'emc180'

In [3]:
# Configure file directories
DATA_DIR = '/Users/elicox/Desktop/Mac/Work/Yr4 Work/Project/CNN-auto/'
OUT_DIR = os.path.join(DATA_DIR, 'SpectralCNNAutoencoder_output')
CSV_PATH = os.path.join(DATA_DIR, 'spectra_data.csv')
JSON_DIR = os.path.join(DATA_DIR, 'anomalous_regions')
os.makedirs(OUT_DIR, exist_ok=True)
# os.environ["PATH"] += os.pathsep + "/opt/homebrew/bin" # uncomment when CNN visualisation is needed

# Initialize SPARCL client
client = SparclClient()

# Data Loading and Saving Functions

In [4]:
def query_spectra_data():
    """
    Queries the DESI Early Data Release (EDR) database to retrieve galaxy spectra data, 
    filtering for primary spectra with a minimum number of coadded spectra and good quality flags.
    
    This function selects observational fields from the DESI EDR database:
        - `targetid`: Unique identifier for each galaxy.
        - `z`: Spectroscopic redshift.
        - `zwarn`: Redshift quality flag, where `zwarn = 0` indicates reliable data.
        - `coadd_fiberstatus`: Status of the fiber used for observation, with `0` indicating no issues.
        - `spectype`: Spectral type, filtered here to include only galaxies.
        - `mean_fiber_ra`, `mean_fiber_dec`: Mean Right Ascension and Declination of the fiber position.
        - `zcat_nspec`: Number of coadded spectra for the target.
        - `zcat_primary`: Indicates the primary spectrum for the target.
    Returns:
        pd.DataFrame: DataFrame with galaxy spectra data including flux error if available.
    """
    query = """
    SELECT zp.targetid, zp.z, zp.zwarn, zp.coadd_fiberstatus, zp.spectype, 
    zp.mean_fiber_ra, zp.mean_fiber_dec
    FROM desi_edr.zpix AS zp
    WHERE zp.zcat_primary = 't'
      AND zp.zcat_nspec > 2
      AND zp.z <= 0.5
      AND zp.spectype = 'GALAXY'
      AND zp.zwarn = '0'
    """
    try:
        print("Querying SPARCL database...")
        zpix_cat = qc.query(sql=query, fmt='table')
        df = zpix_cat.to_pandas()
        print(f"Retrieved {len(df)} records from the database.")
        return df
    except Exception as e:
        print(f"Error querying data: {e}")
        return None


def load_or_query_data(csv_path):
    """
    Loads data from CSV if available; otherwise, queries SPARCL and saves the results.

    Parameters:
    - csv_path (str): Path to the CSV file to save or load data.

    Returns:
    - pd.DataFrame: Loaded or queried data as a pandas DataFrame.
    """
    if os.path.exists(csv_path):
        print(f"Loading data from {csv_path}...")
        try:
            return pd.read_csv(csv_path)
        except Exception as e:
            print(f"Error loading CSV file: {e}")
            print("Attempting to query the database instead...")
    data = query_spectra_data()
    if data is not None:
        try:
            print(f"Saving queried data to {csv_path}...")
            data.to_csv(csv_path, index=False)
        except Exception as e:
            print(f"Error saving data to CSV: {e}")
    return data


## Information about the query
- targetid, survey, program -- unique identifiers for a given spectrum
- healpix -- healpix number for the target
- z -- spectroscopic redshift of the target
- zwarn -- encoded information regarding the redshift (zwarn = 0 is good)
- coadd_fiberstatus -- encoded information regarding the fiber that is assigned to the target (coadd_fiberstatus = 0 is good)
- spectype -- Spectral type of the target: STAR | GALAXY | QSO
- mean_fiber_ra, mean_fiber_dec -- Mean R.A. and Dec. of the fiber position from all the observations of the target
- zcat_nspec -- Number of coadded spectra that are available for a given target
- zcat_primary -- Whether or not a given coadded spectrum is the primary spectrum. zcat_primary = True for the "best" spectrum.
- CASTing this column as an INT: zcat_primary = 1 for the "best" spectrum.
- desi_target -- encodes main survey's DESI targeting information - explained in detail below
- sv1_desi_target -- encodes sv1 desi targeting information
- sv2_desi_target -- encodes sv2 desi targeting information
- sv3_desi_target -- encodes sv3 desi targeting information

# Data Preparation Functions

In [5]:
def retrieve_flux(targetid, inc, retries=5, delay=3):
    """
    Retrieves and normalizes flux data, wavelength, and calculates error for a single target ID from the DESI database, with retry logic.
    
    Parameters:
    - targetid (int): The ID of the target for which flux data is requested.
    - inc (list of str): List of attributes to include in the data retrieval (e.g., ['specid', 'flux', 'wavelength', 'ivar']).
    - retries (int, optional): The number of retry attempts if data retrieval fails (default is 5).
    - delay (int, optional): The base delay in seconds between retries, with exponential backoff (default is 3).

    Returns:
    - tuple (np.array, np.array, np.array) or None: Normalized flux, wavelength, and error arrays if retrieval is successful; otherwise, None if all retries fail.
    """
    for attempt in range(1, retries + 1):
        try:
            res = client.retrieve_by_specid(specid_list=[targetid], include=inc, dataset_list=['DESI-EDR'])
            for record in res.records:
                if record['specprimary']:
                    flux = record['flux']
                    wavelength = record['wavelength']
                    ivar = record['ivar'] 
                    
                    flux_min, flux_max = np.min(flux), np.max(flux)
                    normalized_flux = (flux - flux_min) / (flux_max - flux_min)
                    
                    error = np.sqrt(1 / np.where(ivar == 0, 1e-10, ivar))

                    
                    return normalized_flux, wavelength, error  
        except Exception:
            if attempt < retries:
                time.sleep(delay * (2 ** (attempt - 1)))  # Exponential backoff
    return None  


def process_spectra_data(zpix_cat, batch_size=20, max_workers=10):
    """
    Retrieves and processes flux data in parallel for a DataFrame of galaxy spectra from the DESI database,
    including flux, wavelength, and error data.

    Parameters:
    - zpix_cat (pd.DataFrame): DataFrame containing metadata for the galaxy spectra to be processed, including target IDs.
    - batch_size (int, optional): The number of spectra to retrieve in each batch (default is 20).
    - max_workers (int, optional): The maximum number of parallel threads for retrieving data (default is 10).

    Returns:
    - tuple (np.array, np.array, np.array): Arrays of normalized flux data, corresponding wavelength data, and error data for each galaxy spectrum.
    """
    all_fluxes = []
    all_wavelengths = []
    all_errors = []
    total_records = len(zpix_cat)
    inc = ['specid', 'redshift', 'flux', 'wavelength', 'ivar', 'spectype', 
           'specprimary', 'survey', 'program', 'targetid', 'coadd_fiberstatus']

    for start_idx in tqdm(range(0, total_records, batch_size), desc="Processing spectra in batches"):
        batch = zpix_cat.iloc[start_idx:start_idx+batch_size]
        batch_fluxes = []
        batch_wavelengths = []
        batch_errors = []

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {executor.submit(retrieve_flux, int(row['targetid']), inc): row['targetid'] for _, row in batch.iterrows()}
            for future in as_completed(futures):
                result = future.result()
                if result is not None:
                    flux, wavelength, error = result  
                    batch_fluxes.append(flux)
                    batch_wavelengths.append(wavelength)
                    batch_errors.append(error)
                else:
                    print(f"Warning: Missing flux for target ID {futures[future]}") 

        all_fluxes.extend(batch_fluxes)
        all_wavelengths.extend(batch_wavelengths)
        all_errors.extend(batch_errors)

    if not all_fluxes:
        raise ValueError("Error: No flux data retrieved. Check data source or retrieval logic.")

    return np.array(all_fluxes), np.array(all_wavelengths), np.array(all_errors)
    
def pad_spectra(fluxes, target_length):
    """
    Pads or truncates each flux array in a collection to match a specified target length.
    
    Parameters:
    - fluxes (list of np.array): List of flux arrays for each spectrum.
    - target_length (int): Desired length for each flux array, used for padding or truncation.
    
    Returns:
    - np.array: Array of padded or truncated flux data.
    
    Notes:
    - Adds zeros to the end of shorter spectra to match the target length.
    - Truncates longer spectra to fit the specified target length.
    - Prints a count of spectra that required padding.
    """
    padded_fluxes = []
    num_padded = 0  

    for flux in fluxes:
        if len(flux) < target_length:
            padding = np.zeros(target_length - len(flux))
            padded_flux = np.concatenate([flux, padding])
            num_padded += 1
        else:
            padded_flux = flux[:target_length]
        padded_fluxes.append(padded_flux)

    if num_padded > 0:
        print(f"Padding applied to {num_padded} spectra to match the target length.")
    else:
        print("No padding was necessary; all spectra are of equal length.")

    return np.array(padded_fluxes)



def create_padding_mask(fluxes, target_length):
    """
    Generates a binary mask indicating the padded regions in each flux array.
    
    Parameters:
    - fluxes (list of np.array): List of flux arrays for each spectrum.
    - target_length (int): Length of the output mask arrays, matching the padded flux array length.
    
    Returns:
    - np.array: Array of binary masks for each flux array, where 1 represents original data and 0 represents padded values.
    
    Notes:
    - The mask is used during loss calculation to ignore the padded regions of the spectra.
    """
    masks = []
    for flux in fluxes:
        mask = np.ones_like(flux)
        if len(flux) < target_length:
            mask = np.concatenate([mask, np.zeros(target_length - len(flux))])
        masks.append(mask[:target_length])
    return np.array(masks)

In [6]:
def filter_spectral_data(wavelengths, fluxes, errors, min_wavelength=4000):
    """
    Filters spectral data to retain only wavelengths above a specified minimum value.

    Parameters:
    - wavelengths (np.ndarray): 2D array of wavelength values (one spectrum per row).
    - fluxes (np.ndarray): 2D array of flux values (one spectrum per row).
    - errors (np.ndarray): 2D array of error values (one spectrum per row).
    - min_wavelength (float): Minimum wavelength to retain in Ångstrom.

    Returns:
    - filtered_wavelengths (list of np.ndarray): List of filtered wavelength arrays.
    - filtered_fluxes (list of np.ndarray): List of filtered flux arrays corresponding to wavelengths.
    - filtered_errors (list of np.ndarray): List of filtered error arrays corresponding to wavelengths.
    """
    filtered_wavelengths, filtered_fluxes, filtered_errors = [], [], []

    for i in range(len(wavelengths)):
        mask = wavelengths[i] >= min_wavelength
        filtered_wavelengths.append(wavelengths[i][mask])
        filtered_fluxes.append(fluxes[i][mask])
        filtered_errors.append(errors[i][mask])

    return filtered_wavelengths, filtered_fluxes, filtered_errors

# Autoencoder Model Definition

In [7]:
class CNNAutoencoderWithSkip(nn.Module):
    """
    A Convolutional Neural Network (CNN)-based autoencoder with skip connections for reconstructing spectral data.
    
    This model is designed to reduce the dimensionality of spectral data, capturing essential features 
    during encoding and reconstructing the data during decoding. Skip connections help retain 
    detailed information that may otherwise be lost in deeper layers, improving reconstruction 
    quality for tasks like anomaly detection or dimensionality reduction.
    
    Attributes:
    - encoder1 (nn.Conv1d): First convolutional layer of the encoder, reducing input dimensions.
    - encoder2 (nn.Conv1d): Second convolutional layer of the encoder, further reducing dimensions.
    - encoder3 (nn.Conv1d): Final convolutional layer of the encoder, creating a compressed representation.
    - decoder3 (nn.ConvTranspose1d): First transposed convolutional layer of the decoder.
    - decoder2 (nn.ConvTranspose1d): Second transposed convolutional layer of the decoder.
    - decoder1 (nn.ConvTranspose1d): Final transposed convolutional layer of the decoder, outputting the reconstructed data.
    
    Methods:
    - forward(x): Defines the forward pass of the autoencoder. The input data `x` passes through the encoder 
      layers to compress it, and then through the decoder layers to reconstruct it. Skip connections are 
      used to combine encoder and decoder layers at corresponding depths.
      
    Returns:
    - torch.Tensor: The reconstructed tensor, with values squashed between 0 and 1 using a sigmoid activation.
    """
    def __init__(self):
        super(CNNAutoencoderWithSkip, self).__init__()
        # Encoder layers with reduced channels
        self.encoder1 = nn.Conv1d(1, 64, kernel_size=3, stride=2, padding=1)
        self.encoder2 = nn.Conv1d(64, 32, kernel_size=3, stride=2, padding=1)
        self.encoder3 = nn.Conv1d(32, 16, kernel_size=3, stride=2, padding=1)  

        # Decoder layers with reduced channels
        self.decoder3 = nn.ConvTranspose1d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.decoder2 = nn.ConvTranspose1d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.decoder1 = nn.ConvTranspose1d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=0)

    def forward(self, x):
        # Encoding with skip connections
        x1 = F.relu(self.encoder1(x))  # Save this for skip connection
        x2 = F.relu(self.encoder2(x1))  # Save this for skip connection
        x3 = F.relu(self.encoder3(x2))  # Last encoding layer

        # Decoding with skip connections
        x = F.relu(self.decoder3(x3))
        
        # Adjust sizes for skip connection with x2
        if x2.size(2) > x.size(2):
            x2 = x2[:, :, :x.size(2)]
        elif x2.size(2) < x.size(2):
            x = x[:, :, :x2.size(2)]
        x = F.relu(self.decoder2(x + x2))
        
        # Adjust sizes for skip connection with x1
        if x1.size(2) > x.size(2):
            x1 = x1[:, :, :x.size(2)]
        elif x1.size(2) < x.size(2):
            x = x[:, :, :x1.size(2)]
        x = self.decoder1(x + x1)

        return torch.sigmoid(x)

# Custom Loss Function


In [8]:
def weighted_mse_loss(observed_flux, target_flux, mask, error_on_observed_flux):
    """
    Calculates a modified Mean Squared Error (MSE) loss based on the "significance difference"
    between observed and target flux, adjusted by the observational error.

    Parameters:
    - observed_flux (torch.Tensor): The model's predicted flux values (reconstructed spectra).
    - target_flux (torch.Tensor): The ground truth flux values to be compared against observed_flux.
    - mask (torch.Tensor): A binary mask that identifies which regions in the spectra should be included
                           in the loss calculation, typically to ignore padded regions.
    - error_on_observed_flux (torch.Tensor): Observational error associated with each flux point.

    Returns:
    - torch.Tensor: The mean significance-based MSE loss, adjusted by the mask to ignore specific regions.

    Process:
    1. Compute the "significance difference" by normalizing the difference between observed and target flux
       using `error_on_observed_flux`, which represents the observational uncertainty.
    2. Square the resulting "significance difference" to get a MSE-like loss, which is focused on relative
       significance rather than absolute difference.
    3. Apply the mask to this significance-based MSE loss to focus only on relevant parts of the spectrum.
    4. Return the mean of the masked, significance-based MSE loss.
    """
    significance_difference = (observed_flux - target_flux) / (error_on_observed_flux + 1e-5)
    mse_loss = significance_difference ** 2
    
    return (mse_loss * mask).mean()


# Training the Model


In [9]:
def train_autoencoder(model, data, mask, errors, epochs=50, batch_size=32, lr=0.001, grad_clip=1.0):
    """
    Trains an autoencoder model using mini-batch gradient descent with gradient clipping 
    and a custom weighted MSE loss function that incorporates errors.
    
    Parameters:
    - model (torch.nn.Module): The autoencoder model to be trained.
    - data (torch.Tensor): Input data tensor (spectra).
    - mask (torch.Tensor): Padding mask tensor.
    - errors (torch.Tensor): Errors associated with the data.
    - epochs, batch_size, lr, grad_clip: Training parameters.
    
    Process:
    1. Initializes the optimizer and trains the model over epochs.
    2. Passes errors to the loss function for significance-based weighting.
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(epochs):
        start_time = time.time()
        epoch_loss = 0
        optimizer.zero_grad()  
        
        for i in range(0, len(data), batch_size):
            batch_data, mask_batch, error_batch = data[i:i+batch_size], mask[i:i+batch_size], errors[i:i+batch_size]
            loss = weighted_mse_loss(model(batch_data), batch_data, mask_batch, error_batch)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            optimizer.zero_grad()  
            epoch_loss += loss.item()

        epoch_duration = time.time() - start_time
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(data):.8f}, Time: {epoch_duration:.2f}s")
    print(f"Final Loss after {epochs} epochs: {epoch_loss/len(data):.8f}")


# Anomaly Detection

In [10]:
def detect_anomalous_regions(original_fluxes, reconstructed_fluxes, window_size=50, 
                             abs_residual_threshold=0.1, rel_residual_threshold=0.1, 
                             range_mismatch_factor=1.5, overall_anomaly_threshold=0.3):
    """
    Identifies anomalous regions within spectra using residuals and mismatch factors,
    with added logging to justify each detected anomaly and the triggering value.

    Parameters:
    ----------
    original_fluxes : np.ndarray
        A 2D array of original flux values from the spectra, with shape (num_spectra, spectrum_length).
        Each row represents a spectrum, and each column corresponds to a specific wavelength or flux point.

    reconstructed_fluxes : np.ndarray
        A 2D array of reconstructed flux values from the model, with the same shape as `original_fluxes`.
        These are the model’s best attempt to recreate the original spectra, used to compute errors.

    window_size : int, optional, default=50
        The size of the sliding window (in data points) used to analyze local regions within each spectrum.

    abs_residual_threshold : float, optional, default=0.1
        The fixed threshold value for absolute residuals beyond which regions are flagged as anomalous.

    rel_residual_threshold : float, optional, default=0.1
        The fixed threshold value for relative residuals beyond which regions are flagged as anomalous.

    range_mismatch_factor : float, optional, default=1.5
        A factor to identify anomalies based on the range mismatch between original and reconstructed fluxes.

    overall_anomaly_threshold : float, optional, default=0.3
        The minimum fraction of a spectrum that must be flagged as anomalous to classify the entire spectrum as anomalous.

    Returns:
    -------
    anomalies : list of np.ndarray
        A list of boolean arrays, each representing a spectrum. Each array has True at positions corresponding to
        flagged regions in the spectrum and False otherwise.

    spectrum_anomalies : np.ndarray
        A 1D boolean array where each element represents whether the corresponding spectrum is classified as anomalous.

    anomaly_metadata : list of list of dict
        Metadata for each flagged region in each spectrum, with information on the range of the anomaly,
        the type of anomaly (high absolute residual, high relative residual, or range mismatch), and the value that triggered the flag.

    abs_residual_threshold : float
        The fixed threshold value for absolute residuals beyond which regions are flagged as anomalous.

    rel_residual_threshold : float
        The fixed threshold value for relative residuals beyond which regions are flagged as anomalous.

    range_mismatch_factor : float
        The factor used to determine if the reconstructed flux range mismatch flags a region as anomalous.
    """
    num_spectra, spectrum_length = original_fluxes.shape
    absolute_residuals = np.abs(original_fluxes - reconstructed_fluxes)
    relative_residuals = np.abs((original_fluxes - reconstructed_fluxes) / (original_fluxes + 1e-5))

    anomalies, spectrum_anomalies = [], np.zeros(num_spectra, dtype=bool)
    anomaly_metadata = []

    for i in range(num_spectra):
        spectrum_anomalies_count = 0
        spectrum_anomaly_flags = np.zeros(spectrum_length, dtype=bool)
        anomaly_justification = []

        for start in range(0, spectrum_length - window_size + 1, window_size // 2):
            end = start + window_size
            window_abs_residual = float(np.mean(absolute_residuals[i, start:end]))
            window_rel_residual = float(np.mean(relative_residuals[i, start:end]))
            original_range = float(np.ptp(original_fluxes[i, start:end]))
            reconstructed_range = float(np.ptp(reconstructed_fluxes[i, start:end]))

            if (window_abs_residual > abs_residual_threshold or
                window_rel_residual > rel_residual_threshold or
                reconstructed_range > original_range * range_mismatch_factor):
                
                spectrum_anomaly_flags[start:end] = True
                spectrum_anomalies_count += end - start

                reason = {
                    "range_start": start,
                    "range_end": end,
                    "reason_type": (
                        "high_absolute_residual" if window_abs_residual > abs_residual_threshold else
                        "high_relative_residual" if window_rel_residual > rel_residual_threshold else
                        "range_mismatch"
                    ),
                    "trigger_value": (
                        window_abs_residual if window_abs_residual > abs_residual_threshold else
                        window_rel_residual if window_rel_residual > rel_residual_threshold else
                        reconstructed_range / (original_range + 1e-5)
                    )
                }
                anomaly_justification.append(reason)

        spectrum_anomalies[i] = spectrum_anomalies_count / spectrum_length > overall_anomaly_threshold 
        anomalies.append(spectrum_anomaly_flags)
        anomaly_metadata.append(anomaly_justification)

    return anomalies, spectrum_anomalies, anomaly_metadata, abs_residual_threshold, rel_residual_threshold, range_mismatch_factor


In [12]:
def fitness_function(params, data):
    """
    Calculates the fitness score for anomaly detection thresholds.
    
    Parameters:
    - params (dict): Contains thresholds for:
        - abs_residual_threshold
        - rel_residual_threshold
        - range_mismatch_factor
        - overall_anomaly_threshold
    - data (dict): Contains:
        - fluxes (original spectra)
        - reconstructed_fluxes (from the model)
        
    Returns:
    - float: Fitness score (closer to 0 is better).
    """
    try:
        abs_residual_threshold = params["abs_residual_threshold"]
        rel_residual_threshold = params["rel_residual_threshold"]
        range_mismatch_factor = params["range_mismatch_factor"]
        overall_anomaly_threshold = params["overall_anomaly_threshold"]

        original_fluxes = data["fluxes"]
        reconstructed_fluxes = data["reconstructed_fluxes"]

        anomalies, spectrum_anomalies, _, _, _, _ = detect_anomalous_regions(
            original_fluxes=original_fluxes,
            reconstructed_fluxes=reconstructed_fluxes,
            abs_residual_threshold=abs_residual_threshold,
            rel_residual_threshold=rel_residual_threshold,
            range_mismatch_factor=range_mismatch_factor,
            overall_anomaly_threshold=overall_anomaly_threshold,
        )

        reconstruction_errors = compute_reconstruction_errors(original_fluxes, reconstructed_fluxes, method="mse")
        mean_reconstruction_error = np.mean(reconstruction_errors)

        anomaly_ratio = np.sum(spectrum_anomalies) / len(original_fluxes)  
        false_negative_penalty = np.mean([
            1 if not spectrum_anomalies[i] and np.any(anomalies[i]) else 0
            for i in range(len(original_fluxes))
        ])
        false_positive_penalty = np.mean([
            1 if spectrum_anomalies[i] and not np.any(anomalies[i]) else 0
            for i in range(len(original_fluxes))
        ])

        threshold_penalty = overall_anomaly_threshold ** 2  

        compactness_penalty = np.mean([
            1 - (np.sum(anomalies[i]) / len(anomalies[i]))
            for i in range(len(anomalies))
        ])

        fitness = (
            -mean_reconstruction_error * 10  # Strong emphasis on reconstruction
            - false_negative_penalty * 5  # Reduce false negatives
            - false_positive_penalty * 2  # Balance false positives
            - threshold_penalty * 5  # Penalize high thresholds
            - compactness_penalty * 1  # Light penalty for scattered anomalies
        )
        return fitness,

    except Exception as e:
        print(f"[ERROR] Fitness function failed: {e}")
        return -1e6,  # Large penalty for failures

def detection_parameters(data, param_bounds, pop_size=20, generations=50, mutation_rate=0.2):
    """
    Optimize anomaly detection thresholds using a genetic algorithm.

    Parameters:
    - data (dict): Contains 'fluxes' and 'reconstructed_fluxes'.
    - param_bounds (dict): Parameter boundaries for optimization.
    - pop_size (int): Population size.
    - generations (int): Number of generations.
    - mutation_rate (float): Probability of mutation.

    Returns:
    - dict: Best parameter configuration found.
    """
    creator.create("FitnessMax", base.Fitness, weights=(1.0,))
    creator.create("Individual", list, fitness=creator.FitnessMax)

    toolbox = base.Toolbox()

    for key, bounds in param_bounds.items():
        toolbox.register(f"attr_{key}", random.uniform, bounds[0], bounds[1])

    toolbox.register(
        "individual",
        tools.initCycle,
        creator.Individual,
        tuple(getattr(toolbox, f"attr_{key}") for key in param_bounds.keys()),
        n=1,
    )
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)

    def deap_fitness(individual):
        params = {key: individual[idx] for idx, key in enumerate(param_bounds.keys())}
        return fitness_function(params, data)

    toolbox.register("evaluate", deap_fitness)

    min_bounds = [bounds[0] for bounds in param_bounds.values()]
    max_bounds = [bounds[1] for bounds in param_bounds.values()]
    toolbox.register("mate", tools.cxBlend, alpha=0.5)
    toolbox.register("mutate", tools.mutPolynomialBounded, eta=20.0, low=min_bounds, up=max_bounds, indpb=mutation_rate)
    toolbox.register("select", tools.selTournament, tournsize=3)

    population = toolbox.population(n=pop_size)

    for gen in tqdm(range(generations), desc="Genetic Algorithm Progress"):
        offspring = toolbox.select(population, len(population))
        offspring = list(map(toolbox.clone, offspring))

        for child1, child2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < 0.5:
                toolbox.mate(child1, child2)
                del child1.fitness.values
                del child2.fitness.values

        for mutant in offspring:
            if random.random() < mutation_rate:
                toolbox.mutate(mutant)
                del mutant.fitness.values

        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        fitnesses = map(toolbox.evaluate, invalid_ind)
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit

        population[:] = offspring

    best_individual = tools.selBest(population, k=1)[0]
    best_params = {key: best_individual[idx] for idx, key in enumerate(param_bounds.keys())}
    best_fitness = best_individual.fitness.values[0]
    print(f"Optimized Parameters: {best_params}")
    print(f"Best Fitness: {best_fitness}")
    return best_params, best_fitness
def compute_dynamic_anomaly_ratio(reconstruction_errors, mad_multiplier=3):
    """
    Dynamically computes the anomaly ratio using the Median Absolute Deviation (MAD).

    Parameters:
    - reconstruction_errors (np.ndarray): Array of reconstruction errors.
    - mad_multiplier (float): Multiplier for the MAD to define the threshold.

    Returns:
    - float: Target anomaly ratio based on the error distribution.
    """
    median_error = np.median(reconstruction_errors)
    mad_error = np.median(np.abs(reconstruction_errors - median_error))
    threshold = median_error + mad_multiplier * mad_error

    anomalies = reconstruction_errors > threshold
    target_anomaly_ratio = np.mean(anomalies)  

    print(f"Dynamic Anomaly Ratio: {target_anomaly_ratio:.4f} (Threshold: {threshold:.4f})")
    return target_anomaly_ratio

# Spectra Visualisation

In [13]:
def create_save_path(save_directory, base_filename):
    """
    Generates a unique file path for PNG files in the specified directory by appending
    a sequential numeric suffix to avoid overwriting existing files. This ensures that
    saved file names are uniquely numbered.

    Parameters:
    - save_directory (str): The directory where the file should be saved. It will be created if it doesn't exist.
    - base_filename (str): The base name for the file, to which a numeric suffix will be added.

    Returns:
    - str: A full path to the new file with a sequentially numbered suffix in the specified directory,
           following the format '{base_filename}_{number}.png'.
    """

    os.makedirs(save_directory, exist_ok=True)
    existing_files = os.listdir(save_directory)
    numbers = [
        int(re.search(r'\d+', f).group())
        for f in existing_files if re.search(fr'{base_filename}_(\d+)\.png', f)
    ]
    next_number = max(numbers) + 1 if numbers else 1
    full_path = os.path.join(save_directory, f'{base_filename}_{next_number}.png')
    
    relative_path = os.path.relpath(full_path, start=os.path.dirname(save_directory))
    
    return relative_path
    
def create_json_save_path(save_directory, base_filename):
    """
    Generates a unique file path for JSON files in the specified directory by appending
    a sequential numeric suffix to avoid overwriting existing files. This ensures that
    saved JSON file names are uniquely numbered.

    Parameters:
    - save_directory (str): The directory where the JSON file should be saved. It will be created if it doesn't exist.
    - base_filename (str): The base name for the JSON file, to which a numeric suffix will be added.

    Returns:
    - str: A full path to the new JSON file with a sequentially numbered suffix in the specified directory,
           following the format '{base_filename}_{number}.json'.
    """
    os.makedirs(save_directory, exist_ok=True)
    existing_files = os.listdir(save_directory)
    numbers = [
        int(re.search(rf'{base_filename}_(\d+)\.json', f).group(1))
        for f in existing_files if re.search(rf'{base_filename}_(\d+)\.json', f)
    ]
    next_number = max(numbers) + 1 if numbers else 1
    full_path = os.path.join(save_directory, f'{base_filename}_{next_number}.json')
    relative_path = os.path.relpath(full_path, start=os.path.dirname(save_directory))
    
    return relative_path

In [14]:
def save_sampling_info(galaxy_ids, plot_galaxy_ids, seed, anomaly_metadata, abs_residual_threshold, 
                       rel_residual_threshold, range_mismatch_factor, json_directory, zpix_cat):
    """
    Saves detailed sampling information for detected anomalous spectra to a JSON file. This includes 
    metadata on each detected anomaly, threshold values, and unique target IDs associated with the anomalies.

    Parameters:
    - galaxy_ids (list of int): Indices of galaxies identified as anomalous, corresponding to row indices in `zpix_cat`.
    - plot_galaxy_ids (list of int): Indices of galaxies selected for plotting, used to match target IDs if needed.
    - seed (int): Random seed used during sampling to ensure reproducibility.
    - anomaly_metadata (list of list of dict): Metadata for each galaxy's anomaly detection results, where each inner list
      contains dictionaries with specific anomaly information (e.g., range, justification, trigger value).
    - abs_residual_threshold (float): Threshold for detecting high absolute residual anomalies.
    - rel_residual_threshold (float): Threshold for detecting high relative residual anomalies.
    - range_mismatch_factor (float): Factor threshold for range mismatches between original and reconstructed spectra.
    - json_directory (str): Directory path where the JSON file should be saved. Will be created if it doesn't exist.
    - zpix_cat (DataFrame): DataFrame containing galaxy catalog information, including `targetid` for each galaxy.

    Returns:
    - None: The function saves a JSON file containing the sampling and anomaly data.

    Process:
    1. Maps `galaxy_ids` to actual `targetid` values using `zpix_cat` to provide unique identifiers for each galaxy.
    2. Constructs a dictionary `sampling_info` to store:
       - Count of anomalies detected.
       - Sampling seed for reproducibility.
       - Thresholds for absolute residual, relative residual, and range mismatch.
       - An `anomalous_spectra` dictionary keyed by `targetid`, where each entry holds `anomaly_ranges` metadata.
    3. For each `targetid`, populates `anomaly_ranges` with details on each anomaly's range, justification, 
       and the value triggering the anomaly.
    4. Uses `create_json_save_path` to generate a unique, sequentially numbered file path for saving.
    5. Writes `sampling_info` as JSON to the generated path and confirms save location with a printed message.
    """
    target_ids = [int(zpix_cat['targetid'].iloc[idx]) for idx in galaxy_ids]

    sampling_info = {
        "Number of Anomalies": len(target_ids),
        "Seed": int(seed),
        "absolute_residual_threshold": float(abs_residual_threshold),
        "relative_residual_threshold": float(rel_residual_threshold),
        "range_mismatch_factor": float(range_mismatch_factor),
        "anomalous_spectra": {}
    }

    for idx, target_id in enumerate(target_ids):
        metadata = anomaly_metadata[idx]
        anomaly_ranges = []

        for reason in metadata:
            anomaly_ranges.append({
                "range_start": reason["range_start"],
                "range_end": reason["range_end"],
                "justification": reason["reason_type"],
                "value": reason.get("trigger_value", None)  
            })

        sampling_info["anomalous_spectra"][target_id] = {
            "anomaly_ranges": anomaly_ranges
        }

    json_path = create_json_save_path(json_directory, 'sampling_info')
    secondary_path = create_json_save_path(OUT_DIR, 'sampling_info')

    with open(json_path, 'w') as json_file:
        json.dump(sampling_info, json_file, indent=4)
    with open(secondary_path, 'w') as json_file:
        json.dump(sampling_info, json_file, indent=4)
    
    print(f"Sampling information saved to {json_path} and {secondary_path}")

In [15]:
def fetch_image(ra, dec, pixscale=0.262, width=256, height=256, survey="ls-dr10"):
    """
    Fetches an image from the Legacy Survey for the specified coordinates and overlays a circle representing 
    the DESI fiber diameter (1.5 arcseconds).

    Parameters:
    - ra (float): Right Ascension (RA) in degrees for the target.
    - dec (float): Declination (DEC) in degrees for the target.
    - pixscale (float): Pixel scale in arcseconds per pixel. Default is 0.262 for Legacy Survey.
    - width (int): Width of the returned image in pixels. Default is 256.
    - height (int): Height of the returned image in pixels. Default is 256.
    - survey (str): Legacy Survey data release, e.g., "ls-dr10".

    Returns:
    - Image: PIL Image object with the circle overlayed.
    - None: Returns None if the image retrieval fails.
    """
    fiber_diameter_arcsec = 1.5

    fiber_radius_pixels = fiber_diameter_arcsec / (2 * pixscale)

    url = f"https://www.legacysurvey.org/viewer/jpeg-cutout?ra={ra}&dec={dec}" \
          f"&pixscale={pixscale}&layer={survey}&size={max(width, height)}"

    response = requests.get(url)
    if response.status_code == 200:
        img_data = BytesIO(response.content)
        img = Image.open(img_data).convert("RGBA")  

        draw = ImageDraw.Draw(img)

        center_x, center_y = width // 2, height // 2

        draw.ellipse(
            [
                (center_x - fiber_radius_pixels, center_y - fiber_radius_pixels),
                (center_x + fiber_radius_pixels, center_y + fiber_radius_pixels)
            ],
            outline="white",  
            width=2           
        )
        return img
    else:
        return None

def plot_spectra(
    original_fluxes, reconstructed_fluxes, anomalous_regions, spectrum_anomalies,
    zpix_cat, wavelengths=None, save_directory='SpectralCNNAutoencoder_output', max_samples=20, 
    seed=None, only_anomalous=False
):
    """
    Plots original and reconstructed spectra for multiple galaxies, highlighting anomalous spectra.
    
    Parameters:
    - original_fluxes (list of ndarray): The original flux values for each spectrum.
    - reconstructed_fluxes (list of ndarray): Reconstructed flux values, for comparison with original fluxes.
    - anomalous_regions (list of lists of bool): Each list contains boolean values indicating anomalous regions in each spectrum.
    - spectrum_anomalies (list of bool): Indicates if a given spectrum is anomalous. Used to filter if only_anomalous is set to True.
    - zpix_cat (DataFrame): Metadata with RA/Dec information for each spectrum, allowing for SDSS image retrieval.
    - wavelengths (ndarray, optional): Wavelengths to use for x-axis; defaults to index if not provided.
    - save_directory (str, optional): Directory path to save the generated plot image. Default is 'output_images'.
    - max_samples (int, optional): Maximum number of spectra to plot. Useful for limiting data when dataset is large. Default is 20.
    - seed (int, optional): Seed for random sampling, ensuring reproducibility. If None, a random seed is used. Default is None.
    - only_anomalous (bool, optional): If True, plots only the spectra identified as anomalous. Default is False.

    Returns:
    - tuple: (anomalous_indices, plot_indices, seed)
        - anomalous_indices (list): List of indices for spectra flagged as anomalous.
        - plot_indices (list): Indices of spectra that were plotted.
        - seed (int): Random seed used for sampling.
    """
    
    anomalous_indices = [i for i, is_anomalous in enumerate(spectrum_anomalies) if is_anomalous] if only_anomalous else list(range(len(original_fluxes)))
    
    if not anomalous_indices:
        print("No anomalous spectra detected.")
        return [], [], None  
        
    if seed is None:
        seed = random.randint(0, 10000)
    random.seed(seed)
    print(f"Random sampling seed: {seed}")

    plot_indices = random.sample(anomalous_indices, min(max_samples, len(anomalous_indices)))

    image_width, image_height = 256, 256

    fig, axes = plt.subplots(len(plot_indices), 2, figsize=(20, 6 * len(plot_indices)), gridspec_kw={'width_ratios': [3, 1]})
    if len(plot_indices) == 1:
        axes = [axes] 

    for idx, i in enumerate(plot_indices):
        ra, dec = zpix_cat.loc[zpix_cat['targetid'] == zpix_cat['targetid'].iloc[i], ['mean_fiber_ra', 'mean_fiber_dec']].values[0]
        x_axis = wavelengths if wavelengths is not None else range(len(original_fluxes[i]))

        ax_spectra = axes[idx, 0]
        ax_spectra.plot(x_axis, original_fluxes[i], label="Original", color='#2c7bb6', linewidth=0.5)
        ax_spectra.plot(x_axis, reconstructed_fluxes[i], label="Reconstructed", color='#d7191c', alpha=0.7, linewidth=0.5)

        in_anomaly, anomaly_start = False, 0
        for j in range(len(x_axis)):
            if anomalous_regions[i][j] and not in_anomaly:
                anomaly_start = j
                in_anomaly = True
            elif not anomalous_regions[i][j] and in_anomaly:
                ax_spectra.axvspan(x_axis[anomaly_start], x_axis[j], color='#fdae61', alpha=0.7)
                in_anomaly = False
        if in_anomaly:
            ax_spectra.axvspan(x_axis[anomaly_start], x_axis[-1], color='#fdae61', alpha=0.7)

        if spectrum_anomalies[i]:
            ax_spectra.set_facecolor('#fee090')
            ax_spectra.set_title(f"Spectrum ID: {zpix_cat['targetid'].iloc[i]} - Anomalous", color='red',fontsize=23)
        else:
            ax_spectra.set_title(f"Spectrum ID: {zpix_cat['targetid'].iloc[i]}", color='black',fontsize=23)
        ax_spectra.xaxis.set_visible(False)  

        divider = make_axes_locatable(ax_spectra)
        ax_residual = divider.append_axes("bottom", size="25%", pad=0, sharex=ax_spectra)
        ax_residual.plot(x_axis, original_fluxes[i] - reconstructed_fluxes[i], color='#4dac26', linewidth=0.5)
        ax_residual.set_ylabel("Residuals",fontsize=20)
        ax_spectra.set_ylabel("Flux (normalized)",fontsize=20)
        ax_residual.set_xlabel("Wavelength (Å)",fontsize=20)
        ax_spectra.legend()

        img_data = fetch_image(ra, dec)
        ax_image = axes[idx, 1]
        if img_data:
            img = img_data
            ax_image.imshow(img)
            ax_image.axis("off")
            ax_image.set_title(f"Galaxy Image\nRA={ra:.4f}, Dec={dec:.4f}",fontsize=23)
            circle_radius = 30
            circle = patches.Circle((image_width / 2, image_height / 2), circle_radius, transform=ax_image.transData, 
                                    edgecolor="white", facecolor="none", linewidth=2)
            ax_image.add_patch(circle)
        else:
            ax_image.text(0.5, 0.5, "Image Not Found", ha="center", va="center")
        ax_image.axis("off")

    plt.tight_layout()
    base_filename = "anomalous_spectra" if only_anomalous else "spectra_reconstruction"
    save_path = create_save_path(save_directory, base_filename)
    plt.savefig(save_path, dpi=300)
    plt.close(fig)
    print(f"Figure saved to {save_path}")
    print(f"Number of anomalous spectra: {len(anomalous_indices)}")
    return anomalous_indices, plot_indices, seed


In [16]:
def visualize_autoencoder(autoencoder, input_data):
    """
    Visualizes the entire structure of the autoencoder model by generating a graph of the
    network's architecture, including both encoder and decoder layers, and saves it as an image.

    Parameters:
    - autoencoder (nn.Module): The autoencoder model instance to visualize.
    - input_data (torch.Tensor): A sample input tensor for passing through the model to generate 
                                 the visualization. This input should match the model's expected input shape.
    Returns:
    - None. The function saves the visualization image file to `OUT_DIR` and outputs the path.
    """
    outputs = autoencoder(input_data)
    model_viz = make_dot(outputs, params=dict(autoencoder.named_parameters()))
    model_viz.format = "png"
    save_path = create_save_path(OUT_DIR, 'autoencoder_visualization')
    model_viz.render(save_path.replace(".png", ""))
    print(f"Full autoencoder visualization saved to {save_path}")

def visualize_encoder(autoencoder, input_data):
    """
    Visualizes the encoder section of the autoencoder model, which compresses input data into a 
    lower-dimensional representation. Saves the encoder visualization as an image.

    Parameters:
    - autoencoder (nn.Module): The autoencoder model instance containing the encoder layers.
    - input_data (torch.Tensor): A sample input tensor for passing through only the encoder 
                                 layers. The input shape should match the encoder's expected input.

    Returns:
    - None. The function saves the encoder visualization image to `OUT_DIR` and outputs the path.
    """
    encoder_output = autoencoder.encoder3(autoencoder.encoder2(autoencoder.encoder1(input_data)))
    encoder_viz = make_dot(encoder_output, params=dict(autoencoder.named_parameters()), 
                           show_attrs=True, show_saved=True)
    encoder_viz.format = "png"
    save_path = create_save_path(OUT_DIR, 'encoder_visualization')
    encoder_viz.render(save_path.replace(".png", ""))
    print(f"Encoder visualization saved to {save_path}")


def visualize_decoder(autoencoder, encoded_input):
    """
    Visualizes the decoder portion of the autoencoder model, which reconstructs the original 
    data from a lower-dimensional encoding. Saves the decoder visualization as an image.

    Parameters:
    - autoencoder (nn.Module): The autoencoder model instance containing the decoder layers.
    - encoded_input (torch.Tensor): An encoded representation that serves as the input to the 
                                    decoder layers. This should match the expected shape of the 
                                    decoder's input.
    Returns:
    - None. The function saves the decoder visualization image to `OUT_DIR` and outputs the path.
    """
    decoder_output = autoencoder.decoder1(autoencoder.decoder2(autoencoder.decoder3(encoded_input)))
    decoder_viz = make_dot(decoder_output, params=dict(autoencoder.named_parameters()), 
                           show_attrs=True, show_saved=True)
    decoder_viz.format = "png"
    save_path = create_save_path(OUT_DIR, 'decoder_visualization')
    decoder_viz.render(save_path.replace(".png", ""))
    print(f"Decoder visualization saved to {save_path}")

In [17]:
def compute_reconstruction_errors(original_fluxes, reconstructed_fluxes, method="mse"):
    """
    Computes reconstruction errors between original and reconstructed fluxes.

    Parameters:
    - original_fluxes (np.ndarray): Original flux values.
    - reconstructed_fluxes (np.ndarray): Reconstructed flux values from the model.
    - method (str): Error computation method, either 'mse' for mean squared error or 'mae' for mean absolute error.

    Returns:
    - np.ndarray: Array of errors with shape (num_spectra, num_wavelengths).
    """
    if method == "mse":
        errors = (original_fluxes - reconstructed_fluxes) ** 2
    elif method == "mae":
        errors = np.abs(original_fluxes - reconstructed_fluxes)
    else:
        raise ValueError("Method must be 'mse' or 'mae'")
    
    return errors

def plot_reconstruction_error_distribution(reconstruction_errors, wavelengths, save_directory='output_images'):
    """
    Visualizes the reconstruction error distribution across wavelengths.

    Parameters:
    - reconstruction_errors (np.ndarray): Array of reconstruction errors per spectrum and wavelength.
    - wavelengths (array-like): Array of wavelength values corresponding to each error point.
    - save_directory (str, optional): Directory path to save the generated plot image.
    """
    if wavelengths is None or len(wavelengths) != reconstruction_errors.shape[1]:
        raise ValueError("Wavelengths array must be provided and must match the number of wavelengths in reconstruction errors.")
    
    mean_errors = np.mean(reconstruction_errors, axis=0)

    plt.figure(figsize=(12, 7))
    plt.plot(wavelengths, mean_errors, label='Mean Reconstruction Error', color='#0571b0', linewidth=0.5)
    plt.fill_between(wavelengths, mean_errors, color='#67a9cf', alpha=1)
    plt.xlabel("Wavelength (Å)")
    plt.ylabel("Mean Reconstruction Error")
    plt.legend()

    os.makedirs(save_directory, exist_ok=True)
    line_plot_path = create_save_path(save_directory, 'mean_reconstruction_error')
    plt.savefig(line_plot_path, dpi=300)
    plt.close()
    print(f"Plot of mean reconstruction error saved to {line_plot_path}")
    
def plot_redshift_distribution(zpix_cat, output_directory):
    """
    Plots and saves a histogram of redshift values from the zpix_cat DataFrame 
    and prints key redshift statistics.

    Parameters:
    - zpix_cat (pd.DataFrame): DataFrame containing redshift values in a column named 'z'.
    - output_directory (str): Path to the directory where the plot will be saved.
    
    Returns:
    - None
    """
    redshifts = zpix_cat['z'].to_numpy()

    plt.figure(figsize=(12, 6))
    sns.histplot(redshifts, bins=50, kde=True, color="blue", alpha=0.7)
    plt.xlabel("Redshift (z)")
    plt.ylabel("Count")
    plt.title("Distribution of Redshifts in the Dataset")
    plt.grid(True)

    os.makedirs(OUT_DIR, exist_ok=True)
    line_plot_path = create_save_path(OUT_DIR, 'redshift_dist')
    plt.savefig(line_plot_path, dpi=300)
    plt.close()
    print(f"Redshift distribution plot saved to {line_plot_path}")

    print(f"Total Galaxies: {len(redshifts)}")
    print(f"Redshift Range: {redshifts.min():.2f} - {redshifts.max():.2f}")
    print(f"Mean Redshift: {redshifts.mean():.2f}")
    print(f"Median Redshift: {np.median(redshifts):.2f}")
    print(f"Number of Galaxies with Redshift <= 0.5: {np.sum(redshifts <= 0.5)}")

# Main Execution Flow


In [18]:
zpix_cat = load_or_query_data(CSV_PATH)
if zpix_cat is not None:
    all_fluxes, all_wavelengths, all_errors = process_spectra_data(zpix_cat)
    
    max_length = max(len(f) for f in all_fluxes)
    
    all_fluxes_padded = pad_spectra(all_fluxes, max_length)
    all_wavelengths_padded = pad_spectra(all_wavelengths, max_length)
    all_errors_padded = pad_spectra(all_errors, max_length)
    
    mask = create_padding_mask(all_fluxes, max_length)
    

    all_fluxes_tensor = torch.tensor(all_fluxes_padded, dtype=torch.float32).unsqueeze(1)
    all_wavelengths_tensor = torch.tensor(all_wavelengths_padded, dtype=torch.float32).unsqueeze(1)
    all_errors_tensor = torch.tensor(all_errors_padded, dtype=torch.float32).unsqueeze(1)
    mask_tensor = torch.tensor(mask, dtype=torch.float32).unsqueeze(1)
    
    wavelengths = np.mean(all_wavelengths, axis=0) if len(set(map(len, all_wavelengths))) == 1 else all_wavelengths[0]
else:
    print("No data available for processing.")

Loading data from /Users/elicox/Desktop/Mac/Work/Yr4 Work/Project/CNN-auto/spectra_data>2.csv...


Processing spectra in batches: 100%|██████████| 28/28 [09:32<00:00, 20.44s/it]


No padding was necessary; all spectra are of equal length.
No padding was necessary; all spectra are of equal length.
No padding was necessary; all spectra are of equal length.


In [19]:
filtered_wavelengths, filtered_fluxes, filtered_errors = filter_spectral_data(
    wavelengths=all_wavelengths,
    fluxes=all_fluxes_padded,
    errors=all_errors_padded,
    min_wavelength=4000
)

max_length = max(len(f) for f in filtered_fluxes)
padded_fluxes = np.array([np.pad(f, (0, max_length - len(f)), constant_values=0) for f in filtered_fluxes])
padded_errors = np.array([np.pad(e, (0, max_length - len(e)), constant_values=0) for e in filtered_errors])

filtered_flux_tensor = torch.tensor(padded_fluxes, dtype=torch.float32).unsqueeze(1)  
filtered_error_tensor = torch.tensor(padded_errors, dtype=torch.float32).unsqueeze(1)  
filtered_mask_tensor = (filtered_flux_tensor > 0).float()

autoencoder = CNNAutoencoderWithSkip()
train_autoencoder(autoencoder, filtered_flux_tensor, filtered_mask_tensor, filtered_error_tensor)

autoencoder.eval()
reconstructed_fluxes = autoencoder(filtered_flux_tensor).detach().numpy().squeeze()


Epoch [1/50], Loss: 0.02630807, Time: 5.08s
Epoch [2/50], Loss: 0.01291959, Time: 4.13s
Epoch [3/50], Loss: 0.00353922, Time: 3.95s
Epoch [4/50], Loss: 0.00123957, Time: 3.79s
Epoch [5/50], Loss: 0.00050564, Time: 3.95s
Epoch [6/50], Loss: 0.00027383, Time: 5.95s
Epoch [7/50], Loss: 0.00019097, Time: 4.60s
Epoch [8/50], Loss: 0.00014494, Time: 4.44s
Epoch [9/50], Loss: 0.00012458, Time: 4.93s
Epoch [10/50], Loss: 0.00010752, Time: 5.11s
Epoch [11/50], Loss: 0.00010162, Time: 4.16s
Epoch [12/50], Loss: 0.00009732, Time: 4.09s
Epoch [13/50], Loss: 0.00009356, Time: 5.12s
Epoch [14/50], Loss: 0.00008985, Time: 4.31s
Epoch [15/50], Loss: 0.00008416, Time: 4.14s
Epoch [16/50], Loss: 0.00007706, Time: 4.32s
Epoch [17/50], Loss: 0.00007088, Time: 3.85s
Epoch [18/50], Loss: 0.00006284, Time: 4.30s
Epoch [19/50], Loss: 0.00005951, Time: 5.10s
Epoch [20/50], Loss: 0.00005538, Time: 4.40s
Epoch [21/50], Loss: 0.00005192, Time: 4.14s
Epoch [22/50], Loss: 0.00004899, Time: 4.18s
Epoch [23/50], Loss

In [33]:
data = {
    "fluxes": filtered_flux_tensor.numpy().squeeze(),
    "reconstructed_fluxes": reconstructed_fluxes,
}
parameter_bounds = {
    "abs_residual_threshold": (0.05, 2.0),      
    "rel_residual_threshold": (0.05, 2.0),      
    "range_mismatch_factor": (1.0, 3.0),        
    "overall_anomaly_threshold": (0.0, 1.0)     
}
best_params, best_fitness = detection_parameters(
    data=data,
    param_bounds=parameter_bounds,
    pop_size=20,
    generations=50,
    mutation_rate=0.1,
)

Genetic Algorithm Progress: 100%|██████████| 50/50 [1:05:34<00:00, 78.69s/it]

Optimized Parameters: {'abs_residual_threshold': 1.6401159722358578, 'rel_residual_threshold': 1.4051370923060555, 'range_mismatch_factor': 2.275631548528253, 'overall_anomaly_threshold': 0.46746769381116254}
Best Fitness: 0.0022660213473495094





In [34]:

# Plot redshift distribution for the data
plot_redshift_distribution(zpix_cat, OUT_DIR)

(anomalous_regions, spectrum_anomalies, anomaly_metadata, 
 abs_residual_threshold, rel_residual_threshold, range_mismatch_factor) = detect_anomalous_regions(
    original_fluxes=filtered_flux_tensor.numpy().squeeze(),
    reconstructed_fluxes=reconstructed_fluxes,
    window_size=50,
    abs_residual_threshold=best_params["abs_residual_threshold"],
    rel_residual_threshold=best_params["rel_residual_threshold"],
    range_mismatch_factor=best_params["range_mismatch_factor"],
    overall_anomaly_threshold=best_params["overall_anomaly_threshold"]
)

anomalous_indices, plot_indices, seed = plot_spectra(
    original_fluxes=filtered_flux_tensor.numpy().squeeze(),
    reconstructed_fluxes=reconstructed_fluxes,
    anomalous_regions=anomalous_regions,
    spectrum_anomalies=spectrum_anomalies,
    zpix_cat=zpix_cat,
    wavelengths=np.mean(filtered_wavelengths, axis=0),  
    save_directory='SpectralCNNAutoencoder_output',
    max_samples = 10,
    only_anomalous=False
)
reconstruction_errors = compute_reconstruction_errors(
    original_fluxes=filtered_flux_tensor.numpy().squeeze(),
    reconstructed_fluxes=reconstructed_fluxes
)
plot_reconstruction_error_distribution(
    reconstruction_errors=reconstruction_errors,
    wavelengths=np.mean(filtered_wavelengths, axis=0),  
    save_directory=OUT_DIR
)
default_params = {
    "abs_residual_threshold": (parameter_bounds["abs_residual_threshold"][0] + parameter_bounds["abs_residual_threshold"][1]) / 2,
    "rel_residual_threshold": (parameter_bounds["rel_residual_threshold"][0] + parameter_bounds["rel_residual_threshold"][1]) / 2,
    "range_mismatch_factor": (parameter_bounds["range_mismatch_factor"][0] + parameter_bounds["range_mismatch_factor"][1]) / 2,
    "overall_anomaly_threshold": (parameter_bounds["overall_anomaly_threshold"][0] + parameter_bounds["overall_anomaly_threshold"][1]) / 2,
}
original_params = {"abs_residual_threshold":0.1, "rel_residual_threshold":0.1, 
                             "range_mismatch_factor":1.5, "overall_anomaly_threshold":0.3}
original_fitness = fitness_function(original_params, data)
print(f"Original Fitness: {original_fitness[0]}")

baseline_fitness = fitness_function(default_params, data)
print(f"Initial Fitness: {baseline_fitness[0]}")

print(f"Optimized Fitness: {best_fitness}")


Redshift distribution plot saved to SpectralCNNAutoencoder_output/redshift_dist_14.png
Total Galaxies: 543
Redshift Range: -0.00 - 0.50
Mean Redshift: 0.36
Median Redshift: 0.37
Number of Galaxies with Redshift <= 0.5: 543
Random sampling seed: 1846
Figure saved to SpectralCNNAutoencoder_output/spectra_reconstruction_33.png
Number of anomalous spectra: 543
Plot of mean reconstruction error saved to SpectralCNNAutoencoder_output/mean_reconstruction_error_73.png
Original Fitness: 0.00045046416843572894
Initial Fitness: 0.002213027552324865
Optimized Fitness: 0.0022660213473495094


# vvvvvvvvvvv DEBUG ZONE vvvvvvvvvvv