# Goals
- Access and filter DESI EDR galaxy spectra data from a database using SPARCL.
- Process and normalis the spectra data to prepare it for model training.
- Develop a CAE with skip connections to perform dimensionality reduction and reconstruction of the spectra.
- Train the autoencoder model using a weighted mean squared error (MSE) loss function to emphasise critical spectral features.
- Identify and visualise anomalies in the galaxy spectra based on high reconstruction errors.
- Provide visual representations of detected anomalies and evaluate the model's performance through training loss metrics.


# Summary
This project leverages the Dark Energy Spectroscopic Instrument (DESI) Early Data Release (EDR) dataset to train a convolutional autoencoder (CAE) for anomaly detection in galaxy spectra. The code retrieves galaxy spectra data from a database, processes and normaliss it, and then applies an autoencoder with skip connections to reconstruct the spectra. The reconstruction errors are used to identify anomalous spectra, which may indicate unusual features or observational issues in the data.

## Commands to run in terminal to upload changes to GitHub 
##### Note, to save and exit the commit comment section: Esc, ":wq", Enter


cd "/Users/elicox/Desktop/Mac/Work/Yr4 Work/Project/CNN-auto"

git add SuperCAE.ipynb

output_dir="SpectralCNNAutoencoder_output"

latest_mean_reconstruction_error=$(ls -t output_images/mean_reconstruction_error_*.png | head -n 1)

latest_anomalous_spectra=$(ls -t output_images/anomalous_spectra_*.png | head -n 1)

latest_sampling_info=$(ls -t anomalous_regions/sampling_info_*.json | head -n 1)

if [ -f "$latest_mean_reconstruction_error" ]; then
    mv "$latest_mean_reconstruction_error" "$output_dir/"
    git add "$output_dir/$(basename "$latest_mean_reconstruction_error")"
fi

if [ -f "$latest_anomalous_spectra" ]; then
    mv "$latest_anomalous_spectra" "$output_dir/"
    git add "$output_dir/$(basename "$latest_anomalous_spectra")"
fi

if [ -f "$latest_sampling_info" ]; then
    mv "$latest_sampling_info" "$output_dir/"
    git add "$output_dir/$(basename "$latest_sampling_info")"
fi

git commit 

git push origin main


# Change log
<i> Place to log changes before they are recorded in a github update:

As the deadline to the project became closer, the order within this report fell away therefore, this is the tidied version of the code used for the report.

Beyond tidying, I have changed the normalisation from min max to robust scalar.

# Imports and Global Variables

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from tqdm import tqdm
from sparcl.client import SparclClient
from dl import queryClient as qc, authClient as ac
from getpass import getpass
import os
import re
import csv 
import torch.nn.functional as F
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import AutoMinorLocator
from torchviz import make_dot
import time
import random
from concurrent.futures import ThreadPoolExecutor, as_completed 
import logging
import json
import seaborn as sns
import requests
from io import BytesIO
from PIL import Image
import matplotlib.patches as patches
import requests
from PIL import Image, ImageDraw
from deap import base, creator, tools, algorithms
import subprocess
from skopt.space import Real
from skopt.utils import use_named_args
from skopt import gp_minimize
from skopt.plots import plot_convergence, plot_objective, plot_evaluations
from skopt.plots import plot_evaluations
from scipy.stats import gaussian_kde
from scipy.interpolate import griddata
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import RobustScaler

#token = ac.login(input("Enter user name: (+ENTER) "),getpass("Enter password: (+ENTER) "))
ac.whoAmI()

# Configure file directories
DATA_DIR = '/Users/elicox/Desktop/Mac/Work/Yr4 Work/Project/CNN-auto/'
OUT_DIR = os.path.join(DATA_DIR, 'SpectralCNNAutoencoder_output')
CSV_PATH = os.path.join(DATA_DIR, 'spectra_data_2.csv')
JSON_DIR = os.path.join(DATA_DIR, 'anomalous_regions')
os.makedirs(OUT_DIR, exist_ok=True)
# os.environ["PATH"] += os.pathsep + "/opt/homebrew/bin" # uncomment when CNN visualisation is needed

# Initialise SPARCL client
client = SparclClient()

# Database and Data Query Functions 

In [2]:
def query_spectra_data():
    """
    Queries the DESI Early Data Release (EDR) database to retrieve galaxy spectra data, 
    filtering for primary spectra with a minimum number of coadded spectra and good quality flags.
    
    This function selects observational fields from the DESI EDR database:
        - `targetid`: Unique identifier for each galaxy.
        - `z`: Spectroscopic redshift.
        - `zwarn`: Redshift quality flag, where `zwarn = 0` indicates reliable data.
        - `coadd_fiberstatus`: Status of the fiber used for observation, with `0` indicating no issues.
        - `spectype`: Spectral type, filtered here to include only galaxies.
        - `mean_fiber_ra`, `mean_fiber_dec`: Mean Right Ascension and Declination of the fiber position.
        - `zcat_nspec`: Number of coadded spectra for the target.
        - 'zp.wavelength': Wavelength data for the object.
    Returns:
        pd.DataFrame: DataFrame with galaxy spectra data including flux error if available.
    """
    query = """
    SELECT zp.targetid, zp.z, zp.zwarn, zp.coadd_fiberstatus, zp.spectype, 
    zp.mean_fiber_ra, zp.mean_fiber_dec
    FROM desi_edr.zpix AS zp
    WHERE zp.zcat_primary = 't'
      AND zp.zcat_nspec > 2
      AND zp.z <= 0.5
      AND zp.spectype = 'GALAXY'
      AND zp.zwarn = '0'
    """
    try:
        print("Querying SPARCL database...")
        spec_data = qc.query(sql=query, fmt='table')
        df = spec_data.to_pandas()
        print(f"Retrieved {len(df)} records from the database.")
        return df
    except Exception as e:
        print(f"Error querying data: {e}")
        return None


def load_or_query_data(csv_path):
    """
    Loads data from CSV if available; otherwise, queries SPARCL and saves the results.

    Parameters:
    - csv_path (str): Path to the CSV file to save or load data.

    Returns:
    - pd.DataFrame: Loaded or queried data as a pandas DataFrame.
    """
    if os.path.exists(csv_path):
        print(f"Loading data from {csv_path}...")
        try:
            return pd.read_csv(csv_path)
        except Exception as e:
            print(f"Error loading CSV file: {e}")
            print("Attempting to query the database instead...")
    data = query_spectra_data()
    if data is not None:
        try:
            print(f"Saving queried data to {csv_path}...")
            data.to_csv(csv_path, index=False)
        except Exception as e:
            print(f"Error saving data to CSV: {e}")
    return data


# Data Retrieval and Parallel Processing

In [3]:
def retrieve_flux(targetid, inc, retries=5, delay=3):
    """
    Retrieves and normalises flux data using RobustScaler, wavelength, and calculates error for a single target ID from the DESI database, with retry logic.
    """
    scaler = RobustScaler()

    for attempt in range(1, retries + 1):
        try:
            res = client.retrieve_by_specid(specid_list=[targetid], include=inc, dataset_list=['DESI-EDR'])
            for record in res.records:
                if record['specprimary']:
                    flux = record['flux']
                    wavelength = record['wavelength']
                    ivar = record['ivar']

                    flux = flux.reshape(-1, 1)
                    normalised_flux = scaler.fit_transform(flux).flatten()
                    error = np.sqrt(1 / np.where(ivar == 0, 1e-10, ivar))

                    return normalised_flux, wavelength, error  
        except Exception:
            if attempt < retries:
                time.sleep(delay * (2 ** (attempt - 1)))  # Exponential backoff
    return None, None, None


def process_spectra_data(spec_data, batch_size=20, max_workers=10, output_csv_path=CSV_PATH):
    """
    Retrieves and processes flux data in parallel for a DataFrame of galaxy spectra from the DESI database,
    including flux, wavelength, and error data.

    Parameters:
    - spec_data (pd.DataFrame): DataFrame containing metadata for the galaxy spectra to be processed.
    - batch_size (int): Number of spectra to retrieve in each batch.
    - max_workers (int): Max number of parallel threads.
    - output_csv_path (str): Path to save the updated DataFrame with flux and wavelength data.

    Returns:
    - tuple (list, list, list): Lists of flux, wavelength, and error values.
    """
    all_fluxes, all_wavelengths, all_errors = [], [], []
    total_records = len(spec_data)

    inc = ['specid', 'redshift', 'flux', 'wavelength', 'ivar', 'spectype', 
           'specprimary', 'survey', 'program', 'targetid', 'coadd_fiberstatus']

    for start_idx in tqdm(range(0, total_records, batch_size), desc="Processing spectra in batches"):
        batch = spec_data.iloc[start_idx:start_idx + batch_size]
        batch_fluxes, batch_wavelengths, batch_errors = [], [], []

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {
                executor.submit(retrieve_flux, int(row['targetid']), inc): row['targetid']
                for _, row in batch.iterrows()
            }

            for future in as_completed(futures):
                result = future.result()
                if result is not None:
                    flux, wavelength, error = result
                    batch_fluxes.append(flux)
                    batch_wavelengths.append(wavelength)
                    batch_errors.append(error)
                else:
                    print(f"Warning: Missing flux for target ID {futures[future]}")

        all_fluxes.extend(batch_fluxes)
        all_wavelengths.extend(batch_wavelengths)
        all_errors.extend(batch_errors)

    if not all_fluxes:
        raise ValueError("Error: No flux data retrieved. Check data source or retrieval logic.")

    spec_data['flux'] = [list(f) for f in all_fluxes]
    spec_data['wavelength'] = [list(w) for w in all_wavelengths]
    spec_data['error'] = [list(e) for e in all_errors]

    print(f"Saving updated DataFrame with flux and wavelength data to {output_csv_path}...")
    spec_data.to_csv(output_csv_path, index=False)

    return all_fluxes, all_wavelengths, all_errors


# Data Preprocessing and Normalisation

In [4]:
def pad_or_truncate_generator(arrays, target_length):
    """
    Generator that yields arrays padded with zeros or truncated to a fixed length.

    Parameters:
    - arrays (iterable): A collection of 1D arrays (e.g., lists or numpy arrays) of varying lengths.
    - target_length (int): The desired length for each output array.

    Yields:
    - np.ndarray: A 1D array of length `target_length`, either padded with zeros or truncated as needed.
    """
    for arr in arrays:
        if arr is None:
            yield np.zeros(target_length)
        else:
            yield np.pad(arr, (0, max(target_length - len(arr), 0)), mode='constant')[:target_length]

def pad_or_truncate(arrays, target_length):
    """
    Pads or truncates a list of 1D arrays to a uniform target length and stacks them into a 2D array.

    Parameters:
    - arrays (list of array-like): A list of arrays or lists to be padded or truncated.
    - target_length (int): The length to which each array should be padded or truncated.

    Returns:
    - np.ndarray: A 2D array where each row has shape (target_length,).
    """
    padded_list = list(pad_or_truncate_generator(arrays, target_length))
    return np.vstack(padded_list)

def filter_spectral_data(wavelengths, fluxes, errors, min_wavelength=4000):
    """
    Filters spectral data to retain only wavelengths above a specified minimum value.

    Parameters:
    - wavelengths (list of lists or np.ndarray): 2D array or list of wavelength values.
    - fluxes (list of lists or np.ndarray): 2D array or list of flux values.
    - errors (list of lists or np.ndarray): 2D array or list of error values.
    - min_wavelength (float): Minimum wavelength to retain in Ångstrom.

    Returns:
    - filtered_wavelengths (list of np.ndarray): List of filtered wavelength arrays.
    - filtered_fluxes (list of np.ndarray): List of filtered flux arrays corresponding to wavelengths.
    - filtered_errors (list of np.ndarray): List of filtered error arrays corresponding to wavelengths.
    """
    filtered_wavelengths, filtered_fluxes, filtered_errors = [], [], []

    for i in range(len(wavelengths)):
        w_array = np.array(wavelengths[i])
        f_array = np.array(fluxes[i])
        e_array = np.array(errors[i])

        mask = w_array >= min_wavelength
        filtered_wavelengths.append(w_array[mask])
        filtered_fluxes.append(f_array[mask])
        filtered_errors.append(e_array[mask])

    return filtered_wavelengths, filtered_fluxes, filtered_errors
def create_padding_mask(fluxes, target_length):
    """
    Generates a binary mask indicating the padded regions in each flux array.
    
    Parameters:
    - fluxes (list of np.array): List of flux arrays for each spectrum.
    - target_length (int): Length of the output mask arrays, matching the padded flux array length.
    
    Returns:
    - np.array: Array of binary masks for each flux array, where 1 represents original data and 0 represents padded values.
    
    Notes:
    - The mask is used during loss calculation to ignore the padded regions of the spectra.
    """
    masks = []
    for flux in fluxes:
        mask = np.ones_like(flux)
        if len(flux) < target_length:
            mask = np.concatenate([mask, np.zeros(target_length - len(flux))])
        masks.append(mask[:target_length])
    return np.array(masks)

#  Dataset and DataLoader Setup

In [5]:
class SpectralDataset(Dataset):
    def __init__(self, fluxes, wavelengths, errors, masks):
        def to_tensor(x):
            if isinstance(x, torch.Tensor):
                return x.clone().detach()
            return torch.tensor(x, dtype=torch.float32)

        def ensure_3d(x):
            x = to_tensor(x)
            return x if x.ndim == 3 else x.unsqueeze(1)

        self.fluxes = ensure_3d(fluxes)
        self.wavelengths = ensure_3d(wavelengths)
        self.errors = ensure_3d(errors)
        self.masks = ensure_3d(masks)

    def __len__(self):
        return len(self.fluxes)

    def __getitem__(self, idx):
        return self.fluxes[idx], self.wavelengths[idx], self.errors[idx], self.masks[idx]


# Convolutional Autoencoder

In [6]:
class CNNAutoencoderWithSkip(nn.Module):
    """
    A Convolutional Neural Network (CNN)-based autoencoder with skip connections for reconstructing spectral data.
    
    This model is designed to reduce the dimensionality of spectral data, capturing essential features 
    during encoding and reconstructing the data during decoding. Skip connections help retain 
    detailed information that may otherwise be lost in deeper layers, improving reconstruction 
    quality for tasks like anomaly detection or dimensionality reduction.
    
    Attributes:
    - encoder1 (nn.Conv1d): First convolutional layer of the encoder, reducing input dimensions.
    - encoder2 (nn.Conv1d): Second convolutional layer of the encoder, further reducing dimensions.
    - encoder3 (nn.Conv1d): Final convolutional layer of the encoder, creating a compressed representation.
    - decoder3 (nn.ConvTranspose1d): First transposed convolutional layer of the decoder.
    - decoder2 (nn.ConvTranspose1d): Second transposed convolutional layer of the decoder.
    - decoder1 (nn.ConvTranspose1d): Final transposed convolutional layer of the decoder, outputting the reconstructed data.
    
    Methods:
    - forward(x): Defines the forward pass of the autoencoder. The input data `x` passes through the encoder 
      layers to compress it, and then through the decoder layers to reconstruct it. Skip connections are 
      used to combine encoder and decoder layers at corresponding depths.
      
    Returns:
    - torch.Tensor: The reconstructed tensor, with values squashed between 0 and 1 using a sigmoid activation.
    """
    def __init__(self):
        super(CNNAutoencoderWithSkip, self).__init__()
        # Encoder layers with reduced channels
        self.encoder1 = nn.Conv1d(1, 64, kernel_size=3, stride=2, padding=1)
        self.encoder2 = nn.Conv1d(64, 32, kernel_size=3, stride=2, padding=1)
        self.encoder3 = nn.Conv1d(32, 16, kernel_size=3, stride=2, padding=1)  

        # Decoder layers with reduced channels
        self.decoder3 = nn.ConvTranspose1d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.decoder2 = nn.ConvTranspose1d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.decoder1 = nn.ConvTranspose1d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=0)

    def forward(self, x):
        # Encoding with skip connections
        x1 = F.relu(self.encoder1(x))  
        x2 = F.relu(self.encoder2(x1))  
        x3 = F.relu(self.encoder3(x2))  

        # Decoding with skip connections
        x = F.relu(self.decoder3(x3))
        
        # Adjust sizes for skip connection with x2
        if x2.size(2) > x.size(2):
            x2 = x2[:, :, :x.size(2)]
        elif x2.size(2) < x.size(2):
            x = x[:, :, :x2.size(2)]
        x = F.relu(self.decoder2(x + x2))
        
        # Adjust sizes for skip connection with x1
        if x1.size(2) > x.size(2):
            x1 = x1[:, :, :x.size(2)]
        elif x1.size(2) < x.size(2):
            x = x[:, :, :x1.size(2)]
        x = self.decoder1(x + x1)

        return x

# Weighted Loss Function

In [7]:
def weighted_mse_loss(observed_flux, target_flux, mask, error_on_observed_flux):
    """
    Calculates a modified Mean Squared Error (MSE) loss based on the "significance difference"
    between observed and target flux, adjusted by the observational error.

    Parameters:
    - observed_flux (torch.Tensor): The model's predicted flux values (reconstructed spectra).
    - target_flux (torch.Tensor): The ground truth flux values to be compared against observed_flux.
    - mask (torch.Tensor): A binary mask that identifies which regions in the spectra should be included
                           in the loss calculation, typically to ignore padded regions.
    - error_on_observed_flux (torch.Tensor): Observational error associated with each flux point.

    Returns:
    - torch.Tensor: The mean significance-based MSE loss, adjusted by the mask to ignore specific regions.

    Process:
    1. Compute the "significance difference" by normalising the difference between observed and target flux
       using `error_on_observed_flux`, which represents the observational uncertainty.
    2. Square the resulting "significance difference" to get a MSE-like loss, which is focused on relative
       significance rather than absolute difference.
    3. Apply the mask to this significance-based MSE loss to focus only on relevant parts of the spectrum.
    4. Return the mean of the masked, significance-based MSE loss.
    """
    significance_difference = (observed_flux - target_flux) / (error_on_observed_flux + 1e-5)
    mse_loss = significance_difference ** 2
    
    return (mse_loss * mask).mean()


# Model Training

In [8]:
def train_autoencoder(model, data, mask, errors, epochs=50, batch_size=8, lr=0.001, 
                      grad_clip=1.0, l1_lambda=1e-5):
    """
   Trains an autoencoder model using mini-batch gradient descent with gradient clipping and L1 regularization
    and a custom weighted MSE loss function that incorporates errors.

    Parameters:
    - model (torch.nn.Module): Autoencoder model.
    - data (torch.Tensor): Input tensor (spectra).
    - mask (torch.Tensor): Padding mask tensor.
    - errors (torch.Tensor): Observational errors.
    - epochs (int): Number of training epochs.
    - batch_size (int): Batch size for training.
    - lr (float): Learning rate.
    - grad_clip (float): Gradient clipping value.
    - l1_lambda (float): L1 regularization strength.
    
    Returns:
    - None
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(epochs):
        start_time = time.time()
        epoch_loss = 0
        optimizer.zero_grad()  
        
        for i in range(0, len(data), batch_size):
            batch_data = data[i:i+batch_size]
            mask_batch = mask[i:i+batch_size]
            error_batch = errors[i:i+batch_size]
            
            reconstructed = model(batch_data)
            mse_loss = weighted_mse_loss(reconstructed, batch_data, mask_batch, error_batch)

            l1_norm = sum(p.abs().sum() for p in model.parameters() if p.requires_grad)
            total_loss = mse_loss + (l1_lambda * l1_norm)

            total_loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()
            
            epoch_loss += total_loss.item()

        epoch_duration = time.time() - start_time
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(data):.8f}, Time: {epoch_duration:.2f}s")

    print(f"Final Loss after {epochs} epochs: {epoch_loss/len(data):.8f}")


# Anomaly Detection and Evaluation

In [9]:
def detect_anomalous_regions(original_fluxes, reconstructed_fluxes, window_size=50, 
                             abs_residual_threshold=0.1, rel_residual_threshold=0.1, 
                             range_mismatch_factor=1.5, overall_anomaly_threshold=0.3):
    """
    Identifies anomalous regions within spectra using residuals and mismatch factors,
    with added logging to justify each detected anomaly and the triggering value.

    Parameters:
    ----------
    original_fluxes : np.ndarray
        A 2D array of original flux values from the spectra, with shape (num_spectra, spectrum_length).
        Each row represents a spectrum, and each column corresponds to a specific wavelength or flux point.

    reconstructed_fluxes : np.ndarray
        A 2D array of reconstructed flux values from the model, with the same shape as `original_fluxes`.
        These are the model’s best attempt to recreate the original spectra, used to compute errors.

    window_size : int, optional, default=50
        The size of the sliding window (in data points) used to analyse local regions within each spectrum.

    abs_residual_threshold : float, optional, default=0.1
        The fixed threshold value for absolute residuals beyond which regions are flagged as anomalous.

    rel_residual_threshold : float, optional, default=0.1
        The fixed threshold value for relative residuals beyond which regions are flagged as anomalous.

    range_mismatch_factor : float, optional, default=1.5
        A factor to identify anomalies based on the range mismatch between original and reconstructed fluxes.

    overall_anomaly_threshold : float, optional, default=0.3
        The minimum fraction of a spectrum that must be flagged as anomalous to classify the entire spectrum as anomalous.

    Returns:
    -------
    anomalies : list of np.ndarray
        A list of boolean arrays, each representing a spectrum. Each array has True at positions corresponding to
        flagged regions in the spectrum and False otherwise.

    spectrum_anomalies : np.ndarray
        A 1D boolean array where each element represents whether the corresponding spectrum is classified as anomalous.

    anomaly_metadata : list of list of dict
        Metadata for each flagged region in each spectrum, with information on the range of the anomaly,
        the type of anomaly (high absolute residual, high relative residual, or range mismatch), and the value that triggered the flag.

    abs_residual_threshold : float
        The fixed threshold value for absolute residuals beyond which regions are flagged as anomalous.

    rel_residual_threshold : float
        The fixed threshold value for relative residuals beyond which regions are flagged as anomalous.

    range_mismatch_factor : float
        The factor used to determine if the reconstructed flux range mismatch flags a region as anomalous.
    """
    num_spectra, spectrum_length = original_fluxes.shape
    absolute_residuals = np.abs(original_fluxes - reconstructed_fluxes)
    relative_residuals = np.abs((original_fluxes - reconstructed_fluxes) / (original_fluxes + 1e-5))

    anomalies, spectrum_anomalies = [], np.zeros(num_spectra, dtype=bool)
    anomaly_metadata = []

    for i in range(num_spectra):
        spectrum_anomalies_count = 0
        spectrum_anomaly_flags = np.zeros(spectrum_length, dtype=bool)
        anomaly_justification = []

        for start in range(0, spectrum_length - window_size + 1, window_size // 2):
            end = start + window_size
            window_abs_residual = float(np.mean(absolute_residuals[i, start:end]))
            window_rel_residual = float(np.mean(relative_residuals[i, start:end]))
            original_range = float(np.ptp(original_fluxes[i, start:end]))
            reconstructed_range = float(np.ptp(reconstructed_fluxes[i, start:end]))

            if (window_abs_residual > abs_residual_threshold or
                window_rel_residual > rel_residual_threshold or
                reconstructed_range > original_range * range_mismatch_factor):
                
                spectrum_anomaly_flags[start:end] = True
                spectrum_anomalies_count += end - start

                reason = {
                    "range_start": start,
                    "range_end": end,
                    "reason_type": (
                        "high_absolute_residual" if window_abs_residual > abs_residual_threshold else
                        "high_relative_residual" if window_rel_residual > rel_residual_threshold else
                        "range_mismatch"
                    ),
                    "trigger_value": (
                        window_abs_residual if window_abs_residual > abs_residual_threshold else
                        window_rel_residual if window_rel_residual > rel_residual_threshold else
                        reconstructed_range / (original_range + 1e-5)
                    )
                }
                anomaly_justification.append(reason)

        spectrum_anomalies[i] = spectrum_anomalies_count / spectrum_length > overall_anomaly_threshold 
        anomalies.append(spectrum_anomaly_flags)
        anomaly_metadata.append(anomaly_justification)

    return anomalies, spectrum_anomalies, anomaly_metadata, abs_residual_threshold, rel_residual_threshold, range_mismatch_factor

def fitness_function(params, data):
    """
    Calculates the fitness score for anomaly detection thresholds.
    
    Parameters:
    - params (dict): Contains thresholds for:
        - abs_residual_threshold
        - rel_residual_threshold
        - range_mismatch_factor
        - overall_anomaly_threshold
    - data (dict): Contains:
        - fluxes (original spectra)
        - reconstructed_fluxes (from the model)
        
    Returns:
    - float: Fitness score (closer to 0 is better).
    """
    try:
        abs_residual_threshold = params["abs_residual_threshold"]
        rel_residual_threshold = params["rel_residual_threshold"]
        range_mismatch_factor = params["range_mismatch_factor"]
        overall_anomaly_threshold = params["overall_anomaly_threshold"]

        original_fluxes = data["fluxes"]
        reconstructed_fluxes = data["reconstructed_fluxes"]

        anomalies, spectrum_anomalies, _, _, _, _ = detect_anomalous_regions(
            original_fluxes=original_fluxes,
            reconstructed_fluxes=reconstructed_fluxes,
            abs_residual_threshold=abs_residual_threshold,
            rel_residual_threshold=rel_residual_threshold,
            range_mismatch_factor=range_mismatch_factor,
            overall_anomaly_threshold=overall_anomaly_threshold,
        )

        reconstruction_errors = compute_reconstruction_errors(original_fluxes, reconstructed_fluxes, method="mse")
        mean_reconstruction_error = np.mean(reconstruction_errors)

        anomaly_ratio = np.mean(spectrum_anomalies)

        false_negative_penalty = np.mean([
            1 if not spectrum_anomalies[i] and np.any(anomalies[i]) else 0
            for i in range(len(original_fluxes))
        ])
        false_positive_penalty = np.mean([
            1 if spectrum_anomalies[i] and not np.any(anomalies[i]) else 0
            for i in range(len(original_fluxes))
        ])

        # Add a clear penalty for excessive anomalies
        excessive_anomaly_penalty = max(0, anomaly_ratio - 0.5) ** 2  # Penalize ratios above 50%

        fitness = (
            -mean_reconstruction_error * 10
            - false_negative_penalty * 5
            - false_positive_penalty * 5
            - excessive_anomaly_penalty * 20  
        )
        return fitness,

    except Exception as e:
        print(f"[ERROR] Fitness function failed: {e}")
        return -1e6,


def detection_parameters(data, param_bounds, pop_size=20, generations=50, mutation_rate=0.2):
    """
    Optimise anomaly detection thresholds using a genetic algorithm.

    Parameters:
    - data (dict): Contains 'fluxes' and 'reconstructed_fluxes'.
    - param_bounds (dict): Parameter boundaries for optimisation.
    - pop_size (int): Population size.
    - generations (int): Number of generations.
    - mutation_rate (float): Probability of mutation.

    Returns:
    - dict: Best parameter configuration found.
    """
    creator.create("FitnessMax", base.Fitness, weights=(1.0,))
    creator.create("Individual", list, fitness=creator.FitnessMax)

    toolbox = base.Toolbox()

    for key, bounds in param_bounds.items():
        toolbox.register(f"attr_{key}", random.uniform, bounds[0], bounds[1])

    toolbox.register(
        "individual",
        tools.initCycle,
        creator.Individual,
        tuple(getattr(toolbox, f"attr_{key}") for key in param_bounds.keys()),
        n=1,
    )
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)

    def deap_fitness(individual):
        params = {key: individual[idx] for idx, key in enumerate(param_bounds.keys())}
        return fitness_function(params, data)

    toolbox.register("evaluate", deap_fitness)

    min_bounds = [bounds[0] for bounds in param_bounds.values()]
    max_bounds = [bounds[1] for bounds in param_bounds.values()]
    toolbox.register("mate", tools.cxBlend, alpha=0.5)
    toolbox.register("mutate", tools.mutPolynomialBounded, eta=20.0, low=min_bounds, up=max_bounds, indpb=mutation_rate)
    toolbox.register("select", tools.selTournament, tournsize=3)

    population = toolbox.population(n=pop_size)

    for gen in tqdm(range(generations), desc="Genetic Algorithm Progress"):
        offspring = toolbox.select(population, len(population))
        offspring = list(map(toolbox.clone, offspring))

        for child1, child2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < 0.5:
                toolbox.mate(child1, child2)
                del child1.fitness.values
                del child2.fitness.values

        for mutant in offspring:
            if random.random() < mutation_rate:
                toolbox.mutate(mutant)
                del mutant.fitness.values

        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        fitnesses = map(toolbox.evaluate, invalid_ind)
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit

        population[:] = offspring

    best_individual = tools.selBest(population, k=1)[0]
    best_params = {key: best_individual[idx] for idx, key in enumerate(param_bounds.keys())}
    best_fitness = best_individual.fitness.values[0]
    print(f"Optimised Parameters: {best_params}")
    print(f"Best Fitness: {best_fitness}")
    return best_params, best_fitness
def compute_dynamic_anomaly_ratio(reconstruction_errors, mad_multiplier=3):
    """
    Dynamically computes the anomaly ratio using the Median Absolute Deviation (MAD).

    Parameters:
    - reconstruction_errors (np.ndarray): Array of reconstruction errors.
    - mad_multiplier (float): Multiplier for the MAD to define the threshold.

    Returns:
    - float: Target anomaly ratio based on the error distribution.
    """
    median_error = np.median(reconstruction_errors)
    mad_error = np.median(np.abs(reconstruction_errors - median_error))
    threshold = median_error + mad_multiplier * mad_error

    anomalies = reconstruction_errors > threshold
    target_anomaly_ratio = np.mean(anomalies)  

    print(f"Dynamic Anomaly Ratio: {target_anomaly_ratio:.4f} (Threshold: {threshold:.4f})")
    return target_anomaly_ratio
def compute_reconstruction_errors(original_fluxes, reconstructed_fluxes, method="mse"):
    """
    Computes reconstruction errors between original and reconstructed fluxes.

    Parameters:
    - original_fluxes (np.ndarray): Original flux values.
    - reconstructed_fluxes (np.ndarray): Reconstructed flux values from the model.
    - method (str): Error computation method, either 'mse' for mean squared error or 'mae' for mean absolute error.

    Returns:
    - np.ndarray: Array of errors with shape (num_spectra, num_wavelengths).
    """
    if method == "mse":
        errors = (original_fluxes - reconstructed_fluxes) ** 2
    elif method == "mae":
        errors = np.abs(original_fluxes - reconstructed_fluxes)
    else:
        raise ValueError("Method must be 'mse' or 'mae'")
    
    return errors

# Bayesian Optimization

In [10]:
# --- Define Parameter Space ---
param_space = [
    Real(0.0, 1.0, name="abs_residual_threshold"),
    Real(0.05, 2.0, name="rel_residual_threshold"),
    Real(1.0, 3.0, name="range_mismatch_factor"),
    Real(0.0, 1.0, name="overall_anomaly_threshold"),
]

# --- Define Objective Function ---
@use_named_args(param_space)
def objective_function(**params):
    fitness, = fitness_function(params, data_for_bo)  
    return -fitness
def plot_bo_results(bo_results, param_names):
    """
    Generates visualizations for Bayesian Optimization results:
    1. Convergence plot (Fitness improvement over iterations)
    2. Parameter evaluations (How parameters were sampled)
    3. Objective function heatmap (Relationship between parameters)

    Parameters:
    - bo_results: The result object from gp_minimize
    - param_names: List of parameter names corresponding to bo_results
    """

    fig, axes = plt.subplots(3, 1, figsize=(10, 18))  # More vertical space
    
    # Plot convergence (Fitness over iterations)
    plot_convergence(bo_results, ax=axes[0])
    axes[0].set_title("Bayesian Optimization Convergence", fontsize=14)
    axes[0].grid(True)

    # Plot evaluations (Parameter search distribution)
    plot_evaluations(bo_results, dimensions=param_names, ax=axes[1])
    axes[1].set_title("Parameter Evaluations", fontsize=14)
    
    # Plot objective function relationships
    plot_objective(bo_results, dimensions=param_names, ax=axes[2])
    axes[2].set_title("Parameter Relationships", fontsize=14)

    # Adjust layout for better spacing
    plt.tight_layout(pad=4.0, h_pad=4.0)  # Adds space between subplots
    plt.show()
def plot_bo_3d_grid(bo_results, save_directory='SpectralCNNAutoencoder_output'):
    """
    Generates a structured grid of visualizations for Bayesian Optimization results:
    - Diagonal: 1D Histograms of individual parameter distributions.
    - Lower triangle: 3D surface plots with fitness as the Z-axis.
    - Upper triangle: 2D KDE density plots to show optimizer search behavior.

    Parameters:
    - bo_results: The result object from `gp_minimize`
    - save_directory: Directory to save the figure.
    """

    param_names = ["Absolute Threshold", "Relative Threshold", "Range Disparity", "Overall Threshold"]
    param_values = np.array(bo_results.x_iters)
    fitness_vals = np.array(bo_results.func_vals)
    num_params = len(param_names)

    fig, axes = plt.subplots(num_params, num_params, figsize=(18, 18)) 

    for i in range(num_params):
        for j in range(num_params):
            ax = axes[i, j]

            if i == j:
                ax.hist(param_values[:, i], bins=20, color="steelblue", alpha=0.7, edgecolor="black")
                ax.set_xlabel(param_names[i], fontsize=14, labelpad=10)
                ax.set_ylabel("Frequency", fontsize=14, labelpad=10)
                ax.tick_params(axis='both', which='major', labelsize=14)

            elif i > j:
                ax.remove()
                ax = fig.add_subplot(num_params, num_params, i*num_params + j + 1, projection='3d')
                x_vals, y_vals = param_values[:, j], param_values[:, i]
                z_vals = fitness_vals  
                grid_x, grid_y = np.meshgrid(
                    np.linspace(min(x_vals), max(x_vals), 50),
                                    np.linspace(min(y_vals), max(y_vals), 50)
                )
                grid_z = griddata((x_vals, y_vals), z_vals, (grid_x, grid_y), method='cubic', fill_value=np.nan)
                grid_z = np.nan_to_num(grid_z, nan=np.nanmin(grid_z))
                
                ax.plot_surface(grid_x, grid_y, grid_z, cmap="coolwarm_r", alpha=0.8, edgecolor="k", linewidth=0.3)

                ax.set_xlabel(param_names[j], fontsize=14, labelpad=10)
                ax.set_ylabel(param_names[i], fontsize=14, labelpad=10)
                ax.set_zlabel("Fitness", fontsize=14, labelpad=10)

                ax.zaxis.labelpad = -150 
                ax.zaxis.label.set_rotation(90)  
                ax.zaxis.label.set_verticalalignment('bottom')  
                
                ax.tick_params(axis='both', which='major', labelsize=14)

            else:
                x_vals, y_vals = param_values[:, j], param_values[:, i]

                xy = np.vstack([x_vals, y_vals])
                density = gaussian_kde(xy)(xy)

                grid_x, grid_y = np.meshgrid(
                    np.linspace(min(x_vals), max(x_vals), 100),
                    np.linspace(min(y_vals), max(y_vals), 100)
                )
                grid_z = griddata((x_vals, y_vals), density, (grid_x, grid_y), method='cubic')

                ax.contourf(grid_x, grid_y, grid_z, levels=10, cmap="coolwarm", alpha=0.7)  
                ax.scatter(x_vals, y_vals, c=density, cmap="coolwarm", s=50, alpha=0.8, edgecolor="black")

                ax.set_xlabel(param_names[j], fontsize=14, labelpad=10)
                ax.set_ylabel(param_names[i], fontsize=14, labelpad=10)
                ax.tick_params(axis='both', which='major', labelsize=12)
                
    cbar_ax = fig.add_axes([0.15, 0.02, 0.7, 0.015])
    sm = plt.cm.ScalarMappable(cmap="coolwarm_r", norm=plt.Normalize(min(fitness_vals), max(fitness_vals)))
    sm.set_array([])  
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal')
    cbar.set_label("Fitness Score (Lower is better)", fontsize=16)

    plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.08, wspace=0.3, hspace=0.3)

    base_filename = "bo_results_3d_grid"
    save_path = create_save_path(save_directory, base_filename)

    plt.savefig(save_path, dpi=600, bbox_inches='tight')
    plt.close(fig)

    print(f"Plot saved to {save_path}")

# Visualisation and Plotting Functions

In [11]:
def plot_spectra_combined(
    original_fluxes, reconstructed_fluxes, wavelengths, spec_data,
    anomalous_regions, spectrum_anomalies,
    save_directory='SpectralCNNAutoencoder_output',
    max_samples=10, seed=None, only_anomalous=False
):
    def fetch_image(ra, dec, pixscale=0.262, width=256, height=256, survey="ls-dr10"):
        url = f"https://www.legacysurvey.org/viewer/jpeg-cutout?ra={ra}&dec={dec}&pixscale={pixscale}&layer={survey}&size={max(width, height)}"
        try:
            response = requests.get(url, timeout=5)
            if response.status_code == 200:
                return Image.open(BytesIO(response.content)).convert("RGBA")
        except Exception:
            return None

    if seed is None:
        seed = random.randint(0, 10000)
    random.seed(seed)
    np.random.seed(seed)

    indices = [i for i, is_anom in enumerate(spectrum_anomalies) if is_anom] if only_anomalous else list(range(len(original_fluxes)))
    if not indices:
        print("No spectra to plot.")
        return [], [], seed

    selected_indices = random.sample(indices, min(max_samples, len(indices)))
    fig, axes = plt.subplots(len(selected_indices), 2, figsize=(20, 6 * len(selected_indices)), gridspec_kw={'width_ratios': [3, 1]})
    if len(selected_indices) == 1:
        axes = [axes]

    for idx, i in enumerate(selected_indices):
        ra, dec = spec_data.iloc[i]['mean_fiber_ra'], spec_data.iloc[i]['mean_fiber_dec']
        x_axis = wavelengths if wavelengths is not None else range(len(original_fluxes[i]))

        ax_spectra = axes[idx][0] if isinstance(axes[idx], (list, tuple)) else axes[0]
        ax_spectra.plot(x_axis, original_fluxes[i], label="Original", color='#2c7bb6', linewidth=0.5)
        ax_spectra.plot(x_axis, reconstructed_fluxes[i], label="Reconstructed", color='#d7191c', alpha=0.7, linewidth=0.5)

        in_anomaly = False
        anomaly_start = 0
        for j in range(len(x_axis)):
            if anomalous_regions[i][j] and not in_anomaly:
                anomaly_start = j
                in_anomaly = True
            elif not anomalous_regions[i][j] and in_anomaly:
                ax_spectra.axvspan(x_axis[anomaly_start], x_axis[j], color='#fdae61', alpha=0.7)
                in_anomaly = False
        if in_anomaly:
            ax_spectra.axvspan(x_axis[anomaly_start], x_axis[-1], color='#fdae61', alpha=0.7)

        if spectrum_anomalies[i]:
            ax_spectra.set_facecolor('#fee090')
            ax_spectra.set_title(f"Spectrum ID: {spec_data['targetid'].iloc[i]} - Anomalous", color='red', fontsize=23)
        else:
            ax_spectra.set_title(f"Spectrum ID: {spec_data['targetid'].iloc[i]}", color='black', fontsize=23)

        ax_spectra.xaxis.set_visible(False)
        ax_spectra.set_ylabel("Flux (normalised)", fontsize=20)
        ax_spectra.legend()

        divider = make_axes_locatable(ax_spectra)
        ax_residual = divider.append_axes("bottom", size="25%", pad=0, sharex=ax_spectra)
        ax_residual.plot(x_axis, original_fluxes[i] - reconstructed_fluxes[i], color='#4dac26', linewidth=0.5)
        ax_residual.axhline(0, color='black', linestyle='-', linewidth=0.5)
        ax_residual.set_ylabel("Residual", fontsize=20)
        ax_residual.set_xlabel("Wavelength (Å)", fontsize=20)

        ax_image = axes[idx][1] if isinstance(axes[idx], (list, tuple)) else axes[1]
        img_data = fetch_image(ra, dec)
        if img_data:
            ax_image.imshow(img_data)
            ax_image.set_title(f"Galaxy Image\nRA={ra:.4f}, Dec={dec:.4f}", fontsize=23)
            circle = patches.Circle((img_data.width / 2, img_data.height / 2), 30, edgecolor="white", facecolor="none", linewidth=2)
            ax_image.add_patch(circle)
        else:
            ax_image.text(0.5, 0.5, "Image Not Found", ha="center", va="center", fontsize=14)
        ax_image.axis("off")

    plt.tight_layout()
    base_filename = "anomalous_spectra" if only_anomalous else "spectra_reconstruction"
    save_path = create_save_path(save_directory, base_filename)
    plt.savefig(save_path, dpi=300)
    plt.close(fig)
    print(f"Saved plot to {save_path}, seed: {seed}")
    return indices, selected_indices, seed

def plot_reconstruction_error_distribution(reconstruction_errors, wavelengths, save_directory='output_images'):
    """
    Visualises the reconstruction error distribution across wavelengths with both linear and log-scale plots.

    Parameters:
    - reconstruction_errors (np.ndarray): Array of reconstruction errors per spectrum and wavelength.
    - wavelengths (array-like): Array of wavelength values corresponding to each error point.
    - save_directory (str, optional): Directory path to save the generated plot image.
    """
    if wavelengths is None or len(wavelengths) != reconstruction_errors.shape[1]:
        raise ValueError("Wavelengths array must be provided and must match the number of wavelengths in reconstruction errors.")
    
    mean_errors = np.mean(reconstruction_errors, axis=0)

    os.makedirs(save_directory, exist_ok=True)

    # --- Linear Scale Plot ---
    plt.figure(figsize=(12, 7))
    plt.plot(wavelengths, mean_errors, label='Mean Reconstruction Error', color='#0571b0', linewidth=0.5)
    plt.fill_between(wavelengths, mean_errors, color='#67a9cf', alpha=0.7)
    plt.xlabel("Wavelength (Å)")
    plt.ylabel("Mean Reconstruction Error")
    plt.legend()
    
    linear_plot_path = create_save_path(save_directory, 'mean_reconstruction_error_linear')
    plt.savefig(linear_plot_path, dpi=300)
    plt.close()
    print(f"Linear plot saved to {linear_plot_path}")

    # --- Log Scale Plot ---
    plt.figure(figsize=(12, 7))
    plt.plot(wavelengths, mean_errors, label='Mean Reconstruction Error (Log Scale)', color='#0571b0', linewidth=0.5)
    plt.fill_between(wavelengths, mean_errors, color='#67a9cf', alpha=0.7)
    plt.xlabel("Wavelength (Å)")
    plt.ylabel("Mean Reconstruction Error (Log Scale)")
    plt.yscale("log")  # Log scale applied here
    plt.legend()
    log_plot_path = create_save_path(save_directory, 'mean_reconstruction_error_log')
    plt.savefig(log_plot_path, dpi=300)
    plt.close()
    print(f"Log plot saved to {log_plot_path}")



# External Database Cross-Matching 

In [12]:
def search_for_anomalies(spectrum_anomalies, spec_data, radius_arcsec=1.5):
    """
    Searches multiple spectral databases (ZTF, Gaia Alerts) for known anomalies and retrieves images.
    Identifies whether an object has had an alert and extracts alert classifications.

    Parameters:
    ----------
    spectrum_anomalies : np.ndarray
        Boolean array indicating which spectra are anomalous.

    spec_data : pd.DataFrame
        Metadata containing RA/DEC for each spectrum, allowing for database lookups.

    radius_arcsec : float, default=1.5
        Search radius in arcseconds.

    Returns:
    -------
    results : dict
        Dictionary mapping coordinates to anomaly data, alert sources, and retrieved images.
    """
    anomalous_indices = [i for i, is_anomalous in enumerate(spectrum_anomalies) if is_anomalous]

    if not anomalous_indices:
        print("No anomalous spectra detected.")
        return {}

    anomalous_coords = [
        (float(spec_data.iloc[i]['mean_fiber_ra']), float(spec_data.iloc[i]['mean_fiber_dec']))
        for i in anomalous_indices
    ]

    results = {}

    for ra, dec in tqdm(anomalous_coords, desc="Searching anomalies", unit="coord"):
        query_results = {}
        alert_sources = []

        print(f"\nSearching for anomalies at RA: {ra}, DEC: {dec} (±{radius_arcsec} arcsec)")

        ztf_url = f"https://irsa.ipac.caltech.edu/cgi-bin/ZTF/nph_light_curves?POS={ra},{dec}&SIZE={radius_arcsec/3600}&FORMAT=json"
        print(f"ZTF Query URL: {ztf_url}")

        try:
            ztf_response = requests.get(ztf_url, timeout=5)
            if ztf_response.status_code == 200:
                ztf_data = ztf_response.json()
                if ztf_data and "data" in ztf_data and len(ztf_data["data"]) > 0:
                    alert_sources.append("ZTF")
                    print(f"ZTF Alert Found: {ztf_data['data']}")
                else:
                    print("No ZTF alerts found.")
            else:
                print("ZTF query failed.")
        except requests.RequestException:
            print("ZTF request failed.")
        
        query_results["ZTF"] = ztf_data if "ztf_data" in locals() else "Query failed"

        gaia_url = f"http://gsaweb.ast.cam.ac.uk/alerts/alerts/cone_search?ra={ra}&dec={dec}&radius={radius_arcsec/3600}"
        print(f"Gaia Query URL: {gaia_url}")

        try:
            gaia_response = requests.get(gaia_url, timeout=5)
            if gaia_response.status_code == 200 and "transient" in gaia_response.text:
                alert_sources.append("Gaia")
                print(f"Gaia Alert Found: {gaia_response.text}")
            else:
                print("No Gaia alerts found.")
        except requests.RequestException:
            print("Gaia request failed.")

        query_results["Gaia"] = gaia_response.text if "gaia_response" in locals() else "Query failed"

        image_sources = {
            "SDSS": f"https://skyserver.sdss.org/dr16/SkyServerWS/ImgCutout/getjpeg?ra={ra}&dec={dec}&scale=0.2&width=300&height=300",
            "Legacy Survey": f"https://www.legacysurvey.org/viewer/jpeg-cutout?ra={ra}&dec={dec}&layer=ls-dr10&size=300"
        }

        fig, axes = plt.subplots(1, len(image_sources), figsize=(10, 5))

        if len(image_sources) == 1:
            axes = [axes]  

        for idx, (source, url) in enumerate(image_sources.items()):
            try:
                img = Image.open(BytesIO(requests.get(url, timeout=5).content)).convert("RGBA")
                draw = ImageDraw.Draw(img)
                width, height = img.size
                center = (width // 2, height // 2)

                radius_pixels = int((radius_arcsec / 0.04) * (width / 300))

                draw.ellipse([
                    (center[0] - radius_pixels, center[1] - radius_pixels),
                    (center[0] + radius_pixels, center[1] + radius_pixels)
                ], outline="white", width=2)

                axes[idx].imshow(img)
                axes[idx].set_title(f"{source} Image")
                axes[idx].axis("off")
            except:
                axes[idx].set_title(f"{source} Image Not Found")

        plt.show()

        results[(ra, dec)] = {
            "database_results": query_results,
            "alert_sources": alert_sources,
            "image_urls": image_sources
        }

    return results

# Information Saving

In [13]:
def create_save_path(save_directory, base_filename):
    """
    Generates a unique file path for PNG files in the specified directory by appending
    a sequential numeric suffix to avoid overwriting existing files. This ensures that
    saved file names are uniquely numbered.

    Parameters:
    - save_directory (str): The directory where the file should be saved. It will be created if it doesn't exist.
    - base_filename (str): The base name for the file, to which a numeric suffix will be added.

    Returns:
    - str: A full path to the new file with a sequentially numbered suffix in the specified directory,
           following the format '{base_filename}_{number}.png'.
    """

    os.makedirs(save_directory, exist_ok=True)
    existing_files = os.listdir(save_directory)
    numbers = [
        int(re.search(r'\d+', f).group())
        for f in existing_files if re.search(fr'{base_filename}_(\d+)\.png', f)
    ]
    next_number = max(numbers) + 1 if numbers else 1
    full_path = os.path.join(save_directory, f'{base_filename}_{next_number}.png')
    
    relative_path = os.path.relpath(full_path, start=os.path.dirname(save_directory))
    
    return relative_path
    
def create_json_save_path(save_directory, base_filename):
    """
    Generates a unique file path for JSON files in the specified directory by appending
    a sequential numeric suffix to avoid overwriting existing files. This ensures that
    saved JSON file names are uniquely numbered.

    Parameters:
    - save_directory (str): The directory where the JSON file should be saved. It will be created if it doesn't exist.
    - base_filename (str): The base name for the JSON file, to which a numeric suffix will be added.

    Returns:
    - str: A full path to the new JSON file with a sequentially numbered suffix in the specified directory,
           following the format '{base_filename}_{number}.json'.
    """
    os.makedirs(save_directory, exist_ok=True)
    existing_files = os.listdir(save_directory)
    numbers = [
        int(re.search(rf'{base_filename}_(\d+)\.json', f).group(1))
        for f in existing_files if re.search(rf'{base_filename}_(\d+)\.json', f)
    ]
    next_number = max(numbers) + 1 if numbers else 1
    full_path = os.path.join(save_directory, f'{base_filename}_{next_number}.json')
    relative_path = os.path.relpath(full_path, start=os.path.dirname(save_directory))
    
    return relative_path
def save_sampling_info(galaxy_ids, plot_galaxy_ids, seed, anomaly_metadata, abs_residual_threshold, 
                       rel_residual_threshold, range_mismatch_factor, json_directory, spec_data):
    """
    Saves detailed sampling information for detected anomalous spectra to a JSON file. This includes 
    metadata on each detected anomaly, threshold values, and unique target IDs associated with the anomalies.

    Parameters:
    - galaxy_ids (list of int): Indices of galaxies identified as anomalous, corresponding to row indices in `spec_data`.
    - plot_galaxy_ids (list of int): Indices of galaxies selected for plotting, used to match target IDs if needed.
    - seed (int): Random seed used during sampling to ensure reproducibility.
    - anomaly_metadata (list of list of dict): Metadata for each galaxy's anomaly detection results, where each inner list
      contains dictionaries with specific anomaly information (e.g., range, justification, trigger value).
    - abs_residual_threshold (float): Threshold for detecting high absolute residual anomalies.
    - rel_residual_threshold (float): Threshold for detecting high relative residual anomalies.
    - range_mismatch_factor (float): Factor threshold for range mismatches between original and reconstructed spectra.
    - json_directory (str): Directory path where the JSON file should be saved. Will be created if it doesn't exist.
    - spec_data (DataFrame): DataFrame containing galaxy catalog information, including `targetid` for each galaxy.

    Returns:
    - None: The function saves a JSON file containing the sampling and anomaly data.

    Process:
    1. Maps `galaxy_ids` to actual `targetid` values using `spec_data` to provide unique identifiers for each galaxy.
    2. Constructs a dictionary `sampling_info` to store:
       - Count of anomalies detected.
       - Sampling seed for reproducibility.
       - Thresholds for absolute residual, relative residual, and range mismatch.
       - An `anomalous_spectra` dictionary keyed by `targetid`, where each entry holds `anomaly_ranges` metadata.
    3. For each `targetid`, populates `anomaly_ranges` with details on each anomaly's range, justification, 
       and the value triggering the anomaly.
    4. Uses `create_json_save_path` to generate a unique, sequentially numbered file path for saving.
    5. Writes `sampling_info` as JSON to the generated path and confirms save location with a printed message.
    """
    target_ids = [int(spec_data['targetid'].iloc[idx]) for idx in galaxy_ids]

    sampling_info = {
        "Number of Anomalies": len(target_ids),
        "Seed": int(seed),
        "absolute_residual_threshold": float(abs_residual_threshold),
        "relative_residual_threshold": float(rel_residual_threshold),
        "range_mismatch_factor": float(range_mismatch_factor),
        "anomalous_spectra": {}
    }

    for idx, target_id in enumerate(target_ids):
        metadata = anomaly_metadata[idx]
        anomaly_ranges = []

        for reason in metadata:
            anomaly_ranges.append({
                "range_start": reason["range_start"],
                "range_end": reason["range_end"],
                "justification": reason["reason_type"],
                "value": reason.get("trigger_value", None)  
            })

        sampling_info["anomalous_spectra"][target_id] = {
            "anomaly_ranges": anomaly_ranges
        }

    json_path = create_json_save_path(json_directory, 'sampling_info')
    secondary_path = create_json_save_path(OUT_DIR, 'sampling_info')

    with open(json_path, 'w') as json_file:
        json.dump(sampling_info, json_file, indent=4)
    with open(secondary_path, 'w') as json_file:
        json.dump(sampling_info, json_file, indent=4)
    
    print(f"Sampling information saved to {json_path} and {secondary_path}")

# Main Execution Flow

## --- Data Loading ---

In [14]:
spec_data = load_or_query_data(CSV_PATH)
all_fluxes, all_wavelengths, all_errors = process_spectra_data(spec_data)
max_length = max(len(f) for f in all_fluxes if f is not None)

padded_fluxes = pad_or_truncate(all_fluxes, max_length)
padded_wavelengths = pad_or_truncate(all_wavelengths, max_length)
padded_errors = pad_or_truncate(all_errors, max_length)
mask = np.array(create_padding_mask(padded_fluxes, max_length), dtype=np.float32)

Loading data from /Users/elicox/Desktop/Mac/Work/Yr4 Work/Project/CNN-auto/spectra_data_2.csv...


Processing spectra in batches: 100%|██████████| 28/28 [08:56<00:00, 19.15s/it]


Saving updated DataFrame with flux and wavelength data to /Users/elicox/Desktop/Mac/Work/Yr4 Work/Project/CNN-auto/spectra_data_2.csv...


In [15]:

# --- Dataset/Dataloader ---
dataset = SpectralDataset(padded_fluxes, padded_wavelengths, padded_errors, mask)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0)

# --- Extract for Filtering ---
full_fluxes, full_wavelengths, full_errors = [], [], []
for batch in dataloader:
    f, w, e, _ = batch
    full_fluxes.extend(f.squeeze(1).numpy())
    full_wavelengths.extend(w.squeeze(1).numpy())
    full_errors.extend(e.squeeze(1).numpy())

filtered_wavelengths, filtered_fluxes, filtered_errors = filter_spectral_data(
    np.array(full_wavelengths), np.array(full_fluxes), np.array(full_errors), min_wavelength=4000
)

# --- Pad Again Post-Filtering ---
max_length = max(len(f) for f in filtered_fluxes)
padded_fluxes = pad_or_truncate(filtered_fluxes, max_length)
padded_wavelengths = pad_or_truncate(filtered_wavelengths, max_length)
padded_errors = pad_or_truncate(filtered_errors, max_length)
filtered_mask_tensor = (np.array(padded_fluxes) > 0).astype(np.float32)

# --- Tensors ---
flux_tensor = torch.tensor(padded_fluxes, dtype=torch.float32).unsqueeze(1)
error_tensor = torch.tensor(padded_errors, dtype=torch.float32).unsqueeze(1)
mask_tensor = torch.tensor(filtered_mask_tensor, dtype=torch.float32).unsqueeze(1)
wavelength_tensor = torch.tensor(padded_wavelengths, dtype=torch.float32).unsqueeze(1)

# --- Train Model ---
autoencoder = CNNAutoencoderWithSkip()
train_autoencoder(autoencoder, flux_tensor, mask_tensor, error_tensor)

# --- Reconstruct ---
reconstructed_fluxes = []
for batch in DataLoader(SpectralDataset(flux_tensor, wavelength_tensor, error_tensor, mask_tensor), batch_size=16):
    inputs, _, _, _ = batch
    with torch.no_grad():
        outputs = autoencoder(inputs).squeeze(1).cpu().numpy()
        reconstructed_fluxes.extend(outputs)

# --- Anomaly Detection with Optimised Parameters ---
filtered_flux_np = flux_tensor.squeeze(1).cpu().numpy()
reconstructed_flux_np = np.array(reconstructed_fluxes)

data_for_bo = {"fluxes": filtered_flux_np, "reconstructed_fluxes": reconstructed_flux_np}

# Run BO
bo_results = gp_minimize(
    func=objective_function,
    dimensions=param_space,
    acq_func="EI",
    n_calls=300,
    n_random_starts=50,
    n_jobs=15,
    random_state=42
)

# Extract best parameters
best_params = {dim.name: val for dim, val in zip(param_space, bo_results.x)}
best_fitness = -bo_results.fun

# Detect anomalies
(anomalous_regions, spectrum_anomalies, anomaly_metadata, 
 abs_thresh, rel_thresh, range_factor) = detect_anomalous_regions(
    original_fluxes=filtered_flux_np,
    reconstructed_fluxes=reconstructed_flux_np,
    window_size=50,
    abs_residual_threshold=best_params["abs_residual_threshold"],
    rel_residual_threshold=best_params["rel_residual_threshold"],
    range_mismatch_factor=best_params["range_mismatch_factor"],
    overall_anomaly_threshold=best_params["overall_anomaly_threshold"]
)

# --- Visualisation ---
wavelengths = np.mean(padded_wavelengths, axis=0)

plot_spectra_combined(
    original_fluxes=filtered_flux_np,
    reconstructed_fluxes=reconstructed_flux_np,
    anomalous_regions=anomalous_regions,
    spectrum_anomalies=spectrum_anomalies,
    spec_data=spec_data,
    wavelengths=wavelengths,
    only_anomalous=True,
    max_samples=10
)

plot_reconstruction_error_distribution(
    reconstruction_errors=compute_reconstruction_errors(filtered_flux_np, reconstructed_flux_np),
    wavelengths=wavelengths,
    save_directory=OUT_DIR
)
plot_bo_3d_grid(bo_results, save_directory=OUT_DIR)
# plot_bo_results(bo_results, [dim.name for dim in param_space]) -- commented out for storage reasons

# search_for_anomalies(spectrum_anomalies, spec_data) -- commented out as this can take a lot of storage when training is not perfect

Epoch [1/50], Loss: 0.19003381, Time: 3.70s
Epoch [2/50], Loss: 0.08363794, Time: 3.70s
Epoch [3/50], Loss: 0.06887204, Time: 3.26s
Epoch [4/50], Loss: 0.06665526, Time: 3.17s
Epoch [5/50], Loss: 0.06624749, Time: 3.12s
Epoch [6/50], Loss: 0.06588427, Time: 3.26s
Epoch [7/50], Loss: 0.06579071, Time: 3.23s
Epoch [8/50], Loss: 0.06545078, Time: 3.11s
Epoch [9/50], Loss: 0.06530042, Time: 3.10s
Epoch [10/50], Loss: 0.06537464, Time: 3.11s
Epoch [11/50], Loss: 0.06508280, Time: 3.25s
Epoch [12/50], Loss: 0.06514013, Time: 3.30s
Epoch [13/50], Loss: 0.06509494, Time: 3.24s
Epoch [14/50], Loss: 0.06499017, Time: 3.09s
Epoch [15/50], Loss: 0.06495630, Time: 3.37s
Epoch [16/50], Loss: 0.06493631, Time: 3.15s
Epoch [17/50], Loss: 0.06495822, Time: 3.11s
Epoch [18/50], Loss: 0.06484919, Time: 3.13s
Epoch [19/50], Loss: 0.06473063, Time: 3.08s
Epoch [20/50], Loss: 0.06483282, Time: 3.21s
Epoch [21/50], Loss: 0.06480446, Time: 3.27s
Epoch [22/50], Loss: 0.06480990, Time: 3.20s
Epoch [23/50], Loss