# Catalog inspection - By Field

Perform more detailed verification on the datasets, using LSDB to inspect leaf parquet files, using spatial fields.

In [1]:
import os
import lsdb
import numpy as np
import pandas as pd
from tqdm import tqdm
import itertools

from pathlib import Path

pd.set_option("display.max_rows", None)
pd.options.display.float_format = "{:.4f}".format

In [2]:
VERSION = os.environ["VERSION"]
OUTPUT_DIR = Path(os.environ["OUTPUT_DIR"])

print(f"VERSION: {VERSION}")
print(f"OUTPUT_DIR: {OUTPUT_DIR}")

In [3]:
hats_dir = OUTPUT_DIR / "hats" / VERSION
validation_dir = OUTPUT_DIR / "validation" / VERSION
validation_dir.mkdir(parents=True, exist_ok=True)

In [4]:
# Define the six fields with their coordinates
fields = {
    "ECDFS": (53.13, -28.10),  # Extended Chandra Deep Field South
    "EDFS": (59.10, -48.73),  # Euclid Deep Field South
    "Rubin_SV_38_7": (37.86, 6.98),  # Low Ecliptic Latitude Field
    "Rubin_SV_95_-25": (95.00, -25.00),  # Low Galactic Latitude Field
    "47_Tuc": (6.02, -72.08),  # 47 Tuc Globular Cluster
    "Fornax_dSph": (40.00, -34.45),  # Fornax Dwarf Spheroidal Galaxy
}

# Define the radius for selecting sources
selection_radius_arcsec = 2.0 * 3600  # 2-degree radius

# Define bands
bands = ["u", "g", "r", "i", "z", "y"]

In [5]:
# Function to compute statistics for a given band
def compute_band_stats(df, band, stat_columns):
    """Computes mean, min, max, and count of non-NaN values for a specific band."""
    mask = df["band"] == band
    stats = {}

    if mask.sum() > 0:  # Ensure there are sources in this band
        for col in stat_columns:
            col_values = df.loc[mask, col]
            stats[f"mean_{col}_{band}"] = np.nanmean(col_values)
            stats[f"min_{col}_{band}"] = np.nanmin(col_values)
            stats[f"max_{col}_{band}"] = np.nanmax(col_values)
        if "x" in df.columns:
            stats[f"len_{band}"] = mask.sum() - np.count_nonzero(
                np.isnan(df.loc[mask, "x"])
            )
        else:
            stats[f"len_{band}"] = mask.sum()

    return stats


# Function to compute statistics for a given DataFrame
def get_stats(df, stat_columns, out_columns):
    """Computes per-band statistics for a DataFrame, excluding sky sources."""
    if "sky_source" in df.columns:
        df = df[df["sky_source"] == False]  # Exclude rows where sky_source is True
    stats = {col: np.nan for col in out_columns}

    for band in bands:
        stats.update(compute_band_stats(df, band, stat_columns))

    return pd.DataFrame([stats])  # Convert to DataFrame


# Function to compute weighted statistics across fields
def compute_weighted_stats(result, column_names, bands):
    """Computes weighted mean, min, and max statistics for each band."""
    weighted_stats = {}

    for column, band in itertools.product(column_names, bands):
        len_band = np.nansum(result[f"len_{band}"])
        if len_band:
            mean_col_name = f"mean_{column}_{band}"
            min_col_name = f"min_{column}_{band}"
            max_col_name = f"max_{column}_{band}"

            # Compute weighted mean
            weighted_stats[mean_col_name] = (
                np.nansum(result[mean_col_name] * result[f"len_{band}"]) / len_band
            )

            # Compute min and max directly
            weighted_stats[min_col_name] = np.nanmin(result[min_col_name])
            weighted_stats[max_col_name] = np.nanmax(result[max_col_name])

    return weighted_stats


def run_weighted_statistics(cat):

    ## What are the columns of interest for the results? Everything numeric!
    print("starting column count", len(cat._ddf.meta.columns))
    column_names = list(cat._ddf.meta.select_dtypes(include=np.number))

    # Exclude HATS-added columns
    column_names = [
        c for c in column_names if c not in ["_healpix_29", "Norder", "Dir", "Npix"]
    ]
    column_names = [c for c in column_names if not c.endswith("Id")]
    column_names = [c for c in column_names if "Mag" not in c]
    print("effective column count", len(column_names))

    # Create meta dictionary for Dask
    meta = {}
    for column, band in itertools.product(column_names, bands):
        meta[f"mean_{column}_{band}"] = "f8"
        meta[f"min_{column}_{band}"] = "f8"
        meta[f"max_{column}_{band}"] = "f8"
    for band in bands:
        meta[f"len_{band}"] = "i8"

    # Dictionary to store results
    all_results = {}

    # Loop through each field and perform cone search + computation
    for field_name, (ra, dec) in tqdm(fields.items()):
        # Perform cone search for the given field
        field_cat = cat.cone_search(
            ra=ra, dec=dec, radius_arcsec=selection_radius_arcsec
        )

        # Compute statistics using Dask
        result = field_cat.map_partitions(
            get_stats,
            meta=meta,
            stat_columns=column_names,
            out_columns=meta.keys(),
        ).compute()

        # Compute weighted statistics for the field
        all_results[field_name] = compute_weighted_stats(result, column_names, bands)

    # Convert to DataFrame for better visualization
    return pd.DataFrame.from_dict(all_results, orient="index").T

In [6]:
# Omitting.  A numeric 'band' is not in 'source2'
# cat = lsdb.read_hats(hats_dir / "source2")
# final_statistics = run_weighted_statistics(cat)
# final_statistics.to_parquet(validation_dir / "source_byfield.parquet")

In [7]:
cat = lsdb.read_hats(hats_dir / "object_forced_source")
final_statistics = run_weighted_statistics(cat)
final_statistics.to_parquet(validation_dir / "object_forced_source_byfield.parquet")

In [6]:
cat = lsdb.read_hats(hats_dir / "dia_source")
final_statistics = run_weighted_statistics(cat)
final_statistics.to_parquet(validation_dir / "dia_source_byfield.parquet")

In [7]:
cat = lsdb.read_hats(hats_dir / "dia_object_forced_source")
final_statistics = run_weighted_statistics(cat)
final_statistics.to_parquet(validation_dir / "dia_object_forced_source_byfield.parquet")