# Spectroscopic Properties Lookup Tables

For lost galaxies we need to fill in spectroscopic properties. Redshift is handled first and specially. Here we create two other lookup tables. The lookup tables are a two-step process after z has been assigned to lost galaxies. First you use kcorrlookup, then dn4000lookup. Use train_kcorrlookup.py and train_dn4000.py to search parameters for the best way to do this. This notebook lets you look at those results, save the model via pickle, and validate the performance.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import astropy.io.fits as fits
from astropy.table import Table
import sys
#from pykdtree.kdtree import KDTree as FastKDTree
from scipy.spatial.kdtree import KDTree 
import pickle
import emcee
import corner

if './SelfCalGroupFinder/py/' not in sys.path:
    sys.path.append('./SelfCalGroupFinder/py/')
from pyutils import *
from dataloc import *
from bgs_helpers import *
from groupcatalog import BGS_Z_MAX, BGS_Z_MIN

%load_ext autoreload
%autoreload 2

pd.set_option('display.max_columns', None)

## Load BGS Data

Load the merged BGS data file to use for creating the lookup table.

In [None]:
# Load the merged table
table = Table.read(IAN_BGS_Y1_MERGED_FILE, format='fits')
df = table_to_df(table, 20.175, BGS_Z_MIN, BGS_Z_MAX, True, 1)
print(f"Loaded {len(df):,} galaxies")

In [None]:
# Drop all but the columns we need
df = df[['TARGETID', 'Z', 'ABS_MAG_R', 'ABS_MAG_G', 'ABSMAG01_SDSS_R', 'ABSMAG01_SDSS_G', 'DN4000_MODEL', 'G_R_BEST', 'LOG_L_GAL', 'LOGMSTAR']]
df.reset_index(drop=True, inplace=True)

In [None]:
magr_k_gama = k_correct_gama(df['ABS_MAG_R'], df['Z'], df['ABS_MAG_G'] - df['ABS_MAG_R'], band='r')
magg_k_gama = k_correct_gama(df['ABS_MAG_G'], df['Z'], df['ABS_MAG_G'] - df['ABS_MAG_R'], band='g')
badmatch = (np.abs(magr_k_gama - df['ABSMAG01_SDSS_R']) > 1.0) | (np.abs(magg_k_gama - df['ABSMAG01_SDSS_G']) > 1.0)
goodidx = ~np.isnan(df['ABS_MAG_R']) & ~np.isnan(df['ABS_MAG_G']) & ~np.isnan(df['Z']) & ~np.isnan(df['ABSMAG01_SDSS_R']) & ~np.isnan(df['ABSMAG01_SDSS_G']) & ~badmatch 
print(f"Number of galaxies with good data: {np.sum(goodidx):,}")

# Prepare training and test data
z_arr = df.loc[goodidx, 'Z'].to_numpy()
magr_arr = df.loc[goodidx, 'ABS_MAG_R'].to_numpy() # These have no k-corr
magg_arr = df.loc[goodidx, 'ABS_MAG_G'].to_numpy() # These have no k-corr
gmr_arr = magg_arr - magr_arr
kcorr_r_arr = magr_arr - df.loc[goodidx, 'ABSMAG01_SDSS_R'].to_numpy() # These have the fastspecfit k-corr in them
kcorr_g_arr = magg_arr - df.loc[goodidx, 'ABSMAG01_SDSS_G'].to_numpy() # These have the fastspecfit k-corr in them  

In [None]:
# Scatter plot of redshift and g-r
plt.figure(figsize=(8,6))
plt.scatter(magg_arr, z_arr, s=5, alpha=0.05, label='All Y1 Good Data')
plt.xlabel('g - r')
plt.ylabel('Redshift')
plt.title('Redshift vs. g-r Color')
plt.legend()
plt.show()  

In [None]:
# Let's look at the distribution of k-corrections in r and g bands
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(kcorr_r_arr, bins=1250, alpha=0.7, color='r', label='r-band')
plt.hist(kcorr_g_arr, bins=1250, alpha=0.7, color='g', label='g-band')
plt.xlabel('K-correction (mag)')
plt.ylabel('Number of galaxies')
plt.title('Distribution of K-corrections (Training Set)')
plt.legend()
plt.xlim(-0.4, 1.5)

## View MCMC results; Save Best

### kcorrlookup

In [None]:
# Load sampler
backend = emcee.backends.HDFBackend(OUTPUT_FOLDER + "kcorr_lookup_optimization.h5")
chains = backend.get_log_prob(flat=True)
argmax = np.argmax(chains)
samples = backend.get_chain(flat=True)
chisqr = - backend.get_log_prob(flat=True)

print(f"Total Samples: {len(samples)}")
print(f"Best-fit parameters: {samples[argmax]}")
print(f"Best-fit chi-squared: {chisqr[argmax]}")

In [None]:
# Scatter plot of chisqr values, chi sqr is color and 2d map is the two parameters 
plt.figure(figsize=(10, 6))
plt.scatter(samples[:, 0], samples[:, 1], c=chisqr, cmap='viridis', s=10)
plt.colorbar(label='Chi-squared')
plt.xlabel('METRIC_Z')
plt.ylabel('METRIC_GMR')
plt.title('MCMC Samples Colored by Chi-squared')
plt.show()

# Plot corner plot
fig = corner.corner(samples, labels=["METRIC_Z", "METRIC_GMR"], show_titles=True)
plt.show()

In [None]:
# Now for the best fit metric, build the tree and lookup to save off (using the full data now)

# MCMC says it doesn't matter much
optimal_metric_z = 50.0
optimal_metric_gmr = 10.0

# Filter to galaxies with all required data
badmatch = (np.abs(magr_k_gama - df['ABSMAG01_SDSS_R']) > 1.0) | (np.abs(magg_k_gama - df['ABSMAG01_SDSS_G']) > 1.0)
goodidx = ~np.isnan(df['ABS_MAG_R']) & ~np.isnan(df['ABS_MAG_G']) & ~np.isnan(df['Z']) & ~np.isnan(df['ABSMAG01_SDSS_R']) & ~np.isnan(df['ABSMAG01_SDSS_G']) & ~badmatch

print(f"Building final lookup table with metric:")
print(f"  METRIC_Z: {optimal_metric_z}")
print(f"  METRIC_GMR: {optimal_metric_gmr}")

# Prepare full dataset (all good galaxies)
z_full = df.loc[goodidx, 'Z'].to_numpy()
magr_full = df.loc[goodidx, 'ABS_MAG_R'].to_numpy()
magg_full = df.loc[goodidx, 'ABS_MAG_G'].to_numpy()
gmr_full = magg_full - magr_full

# Calculate k-corrections for the full dataset
kcorr_r_full = magr_full - df.loc[goodidx, 'ABSMAG01_SDSS_R'].to_numpy()
kcorr_g_full = magg_full - df.loc[goodidx, 'ABSMAG01_SDSS_G'].to_numpy()

# Scale the features with optimal metrics
z_scaled = z_full * optimal_metric_z
gmr_scaled = gmr_full * optimal_metric_gmr
magr_scaled = magr_full  # metric_absmag_r = 1.0

# Build the KDTree
lookup_points = np.vstack((z_scaled, gmr_scaled, magr_scaled)).T
kdtree = KDTree(lookup_points)

# Store the k-corrections as lookup tables
kcorr_r_lookup = kcorr_r_full
kcorr_g_lookup = kcorr_g_full

print(f"Built KDTree with {len(lookup_points):,} galaxies")
print(f"K-correction lookup table shape: {kcorr_r_lookup.shape}")

# Save the lookup table and optimal metrics
lookup_data = (lookup_points, kcorr_r_lookup, kcorr_g_lookup, optimal_metric_z, optimal_metric_gmr, 1.0)

with open(BGS_Y3_KCORR_LOOKUP_FILE, 'wb') as f:
    pickle.dump(lookup_data, f)

print(f"\nSaved lookup table to {BGS_Y3_KCORR_LOOKUP_FILE}")



### dn4000lookup

In [None]:
# Load sampler
backend2 = emcee.backends.HDFBackend(OUTPUT_FOLDER + "dn4000_lookup_optimization.h5")
chains = backend2.get_log_prob(flat=True)
argmax = np.argmax(chains)
samples = backend2.get_chain(flat=True)
chisqr = - backend2.get_log_prob(flat=True)

print(f"Total Samples: {len(samples)}")
print(f"Best-fit parameters: {samples[argmax]}")
print(f"Best-fit chi-squared: {chisqr[argmax]}")

In [None]:
# Show chi squared distrubtion as function of parameters
import matplotlib.pyplot as plt

# Plot chi-squared as a function of each parameter
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.scatter(samples[:, 0], chisqr, alpha=0.5)
plt.xlabel("METRIC_GMR")
plt.ylabel("Chi-squared")

plt.subplot(1, 3, 2)
plt.scatter(samples[:, 1], chisqr, alpha=0.5)
plt.xlabel("K")
plt.ylabel("Chi-squared")

plt.subplot(1, 3, 3)
plt.scatter(samples[:, 2], chisqr, alpha=0.5)
plt.xlabel("INNER_METRIC_DN4000")
plt.ylabel("Chi-squared")

plt.tight_layout()
plt.show()

In [None]:
# Plot corner plot
fig = corner.corner(samples, labels=["METRIC_GMR", "K", "INNER_METRIC_DN4000"], show_titles=True)
plt.show()

In [None]:
optimal_metric_gmr = 5.0
optimal_innermetric_dn4000 = 5.0
    
# Filter to galaxies with all required data
goodidx = (~np.isnan(df['ABSMAG01_SDSS_R']) & 
            ~np.isnan(df['ABSMAG01_SDSS_G']) & 
            ~np.isnan(df['DN4000_MODEL']) &
            ~np.isnan(df['LOGMSTAR']))
# Build final lookup table with optimal metrics
print(f"\nBuilding dn4000/mstellar final lookup table...")

# Use all good galaxies for final table
magr_scaled = df.loc[goodidx, 'ABSMAG01_SDSS_R'].to_numpy()  # metric_absmag_r = 1.0
gmr_scaled = (df.loc[goodidx, 'ABSMAG01_SDSS_G'] - df.loc[goodidx, 'ABSMAG01_SDSS_R']).to_numpy() * optimal_metric_gmr
dn4000 = df.loc[goodidx, 'DN4000_MODEL'].to_numpy()
logmstar = df.loc[goodidx, 'LOGMSTAR'].to_numpy()

lookup_points = np.vstack((magr_scaled, gmr_scaled)).T
kdtree = KDTree(lookup_points)

print(f"Built KDTree with {len(lookup_points):,} galaxies")

# Save lookup table
lookup_data = (kdtree, dn4000, logmstar, optimal_metric_gmr, 1.0, optimal_innermetric_dn4000, 1.0)

with open(BGS_Y3_DN4000_LOOKUP_FILE, 'wb') as f:
    pickle.dump(lookup_data, f)

print(f"\nSaved lookup table to {BGS_Y3_DN4000_LOOKUP_FILE}")

## Validation

In [None]:
validation_file = OUTPUT_FOLDER + "y3_validation_set.pkl"

if os.path.exists(validation_file):
    df_valid = pd.read_pickle(validation_file)
    print(f"Loaded Y3 validation set from {validation_file}")

else:
    print(f"Creating Y3 validation set {validation_file}")
    table_y3 = Table.read(IAN_BGS_Y3_MERGED_FILE_LOA, format='fits')
    df_y3 = table_to_df(table_y3, 20.175, BGS_Z_MIN, BGS_Z_MAX, True, 1)

    # Take a random subset of Y3 for validation (up to 100k)
    n_validate = min(100000, len(df_y3))
    validation_idx = np.random.choice(len(df_y3), size=n_validate, replace=False)
    df_y3_subset = df_y3.iloc[validation_idx].copy()
    print(f"Loaded {n_validate:,} Y3 galaxies for validation")

    # Remove data points that were in Y1 (use TARGETID)
    y1_targetids = set(df['TARGETID'].values)
    not_in_y1 = ~df_y3_subset['TARGETID'].isin(y1_targetids)
    df_y3_subset = df_y3_subset[not_in_y1].copy()
    print(f"After removing Y1 galaxies: {len(df_y3_subset):,} Y3 galaxies remaining")

    # Filter to galaxies with complete data for all validations
    has_kcorr = ~np.isnan(df_y3_subset['ABSMAG01_SDSS_R']) & ~np.isnan(df_y3_subset['ABSMAG01_SDSS_G'])
    has_mags = ~np.isnan(df_y3_subset['ABS_MAG_R']) & ~np.isnan(df_y3_subset['ABS_MAG_G'])
    has_z = ~np.isnan(df_y3_subset['Z'])
    has_dn4000 = ~np.isnan(df_y3_subset['DN4000_MODEL'])
    has_logmstar = ~np.isnan(df_y3_subset['LOGMSTAR'])
    has_quiescent = ~np.isnan(df_y3_subset['QUIESCENT'])

    valid_all = has_kcorr & has_mags & has_z & has_dn4000 & has_logmstar & has_quiescent
    df_valid = df_y3_subset[valid_all].copy()
    print(f"Validation sample: {len(df_valid):,} Y3 galaxies with complete data\n")

    # Pickle save off this validation set
    df_valid.to_pickle(OUTPUT_FOLDER + "y3_validation_set.pkl")
    print(f"Saved Y3 validation set to {validation_file}")

In [None]:
# Load lookup tables
print("Loading lookup tables...")
kcorr_lookup = kcorrlookup()
dn4000_lookup = dn4000lookup()

# Extract test data
z_val = df_valid['Z'].to_numpy()
magr_val = df_valid['ABS_MAG_R'].to_numpy()
magg_val = df_valid['ABS_MAG_G'].to_numpy()
gmr_val = magg_val - magr_val

# Get predictions from k-correction lookup
pred_kcorr_r, pred_kcorr_g = kcorr_lookup.query(gmr_val, z_val, magr_val)

# Apply k-corrections
pred_magr_kcorr = magr_val - pred_kcorr_r
pred_magg_kcorr = magg_val - pred_kcorr_g
pred_gmr_kcorr = pred_magg_kcorr - pred_magr_kcorr

# Get predictions from dn4000 lookup
pred_dn4000, pred_logmstar = dn4000_lookup.query(pred_magr_kcorr, pred_gmr_kcorr)

# Get quiescent classification using predicted values
pred_log_l = abs_mag_r_to_log_solar_L(pred_magr_kcorr)
pred_quiescent = is_quiescent_BGS_dn4000(pred_log_l, pred_dn4000, pred_gmr_kcorr)
pred_quiescent_gmrk = is_quiescent_BGS_gmr(pred_log_l, pred_gmr_kcorr)
pred_quiescent_gmr = is_quiescent_lost_gal_guess(gmr_val)

# True values
true_kcorr_r = magr_val - df_valid['ABSMAG01_SDSS_R'].to_numpy()
true_kcorr_g = magg_val - df_valid['ABSMAG01_SDSS_G'].to_numpy()
true_dn4000 = df_valid['DN4000_MODEL'].to_numpy()
true_logmstar = df_valid['LOGMSTAR'].to_numpy()
true_quiescent = df_valid['QUIESCENT'].to_numpy().astype(bool)

# Calculate errors
err_kcorr_r = pred_kcorr_r - true_kcorr_r
err_kcorr_g = pred_kcorr_g - true_kcorr_g
err_dn4000 = pred_dn4000 - true_dn4000
err_logmstar = pred_logmstar - true_logmstar

# Calculate percentile-based sigma levels
def get_sigma_levels(errors):
    """Get 1, 2, 3 sigma levels using percentiles (68.27%, 95.45%, 99.73%)"""
    sigma_1 = np.percentile(np.abs(errors), 68.27)
    sigma_2 = np.percentile(np.abs(errors), 95.45)
    sigma_3 = np.percentile(np.abs(errors), 99.73)
    return sigma_1, sigma_2, sigma_3

# Calculate quiescent classification accuracy
quiescent_correct = np.sum(pred_quiescent == true_quiescent)
quiescent_accuracy = 100.0 * quiescent_correct / len(pred_quiescent)

quiescent_correct_gmrk = np.sum(pred_quiescent_gmrk == true_quiescent)
quiescent_accuracy_gmrk = 100.0 * quiescent_correct_gmrk / len(pred_quiescent_gmrk)

quiescent_correct_gmr = np.sum(pred_quiescent_gmr == true_quiescent)
quiescent_accuracy_gmr = 100.0 * quiescent_correct_gmr / len(pred_quiescent_gmr)

# Print summary results
print("="*70)
print("LOOKUP TABLE VALIDATION SUMMARY (Y3 DATA)")
print("="*70)
print(f"\nSample size: {len(df_valid):,} galaxies")
print(f"Redshift range: {z_val.min():.3f} - {z_val.max():.3f}")

print("\n" + "-"*70)
print("ERROR ON LOOKED UP QUANTITIES")
print("-"*70)
s1, s2, s3 = get_sigma_levels(err_kcorr_r)
print(f"r-band k-correction:  1σ = {s1:.4f},  2σ = {s2:.4f},  3σ = {s3:.4f}")
s1, s2, s3 = get_sigma_levels(err_kcorr_g)
print(f"g-band k-correction:  1σ = {s1:.4f},  2σ = {s2:.4f},  3σ = {s3:.4f}")
s1, s2, s3 = get_sigma_levels(err_dn4000)
print(f"Dn4000:               1σ = {s1:.4f},  2σ = {s2:.4f},  3σ = {s3:.4f}")
s1, s2, s3 = get_sigma_levels(err_logmstar)
print(f"log(M*/M_sun):        1σ = {s1:.4f},  2σ = {s2:.4f},  3σ = {s3:.4f}")

print("\n" + "-"*70)
print("QUIESCENT CLASSIFICATION")
print("-"*70)
print(f"Accuracy:             {quiescent_accuracy:.2f}%")
print(f"Correct:              {quiescent_correct:,} / {len(pred_quiescent):,}")
print(f"Accuracy (gmrk):      {quiescent_accuracy_gmrk:.2f}%")
print(f"Correct (gmrk):       {quiescent_correct_gmrk:,} / {len(pred_quiescent_gmrk):,}")
print(f"Accuracy (gmr):       {quiescent_accuracy_gmr:.2f}%")
print(f"Correct (gmr):        {quiescent_correct_gmr:,} / {len(pred_quiescent_gmr):,}")

print("\n" + "="*70)

# Optional: Create a compact visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: K-correction errors
ax = axes[0, 0]
ax.hist(err_kcorr_r, bins=250, alpha=0.6, label='r-band', color='r', density=True)
ax.hist(err_kcorr_g, bins=250, alpha=0.6, label='g-band', color='g', density=True)
ax.axvline(0, color='k', linestyle='--', lw=1)
s1_r, s2_r, s3_r = get_sigma_levels(err_kcorr_r)
ax.axvline(s1_r, color='r', linestyle=':', alpha=0.5, label=f'1σ (r)={s1_r:.3f}')
ax.axvline(-s1_r, color='r', linestyle=':', alpha=0.5)
ax.set_xlabel('K-correction Error (mag)')
ax.set_ylabel('Density')
ax.set_title('K-correction Errors')
ax.set_xlim(-0.5, 0.5)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# Plot 2: Dn4000 errors
ax = axes[0, 1]
ax.hist(err_dn4000, bins=100, alpha=0.7, density=True)
ax.axvline(0, color='k', linestyle='--', lw=1)
s1, s2, s3 = get_sigma_levels(err_dn4000)
ax.axvline(s1, color='r', linestyle=':', alpha=0.5, label=f'1σ={s1:.3f}')
ax.axvline(-s1, color='r', linestyle=':', alpha=0.5)
ax.set_xlabel('Dn4000 Error')
ax.set_ylabel('Density')
ax.set_title('Dn4000 Errors')
ax.set_xlim(-0.5, 0.5)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# Plot 3: Stellar mass errors
ax = axes[1, 0]
ax.hist(err_logmstar, bins=100, alpha=0.7, density=True, color='green')
ax.axvline(0, color='k', linestyle='--', lw=1)
s1, s2, s3 = get_sigma_levels(err_logmstar)
ax.axvline(s1, color='r', linestyle=':', alpha=0.5, label=f'1σ={s1:.3f}')
ax.axvline(-s1, color='r', linestyle=':', alpha=0.5)
ax.set_xlabel('log(M*) Error (dex)')
ax.set_ylabel('Density')
ax.set_title('Stellar Mass Errors')
ax.set_xlim(-1.0, 1.0)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# Plot 4: Quiescent classification confusion
ax = axes[1, 1]
confusion = np.zeros((2, 2))
confusion[0, 0] = np.sum((~pred_quiescent) & (~true_quiescent))  # TN
confusion[0, 1] = np.sum((pred_quiescent) & (~true_quiescent))   # FP
confusion[1, 0] = np.sum((~pred_quiescent) & (true_quiescent))   # FN
confusion[1, 1] = np.sum((pred_quiescent) & (true_quiescent))    # TP
im = ax.imshow(confusion, cmap='Blues', aspect='auto')
ax.set_xticks([0, 1])
ax.set_yticks([0, 1])
ax.set_xticklabels(['Star-forming', 'Quiescent'])
ax.set_yticklabels(['Star-forming', 'Quiescent'])
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title(f'Quiescent Classification\nAccuracy: {quiescent_accuracy:.2f}%')
for i in range(2):
    for j in range(2):
        text = ax.text(j, i, f'{int(confusion[i, j]):,}',
                      ha="center", va="center", color="black", fontsize=10)
plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.savefig(OUTPUT_FOLDER + 'lookup_validation_summary.png', dpi=150, bbox_inches='tight')
print(f"\nSaved validation plot to {OUTPUT_FOLDER}lookup_validation_summary.png")
plt.show()