# Robust Heteroskedastic Matrix Factorization
A robust-PCA-like model that knows about observational uncertainties

## Author:
- **David W. Hogg** (NYU) (MPIA) (Flatiron)
- (with help from Claude)

## Dependencies:
- `pip3 install jax matplotlib astropy astroquery`

## Issues:
- Assumes (and gets) rectangular data with known uncertainties.
- `train()` function is written but `test()` function is not.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import clod_gaia_rvs as clod
import rhmf
plt.rcParams['figure.figsize'] = (8, 4.5)
plt.rcParams['font.size'] = 12

In [None]:
params = {
    'teff_min': 9000,
    'teff_max': 20000,
    'logg_min': 1.0,
    'logg_max': 5.0,
    'grvs_mag_max': 9.0,
    'n_sources': 1000
}
sources = clod.find_rvs_sources_gspphot(**params)

In [None]:
n_sources = min(params['n_sources'], len(sources))
spectra_data = clod.download_multiple_spectra(sources, max_spectra=n_sources)

In [None]:
Y, wavelength_grid, source_ids, W = clod.create_spectral_matrices(spectra_data)
print(f"\nSpectral matrix statistics:")
print(f"  shape: {Y.shape}")
print(f"  min flux: {np.min(Y):.4f}")
print(f"  max flux: {np.max(Y):.4f}")
print(f"  mean flux: {np.mean(Y):.4f}")
print(f"  std flux: {np.std(Y):.4f}")
print(f"  median uncertainty: {1. / np.sqrt(np.median(W)):.4f}")
print(f"  flux contains NaN: {np.any(np.isnan(Y))}")
print(f"  flux contains Inf: {np.any(np.isinf(Y))}")
print(f"  invvar zeros: {np.sum(W < 1.e0)}")

In [None]:
# split data
rng = np.random.default_rng(17)
rr = rng.uniform(size=len(source_ids))
A = rr < np.median(rr)
B = np.logical_not(A)
YA, WA, source_ids_A = Y[A], W[A], source_ids[A]
YB, WB, source_ids_B = Y[B], W[B], source_ids[B]
print(YA.shape, YB.shape)

In [None]:
k, nsigma = 30, 3.0
modelA = rhmf.RHMF(k, nsigma)
modelA.train(YA, WA)

In [None]:
def plot_components(model, title, savefig=None):
    for k, g in enumerate(modelA.G):
        plt.plot(wavelength_grid, g + 0.15 * k)
    plt.xlabel("wavelength")
    plt.ylabel("spectral component (plus offset)")
    plt.title(title)
    if savefig is not None:
        plt.savefig(savefig)

plot_components(modelA, "model A", savefig="modelA.png")

In [None]:
modelB = rhmf.RHMF(k, nsigma, G=modelA.G.copy())
modelB.train(YB, WB)

In [None]:
plot_components(modelB, "model B", savefig="modelB.png")

In [None]:
synthB = np.zeros_like(YB) + np.nan
for i, (y, w) in enumerate(zip(YB, WB)):
    synthB[i] = modelA.test(y, w)

In [None]:
cache = "./rvs_plot_cache"
os.makedirs(cache, exist_ok=True)

for ii in range(len(YB)):
    f = plt.figure()
    plt.plot(wavelength_grid, synthB[ii], "r-", lw=1, alpha=0.5)
    plt.plot(wavelength_grid, YB[ii], "k-")
    plt.plot(wavelength_grid, YB[ii] - synthB[ii], "k-")
    plt.title(source_ids[ii])
    plt.savefig(f"{cache}/{source_ids[ii]}.png")
    plt.close(f)