# Robust Heteroskedastic Matrix Factorization
A robust-PCA-like model that knows about observational uncertainties

## Author:
- **David W. Hogg** (NYU) (MPIA) (Flatiron)
- (with help from Claude)

## Dependencies:
- `pip3 install jax matplotlib astropy astroquery`

## Issues:
- Assumes (and gets) rectangular data with known uncertainties.
- `train()` function is written but `test()` function is not.

In [None]:
# jax related
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

In [None]:
# data-gathering and plotting related
import os
import numpy as np
import matplotlib.pyplot as plt
import clod
plt.rcParams['figure.figsize'] = (8, 4.5)
plt.rcParams['font.size'] = 12

In [None]:
params = {
    'teff_min': 8000,
    'teff_max': 9000,
    'logg_min': 1.0,
    'logg_max': 5.0,
    'grvs_mag_max': 10.0,
    'n_sources': 1000
}
sources = clod.find_rvs_sources_gspphot(**params)

In [None]:
n_sources = min(params['n_sources'], len(sources))
spectra_data = clod.download_multiple_spectra(sources, max_spectra=n_sources)

In [None]:
Y, wavelength_grid, source_ids, W = clod.create_spectral_matrices(spectra_data)
print(f"\nSpectral matrix statistics:")
print(f"  min flux: {np.min(Y):.4f}")
print(f"  max flux: {np.max(Y):.4f}")
print(f"  mean flux: {np.mean(Y):.4f}")
print(f"  std flux: {np.std(Y):.4f}")
print(f"  median uncertainty: {1. / np.sqrt(np.median(W)):.4f}")
print(f"  flux contains NaN: {np.any(np.isnan(Y))}")
print(f"  flux contains Inf: {np.any(np.isinf(Y))}")
print(f"  invvar zeros: {np.sum(W < 1.e0)}")

In [None]:
class RHMF():
    def __init__(self, rank, nsigma):
        self.K = int(rank)
        self.nsigma = float(nsigma)
        self.Q2 = self.nsigma ** 2
        self.trained = False

    def train(self, data, weights):
        """
        # inputs:
        `data`:     (N, M) array of observations.
        `weights`:  units of (and equivalent to) inverse uncertainty variances.
        """
        self.trained = False
        assert np.sum(jnp.isnan(data)) == 0
        assert np.sum(jnp.isnan(weights)) == 0
        self.Y = jnp.array(data)
        self.input_W = jnp.array(weights)
        assert self.Y.shape == self.input_W.shape
        self.N, self.M = self.Y.shape
        self.converged = False
        self.n_iter = 0
        self._initialize()
        print("train(): before starting:", self.objective(), self.original_objective())
        while not self.converged:
            self._A_step()
            self._G_step()
            self._affine()
            self._update_W()
            self.n_iter += 1
            if self.n_iter % 10 == 0:
                print("train(): after iteration", self.n_iter, ":",
                      self.objective(), self.original_objective())
        print("train(): converged at iteration", self.n_iter, ":",
              self.objective(), self.original_objective())
        self.trained = True

    def test(self, ystar, wstar):
        """
        # inputs:
        `ystar`:     (M, ) array for one observation.
        `wstar`:     units of (and equivalent to) inverse uncertainty variances.
        """
        assert self.trained
        assert np.sum(ystar) == 0
        assert np.sum(wstar) == 0
        assert ystar.shape == (self.M, )
        assert wstar.shape == (self.M, )
        self.converged = False
        self.n_iter = 0
        w = 1. * wstar
        while not self.converged:
            a = self._one_star_A_step(ystar, w)
            w = self._update_one_W(ystar, wstar, a)
            assert False # need convergence criterion
            self.n_iter += 1
        print("test(): converged at iteration:", self.n_iter, ":",
              self.one_star_objective(ystar, w, a),
              self.one_star_objective(ystar, wstar, a))

    def _initialize(self):
        """
        # bugs:
        - Consider switching SVD to a fast PCA implementation?
        """
        self.W = 1. * self.input_W # copy not reference
        u, s, v = jnp.linalg.svd(Y, full_matrices=False) # maybe slow
        self.A = (u[:,:self.K] * s[:self.K]).T
        self.G = v[:self.K,:]
        print("_initialize():", self.A.shape, self.G.shape)

    def _one_star_A_step(self, y1, w1):
        XTCinvX = self.G * w1 @ self.G.T
        XTCinvY = self.G * w1 @ y1
        return jnp.linalg.solve(XTCinvX, XTCinvY)

    def _one_star_G_step(self, y1, w1):
        XTCinvX = self.A * w1 @ self.A.T
        XTCinvY = self.A * w1 @ y1
        return jnp.linalg.solve(XTCinvX, XTCinvY)

    def _A_step(self):
        foo = self.objective()
        self.A = jax.vmap(self._one_star_A_step)(self.Y, self.W).T
        bar = self.objective()
        if foo < bar:
            print("_A_step(): ERROR: objective got worse", foo, bar)

    def _G_step(self):
        foo = self.objective()
        self.G = jax.vmap(self._one_star_G_step)(self.Y.T, self.W.T).T
        bar = self.objective()
        if foo < bar:
            print("_G_step(): ERROR: objective got worse", foo, bar)
        if foo - bar < 1.e-3: # magic
            self.converged = True

    def _affine(self):
        """
        # bugs:
        - Consider switching SVD to a fast PCA implementation?
        """
        u, s, v = jnp.linalg.svd(self.A.T @ self.G, full_matrices=False)
        self.A = (u[:,:self.K] * s[:self.K]).T
        self.G = v[:self.K,:]

    def synthesis(self):
        return self.A.T @ self.G

    def resid(self):
        return self.Y - self.synthesis()

    def objective(self):
        return jnp.sum(self.W * self.resid() ** 2)

    def original_chi(self):
        return jnp.sqrt(self.input_W) * self.resid()

    def original_objective(self):
        return jnp.sum(self.input_W * self.resid() ** 2)

    def _update_W(self):
        self.W = self.input_W * self.Q2 / (self.input_W * self.resid() ** 2 + self.Q2)

In [None]:
model = RHMF(15, 2.5)
model.train(Y, W)

In [None]:
for k, g in enumerate(model.G):
    plt.plot(wavelength_grid, g + 0.15 * k)

In [None]:
chi = model.original_chi()
chi_squared = np.sum(chi ** 2, axis=1)
resid = model.resid()
mse = np.mean(resid ** 2, axis=1)
indx = np.argsort(-chi_squared)
print(chi.shape, indx.shape)

In [None]:
plotcache = "./rvs_plot_cache"
os.makedirs(plotcache, exist_ok=True)

for ii in range(len(Y)):
    f = plt.figure()
    plt.plot(wavelength_grid, Y[ii], "k-")
    plt.plot(wavelength_grid, resid[ii], "k-")
    plt.title(f"{source_ids[ii]}")
    plt.savefig(f"{plotcache}/{source_ids[ii]}.png")
    if ii % 64 == 0:
        plt.show()
    plt.close(f)