## Post-processing

- Keep rows of latest validity start, for each object.
- Add magnitude science columns.
- Cast non-(positional/time) columns to float32.

In [1]:
import astropy.units as u
import hats
import os
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

from dask.distributed import as_completed, Client
from hats.catalog import PartitionInfo
from hats.io import paths
from hats.io.parquet_metadata import write_parquet_metadata
from datetime import datetime, timezone
from pathlib import Path
from tqdm import tqdm

In [2]:
# Path to the target OUTPUT directories
TMP_DIR = Path(os.environ["OUTPUT_DIR"]) / "tmp"

In [3]:
# Initialize Dask client
client = Client(n_workers=16, threads_per_worker=1, local_directory=TMP_DIR)

In [None]:
def select_by_latest_validity(table):
    """Select rows with the latest validityStart for each object."""
    return table.sort_values("validityStart").drop_duplicates(
        "diaObjectId", keep="last"
    )


def append_mag_and_magerr(table, flux_cols):
    """Calculate magnitudes and their errors for flux columns."""
    mag_cols = {}

    for flux_col in flux_cols:
        flux_col_err = f"{flux_col}Err"
        mag_col = flux_col.replace("Flux", "Mag")
        mag_col_err = f"{mag_col}Err"

        flux = table[flux_col]
        mag = u.nJy.to(u.ABmag, flux)
        mag_cols[mag_col] = mag

        flux_err = table[flux_col_err]
        upper_mag = u.nJy.to(u.ABmag, flux + flux_err)
        lower_mag = u.nJy.to(u.ABmag, flux - flux_err)
        magErr = -(upper_mag - lower_mag) / 2
        mag_cols[mag_col_err] = magErr

    mag_table = pd.DataFrame(
        mag_cols, dtype=pd.ArrowDtype(pa.float32()), index=table.index
    )
    return pd.concat([table, mag_table], axis=1)


def cast_columns_float32(table):
    """Cast non-(positional/time) columns to single-precision"""
    position_time_cols = [
        "ra",
        "dec",
        "raErr",
        "decErr",
        "x",
        "y",
        "xErr",
        "yErr",
        "midpointMjdTai",
        "radecMjdTai",
    ]
    columns_to_cast = [
        field
        for (field, type) in table.dtypes.items()
        if field not in position_time_cols and type == pd.ArrowDtype(pa.float64())
    ]
    dtype_map = {col: pd.ArrowDtype(pa.float32()) for col in columns_to_cast}
    return table.astype(dtype_map)

Let's add code to parallelize these operations:

In [None]:
def postprocess_catalog(catalog_name, flux_col_prefixes):
    catalog_dir = TMP_DIR / catalog_name
    catalog = hats.read_hats(catalog_dir)
    futures = []
    for target_pixel in catalog.get_healpix_pixels():
        futures.append(
            client.submit(
                process_partition,
                catalog_dir=catalog_dir,
                target_pixel=target_pixel,
                flux_col_prefixes=flux_col_prefixes,
            )
        )
    for future in tqdm(as_completed(futures), desc=catalog_name, total=len(futures)):
        if future.status == "error":
            raise future.exception()
    rewrite_catalog_metadata(catalog)


def process_partition(catalog_dir, target_pixel, flux_col_prefixes):
    """Apply post-processing steps to each individual partition"""
    file_path = hats.io.pixel_catalog_file(catalog_dir, target_pixel)
    table = pd.read_parquet(file_path, dtype_backend="pyarrow")
    if "validityStart" in table.columns:
        table = select_by_latest_validity(table)
    if len(flux_col_prefixes) > 0:
        table = append_mag_and_magerr(table, flux_col_prefixes)
    table = cast_columns_float32(table)
    final_table = pa.Table.from_pandas(
        table, preserve_index=False
    ).replace_schema_metadata()
    pq.write_table(final_table, file_path.path)


def rewrite_catalog_metadata(catalog):
    """Update catalog metadata after processing the leaf parquet files"""
    destination_path = TMP_DIR / catalog.catalog_name
    parquet_rows = write_parquet_metadata(destination_path)
    # Read partition info from _metadata and write to partition_info.csv
    partition_info = PartitionInfo.read_from_dir(destination_path)
    partition_info_file = paths.get_partition_info_pointer(destination_path)
    partition_info.write_to_file(partition_info_file)
    now = datetime.now(tz=timezone.utc)
    catalog.catalog_info.copy_and_update(
        total_rows=parquet_rows, hats_creation_date=now.strftime("%Y-%m-%dT%H:%M%Z")
    ).to_properties_file(destination_path)

For DIA objects, calculate the mean magnitudes per band:

In [6]:
flux_col_prefixes = [f"{band}_scienceFluxMean" for band in list("ugrizy")]
postprocess_catalog("dia_object", flux_col_prefixes=flux_col_prefixes)

  return getattr(ufunc, method)(*new_inputs, **kwargs)
dia_object: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:09<00:00,  9.70s/it]


For DIA source and forced source, calculate their science magnitudes:

In [7]:
postprocess_catalog("dia_source", flux_col_prefixes=["scienceFlux"])

  return getattr(ufunc, method)(*new_inputs, **kwargs)
dia_source: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.23s/it]


In [8]:
postprocess_catalog("dia_forced_source", flux_col_prefixes=["scienceFlux"])

  return getattr(ufunc, method)(*new_inputs, **kwargs)
dia_forced_source: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.77s/it]


In [9]:
client.close()