In [None]:
%pip install -U hats lsdb

In [None]:
import math
from pathlib import Path
from shutil import rmtree

import dask
import lsdb
from dask.distributed import Client
from lsdb.core.search.pixel_search import PixelSearch
from tqdm.auto import tqdm

from hats_import_parquet import hats_import_parquet

dask.config.set({
    "distributed.comm.timeouts.connect": "600s",   # Connection timeout
    "distributed.comm.timeouts.tcp": "1200s",       # Communication timeout
    "distributed.nanny.shutdown-timeout": "1200s",  # Increase nanny shutdown timeout
})

In [None]:
# "Global" paths on PSC Bridges2 cluster
GLOBAL_HATS_PATH = Path("/ocean/projects/phy210048p/shared/hats/catalogs/")
LOCAL_HATS_PATH = Path("./hats")

PS1_OTMO_PATH = GLOBAL_HATS_PATH / "ps1/ps1_otmo"
PS1_OTMO_MARGIN_PATH = GLOBAL_HATS_PATH / "ps1/ps1_otmo_10arcs"

ZUBERCAL_PATH = GLOBAL_HATS_PATH / "ztf_dr16/zubercal"

GAIA_CATALOG_TYPE = "vcep"
GAIA_VARS_PATH = LOCAL_HATS_PATH / f"gaia_dr3_{GAIA_CATALOG_TYPE}"
GAIA_VARS_MARGIN_PATH = LOCAL_HATS_PATH / f"gaia_dr3_{GAIA_CATALOG_TYPE}_10arcsec"

PS1_FILTERS = 'grizy'
PS1_MAG_SUFFIXES = ['MeanPSFMag', 'MeanPSFMagErr', 'Flags']

OUTPUT_CATALOG_NAME = f"zubercal_{GAIA_CATALOG_TYPE}"
LSDB_OUTPUT_PATH = Path("./lsdb") / OUTPUT_CATALOG_NAME
LSDB_OUTPUT_PATH.mkdir(exist_ok=True, parents=True)

In [None]:
def matched_catalog(search_filter, output_catalog_name):
    try:
        ps1_otmo = lsdb.read_hats(
            PS1_OTMO_PATH,
            margin_cache=PS1_OTMO_MARGIN_PATH,
            # Few useful columns from PS1 object catalog
            columns=(
                ['objID', 'raMean', 'decMean']
                + [f'{fltr}{suffix}' for fltr in PS1_FILTERS for suffix in PS1_MAG_SUFFIXES]
            ),
            search_filter=search_filter,
        )
    # No coverage
    except ValueError as e:
        print(e)
        return None

    # Zubercal catalog, skip coordinates and few other columns
    # Column description:
    # http://atua.caltech.edu/ZTF/Fields/ReadMe.txt
    try:
        zubercal = lsdb.read_hats(
            ZUBERCAL_PATH,
            columns=['mjd', 'mag', 'magerr', 'objectid', 'band'],
            search_filter=search_filter,
            filters=[
                ("info", "==", 0),  # No errors in calibration
                ("flag", "==", 0),  # Good observational conditions
            ],
        )
    # No coverage
    except ValueError as e:
        print(e)
        return None

    gaia_var = lsdb.read_hats(
        GAIA_VARS_PATH,
        margin_cache=GAIA_VARS_MARGIN_PATH,
        search_filter=search_filter,
    )

    print(gaia_var._ddf.npartitions, ps1_otmo._ddf.npartitions)
    
    # try:
    result = gaia_var.crossmatch(
        ps1_otmo,
        radius_arcsec=1.0,
        suffixes=["", ""],
        output_catalog_name="gaia_vars_x_ps1_otmo",
    ).join_nested(
        zubercal,
        left_on='objID',
        right_on='objectid',
        nested_column_name='lc',
        output_catalog_name=OUTPUT_CATALOG_NAME,
    )
    # except ValueError as e:
    #     print(e)
    #     return None
    
    return result

In [None]:
%%time

order = 1
num_pixels = 12 * 4 ** order
batch_n_digits = len(str(num_pixels))

with Client(n_workers=64, threads_per_worker=1, memory_limit='64GB') as client:
    display(client)

    for pix in tqdm(range(num_pixels)):
        batch_str = f"batch_{pix:0{batch_n_digits}d}"
        output_path = LSDB_OUTPUT_PATH / f"{batch_str}.parquet"
        output_catalog_name = f"{OUTPUT_CATALOG_NAME}_{batch_str}"

        if output_path.exists():
            print(f"{output_path} exists, skipping")
            continue
            # if (properties := output_path / "properties").exists():
            #     print(f"{properties} exists, skipping")
            #     continue
            # print(f"Deleting incomplete catalog {output_path}")
            # rmtree(output_path)

        pixel_search = PixelSearch([(order, pix)])

        print("X-matching...")
        batch = matched_catalog(pixel_search, output_catalog_name)

        # No coverage
        if batch is None:
            print(f"No coverage for tile ({order}, {pix})")
            continue

        print(f"Matched partitions: {batch._ddf.npartitions}")

        print("Computing...")
        result = batch.compute()

        n_rows = result.shape[0]
        print(f"Number of rows: {result.shape[0]}")

        if n_rows == 0:
            print(f"No matched rows for tile ({pix}, {order})")
            continue

        print("Saving...")
        result.drop(["Dir", "Norder", "Npix"], axis=1).to_parquet(output_path, index=True)

In [None]:
%%time

hats_import_parquet(LSDB_OUTPUT_PATH, LOCAL_HATS_PATH, OUTPUT_CATALOG_NAME,
                    dask_kwargs={'n_workers': 16, 'memory_limit': '64GB'})