In [None]:
import os
import glob
import yaml
import pandas as pd
import dask.dataframe as dd
import warnings
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
import hats
from hats_import.pipeline import ImportArguments, pipeline_with_client
from hats_import.catalog.file_readers import ParquetReader
from hats_import.margin_cache.margin_cache_arguments import MarginCacheArguments
import lsdb
from pathlib import Path

# ----------------------------
# INITIAL CONFIGURATION
# ----------------------------
path_to_yaml_file = "/scratch/users/luigi.silva/pzserver_pipelines/combine_specz/notebooks/config.yaml"

with open(path_to_yaml_file) as f:
    config = yaml.safe_load(f)

base_dir = config["base_dir"]
output_dir = os.path.join(base_dir, config["output_dir"])
logs_dir = os.path.join(base_dir, config["logs_dir"])
temp_dir = os.path.join(base_dir, config["temp_dir"])
os.makedirs(output_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)
os.makedirs(temp_dir, exist_ok=True)

catalogs = config["inputs"]["specz"]

# ----------------------------
# HELPER FUNCTIONS
# ----------------------------

def setup_cluster(config, logs_dir):
    executor = config["executor"]
    if executor["name"] == "slurm":
        instance = executor["args"]["instance"]
        adapt = executor["args"].get("adapt", {})
        job_extra_directives = instance.get("job_extra_directives", [])
        job_extra_directives += [
            f"--output={logs_dir}/dask_job_%j.out",
            f"--error={logs_dir}/dask_job_%j.err"
        ]
        cluster = SLURMCluster(
            interface="ib0",
            queue=instance["queue"],
            cores=instance["cores"],
            processes=instance["processes"],
            memory=instance["memory"],
            job_extra_directives=job_extra_directives,
        )
        cluster.adapt(**adapt)
        return Client(cluster)
    else:
        return Client()

def import_catalog(entry, artifact_name, output_path, client):
    input_file_list = entry.get("input_file_list")
    if input_file_list is None:
        input_file_list = [entry["path"]]

    df = dd.read_parquet(input_file_list)

    has_survey = "survey" in df.columns
    has_year = "year" in df.columns

    if not has_survey and "survey" in entry:
        df["survey"] = entry["survey"]
        has_survey = True

    if not has_year and "year" in entry:
        df["year"] = entry["year"]
        has_year = True

    # If metadata is already present, use it directly
    if has_survey and has_year and "input_file_list" in entry:
        column_names = list(entry["columns"].values()) + ["survey", "year"]
        file_reader = ParquetReader(column_names=column_names)
        args = ImportArguments(
            ra_column=entry["columns"]["ra"],
            dec_column=entry["columns"]["dec"],
            input_file_list=input_file_list,
            file_reader=file_reader,
            output_artifact_name=artifact_name,
            output_path=output_path,
        )
    else:
        # Otherwise, save a temporary version with added metadata
        temp_parquet_path = os.path.join(output_path, f"tmp_{artifact_name}_with_meta")
        df.to_parquet(temp_parquet_path, engine="pyarrow", write_index=False)
        column_names = list(entry["columns"].values())
        if has_survey:
            column_names.append("survey")
        if has_year:
            column_names.append("year")
        file_reader = ParquetReader(column_names=column_names)
        args = ImportArguments(
            ra_column=entry["columns"]["ra"],
            dec_column=entry["columns"]["dec"],
            input_file_list=sorted(glob.glob(os.path.join(temp_parquet_path, "part.*.parquet"))),
            file_reader=file_reader,
            output_artifact_name=artifact_name,
            output_path=output_path,
        )
    pipeline_with_client(args, client)

def generate_margin_cache(hats_path, output_path, artifact_name, client):
    catalog = hats.read_hats(hats_path)
    info = catalog.partition_info.as_dataframe().astype(int)
    if len(info) > 1:
        args = MarginCacheArguments(
            input_catalog_path=hats_path,
            output_path=output_path,
            margin_threshold=1.0,
            output_artifact_name=artifact_name
        )
        pipeline_with_client(args, client)
    else:
        warnings.warn(f"Number of pixels is {len(info)}. Margin cache will not be generated.")

def crossmatch_and_merge(left_catalog, right_catalog, output_xmatch_dir):
    # Perform 1 arcsec crossmatch between left and right catalogs
    xmatched = left_catalog.crossmatch(right_catalog, radius_arcsec=1.0, n_neighbors=1, suffixes=("left", "right"))
    xmatched.to_hats(output_xmatch_dir, overwrite=True)
    df = lsdb.read_hats(output_xmatch_dir)._ddf

    # Define rule to decide which source to keep
    def decide_winner(row):
        zf1, zf2 = row["z_flagleft"], row["z_flagright"]
        y1, y2 = row["yearleft"], row["yearright"]
        if zf1 > zf2: return "left"
        elif zf1 < zf2: return "right"
        return "left" if y1 > y2 else "right" if y1 < y2 else "tie"

    df = df.assign(winner=df.map_partitions(lambda p: p.apply(decide_winner, axis=1), meta=("winner", "str")))

    # Extract IDs of the discarded sources
    def extract_losers(part):
        return pd.DataFrame({
            "loser_cat1": part.loc[part["winner"] == "right", "idleft"],
            "loser_cat2": part.loc[part["winner"] == "left", "idright"]
        })

    losers_df = df.map_partitions(extract_losers).compute()
    loser_ids_cat1 = list(losers_df["loser_cat1"].dropna().unique())
    loser_ids_cat2 = list(losers_df["loser_cat2"].dropna().unique())

    # Remove discarded sources and merge the rest
    filtered_left = left_catalog[~left_catalog["id"].isin(loser_ids_cat1)]
    filtered_right = right_catalog[~right_catalog["id"].isin(loser_ids_cat2)]
    final_catalog = filtered_left.merge(filtered_right, how="outer", on="id", suffixes=("_left", "_right"))

    # Collapse columns using priority from left to right
    return final_catalog.assign(
        ra=final_catalog["ra_left"].combine_first(final_catalog["ra_right"]),
        dec=final_catalog["dec_left"].combine_first(final_catalog["dec_right"]),
        z=final_catalog["z_left"].combine_first(final_catalog["z_right"]),
        z_flag=final_catalog["z_flag_left"].combine_first(final_catalog["z_flag_right"]),
        survey=final_catalog["survey_left"].combine_first(final_catalog["survey_right"]),
        year=final_catalog["year_left"].combine_first(final_catalog["year_right"]),
    )[["id", "ra", "dec", "z", "z_flag", "survey", "year"]]

# ----------------------------
# PIPELINE EXECUTION FOR N CATALOGS
# ----------------------------

client = setup_cluster(config, logs_dir)

# Step 1: import the first catalog
import_catalog(catalogs[0], f"cat0_hats", temp_dir, client)
cat_prev = lsdb.read_hats(os.path.join(temp_dir, f"cat0_hats"))

# Step 2: sequential crossmatch and merge for the remaining catalogs
for i in range(1, len(catalogs)):
    import_catalog(catalogs[i], f"cat{i}_hats", temp_dir, client)
    generate_margin_cache(os.path.join(temp_dir, f"cat{i}_hats"), temp_dir, f"cat{i}_margin", client)
    cat_curr = lsdb.read_hats(
        os.path.join(temp_dir, f"cat{i}_hats"),
        margin_cache=os.path.join(temp_dir, f"cat{i}_margin")
    )
    cat_merged = crossmatch_and_merge(cat_prev, cat_curr, os.path.join(temp_dir, f"xmatch_{i}"))
    
    # Save intermediate collapsed version
    collapsed_path = os.path.join(temp_dir, f"collapsed_{i}")
    cat_merged.to_parquet(collapsed_path, by_layer=True)

    # Re-import as HATS to continue processing
    import_catalog(
        {
            "input_file_list": sorted(glob.glob(os.path.join(collapsed_path, "base", "part.*.parquet"))),
            "file_reader": "parquet",
            "columns": {
                "id": "id", "ra": "ra", "dec": "dec",
                "z": "z", "z_flag": "z_flag", "survey": "survey", "year": "year"
            },
            "ra": "ra", "dec": "dec"
        },
        f"collapsed_{i}_hats", temp_dir, client
    )
    cat_prev = lsdb.read_hats(os.path.join(temp_dir, f"collapsed_{i}_hats"))

# ----------------------------
# SAVE FINAL COLLAPSED PARQUET
# ----------------------------

final_collapsed_dir = os.path.join(temp_dir, f"collapsed_{len(catalogs) - 1}")
final_output_path = os.path.join(output_dir, f"{config['output_name']}.parquet")
final_parts = sorted(glob.glob(os.path.join(final_collapsed_dir, "base", "part.*.parquet")))

print(f"Saving single .parquet file to: {final_output_path}")

# Load by-layer parts into a single pandas DataFrame
df_final = dd.read_parquet(final_parts).compute()

# Save single-file Parquet (not directory)
df_final.to_parquet(
    final_output_path,
    engine="pyarrow",
    index=False
)

print("Pipeline completed successfully.")

# ----------------------------
# REMOVE TEMP DIRECTORY
# ----------------------------

delete_temp_in_the_end = False  # Set to True to remove temp folder after run

if delete_temp_in_the_end:
    print(f"Removing temporary directory: {temp_dir}")
    import shutil
    shutil.rmtree(temp_dir, ignore_errors=True)
    print("Temporary directory removed.")
else:
    print(f"Temporary directory kept at: {temp_dir}")

# Validation 1

In [None]:
path_to_xmatch_1 = "/scratch/users/luigi.silva/pzserver_pipelines/combine_specz/notebooks/process001/temp/xmatch_1"
path_to_collapsed_1 = "/scratch/users/luigi.silva/pzserver_pipelines/combine_specz/notebooks/process001/temp/collapsed_1/base"

In [None]:
# 1. Load the Dask DataFrame from the crossmatch result
df_xmatch_1 = lsdb.read_hats(path_to_xmatch_1)
df_xmatch_1 = df_xmatch_1._ddf  # Access underlying Dask DataFrame

# Optional: trigger full computation to inspect the crossmatch result
df_xmatch_1.compute()

In [None]:
# Load the collapsed catalog (already deduplicated) as a Pandas DataFrame
df_collapsed_1 = pd.read_parquet(path_to_collapsed_1)

# Preview contents of collapsed catalog
df_collapsed_1

In [None]:
# 2. Extract unique IDs from both sides of the crossmatch result
ids_cat_left = df_xmatch_1["idleft"].dropna().unique().compute()
ids_cat_right = df_xmatch_1["idright"].dropna().unique().compute()

# 3. Combine both sets of IDs into one unified set
ids_from_crossmatch = set(ids_cat_left.tolist()) | set(ids_cat_right.tolist())

# 4. Filter the collapsed final DataFrame to include only crossmatched objects
df_final_subset = df_collapsed_1[df_collapsed_1["id"].isin(ids_from_crossmatch)]

# 5. Display the filtered result (crossmatched objects only)
df_final_subset

# Validation 2

In [None]:
path_to_xmatch_2 = "/scratch/users/luigi.silva/pzserver_pipelines/combine_specz/notebooks/process001/temp/xmatch_2"
path_to_collapsed_final = final_output_path

In [None]:
# 1. Load the Dask DataFrame from the second crossmatch result
df_xmatch_2 = lsdb.read_hats(path_to_xmatch_2)
df_xmatch_2 = df_xmatch_2._ddf  # Access the underlying Dask DataFrame

# Optional: compute the crossmatch result to inspect it fully
df_xmatch_2.compute()

In [None]:
# Load the final collapsed catalog (after all merges) as a Pandas DataFrame
df_collapsed_final = pd.read_parquet(path_to_collapsed_final)

# Display the final merged catalog
df_collapsed_final

In [None]:
# 2. Extract unique source IDs from both left and right columns of the crossmatch
ids_cat_left = df_xmatch_2["idleft"].dropna().unique().compute()
ids_cat_right = df_xmatch_2["idright"].dropna().unique().compute()

# 3. Combine both arrays into a single set of matched IDs
ids_from_crossmatch = set(ids_cat_left.tolist()) | set(ids_cat_right.tolist())

# 4. Filter the final catalog to include only the matched sources
df_final_subset = df_collapsed_final[df_collapsed_final["id"].isin(ids_from_crossmatch)]

# 5. Display the filtered DataFrame (objects involved in the last crossmatch)
df_final_subset