# Nesting

Create catalogs for `dia_object` and `object` with nested sources and forced sources.

In [None]:
import os
import lsdb
import tempfile
import hats_import.pipeline as runner

from pathlib import Path
from dask.distributed import Client
from hats_import.margin_cache.margin_cache_arguments import MarginCacheArguments
from nested_pandas import NestedDtype

In [None]:
DRP_VERSION = os.environ["DRP_VERSION"]
print(f"DRP_VERSION: {DRP_VERSION}")
base_output_dir = Path(f"/sdf/data/rubin/shared/lsdb_commissioning")
raw_dir = base_output_dir / "raw" / DRP_VERSION
hats_dir = base_output_dir / "hats" / DRP_VERSION

In [None]:
tmp_path = tempfile.TemporaryDirectory()
tmp_dir = tmp_path.name
client = Client(n_workers=16, threads_per_worker=1, local_directory=tmp_dir)

In [4]:
def sort_nested_sources(df, source_cols):
    mjd_col = "midpointMjdTai"
    for source_col in source_cols:
        flat_sources = df[source_col].nest.to_flat()
        df = df.drop(columns=[source_col])
        df = df.add_nested(
            flat_sources.sort_values([flat_sources.index.name, mjd_col]), source_col
        )
    return df

### Generate margin caches

To nest the sources accurately we need to generate intermediate margin caches for those catalogs. They will be temporarily stored in a scratch directory and automatically erased at the end of the notebook.

In [None]:
margin_radius_arcsec = 2

In [None]:
args = MarginCacheArguments(
    input_catalog_path=hats_dir / "dia_source",
    output_path=tmp_dir,
    margin_threshold=margin_radius_arcsec,
    output_artifact_name=f"dia_source_{margin_radius_arcsec}arcs",
)
runner.pipeline_with_client(args, client)

In [None]:
args = MarginCacheArguments(
    input_catalog_path=hats_dir / "dia_object_forced_source",
    output_path=tmp_dir,
    margin_threshold=margin_radius_arcsec,
    output_artifact_name=f"dia_object_forced_source_{margin_radius_arcsec}arcs",
)
runner.pipeline_with_client(args, client)

In [None]:
args = MarginCacheArguments(
    input_catalog_path=hats_dir / "object_forced_source",
    output_path=tmp_dir,
    margin_threshold=margin_radius_arcsec,
    output_artifact_name=f"object_forced_source_{margin_radius_arcsec}arcs",
)
runner.pipeline_with_client(args, client)

### dia_object with nested sources

In [None]:
dia_object_cat = lsdb.read_hats(hats_dir / "dia_object")

dia_source_cat = lsdb.read_hats(
    hats_dir / "dia_source",
    margin_cache=Path(tmp_dir) / f"dia_source_{margin_radius_arcsec}arcs",
)

dia_object_forced_source_cat = lsdb.read_hats(
    hats_dir / "dia_object_forced_source",
    margin_cache=Path(tmp_dir) / f"dia_object_forced_source_{margin_radius_arcsec}arcs",
)

In [None]:
dia_object_cat_nested = dia_object_cat.join_nested(
    dia_source_cat,
    left_on="diaObjectId",
    right_on="diaObjectId",
    nested_column_name="diaSource",
).join_nested(
    dia_object_forced_source_cat,
    left_on="diaObjectId",
    right_on="diaObjectId",
    nested_column_name="diaObjectForcedSource",
)
dia_object_cat_nested

Also, for each object, sort sources by timestamp:

In [None]:
dia_object_cat_nested = dia_object_cat_nested.map_partitions(
    lambda x: sort_nested_sources(x, source_cols=["diaSource", "diaObjectForcedSource"])
)

Save resulting catalog to disk:

In [None]:
dia_object_cat_nested.to_hats(hats_dir / "dia_object_lc", catalog_name="dia_object_lc")

Reading with LSDB currently requires a bit of manipulation:

In [None]:
dia_object_lc = lsdb.read_hats(hats_dir / "dia_object_lc").map_partitions(
    lambda df: df.assign(
        **{
            lc_column: df[lc_column].astype(
                NestedDtype.from_pandas_arrow_dtype(df.dtypes[lc_column])
            )
            for lc_column in ["diaSource", "diaObjectForcedSource"]
        }
    )
)
dia_object_lc

### object with nested sources

In [None]:
object_cat = lsdb.read_hats(hats_dir / "object")

object_forced_source_cat = lsdb.read_hats(
    hats_dir / "object_forced_source",
    margin_cache=Path(tmp_dir) / f"object_forced_source_{margin_radius_arcsec}arcs",
)

In [None]:
object_cat_nested = object_cat.join_nested(
    object_forced_source_cat,
    left_on="objectId",
    right_on="objectId",
    nested_column_name="objectForcedSource",
)
object_cat_nested

Also, for each object, sort sources by timestamp:

In [None]:
object_cat_nested = object_cat_nested.map_partitions(
    lambda x: sort_nested_sources(x, source_cols=["objectForcedSource"])
)

Save resulting catalog to disk:

In [None]:
object_cat_nested.to_hats(hats_dir / "object_lc", catalog_name="object_lc")

Reading with LSDB currently requires a bit of manipulation:

In [None]:
object_lc = lsdb.read_hats(hats_dir / "object_lc").map_partitions(
    lambda df: df.assign(
        **{
            "objectForcedSource": df["objectForcedSource"].astype(
                NestedDtype.from_pandas_arrow_dtype(df.dtypes["objectForcedSource"])
            )
        }
    )
)
object_lc

In [None]:
client.close()
tmp_path.cleanup()