# Import to HATS

Use hats-import to ingest the parquet URLs and create each HATS catalog.

In [1]:
# %pip install -q lsdb hats-import

In [1]:
import os
import tempfile
import pandas as pd
import hats_import.pipeline as runner

from pathlib import Path
from dask.distributed import Client
from hats_import.catalog.arguments import ImportArguments
from hats_import.catalog.file_readers import ParquetPyarrowReader

In [None]:
DRP_VERSION = os.environ["DRP_VERSION"]
COLLECTION_TAG = os.environ["COLLECTION_TAG"]
print(f"DRP_VERSION: {DRP_VERSION}")
print(f"COLLECTION_TAG: {COLLECTION_TAG}")
base_output_dir = Path(f"/sdf/data/rubin/shared/lsdb_commissioning/hats/{DRP_VERSION}")
collections = f"LSSTComCam/runs/DRP/DP1/{DRP_VERSION}/{COLLECTION_TAG}"

In [4]:
raw_dir = base_output_dir / "raw"
hats_dir = base_output_dir / "hats"
hats_dir.mkdir(parents=True, exist_ok=True)

In [5]:
tmp_path = tempfile.TemporaryDirectory()
tmp_dir = tmp_path.name
client = Client(n_workers=4, threads_per_worker=1, local_directory=tmp_dir)

### Helper methods

In [6]:
def get_paths(dataset_type, raw_dir):
    file_pointer = raw_dir / "paths" /  f"{dataset_type}.txt"
    with file_pointer.open("r", encoding="utf8") as _text_file:
        paths = _text_file.readlines()
    paths = [path.strip() for path in paths]
    return paths

#### DiaObject

In [None]:
diaObj_default_columns = ["diaObjectId", "ra", "dec", "nDiaSources", "radecMjdTai"]

args = ImportArguments(
    output_path=hats_dir,
    output_artifact_name="diaObject",
    input_file_list=get_paths("diaObjectTable_tract", raw_dir),
    file_reader=ParquetPyarrowReader(column_names=diaObj_default_columns),
    ra_column="ra",
    dec_column="dec",
    catalog_type="object",
    resume=False,
    pixel_threshold=2_000_000,
)
runner.pipeline_with_client(args, client)

#### DiaSource

In [None]:
args = ImportArguments(
    output_path=hats_dir,
    output_artifact_name="diaSource",
    input_file_list=get_paths("diaSourceTable_tract", raw_dir),
    file_reader=ParquetPyarrowReader(),
    ra_column="ra",
    dec_column="dec",
    catalog_type="source",
    resume=False,
    pixel_threshold=2_000_000,
)
runner.pipeline_with_client(args, client)

#### DiaForcedSource

In [None]:
args = ImportArguments(
    output_path=hats_dir,
    output_artifact_name="diaForcedSource",
    input_file_list=get_paths("forcedSourceOnDiaObjectTable", raw_dir),
    file_reader=ParquetPyarrowReader(),
    ra_column="coord_ra",
    dec_column="coord_dec",
    catalog_type="source",
    pixel_threshold=5_000_000,
    highest_healpix_order=12,
)
runner.pipeline_with_client(args, client)

#### Object

In [None]:
cols_per_band = []
for band in list("ugrizy"):
    for flux_type in ["psf","kron"]:
        prefix = f"{band}_{flux_type}"
        cols_per_band.extend([f"{prefix}Flux", f"{prefix}FluxErr"])
    cols_per_band.append(f"{band}_kronRad")
    
obj_default_columns = [
    "objectId",
    "refFwhm",
    "shape_flag",
    "sky_object",
    "parentObjectId",
    "detect_isPrimary",
    "x",
    "y",
    "xErr",
    "yErr",
    "shape_yy", 
    "shape_xx", 
    "shape_xy", 
    "coord_ra",
    "coord_dec", 
    "coord_raErr", 
    "coord_decErr",
    "tract",
    "patch",
    "detect_isIsolated"
] + cols_per_band

obj_default_columns

In [None]:
args = ImportArguments(
    output_path=hats_dir,
    output_artifact_name="object",
    input_file_list=get_paths("objectTable", raw_dir),
    file_reader=ParquetPyarrowReader(column_names=obj_default_columns),
    ra_column="coord_ra",
    dec_column="coord_dec",
    catalog_type="object",
    resume=False,
    pixel_threshold=300_000,
)
runner.pipeline_with_client(args, client)

#### Source

This is one that's going to get much worse very quickly. The `sourceTable` dimension is on the visit. So each file is very small, and there are LOTS of them. 

```
Planning  : 100% 4/4 [00:00<00:00, 123.68it/s]
Mapping   : 100% 16471/16471 [04:25<00:00,  1.77it/s]
Binning   : 100% 2/2 [00:38<00:00, 17.09s/it]
Splitting : 100% 16471/16471 [28:41<00:00,  1.64s/it]
Reducing  : 100% 148/148 [04:30<00:00,  2.21s/it]
Finishing : 100% 5/5 [00:24<00:00,  8.99s/it]
```

Solutions:

- Use the `IndexedParquetReader`. We can aggregate each index file by something like tract/patch of the visit, to reduce intermediate file usage.
- Escalate to DM. This is going to be ROUGH for everyone if there is no aggregation.

In [None]:
args = ImportArguments(
    output_path=hats_dir,
    output_artifact_name="source",
    input_file_list=get_paths("sourceTable", raw_dir),
    file_reader=ParquetPyarrowReader(),
    ra_column="ra",
    dec_column="dec",
    catalog_type="source",
    resume=False,
    pixel_threshold=1_000_000,
)
runner.pipeline_with_client(args, client)

#### ForcedSource

In [13]:
visits = pd.read_parquet(raw_dir / "visits.parquet")
visit_map = visits[["expMidptMJD"]].T.to_dict('records')[0]

In [None]:
args = ImportArguments(
    output_path=hats_dir,
    output_artifact_name="forcedSource",
    input_file_list=get_paths("forcedSourceTable", raw_dir),
    file_reader=ParquetPyarrowReader(),
    ra_column="coord_ra",
    dec_column="coord_dec",
    catalog_type="source",
    resume=False,
    pixel_threshold=8_000_000,
)
runner.pipeline_with_client(args, client)

In [15]:
client.close()
tmp_path.cleanup()