# RHMF on *SDSS* *eBOSS* LRGs
Looking for second redshifts, maybe?

## Author:
- **David W. Hogg** (NYU) (MPIA) (Flatiron)
- with help from **Claude** (Anthropic) and consulting from **Hans-Walter Rix** (MPIA)

## License:
Copyright 2025 the author. This code is released for re-use under the open-source *MIT License*.

## Issues:
- How to select interesting objects after test step?
- Maybe should have some code that reads the whole cache and deletes any bad files?
- Should have at least 2 models -- maybe 5-ish -- for comparisons.
- The RHMF code has an issue where it increases the objective occasionally at the a-step or g-step. What to do / analyze?

In [None]:
import os
import numpy as np
import matplotlib as mpl
import matplotlib.pylab as plt
import clod_sdss_lrg as clod
import rhmf

In [None]:
# set defaults
mpl.rcParams['figure.figsize'] = (12, 4)
cache = './sdss_lrg_cache'
rng = np.random.default_rng(17)
zmin, zmax = 0.20, 0.55

In [None]:
# set integers
Nmax = 15_000
Ntrain = 1000

In [None]:
# set model parameters
rank, nsigma = 10, 2.5
model = rhmf.RHMF(rank, nsigma)

In [None]:
# get data

# Initialize processor
processor = clod.SDSSLRGProcessor(cache_dir=cache)

# Process LRG sample (will use cache on subsequent runs)
processed_spectra = processor.process_lrg_sample(max_objects=Nmax,
                                                 z_min=zmin, z_max=zmax,
                                                 force_reprocess=True)

# To clear cache:
# processor.clear_cache()

In [None]:
# put the data into rectangular objects
objs = np.array([o for o,s in processed_spectra.items()])
zs = np.array([s['redshift'] for o,s in processed_spectra.items()])
wavelengths = processor.rest_wave_grid
Y = np.vstack([s['flux'] for o,s in processed_spectra.items()])
W = np.vstack([s['ivar'] for o,s in processed_spectra.items()])
print(objs.shape, zs.shape, wavelengths.shape, Y.shape, W.shape)

In [None]:
_ = plt.hist(zs, bins=100)
plt.xlabel("redshift")
plt.ylabel("number per bin")
plt.title(f"full set of {len(Y)} LRGs")

In [None]:
# get rid of bad spectra / objects?
goodobjects = (np.sum(W, axis=1) > (0.1 * np.median(W))) # magic
goodobjects *= np.sum(W > 0, axis=1) > (rank + 5) # magic
objs = objs[goodobjects]
zs = zs[goodobjects]
Y = Y[goodobjects]
W = W[goodobjects]
print(objs.shape, zs.shape, wavelengths.shape, Y.shape, W.shape)

In [None]:
# check data
print(np.sum(~ np.isfinite(Y)), np.sum(~ np.isfinite(W)), np.sum(W < 0.))
print(np.sum(W > 0.) / np.prod(W.shape))

In [None]:
information_proxy = np.percentile(W, 50., axis=1) * np.percentile(Y, 75., axis=1) ** 2

In [None]:
# make training set as an index subset
# make training set be uniform in z, because why not?
# WARNING: This must produce an index list, not a boolean list.
dz = (zmax - zmin) / Ntrain
zgrid = np.arange(0.5 * dz, zmax, dz)
foo = np.abs(zs[:, None] - zgrid[None, :])
train = np.argmin(foo, axis=0)
train = np.unique(train)
Ntrain = len(train)
print(train.shape)

In [None]:
test = np.ones_like(zs).astype(bool)
test[train] = False
test = np.arange(len(zs))[test]
print(test.shape)

In [None]:
_ = plt.hist(zs[train], bins=100)
plt.xlabel("redshift")
plt.ylabel("number per bin")
plt.title(f"training set of {len(train)} LRGs")

In [None]:
# get rid of bad wavelengths (we need good training data)
goodwavelengths = (np.sum(W[train], axis=0) > (0.1 * np.median(W[train]))) # magic
goodwavelengths *= np.sum(W[train] > 0, axis=0) > (rank + 5) # magic
wavelengths = wavelengths[goodwavelengths]
Y = Y[:, goodwavelengths]
W = W[:, goodwavelengths]
print(objs.shape, wavelengths.shape, Y.shape, W.shape)

In [None]:
model.train(Y[train], W[train], maxiter=500)

In [None]:
model.train(Y[train], W[train], maxiter=500)

In [None]:
# plot low-rank model
f = plt.figure(figsize=(12, 8))
foo = 0.1
for k, g in enumerate(model.G):
    plt.plot(wavelengths, g + foo * k, lw=1)
plt.xlim(np.min(wavelengths), np.max(wavelengths))
plt.ylim(-1. * foo, rank * foo)

In [None]:
# this from Claude (Anthropic), lightly edited.

nebular_lines = [
    # Oxygen lines
    ('O II', 372.709),
    ('', 372.988),
    ('', 436.444),
    ('', 496.030),
    ('O III', 500.824),
    
    # Hydrogen Balmer series
#    ('H12', 375.122),
#    ('H11', 377.170),
#    ('H10', 379.898),
#    ('H9', 383.649),
#    ('H8', 389.015),
    ('H_epsilon', 397.120),
    ('H_delta', 410.289),
    ('H_gamma', 434.168),
    ('H_beta', 486.268),
    ('H_alpha', 656.461),
    
    # Nitrogen lines
    ('', 654.986),
    ('', 658.527),
    
    # Sulfur lines
    ('S II', 671.829),
    ('', 673.267),
    ('', 907.1),
    ('S III', 953.3),
    
    # Helium lines
#    ('HeI_3889', 388.975),
#    ('HeI_4027', 402.734),
#    ('HeI_4472', 447.276),
    ('He I', 587.729),
#    ('HeII_4687', 468.702),
    
    # Other lines
#    ('NeIII_3869', 386.986),
#    ('ArIII_7137', 713.777),
]

In [None]:
 def plot_lines(y):
    for label, line in nebular_lines:
        plt.axvline(10. * line, color="b", lw=1, alpha=0.4, zorder=-1001)
        plt.text(10. * line, y, label, size="small", color="b", alpha=0.4,
                 rotation=90, va="top", ha="right")

def plot_one_spectrum(data, synth, title, fn, w=None, nolim=False):
    f = plt.figure()
    foo = np.percentile(data, 75.)
    plot_lines(2.49 * foo) # magic
    plt.step(wavelengths, data, where="mid", color="k", lw=1, alpha=0.75)
    if synth is not None:
        plt.step(wavelengths, data - synth, where="mid", color="k", lw=1, alpha=0.75)
        plt.step(wavelengths, synth, where="mid", color="r", lw=1, alpha=0.90)
    if w is not None:
        tiny = 0.25 / foo ** 2
        yerr = 1. / (np.sqrt(w) + tiny)
        plt.fill_between(wavelengths, -3. * yerr, 3. * yerr, step="mid", color="k", alpha=0.20)
    plt.axhline(0., color="r", lw=1, alpha=0.90)
    plt.semilogx()
    xts = np.arange(1000, 9000, 1000)
    xtsl = np.array([f"{x:4.0f}" for x in xts])
    plt.xticks(ticks=xts, labels=xtsl)
    plt.xlim(np.min(wavelengths), np.max(wavelengths))
    if not nolim:
        plt.ylim(-0.5 * foo, 2.5 * foo)
    plt.title(title)
    plt.xlabel("wavelength")
    plt.ylabel("flux")
    plt.savefig(fn)
    plt.close(f)

In [None]:
# plot held-out test data with syntheses
prefix = f"{cache}/test-"
os.system(f"rm -f {prefix}*.png")
for ii in test:
    synthii = model.test(Y[ii], W[ii])
    chiii = (Y[ii] - synthii) * np.sqrt(W[ii])
    if np.sum(chiii > 5.) > 10: # magic magic
        plot_one_spectrum(Y[ii], synthii, f"test set object {objs[ii]}",
                          f"{prefix}{objs[ii]}.png", w=W[ii])