# Imports

In [None]:
import os
import glob
import warnings
import pyarrow.parquet as pq

from IPython.display import display
import numpy as np
import pandas as pd

import getpass
import os

user = getpass.getuser()

# Configuration

In [None]:
# =========================================================
# CONFIGURATION SECTION
# =========================================================
#input_paths = [
#    f"/scratch/users/{user}/speczs-catalogs/processed/2dfgrs_final_release.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/2dflens_final_release.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/2mrs_v240.parquet",
#    f"/scratch/users/luigi.silva/speczs-catalogs/processed/3dhst_v4.1.5.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/6dfgs_dr3.parquet",
#    f"/scratch/users/luigi.silva/speczs-catalogs/processed/astrodeep_jwst.parquet",
#    f"/scratch/users/luigi.silva/speczs-catalogs/processed/astrodeep-gs43.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/desi_dr1_in_lsst_dp1_fields.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/jades_dr3.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/mosdef_final_release.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/ozdes_dr2.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/primus_dr1.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/vandels_dr4.parquet",
#    f"/scratch/users/luigi.silva/speczs-catalogs/processed/vlt_vimos_v2.0.1.parquet",
#    f"/scratch/users/luigi.silva/speczs-catalogs/processed/vuds_dr1.parquet",
#    f"/scratch/users/{user}/speczs-catalogs/processed/vvds_final_release.parquet",
#    f"/scratch/users/luigi.silva/speczs-catalogs/johns-catalogs/z_cat_CANDELS_clean_sitcomtn-154.parquet",
#    f"/scratch/users/luigi.silva/speczs-catalogs/johns-catalogs/z_cat_NED_clean_sitcomtn-154.parquet",
#]

input_paths = glob.glob('test_data/*.parquet')

# replace paths in case of using local env 
final_catalog_path = f"./process001/outputs/crd.parquet"
prepared_temp_dir = f"./process001/temp/"

combine_mode = "concatenate_and_mark_duplicates" # Options: "concatenate", "concatenate_and_mark_duplicates", or "concatenate_and_remove_duplicates"

# Validation

## Validation - Final

Counting input and output rows.

In [None]:
# ========================================================
# COUNT INPUT ROWS
# =========================================================
total_input_rows = 0
for path in input_paths:
    if os.path.exists(path):
        parquet_file = pq.ParquetFile(path)
        n_rows = parquet_file.metadata.num_rows
        print(f"{path} -> {n_rows} rows")
        total_input_rows += n_rows
    else:
        warnings.warn(f"⚠️ File not found: {path}")

print(f"✅ Total number of input rows: {total_input_rows}")

# =========================================================
# LOAD FINAL MERGED CATALOG
# =========================================================
if not os.path.exists(final_catalog_path):
    raise FileNotFoundError(f"❌ Final catalog not found: {final_catalog_path}")

df_final = pd.read_parquet(final_catalog_path)
print(f"✅ Total number of rows in final catalog: {len(df_final)}")

In [None]:
df_final

In [None]:
df_final.columns

In [None]:
df_final.dtypes

Basic statistics.

In [None]:
df_final.describe()

Counting tie_result values.

In [None]:
if combine_mode != "concatenate":
    print(df_final["tie_result"].value_counts())

Counting survey values.

In [None]:
df_final["source"].value_counts()

Checking the percentage of unsolved objects.

In [None]:
if combine_mode != "concatenate":
    # Total number of objects
    total_all = len(df_final)
    
    # Filter objects that were compared (compared_to is not null or empty)
    mask_compared = df_final["compared_to"].notna() & (df_final["compared_to"] != "")
    df_compared = df_final[mask_compared]
    
    # Count how many have tie_result == 2
    count_tie2 = (df_final["tie_result"] == 2).sum()
    count_tie2_compared = (df_compared["tie_result"] == 2).sum()
    
    # Percentages
    percent_all = (count_tie2 / total_all) * 100 if total_all > 0 else 0
    percent_compared = (count_tie2_compared / len(df_compared)) * 100 if len(df_compared) > 0 else 0
    
    # Formatted print
    print(f"📊 tie_result == 2 represents:")
    print(f"  • {percent_all:.2f}% of the total ({count_tie2} out of {total_all})")
    print(f"  • {percent_compared:.2f}% of the compared objects ({count_tie2_compared} out of {len(df_compared)})")

Doing the validation.

In [None]:
if combine_mode != "concatenate":
    # =========================================================
    # ANALYZE GROUPS BY COMPARED_TO
    # =========================================================
    import numpy as np
    from collections import defaultdict, deque

    threshold = 0.0005
    max_groups = 10000  # For debugging or limiting processed groups

    desired_order = [
        "CRD_ID", "id", "ra", "dec", "z", "z_flag", "z_err", "instrument_type", "survey", "source",
        "z_flag_homogenized", "instrument_type_homogenized", "tie_result", "compared_to", "role"
    ]

    # Filter valid rows: only rows with non-empty 'compared_to'
    df_final_just_compared = df_final[
        (df_final["compared_to"].notnull()) &
        (df_final["compared_to"] != "")
    ]
    print(f"✅ Number of rows with non-empty compared_to: {len(df_final_just_compared)}")

    # ------------------------------------------------------------------
    # Build an undirected adjacency graph from 'compared_to' relationships
    # ------------------------------------------------------------------
    adjacency = defaultdict(set)
    for _, row in df_final_just_compared.iterrows():
        crd_id = row["CRD_ID"]
        for neighbor in str(row["compared_to"]).split(","):
            nb = neighbor.strip()
            if not nb:
                continue
            adjacency[crd_id].add(nb)
            adjacency[nb].add(crd_id)

    # BFS to get connected component starting at 'start_id'
    def get_connected_group(start_id):
        visited = set()
        queue = deque([start_id])
        while queue:
            current = queue.popleft()
            if current in visited:
                continue
            visited.add(current)
            queue.extend(adjacency[current] - visited)
        return tuple(sorted(visited))

    # Buckets for case studies / inspections
    group_cases = {
        "CASE1_small_same": [],
        "CASE1_small_diff": [],
        "CASE1_large_same": [],
        "CASE1_large_diff": [],
        "CASE2_small_same": [],
        "CASE2_small_diff": [],
        "CASE2_large_same": [],
        "CASE2_large_diff": [],
        "TIE_FLAG_TYPE_BREAK_PAIR": [],
        "TIE_FLAG_TYPE_BREAK_GROUP": [],
        "SAME_FLAG_DIFF_TYPE": [],
        "SAME_SOURCE_PAIR": []  # NEW: pairs with identical `source`
    }

    processed_groups = 0
    seen_groups = set()

    # ------------------------------------------------------------------
    # Loop over rows; for each connected group, compute diagnostics once
    # ------------------------------------------------------------------
    for idx, row in df_final_just_compared.iterrows():
        if max_groups is not None and processed_groups >= max_groups:
            break

        group_ids = get_connected_group(row["CRD_ID"])
        if group_ids in seen_groups:
            continue
        seen_groups.add(group_ids)

        group_df = df_final_just_compared[df_final_just_compared["CRD_ID"].isin(group_ids)].copy()
        if len(group_df) < 2:
            continue

        # Mark the row used to discover the group (only for display)
        group_df["role"] = np.where(group_df["CRD_ID"] == row["CRD_ID"], "principal", "compared")

        # Pairwise Δz and survey signature
        z_vals = group_df["z"].to_numpy()
        surveys = group_df["survey"].values

        delta_z_matrix = np.abs(z_vals[:, None] - z_vals[None, :])
        pairwise_dz = delta_z_matrix[np.triu_indices(len(z_vals), k=1)]

        max_delta_z = float(np.max(pairwise_dz)) if len(pairwise_dz) else 0.0
        all_pairs_below_thresh = bool(np.all(pairwise_dz <= threshold)) if len(pairwise_dz) else True
        same_survey = len(set(surveys)) == 1

        # Classification by size (pair/group), Δz, and survey consistency
        if len(group_df) == 2:
            key = (
                "CASE1_small_same" if (max_delta_z <= threshold and same_survey) else
                "CASE1_small_diff" if (max_delta_z <= threshold) else
                "CASE1_large_same" if (same_survey) else
                "CASE1_large_diff"
            )
        else:
            key = (
                "CASE2_small_same" if (all_pairs_below_thresh and same_survey) else
                "CASE2_small_diff" if (all_pairs_below_thresh) else
                "CASE2_large_same" if (same_survey) else
                "CASE2_large_diff"
            )

        # Reorder columns for readability
        all_columns = list(group_df.columns)
        ordered_columns = desired_order + [col for col in all_columns if col not in desired_order]
        group_df = group_df.reindex(columns=ordered_columns)

        group_cases[key].append(group_df)

        # -----------------------------------------------
        # Additional tie-breaking investigations / buckets
        # -----------------------------------------------
        flags = set(group_df["z_flag_homogenized"].dropna())
        types = set(group_df["instrument_type_homogenized"].dropna())
        surveys_in_group = set(group_df["survey"].dropna())

        # Case: same flag, different type, different surveys — split by size
        if len(surveys_in_group) > 1 and len(flags) == 1 and len(types) > 1:
            if len(group_df) == 2:
                group_cases["TIE_FLAG_TYPE_BREAK_PAIR"].append(group_df)
            elif len(group_df) > 2:
                group_cases["TIE_FLAG_TYPE_BREAK_GROUP"].append(group_df)

        # Case: same flag, different types — only for true groups (>2)
        if len(flags) == 1 and len(types) > 1 and len(group_df) > 2:
            group_cases["SAME_FLAG_DIFF_TYPE"].append(group_df)

        # NEW: Case — pairs with the same `source` (normalized, non-null, exact match)
        if len(group_df) == 2:
            src_series = group_df["source"]
            src_valid = src_series.dropna().astype(str).str.strip().str.lower()
            if len(src_valid) == 2 and len(set(src_valid)) == 1:
                group_cases["SAME_SOURCE_PAIR"].append(group_df)

        processed_groups += 1

    print(f"✅ Processed {processed_groups} unique groups.")

    # Human-readable descriptions for each bucket
    case_descriptions = {
        "CASE1_small_same": f"pair with delta_z <= {threshold} from same survey",
        "CASE1_small_diff": f"pair with delta_z <= {threshold} from different surveys",
        "CASE1_large_same": f"pair with delta_z > {threshold} from same survey",
        "CASE1_large_diff": f"pair with delta_z > {threshold} from different surveys",
        "CASE2_small_same": f"group with all delta_z <= {threshold} from same survey",
        "CASE2_small_diff": f"group with all delta_z <= {threshold} from different surveys",
        "CASE2_large_same": f"group with some delta_z > {threshold} from same survey",
        "CASE2_large_diff": f"group with some delta_z > {threshold} from different surveys",
        "TIE_FLAG_TYPE_BREAK_PAIR": "pair with equal z_flag_homogenized, different instrument_type_homogenized, and different surveys",
        "TIE_FLAG_TYPE_BREAK_GROUP": "group with equal z_flag_homogenized, different instrument_type_homogenized, and different surveys",
        "SAME_FLAG_DIFF_TYPE": "group with same z_flag_homogenized but at least one differing instrument_type_homogenized",
        "SAME_SOURCE_PAIR": "pair with identical source (normalized, non-null)"  # NEW
    }

    # Final columns to display for examples
    final_columns = [
        "CRD_ID", "ra", "dec", "z", "z_flag", "z_err",
        "z_flag_homogenized", "instrument_type", "instrument_type_homogenized",
        "tie_result", "survey", "source", "compared_to"
    ]

    def survey_signature(df):
        """
        Signature = sorted tuple of the group's surveys (for diversity sampling).
        Example: ('VANDELS', 'VVDS')
        """
        vals = df["survey"].dropna().astype(str).unique().tolist()
        return tuple(sorted(vals)) if len(vals) else ("<MISSING>",)

    MAX_EXAMPLES_PER_CASE = 5  # show up to 5 groups per bucket

    for case_name, groups in group_cases.items():
        if not groups:
            continue

        print(f"\n📌 Showing examples of {case_descriptions[case_name]} "
              f"({len(groups)} groups found):")

        # 1) Prefer groups with unique survey signatures to increase diversity
        seen_sigs = set()
        diverse_selection = []
        leftovers = []

        for g in groups:
            sig = survey_signature(g)
            if sig not in seen_sigs:
                seen_sigs.add(sig)
                diverse_selection.append(g)
            else:
                leftovers.append(g)

            if len(diverse_selection) >= MAX_EXAMPLES_PER_CASE:
                break

        # 2) If needed, fill remaining slots with leftover groups (allow repetition)
        i = 0
        while len(diverse_selection) < MAX_EXAMPLES_PER_CASE and i < len(leftovers):
            diverse_selection.append(leftovers[i])
            i += 1

        # 3) Display in order with the final selected columns
        for group in diverse_selection:
            group_to_show = group.reindex(columns=final_columns)
            display(group_to_show)
            print("-" * 80)

In [None]:
# --- Validate tie_result patterns for ALL connected groups built from `compared_to` ---
# Rules:
# - Pair (size == 2): allowed patterns are (2,2), (1,0), and (0,0) ONLY if both z_flag_homogenized == 6
# - Group (size > 2): allowed patterns are:
#     * exactly one "1" and the rest "0" (single winner)
#     * mix of some "2" and some "0" (no "1"s)
#     * all "2"
#     * possibly all "0" ONLY if all have z_flag_homogenized == 6
#
# We will:
# 1) build connected groups via `compared_to`
# 2) classify each group into one of the categories below (including INVALID buckets)
# 3) print a few examples per category (if any)

import numpy as np
import pandas as pd
from collections import defaultdict, deque, Counter
from IPython.display import display

# ------------ Helpers ------------
def _parse_neighbors(s):
    """Parse a comma-separated 'compared_to' field into a clean list of IDs as strings."""
    if pd.isna(s) or not str(s).strip():
        return []
    return [x.strip() for x in str(s).split(",") if x.strip()]

def _build_adjacency(df):
    """Build an undirected adjacency mapping from CRD_ID <-> compared_to."""
    adj = defaultdict(set)
    for _, r in df.iterrows():
        u = r["CRD_ID"]
        for v in _parse_neighbors(r["compared_to"]):
            # Keep IDs as strings for consistency; also add reciprocal link
            u_str = str(u)
            v_str = str(v)
            adj[u_str].add(v_str)
            adj[v_str].add(u_str)
    return adj

def _get_connected(start, adj, seen):
    """BFS over adjacency; returns sorted tuple of nodes in the component."""
    q = deque([start])
    comp = []
    while q:
        cur = q.popleft()
        if cur in seen:
            continue
        seen.add(cur)
        comp.append(cur)
        q.extend(adj[cur] - set(comp) - seen)
    return tuple(sorted(comp))

def _all_flag6(group):
    """True if all z_flag_homogenized == 6 (ignoring NaNs means NOT all 6)."""
    zf = pd.to_numeric(group["z_flag_homogenized"], errors="coerce")
    return zf.notna().all() and (zf == 6).all()

def _as_int_series(s):
    """Coerce tie_result to integers; NaNs become -1 (invalid)."""
    return pd.to_numeric(s, errors="coerce").fillna(-1).astype(int)

# Columns to display in examples
final_columns = [
    "CRD_ID", "ra", "dec", "z", "z_flag", "z_err",
    "z_flag_homogenized", "instrument_type", "instrument_type_homogenized",
    "tie_result", "survey", "source", "compared_to"
]

# ------------ Filter rows with usable compared_to ------------
_work = df_final.copy()
_work = _work[(_work["compared_to"].notna()) & (_work["compared_to"].astype(str).str.strip() != "")]
print(f"✅ Rows with non-empty compared_to: {len(_work)}")

# ------------ Build groups ------------
adjacency = _build_adjacency(_work)
seen = set()
groups = []
for node in adjacency.keys():
    if node not in seen:
        comp = _get_connected(node, adjacency, seen)
        if len(comp) >= 1:
            groups.append(comp)

print(f"✅ Connected groups found (considering only rows with non-empty compared_to): {len(groups)}")

# ------------ Classify groups by tie_result pattern ------------
CATS = {
    # Pairs
    "PAIR_22": [],
    "PAIR_10": [],
    "PAIR_00_ALL_FLAG6": [],
    "PAIR_00_NOT_ALL_FLAG6_INVALID": [],
    "PAIR_OTHER_INVALID": [],
    # Groups (>2)
    "GROUP_SINGLE_WINNER_1and0": [],
    "GROUP_MIX_2and0": [],
    "GROUP_ALL_2": [],
    "GROUP_ALL_0_ALL_FLAG6": [],
    "GROUP_ALL_0_NOT_ALL_FLAG6_INVALID": [],
    "GROUP_OTHER_INVALID": [],
}

# Which categories are "allowed" (for summary)
ALLOWED = {
    "PAIR_22",
    "PAIR_10",
    "PAIR_00_ALL_FLAG6",
    "GROUP_SINGLE_WINNER_1and0",
    "GROUP_MIX_2and0",
    "GROUP_ALL_2",
    "GROUP_ALL_0_ALL_FLAG6",
}

for comp in groups:
    # Note: CRD_ID in df is numeric or string; our adjacency keys are strings.
    # We match on string-form to be robust, then bring back to the DataFrame subset.
    comp_set_str = set(comp)
    g = _work[_work["CRD_ID"].astype(str).isin(comp_set_str)].copy()
    if len(g) < 2:
        continue  # ignore singletons for this validation

    ties = _as_int_series(g["tie_result"])
    cnt = Counter(ties.tolist())
    size = len(g)

    if size == 2:
        # Normalize as a sorted pair
        pattern = tuple(sorted(ties.tolist()))
        if pattern == (2, 2):
            CATS["PAIR_22"].append(g[final_columns])
        elif pattern == (0, 1):
            CATS["PAIR_10"].append(g[final_columns])
        elif pattern == (0, 0):
            if _all_flag6(g):
                CATS["PAIR_00_ALL_FLAG6"].append(g[final_columns])
            else:
                CATS["PAIR_00_NOT_ALL_FLAG6_INVALID"].append(g[final_columns])
        else:
            CATS["PAIR_OTHER_INVALID"].append(g[final_columns])

    else:
        # Groups (>2)
        n0 = cnt.get(0, 0)
        n1 = cnt.get(1, 0)
        n2 = cnt.get(2, 0)
        total = sum(cnt.values())

        if n1 == 1 and (n0 == total - 1) and n2 == 0:
            # one winner (1) + the rest losers (0), no 2s
            CATS["GROUP_SINGLE_WINNER_1and0"].append(g[final_columns])
        elif n2 > 0 and n1 == 0 and (n0 + n2 == total):
            # some 2s + some 0s (no 1s)
            if n0 == 0:
                CATS["GROUP_ALL_2"].append(g[final_columns])
            else:
                CATS["GROUP_MIX_2and0"].append(g[final_columns])
        elif n0 == total:
            if _all_flag6(g):
                CATS["GROUP_ALL_0_ALL_FLAG6"].append(g[final_columns])
            else:
                CATS["GROUP_ALL_0_NOT_ALL_FLAG6_INVALID"].append(g[final_columns])
        else:
            # Anything else (e.g., 1 mixed with 2; multiple winners; unexpected values)
            CATS["GROUP_OTHER_INVALID"].append(g[final_columns])

# ------------ Summary ------------
allowed_count = sum(len(CATS[k]) for k in ALLOWED)
invalid_count = sum(len(CATS[k]) for k in CATS.keys() - ALLOWED)
print("\n===== Validation summary =====")
for k, v in CATS.items():
    tag = "✅" if k in ALLOWED else "⚠️"
    print(f"{tag} {k}: {len(v)} groups")
print(f"TOTAL allowed groups: {allowed_count}")
print(f"TOTAL invalid groups: {invalid_count}")

# ------------ Show a few examples per category ------------
MAX_EXAMPLES_PER_CAT = 8  # tweak as you like

for cat, lst in CATS.items():
    if not lst:
        continue
    print("\n" + ("=" * 88))
    print(f"{'ALLOWED' if cat in ALLOWED else 'INVALID'} → {cat}  |  examples: {min(MAX_EXAMPLES_PER_CAT, len(lst))} of {len(lst)}")
    print("=" * 88)
    for i, df_ex in enumerate(lst[:MAX_EXAMPLES_PER_CAT], start=1):
        print(f"\n[{cat}] Example {i}")
        # Helpful quick glance at the pattern
        tie_list = df_ex['tie_result'].tolist()
        flag_list = df_ex['z_flag_homogenized'].tolist()
        print(f"tie_result: {tie_list} | z_flag_homogenized: {flag_list}")
        display(df_ex)
        print("-" * 88)

In [None]:
# --- Validation: when compared_to is <NA>, tie_result must be 1
# --- EXCEPTION: tie_result may be 0 ONLY if z_flag_homogenized == 6

import pandas as pd

# Ensure compared_to is normalized to real NA (StringDtype shows <NA>)
df_final_val = df_final.copy()
df_final_val["compared_to"] = (
    df_final_val["compared_to"]
      .astype("string").str.strip()
      .replace({"": pd.NA, "nan": pd.NA, "NaN": pd.NA, "NA": pd.NA, "<NA>": pd.NA, "None": pd.NA, "null": pd.NA})
)

# 1) Subset to rows where compared_to is truly missing
na_cmp = df_final_val[df_final_val["compared_to"].isna()].copy()

# 2) Parse tie_result and z_flag_homogenized
tie_as_int   = pd.to_numeric(na_cmp["tie_result"], errors="coerce").fillna(-1).astype(int)
zflag_as_int = pd.to_numeric(na_cmp["z_flag_homogenized"], errors="coerce").fillna(-1).astype(int)

# 3) Valid if:
#    - tie_result == 1
#    - OR tie_result == 0 and z_flag_homogenized == 6
valid_mask = (tie_as_int.eq(1)) | (tie_as_int.eq(0) & zflag_as_int.eq(6))
violations = na_cmp[~valid_mask].copy()

# 4) Summary
total_na     = len(na_cmp)
valid_count  = int(valid_mask.sum())
invalid_count = len(violations)

print("===== compared_to <NA> validation (tie_result rule with flag-6 exception) =====")
print(f"Rows with compared_to <NA>: {total_na}")
print(f"Valid:                      {valid_count}")
print(f"INVALID:                    {invalid_count}")

# Quick distribution to sanity-check what's inside the NA-compared set
print("\nCrosstab of tie_result x z_flag_homogenized (parsed ints) for <NA> compared_to:")
print(pd.crosstab(tie_as_int, zflag_as_int, dropna=False))

# 5) Show some offending rows (if any)
if invalid_count > 0:
    cols_to_show = [
        "CRD_ID", "ra", "dec", "z", "z_flag", "z_err",
        "z_flag_homogenized", "instrument_type", "instrument_type_homogenized",
        "tie_result", "survey", "source", "compared_to"
    ]
    print("\n⚠️ Examples of violations (up to 10):")
    display(violations[cols_to_show].head(10))
else:
    print("\n✅ All rows with compared_to <NA> satisfy the rule (1, or 0 with flag==6).")

# 6) Optional hard assertion
# assert invalid_count == 0, "Found rows where compared_to is <NA> but tie_result is invalid (not 1, nor 0 with flag==6)."

In [None]:
import pandas as pd

# --- Normalize compared_to so blanks/strings like "nan" become NA ---
df_chk = df_final.copy()
df_chk["compared_to_norm"] = (
    df_chk["compared_to"]
      .astype("string").str.strip()
      .replace({"": pd.NA, "nan": pd.NA, "NaN": pd.NA, "NA": pd.NA,
                "<NA>": pd.NA, "None": pd.NA, "null": pd.NA})
)

# --- Check condition only on rows with z_flag_homogenized == 6 ---
mask6 = df_chk["z_flag_homogenized"].eq(6)
all_na = df_chk.loc[mask6, "compared_to_norm"].isna().all()

print(f"Todos os z_flag_homogenized==6 têm compared_to <NA>? {all_na}")

# Optional: list offending rows if any
if not all_na:
    offenders = df_chk.loc[
        mask6 & df_chk["compared_to_norm"].notna(),
        ["CRD_ID", "z", "z_flag_homogenized", "survey", "source", "compared_to", "tie_result"]
    ].head(20)
    display(offenders)

## Validation - Prepared Catalogs

In [None]:
# Dicionário de regras equivalente ao YAML
translation_rules = {
    "2DFGRS": {
        "z_flag_translation": {1: 0, 2: 1, 3: 3, 4: 4, 5: 4},
        "instrument_type_translation": {"default": "s"},
    },
    "2DFLENS": {
        "z_flag_translation": {1: 0, 2: 1, 3: 3, 4: 4, 6: 6},
        "instrument_type_translation": {"default": "s"},
    },
    "2MRS": {
        "z_flag_translation": {
            "conditions": [
                {"expr": "z_err == 0", "value": 3},
                {"expr": "0 < z_err < 0.0005", "value": 4},
                {"expr": "z_err >= 0.0005", "value": 3},
            ],
            "default": 0,
        },
        "instrument_type_translation": {"default": "s"},
    },
    "3D-HST": {
        "z_flag_translation": {
            "conditions": [
                {"expr": "z_best_s == 0", "value": 6},
                {"expr": "z_best_s == 1 and z_spec != -1", "value": 4},
                {"expr": "z_best_s == 2 and use_zgrism == 1 and flag1 == 0 and flag2 == 0", "value": 3},
                {"expr": "z_best_s == 3 and use_phot == 1", "value": 3},
            ],
            "default": 0,
        },
        "instrument_type_translation": {
            "conditions": [
                {"expr": "z_best_s == 1", "value": "s"},
                {"expr": "z_best_s == 2", "value": "g"},
                {"expr": "z_best_s == 3", "value": "p"},
            ],
            "default": "g",
        },
    },
    "6DFGS": {
        "z_flag_translation": {1: 0, 2: 1, 3: 3, 4: 4, 6: 6},
        "instrument_type_translation": {"default": "s"},
    },
    "ASTRODEEP": {
        "z_flag_translation": {
            "conditions": [
                {"expr": "zspec_survey != '-'", "value": 4},
                {"expr": "zspec_survey == '-'", "value": 3},
            ],
            "default": 0,
        },
        "instrument_type_translation": {
            "conditions": [
                {"expr": "zspec_survey != '-'", "value": "s"},
                {"expr": "zspec_survey == '-'", "value": "p"},
            ],
            "default": "p",
        },
    },
    "ASTRODEEP-JWST": {
        "z_flag_translation": {
            "conditions": [
                {"expr": "zspec != -99 and z_flag < 400 and (len(str(int(z_flag))) <= 1 or int(str(int(z_flag))[-2]) <= 3)", "value": 4},
                {"expr": "zspec == -99 and z_flag < 400 and (len(str(int(z_flag))) <= 1 or int(str(int(z_flag))[-2]) <= 3)", "value": 3},
            ],
            "default": 0,
        },
        "instrument_type_translation": {
            "conditions": [
                {"expr": "zspec != -99", "value": "s"},
                {"expr": "zspec == -99", "value": "p"},
            ],
            "default": "p",
        },
    },
    "DESI": {
        "z_flag_translation": {
            "conditions": [
                {"expr": "ZCAT_PRIMARY != True", "value": 0},
                {"expr": "z_flag != 0 and ZCAT_PRIMARY == True", "value": 1},
                {"expr": "z_flag == 0 and ZCAT_PRIMARY == True and z_err < 0.0005", "value": 4},
                {"expr": "z_flag == 0 and ZCAT_PRIMARY == True and z_err >= 0.0005", "value": 3},
            ],
            "default": 0,
        },
        "instrument_type_translation": {"default": "s"},
    },
    "JADES": {
        "z_flag_translation": {4: 4, 3: 3, 2: 2, 1: 1, 0: 0},
        "instrument_type_translation": {"default": "s"},
    },
    "MOSDEF": {
        "z_flag_translation": {7: 4, 6: 3, 5: 2, 4: 2, 3: 1, 2: 1, 1: 0, 0: 0},
        "instrument_type_translation": {"default": "s"},
    },
    "OZDES": {
        "z_flag_translation": {1: 0, 2: 1, 3: 3, 4: 4, 6: 6},
        "instrument_type_translation": {"default": "s"},
    },
    "PRIMUS": {
        "z_flag_translation": {-1: 0, 2: 1, 3: 2, 4: 3},
        "instrument_type_translation": {"default": "g"},
    },
    "VANDELS": {
        "z_flag_translation": {
            0: 0, 1: 1, 2: 2, 3: 4, 4: 4, 9: 3,
            10: 0, 11: 1, 12: 2, 13: 4, 14: 4, 19: 3,
            20: 0, 21: 1, 22: 2, 23: 4, 24: 4, 29: 3,
            210: 0, 211: 1, 212: 2, 213: 4, 214: 4, 219: 3,
        },
        "instrument_type_translation": {"default": "s"},
    },
    "VIMOS": {
        "z_flag_translation": {4: 4, 3: 3, 2: 2, 1: 1, 0: 0},
        "instrument_type_translation": {"default": "s"},
    },
    "VUDS": {
        "z_flag_translation": {
            1: 1, 11: 1, 21: 1, 31: 1, 41: 1,
            2: 2, 12: 2, 22: 2, 32: 2, 42: 2, 9: 2, 19: 2, 29: 2, 39: 2, 49: 2,
            3: 3, 13: 3, 23: 3, 33: 3, 43: 3,
            4: 4, 14: 4, 24: 4, 34: 4, 44: 4,
        },
        "instrument_type_translation": {"default": "s"},
    },
    "VVDS": {
        "z_flag_translation": {
            0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 9: 2,
            10: 0, 11: 1, 12: 2, 13: 3, 14: 4, 19: 2,
            20: 0, 21: 1, 22: 2, 23: 3, 24: 4, 29: 2,
            210: 0, 211: 1, 212: 2, 213: 3, 214: 4, 219: 2,
        },
        "instrument_type_translation": {"default": "s"},
    },

    # Special cases using continuous rule and inherited type
    "CANDELS": {"_special": "CANDELS_NED"},
    "NED": {"_special": "CANDELS_NED"},
}

def _safe_eval_expr(expr: str, ctx: dict) -> bool:
    """
    Avalia 'expr' usando apenas variáveis do ctx e funções básicas.
    Retorna True/False; se der erro, retorna False.
    """
    try:
        # Permitir apenas funções básicas e numpy
        allowed_globals = {
            "__builtins__": {"len": len, "int": int, "str": str, "float": float},
            "np": np,
        }
        return bool(eval(expr, allowed_globals, ctx))
    except Exception:
        return False

def _apply_translation(value_map, row_ctx):
    """
    value_map pode ser:
      - dict simples {orig: dest} (pode conter 'default')
      - dict com 'conditions' (lista de {expr, value}) e opcional 'default'
    Retorna (valor_traduzido, matched_bool)
    """
    if isinstance(value_map, dict) and "conditions" in value_map:
        for cond in value_map["conditions"]:
            expr = cond.get("expr", "")
            val = cond.get("value", np.nan)
            if expr and _safe_eval_expr(expr, row_ctx):
                return val, True
        # nenhum matched -> usa default se houver
        if "default" in value_map:
            return value_map["default"], True
        return np.nan, False

    # mapeamento direto (sem 'conditions'):
    if isinstance(value_map, dict):
        key = row_ctx.get("z_flag", np.nan)
        if key in value_map:
            return value_map[key], True
        # Se não houver chave correspondente, mas existir 'default', use-o
        if "default" in value_map:
            return value_map["default"], True
        return np.nan, False

    return np.nan, False

def validate_row(row):
    survey = row.get("survey", None)

    # construir contexto com None -> np.nan, para evitar erros de comparação
    ctx = {}
    for k, v in row.items():
        ctx[k] = (np.nan if v is None else v)

    # Casos especiais (CANDELS e NED): regra contínua 0..1 e type herdado
    if survey in ("CANDELS", "NED"):
        x = row.get("z_flag", np.nan)
        # z_flag esperado:
        if x == 0.0:
            z_expected = 0.0
        elif (isinstance(x, (float, int))) and (0.0 < x < 0.7):
            z_expected = 1.0
        elif (isinstance(x, (float, int))) and (0.7 <= x < 0.9):
            z_expected = 2.0
        elif (isinstance(x, (float, int))) and (0.9 <= x < 0.99):
            z_expected = 3.0
        elif (isinstance(x, (float, int))) and (0.99 <= x <= 1.0):
            z_expected = 4.0
        else:
            z_expected = np.nan

        # type_expected é o próprio 'type' da linha
        type_expected = row.get("instrument_type", np.nan)
        return z_expected, type_expected

    # Regras gerais dos surveys
    rules = translation_rules.get(survey, None)
    if rules is None:
        return np.nan, np.nan

    # z_flag_homogenized esperado
    z_rules = rules.get("z_flag_translation", None)
    if z_rules is None:
        z_expected = np.nan
    else:
        z_expected, _ = _apply_translation(z_rules, ctx)

    # instrument_type_homogenized esperado
    t_rules = rules.get("instrument_type_translation", None)
    if t_rules is None:
        type_expected = np.nan
    else:
        if isinstance(t_rules, dict) and ("conditions" in t_rules or "default" in t_rules):
            type_expected, matched = _apply_translation(t_rules, ctx)
        else:
            type_expected, matched = _apply_translation(t_rules, ctx)


    return z_expected, type_expected


# =========================================================
# VALIDATE TRANSLATIONS IN TEMP FILES
# =========================================================
merged_files = glob.glob(os.path.join(prepared_temp_dir, "prepared*/*.parquet"))
merged_files = [f for f in merged_files if "pipeline_sample" not in f]

if not merged_files:
    print("⚠️ No prepared parquet files found for validation.")
else:
    issues = []
    
    for merged_file in merged_files:
        print(f"🔍 Validating {merged_file}")
        df = pd.read_parquet(merged_file)
    
        for _, row in df.iterrows():
            z_exp, type_exp = validate_row(row)
    
            if not (pd.isna(z_exp) and pd.isna(row["z_flag_homogenized"])) and z_exp != row["z_flag_homogenized"]:
                issue = row.to_dict()
                issue["field"] = "z_flag_homogenized"
                issue["expected"] = z_exp
                issue["found"] = row["z_flag_homogenized"]
                issues.append(issue)
    
            if not (pd.isna(type_exp) and pd.isna(row["instrument_type_homogenized"])) and type_exp != row["instrument_type_homogenized"]:
                issue = row.to_dict()
                issue["field"] = "instrument_type_homogenized"
                issue["expected"] = type_exp
                issue["found"] = row["instrument_type_homogenized"]
                issues.append(issue)
    
    if issues:
        issues_df = pd.DataFrame(issues)
        display(issues_df)
        print(f"⚠️ {len(issues)} mismatches found!")
    else:
        print("✅ All homogenized fields match the expected values.")

# Time Profiler

In [None]:
import os
import re
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from collections import defaultdict

# ============================================
# 1. CONFIGURAÇÃO
# ============================================

log_dir = "process001/process_info"

log_files = [
    "prepare_all.log",
    "import_all.log",
    "margin_cache_all.log",
    "crossmatch_and_merge_all.log",
    "process.log"
]

START_RE = re.compile(
    r"(?P<timestamp>\d{4}-\d{2}-\d{2}-\d{2}:\d{2}:\d{2}\.\d+): Starting: (?P<task>[\w_]+) id=(?P<id>[\w\d_]+)"
)
FINISH_RE = re.compile(
    r"(?P<timestamp>\d{4}-\d{2}-\d{2}-\d{2}:\d{2}:\d{2}\.\d+): Finished: (?P<task>[\w_]+) id=(?P<id>[\w\d_]+)"
)

start_times = {}
end_times = {}

# ============================================
# 2. LEITURA E PARSE DOS LOGS
# ============================================

for file in log_files:
    path = os.path.join(log_dir, file)
    if not os.path.exists(path):
        continue

    with open(path) as f:
        for line in f:
            m_start = START_RE.search(line)
            m_finish = FINISH_RE.search(line)

            if m_start:
                task_id = f"{m_start.group('task')}|{m_start.group('id')}"
                if task_id not in start_times:
                    start_times[task_id] = datetime.strptime(
                        m_start.group("timestamp"), "%Y-%m-%d-%H:%M:%S.%f"
                    )

            if m_finish:
                task_id = f"{m_finish.group('task')}|{m_finish.group('id')}"
                if task_id not in end_times:
                    end_times[task_id] = datetime.strptime(
                        m_finish.group("timestamp"), "%Y-%m-%d-%H:%M:%S.%f"
                    )

# ============================================
# 2b. AJUSTAR TEMPOS DOS PREPARE_CATALOGS
# ============================================

prepare_ids = [tid for tid in start_times if tid.startswith("prepare_catalog|") and tid != "prepare_catalogs|prepare_catalogs"]
prepare_catalogs_id = "prepare_catalogs|prepare_catalogs"

if prepare_catalogs_id in start_times:
    general_prepare_start = start_times[prepare_catalogs_id]
    for tid in prepare_ids:
        start_times[tid] = general_prepare_start

# ============================================
# 3. CONSTRUÇÃO DO EIXO Y EM ORDEM CUSTOMIZADA
# ============================================

all_ids = sorted(set(start_times) & set(end_times))

pipeline_init_id = "pipeline_init|pipeline_init"
consolidate_id = "consolidate|consolidate"

import_cat0 = [tid for tid in all_ids if tid == "import_catalog|cat0_hats"]

remaining_ids = [tid for tid in all_ids if tid not in prepare_ids + import_cat0 + [pipeline_init_id, consolidate_id]]

step_dict = defaultdict(list)
for tid in remaining_ids:
    match = re.search(r"(?:cat|merged_step)(\d+)", tid)
    if match:
        step = int(match.group(1))
        step_dict[step].append(tid)

ordered_step_ids = []
for step in sorted(step_dict):
    step_tasks = step_dict[step]

    def task_order(tid):
        if tid.startswith("import_catalog|cat"):
            return 0
        elif tid.startswith("generate_margin_cache"):
            return 1
        elif tid.startswith("crossmatch_and_merge"):
            return 2
        elif tid.startswith("import_catalog|merged_step"):
            return 3
        else:
            return 99

    ordered_step_ids.extend(sorted(step_tasks, key=task_order))

ordered_ids = [pipeline_init_id] + prepare_ids + import_cat0 + ordered_step_ids + [consolidate_id]

# ============================================
# 4. MONTAGEM DOS DADOS PARA O PLOT
# ============================================

# Tempo adicional a ser subtraído do início do pipeline_init (em segundos)
aditional_pipeline_init_time = 3  # ⏱️ ajuste aqui conforme necessário

# 🛠️ Aplicar tempo extra retroativo ao início do pipeline_init
start_times[pipeline_init_id] -= timedelta(seconds=aditional_pipeline_init_time)

# Novo zero do gráfico com base nesse novo tempo
start_zero = min(start_times[tid] for tid in ordered_ids)

# --------------------------------------------------
# 🛠️ INSERIR REGISTRO MANUAL DA TAREFA "register"
# --------------------------------------------------
register_id = "register|register"
register_duration = 3  # ⏱️ ajuste aqui a duração da tarefa "register" em segundos

ordered_ids.append(register_id)
register_start = max(end_times[consolidate_id], *end_times.values())
register_end = register_start + timedelta(seconds=register_duration)

start_times[register_id] = register_start
end_times[register_id] = register_end
# --------------------------------------------------

# Recalcular dados do gráfico com tempos relativos ao novo start_zero
y_labels = []
start_list = []
end_list = []

for tid in ordered_ids:
    y_labels.append(tid)
    start_rel = (start_times[tid] - start_zero).total_seconds()
    end_rel = (end_times[tid] - start_zero).total_seconds()
    start_list.append(start_rel)
    end_list.append(end_rel)

# ============================================
# 4b. AJUSTAR POSIÇÕES Y PARA SEPARAR "register"
# ============================================

# Cria posições Y padrão e separa o último (register) com um espaçamento extra
y_positions = list(range(len(ordered_ids)))
y_positions[-1] += 5.0  # 🛠️ Aumenta a posição do "register" no eixo Y

# ============================================
# 5. PLOTAGEM DO GRÁFICO DE TIME PROFILE
# ============================================

plt.figure(figsize=(12, 4))

# === Mapear cores por grupo
group_colors = {
    "pipeline_init": "#003f5c",       # azul escuro
    "prepare_catalogs": "#b8860b",    # amarelo escuro
    "crossmatch": "#2f855a",          # verde escuro
    "consolidate": "#003f5c",         # bege escuro
    "register": "#003f5c",            # 🛠️ mesmo azul escuro do pipeline_init
}

# === Determinar grupo de cada tarefa
def get_group(tid):
    if tid == pipeline_init_id:
        return "pipeline_init"
    elif tid == register_id:
        return "register"
    elif tid in prepare_ids:
        return "prepare_catalogs"
    elif tid == consolidate_id:
        return "consolidate"
    else:
        return "crossmatch"

# === Plotar tarefas com cor unificada para linha e bolinhas
for y, start, end, tid in zip(y_positions, start_list, end_list, ordered_ids):
    group = get_group(tid)
    color = group_colors[group]

    plt.hlines(y, start, end, colors=color, linewidth=2)
    plt.scatter(start, y, color=color, s=10)  # 🟢 mesmo tom da linha (início)
    plt.scatter(end, y, color=color, s=10)    # 🔴 mesmo tom da linha (fim)

# ============================================
# Agrupar labels do eixo Y por grupo
# ============================================

group_positions = defaultdict(list)
for y, tid in zip(y_positions, ordered_ids):
    group_positions[get_group(tid)].append(y)

group_labels = []
group_ticks = []

for label in ["pipeline_init", "prepare_catalogs", "crossmatch", "consolidate", "register"]:
    if group_positions[label]:
        center = sum(group_positions[label]) / len(group_positions[label])
        group_labels.append(label)
        group_ticks.append(center)

# ============================================
# Personalização final do gráfico
# ============================================

plt.yticks(group_ticks, group_labels, fontsize=20)
plt.xticks(fontsize=12)
plt.xlabel("Time (s)", fontsize=20)
#plt.ylabel("Task Group", fontsize=25)
#plt.title("Time Profile", fontsize=30)
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()