# Post-processing

We will modify each parquet file in place. This seems like a good idea today, but could be crap tomorrow.

If we use LSDB, we will need to use additional disk storage, both for fresh and post-processed data.

Elements of post-processing to be accomplished in this notebook:

* brightness in magnitude (e.g. convert ALL flux to magnitude)
* join to visit table, where necessary

In [1]:
import os
import astropy.units as u
import hats
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import tempfile

from tqdm.auto import tqdm
from pathlib import Path
from dask.distributed import as_completed, Client
from hats.catalog import PartitionInfo
from hats.io import paths
from hats.io.parquet_metadata import write_parquet_metadata
from datetime import datetime, timezone

In [2]:
DRP_VERSION = os.environ["DRP_VERSION"]
print(f"DRP_VERSION: {DRP_VERSION}")
base_output_dir = Path(f"/sdf/data/rubin/shared/lsdb_commissioning")
raw_dir = base_output_dir / "raw" / DRP_VERSION
hats_dir = base_output_dir / "hats" / DRP_VERSION

In [3]:
visit_table = pd.read_parquet(raw_dir / "visit_table.parquet")
visit_map = visit_table.set_index("visitId")["expMidptMJD"].to_dict()

In [None]:
def append_mag_and_magerr(table, flux_col_prefixes):
    """Calculate magnitudes and their errors for flux columns."""
    mag_cols = {}
    for prefix in flux_col_prefixes:
        # Magnitude
        flux = table[f"{prefix}Flux"]
        mag = u.nJy.to(u.ABmag, flux)
        mag_cols[f"{prefix}Mag"] = mag
        # Magnitude error, if flux error exists
        fluxErr_col = f"{prefix}FluxErr"
        if fluxErr_col in table.columns:
            fluxErr = table[fluxErr_col]
            upper_mag = u.nJy.to(u.ABmag, flux + fluxErr)
            lower_mag = u.nJy.to(u.ABmag, flux - fluxErr)
            magErr = -(upper_mag - lower_mag) / 2
            mag_cols[f"{prefix}MagErr"] = magErr
    mag_table = pd.DataFrame(mag_cols, dtype=np.float32, index=table.index)
    return pd.concat([table, mag_table], axis=1)


def add_mjd_from_visit(table):
    """Add mjd (if it does not exist) from the visit mapping"""
    if "visit" not in table.columns:
        raise ValueError("`visit` column is missing")
    if "midpointMjdTai" in table.columns:
        raise ValueError("`mjd` is already present in table")
    mjds = list(map(lambda x: visit_map.get(x, np.nan), table["visit"]))
    table["midpointMjdTai"] = pd.Series(mjds, index=table.index)
    return table

Initialize a Dask Client to parallelize the post-processing:

In [5]:
tmp_path = tempfile.TemporaryDirectory()
tmp_dir = tmp_path.name
client = Client(n_workers=16, threads_per_worker=1, local_directory=tmp_dir)

In [10]:
def postprocess_catalog(catalog_name, flux_col_prefixes=[], add_mjds=False):
    catalog_dir = hats_dir / catalog_name
    catalog = hats.read_hats(catalog_dir)
    futures = []
    for target_pixel in catalog.get_healpix_pixels():
        futures.append(
            client.submit(
                process_partition,
                catalog_dir=catalog_dir,
                target_pixel=target_pixel,
                flux_col_prefixes=flux_col_prefixes,
                add_mjds=add_mjds,
            )
        )
    wait_for_futures(futures, catalog_name)
    rewrite_catalog_metadata(catalog)


def process_partition(catalog_dir, target_pixel, flux_col_prefixes, add_mjds):
    """Apply post-processing steps to each individual partition"""
    file_path = hats.io.pixel_catalog_file(catalog_dir, target_pixel)
    table = pd.read_parquet(file_path)
    # Add magnitudes and mjds
    if len(flux_col_prefixes) > 0:
        table = append_mag_and_magerr(table, flux_col_prefixes)
    if add_mjds:
        table = add_mjd_from_visit(table)
    # Overwrite partition on disk
    final_table = pa.Table.from_pandas(
        table, preserve_index=False
    ).replace_schema_metadata()
    pq.write_table(final_table, file_path.path)


def wait_for_futures(futures, catalog_name):
    for future in tqdm(as_completed(futures), desc=catalog_name, total=len(futures)):
        if future.status == "error":
            raise future.exception()


def rewrite_catalog_metadata(catalog):
    """Update catalog metadata after processing the leaf parquet files"""
    destination_path = hats_dir / catalog.catalog_name

    parquet_rows = write_parquet_metadata(destination_path)

    # Read partition info from _metadata and write to partition_info.csv
    partition_info = PartitionInfo.read_from_dir(destination_path)
    partition_info_file = paths.get_partition_info_pointer(destination_path)
    partition_info.write_to_file(partition_info_file)

    now = datetime.now(tz=timezone.utc)

    catalog.catalog_info.copy_and_update(
        total_rows=parquet_rows, hats_creation_date=now.strftime("%Y-%m-%dT%H:%M%Z")
    ).to_properties_file(destination_path)

## dia_object

This one is the easiest because it doesn't require ANY post-processing!!!

## dia_source

We need to add the psf/science magnitudes and their errors.

In [11]:
postprocess_catalog("dia_source", flux_col_prefixes=["psf", "science"])

## dia_object_forced_source

We need to add the psf magnitudes and their errors.

We add the `midpointMjdTai` from the visits table lookup.

In [12]:
postprocess_catalog(
    "dia_object_forced_source", flux_col_prefixes=["psf"], add_mjds=True
)

## object

We need to add the psf/kron magnitudes, for each band, and their errors.

In [13]:
flux_col_prefixes = []
for band in list("ugrizy"):
    for flux_name in ["psf", "kron"]:
        band_col = f"{band}_{flux_name}"
        flux_col_prefixes.append(band_col)
print(flux_col_prefixes)

In [14]:
postprocess_catalog("object", flux_col_prefixes=flux_col_prefixes)

## source

We need to add the psf magnitudes and their errors.

We add the `midpointMjdTai` from the visits table lookup.

In [15]:
postprocess_catalog("source", flux_col_prefixes=["psf"], add_mjds=True)

## object_forced_source

We need to add the psf magnitudes and their errors.

We add the `midpointMjdTai` from the visits table lookup.

In [16]:
postprocess_catalog("object_forced_source", flux_col_prefixes=["psf"], add_mjds=True)

In [17]:
client.close()
tmp_path.cleanup()