# Nesting

Create catalogs for `dia_object` and `object` with nested sources and forced sources.

In [14]:
import os
import lsdb
import tempfile
import hats_import.pipeline as runner

from pathlib import Path
from dask.distributed import Client
from hats_import.margin_cache.margin_cache_arguments import MarginCacheArguments

In [15]:
VERSION = os.environ["VERSION"]
OUTPUT_DIR = Path(os.environ["OUTPUT_DIR"])

print(f"VERSION: {VERSION}")
print(f"OUTPUT_DIR: {OUTPUT_DIR}")

raw_dir = OUTPUT_DIR / "raw" / VERSION
hats_dir = OUTPUT_DIR / "hats" / VERSION

In [16]:
tmp_path = tempfile.TemporaryDirectory()
tmp_dir = tmp_path.name
client = Client(n_workers=16, threads_per_worker=1, local_directory=tmp_dir)

In [17]:
def sort_nested_sources(df, source_cols):
    mjd_col = "midpointMjdTai"
    for source_col in source_cols:
        flat_sources = df[source_col].nest.to_flat()
        df = df.drop(columns=[source_col])
        df = df.add_nested(
            flat_sources.sort_values([flat_sources.index.name, mjd_col]), source_col
        )
    return df

### Generate margin caches

To nest the sources accurately we need to generate intermediate margin caches for those catalogs. They will be temporarily stored in a scratch directory and automatically erased at the end of the notebook.

In [18]:
margin_radius_arcsec = 2

In [19]:
args = MarginCacheArguments(
    input_catalog_path=hats_dir / "dia_source",
    output_path=tmp_dir,
    margin_threshold=margin_radius_arcsec,
    output_artifact_name=f"dia_source_{margin_radius_arcsec}arcs",
    simple_progress_bar=True,
    resume=False,
)
runner.pipeline_with_client(args, client)

In [20]:
args = MarginCacheArguments(
    input_catalog_path=hats_dir / "dia_object_forced_source",
    output_path=tmp_dir,
    margin_threshold=margin_radius_arcsec,
    output_artifact_name=f"dia_object_forced_source_{margin_radius_arcsec}arcs",
    simple_progress_bar=True,
    resume=False,
)
runner.pipeline_with_client(args, client)

In [21]:
args = MarginCacheArguments(
    input_catalog_path=hats_dir / "object_forced_source",
    output_path=tmp_dir,
    margin_threshold=margin_radius_arcsec,
    output_artifact_name=f"object_forced_source_{margin_radius_arcsec}arcs",
    simple_progress_bar=True,
    resume=False,
)
runner.pipeline_with_client(args, client)

### dia_object with nested sources

In [22]:
dia_object_cat = lsdb.read_hats(hats_dir / "dia_object")

dia_source_cat = lsdb.read_hats(
    hats_dir / "dia_source",
    margin_cache=Path(tmp_dir) / f"dia_source_{margin_radius_arcsec}arcs",
)

dia_object_forced_source_cat = lsdb.read_hats(
    hats_dir / "dia_object_forced_source",
    margin_cache=Path(tmp_dir) / f"dia_object_forced_source_{margin_radius_arcsec}arcs",
)

In [23]:
dia_object_cat_nested = dia_object_cat.join_nested(
    dia_source_cat,
    left_on="diaObjectId",
    right_on="diaObjectId",
    nested_column_name="diaSource",
).join_nested(
    dia_object_forced_source_cat,
    left_on="diaObjectId",
    right_on="diaObjectId",
    nested_column_name="diaObjectForcedSource",
)
dia_object_cat_nested

Also, for each object, sort sources by timestamp:

In [24]:
dia_object_cat_nested = dia_object_cat_nested.map_partitions(
    lambda x: sort_nested_sources(x, source_cols=["diaSource", "diaObjectForcedSource"])
)

Save resulting catalog to disk:

In [25]:
dia_object_cat_nested.to_hats(hats_dir / "dia_object_lc", catalog_name="dia_object_lc")

Reading the catalog with LSDB:

In [26]:
dia_object_lc = lsdb.read_hats(hats_dir / "dia_object_lc")
dia_object_lc

### object with nested sources

In [27]:
object_cat = lsdb.read_hats(hats_dir / "object")

object_forced_source_cat = lsdb.read_hats(
    hats_dir / "object_forced_source",
    margin_cache=Path(tmp_dir) / f"object_forced_source_{margin_radius_arcsec}arcs",
)

In [28]:
object_cat_nested = object_cat.join_nested(
    object_forced_source_cat,
    left_on="objectId",
    right_on="objectId",
    nested_column_name="objectForcedSource",
)
object_cat_nested

Also, for each object, sort sources by timestamp:

In [29]:
object_cat_nested = object_cat_nested.map_partitions(
    lambda x: sort_nested_sources(x, source_cols=["objectForcedSource"])
)

Save resulting catalog to disk:

In [30]:
object_cat_nested.to_hats(hats_dir / "object_lc", catalog_name="object_lc")

Reading the catalog with LSDB:

In [31]:
object_lc = lsdb.read_hats(hats_dir / "object_lc")
object_lc

In [32]:
client.close()
tmp_path.cleanup()