# Robust HMF on *BOSS* spectra of hot stars...
...to find evidence of H-alpha emission.

## Authors:
- **David W. Hogg** (NYU) (MPIA) (Flatiron)
- **Hans-Walter Rix** (MPIA)

## To-do items:
- Vet results and deliver to HWR's people.
- Make some method (perhaps in `rhmf.py`) to save and restore a Robust HMF model.

## Bugs:
- Probably RHMF is the wrong tool for this job?

In [None]:
# Cell 1: Import required libraries
import numpy as np
import pandas as pd
from astropy.io import fits
import requests
from requests.auth import HTTPBasicAuth
import os
import concurrent.futures
import matplotlib.pyplot as plt
import rhmf

In [None]:
# data choices
bosstag = 'v6_2_1'
cachedir = f'./boss_{bosstag}_star_cache'
os.makedirs(cachedir, exist_ok=True)

# Create subdirectory for plots
plot_folder = cachedir + '/plots'
os.makedirs(plot_folder, exist_ok=True)

In [None]:
# model choices
rank, nsigma = 24, 3.5

In [None]:
# Define download functions
user, password = None, None

def download_one_file_from_df(args):
    """Download a single file from SDSS."""
    url, filename, user, password, cachedir = args
    filepath = os.path.join(cachedir, filename)
    
    # Skip if already downloaded
    if os.path.exists(filepath):
        # print(f"File {filename} already exists, skipping")
        return True
        
    try:
        with requests.Session() as session:
            response = session.get(url, auth=HTTPBasicAuth(user, password), timeout=30)
            response.raise_for_status()
            with open(filepath, 'wb') as f:
                f.write(response.content)
        if np.random.uniform() < 0.1:
            print(f"Random example: File downloaded: {filename}")
        return True
    except Exception as e:
        print(f"Failed to download {filename}: {e}")
        return False

def download_files_from_df(df, user, password, dest_folder, boss_tag='v6_2_1', coadd_version='daily', max_workers=8):
    """Download multiple files from SDSS based on dataframe."""
    os.makedirs(dest_folder, exist_ok=True)
    args_list = []

    for idx, row in df.iterrows():
        spec_file = row['SPEC_FILE']
        fieldid = f"{row['FIELD']:06d}"
        mjd = str(row['MJD'])
        fieldidXXX = fieldid[:-3] + 'XXX'
        url = (
            f"https://data.sdss5.org/sas/sdsswork/bhm/boss/spectro/redux/"
            f"{boss_tag}/spectra/{coadd_version}/lite/{fieldidXXX}/{fieldid}/{mjd}/{spec_file}"
        )
        args_list.append((url, spec_file, user, password, dest_folder))

    print(f"Starting attempts to download {len(args_list)} files")
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(download_one_file_from_df, args_list))
    print(f"Number successful: {sum(results)} files")
    return results

In [None]:
# Download and examine the spAll file
spallname = f'spAll-lite-{bosstag}.fits'
summaryurl = f'https://data.sdss5.org/sas/sdsswork/bhm/boss/spectro/redux/{bosstag}/summary/daily/{spallname}.gz'
summaryfile = cachedir + '/' + spallname + '.gz'
summaryfile_uncompressed = cachedir + '/' + spallname

if not os.path.exists(summaryfile_uncompressed):
    if not os.path.exists(summaryfile):
        print(f"Downloading summary file from {summaryurl}")
        response = requests.get(summaryurl, auth=HTTPBasicAuth(user, password))
        with open(summary_file, 'wb') as f:
            f.write(response.content)
        print(f"Summary file {summaryfile} downloaded")
    
    # Decompress
    os.system(f'gunzip -v {summaryfile}')
    print(f"Summary file {summaryfile} decompressed")
else:
    print(f"Summary file {summaryfile_uncompressed} already exists")

In [None]:
# Load spAll data
with fits.open(summaryfile_uncompressed) as hdul:
    data = hdul[1].data
if False:
    print("="*70)
    print("ALL AVAILABLE COLUMNS IN SUMMARY FILE")
    print("="*70)
    columns = data.columns.names
    for i, col in enumerate(columns):
        print(f"{i+1:3d}. {col}")
print(f"rows: {len(data)}; columns: {len(data.columns.names)}")

In [None]:
# Select a sample of spectra to download
# Let's look for high-SNR spectra with BP_MAG - RP_MAG < 0.5 to examine
df = pd.DataFrame({col: data[col].byteswap().newbyteorder() for col in ['SPEC_FILE', 'FIELD', 'MJD', 'SN_MEDIAN_ALL', 'BP_MAG', 'RP_MAG', 'PROGRAMNAME']})

# Filter for high SNR and BP_MAG - RP_MAG < 0.5 spectra
high_snr_df = df[(df['SN_MEDIAN_ALL'] > 40) & (df['PROGRAMNAME'] == 'mwm_ob') ] #.head(50)  # Just 50 spectra for testing
print(f"Selected {len(high_snr_df)} high-SNR and BP-RP < 0.5 spectra for header examination")

In [None]:
# Download the sample spectra
_ = download_files_from_df(high_snr_df, user, password, cachedir, boss_tag=bosstag, coadd_version='daily', max_workers=4)

In [None]:
# make lists of strings
filenames = np.array([f for f in os.listdir(cachedir) if f.endswith('.fits') and f.startswith('spec-')])
starnames = np.array([f[5:-5] for f in filenames])
print(filenames.shape, starnames.shape, filenames[13], starnames[13])

In [None]:
# make rectangular data, plus wavelength grid
wavelength = None
N = len(filenames)
print(f"reading {N} files...")
for i, fn in enumerate(filenames):
    filepath = cachedir + '/' + fn
    
    try:
        with fits.open(filepath) as hdul:
            if len(hdul) > 1 and hasattr(hdul[1], 'data'):
                spec_data = hdul[1].data
                loglam = spec_data['LOGLAM']
                fl = spec_data['FLUX']
                iv = spec_data['IVAR']
                wa = 10**loglam
                if wavelength is None:
                    wavelength = wa
                    M = len(wavelength)
                    flux = np.ones((N, M))
                    ivar = np.zeros_like(flux)
                if np.allclose(wa, wavelength):
                    flux[i] = fl / np.median(fl)
                    ivar[i] = iv * np.median(fl) ** 2
                else:
                    print(f"  Dropped {filepath}: bad wavelength grid")

    except Exception as e:
        print(f"  Dropped {filepath}: {e}")

print("data blocks:", flux.shape, ivar.shape, np.prod(flux.shape))
print("bad pixels:", np.sum(~ np.isfinite(flux)), np.sum(~ np.isfinite(ivar)),
      np.sum(ivar <= 0.) / np.prod(flux.shape))

In [None]:
# trim data
good = (wavelength > 3700) & (wavelength < 12000) # magic
wavelength = wavelength[good]
flux = flux[:, good]
ivar = ivar[:, good]
print(flux.shape, ivar.shape, wavelength.shape)

In [None]:
# floor and ceil the ivars ## HACK
maxivar = 1.e4 / flux ** 2 # magic -- nothing is known to better than 1 percent
ivar = np.clip(ivar, 0., maxivar)
maxivar = 1.e4 / np.median(flux, axis=1) ** 2 # magic -- nothing is known to better than 1 percent on average
minivar = 1.e-4 / np.median(flux, axis=1) ** 2 # magic -- there is trivial information even at useless pixels
ivar = np.clip(ivar, minivar[:, None], maxivar[:, None])
print(np.min(ivar), np.max(ivar))

In [None]:
# make two disjoint training sets
N, M = flux.shape
rng = np.random.default_rng(17)
foo = np.random.uniform(size=N)
A = foo < np.median(foo)
B = np.logical_not(A)
print(np.sum(A), np.sum(B), ~np.any(np.logical_and(A, B)))

In [None]:
Aidx = np.arange(N)[A]
Bidx = np.arange(N)[B]
print(len(Aidx), len(Bidx), np.all(A[Aidx]), np.all(B[Bidx]))

In [None]:
# plotting utility: Hydrogen recombination lines

def hydrogen_line(n_upper, n_lower):
    R_H = 10973731.568157 # (12) per meter; Wikipedia
    wave_number = R_H * (1/n_lower**2 - 1/n_upper**2) # per meter
    return (1. / np.abs(wave_number)) * 1.e10 # Angstrom

def plot_hydrogen_lines(ax):
    for n1 in (2, 3):
        for n2 in range(n1 + 1, n1 + 15): # magic 15
            ax.axvline(hydrogen_line(n2, n1), color="b", lw=0.5, alpha=0.23)

In [None]:
# plotting utility: Hogg cares about wavelength axes.

def hogg_wavelength_axis(ax, wavelength):
    plot_hydrogen_lines(ax)
    ax.semilogx()
    ticks = [3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000]
    ticklabels = [str(tick) for tick in ticks]
    ax.set_xticks(ticks, ticklabels)
    ax.set_xlim(np.min(wavelength), np.max(wavelength))
    ax.set_xlabel('wavelength')
    return ax

In [None]:
# plot the eigenvectors of a model

def plot_G(model, title):
    plt.figure(figsize=(12,8))
    for k, g in enumerate(model.G):
        plt.step(wavelength, 10. * g + k,
                 where='mid', lw=0.5, alpha=0.90)
    plt.ylim(-1., model.K)
    ax = hogg_wavelength_axis(plt.gca(), wavelength)
    plt.title(title)

In [None]:
# plot a spectrum and a synthetic spectrum and residuals

def plot_one_spectrum(wavelength, flux, ivar, name, prefix, synth=None):
    f = plt.figure(figsize=(12, 4))
    plt.axhline(0., lw=0.5, color='k', alpha=0.45)
    plt.step(wavelength, flux,
             where='mid', color='k', lw=0.5, alpha=0.90)
    tiny = 0.01 / np.median(flux) ** 2
    flhi = flux + 1. / np.sqrt(ivar + tiny)
    fllo = flux - 1. / np.sqrt(ivar + tiny)
    plt.fill_between(wavelength, fllo, flhi,
                     step='mid', color='k', alpha=0.23)
    if synth is not None:
        plt.step(wavelength, flux - synth,
                 where='mid', color='k', lw=0.5, alpha=0.90)
        plt.step(wavelength, synth,
                 where='mid', color='r', lw=0.5, alpha=0.90)
        plt.step(wavelength, np.zeros_like(flux),
                 where='mid', color='r', lw=0.5, alpha=0.90)

    # adjust axes
    foo = np.nanmedian(flux)
    plt.ylim(-0.5 * foo, 2.5 * foo)
    plt.ylabel('flux')
    plt.title(name)
    hogg_wavelength_axis(plt.gca(), wavelength)

    # Save plot
    plot_filename = plot_folder + '/' + prefix + name + '.png'
    plt.savefig(plot_filename)
    plt.close(f)
    print(f"  Plot saved: {plot_filename}")

In [None]:
# make test step but with a line held out of the fitting (like, say, H-alpha)

def censored_cross_test(Y, W, models, line, delta):
    near_line = (wavelength > (line - delta)) & (wavelength < (line + delta))
    print(np.sum(near_line))
    W_line = 1. * W # copy
    W_line[:, near_line] = 0.
    return cross_test(Y, W_line, models)

In [None]:
# cross test: Test A objects with model B, and B objects with model A.

def cross_test(Y, W, models):
    assert len(models) == 2
    synth = np.zeros_like(Y) + np.nan
    print(np.sum(np.isnan(synth)))
    for m in range(2):
        n = (m + 1) % 2
        model, _, _ = models[m]
        _, idx, _ = models[n]
        for i in idx:
            synth[i] = model.test(Y[i], W[i])
        print(np.sum(np.isnan(synth)))    
    return synth

In [None]:
# the full train and test pipeline

def train_and_test(Y, W, models, maxiter=10):

    # train step
    for model, idx, label in models:
        print(label)
        model.train(Y[idx], W[idx], maxiter=maxiter)
        plot_G(model, label)
        plt.show()

    # test step
    halpha, delta = 6564.6, 5. # line from Wikipedia, Angstroms; delta from magic
    synth_ex_halpha = censored_cross_test(Y, W, models, halpha, delta)

    # choose interesting objects to plot
    near_halpha = (wavelength > (halpha - delta)) & (wavelength < (halpha + delta))
    resid = Y - synth_ex_halpha
    chi_halpha = (resid * np.sqrt(W))[:, near_halpha]
    chi2_halpha = np.sum(chi_halpha ** 2, axis=1)
    interesting = np.argsort(-chi2_halpha)
    _ = plt.hist(np.log10(chi2_halpha), bins=100)
    plt.xlabel("log10(chi-squared)")
    plt.semilogy()
    plt.show()

    # make plots
    prefix = "halpha_emitter_"
    os.system(f"rm -v {plot_folder}/{prefix}*.png")
    for i in interesting[:300]:
        plot_one_spectrum(wavelength, Y[i], W[i], starnames[i], prefix, synth=synth_ex_halpha[i])
    return synth_ex_halpha

In [None]:
# start models
models = [(rhmf.RHMF(rank, nsigma), Aidx, "model A"),
          (rhmf.RHMF(rank, nsigma), Bidx, "model B")]
synth = train_and_test(flux, ivar, models, maxiter=30)

In [None]:
# train even more
for t in range(30):
    synth = train_and_test(flux, ivar, models, maxiter=30)