In [8]:
import numpy as np
import warnings
import pandas as pd

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import FeatureUnion, Pipeline
from sklearn.decomposition import PCA
from scipy.signal import find_peaks, peak_widths

import astropy
from astropy.table import Table
from astropy.io import fits
from astropy import units as u

warnings.filterwarnings('ignore')

In [2]:
X_full_path = '../data/X_full.npy'
X_full = np.load(X_full_path)

wavelengths_path = '../data/wavelengths.npy'
wavelengths = np.load(wavelengths_path)


In [7]:
X_full

array([[1.19550204, 1.03793788, 1.25705957, ..., 0.19054089, 0.47583213,
        0.43016803],
       [7.60315248, 7.13697939, 6.64533472, ..., 3.12152142, 3.13476166,
        3.37540529],
       [5.31835146, 5.16014164, 4.86810956, ..., 1.1297361 , 1.09039694,
        1.71308875],
       ...,
       [2.52033424, 2.72489985, 2.60168448, ..., 0.77349984, 0.83812168,
        0.71297602],
       [3.12595589, 2.91660415, 3.53408267, ..., 1.09230628, 1.38101776,
        1.1850602 ],
       [2.72489721, 3.02516791, 1.20496079, ..., 0.28848155, 0.37093287,
        0.27019219]])

In [None]:
dr16q_filename = '../data/DR16Q_v4.fits'

hdul = fits.open(dr16q_filename)

In [None]:

X_train = X

pca = PCA(n_components=5, svd_solver='full', random_state=12345)
# pca_mle = PCA(n_components='mle', random_state=1234)

pca.fit(X_train)

In [None]:
class CustomPCA(PCA):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def get_bal_mask(flux, wavelengths, rest_wavelength=1549.48, v_range=[-25000, 0], min_trough_depth=0.4, min_trough_width=2000, max_trough_width=25000, plt_return=False):
        # Finding the minimum limit to search for troughs i.e. blueshifts of up to -25,000km/s
        min_trough_lim = ( ( v_range[0]*u.km/u.s )/(astropy.constants.c) + 1 ) * ( rest_wavelength * u.AA )
        min_trough_disp = rest_wavelength * u.AA - min_trough_lim

        # Finding the minimum limit to search for troughs i.e. blueshifts of up to -25,000km/s
        min_bal_lim = ( ( min_trough_width*u.km/u.s )/(astropy.constants.c) + 1 ) * ( rest_wavelength * u.AA )
        min_bal_trough_width = np.abs(rest_wavelength * u.AA - min_bal_lim).value
        # Finding the max limit to search for troughs i.e. blueshifts of up to -25,000km/s
        max_bal_lim = ( ( max_trough_width*u.km/u.s )/(astropy.constants.c) + 1 ) * ( rest_wavelength * u.AA )
        max_bal_trough_width = np.abs(rest_wavelength * u.AA - max_bal_lim).value

        # Load quasar spectral data
        wavelength = np.array(spec.spectral_axis)
        flux = np.array(spec.flux)

        # flux = flux / np.nanmax(flux)

        # Finding the minimum limit to search for troughs i.e. blueshifts of up to -25,000km/s
        min_trough_lim = ( ( v_range[0]*u.km/u.s )/(astropy.constants.c) + 1 ) * ( rest_wavelength * u.AA )
        min_trough_disp = (rest_wavelength * u.AA - min_trough_lim).value

        # Find indices corresponding to the region of interest
        civ_indices = np.where((wavelength >= (rest_wavelength - min_trough_disp)) &
                            (wavelength <= rest_wavelength))[0]

        # Extract wavelength and flux data within the defined window
        civ_wavelength = wavelength[civ_indices]
        civ_flux = flux[civ_indices]

        # Smooth the flux data to enhance peak detection
        smoothed_flux = np.convolve(civ_flux, np.ones(5)/5, mode='valid')

        # Find peaks (troughs) in the smoothed flux data
        trough_indices, _ = find_peaks(-smoothed_flux, prominence = 1.5, width=[min_bal_trough_width,max_bal_trough_width] )
        trough_widths, _, _, _ = peak_widths(-smoothed_flux, trough_indices, rel_height=1)

        if len(trough_indices) > 0:
            # Set the mask region
            min_mask_wavelength = civ_wavelength[trough_indices.min()] - trough_widths[trough_indices.argmin()]/2
            bal_mask = [min_mask_wavelength, rest_wavelength]
        else:
            bal_mask = None

        if plt_return:
            return bal_mask, wavelength, flux, civ_wavelength, smoothed_flux, civ_wavelength, trough_indices, trough_widths
        else:
            return bal_mask
        
    def 
    
    def fit(self, X, y=None):
        X = X[X.pca_train == 1]
        transformed_X = super().fit(X)
        new_X = pd.DataFrame(transformed_X, columns=self.feature_names_in_)
        return new_X