# Imports

In [None]:
import os
import glob
import warnings
import pyarrow.parquet as pq

from IPython.display import display
import numpy as np
import pandas as pd

# Configuration

In [None]:
# =========================================================
# CONFIGURATION SECTION
# =========================================================
input_paths = [
    "/scratch/users/luigi.silva/speczs-catalogs/processed/2dfgrs_final_release.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/2dflens_final_release.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/2mrs_v240.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/6dfgs_dr3.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/desi_dr1_in_lsst_dp1_fields.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/jades_dr3.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/mosdef_final_release.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/ozdes_dr2.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/primus_dr1.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/vandels_dr4.parquet",
    "/scratch/users/luigi.silva/speczs-catalogs/processed/vvds_final_release.parquet"
]

final_catalog_path = "/scratch/users/luigi.silva/pzserver_pipelines/combine_redshift_dedup/process003/outputs/crd.parquet"
prepared_temp_dir = "/scratch/users/luigi.silva/pzserver_pipelines/combine_redshift_dedup/process003/temp/"

# Validation

## Validation - Final

Counting input and output rows.

In [None]:
# ========================================================
# COUNT INPUT ROWS
# =========================================================
total_input_rows = 0
for path in input_paths:
    if os.path.exists(path):
        parquet_file = pq.ParquetFile(path)
        n_rows = parquet_file.metadata.num_rows
        print(f"{path} -> {n_rows} rows")
        total_input_rows += n_rows
    else:
        warnings.warn(f"⚠️ File not found: {path}")

print(f"✅ Total number of input rows: {total_input_rows}")

# =========================================================
# LOAD FINAL MERGED CATALOG
# =========================================================
if not os.path.exists(final_catalog_path):
    raise FileNotFoundError(f"❌ Final catalog not found: {final_catalog_path}")

df_final = pd.read_parquet(final_catalog_path)
print(f"✅ Total number of rows in final catalog: {len(df_final)}")

In [None]:
df_final

Basic statistics.

In [None]:
df_final.describe()

Counting tie_result values.

In [None]:
df_final["tie_result"].value_counts()

Counting survey values.

In [None]:
df_final["survey"].value_counts()

Checking the percentage of unsolved objects.

In [None]:
# Total number of objects
total_all = len(df_final)

# Filter objects that were compared (compared_to is not null or empty)
mask_compared = df_final["compared_to"].notna() & (df_final["compared_to"] != "")
df_compared = df_final[mask_compared]

# Count how many have tie_result == 2
count_tie2 = (df_final["tie_result"] == 2).sum()
count_tie2_compared = (df_compared["tie_result"] == 2).sum()

# Percentages
percent_all = (count_tie2 / total_all) * 100 if total_all > 0 else 0
percent_compared = (count_tie2_compared / len(df_compared)) * 100 if len(df_compared) > 0 else 0

# Formatted print
print(f"📊 tie_result == 2 represents:")
print(f"  • {percent_all:.2f}% of the total ({count_tie2} out of {total_all})")
print(f"  • {percent_compared:.2f}% of the compared objects ({count_tie2_compared} out of {len(df_compared)})")

Doing the validation.

In [None]:
%%time

# =========================================================
# ANALYZE GROUPS BY COMPARED_TO
# =========================================================
threshold = 0.0005
max_groups = 20000  # Set to a number for debugging (e.g., 50)

desired_order = [
    "CRD_ID", "id", "ra", "dec", "z", "z_flag", "z_err", "type", "survey", "source",
    "z_flag_homogenized", "instrument_type_homogenized", "tie_result", "compared_to", "role"
]

df_final = df_final[
    (df_final["compared_to"].notnull()) &
    (df_final["compared_to"] != "")
]
print(f"✅ Number of rows with non-empty compared_to: {len(df_final)}")

group_cases = {
    "CASE1_small_same": [],
    "CASE1_small_diff": [],
    "CASE1_large_same": [],
    "CASE1_large_diff": [],
    "CASE2_small_same": [],
    "CASE2_small_diff": [],
    "CASE2_large_same": [],
    "CASE2_large_diff": [],
    "TIE_FLAG_TYPE_BREAK_PAIR": [],
    "TIE_FLAG_TYPE_BREAK_GROUP": [],
    "SAME_FLAG_DIFF_TYPE": []
}

processed_groups = 0
seen_groups = set()

for idx, row in df_final.iterrows():
    if max_groups is not None and processed_groups >= max_groups:
        break

    group_ids = tuple(sorted([row["CRD_ID"]] + row["compared_to"].split(",")))
    if group_ids in seen_groups:
        continue
    seen_groups.add(group_ids)

    group_df = df_final[df_final["CRD_ID"].isin(group_ids)].copy()
    if len(group_df) < 2:
        continue

    group_df["role"] = np.where(group_df["CRD_ID"] == row["CRD_ID"], "principal", "compared")

    z_vals = group_df["z"].to_numpy()
    surveys = group_df["survey"].values

    delta_z_matrix = np.abs(z_vals[:, None] - z_vals[None, :])
    pairwise_dz = delta_z_matrix[np.triu_indices(len(z_vals), k=1)]

    max_delta_z = np.max(pairwise_dz)
    all_pairs_below_thresh = np.all(pairwise_dz <= threshold)
    same_survey = len(set(surveys)) == 1

    # Classify group
    if len(group_df) == 2:
        key = "CASE1_small_same" if max_delta_z <= threshold and same_survey else \
              "CASE1_small_diff" if max_delta_z <= threshold else \
              "CASE1_large_same" if same_survey else "CASE1_large_diff"
    else:
        key = "CASE2_small_same" if all_pairs_below_thresh and same_survey else \
              "CASE2_small_diff" if all_pairs_below_thresh else \
              "CASE2_large_same" if same_survey else "CASE2_large_diff"

    # Reorder columns
    all_columns = list(group_df.columns)
    ordered_columns = desired_order + [col for col in all_columns if col not in desired_order]
    group_df = group_df.reindex(columns=ordered_columns)

    group_cases[key].append(group_df)

    # Additional logic
    flags = set(group_df["z_flag_homogenized"].dropna())
    types = set(group_df["instrument_type_homogenized"].dropna())
    surveys_in_group = set(group_df["survey"].dropna())

    if len(surveys_in_group) > 1 and len(flags) == 1 and len(types) > 1:
        case_key = "TIE_FLAG_TYPE_BREAK_PAIR" if len(group_df) == 2 else "TIE_FLAG_TYPE_BREAK_GROUP"
        group_cases[case_key].append(group_df)

    if len(flags) == 1 and len(types) > 1:
        group_cases["SAME_FLAG_DIFF_TYPE"].append(group_df)

    processed_groups += 1

print(f"✅ Processed {processed_groups} unique groups.")

# Case descriptions
case_descriptions = {
    "CASE1_small_same": f"pair with delta_z <= {threshold} from same survey",
    "CASE1_small_diff": f"pair with delta_z <= {threshold} from different surveys",
    "CASE1_large_same": f"pair with delta_z > {threshold} from same survey",
    "CASE1_large_diff": f"pair with delta_z > {threshold} from different surveys",
    "CASE2_small_same": f"group with all delta_z <= {threshold} from same survey",
    "CASE2_small_diff": f"group with all delta_z <= {threshold} from different surveys",
    "CASE2_large_same": f"group with some delta_z > {threshold} from same survey",
    "CASE2_large_diff": f"group with some delta_z > {threshold} from different surveys",
    "TIE_FLAG_TYPE_BREAK_PAIR": "pair with equal z_flag_homogenized, different instrument_type_homogenized, and different surveys",
    "TIE_FLAG_TYPE_BREAK_GROUP": "group with equal z_flag_homogenized, different instrument_type_homogenized, and different surveys",
    "SAME_FLAG_DIFF_TYPE": "group with same z_flag_homogenized but at least one differing instrument_type_homogenized"
}

# Display up to 5 examples per case
for case_name, groups in group_cases.items():
    if not groups:
        continue
    print(f"\n📌 Showing examples of {case_descriptions[case_name]} "
          f"({len(groups)} groups found):")
    for group in groups[:5]:
        display(group)
        print("-" * 80)

## Validation - Prepared Catalogs

In [None]:
# =========================================================
# VALIDATE TRANSLATIONS IN TEMP FILES
# =========================================================
merged_files = glob.glob(os.path.join(prepared_temp_dir, "prepared*/*.parquet"))
if not merged_files:
    print("⚠️ No prepared parquet files found for validation.")
else:
    def validate_row(row):
        survey = row["survey"]
        z_flag = row.get("z_flag", None)
        z_err = row.get("z_err", None)
        zcat_primary = row.get("ZCAT_PRIMARY", None)
    
        z_expected = np.nan
        type_expected = np.nan
    
        if survey == "2DFGRS":
            type_expected = "s"
            z_expected = {1: 0, 2: 1, 3: 3, 4: 4, 5: 4}.get(z_flag, np.nan)
    
        elif survey == "2DFLENS":
            type_expected = "s"
            z_expected = {1: 0, 2: 1, 3: 3, 4: 4, 6: 6}.get(z_flag, np.nan)
    
        elif survey == "2MRS":
            type_expected = "s"
            if z_err == 0:
                z_expected = 3
            elif z_err is not None and 0 < z_err < 0.0005:
                z_expected = 4
            elif z_err is not None and z_err >= 0.0005:
                z_expected = 3
    
        elif survey == "6DFGS":
            type_expected = "s"
            z_expected = {1: 0, 2: 1, 3: 3, 4: 4, 6: 6}.get(z_flag, np.nan)
    
        elif survey == "DESI":
            type_expected = "s"
            if zcat_primary is not True:
                z_expected = 0
            elif z_flag != 0 and zcat_primary is True:
                z_expected = 1
            elif z_flag == 0 and zcat_primary is True:
                if z_err is not None and z_err < 0.0005:
                    z_expected = 4
                elif z_err is not None and z_err >= 0.0005:
                    z_expected = 3
    
        elif survey == "JADES":
            type_expected = "s"
            z_expected = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}.get(z_flag, np.nan)
    
        elif survey == "MOSDEF":
            type_expected = "s"
            z_expected = {
                0: 0, 1: 0, 2: 1, 3: 1, 4: 2, 5: 2, 6: 3, 7: 4
            }.get(z_flag, np.nan)
    
        elif survey == "OZDES":
            type_expected = "s"
            z_expected = {1: 0, 2: 1, 3: 3, 4: 4, 6: 6}.get(z_flag, np.nan)
    
        elif survey == "PRIMUS":
            type_expected = "g"
            z_expected = {-1: 0, 2: 1, 3: 2, 4: 3}.get(z_flag, np.nan)
    
        elif survey == "VANDELS":
            type_expected = "s"
            vandels_map = {
                0: 0, 1: 1, 2: 2, 3: 4, 4: 4, 9: 3,
                10: 0, 11: 1, 12: 2, 13: 4, 14: 4, 19: 3,
                20: 0, 21: 1, 22: 2, 23: 4, 24: 4, 29: 3,
                210: 0, 211: 1, 212: 2, 213: 4, 214: 4, 219: 3
            }
            z_expected = vandels_map.get(z_flag, np.nan)
    
        elif survey == "VVDS":
            type_expected = "s"
            vvds_map = {
                0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 9: 2,
                10: 0, 11: 1, 12: 2, 13: 3, 14: 4, 19: 2,
                20: 0, 21: 1, 22: 2, 23: 3, 24: 4, 29: 2,
                210: 0, 211: 1, 212: 2, 213: 3, 214: 4, 219: 2
            }
            z_expected = vvds_map.get(z_flag, np.nan)
    
        return z_expected, type_expected


    issues = []
    
    for merged_file in merged_files:
        print(f"🔍 Validating {merged_file}")
        df = pd.read_parquet(merged_file)
    
        for _, row in df.iterrows():
            z_exp, type_exp = validate_row(row)
    
            if not (pd.isna(z_exp) and pd.isna(row["z_flag_homogenized"])) and z_exp != row["z_flag_homogenized"]:
                issue = row.to_dict()
                issue["field"] = "z_flag_homogenized"
                issue["expected"] = z_exp
                issue["found"] = row["z_flag_homogenized"]
                issues.append(issue)
    
            if not (pd.isna(type_exp) and pd.isna(row["instrument_type_homogenized"])) and type_exp != row["instrument_type_homogenized"]:
                issue = row.to_dict()
                issue["field"] = "instrument_type_homogenized"
                issue["expected"] = type_exp
                issue["found"] = row["instrument_type_homogenized"]
                issues.append(issue)
    
    if issues:
        issues_df = pd.DataFrame(issues)
        display(issues_df)
        print(f"⚠️ {len(issues)} mismatches found!")
    else:
        print("✅ All homogenized fields match the expected values.")