<a href="https://colab.research.google.com/github/davidwhogg/RVSanomalies/blob/main/notebooks/RobustHMF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Robust Heteroskedastic Matrix Factorization
A robust-PCA-like model that knows about observational uncertainties

## Author:
- **David W. Hogg** (NYU) (MPIA) (Flatiron)

## Dependencies:
- `pip3 install jax matplotlib astropy astroquery`

## Issues:
- Not yet written.
- Assumes rectangular data with known uncertainties.

In [None]:
# jax related
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

In [None]:
!pip install astropy
!pip install astroquery

In [None]:
# data-gathering and plotting related
import numpy as np
import matplotlib.pyplot as plt
from astroquery.gaia import Gaia
from astropy.table import Table
import os
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 12

In [None]:
def find_rvs_sources_gspphot(teff_min=4810, teff_max=6200,
                            logg_min=1.0, logg_max=3.0,
                            grvs_mag_max=11.0, n_sources=500):
    query = f"""
    SELECT TOP {n_sources}
        source_id, ra, dec, phot_g_mean_mag, grvs_mag,
        radial_velocity, radial_velocity_error,
        teff_gspphot, logg_gspphot, mh_gspphot,
        bp_rp, parallax
    FROM gaiadr3.gaia_source
    WHERE has_rvs = 't'
    AND grvs_mag <= {grvs_mag_max}
    AND teff_gspphot BETWEEN {teff_min} AND {teff_max}
    AND logg_gspphot BETWEEN {logg_min} AND {logg_max}
    AND teff_gspphot IS NOT NULL
    AND logg_gspphot IS NOT NULL
    AND radial_velocity IS NOT NULL
    ORDER BY grvs_mag ASC
    """
    job = Gaia.launch_job_async(query)
    sources = job.get_results()
    print(f"\nFound {len(sources)} sources matching criteria")
    return sources

In [None]:
params = {
    'teff_min': 8000,
    'teff_max': 10000,
    'logg_min': 1.0,
    'logg_max': 5.0,
    'grvs_mag_max': 11.0,
    'n_sources': 200
}
sources = find_rvs_sources_gspphot(**params)

In [None]:
def download_rvs_spectrum(source_id, output_dir="rvs_spectra_cache"):
    """
    Download RVS spectrum for a single source.
    Returns wavelength and flux arrays.
    """
    # Check if spectrum already exists in cache
    cache_file = os.path.join(output_dir, f"rvs_{source_id}.npz")
    if os.path.exists(cache_file):
        data = np.load(cache_file)
        return data['wavelength'], data['flux'], data['flux_error']
    try:
        # Download spectrum
        retrieval_type = 'RVS'
        data_structure = 'INDIVIDUAL'
        data_release = 'Gaia DR3'

        datalink_products = Gaia.load_data(
            ids=[str(source_id)],
            data_release=data_release,
            retrieval_type=retrieval_type,
            data_structure=data_structure,
            verbose=False
        )

        if not datalink_products:
            return None, None

        product_key = f"RVS-Gaia DR3 {source_id}.xml"

        if product_key not in datalink_products:
            return None, None

        # Extract spectrum
        votable = datalink_products[product_key][0]
        spectrum_table = votable.to_table()

        wavelength = np.array(spectrum_table['wavelength'])  # in nm
        flux = np.array(spectrum_table['flux'])  # normalized
        flux_error = np.array(spectrum_table['flux_error'])

        # Save to cache
        os.makedirs(output_dir, exist_ok=True)
        np.savez(cache_file, wavelength=wavelength,
                 flux=flux, flux_error=flux_error)

        return wavelength, flux, flux_error

    except Exception as e:
        print(f"Error downloading spectrum for source {source_id}: {e}")
        return None, None

def download_multiple_spectra(sources, max_spectra=None):
    if max_spectra is None:
        max_spectra = len(sources)

    spectra_data = {}
    successful_downloads = 0

    print(f"\nDownloading RVS spectra for up to {max_spectra} sources...")

    # Check column names (Gaia returns uppercase)
    if 'SOURCE_ID' in sources.colnames:
        source_id_col = 'SOURCE_ID'
    else:
        source_id_col = 'source_id'

    for i, source in enumerate(sources[:max_spectra]):
        source_id = source[source_id_col]

        if i % 10 == 0:
            print(f"Progress: {i}/{max_spectra} spectra processed...")

        wavelength, flux, flux_error = download_rvs_spectrum(source_id)

        if wavelength is not None and flux is not None:
            spectra_data[source_id] = (wavelength, flux, flux_error)
            successful_downloads += 1

    print(f"\nSuccessfully downloaded {successful_downloads} spectra")

    return spectra_data

def create_spectral_matrices(spectra_data, wavelength_grid=None, fill_value=1.0):
    """
    Create a matrices Y, W where each row is a spectrum or its invvar weight

    Parameters:
    -----------
    spectra_data : dict
        Dictionary of (wavelength, flux) tuples
    wavelength_grid : array-like, optional
        Common wavelength grid. If None, uses the first spectrum's grid
    fill_value : float
        Value to use for replacing NaN/Inf (default: 1.0 for continuum)
    n_clip_lower : int
        Number of pixels to clip from the lower wavelength end
    n_clip_upper : int
        Number of pixels to clip from the upper wavelength end

    Returns:
    --------
    Y : np.ndarray
        Matrix of shape (n_spectra, n_wavelengths)
    wavelength_grid : np.ndarray
        Wavelength grid used
    source_ids : list
        List of source IDs in same order as rows of Y
    W : np.ndarray
        invvars for Y, with bad pixels zeroed out
    """

    source_ids = list(spectra_data.keys())
    n_spectra = len(source_ids)

    # Use first spectrum to define wavelength grid if not provided
    if wavelength_grid is None:
        wavelength_grid = spectra_data[source_ids[0]][0]
    n_wavelengths = len(wavelength_grid)

    # Initialize spectral matrix and bad pixel mask
    Y = np.zeros((n_spectra, n_wavelengths)) + np.nan
    W = np.zeros((n_spectra, n_wavelengths))

    # Track statistics
    total_bad_pixels = 0
    spectra_with_bad_pixels = 0

    # Fill matrix
    for i, source_id in enumerate(source_ids):
        wavelength, flux, flux_error = spectra_data[source_id]

        # Make weights / invvars
        invvar = 1. / flux_error ** 2
        bad_pixels = np.isnan(flux) | np.isinf(flux)
        if np.any(bad_pixels):
            spectra_with_bad_pixels += 1
            total_bad_pixels += np.sum(bad_pixels)

            # Replace bad pixels with fill_value
            flux = np.where(bad_pixels, fill_value, flux)
            invvar = np.where(bad_pixels, 0., invvar)

        Y[i, :] = flux
        W[i, :] = invvar

    print(f"\nBad pixel statistics:")
    print(f"  Spectra with bad pixels: {spectra_with_bad_pixels}/{n_spectra}")
    print(f"  Total bad pixels: {total_bad_pixels}")
    print(f"  Bad pixels replaced with: {fill_value}")

    return Y, wavelength_grid, source_ids, W


In [None]:
spectra_data = download_multiple_spectra(sources, max_spectra=params['n_sources'])

In [None]:
Y, wavelength_grid, source_ids, W = create_spectral_matrices(spectra_data)
print(f"Spectral matrix shape: {Y.shape}")
print(f"Wavelength range: {wavelength_grid[0]:.2f} - {wavelength_grid[-1]:.2f} nm")

# Additional diagnostics
print(f"\nSpectral matrix statistics:")
print(f"  Min flux: {np.min(Y):.4f}")
print(f"  Max flux: {np.max(Y):.4f}")
print(f"  Mean flux: {np.mean(Y):.4f}")
print(f"  Std flux: {np.std(Y):.4f}")
print(f"  Mean invvar: {np.nanmean(W):.4f}")
print(f"  Contains NaN: {np.any(np.isnan(Y))}")
print(f"  Contains Inf: {np.any(np.isinf(Y))}")

In [None]:
class RHMF():
    def __init__(self, rank, nsigma):
        self.K = int(rank)
        self.nsigma = float(nsigma)
        self.Q2 = self.nsigma ** 2

    def fit(self, data, weights):
        """
        # inputs:
        `data`:     (N, M) array of observations.
        `weights`:  units of (and equivalent to) inverse uncertainty variances.
        """
        assert np.sum(jnp.isnan(data)) == 0
        assert np.sum(jnp.isnan(weights)) == 0
        self.Y = jnp.array(data)
        self.input_W = jnp.array(weights)
        assert self.Y.shape == self.input_W.shape
        self.N, self.M = self.Y.shape
        self.converged = False
        self.n_iter = 0
        self._initialize()
        print("fit(): before starting:", self.objective(), self.original_objective())
        while not self.converged:
            print("fit(): before A-step:", self.objective(), self.original_objective())
            self._A_step()
            print("fit(): before G-step:", self.objective(), self.original_objective())
            self._G_step()
            print("fit(): before affine step:", self.objective(), self.original_objective())
            self._affine()
            print("fit(): before weight update step:", self.objective(), self.original_objective())
            self._update_W()
            print("fit(): after weight update step:", self.objective(), self.original_objective())
            self.n_iter += 1

    def _initialize(self):
        """
        # bugs:
        - Consider switching SVD to a fast PCA implementation?
        """
        self.W = 1. * self.input_W # copy not reference
        u, s, v = jnp.linalg.svd(Y, full_matrices=False) # maybe slow
        self.A = (u[:,:self.K] * s[:self.K]).T
        self.G = v[:self.K,:]
        print("_initialize():", self.A.shape, self.G.shape)

    def _one_star_A_step(self, i):
        XTCinvX = self.G * self.W[i] @ self.G.T
        XTCinvY = self.G * self.W[i] @ self.Y[i]
        return jnp.linalg.solve(XTCinvX, XTCinvY)

    def _one_star_G_step(self, j):
        XTCinvX = self.A * self.W[:,j] @ self.A.T
        XTCinvY = self.A * self.W[:,j] @ self.Y[:,j]
        return jnp.linalg.solve(XTCinvX, XTCinvY)

    def _A_step(self):
        self.A = jax.vmap(self._one_star_A_step)(jnp.arange(self.N)).T

    def _G_step(self):
        foo = self.objective()
        self.G = jax.vmap(self._one_star_G_step)(jnp.arange(self.M)).T
        bar = self.objective()
        if foo < bar:
            print("_G_step(): ERROR: objective got worse", foo, bar)
        if foo - bar < 1.e-2: # magic
            self.converged = True

    def _affine(self):
        """
        # bugs:
        - Consider switching SVD to a fast PCA implementation?
        """
        u, s, v = jnp.linalg.svd(self.A.T @ self.G, full_matrices=False)
        self.A = (u[:,:self.K] * s[:self.K]).T
        self.G = v[:self.K,:]

    def synthesis(self):
        return self.A.T @ self.G

    def resid(self):
        return self.Y - self.synthesis()

    def objective(self):
        return jnp.sum(self.W * self.resid() ** 2)

    def original_chi(self):
        return self.resid() * jnp.sqrt(self.input_W)

    def original_objective(self):
        return jnp.sum(self.input_W * self.resid() ** 2)

    def _update_W(self):
        self.W = self.input_W * self.Q2 / (self.input_W * self.resid() ** 2 + self.Q2)

In [None]:
model = RHMF(10, 2.5)
model.fit(Y, W)

In [None]:
for k, g in enumerate(model.G):
    plt.plot(wavelength_grid, g + 0.15 * k)

In [None]:
chi = model.original_chi()
chi_squared = np.sum(chi ** 2, axis=1)
indx = np.argsort(chi_squared)
print(chi.shape, indx.shape)

In [None]:
resid = model.resid()
for i in range(8):
    plt.plot(wavelength_grid, resid[indx[-i]] + 0.03 * i)