<a href="https://colab.research.google.com/github/davidwhogg/RVSanomalies/blob/main/notebooks/RVS_meets_RPCA_via_Claude_vibe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Robust PCA Analysis of Gaia RVS Spectra
This notebook implements Robust Principal Component Analysis (RPCA) on Gaia RVS spectra to decompose them into low-rank and sparse components, potentially revealing unusual spectral features, emission lines, or other anomalies.

## Authors:
- **David W. Hogg**
- **Hans-Walter Rix**
- **Claude**

## Requirements:
- pip install astroquery astropy numpy matplotlib scipy scikit-learn



In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astroquery.gaia import Gaia
from astropy.table import Table
from astropy.io import fits
import os
import warnings
from datetime import datetime
from scipy.linalg import svd
from scipy.sparse.linalg import svds
import pickle

warnings.filterwarnings('ignore')

# Set up plotting parameters
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 12

In [None]:
## 1. Data Download Functions

def find_rvs_sources_gspphot(teff_min=4810, teff_max=6200,
                            logg_min=1.0, logg_max=3.0,
                            grvs_mag_max=11.0, n_sources=500):
    """
    Query Gaia archive for sources with RVS spectra using GSP-Phot parameters.

    Parameters:
    -----------
    teff_min, teff_max : float
        Temperature range in K
    logg_min, logg_max : float
        Log gravity range
    grvs_mag_max : float
        Maximum GRVS magnitude
    n_sources : int
        Number of sources to retrieve

    Returns:
    --------
    astropy.table.Table
        Table of sources with RVS spectra
    """

    query = f"""
    SELECT TOP {n_sources}
        source_id, ra, dec, phot_g_mean_mag, grvs_mag,
        radial_velocity, radial_velocity_error,
        teff_gspphot, logg_gspphot, mh_gspphot,
        bp_rp, parallax
    FROM gaiadr3.gaia_source
    WHERE has_rvs = 't'
    AND grvs_mag <= {grvs_mag_max}
    AND teff_gspphot BETWEEN {teff_min} AND {teff_max}
    AND logg_gspphot BETWEEN {logg_min} AND {logg_max}
    AND teff_gspphot IS NOT NULL
    AND logg_gspphot IS NOT NULL
    AND radial_velocity IS NOT NULL
    ORDER BY grvs_mag ASC
    """

    print(f"Querying Gaia archive for RVS sources...")
    print(f"Criteria: {teff_min} <= Teff <= {teff_max} K")
    print(f"         {logg_min} <= log g <= {logg_max}")
    print(f"         GRVS mag <= {grvs_mag_max}")
    print(f"         Requesting {n_sources} sources")

    job = Gaia.launch_job_async(query)
    sources = job.get_results()

    print(f"\nFound {len(sources)} sources matching criteria")

    return sources


def download_rvs_spectrum(source_id, output_dir="rvs_spectra_cache"):
    """
    Download RVS spectrum for a single source.
    Returns wavelength and flux arrays.
    """

    # Check if spectrum already exists in cache
    cache_file = os.path.join(output_dir, f"rvs_{source_id}.npz")
    if os.path.exists(cache_file):
        data = np.load(cache_file)
        return data['wavelength'], data['flux']

    try:
        # Download spectrum
        retrieval_type = 'RVS'
        data_structure = 'INDIVIDUAL'
        data_release = 'Gaia DR3'

        datalink_products = Gaia.load_data(
            ids=[str(source_id)],
            data_release=data_release,
            retrieval_type=retrieval_type,
            data_structure=data_structure,
            verbose=False
        )

        if not datalink_products:
            return None, None

        product_key = f"RVS-Gaia DR3 {source_id}.xml"

        if product_key not in datalink_products:
            return None, None

        # Extract spectrum
        votable = datalink_products[product_key][0]
        spectrum_table = votable.to_table()

        wavelength = np.array(spectrum_table['wavelength'])  # in nm
        flux = np.array(spectrum_table['flux'])  # normalized

        # Save to cache
        os.makedirs(output_dir, exist_ok=True)
        np.savez(cache_file, wavelength=wavelength, flux=flux)

        return wavelength, flux

    except Exception as e:
        print(f"Error downloading spectrum for source {source_id}: {e}")
        return None, None


def download_multiple_spectra(sources, max_spectra=None):
    """
    Download RVS spectra for multiple sources.

    Returns:
    --------
    spectra_data : dict
        Dictionary with source_ids as keys and (wavelength, flux) tuples as values
    """

    if max_spectra is None:
        max_spectra = len(sources)

    spectra_data = {}
    successful_downloads = 0

    print(f"\nDownloading RVS spectra for up to {max_spectra} sources...")

    # Check column names (Gaia returns uppercase)
    if 'SOURCE_ID' in sources.colnames:
        source_id_col = 'SOURCE_ID'
    else:
        source_id_col = 'source_id'

    for i, source in enumerate(sources[:max_spectra]):
        source_id = source[source_id_col]

        if i % 10 == 0:
            print(f"Progress: {i}/{max_spectra} spectra processed...")

        wavelength, flux = download_rvs_spectrum(source_id)

        if wavelength is not None and flux is not None:
            spectra_data[source_id] = (wavelength, flux)
            successful_downloads += 1

    print(f"\nSuccessfully downloaded {successful_downloads} spectra")

    return spectra_data

In [None]:
## 2. Data preprocessing

def create_spectral_matrix(spectra_data, wavelength_grid=None, fill_value=1.0,
                          n_clip_lower=8, n_clip_upper=3):
    """
    Create a matrix Y where each row is a spectrum.

    Parameters:
    -----------
    spectra_data : dict
        Dictionary of (wavelength, flux) tuples
    wavelength_grid : array-like, optional
        Common wavelength grid. If None, uses the first spectrum's grid
    fill_value : float
        Value to use for replacing NaN/Inf (default: 1.0 for continuum)
    n_clip_lower : int
        Number of pixels to clip from the lower wavelength end
    n_clip_upper : int
        Number of pixels to clip from the upper wavelength end

    Returns:
    --------
    Y : np.ndarray
        Matrix of shape (n_spectra, n_wavelengths)
    wavelength_grid : np.ndarray
        Wavelength grid used
    source_ids : list
        List of source IDs in same order as rows of Y
    bad_pixel_mask : np.ndarray
        Boolean mask indicating replaced NaN/Inf values
    """

    source_ids = list(spectra_data.keys())
    n_spectra = len(source_ids)

    # Use first spectrum to define wavelength grid if not provided
    if wavelength_grid is None:
        wavelength_grid = spectra_data[source_ids[0]][0]
        # Apply clipping to wavelength grid
        if n_clip_lower > 0 or n_clip_upper > 0:
            wavelength_grid = wavelength_grid[n_clip_lower:len(wavelength_grid)-n_clip_upper]
            print(f"Clipping spectra: removing {n_clip_lower} pixels from lower end, "
                  f"{n_clip_upper} pixels from upper end")
            print(f"New wavelength range: {wavelength_grid[0]:.2f} - {wavelength_grid[-1]:.2f} nm")

    n_wavelengths = len(wavelength_grid)

    # Initialize spectral matrix and bad pixel mask
    Y = np.zeros((n_spectra, n_wavelengths))
    bad_pixel_mask = np.zeros((n_spectra, n_wavelengths), dtype=bool)

    # Track statistics
    total_bad_pixels = 0
    spectra_with_bad_pixels = 0

    # Fill matrix
    for i, source_id in enumerate(source_ids):
        wavelength, flux = spectra_data[source_id]

        # Apply clipping to flux
        if n_clip_lower > 0 or n_clip_upper > 0:
            flux = flux[n_clip_lower:len(flux)-n_clip_upper]

        # Handle different length spectra
        if len(flux) != n_wavelengths:
            print(f"Warning: Spectrum {source_id} has different length ({len(flux)} vs {n_wavelengths})")
            # Simple truncation or padding
            min_len = min(len(flux), n_wavelengths)
            flux_adjusted = np.full(n_wavelengths, fill_value)
            flux_adjusted[:min_len] = flux[:min_len]
            flux = flux_adjusted

        # Find bad pixels (NaN or Inf)
        bad_pixels = np.isnan(flux) | np.isinf(flux)
        if np.any(bad_pixels):
            spectra_with_bad_pixels += 1
            total_bad_pixels += np.sum(bad_pixels)
            bad_pixel_mask[i, :] = bad_pixels

            # Replace bad pixels with fill_value
            flux = np.where(bad_pixels, fill_value, flux)

        Y[i, :] = flux

    print(f"\nBad pixel statistics:")
    print(f"  Spectra with bad pixels: {spectra_with_bad_pixels}/{n_spectra}")
    print(f"  Total bad pixels: {total_bad_pixels}")
    print(f"  Bad pixels replaced with: {fill_value}")

    return Y, wavelength_grid, source_ids, bad_pixel_mask


In [None]:
## 3. Robust PCA Implementation

class RobustPCA:
    """
    Robust PCA using Alternating Direction Method of Multipliers (ADMM)

    Decomposes a matrix Y = L + S where:
    - L is low-rank
    - S is sparse
    """

    def __init__(self, lambda_param=None, mu=None, tol=1e-6, max_iter=400):
        """
        Parameters:
        -----------
        lambda_param : float
            Regularization parameter for sparse component
            Default: 1/sqrt(max(m,n))
        mu : float
            ADMM penalty parameter
            Default: 0.25/mean(abs(Y))
        tol : float
            Convergence tolerance
        max_iter : int
            Maximum iterations
        """
        self.lambda_param = lambda_param
        self.mu = mu
        self.tol = tol
        self.max_iter = max_iter
        self.L = None
        self.S = None

    def fit(self, Y):
        """
        Fit Robust PCA model to data matrix Y.
        """
        m, n = Y.shape

        # Set default parameters if not provided
        if self.lambda_param is None:
            self.lambda_param = 1. / np.sqrt(max(m, n))

        if self.mu is None:
            self.mu = 0.25 / np.mean(np.abs(Y))

        # Initialize variables
        L = np.zeros_like(Y)
        S = np.zeros_like(Y)
        Z = np.zeros_like(Y)

        # ADMM iterations
        print(f"Starting Robust PCA with lambda={self.lambda_param:.4f}, mu={self.mu:.4f}")

        for iter_num in range(self.max_iter):
            # Update L (low-rank component) via SVD shrinkage
            U, sigma, Vt = svd(Y - S + Z/self.mu, full_matrices=False)
            sigma_shrink = np.maximum(sigma - 1/self.mu, 0)
            L = U @ np.diag(sigma_shrink) @ Vt

            # Update S (sparse component) via soft thresholding
            S = self._soft_threshold(Y - L + Z/self.mu, self.lambda_param/self.mu)
            noise_threshold = 3.0 * np.std(Y - L, axis=0)  # Simple noise estimate
            S = S * (np.abs(S) > noise_threshold[np.newaxis, :])  # Zero out below noise

            # Update dual variable Z
            Z = Z + self.mu * (Y - L - S)

            # Check convergence
            primal_residual = np.linalg.norm(Y - L - S, 'fro')

            if iter_num % 50 == 0:
                rank_L = np.sum(sigma_shrink > 0)
                sparsity_S = np.sum(np.abs(S) > 1e-6) / S.size
                print(f"Iter {iter_num}: residual={primal_residual:.6f}, "
                      f"rank(L)={rank_L}, sparsity(S)={sparsity_S:.3%}")

            if primal_residual < self.tol:
                print(f"Converged after {iter_num + 1} iterations")
                break

        self.L = L
        self.S = S

        # Compute final statistics
        self.rank = np.linalg.matrix_rank(L)
        self.sparsity = np.sum(np.abs(S) > 1e-6) / S.size

        return self

    def _soft_threshold(self, X, threshold):
        """Soft thresholding operator."""
        return np.sign(X) * np.maximum(np.abs(X) - threshold, 0)

    def transform(self, Y):
        """Transform new data using fitted model."""
        if self.L is None:
            raise ValueError("Model not fitted yet")
        # For new data, would need to project onto learned subspace
        # For now, just return decomposition of training data
        return self.L, self.S

In [None]:
## 4. Analysis Functions

def analyze_low_rank_component(L, wavelength_grid, n_components=5):
    """
    Analyze the low-rank component using SVD.

    Returns:
    --------
    components : dict
        Dictionary with 'U', 'S', 'V' from SVD
    """

    U, S, Vt = svd(L, full_matrices=False)

    components = {
        'U': U[:, :n_components],  # Spectral coefficients
        'S': S[:n_components],     # Singular values
        'V': Vt[:n_components, :], # Principal spectral components
        'wavelength': wavelength_grid
    }

    return components


def find_sparse_outliers(S, source_ids, threshold_percentile=95):
    """
    Find spectra with significant sparse contributions.

    Returns:
    --------
    outlier_info : dict
        Information about outlier spectra
    """

    # Compute sparse contribution per spectrum
    sparse_norms = np.linalg.norm(S, axis=1)
    threshold = np.percentile(sparse_norms, threshold_percentile)

    outlier_mask = sparse_norms > threshold
    outlier_indices = np.where(outlier_mask)[0]

    outlier_info = {
        'indices': outlier_indices,
        'source_ids': [source_ids[i] for i in outlier_indices],
        'sparse_norms': sparse_norms[outlier_indices],
        'threshold': threshold,
        'all_norms': sparse_norms
    }

    return outlier_info


def identify_sparse_features(S_spectrum, wavelength, feature_threshold=0.05):
    """
    Identify significant features in a sparse spectrum.

    Parameters:
    -----------
    S_spectrum : array
        Sparse component for one spectrum
    wavelength : array
        Wavelength grid
    feature_threshold : float
        Threshold for significant features (relative to max)

    Returns:
    --------
    features : dict
        Dictionary of identified features
    """

    # Find peaks in absolute sparse component
    abs_S = np.abs(S_spectrum)
    max_S = np.max(abs_S)

    if max_S < 1e-6:
        return {'emission': [], 'absorption': [], 'wavelengths': []}

    # Find significant features
    significant_mask = abs_S > feature_threshold * max_S
    feature_indices = np.where(significant_mask)[0]

    emission_features = []
    absorption_features = []

    for idx in feature_indices:
        if S_spectrum[idx] > 0:
            emission_features.append((wavelength[idx], S_spectrum[idx]))
        else:
            absorption_features.append((wavelength[idx], S_spectrum[idx]))

    features = {
        'emission': emission_features,
        'absorption': absorption_features,
        'wavelengths': wavelength[feature_indices],
        'values': S_spectrum[feature_indices]
    }

    return features

In [None]:
## 5. Visualization Functions

def plot_rpca_components(rpca_model, wavelength_grid, n_examples=3):
    """
    Visualize the Robust PCA decomposition.
    """

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Plot 1: Singular values
    U, S, Vt = svd(rpca_model.L, full_matrices=False)
    axes[0, 0].semilogy(S[:50], 'bo-')
    axes[0, 0].set_xlabel('Component')
    axes[0, 0].set_ylabel('Singular Value')
    axes[0, 0].set_title('Singular Values of Low-Rank Component')
    axes[0, 0].grid(True, alpha=0.3)

    # Plot 2: First few principal components
    ax = axes[0, 1]
    for i in range(min(5, len(S))):
        ax.plot(wavelength_grid, Vt[i, :] + i*0.5,
                label=f'PC{i+1} (σ={S[i]:.2f})')
    ax.set_xlabel('Wavelength (nm)')
    ax.set_ylabel('Principal Component (offset)')
    ax.set_title('First 5 Principal Spectral Components')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Plot 3: Distribution of sparse values
    ax = axes[1, 0]
    sparse_values = rpca_model.S.flatten()
    sparse_nonzero = sparse_values[np.abs(sparse_values) > 1e-6]
    ax.hist(sparse_nonzero, bins=50, alpha=0.7, color='red')
    ax.set_xlabel('Sparse Component Value')
    ax.set_ylabel('Count')
    ax.set_title(f'Distribution of Non-zero Sparse Values\n'
                 f'Sparsity: {rpca_model.sparsity:.1%}')
    ax.grid(True, alpha=0.3)

    # Plot 4: Example sparse spectrum
    ax = axes[1, 1]
    sparse_norms = np.linalg.norm(rpca_model.S, axis=1)
    top_sparse_idx = np.argsort(sparse_norms)[-n_examples:]

    for i, idx in enumerate(top_sparse_idx):
        ax.plot(wavelength_grid, rpca_model.S[idx, :] + i*0.1,
                label=f'Spectrum {idx}')

    ax.set_xlabel('Wavelength (nm)')
    ax.set_ylabel('Sparse Component (offset)')
    ax.set_title('Top Sparse Spectra')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Mark Ca II triplet positions (vacuum wavelengths)
    ca_lines_vacuum = [849.8023, 854.4444, 866.4536]
    for ca_line in ca_lines_vacuum:
        ax.axvline(ca_line, color='gray', linestyle='--', alpha=0.5)

    plt.tight_layout()
    plt.show()


def plot_spectrum_decomposition(Y, L, S, idx, wavelength_grid, source_id=None, return_fig=False):
    """
    Plot the decomposition of a single spectrum.
    """

    fig, axes = plt.subplots(4, 1, figsize=(12, 10), sharex=True)

    title_suffix = f" (Source {source_id})" if source_id else f" (Index {idx})"

    # Original spectrum
    axes[0].plot(wavelength_grid, Y[idx, :], 'b-', linewidth=1.5)
    axes[0].set_ylabel('Flux')
    axes[0].set_title(f'Original Spectrum{title_suffix}')
    axes[0].grid(True, alpha=0.3)
    axes[0].set_ylim(0, 1.2)

    # Low-rank component
    axes[1].plot(wavelength_grid, L[idx, :], 'g-', linewidth=1.5)
    axes[1].set_ylabel('Flux')
    axes[1].set_title('Low-Rank Component')
    axes[1].grid(True, alpha=0.3)
    axes[1].set_ylim(0, 1.2)

    # Sparse component
    axes[2].plot(wavelength_grid, S[idx, :], 'r-', linewidth=1.5)
    axes[2].set_ylabel('Flux')
    axes[2].set_title(f'Sparse Component (||S||₂ = {np.linalg.norm(S[idx, :]):.4f})')
    axes[2].grid(True, alpha=0.3)

    # Residual
    residual = Y[idx, :] - L[idx, :] - S[idx, :]
    axes[3].plot(wavelength_grid, residual, 'k-', linewidth=1.5)
    axes[3].set_ylabel('Flux')
    axes[3].set_title(f'Residual (||r||₂ = {np.linalg.norm(residual):.4f})')
    axes[3].set_xlabel('Wavelength (nm)')
    axes[3].grid(True, alpha=0.3)

    # Mark Ca II triplet (vacuum wavelengths)
    ca_lines_vacuum = [849.8023, 854.4444, 866.4536]
    for ax in axes:
        for ca_line in ca_lines_vacuum:
            ax.axvline(ca_line, color='gray', linestyle='--', alpha=0.5)

    plt.tight_layout()

    if return_fig:
        return fig
    else:
        plt.show()


def plot_outlier_analysis(Y, S, outlier_info, wavelength_grid, source_ids,
                         sources_table, n_examples=5):
    """
    Detailed analysis of outlier spectra.
    """

    n_outliers = min(n_examples, len(outlier_info['indices']))

    if n_outliers == 0:
        print("No outliers found!")
        return

    fig, axes = plt.subplots(n_outliers, 2, figsize=(15, 4*n_outliers))

    if n_outliers == 1:
        axes = axes.reshape(1, -1)

    # Check column names
    source_id_col = 'SOURCE_ID' if 'SOURCE_ID' in sources_table.colnames else 'source_id'
    teff_col = 'TEFF_GSPPHOT' if 'TEFF_GSPPHOT' in sources_table.colnames else 'teff_gspphot'
    logg_col = 'LOGG_GSPPHOT' if 'LOGG_GSPPHOT' in sources_table.colnames else 'logg_gspphot'
    grvs_col = 'GRVS_MAG' if 'GRVS_MAG' in sources_table.colnames else 'grvs_mag'

    for i in range(n_outliers):
        idx = outlier_info['indices'][i]
        source_id = source_ids[idx]

        # Find source info
        source_mask = sources_table[source_id_col] == source_id
        if np.any(source_mask):
            source_info = sources_table[source_mask][0]
            info_str = (f"Teff={source_info[teff_col]:.0f}K, "
                       f"log g={source_info[logg_col]:.2f}, "
                       f"GRVS={source_info[grvs_col]:.2f}")
        else:
            info_str = "No info"

        # Plot original spectrum
        ax = axes[i, 0]
        ax.plot(wavelength_grid, Y[idx, :], 'b-', alpha=0.7, label='Original')
        ax.plot(wavelength_grid, Y[idx, :] - S[idx, :], 'g-', alpha=0.7,
                label='Low-rank approx')
        ax.set_ylabel('Flux')
        ax.set_title(f'Source {source_id}\n{info_str}')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 1.2)

        # Plot sparse component
        ax = axes[i, 1]
        ax.plot(wavelength_grid, S[idx, :], 'r-', linewidth=1.5)
        ax.axhline(0, color='k', linestyle='-', alpha=0.3)
        ax.set_ylabel('Sparse Component')
        ax.set_title(f'Sparse Component (||S||₂ = {outlier_info["sparse_norms"][i]:.4f})')
        ax.grid(True, alpha=0.3)

        # Identify and mark features
        features = identify_sparse_features(S[idx, :], wavelength_grid)

        # Mark emission features
        for wave, value in features['emission']:
            ax.plot(wave, value, 'ro', markersize=8)
            ax.annotate(f'{wave:.1f}', (wave, value),
                       xytext=(5, 5), textcoords='offset points',
                       fontsize=8, color='red')

        # Mark absorption features
        for wave, value in features['absorption']:
            ax.plot(wave, value, 'bo', markersize=8)
            ax.annotate(f'{wave:.1f}', (wave, value),
                       xytext=(5, -10), textcoords='offset points',
                       fontsize=8, color='blue')

        if i == n_outliers - 1:
            ax.set_xlabel('Wavelength (nm)')

    plt.tight_layout()
    plt.show()

In [None]:
## 6. Main Analysis Pipeline

def main_analysis():
    """
    Main analysis pipeline.
    """

    # Set parameters
    params = {
        'teff_min': 4200,
        'teff_max': 4800,
        'logg_min': 1.0,
        'logg_max': 3.0,
        'grvs_mag_max': 11.0,
        'n_sources': 2000
    }

    print("=" * 70)
    print("Robust PCA Analysis of Gaia RVS Spectra")
    print("=" * 70)
    print("\nParameters:")
    for key, value in params.items():
        print(f"  {key}: {value}")

    # Step 1: Query Gaia archive
    print("\n" + "="*50)
    print("STEP 1: Querying Gaia Archive")
    print("="*50)

    sources = find_rvs_sources_gspphot(**params)

    # Save source table
    sources.write('gaia_rvs_sources.fits', format='fits', overwrite=True)
    print(f"Source table saved to 'gaia_rvs_sources.fits'")

    # Step 2: Download spectra
    print("\n" + "="*50)
    print("STEP 2: Downloading RVS Spectra")
    print("="*50)

    spectra_data = download_multiple_spectra(sources, max_spectra=params['n_sources'])

    # Save spectra data
    with open('spectra_data.pkl', 'wb') as f:
        pickle.dump(spectra_data, f)
    print(f"Spectra data saved to 'spectra_data.pkl'")

    # Step 3: Create spectral matrix
    print("\n" + "="*50)
    print("STEP 3: Creating Spectral Matrix")
    print("="*50)

    #Y, wavelength_grid, source_ids, bad_pixel_mask = create_spectral_matrix(spectra_data)
    Y, wavelength_grid, source_ids, bad_pixel_mask = create_spectral_matrix(spectra_data, n_clip_lower=8, n_clip_upper=3)
    print(f"Spectral matrix shape: {Y.shape}")
    print(f"Wavelength range: {wavelength_grid[0]:.2f} - {wavelength_grid[-1]:.2f} nm")

    # Additional diagnostics
    print(f"\nSpectral matrix statistics:")
    print(f"  Min flux: {np.min(Y):.4f}")
    print(f"  Max flux: {np.max(Y):.4f}")
    print(f"  Mean flux: {np.mean(Y):.4f}")
    print(f"  Std flux: {np.std(Y):.4f}")
    print(f"  Contains NaN: {np.any(np.isnan(Y))}")
    print(f"  Contains Inf: {np.any(np.isinf(Y))}")

    # Check if matrix is valid
    if Y.shape[0] == 0:
        print("ERROR: No valid spectra found in the data!")
        return None

    # Step 4: Apply Robust PCA
    print("\n" + "="*50)
    print("STEP 4: Applying Robust PCA")
    print("="*50)

    rpca = RobustPCA()
    rpca.fit(Y)

    print(f"\nRobust PCA Results:")
    print(f"  Rank of low-rank component: {rpca.rank}")
    print(f"  Sparsity of sparse component: {rpca.sparsity:.2%}")
    print(f"  Reconstruction error: {np.linalg.norm(Y - rpca.L - rpca.S, 'fro'):.6f}")

    # Save RPCA results
    np.savez('rpca_results.npz',
             L=rpca.L, S=rpca.S, Y=Y,
             wavelength=wavelength_grid,
             source_ids=source_ids,
             bad_pixel_mask=bad_pixel_mask)
    print(f"RPCA results saved to 'rpca_results.npz'")

    # Step 5: Analyze components
    print("\n" + "="*50)
    print("STEP 5: Analyzing Components")
    print("="*50)

    # Analyze bad pixels in sparse component
    print("\nAnalyzing bad pixels in sparse component:")

    # Find which bad pixels ended up in sparse component
    sparse_mask = np.abs(rpca.S) > 1e-6
    bad_in_sparse = bad_pixel_mask & sparse_mask

    n_bad_pixels = np.sum(bad_pixel_mask)
    n_bad_in_sparse = np.sum(bad_in_sparse)

    print(f"  Total bad pixels (NaN/Inf): {n_bad_pixels}")
    print(f"  Bad pixels identified in sparse component: {n_bad_in_sparse}")
    if n_bad_pixels > 0:
        print(f"  Percentage of bad pixels caught by sparse component: {100*n_bad_in_sparse/n_bad_pixels:.1f}%")

    # Analyze low-rank component
    components = analyze_low_rank_component(rpca.L, wavelength_grid)
    print(f"\nPrincipal component analysis:")
    print(f"  Explained variance ratios: {components['S'][:5]**2 / np.sum(components['S']**2)}")

    # Find outliers
    outlier_info = find_sparse_outliers(rpca.S, source_ids, threshold_percentile=90)
    print(f"\nOutlier analysis:")
    print(f"  Number of outliers (>90th percentile): {len(outlier_info['indices'])}")
    print(f"  Outlier source IDs: {outlier_info['source_ids'][:10]}")

    # Step 6: Visualizations
    print("\n" + "="*50)
    print("STEP 6: Creating Visualizations")
    print("="*50)

    # Plot RPCA components overview
    plot_rpca_components(rpca, wavelength_grid)

    # Create directory for individual spectrum plots
    plot_dir = "rpca_spectrum_plots"
    os.makedirs(plot_dir, exist_ok=True)
    print(f"\nCreated directory '{plot_dir}' for individual spectrum plots")

    # Sort ALL spectra by sparse component norm and plot top 20
    print("\nSorting spectra by sparse component norm...")
    sparse_norms = np.linalg.norm(rpca.S, axis=1)
    sorted_indices = np.argsort(sparse_norms)[::-1]  # Descending order

    n_top = 20
    print(f"\nPlotting top {n_top} spectra with largest sparse components:")

    for rank, idx in enumerate(sorted_indices[:n_top]):
        source_id = source_ids[idx]
        sparse_norm = sparse_norms[idx]
        print(f"  Rank {rank+1}: Source {source_id}, ||S||₂ = {sparse_norm:.4f}")

        # Create the 4-panel plot
        fig = plot_spectrum_decomposition(Y, rpca.L, rpca.S, idx, wavelength_grid,
                                        source_id=source_id, return_fig=True)

        # Save to file
        filename = f"spectrum_rank{rank+1:02d}_source{source_id}_sparsenorm{sparse_norm:.4f}.png"
        filepath = os.path.join(plot_dir, filename)
        fig.savefig(filepath, dpi=150, bbox_inches='tight')
        plt.close(fig)  # Close to save memory

        print(f"    Saved: {filename}")

    print(f"\nAll {n_top} spectrum plots saved to '{plot_dir}/' directory")

    # Create scatter plot of bad pixels vs sparse pixels
    print("\nCreating bad pixels vs sparse pixels scatter plot...")

    # Count bad pixels and sparse pixels per spectrum
    n_bad_per_spectrum = np.sum(bad_pixel_mask, axis=1)
    sparse_mask = np.abs(rpca.S) > 1e-6
    n_sparse_per_spectrum = np.sum(sparse_mask, axis=1)

    # Create scatter plot
    fig, ax = plt.subplots(figsize=(10, 8))

    # Color points by sparse norm
    scatter = ax.scatter(n_bad_per_spectrum, n_sparse_per_spectrum,
                        c=sparse_norms, cmap='viridis', alpha=0.6, s=50)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('||S||₂ (Sparse Norm)', fontsize=12)

    # Add diagonal line for reference
    max_val = max(np.max(n_bad_per_spectrum), np.max(n_sparse_per_spectrum))
    if max_val > 0:
        ax.plot([0, max_val], [0, max_val], 'k--', alpha=0.3, label='1:1 line')

    ax.set_xlabel('Number of Bad Pixels (NaN/Inf)', fontsize=12)
    ax.set_ylabel('Number of Sparse Pixels (|S| > 1e-6)', fontsize=12)
    ax.set_title('Bad Pixels vs Sparse Pixels for All Spectra', fontsize=14)
    ax.grid(True, alpha=0.3)

    # Add statistics to plot
    correlation = np.corrcoef(n_bad_per_spectrum, n_sparse_per_spectrum)[0, 1]
    textstr = f'Total spectra: {len(n_bad_per_spectrum)}\n'
    textstr += f'Spectra with bad pixels: {np.sum(n_bad_per_spectrum > 0)}\n'
    textstr += f'Correlation: {correlation:.3f}'
    ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()

    # Save scatter plot
    scatter_filename = 'bad_vs_sparse_pixels_scatter.png'
    plt.savefig(scatter_filename, dpi=150, bbox_inches='tight')
    print(f"Scatter plot saved to '{scatter_filename}'")

    plt.show()

    # Still do outlier analysis but with fewer examples
    print("\nPlotting outlier analysis (top 5)...")
    outlier_info['indices'] = sorted_indices[:5]  # Use top 5 for the outlier plot
    outlier_info['source_ids'] = [source_ids[i] for i in sorted_indices[:5]]
    outlier_info['sparse_norms'] = sparse_norms[sorted_indices[:5]]

    plot_outlier_analysis(Y, rpca.S, outlier_info, wavelength_grid,
                         source_ids, sources, n_examples=5)

    # Step 7: Feature statistics
    print("\n" + "="*50)
    print("STEP 7: Feature Statistics")
    print("="*50)

    all_emission = []
    all_absorption = []

    for idx in outlier_info['indices']:
        features = identify_sparse_features(rpca.S[idx, :], wavelength_grid)
        all_emission.extend([w for w, _ in features['emission']])
        all_absorption.extend([w for w, _ in features['absorption']])

    if all_emission:
        print(f"\nEmission features found at wavelengths:")
        unique_emission = np.unique(np.round(all_emission, 1))
        emission_counts = {wave: all_emission.count(wave) for wave in unique_emission}
        for wave, count in sorted(emission_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
            print(f"  {wave:.1f} nm: {count} occurrences")

    if all_absorption:
        print(f"\nAbsorption features found at wavelengths:")
        unique_absorption = np.unique(np.round(all_absorption, 1))
        absorption_counts = {wave: all_absorption.count(wave) for wave in unique_absorption}
        for wave, count in sorted(absorption_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
            print(f"  {wave:.1f} nm: {count} occurrences")

    print("\n" + "="*70)
    print("Analysis complete!")
    print("="*70)

    return {
        'sources': sources,
        'spectra_data': spectra_data,
        'Y': Y,
        'wavelength_grid': wavelength_grid,
        'source_ids': source_ids,
        'rpca': rpca,
        'components': components,
        'outlier_info': outlier_info,
        'bad_pixel_mask': bad_pixel_mask
    }



In [None]:
## 7. Additional Analysis Functions (Optional)

def analyze_spectral_type_differences(results):
    """
    Analyze how RPCA components differ by stellar parameters.
    """

    sources = results['sources']
    rpca = results['rpca']
    source_ids = results['source_ids']

    # Get stellar parameters for each spectrum
    teff_list = []
    logg_list = []

    # Check column names (uppercase from Gaia)
    teff_col = 'TEFF_GSPPHOT' if 'TEFF_GSPPHOT' in sources.colnames else 'teff_gspphot'
    logg_col = 'LOGG_GSPPHOT' if 'LOGG_GSPPHOT' in sources.colnames else 'logg_gspphot'
    source_id_col = 'SOURCE_ID' if 'SOURCE_ID' in sources.colnames else 'source_id'

    for sid in source_ids:
        mask = sources[source_id_col] == sid
        if np.any(mask):
            source = sources[mask][0]
            teff_list.append(source[teff_col])
            logg_list.append(source[logg_col])
        else:
            teff_list.append(np.nan)
            logg_list.append(np.nan)

    teff_array = np.array(teff_list)
    logg_array = np.array(logg_list)

    # Analyze sparse norms vs stellar parameters
    sparse_norms = np.linalg.norm(rpca.S, axis=1)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Sparse norm vs Teff
    scatter1 = ax1.scatter(teff_array, sparse_norms, c=logg_array,
                          cmap='viridis', alpha=0.6)
    ax1.set_xlabel('Teff (K)')
    ax1.set_ylabel('||S|| (Sparse Norm)')
    ax1.set_title('Sparse Component vs Temperature')
    plt.colorbar(scatter1, ax=ax1, label='log g')

    # Sparse norm vs log g
    scatter2 = ax2.scatter(logg_array, sparse_norms, c=teff_array,
                          cmap='plasma', alpha=0.6)
    ax2.set_xlabel('log g')
    ax2.set_ylabel('||S|| (Sparse Norm)')
    ax2.set_title('Sparse Component vs Surface Gravity')
    plt.colorbar(scatter2, ax=ax2, label='Teff (K)')

    plt.tight_layout()
    plt.show()


def export_results_for_further_analysis(results, output_prefix='rpca_gaia_rvs'):
    """
    Export results in various formats for further analysis.
    """

    # Export outlier spectra
    outlier_table = Table()
    outlier_table['source_id'] = results['outlier_info']['source_ids']
    outlier_table['sparse_norm'] = results['outlier_info']['sparse_norms']

    # Add stellar parameters
    sources = results['sources']

    # Check column names
    source_id_col = 'SOURCE_ID' if 'SOURCE_ID' in sources.colnames else 'source_id'
    teff_col = 'TEFF_GSPPHOT' if 'TEFF_GSPPHOT' in sources.colnames else 'teff_gspphot'
    logg_col = 'LOGG_GSPPHOT' if 'LOGG_GSPPHOT' in sources.colnames else 'logg_gspphot'
    mh_col = 'MH_GSPPHOT' if 'MH_GSPPHOT' in sources.colnames else 'mh_gspphot'

    teff_list = []
    logg_list = []
    mh_list = []

    for sid in outlier_table['source_id']:
        mask = sources[source_id_col] == sid
        if np.any(mask):
            source = sources[mask][0]
            teff_list.append(source[teff_col])
            logg_list.append(source[logg_col])
            mh_list.append(source[mh_col])
        else:
            teff_list.append(np.nan)
            logg_list.append(np.nan)
            mh_list.append(np.nan)

    outlier_table['teff_gspphot'] = teff_list
    outlier_table['logg_gspphot'] = logg_list
    outlier_table['mh_gspphot'] = mh_list

    # Save outlier table
    outlier_table.write(f'{output_prefix}_outliers.fits', format='fits', overwrite=True)
    print(f"Outlier table saved to '{output_prefix}_outliers.fits'")

    # Save principal components
    components = results['components']
    pc_table = Table()
    pc_table['wavelength'] = results['wavelength_grid']

    for i in range(min(10, components['V'].shape[0])):
        pc_table[f'PC{i+1}'] = components['V'][i, :]

    pc_table.write(f'{output_prefix}_principal_components.fits', format='fits', overwrite=True)
    print(f"Principal components saved to '{output_prefix}_principal_components.fits'")

    print("\nAll results exported successfully!")


In [None]:
## 8. Parameter Tuning Functions

def tune_lambda_parameter(Y, lambda_values=None):
    """
    Tune the lambda parameter for Robust PCA.
    """

    if lambda_values is None:
        m, n = Y.shape
        base_lambda = 2.0 / np.sqrt(max(m, n))
        lambda_values = base_lambda * np.array([0.5, 0.75, 1.0, 1.25, 1.5, 2.0])

    results = []

    for lam in lambda_values:
        print(f"\nTesting lambda = {lam:.4f}")
        rpca = RobustPCA(lambda_param=lam, max_iter=400)
        rpca.fit(Y)

        result = {
            'lambda': lam,
            'rank': rpca.rank,
            'sparsity': rpca.sparsity,
            'reconstruction_error': np.linalg.norm(Y - rpca.L - rpca.S, 'fro'),
            'sparse_norm': np.linalg.norm(rpca.S, 'fro'),
            'low_rank_norm': np.linalg.norm(rpca.L, 'fro')
        }
        results.append(result)

    # Plot results
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    lambdas = [r['lambda'] for r in results]

    axes[0, 0].plot(lambdas, [r['rank'] for r in results], 'bo-')
    axes[0, 0].set_xlabel('Lambda')
    axes[0, 0].set_ylabel('Rank')
    axes[0, 0].set_title('Rank vs Lambda')
    axes[0, 0].grid(True, alpha=0.3)

    axes[0, 1].plot(lambdas, [r['sparsity'] for r in results], 'ro-')
    axes[0, 1].set_xlabel('Lambda')
    axes[0, 1].set_ylabel('Sparsity')
    axes[0, 1].set_title('Sparsity vs Lambda')
    axes[0, 1].grid(True, alpha=0.3)

    axes[1, 0].plot(lambdas, [r['reconstruction_error'] for r in results], 'go-')
    axes[1, 0].set_xlabel('Lambda')
    axes[1, 0].set_ylabel('Reconstruction Error')
    axes[1, 0].set_title('Reconstruction Error vs Lambda')
    axes[1, 0].grid(True, alpha=0.3)

    axes[1, 1].plot(lambdas, [r['sparse_norm'] for r in results], 'mo-', label='||S||_F')
    axes[1, 1].plot(lambdas, [r['low_rank_norm'] for r in results], 'co-', label='||L||_F')
    axes[1, 1].set_xlabel('Lambda')
    axes[1, 1].set_ylabel('Frobenius Norm')
    axes[1, 1].set_title('Component Norms vs Lambda')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    return results

In [None]:
results = main_analysis()

# Additional analysis can be performed here using the results dictionary
print("\nResults dictionary contains:")
for key in results.keys():
    print(f"  - {key}")