In [2]:
import numpy
import numpy as np
from abc import ABC, abstractmethod
from scipy.ndimage import gaussian_filter
from numpy.typing import NDArray


In [3]:
class Pipeline(ABC):
    def __init__(self, args):
        self.args = args

    @abstractmethod
    def gaussian_filtering(self, data:NDArray[np.float64], sigma:float, ss_sigma:float) -> NDArray[np.float64] :
        """
        Applies a Gaussian Filter to the data cube on spectral axis.
        
        Parameters:
        - data: data cube to be used, axis format [spectral, spatial, spatial].
        - sigma: Standard deviation for the Gaussian filter.

        Return:
        - np.ndarray with gaussian filter
        """
        pass
    

class RMSPipeline(ABC):
    def __init__(self, args):
        self.args = args

    def use_mask(self, data:NDArray[np.float64], mask_value = np.nan) -> NDArray[np.float64]:
        """
        Applies mask to data using np.nan values.
        
        Parameters:
        - data: data cube to be used, axis format [spectral, spatial, spatial].
        - mask: S.

        Return:
        - np.ndarray
        """
        mask = get_mask(self.args.ContinuumImage, self.args.MaskSN)
        data[:, mask] = mask_value

        return data

    @abstractmethod
    def rms_filtering(self, data:NDArray[np.float64], use_mask= False, mask = []) -> NDArray[np.float64]:
        """
        Applies a Gaussian Filter to the data cube on spectral axis.
        
        Parameters:
        - data: data cube to be used, axis format [spectral, spatial, spatial].
        - sigma: Standard deviation for the Gaussian filter.

        Return:
        - np.ndarray with gaussian filter
        """
        pass

In [None]:
class SciPyPipeline(Pipeline):

    #Hacer antes el espacial de para cada S espacial y despues los S espectrales.
    def gaussian_filtering(self, data:NDArray[np.float64], sigma:float, spatial_sigma:float ,mode:str="constant", cval:float=0.0, truncate:float=5.0):
        result = gaussian_filter(data, sigma= [sigma,spatial_sigma,spatial_sigma], mode=mode, cval= cval, truncate= truncate)
        return result
    

class SciPyRMS(RMSPipeline):

    def rms_filtering(self, data:NDArray[np.float64], use_mask= False) -> NDArray[np.float64]:
        print("Using default nanstd") 
        if use_mask:
            self.use_mask(data)

        for i in range(len(data)):
            data[i] = self.cp_rms(data,i)
               
        return(data)
               
    def cp_rms(self, data:NDArray[np.float64], index:int):
        initial_rms = np.nanstd(data[index])
        final_rms = np.nanstd(data[index][data[index]<5.0*initial_rms])
        return (data[index]/final_rms)