In [1]:
from scipy.signal import find_peaks
from astropy.io import fits
from specutils import Spectrum1D
import matplotlib.pyplot as plt
import numpy as np
from astropy import units as u
import requests
from io import BytesIO
import os
import csv


## Define Functions to inject, detect, and analyze

In [2]:
def inject_and_plot_laser(obj_id, filt, folder="/datax/scratch/emmay/galah_spectra",
                          fwhm=1.5, laser_amp_percent=20, plot=True):
    """
    Injects a Gaussian laser spike into GALAH FITS spectrum and plots it.

    Parameters:
        obj_id (str): Object ID
        filt (str): Filter (e.g. 'B', 'V', 'R', 'I')
        folder (str): Path to FITS files
        fwhm (float): FWHM of the Gaussian in pixels
        laser_amp_percent (float): Amplitude of injection as percentage of max flux
        plot (bool): Whether to show the plot

    Returns:
        wavelength (np.ndarray): Wavelength array (in Å)
        flux_injected (np.ndarray): Flux array with injected Gaussian
    """
    filename = f"{obj_id}_{filt}.fits"
    path = os.path.join(folder, filename)

    with fits.open(path) as hdul:
        header = hdul[1].header
        flux = hdul[1].data.astype(float)

    # Build wavelength axis
    crval1 = header.get('CRVAL1')
    cdelt1 = header.get('CDELT1')
    crpix1 = header.get('CRPIX1', 1)

    npix = len(flux)
    wavelength = (crval1 + (np.arange(npix) + 1 - crpix1) * cdelt1) * u.AA

    # Inject Gaussian
    sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
    center = npix // 2
    x = np.arange(npix)
    gaussian = np.exp(-0.5 * ((x - center) / sigma) ** 2)
    gaussian *= (laser_amp_percent / 100.0) * np.nanmax(flux)

    flux_injected = flux + gaussian

    # Plot
    if plot:
        threshold_mask = gaussian > 1e-4
        plt.figure(figsize=(13, 5))
        plt.plot(wavelength, flux, color='blue', linewidth=0.6, label="Original Spectrum")
        plt.plot(wavelength[threshold_mask], flux_injected[threshold_mask], color='red', linewidth=1.0, label="Injected Region")
        plt.xlabel("Wavelength (Å)")
        plt.ylabel("Flux")
        plt.title(f"Injected Spectrum — {obj_id} ({filt}) — Amplitude = {laser_amp_percent}%")
        plt.legend()
        plt.tight_layout()
        plt.show()

    return wavelength.value, flux_injected, flux, wavelength[center]


In [3]:
def measure_width_at_y(flux, wavelength, peak_idx, y_level=1):
    """Measure width of a peak at a fixed y_level using linear interpolation."""
    # Left side
    i = peak_idx
    while i > 0 and flux[i] > y_level:
        i -= 1
    if i == 0 or flux[i] > y_level:
        return None  # No valid left edge
    left = wavelength[i] + (y_level - flux[i]) / (flux[i+1] - flux[i]) * (wavelength[i+1] - wavelength[i])

    # Right side
    i = peak_idx
    while i < len(flux) - 1 and flux[i] > y_level:
        i += 1
    if i == len(flux) - 1 or flux[i] > y_level:
        return None  # No valid right edge
    right = wavelength[i-1] + (y_level - flux[i-1]) / (flux[i] - flux[i-1]) * (wavelength[i] - wavelength[i-1])

    return left, right, right - left

def detect_laser_peak_with_fixed_level_width(
    wavelength, 
    flux, 
    height_fraction=0.2, 
    y_level=1.0, 
    max_pixel_width=3,
    injected_wavelength=None,
    plot=True
):
    """
    Detect laser-like peaks and measure width above a fixed y-level.

    Parameters
    ----------
    wavelength : array
        Wavelength array.
    flux : array
        Flux array.
    height_fraction : float
        Threshold for peak prominence (fraction of max flux).
    y_level : float
        Fixed flux level at which to measure peak width.
    max_pixel_width : float
        Maximum allowed width in pixel units for valid peaks.
    plot : bool
        If True, show zoomed-in and full-spectrum plots.

    Returns
    -------
    peak_wavelengths : list of float
        Wavelengths of detected peaks.
    widths : list of float
        Widths measured at y_level.
    """

    # Remove units if present
    flux = np.asarray(flux)
    wavelength = np.asarray(wavelength)
        
    height_threshold = height_fraction * max(np.abs(flux - 1))
    peak_indices, properties = find_peaks(flux, height=height_threshold*np.nanmax(flux)+1)

    # Handle case where no peaks are found
    if len(peak_indices) == 0:
        print("No peaks detected.")
        if plot:
            plt.figure(figsize=(12, 4))
            plt.plot(wavelength, flux, label="Full Spectrum")
            if injected_wavelength is not None:
                plt.axvline(x=float(injected_wavelength.value), color='r', linestyle='--', linewidth=2, label="Injected Laser")
            plt.xlabel("Wavelength (Å)")
            plt.ylabel("Flux")
            plt.title("Full Spectrum (No Peaks Detected)")
            plt.legend()
            plt.tight_layout()
            plt.show()
        return [], []

    peak_wavelengths = []
    widths = []
    valid_peak_indices = []

    # Estimate pixel scale (assuming evenly spaced wavelength)
    pixel_scale = np.median(np.diff(wavelength))

    for i, idx in enumerate(peak_indices):
        result = measure_width_at_y(flux, wavelength, idx, y_level=y_level)
        if result is not None:
            left, right, width = result
            print(type(left),type(right),type(width))
            width_in_pixels = width / pixel_scale

            if width_in_pixels > max_pixel_width:
                continue  # Skip wide peaks

            peak_wavelengths.append(wavelength[idx])
            widths.append(width)
            valid_peak_indices.append(idx)

            if plot:
                i_min = max(0, idx - 20)
                i_max = min(len(wavelength), idx + 20)

                plt.figure(figsize=(10, 4))
                plt.plot(wavelength[i_min:i_max], flux[i_min:i_max], label="Flux")
                plt.axvline(left, color='C2', linestyle='--')
                plt.axvline(right, color='C2', linestyle='--')
                plt.hlines(y_level, left, right, color='C3', linewidth=2, label=f'Width @ y={y_level}')
                plt.plot(wavelength[idx], flux[idx], 'rx', label="Peak Center")
                plt.xlabel("Wavelength (Å)")
                plt.ylabel("Flux")
                plt.title(f"Zoomed Peak at {wavelength[idx]:.2f} Å, Width = {width:.3f} Å ({width_in_pixels:.2f} px)")
                plt.legend()
                plt.tight_layout()
                plt.show()
                
    # Plot full spectrum with vertical lines
    if plot and valid_peak_indices:
        plt.figure(figsize=(12, 4))
        plt.plot(wavelength, flux, label="Full Spectrum")
        for idx in valid_peak_indices:
            #plt.axvline(wavelength[idx], color='blue', linestyle='--', alpha=0.7)
            plt.plot(wavelength[idx], flux[idx], 'rx', label="Detected Peaks")


        if injected_wavelength is not None:
            plt.axvline(x=float(injected_wavelength.value), color='black', linestyle='--', linewidth=2, label="Injected Laser")

        plt.xlabel("Wavelength (Å)")
        plt.ylabel("Flux")
        plt.title("Full Spectrum with Detected Peaks")
        plt.legend()
        plt.tight_layout()
        plt.show()

    return peak_wavelengths, widths


In [4]:
def evaluate_detection(injected_wavelengths, detected_wavelengths, threshold_angstroms=2.0):
    """
    Compare injected wavelengths to detected peaks and compute detection statistics.

    Parameters:
    - injected_wavelengths: list of float
    - detected_wavelengths: list of float
    - threshold_angstroms: float

    Returns:
    - stats: dict with TP, FP, FN, precision, recall, F1 score
    - matched_flags: list of bool per injected wavelength (True if detected)
    """
    matched_injected = [False] * len(injected_wavelengths)
    matched_detected = [False] * len(detected_wavelengths)

    # Match detected peaks to injected lasers
    for i, inj_wl in enumerate(injected_wavelengths):
        for j, det_wl in enumerate(detected_wavelengths):
            if not matched_detected[j] and abs(det_wl - inj_wl) <= threshold_angstroms:
                matched_injected[i] = True
                matched_detected[j] = True
                break  # Stop after first match

    TP = sum(matched_injected)
    FN = len(injected_wavelengths) - TP
    FP = len(detected_wavelengths) - TP

    precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

    stats = {
        "TP": TP,
        "FP": FP,
        "FN": FN,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }

    return stats, matched_injected
