# Spectroscopic Properties Lookup Table

This notebook creates k-correction lookup tables for the BGS data. The lookup table uses KDTree for efficient nearest neighbor searching based on redshift and g-r color to provide k-corrections for absolute magnitude calculations.

## Import Required Libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import astropy.io.fits as fits
from astropy.table import Table
import sys
from scipy.spatial import KDTree
import pickle
import emcee
import corner

if './SelfCalGroupFinder/py/' not in sys.path:
    sys.path.append('./SelfCalGroupFinder/py/')
from pyutils import *
from dataloc import *
from bgs_helpers import *
from groupcatalog import BGS_Z_MAX, BGS_Z_MIN

%load_ext autoreload
%autoreload 2

pd.set_option('display.max_columns', None)

## Load BGS Data

Load the merged BGS data file to use for creating the lookup table.

In [None]:
# Load the merged table
table = Table.read(IAN_BGS_Y1_MERGED_FILE, format='fits')
df = table_to_df(table, 20.175, BGS_Z_MIN, BGS_Z_MAX, True, 1)
print(f"Loaded {len(df):,} galaxies")

In [None]:
print(df.columns)

In [None]:
# Drop all but the columns we need
df = df[['TARGETID', 'Z', 'ABS_MAG_R', 'ABS_MAG_G', 'ABSMAG01_SDSS_R', 'ABSMAG01_SDSS_G', 'DN4000_MODEL', 'G_R_BEST', 'LOG_L_GAL', 'LOGMSTAR']]

In [None]:
df.reset_index(drop=True, inplace=True)

In [None]:
# Let's look at the distribution of k-corrections in r and g bands
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.hist(kcorr_r_train, bins=1250, alpha=0.7, color='r', label='r-band')
plt.hist(kcorr_g_train, bins=1250, alpha=0.7, color='g', label='g-band')
plt.xlabel('K-correction (mag)')
plt.ylabel('Number of galaxies')
plt.title('Distribution of K-corrections (Training Set)')
plt.legend()

In [None]:
# First rusults.

## Validation of the Final kcorrlook metric

In [None]:
# Let's validate by reading in and using the first 100000 from the Y3 data

print("Loading Y3 validation data...")
table_y3 = Table.read(IAN_BGS_Y3_MERGED_FILE_LOA, format='fits')
df_y3 = table_to_df(table_y3, 20.175, BGS_Z_MIN, BGS_Z_MAX, True, 1)

In [None]:
# Take a random subset of Y3 for validation (up to 100k)
n_validate = min(100000, len(df_y3))
validation_idx = np.random.choice(len(df_y3), size=n_validate, replace=False)
df_y3_subset = df_y3.iloc[validation_idx].copy()
print(f"Loaded {n_validate:,} Y3 galaxies for validation")

# Remove data points that were in Y1 (use TARGETID)
y1_targetids = set(df['TARGETID'].values)
not_in_y1 = ~df_y3_subset['TARGETID'].isin(y1_targetids)
df_y3_subset = df_y3_subset[not_in_y1].copy()
print(f"After removing Y1 galaxies: {len(df_y3_subset):,} Y3 galaxies remaining")

# Filter to galaxies with all required data
has_kcorr = ~np.isnan(df_y3_subset['ABSMAG01_SDSS_R']) & ~np.isnan(df_y3_subset['ABSMAG01_SDSS_G'])
has_mags = ~np.isnan(df_y3_subset['ABS_MAG_R']) & ~np.isnan(df_y3_subset['ABS_MAG_G'])
has_z = ~np.isnan(df_y3_subset['Z'])
valid_y3 = has_kcorr & has_mags & has_z

df_y3_valid = df_y3_subset[valid_y3].copy()
print(f"Number of Y3 galaxies with complete data: {len(df_y3_valid):,}")

# Load the saved lookup table
print(f"\nLoading lookup table from {BGS_Y3_KCORR_LOOKUP_FILE}...")
lookup = kcorrlookup()

# Query the lookup table for Y3 galaxies
z_y3 = df_y3_valid['Z'].to_numpy()
magr_y3 = df_y3_valid['ABS_MAG_R'].to_numpy()
magg_y3 = df_y3_valid['ABS_MAG_G'].to_numpy()
gmr_y3 = magg_y3 - magr_y3

# Get predicted k-corrections
pred_kcorr_r_y3, pred_kcorr_g_y3  = lookup.query(gmr_y3, z_y3, magr_y3)

# True k-corrections from Y3
true_kcorr_r_y3 = magr_y3 - df_y3_valid['ABSMAG01_SDSS_R'].to_numpy()
true_kcorr_g_y3 = magg_y3 - df_y3_valid['ABSMAG01_SDSS_G'].to_numpy()

# Calculate errors
kcorr_r_error = pred_kcorr_r_y3 - true_kcorr_r_y3
kcorr_g_error = pred_kcorr_g_y3 - true_kcorr_g_y3

# Calculate predicted k-corrected magnitudes
pred_abs_mag_r_kcorr = magr_y3 - pred_kcorr_r_y3
pred_abs_mag_g_kcorr = magg_y3 - pred_kcorr_g_y3
pred_gmr_kcorr = pred_abs_mag_g_kcorr - pred_abs_mag_r_kcorr

# True k-corrected magnitudes
true_abs_mag_r_kcorr = df_y3_valid['ABSMAG01_SDSS_R'].to_numpy()
true_abs_mag_g_kcorr = df_y3_valid['ABSMAG01_SDSS_G'].to_numpy()
true_gmr_kcorr = true_abs_mag_g_kcorr - true_abs_mag_r_kcorr

# Calculate magnitude errors
abs_mag_r_error = pred_abs_mag_r_kcorr - true_abs_mag_r_kcorr
gmr_error = pred_gmr_kcorr - true_gmr_kcorr

# Print validation statistics
print("\n" + "="*60)
print("VALIDATION RESULTS")
print("="*60)
print("\nK-correction errors:")
print(f"  r-band k-corr MAE: {np.mean(np.abs(kcorr_r_error)):.4f} mag")
print(f"  g-band k-corr MAE: {np.mean(np.abs(kcorr_g_error)):.4f} mag")
print(f"  r-band k-corr RMS: {np.sqrt(np.mean(kcorr_r_error**2)):.4f} mag")
print(f"  g-band k-corr RMS: {np.sqrt(np.mean(kcorr_g_error**2)):.4f} mag")

print("\nK-corrected magnitude errors:")
print(f"  r-band abs mag MAE: {np.mean(np.abs(abs_mag_r_error)):.4f} mag")
print(f"  g-r color MAE: {np.mean(np.abs(gmr_error)):.4f} mag")
print(f"  r-band abs mag RMS: {np.sqrt(np.mean(abs_mag_r_error**2)):.4f} mag")
print(f"  g-r color RMS: {np.sqrt(np.mean(gmr_error**2)):.4f} mag")


# Plot validation results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Plot 1: K-correction comparison for r-band
ax = axes[0, 0]
ax.scatter(true_kcorr_r_y3, pred_kcorr_r_y3, alpha=0.1, s=1)
ax.plot([-1, 1], [-1, 1], 'r--', lw=2)
ax.set_xlim(-0.5, 1)
ax.set_ylim(-0.5, 1)
ax.set_xlabel('True r-band k-correction')
ax.set_ylabel('Predicted r-band k-correction')
ax.set_title(f'r-band k-correction\nMAE={np.mean(np.abs(kcorr_r_error)):.4f}')
ax.grid(True, alpha=0.3)

# Plot 2: K-correction comparison for g-band
ax = axes[0, 1]
ax.scatter(true_kcorr_g_y3, pred_kcorr_g_y3, alpha=0.1, s=1)
ax.plot([-1, 1], [-1, 1], 'r--', lw=2)
ax.set_xlim(-0.5, 2)
ax.set_ylim(-0.5, 2)
ax.set_xlabel('True g-band k-correction')
ax.set_ylabel('Predicted g-band k-correction')
ax.set_title(f'g-band k-correction\nMAE={np.mean(np.abs(kcorr_g_error)):.4f}')
ax.grid(True, alpha=0.3)

# Plot 3: K-correction error histogram
ax = axes[0, 2]
ax.hist(kcorr_r_error, bins=50, alpha=0.5, label='r-band', density=True)
ax.hist(kcorr_g_error, bins=50, alpha=0.5, label='g-band', density=True)
ax.axvline(0, color='k', linestyle='--', lw=2)
ax.set_xlabel('K-correction error (mag)')
ax.set_ylabel('Density')
ax.set_title('K-correction error distribution')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Absolute magnitude comparison
ax = axes[1, 0]
ax.scatter(true_abs_mag_r_kcorr, pred_abs_mag_r_kcorr, alpha=0.1, s=1)
ax.plot([-24, -16], [-24, -16], 'r--', lw=2)
ax.set_xlabel('True k-corrected r-band abs mag')
ax.set_ylabel('Predicted k-corrected r-band abs mag')
ax.set_title(f'Absolute magnitude\nMAE={np.mean(np.abs(abs_mag_r_error)):.4f}')
ax.grid(True, alpha=0.3)

# Plot 5: g-r color comparison
ax = axes[1, 1]
ax.scatter(true_gmr_kcorr, pred_gmr_kcorr, alpha=0.1, s=1)
ax.plot([0, 1.5], [0, 1.5], 'r--', lw=2)
ax.set_xlabel('True k-corrected g-r color')
ax.set_ylabel('Predicted k-corrected g-r color')
ax.set_title(f'g-r color\nMAE={np.mean(np.abs(gmr_error)):.4f}')
ax.grid(True, alpha=0.3)

# Plot 6: Error vs redshift
ax = axes[1, 2]
ax.scatter(z_y3, abs_mag_r_error, alpha=0.1, s=1)
ax.axhline(0, color='r', linestyle='--', lw=2)
ax.set_xlabel('Redshift')
ax.set_ylabel('Abs mag error (mag)')
ax.set_title('Error vs redshift')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Valdiation of the Final dn4000lookup

In [None]:
# Validation of Dn4000 lookup on Y3 data
print("Loading Y3 validation data for Dn4000...")

# Use the same Y3 subset we already loaded
# Filter to galaxies with all required Dn4000 data
has_dn4000 = ~np.isnan(df_y3_subset['DN4000_MODEL'])
has_logmstar = ~np.isnan(df_y3_subset['LOGMSTAR'])
has_mags = ~np.isnan(df_y3_subset['ABS_MAG_R']) & ~np.isnan(df_y3_subset['ABS_MAG_G'])
valid_dn4000_y3 = has_dn4000 & has_logmstar & has_mags & not_in_y1

df_y3_dn4000 = df_y3_subset[valid_dn4000_y3].copy()
print(f"Number of Y3 galaxies with complete Dn4000 data: {len(df_y3_dn4000):,}")

# Load the saved Dn4000 lookup table
print(f"\nLoading Dn4000 lookup table from {BGS_Y3_DN4000_LOOKUP_FILE}...")
dn4000_lookup = dn4000lookup()

# Query the lookup table for Y3 galaxies
magr_y3_dn = df_y3_dn4000['ABS_MAG_R'].to_numpy()
magg_y3_dn = df_y3_dn4000['ABS_MAG_G'].to_numpy()
gmr_y3_dn = magg_y3_dn - magr_y3_dn

# Get predicted values
pred_dn4000_y3, pred_logmstar_y3 = dn4000_lookup.query(gmr_y3_dn, magr_y3_dn)

# True values from Y3
true_dn4000_y3 = df_y3_dn4000['DN4000_MODEL'].to_numpy()
true_logmstar_y3 = df_y3_dn4000['LOGMSTAR'].to_numpy()

# Calculate errors
dn4000_error = pred_dn4000_y3 - true_dn4000_y3
logmstar_error = pred_logmstar_y3 - true_logmstar_y3

# Print validation statistics
print("\n" + "="*60)
print("DN4000 VALIDATION RESULTS")
print("="*60)
print("\nDn4000 errors:")
print(f"  Dn4000 MAE: {np.mean(np.abs(dn4000_error)):.4f}")
print(f"  Dn4000 RMS: {np.sqrt(np.mean(dn4000_error**2)):.4f}")
print(f"  Dn4000 median error: {np.median(dn4000_error):.4f}")
print(f"  Dn4000 std: {np.std(dn4000_error):.4f}")

print("\nLogMstar errors:")
print(f"  LogMstar MAE: {np.mean(np.abs(logmstar_error)):.4f}")
print(f"  LogMstar RMS: {np.sqrt(np.mean(logmstar_error**2)):.4f}")
print(f"  LogMstar median error: {np.median(logmstar_error):.4f}")
print(f"  LogMstar std: {np.std(logmstar_error):.4f}")

# Plot validation results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Plot 1: Dn4000 comparison
ax = axes[0, 0]
ax.scatter(true_dn4000_y3, pred_dn4000_y3, alpha=0.1, s=1)
ax.plot([1.0, 2.5], [1.0, 2.5], 'r--', lw=2)
ax.set_xlim(1.0, 2.5)
ax.set_ylim(1.0, 2.5)
ax.set_xlabel('True Dn4000')
ax.set_ylabel('Predicted Dn4000')
ax.set_title(f'Dn4000\nMAE={np.mean(np.abs(dn4000_error)):.4f}')
ax.grid(True, alpha=0.3)

# Plot 2: LogMstar comparison
ax = axes[0, 1]
ax.scatter(true_logmstar_y3, pred_logmstar_y3, alpha=0.1, s=1)
ax.plot([8, 12], [8, 12], 'r--', lw=2)
ax.set_xlim(8, 12)
ax.set_ylim(8, 12)
ax.set_xlabel('True log(M*/M_sun)')
ax.set_ylabel('Predicted log(M*/M_sun)')
ax.set_title(f'Stellar Mass\nMAE={np.mean(np.abs(logmstar_error)):.4f}')
ax.grid(True, alpha=0.3)

# Plot 3: Error histograms
ax = axes[0, 2]
ax.hist(dn4000_error, bins=50, alpha=0.7, label='Dn4000', density=True)
ax.axvline(0, color='k', linestyle='--', lw=2)
ax.set_xlabel('Dn4000 error')
ax.set_ylabel('Density')
ax.set_title('Dn4000 error distribution')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Error vs g-r color
ax = axes[1, 0]
ax.scatter(gmr_y3_dn, dn4000_error, alpha=0.1, s=1)
ax.axhline(0, color='r', linestyle='--', lw=2)
ax.set_xlabel('g-r color')
ax.set_ylabel('Dn4000 error')
ax.set_title('Error vs color')
ax.grid(True, alpha=0.3)

# Plot 5: Error vs absolute magnitude
ax = axes[1, 1]
ax.scatter(magr_y3_dn, dn4000_error, alpha=0.1, s=1)
ax.axhline(0, color='r', linestyle='--', lw=2)
ax.set_xlabel('M_r (abs mag)')
ax.set_ylabel('Dn4000 error')
ax.set_title('Error vs magnitude')
ax.grid(True, alpha=0.3)

# Plot 6: LogMstar error histogram
ax = axes[1, 2]
ax.hist(logmstar_error, bins=50, alpha=0.7, label='LogMstar', density=True)
ax.axvline(0, color='k', linestyle='--', lw=2)
ax.set_xlabel('log(M*) error')
ax.set_ylabel('Density')
ax.set_title('Stellar mass error distribution')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Spot Check Final Lookup Tables
For lost galaxies we want a Dn4000 value and stellar mass. But they have no spectra. So, using their k-corrected g-r color and their absolute magnitude (once a z is assigned to them), we can assigned a random Dn4000 based on similar observed galaxies. This approach helps keep whatever systematics are associated with out quiesence determination similar between lost and observed galaxies.

In [None]:
# Example usage
lookup = dn4000lookup()
# Pick 10 random index from df2
n = np.min([len(df_y3), 1000])
idx_test = np.random.choice(len(df_y3), size=n, replace=False)
example_z = df_y3['Z'].values[idx_test]
example_abs_mag = app_mag_to_abs_mag(df_y3['APP_MAG_R'].values[idx_test], example_z)
dn4000_truth = df_y3['DN4000_MODEL'].values[idx_test]
example_gmr = df_y3['APP_MAG_G'].values[idx_test] - df2['APP_MAG_R'].values[idx_test]
quiescent_truth = df2['QUIESCENT'].values[idx_test]

#example_abs_mag = np.random.uniform(-25, -14, size=100000) 
#example_gmr = np.random.uniform(-0.1, 2.5, size=100000) 
nearest_dn4000, near_logmstar = lookup.query(example_abs_mag, example_gmr)
example_q = is_quiescent_BGS_dn4000(abs_mag_r_to_log_solar_L(example_abs_mag), nearest_dn4000, example_gmr)
# TODO gmr_kcorr = 
example_q2 = is_quiescent_BGS_gmr(abs_mag_r_to_log_solar_L(example_abs_mag), example_gmr)

# Print % classification that agrees
agreement = np.sum(quiescent_truth == example_q) / len(example_q) * 100
print(f"\nQuiescent classification agreement (dn4000 method): {agreement:.2f}%\n") # 76%
agreement2 = np.sum(quiescent_truth == example_q2) / len(example_q2) * 100
print(f"Quiescent classification agreement (g-r method): {agreement2:.2f}%\n") # 81%

for i in range(10):
    print(f"Test {i}: M_r={example_abs_mag[i]:.2f}, g-r={example_gmr[i]:.2f}, T DN4000={dn4000_truth[i]:.3f}, lookup DN4000={nearest_dn4000[i]:.3f}, T Q={quiescent_truth[i].astype(int)}, lookup Q={example_q[i].astype(int)}")

In [None]:
# Example
klookup = kcorrlookup()
# Pick 10 random index from df_y3
n = np.min([len(df_y3), 1000])
idx_test = np.random.choice(len(df_y3), size=n, replace=False)
example_z = df_y3['Z'].values[idx_test]
example_abs_mag_r = app_mag_to_abs_mag(df_y3['APP_MAG_R'].values[idx_test], example_z)
example_abs_mag_g = app_mag_to_abs_mag(df_y3['APP_MAG_G'].values[idx_test], example_z)
true_abs_mag_r = df_y3['ABSMAG01_SDSS_R'].values[idx_test]
true_abs_mag_g = df_y3['ABSMAG01_SDSS_G'].values[idx_test]

#example_z = np.random.uniform(0.001, 0.5, size=1000)  # Replace with your actual data
#example_abs_mag_r = np.random.uniform(-25, -15, size=1000)  # Replace with your actual data
#example_abs_mag_g = np.random.uniform(-25, -15, size=1000)   # Replace with your actual data
nearest_kcorr_r, nearest_kcorr_g = klookup.query(example_abs_mag_r, example_abs_mag_g, example_z)
calculated_abs_mag_r, calculated_abs_mag_g = k_correct_fromlookup(example_abs_mag_r, example_abs_mag_g, example_z)

# What % have a k-corrected g-r within 0.1 of the true value?
true_gmr = true_abs_mag_g - true_abs_mag_r
calc_gmr = calculated_abs_mag_g - calculated_abs_mag_r
agreement = np.sum(np.abs(true_gmr - calc_gmr) < 0.1) / len(true_gmr) * 100
print(f"\nK-correction g-r agreement within 0.1 mag: {agreement:.2f}%\n")

for i in range(10):
    print(f"Test {i}: z={example_z[i]:.3f}, M_r={example_abs_mag_r[i]:.2f}, M_g={example_abs_mag_g[i]:.2f}, True M_r^0.1={true_abs_mag_r[i]:.2f}, Calc M_r^0.1={calculated_abs_mag_r[i]:.2f}, True M_g={true_abs_mag_g[i]:.2f}, Calc M_g={calculated_abs_mag_g[i]:.2f}")