# Robust Heteroskedastic Matrix Factorization
A robust-PCA-like model that knows about observational uncertainties

## Author:
- **David W. Hogg** (NYU) (MPIA) (Flatiron)
- (with help from Claude)

## Dependencies:
- `pip3 install jax matplotlib astropy astroquery`

## Issues:
- Assumes (and gets) rectangular data with known uncertainties.
- `train()` function is written but `test()` function is not.

In [None]:
# jax related
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

In [None]:
# data-gathering and plotting related
import os
import numpy as np
import matplotlib.pyplot as plt
import clod
plt.rcParams['figure.figsize'] = (8, 4.5)
plt.rcParams['font.size'] = 12

In [None]:
params = {
    'teff_min': 8000,
    'teff_max': 20000,
    'logg_min': 1.0,
    'logg_max': 5.0,
    'grvs_mag_max': 9.0,
    'n_sources': 500
}
sources = clod.find_rvs_sources_gspphot(**params)

In [None]:
n_sources = min(params['n_sources'], len(sources))
spectra_data = clod.download_multiple_spectra(sources, max_spectra=n_sources)

In [None]:
Y, wavelength_grid, source_ids, W = clod.create_spectral_matrices(spectra_data)
print(f"\nSpectral matrix statistics:")
print(f"  shape: {Y.shape}")
print(f"  min flux: {np.min(Y):.4f}")
print(f"  max flux: {np.max(Y):.4f}")
print(f"  mean flux: {np.mean(Y):.4f}")
print(f"  std flux: {np.std(Y):.4f}")
print(f"  median uncertainty: {1. / np.sqrt(np.median(W)):.4f}")
print(f"  flux contains NaN: {np.any(np.isnan(Y))}")
print(f"  flux contains Inf: {np.any(np.isinf(Y))}")
print(f"  invvar zeros: {np.sum(W < 1.e0)}")

In [None]:
class RHMF():
    def __init__(self, rank, nsigma, tol=1.e-4):
        self.K = int(rank)
        self.nsigma = float(nsigma)
        self.Q2 = self.nsigma ** 2
        self.tol = tol
        self.trained = False

    def train(self, data, weights):
        """
        # inputs:
        `data`:     (N, M) array of observations.
        `weights`:  (N, M) units of (and equivalent to) inverse uncertainty variances.

        # comments:
        - Checks convergence with the g-step only.
        """
        self.trained = False
        assert np.all(jnp.isfinite(data))
        assert np.all(jnp.isfinite(weights))
        self.Y = jnp.array(data)
        self.input_W = jnp.array(weights)
        assert self.Y.shape == self.input_W.shape
        self.N, self.M = self.Y.shape
        self.converged = False
        self.n_iter = 0
        self._initialize()
        print("train(): before starting:", self.objective(), self.original_objective())
        while not self.converged:
            self._A_step()
            self._G_step()
            self._affine()
            self._update_W()
            self.n_iter += 1
            if self.n_iter % 100 == 0:
                print("train(): after iteration", self.n_iter, ":",
                      self.objective(), self.original_objective())
        print("train(): converged at iteration", self.n_iter, ":",
              self.objective(), self.original_objective())
        self.trained = True

    def test(self, ystar, wstar):
        """
        # inputs:
        `ystar`:     (M, ) array for one observation.
        `wstar`:     (M, ) units of (and equivalent to) inverse uncertainty variances.

        # outputs:
        `synth`:     (M, ) synthetic spectrum.

        # comments:
        - Checks convergence with the a-step only.
        """
        assert self.trained
        assert np.all(np.isfinite(ystar))
        assert np.all(np.isfinite(wstar))
        assert ystar.shape == (self.M, )
        assert wstar.shape == (self.M, )
        self.converged = False
        self.n_iter = 0
        w = 1. * wstar
        a = np.zeros(self.K)
        while not self.converged:
            foo = self.one_star_objective(ystar, w, a)
            a = self._one_star_A_step(ystar, w)
            bar = self.one_star_objective(ystar, w, a)
            if foo - bar < self.tol:
                self.converged = True
            w = self._update_one_star_W(ystar, wstar, a)
            self.n_iter += 1
        print("test(): converged at iteration:", self.n_iter, ":",
              self.one_star_objective(ystar, w, a),
              self.one_star_objective(ystar, wstar, a))
        return self.one_star_synthesis(a)

    def synthesis(self):
        return self.A.T @ self.G

    def one_star_synthesis(self, a):
        return a @ self.G

    def resid(self):
        return self.Y - self.synthesis()

    def one_star_resid(self, y, a):
        return y - self.one_star_synthesis(a)

    def objective(self):
        return jnp.sum(self.W * self.resid() ** 2)

    def one_star_objective(self, y, w, a):
        return jnp.sum(w * self.one_star_resid(y, a) ** 2)

    def original_chi(self):
        return jnp.sqrt(self.input_W) * self.resid()

    def original_objective(self):
        return jnp.sum(self.input_W * self.resid() ** 2)

    def _initialize(self):
        """
        # bugs:
        - Consider switching SVD to a fast PCA implementation?
        """
        self.W = 1. * self.input_W # copy not reference
        u, s, v = jnp.linalg.svd(self.Y, full_matrices=False) # maybe slow
        self.A = (u[:,:self.K] * s[:self.K]).T
        self.G = v[:self.K,:]
        print("_initialize():", self.A.shape, self.G.shape)

    def _one_star_A_step(self, y1, w1):
        XTCinvX = self.G * w1 @ self.G.T
        XTCinvY = self.G * w1 @ y1
        return jnp.linalg.solve(XTCinvX, XTCinvY)

    def _one_star_G_step(self, y1, w1):
        XTCinvX = self.A * w1 @ self.A.T
        XTCinvY = self.A * w1 @ y1
        return jnp.linalg.solve(XTCinvX, XTCinvY)

    def _A_step(self):
        foo = self.objective()
        self.A = jax.vmap(self._one_star_A_step)(self.Y, self.W).T
        bar = self.objective()
        if foo < bar:
            print("_A_step(): ERROR: objective got worse", foo, bar)

    def _G_step(self):
        foo = self.objective()
        self.G = jax.vmap(self._one_star_G_step)(self.Y.T, self.W.T).T
        bar = self.objective()
        if foo < bar:
            print("_G_step(): ERROR: objective got worse", foo, bar)
        if foo - bar < self.tol:
            self.converged = True

    def _affine(self):
        """
        # bugs:
        - Consider switching SVD to a fast PCA implementation?
        """
        u, s, v = jnp.linalg.svd(self.A.T @ self.G, full_matrices=False)
        self.A = (u[:,:self.K] * s[:self.K]).T
        self.G = v[:self.K,:]

    def _update_W(self):
        self.W = self.input_W * self.Q2 / (self.input_W * self.resid() ** 2 + self.Q2)

    def _update_one_star_W(self, y, w, a):
        return w * self.Q2 / (w * self.one_star_resid(y, a) ** 2 + self.Q2)


In [None]:
# split data
rng = np.random.default_rng(17)
rr = rng.uniform(size=len(source_ids))
A = rr < np.median(rr)
B = np.logical_not(A)
YA, WA, source_ids_A = Y[A], W[A], source_ids[A]
YB, WB, source_ids_B = Y[B], W[B], source_ids[B]

In [None]:
k, nsigma = 15, 3.0
modelA = RHMF(k, nsigma)
modelA.train(YA, WA)

In [None]:
def plot_components(model, title, savefig=None):
    for k, g in enumerate(modelA.G):
        plt.plot(wavelength_grid, g + 0.15 * k)
    plt.xlabel("wavelength")
    plt.ylabel("spectral component (plus offset)")
    plt.title(title)
    if savefig is not None:
        plt.savefig(savefig)

plot_components(modelA, "model A", savefig="modelA.png")

In [None]:
modelB = RHMF(k, nsigma)
modelB.train(YB, WB)

In [None]:
plot_components(modelB, "model B", savefig="modelB.png")

In [None]:
synthB = np.zeros_like(YB) + np.nan
for i, (y, w) in enumerate(zip(YB, WB)):
    synthB[i] = modelA.test(y, w)

In [None]:
ii = 23
plt.plot(wavelength_grid, synthB[ii], "r-", lw=1, alpha=0.5)
plt.plot(wavelength_grid, YB[ii], "k-")
plt.plot(wavelength_grid, YB[ii] - synthB[ii], "k-")