# Goals
- Access and filter DESI EDR galaxy spectra data from a database using SPARCL.
- Process and normalize the spectra data to prepare it for model training.
- Develop a CNN autoencoder with skip connections to perform dimensionality reduction and reconstruction of the spectra.
- Train the autoencoder model using a weighted mean squared error (MSE) loss function to emphasize critical spectral features.
- Identify and visualize anomalies in the galaxy spectra based on high reconstruction errors.
- Provide visual representations of detected anomalies and evaluate the model's performance through training loss metrics.


# Summary
This project leverages the Dark Energy Spectroscopic Instrument (DESI) Early Data Release (EDR) dataset to train a Convolutional Neural Network (CNN) autoencoder for anomaly detection in galaxy spectra. The code retrieves galaxy spectra data from a database, processes and normalizes it, and then applies an autoencoder with skip connections to reconstruct the spectra. The reconstruction errors are used to identify anomalous spectra, which may indicate unusual features or observational issues in the data.

# Imports

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from tqdm import tqdm
from sparcl.client import SparclClient
from dl import queryClient as qc, authClient as ac
from getpass import getpass
import os
import re
import csv 
import torch.nn.functional as F
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import AutoMinorLocator
from torchviz import make_dot
import time
import random
from concurrent.futures import ThreadPoolExecutor, as_completed 
import logging



In [2]:
# Configure file directories
DATA_DIR = '/Users/elicox/Desktop/Mac/Work/Yr4 Work/Project/CNN-auto/'
IMG_DIR = os.path.join(DATA_DIR, 'output_images')
CSV_PATH = os.path.join(DATA_DIR, 'spectra_data.csv')
os.makedirs(IMG_DIR, exist_ok=True)
#os.environ["PATH"] += os.pathsep + "/opt/homebrew/bin" # uncomment when CNN visualisation is needed

# Initialize SPARCL client
client = SparclClient()

# Data Loading and Saving Functions

In [3]:
# Function to query the database
def query_spectra_data():
    """
    Queries the DESI EDR database to retrieve primary galaxy spectra.
    Returns a DataFrame with the spectra data.
    """
    query = """
    SELECT zp.targetid, zp.survey, zp.program, zp.healpix,  
           zp.z, zp.zwarn, zp.coadd_fiberstatus, zp.spectype, 
           zp.mean_fiber_ra, zp.mean_fiber_dec, zp.zcat_nspec, 
           zp.desi_target, zp.sv1_desi_target, zp.sv2_desi_target, zp.sv3_desi_target
    FROM desi_edr.zpix AS zp
    WHERE zp.zcat_primary = 't'
      AND zp.zcat_nspec > 2
      AND zp.spectype = 'GALAXY'
    """
    try:
        zpix_cat = qc.query(sql=query, fmt='table')
        df = zpix_cat.to_pandas()
        print(f"Retrieved {len(df)} records from the database.")
        return df
    except Exception as e:
        print(f"Error querying data: {e}")
        return None

# Data Preparation Functions

In [4]:
def retrieve_flux(targetid, inc, retries=5, delay=2):
    """
    Helper function to retrieve and normalize flux for a single target ID.
    Attempts multiple retries on failure.
    """
    for attempt in range(1, retries + 1):
        try:
            # Attempt to retrieve the flux data
            res = client.retrieve_by_specid(specid_list=[targetid], include=inc, dataset_list=['DESI-EDR'])
            for record in res.records:
                if record['specprimary']:
                    flux = record['flux']
                    flux_min, flux_max = np.min(flux), np.max(flux)
                    return (flux - flux_min) / (flux_max - flux_min)  # Return normalized flux if successful
        except Exception:
            if attempt < retries:
                time.sleep(delay * (2 ** (attempt - 1)))  # Exponential backoff
    return None  # Return None if all attempts fail


def process_spectra_data(zpix_cat, batch_size=20, max_workers=10):
    """
    Processes the spectra data in batches with parallel retrieval.
    Assumes that every galaxy should have flux data, logs if missing.
    """
    all_fluxes = []
    total_records = len(zpix_cat)
    inc = ['specid', 'redshift', 'flux', 'wavelength', 'spectype', 
           'specprimary', 'survey', 'program', 'targetid', 'coadd_fiberstatus']

    for start_idx in tqdm(range(0, total_records, batch_size), desc="Processing spectra in large batches"):
        batch = zpix_cat.iloc[start_idx:start_idx+batch_size]
        batch_fluxes = []

        # Parallelize requests within the batch
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = {executor.submit(retrieve_flux, int(row['targetid']), inc): row['targetid'] for _, row in batch.iterrows()}
            for future in as_completed(futures):
                flux = future.result()
                if flux is not None:
                    batch_fluxes.append(flux)
                else:
                    print(f"Warning: Missing flux for target ID {futures[future]}")  # Log missing data if any

        all_fluxes.extend(batch_fluxes)

    if not all_fluxes:
        raise ValueError("Error: No flux data retrieved. Check data source or retrieval logic.")

    return np.array(all_fluxes)  # Return as numpy array for efficiency
#################################################

 
# Pad spectra data
def pad_spectra(fluxes, target_length):
    """
    Pads or truncates each flux array to match the target length.
    Returns a numpy array of padded flux data.
    """
    padded_fluxes = []
    num_padded = 0  # Track only the padded spectra

    for flux in fluxes:
        if len(flux) < target_length:
            padding = np.zeros(target_length - len(flux))
            padded_flux = np.concatenate([flux, padding])
            num_padded += 1
        else:
            padded_flux = flux[:target_length]
        padded_fluxes.append(padded_flux)

    if num_padded > 0:
        print(f"Padding applied to {num_padded} spectra to match the target length.")
    else:
        print("No padding was necessary; all spectra are of equal length.")

    return np.array(padded_fluxes)


# Create padding mask
def create_padding_mask(fluxes, target_length):
    """
    Generates a binary mask for padded values in the flux arrays.
    Used in the loss calculation to ignore padded regions.
    """
    masks = []
    for flux in fluxes:
        mask = np.ones_like(flux)
        if len(flux) < target_length:
            mask = np.concatenate([mask, np.zeros(target_length - len(flux))])
        masks.append(mask[:target_length])
    return np.array(masks)


# Autoencoder Model Definition

In [5]:
class CNNAutoencoderWithSkip(nn.Module):
    """
    A CNN-based autoencoder with skip connections to reconstruct spectra data.
    """
    def __init__(self):
        super(CNNAutoencoderWithSkip, self).__init__()
        # Encoder layers with reduced channels
        self.encoder1 = nn.Conv1d(1, 64, kernel_size=3, stride=2, padding=1)
        self.encoder2 = nn.Conv1d(64, 32, kernel_size=3, stride=2, padding=1)
        self.encoder3 = nn.Conv1d(32, 16, kernel_size=3, stride=2, padding=1)  # Adjusted from 128 max channels

        # Decoder layers with reduced channels
        self.decoder3 = nn.ConvTranspose1d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.decoder2 = nn.ConvTranspose1d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.decoder1 = nn.ConvTranspose1d(64, 1, kernel_size=3, stride=2, padding=1, output_padding=0)

    def forward(self, x):
        # Encoding with skip connections
        x1 = F.relu(self.encoder1(x))  # Save this for skip connection
        x2 = F.relu(self.encoder2(x1))  # Save this for skip connection
        x3 = F.relu(self.encoder3(x2))  # Last encoding layer

        # Decoding with skip connections
        x = F.relu(self.decoder3(x3))
        
        # Adjust sizes for skip connection with x2
        if x2.size(2) > x.size(2):
            x2 = x2[:, :, :x.size(2)]
        elif x2.size(2) < x.size(2):
            x = x[:, :, :x2.size(2)]
        x = F.relu(self.decoder2(x + x2))
        
        # Adjust sizes for skip connection with x1
        if x1.size(2) > x.size(2):
            x1 = x1[:, :, :x.size(2)]
        elif x1.size(2) < x.size(2):
            x = x[:, :, :x1.size(2)]
        x = self.decoder1(x + x1)

        return torch.sigmoid(x)

# Custom Loss Function


In [6]:
# Custom weighted MSE loss function
def weighted_mse_loss(output, target, mask, weight_factor=10):
    """
    Calculates MSE loss, amplifying high-residual areas with weight based on spectral gradient.
    """
    mse_loss = (output - target) ** 2
    gradient = torch.abs(target[:, :, 1:] - target[:, :, :-1])
    weighted_loss = mse_loss[:, :, 1:] * (1 + weight_factor * gradient)
    return (weighted_loss * mask[:, :, 1:]).mean()


# Training the Model


In [17]:
# Optimized train_autoencoder with gradient clipping
def train_autoencoder(model, data, mask, epochs=50, batch_size=32, lr=0.001, grad_clip=1.0):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(epochs):
        start_time = time.time()
        epoch_loss = 0
        optimizer.zero_grad()  # Zero gradients at the start of each epoch

        for i in range(0, len(data), batch_size):
            batch_data, mask_batch = data[i:i+batch_size], mask[i:i+batch_size]
            loss = weighted_mse_loss(model(batch_data), batch_data, mask_batch)
            loss.backward()

            # Clip gradients to prevent explosions
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            optimizer.zero_grad()  # Zero gradients after each batch

            epoch_loss += loss.item()

        epoch_duration = time.time() - start_time
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(data):.4f}, Time: {epoch_duration:.2f}s")
    # Print final loss after all epochs
    print(f"Final Loss after {epochs} epochs: {final_loss:.8f}")



# Anomaly Detection

In [8]:
# def detect_anomalous_regions(original_fluxes, reconstructed_fluxes, window_size=50, percentile_threshold=95, range_mismatch_factor=1.5, overall_anomaly_threshold=0.3):
#     """
#     Identifies anomalous regions within spectra using residuals and a mismatch factor.
#     """
#     num_spectra, spectrum_length = original_fluxes.shape
#     absolute_residuals = np.abs(original_fluxes - reconstructed_fluxes)
#     relative_residuals = np.abs((original_fluxes - reconstructed_fluxes) / (original_fluxes + 1e-5))

#     abs_residual_threshold = np.percentile(
#         [np.mean(absolute_residuals[:, i:i+window_size], axis=1)
#         for i in range(0, spectrum_length - window_size + 1, window_size // 2)], percentile_threshold)
#     rel_residual_threshold = np.percentile(
#         [np.mean(relative_residuals[:, i:i+window_size], axis=1)
#         for i in range(0, spectrum_length - window_size + 1, window_size // 2)], percentile_threshold)
    
#     anomalies, spectrum_anomalies = [], np.zeros(num_spectra, dtype=bool)
#     for i in range(num_spectra):
#         spectrum_anomalies_count = 0
#         spectrum_anomaly_flags = np.zeros(spectrum_length, dtype=bool)
        
#         for start in range(0, spectrum_length - window_size + 1, window_size // 2):
#             end = start + window_size
#             window_abs_residual, window_rel_residual = np.mean(absolute_residuals[i, start:end]), np.mean(relative_residuals[i, start:end])
#             original_range, reconstructed_range = np.ptp(original_fluxes[i, start:end]), np.ptp(reconstructed_fluxes[i, start:end])
            
#             if (window_abs_residual > abs_residual_threshold or
#                 window_rel_residual > rel_residual_threshold or
#                 reconstructed_range > original_range * range_mismatch_factor):
#                 spectrum_anomaly_flags[start:end] = True
#                 spectrum_anomalies_count += end - start

#         spectrum_anomalies[i] = spectrum_anomalies_count / spectrum_length > overall_anomaly_threshold
#         anomalies.append(spectrum_anomaly_flags)
#     return anomalies, spectrum_anomalies
############################################
# Enhanced anomaly detection with justification for each anomaly
def detect_anomalous_regions(original_fluxes, reconstructed_fluxes, window_size=50, percentile_threshold=95, 
                             range_mismatch_factor=1.5, overall_anomaly_threshold=0.3):
    """
    Identifies anomalous regions within spectra using residuals and mismatch factors,
    with added logging to justify each detected anomaly.
    """
    num_spectra, spectrum_length = original_fluxes.shape
    absolute_residuals = np.abs(original_fluxes - reconstructed_fluxes)
    relative_residuals = np.abs((original_fluxes - reconstructed_fluxes) / (original_fluxes + 1e-5))

    abs_residual_threshold = np.percentile(
        [np.mean(absolute_residuals[:, i:i+window_size], axis=1)
         for i in range(0, spectrum_length - window_size + 1, window_size // 2)], percentile_threshold)
    rel_residual_threshold = np.percentile(
        [np.mean(relative_residuals[:, i:i+window_size], axis=1)
         for i in range(0, spectrum_length - window_size + 1, window_size // 2)], percentile_threshold)

    anomalies, spectrum_anomalies = [], np.zeros(num_spectra, dtype=bool)
    anomaly_metadata = []

    for i in range(num_spectra):
        spectrum_anomalies_count = 0
        spectrum_anomaly_flags = np.zeros(spectrum_length, dtype=bool)
        anomaly_justification = []  # To log reasons for each anomaly

        for start in range(0, spectrum_length - window_size + 1, window_size // 2):
            end = start + window_size
            window_abs_residual = np.mean(absolute_residuals[i, start:end])
            window_rel_residual = np.mean(relative_residuals[i, start:end])
            original_range = np.ptp(original_fluxes[i, start:end])
            reconstructed_range = np.ptp(reconstructed_fluxes[i, start:end])

            if (window_abs_residual > abs_residual_threshold or
                window_rel_residual > rel_residual_threshold or
                reconstructed_range > original_range * range_mismatch_factor):
                
                spectrum_anomaly_flags[start:end] = True
                spectrum_anomalies_count += end - start

                # Log specific reasons why this region was flagged as anomalous
                reason = {
                    "window_start": start,
                    "window_end": end,
                    "abs_residual": window_abs_residual,
                    "rel_residual": window_rel_residual,
                    "range_mismatch": reconstructed_range / (original_range + 1e-5)
                }
                anomaly_justification.append(reason)

        spectrum_anomalies[i] = spectrum_anomalies_count / spectrum_length > overall_anomaly_threshold
        anomalies.append(spectrum_anomaly_flags)
        anomaly_metadata.append(anomaly_justification)  # Add metadata for each spectrum's anomalies

    return anomalies, spectrum_anomalies, anomaly_metadata


# Spectra Visualisation

In [9]:
# Create dynamic save path with numbering
def create_save_path(save_directory, base_filename):
    """
    Creates a new file path in the specified directory with a unique numeric suffix.
    Ensures saved file names are sequentially numbered.
    """
    os.makedirs(save_directory, exist_ok=True)
    existing_files = os.listdir(save_directory)
    numbers = [
        int(re.search(r'\d+', f).group())
        for f in existing_files if re.search(fr'{base_filename}_(\d+)\.png', f)
    ]
    next_number = max(numbers) + 1 if numbers else 1
    return os.path.join(save_directory, f'{base_filename}_{next_number}.png')


In [10]:
def plot_spectra(original_fluxes, reconstructed_fluxes, anomalous_regions, spectrum_anomalies, zpix_cat, wavelengths=None, save_directory='output_images', max_samples=20):
    num_spectra = len(original_fluxes)
    indices = list(range(num_spectra))

    # Sample only max_samples if dataset is too large
    if num_spectra > max_samples:
        print(f"Too many spectra to plot ({num_spectra}). Sampling {max_samples} spectra.")
        indices = random.sample(indices, max_samples)
    else:
        print(f"Plotting all {num_spectra} spectra.")

    fig, axes = plt.subplots(len(indices), 1, figsize=(10, 6 * len(indices)), sharex=True)
    if len(indices) == 1:
        axes = [axes]  # Ensure axes is iterable for a single plot

    for idx, i in enumerate(indices):
        ax = axes[idx]
        x_axis = wavelengths if wavelengths is not None else range(len(original_fluxes[i]))
        
        # Plot original and reconstructed spectra
        ax.plot(x_axis, original_fluxes[i], label="Original", color='#2c7bb6', linewidth=0.5)
        ax.plot(x_axis, reconstructed_fluxes[i], label="Reconstructed", color='#d7191c', alpha=0.7, linewidth=0.5)

        # Highlight anomalous regions
        in_anomaly, anomaly_start = False, 0
        for j in range(len(x_axis)):
            if anomalous_regions[i][j] and not in_anomaly:
                anomaly_start = j
                in_anomaly = True
            elif not anomalous_regions[i][j] and in_anomaly:
                ax.axvspan(x_axis[anomaly_start], x_axis[j], color='#fdae61', alpha=0.7)
                in_anomaly = False
        if in_anomaly:
            ax.axvspan(x_axis[anomaly_start], x_axis[-1], color='#fdae61', alpha=0.7)

        # Set background color for anomalous spectra
        if spectrum_anomalies[i]:
            ax.set_facecolor('#fee090')  # Yellow background
            ax.set_title(f"Spectrum {i+1} (ID: {zpix_cat['targetid'].iloc[i]}) - Anomalous", color='red')
        else:
            ax.set_title(f"Spectrum {i+1} (ID: {zpix_cat['targetid'].iloc[i]})", color='black')

        # Add residuals plot
        divider = make_axes_locatable(ax)
        ax_residual = divider.append_axes("bottom", size="25%", pad=0, sharex=ax)
        ax_residual.plot(x_axis, original_fluxes[i] - reconstructed_fluxes[i], color='#4dac26', linewidth=0.5)
        ax_residual.set_ylabel("Residuals")
        ax.set_ylabel("Flux (normalized)")
        ax.set_xlabel("Wavelength (Å)")
        ax.legend()

    plt.tight_layout()
    save_path = create_save_path(save_directory, 'spectra_reconstruction')
    plt.savefig(save_path, dpi=300)
    plt.close(fig)  # Close figure to free memory
    print(f"Figure saved to {save_path}")


def plot_anomalous_spectra(original_fluxes, reconstructed_fluxes, anomalous_regions, spectrum_anomalies, zpix_cat, wavelengths=None, save_directory='output_images', max_samples=20):
    """
    Plots only the spectra that are flagged as anomalous.
    """
    # Filter for only anomalous spectra
    anomalous_indices = [i for i, is_anomalous in enumerate(spectrum_anomalies) if is_anomalous]
    
    if not anomalous_indices:
        print("No anomalous spectra detected.")
        return
    
    # Limit the number of spectra to plot if needed
    if len(anomalous_indices) > max_samples:
        print(f"Too many anomalous spectra to plot ({len(anomalous_indices)}). Sampling {max_samples} spectra.")
        anomalous_indices = random.sample(anomalous_indices, max_samples)
    else:
        print(f"Plotting all {len(anomalous_indices)} anomalous spectra.")

    # Extract only the anomalous spectra data
    filtered_original_fluxes = [original_fluxes[i] for i in anomalous_indices]
    filtered_reconstructed_fluxes = [reconstructed_fluxes[i] for i in anomalous_indices]
    filtered_anomalous_regions = [anomalous_regions[i] for i in anomalous_indices]
    filtered_zpix_cat = zpix_cat.iloc[anomalous_indices]  # Select rows for anomalous spectra

    # Call the original plotting function with filtered data
    plot_spectra(
        original_fluxes=filtered_original_fluxes,
        reconstructed_fluxes=filtered_reconstructed_fluxes,
        anomalous_regions=filtered_anomalous_regions,
        spectrum_anomalies=[True] * len(filtered_original_fluxes),  # Mark all as anomalous for consistent background
        zpix_cat=filtered_zpix_cat,
        wavelengths=wavelengths,
        save_directory=save_directory,
        max_samples=max_samples
    )

#Save anomalies to CSV for analysis
def save_anomaly_data(zpix_cat, spectrum_anomalies, file_path='anomaly_data.csv'):
    """
    Saves information about detected anomalies to a CSV file for further analysis.
    Each row includes the target ID and an anomaly indicator.
    """
    with open(file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['targetid', 'is_anomalous'])  # Header row
        for i in range(len(zpix_cat)):
            writer.writerow([zpix_cat['targetid'].iloc[i], int(spectrum_anomalies[i])])
    print(f"Anomaly data saved to {file_path}")


In [11]:
# Visualize the full autoencoder
def visualize_autoencoder(autoencoder, input_data):
    """
    Visualizes the entire autoencoder model structure, saving it as an image.
    Takes input data, passes it through the model, and generates a network graph.
    """
    outputs = autoencoder(input_data)
    model_viz = make_dot(outputs, params=dict(autoencoder.named_parameters()))
    model_viz.format = "png"
    save_path = create_save_path(IMG_DIR, 'autoencoder_visualization')
    model_viz.render(save_path.replace(".png", ""))
    print(f"Full autoencoder visualization saved to {save_path}")

# Visualize only the encoder part
def visualize_encoder(autoencoder, input_data):
    """
    Visualizes the encoder portion of the autoencoder model, saving it as an image.
    Passes input data through only the encoder layers and generates a graph.
    """
    encoder_output = autoencoder.encoder3(autoencoder.encoder2(autoencoder.encoder1(input_data)))
    encoder_viz = make_dot(encoder_output, params=dict(autoencoder.named_parameters()), 
                           show_attrs=True, show_saved=True)
    encoder_viz.format = "png"
    save_path = create_save_path(IMG_DIR, 'encoder_visualization')
    encoder_viz.render(save_path.replace(".png", ""))
    print(f"Encoder visualization saved to {save_path}")


# Visualize only the decoder part, given an encoded input
def visualize_decoder(autoencoder, encoded_input):
    """
    Visualizes the decoder portion of the autoencoder model, saving it as an image.
    Takes an encoded input and generates a graph of the decoder layers.
    """
    decoder_output = autoencoder.decoder1(autoencoder.decoder2(autoencoder.decoder3(encoded_input)))
    decoder_viz = make_dot(decoder_output, params=dict(autoencoder.named_parameters()), 
                           show_attrs=True, show_saved=True)
    decoder_viz.format = "png"
    save_path = create_save_path(IMG_DIR, 'decoder_visualization')
    decoder_viz.render(save_path.replace(".png", ""))
    print(f"Decoder visualization saved to {save_path}")


# Autoencoder Memory Wipe:

In [12]:
def reset_model_weights(model):
    """
    Wipes the memory of the model by reinitialising all weights.
    This allows the model to be retrained from scratch.
    """
    for layer in model.modules():
        if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
            # Reinitialize weights and biases for layers that have them
            layer.reset_parameters()
    print("Model weights and biases have been reset.")

# Main Execution Flow


In [13]:
zpix_cat = query_spectra_data()
if zpix_cat is not None:
    all_fluxes = process_spectra_data(zpix_cat)
    max_length = max(len(f) for f in all_fluxes)
    all_fluxes_padded = pad_spectra(all_fluxes, max_length)
    mask = create_padding_mask(all_fluxes, max_length)

    # Prepare tensors for model training
    all_fluxes_tensor = torch.tensor(all_fluxes_padded, dtype=torch.float32).unsqueeze(1)
    mask_tensor = torch.tensor(mask, dtype=torch.float32).unsqueeze(1)
else:
    print("No data available for processing.")



Retrieved 739 records from the database.


Processing spectra in large batches: 100%|██████████| 37/37 [10:59<00:00, 17.82s/it]

No padding was necessary; all spectra are of equal length.





In [16]:
# Model setup and initialization
autoencoder = CNNAutoencoderWithSkip()
reset_model_weights(autoencoder)

# Train the model
train_autoencoder(autoencoder, all_fluxes_tensor, mask_tensor)

# Evaluate the model and perform anomaly detection
autoencoder.eval()
reconstructed_fluxes = autoencoder(all_fluxes_tensor).detach().numpy().squeeze()
# Detect anomalies
anomalous_regions, spectrum_anomalies = detect_anomalous_regions(
    original_fluxes=all_fluxes_tensor.numpy().squeeze(),
    reconstructed_fluxes=reconstructed_fluxes,
    window_size=50,
    percentile_threshold=95,
    range_mismatch_factor=1.5,
    overall_anomaly_threshold=0.3
)

# Save anomaly information
save_anomaly_data(zpix_cat, spectrum_anomalies, file_path='anomalous_spectra_data.csv')

# Plot only the anomalous spectra
plot_anomalous_spectra(
    original_fluxes=all_fluxes_tensor.numpy().squeeze(),
    reconstructed_fluxes=reconstructed_fluxes,
    anomalous_regions=anomalous_regions,
    spectrum_anomalies=spectrum_anomalies,
    zpix_cat=zpix_cat,
    wavelengths=None,  # or provide the wavelength range if available
    save_directory=IMG_DIR
)

# Plot the results, using dynamic naming
plot_spectra(
    original_fluxes=all_fluxes_tensor.numpy().squeeze(),
    reconstructed_fluxes=reconstructed_fluxes,
    anomalous_regions=anomalous_regions,
    spectrum_anomalies=spectrum_anomalies,
    zpix_cat=zpix_cat,
    save_directory=IMG_DIR
)

# # Optional: Visualize the autoencoder and its components if needed
#real_input = all_fluxes_tensor[0].unsqueeze(0)  # Select a real input sample

####### Uncomment when visualizations are required #########
# visualize_autoencoder(autoencoder, real_input)
# visualize_encoder(autoencoder, real_input)
# encoded_input = autoencoder.encoder3(autoencoder.encoder2(autoencoder.encoder1(real_input)))
# visualize_decoder(autoencoder, encoded_input)

Model weights and biases have been reset.
Epoch [1/50], Loss: 0.0026, Time: 4.89s
Epoch [2/50], Loss: 0.0007, Time: 4.43s
Epoch [3/50], Loss: 0.0001, Time: 4.16s
Epoch [4/50], Loss: 0.0001, Time: 4.27s
Epoch [5/50], Loss: 0.0001, Time: 4.19s
Epoch [6/50], Loss: 0.0000, Time: 4.12s
Epoch [7/50], Loss: 0.0000, Time: 3.90s
Epoch [8/50], Loss: 0.0000, Time: 3.84s
Epoch [9/50], Loss: 0.0000, Time: 4.10s
Epoch [10/50], Loss: 0.0000, Time: 3.83s
Epoch [11/50], Loss: 0.0000, Time: 3.95s
Epoch [12/50], Loss: 0.0000, Time: 3.81s
Epoch [13/50], Loss: 0.0000, Time: 3.95s
Epoch [14/50], Loss: 0.0000, Time: 3.80s
Epoch [15/50], Loss: 0.0000, Time: 3.87s
Epoch [16/50], Loss: 0.0000, Time: 4.21s
Epoch [17/50], Loss: 0.0000, Time: 4.06s
Epoch [18/50], Loss: 0.0000, Time: 3.91s
Epoch [19/50], Loss: 0.0000, Time: 3.84s
Epoch [20/50], Loss: 0.0000, Time: 3.92s
Epoch [21/50], Loss: 0.0000, Time: 3.99s
Epoch [22/50], Loss: 0.0000, Time: 4.04s
Epoch [23/50], Loss: 0.0000, Time: 3.92s
Epoch [24/50], Loss: 0.0

# vvvvvvvvvvv DEBUG ZONE vvvvvvvvvvv