# Nesting

Create catalogs for `dia_object` and `object` with nested sources and forced sources.

In [None]:
import os
import lsdb
import tempfile

from pathlib import Path
from dask.distributed import Client
from hats_import import pipeline_with_client
from hats_import.catalog import ImportArguments
from hats_import.margin_cache.margin_cache_arguments import MarginCacheArguments

In [None]:
def full_column_names(cat):
    for c in cat.columns:
        cc = cat[c]
        if not hasattr(cc, 'nest'):
            yield c
        else:
            for f in cc.nest.columns:
                yield f'{c}.{f}'


In [None]:
VERSION = os.environ["VERSION"]
OUTPUT_DIR = Path(os.environ["OUTPUT_DIR"])

print(f"VERSION: {VERSION}")
print(f"OUTPUT_DIR: {OUTPUT_DIR}")

raw_dir = OUTPUT_DIR / "raw" / VERSION
hats_dir = OUTPUT_DIR / "hats" / VERSION

In [None]:
tmp_path = tempfile.TemporaryDirectory()
tmp_dir = tmp_path.name
client = Client(n_workers=16, threads_per_worker=1, local_directory=tmp_dir)

In [None]:
def sort_nested_sources(df, source_cols):
    mjd_col = "midpointMjdTai"
    for source_col in source_cols:
        flat_sources = df[source_col].nest.to_flat()
        df = df.drop(columns=[source_col])
        df = df.join_nested(
            flat_sources.sort_values([flat_sources.index.name, mjd_col]), source_col
        )
    return df

### Generate margin caches

To nest the sources accurately we need to generate intermediate margin caches for those catalogs. They will be temporarily stored in a scratch directory and automatically erased at the end of the notebook.

In [None]:
margin_radius_arcsec = 2

In [None]:
args = MarginCacheArguments(
    input_catalog_path=hats_dir / "dia_source",
    output_path=tmp_dir,
    margin_threshold=margin_radius_arcsec,
    output_artifact_name=f"dia_source_{margin_radius_arcsec}arcs",
    simple_progress_bar=True,
    resume=False,
)
pipeline_with_client(args, client)

In [None]:
args = MarginCacheArguments(
    input_catalog_path=hats_dir / "dia_object_forced_source",
    output_path=tmp_dir,
    margin_threshold=margin_radius_arcsec,
    output_artifact_name=f"dia_object_forced_source_{margin_radius_arcsec}arcs",
    simple_progress_bar=True,
    resume=False,
)
pipeline_with_client(args, client)

In [None]:
args = MarginCacheArguments(
    input_catalog_path=hats_dir / "object_forced_source",
    output_path=tmp_dir,
    margin_threshold=margin_radius_arcsec,
    output_artifact_name=f"object_forced_source_{margin_radius_arcsec}arcs",
    simple_progress_bar=True,
    resume=False,
)
pipeline_with_client(args, client)

### dia_object with nested sources

In [None]:
dia_object_cat = lsdb.read_hats(hats_dir / "dia_object")

dia_source_cat = lsdb.read_hats(
    hats_dir / "dia_source",
    margin_cache=Path(tmp_dir) / f"dia_source_{margin_radius_arcsec}arcs",
)

dia_object_forced_source_cat = lsdb.read_hats(
    hats_dir / "dia_object_forced_source",
    margin_cache=Path(tmp_dir) / f"dia_object_forced_source_{margin_radius_arcsec}arcs",
)

In [None]:
dia_object_cat_nested = dia_object_cat.join_nested(
    dia_source_cat,
    left_on="diaObjectId",
    right_on="diaObjectId",
    nested_column_name="diaSource",
).join_nested(
    dia_object_forced_source_cat,
    left_on="diaObjectId",
    right_on="diaObjectId",
    nested_column_name="diaObjectForcedSource",
)
dia_object_cat_nested

Also, for each object, sort sources by timestamp:

In [None]:
dia_object_cat_nested = dia_object_cat_nested.map_partitions(
    lambda x: sort_nested_sources(x, source_cols=["diaSource", "diaObjectForcedSource"])
)

And save the resulting catalog to disk:

In [None]:
dia_object_cat_nested.to_hats(hats_dir / "dia_object_lc_intermediate", catalog_name="dia_object_lc")

Finally, reimport with a new threshold, and select the columns to be loaded by default:

In [None]:
dia_object_desired_cols = [
    "dec\n",
    "diaObjectForcedSource.band\n",
    "diaObjectForcedSource.coord_dec\n",
    "diaObjectForcedSource.coord_ra\n",
    "diaObjectForcedSource.diff_PixelFlags_nodataCenter\n",
    "diaObjectForcedSource.invalidPsfFlag\n",
    "diaObjectForcedSource.midpointMjdTai\n",
    "diaObjectForcedSource.pixelFlags_bad\n",
    "diaObjectForcedSource.pixelFlags_cr\n",
    "diaObjectForcedSource.pixelFlags_crCenter\n",
    "diaObjectForcedSource.pixelFlags_edge\n",
    "diaObjectForcedSource.pixelFlags_interpolated\n",
    "diaObjectForcedSource.pixelFlags_interpolatedCenter\n",
    "diaObjectForcedSource.pixelFlags_nodata\n",
    "diaObjectForcedSource.pixelFlags_saturated\n",
    "diaObjectForcedSource.pixelFlags_saturatedCenter\n",
    "diaObjectForcedSource.pixelFlags_suspect\n",
    "diaObjectForcedSource.pixelFlags_suspectCenter\n",
    "diaObjectForcedSource.psfDiffFlux\n",
    "diaObjectForcedSource.psfDiffFlux_flag\n",
    "diaObjectForcedSource.psfDiffFluxErr\n",
    "diaObjectForcedSource.psfFlux\n",
    "diaObjectForcedSource.psfFlux_flag\n",
    "diaObjectForcedSource.psfFluxErr\n",
    "diaObjectForcedSource.psfMag\n",
    "diaObjectForcedSource.psfMagErr\n",
    "diaObjectForcedSource.visit\n",
    "diaObjectId\n",
    "diaSource.band\n",
    "diaSource.centroid_flag\n",
    "diaSource.coord_dec\n",
    "diaSource.coord_ra\n",
    "diaSource.dec\n",
    "diaSource.decErr\n",
    "diaSource.diaSourceId\n",
    "diaSource.forced_PsfFlux_flag\n",
    "diaSource.forced_PsfFlux_flag_edge\n",
    "diaSource.forced_PsfFlux_flag_noGoodPixels\n",
    "diaSource.midpointMjdTai\n",
    "diaSource.pixelFlags\n",
    "diaSource.pixelFlags_bad\n",
    "diaSource.pixelFlags_cr\n",
    "diaSource.pixelFlags_crCenter\n",
    "diaSource.pixelFlags_edge\n",
    "diaSource.pixelFlags_interpolated\n",
    "diaSource.pixelFlags_interpolatedCenter\n",
    "diaSource.pixelFlags_nodata\n",
    "diaSource.pixelFlags_nodataCenter\n",
    "diaSource.pixelFlags_offimage\n",
    "diaSource.pixelFlags_saturated\n",
    "diaSource.pixelFlags_saturatedCenter\n",
    "diaSource.pixelFlags_streak\n",
    "diaSource.pixelFlags_streakCenter\n",
    "diaSource.pixelFlags_suspect\n",
    "diaSource.pixelFlags_suspectCenter\n",
    "diaSource.psfFlux\n",
    "diaSource.psfFlux_flag\n",
    "diaSource.psfFlux_flag_edge\n",
    "diaSource.psfFlux_flag_noGoodPixels\n",
    "diaSource.psfFluxErr\n",
    "diaSource.psfMag\n",
    "diaSource.psfMagErr\n",
    "diaSource.ra\n",
    "diaSource.raErr\n",
    "diaSource.reliability\n",
    "diaSource.scienceFlux\n",
    "diaSource.scienceFluxErr\n",
    "diaSource.scienceMag\n",
    "diaSource.scienceMagErr\n",
    "diaSource.shape_flag\n",
    "diaSource.shape_flag_no_pixels\n",
    "diaSource.shape_flag_not_contained\n",
    "diaSource.shape_flag_parent_source\n",
    "diaSource.snr\n",
    "diaSource.trail_flag_edge\n",
    "diaSource.visit\n",
    "diaSource.x\n",
    "diaSource.xErr\n",
    "diaSource.y\n",
    "diaSource.yErr\n",
    "nDiaSources\n",
    "ra\n",
    "tract\n",
]
dia_object_desired_cols = [c.strip() for c in dia_object_desired_cols]
dia_object_actual_cols = set(full_column_names(dia_object_cat_nested))
dia_object_missing_cols = sorted(set(dia_object_desired_cols) - dia_object_actual_cols)
if dia_object_missing_cols:
    print(
        "Warning: requested default columns missing from catalog: "
        + ", ".join(dia_object_missing_cols)
    )
hats_cols_default = ",".join(
    [c for c in dia_object_desired_cols if c in dia_object_actual_cols]
)


In [None]:
args = ImportArguments.reimport_from_hats(
    hats_dir / "dia_object_lc_intermediate",
    output_dir=hats_dir,
    highest_healpix_order=11,
    pixel_threshold=15_000,
    skymap_alt_orders=[2, 4, 6],
    row_group_kwargs={"subtile_order_delta": 1},
    addl_hats_properties={"hats_cols_default": hats_cols_default},
)
pipeline_with_client(args, client)

In [None]:
%rm -r $hats_dir/dia_object_lc_intermediate

### object with nested sources

In [None]:
object_cat = lsdb.read_hats(hats_dir / "object")

object_forced_source_cat = lsdb.read_hats(
    hats_dir / "object_forced_source",
    margin_cache=Path(tmp_dir) / f"object_forced_source_{margin_radius_arcsec}arcs",
)

In [None]:
object_cat_nested = object_cat.join_nested(
    object_forced_source_cat,
    left_on="objectId",
    right_on="objectId",
    nested_column_name="objectForcedSource",
)
object_cat_nested

Also, for each object, sort sources by timestamp:

In [None]:
object_cat_nested = object_cat_nested.map_partitions(
    lambda x: sort_nested_sources(x, source_cols=["objectForcedSource"])
)

And save the resulting catalog to disk:

In [None]:
object_cat_nested.to_hats(hats_dir / "object_lc_intermediate", catalog_name="object_lc")

Finally, reimport with a new threshold, and select the columns to be loaded by default:

In [None]:
object_desired_cols = [
    "coord_dec\n",
    "coord_decErr\n",
    "coord_ra\n",
    "coord_raErr\n",
    "g_psfFlux\n",
    "g_psfFluxErr\n",
    "g_psfMag\n",
    "g_psfMagErr\n",
    "i_psfFlux\n",
    "i_psfFluxErr\n",
    "i_psfMag\n",
    "i_psfMagErr\n",
    "objectForcedSource.band\n",
    "objectForcedSource.coord_dec\n",
    "objectForcedSource.coord_ra\n",
    "objectForcedSource.detector\n",
    "objectForcedSource.forcedSourceId\n",
    "objectForcedSource.invalidPsfFlag\n",
    "objectForcedSource.midpointMjdTai\n",
    "objectForcedSource.pixelFlags_bad\n",
    "objectForcedSource.pixelFlags_cr\n",
    "objectForcedSource.pixelFlags_crCenter\n",
    "objectForcedSource.pixelFlags_edge\n",
    "objectForcedSource.pixelFlags_interpolated\n",
    "objectForcedSource.pixelFlags_interpolatedCenter\n",
    "objectForcedSource.pixelFlags_nodata\n",
    "objectForcedSource.pixelFlags_saturated\n",
    "objectForcedSource.pixelFlags_saturatedCenter\n",
    "objectForcedSource.pixelFlags_suspect\n",
    "objectForcedSource.pixelFlags_suspectCenter\n",
    "objectForcedSource.psfDiffFlux\n",
    "objectForcedSource.psfDiffFlux_flag\n",
    "objectForcedSource.psfDiffFluxErr\n",
    "objectForcedSource.psfFlux\n",
    "objectForcedSource.psfFlux_flag\n",
    "objectForcedSource.psfFluxErr\n",
    "objectForcedSource.psfMag\n",
    "objectForcedSource.psfMagErr\n",
    "objectForcedSource.visit\n",
    "objectId\n",
    "patch\n",
    "r_psfFlux\n",
    "r_psfFluxErr\n",
    "r_psfMag\n",
    "r_psfMagErr\n",
    "refBand\n",
    "shape_flag\n",
    "shape_xx\n",
    "shape_xy\n",
    "shape_yy\n",
    "tract\n",
    "u_psfFlux\n",
    "u_psfFluxErr\n",
    "u_psfMag\n",
    "u_psfMagErr\n",
    "x\n",
    "xErr\n",
    "y\n",
    "y_psfFlux\n",
    "y_psfFluxErr\n",
    "y_psfMag\n",
    "y_psfMagErr\n",
    "yErr\n",
    "z_psfFlux\n",
    "z_psfFluxErr\n",
    "z_psfMag\n",
    "z_psfMagErr\n",
]
object_desired_cols = [c.strip() for c in object_desired_cols]
object_actual_cols = set(full_column_names(object_cat_nested))
object_missing_cols = sorted(set(object_desired_cols) - object_actual_cols)
if object_missing_cols:
    print(
        "Warning: requested default columns missing from catalog: "
        + ", ".join(object_missing_cols)
    )
hats_cols_default = ",".join(
    [c for c in object_desired_cols if c in object_actual_cols]
)


In [None]:
args = ImportArguments.reimport_from_hats(
    hats_dir / "object_lc_intermediate",
    output_dir=hats_dir,
    highest_healpix_order=11,
    pixel_threshold=15_000,
    skymap_alt_orders=[2, 4, 6],
    row_group_kwargs={"subtile_order_delta": 1},
    addl_hats_properties={"hats_cols_default": hats_cols_default},
)
pipeline_with_client(args, client)

In [None]:
%rm -r $hats_dir/object_lc_intermediate

In [None]:
client.close()
tmp_path.cleanup()