# Spectrum Class GAIA-NIR



In [3]:
import numpy as np
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter1d
import matplotlib.pyplot as plt

class Spectrum:
    """
    Class to manage spectrum data and operations.
    """

    def __init__(self):
        """
        Initialize spectrum attributes.
        """
        self.wavelength = np.array([])
        self.flux = np.array([])
    
    def copy(self):
        """
        Create a copy of the current spectrum instance.
        Returns:
            Spectrum: A new instance with copied wavelength and flux data.
        """
        new_spectrum = Spectrum()
        new_spectrum.wavelength = self.wavelength.copy()
        new_spectrum.flux = self.flux.copy()
        
        return new_spectrum
    
    
    def load_spectrum(self, file_name):
        """
        Load input spectrum from a file and parse data.
        """
        print(f"Loading spectrum from {file_name}")
        try:
            data = np.loadtxt(file_name, skiprows=1)
            self.wavelength = data[:, 0]
            self.flux = data[:, 1]
            print("Spectrum loaded correctly.") 
            #print(f"Max flux: {np.max(self.flux)}, Sum flux: {np.sum(self.flux)}")  # Should match the Dirac file
       
        except Exception as e:
            print(f"Error loading spectrum: {e}")


    def convolve_spectrum(self, sigma_wavelength=5.0, verbose=True):
        """
        Apply Gaussian convolution to the spectra using a fixed sigma in wavelength units.

        Args:
            sigma_wavelength (float): Standard deviation for Gaussian convolution (in wavelength units).
            verbose (bool): If True, prints flux conservation details.
        """
        original_flux = self.flux.copy()
        initial_flux = np.sum(original_flux)

        # Convert sigma from wavelength units to pixel units
        lambda_step = np.mean(np.diff(self.wavelength))
        sigma_pixels = sigma_wavelength / lambda_step

        print(f"Sigma wavelength units: {sigma_wavelength}, pixel units: {sigma_pixels}")

        convolved_flux = gaussian_filter1d(self.flux, sigma_pixels, mode='nearest')

        # Normalize flux
        convolved_flux *= initial_flux / np.sum(convolved_flux)

        self.flux = convolved_flux

        final_flux = np.sum(convolved_flux)
        conservation_ratio = final_flux / initial_flux if initial_flux != 0 else 0

        if verbose:
            print("Gaussian convolution completed.")
            print(f"Flux before & after convolution: {initial_flux:.3f} | {final_flux:.3f}")
            print(f"Flux conservation ratio: {conservation_ratio:.6f}")

        return self
#calculate mean sampling , fisrt and last, central

    """
    def resample_spectrum(self, num_points=1000, verbose=True):
        
        Resample the spectrum onto a new uniform wavelength grid.

        Args:
            num_points (int): Number of points for the new resampled grid.
            verbose (bool): If True, prints debugging information.
        

        new_wavelength_grid = np.linspace(min(self.wavelength), max(self.wavelength), 100)
    
        interpolator = interp1d(self.wavelength, self.flux, kind="linear", bounds_error=False, fill_value=0)
        resampled_flux = interpolator(new_wavelength_grid)
        
        self.wavelength = new_wavelength_grid
        self.flux = resampled_flux
        
        if verbose:
            print("Resampling completed.")
            
        return self
        
        to test, run with convo and no resampling, and rerun with convo and resample save both and overplot, see where the resampling breaks, maybe in core of lines
    """
    
    def plot_comparison(self, reference_spectrum, verbose=True):
        """
        Plot the spectrum before and after convolution for visualization.

        Args:
            reference_spectrum (Spectrum): The original (unconvolved) spectrum to compare against.
        """
       
        plt.figure(figsize=(15, 10))
        if verbose:
            if reference_spectrum:
                plt.plot(reference_spectrum.wavelength, reference_spectrum.flux, 
                         label="Original Spectrum", color="pink", linestyle='--')
        plt.plot(self.wavelength, self.flux, label="Convolved Spectrum", color="red")
        plt.xlabel("Wavelength")
        plt.ylabel("Flux")
        plt.legend()
        plt.grid(True)
        plt.show()
        
    def convert_units(self):
        print("Converting units")

    def rescale_flux(self):
        print("Rescaling flux levels")

    def radial_velocity_shift(self, verbose=True):
        if verbose:
            print("Applying radial velocity shift")

    def resample_stochastic(self, verbose=True):
        if verbose:
            print("Resampling spectrum for stochastic process")

    def generate_noise(self, verbose=True):
        if verbose:
            print("Adding noise to spectrum")

    def save_spectrum(self, output_file):
        """
        Save the processed spectrum to a file.

        Args:
            output_file (str): Name of the output file.
        """
        if self.wavelength.size == 0 or self.flux.size == 0:
            raise ValueError("No spectrum data to save. Ensure spectrum is processed before saving.")

        np.savetxt(output_file, np.column_stack((self.wavelength, self.flux)), 
                   header="Wavelength Flux", fmt="%.10f")

        print(f"Spectrum saved to {output_file}")
