# RHMF on SDSS-Classic LRGs
Looking for second redshifts, maybe?

## Author:
- **David W. Hogg** (NYU) (MPIA) (Flatiron)
- with help from **Claude** (Anthropic) and consulting from **Hans-Walter Rix** (MPIA)

## License:
Copyright 2025 the author. This code is released for re-use under the open-source *MIT License*.

## Issues:
- Test step not running yet.
- How to select interesting objects after test.

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pylab as plt
import clod_sdss_lrg as clod
import rhmf

In [None]:
mpl.rcParams['figure.figsize'] = (12, 4)
cache = './sdss_lrg_cache'

In [None]:
# get data

# Initialize processor
processor = clod.SDSSLRGProcessor(cache_dir=cache)

# Process LRG sample (will use cache on subsequent runs)
processed_spectra = processor.process_lrg_sample(max_objects=175, force_reprocess=True)

# To clear cache:
# processor.clear_cache()

In [None]:
# put the data into rectangular objects
objs = np.array([o for o,s in processed_spectra.items()])
wavelengths = processor.rest_wave_grid
Y = np.vstack([s['flux'] for o,s in processed_spectra.items()])
W = np.vstack([s['ivar'] for o,s in processed_spectra.items()])
print(objs.shape, wavelengths.shape, Y.shape, W.shape)

In [None]:
# check data
print(np.sum(~ np.isfinite(Y)))
print(np.sum(~ np.isfinite(W)))
print(np.sum(W < 0.))

In [None]:
# set model parameters
rank, nsigma = 10, 3.0
model = rhmf.RHMF(rank, nsigma)

In [None]:
# censor data, given these model parameters
goodobjects = (np.sum(W, axis=1) > (0.1 * np.median(W))) # magic
goodobjects *= np.sum(W > 0, axis=1) > (rank + 5) # magic
objs = objs[goodobjects]
Y = Y[goodobjects]
W = W[goodobjects]
goodwavelengths = (np.sum(W, axis=0) > (0.1 * np.median(W))) # magic
goodwavelengths *= np.sum(W > 0, axis=0) > (rank + 5) # magic
wavelengths = wavelengths[goodwavelengths]
Y = Y[:, goodwavelengths]
W = W[:, goodwavelengths]
print(objs.shape, wavelengths.shape, Y.shape, W.shape)

In [None]:
model.train(Y, W, maxiter=600)

In [None]:
synth = model.synthesis()
resid = model.resid()
for ii in range(len(Y)):
    f = plt.figure()
    plt.plot(wavelengths, Y[ii], "k-", lw=1, alpha=0.45)
    plt.plot(wavelengths, synth[ii], "r-", lw=1, alpha=0.90)
    plt.plot(wavelengths, resid[ii], "k-", lw=1, alpha=0.45)
    plt.xlim(np.min(wavelengths), np.max(wavelengths))
    foo = np.median(Y[ii])
    plt.ylim(-0.2 * foo, 2.0 * foo)
    plt.title(objs[ii])
    plt.xlabel("wavelength")
    plt.ylabel("flux")
    plt.savefig(f"{cache}/{objs[ii]}.png")
    plt.close(f)