## Prototyping Dask-HiPSCat Integration

In [1]:
# Since the best way to learn/hack on Dask seems to be by directly editing the installed source code,
# let's have it all reloaded automatically.
#
# We'll exempt modules where we monkeypatch functions from autoreloading (e.g. dask.utils).
# Otherwise the autoreload would remove our patched version.
#

%load_ext autoreload

%aimport -dask.utils
%autoreload 2

In [2]:
# To set up an environment for this experiment, run:
#    mamba create -n lsd2-2023 -c conda-forge dask pyarrow healpy ipykernel python-graphviz rich ipywidgets
# and then create a Jupyter kernel w. something like:
#    conda activate lsd2-2023
#    python -m ipykernel install --user --name lsd2-2023 --display-name "LSD2 (2023)"
#

import dask
import dask.dataframe as dd
import pandas as pd
import numpy as np
import healpy as hp

print(f"{dask.__version__=}")
print(f"{pd.__version__=}")
print(f"{np.__version__=}")
print(f"{hp.__version__=}")

dask.__version__='2022.12.1'
pd.__version__='1.5.2'
np.__version__='1.24.0'
hp.__version__='1.16.1'


It's useful to have a unique (within the catalog!) ID for each row. Two reasons: tables can be efficiently joined on that ID (e.g., a table of objects with its corresponding sources), and it can also be used to determine the partition where that row lives (maybe w. some extra metadata).

For now, I'll just make that be a very high-order healpix ipix index computed from the object's (ra, dec) and maximally bitshifted to the left, with the remaining bits to the right just increasing numerically if there are more than one objects with the same ipix in the catalog.

This is implemented by `compute_index` (and should be performant).

In [3]:
# some stats on high-order pixels
# for order=20, ipix corresponds to 0.2" x 0.2" area in the sky (~a single LSST pixel)
# and we have room for up to ~1M objects detected within that same pixel.
#
# FIXME: while I can't yet imagine a scenario with 1M objects in 0.2" x 0.2" area, any such limitation feels icky.
#        someone at some point will come up with a need for more... we should think of an escape hatch...
#
order=20
pix_edge_len_arcsec = np.sqrt(hp.nside2pixarea(hp.order2nside(order), degrees=True))*3600 # arsec x arcsec square
bits=4 + 2*order
maxindex=2**(64-bits)
print(f"{order=}\n{pix_edge_len_arcsec=}\n{bits=}\n{maxindex=}")

order=20
pix_edge_len_arcsec=0.2012980319424261
bits=44
maxindex=1048576


In [4]:
def compute_index(ra, dec, order=20):
    # the 64-bit index, viewed as a bit array, consists of two parts:
    #
    #    idx = |(pix)|(rank)|
    #
    # where pix is the healpix nest-scheme index of for given order,
    # and rank is a monotonically increasing integer for all objects
    # with the same value of pix.

    # compute the healpix pix-index of each object
    pix = hp.ang2pix(2**order, ra, dec, nest=True, lonlat=True)

    # shift to higher bits of idx
    bits=4 + 2*order
    idx = pix.astype(np.uint64) << (64-bits)

    # sort
    orig_idx = np.arange(len(idx))
    sorted_idx = np.lexsort((dec, ra, idx))
    idx, ra, dec, orig_idx = idx[sorted_idx], ra[sorted_idx], dec[sorted_idx], orig_idx[sorted_idx]

    # compute the rank for each unique value of idx (== bitshifted pix, at this point)
    # the goal: given values of idx such as:
    #   1000, 1000, 1000, 2000, 2000, 3000, 5000, 5000, 5000, 5000, ...
    # compute a unique array such as:
    #   1000, 1001, 1002, 2000, 2001, 3000, 5000, 5001, 5002, 5003, ...
    # that is for the subset of nobj objects with the same pix, add
    # to the index an range [0..nobj)
    #
    # how this works:
    # * x are the indices of the first appearance of a new pix value. In the example above,
    # it would be equal to [0, 3, 5, 6, ...]. But note that this is also the total number
    # of entries before the next unique value (e.g. 5 above means there were 5 elements in
    # idx -- 1000, 1000, 1000, 2000, 2000 -- before the third unique value of idx -- 3000)
    # * i are the indices of each unique value of idx, starting with 0 for the first one
    # in the example above, i = [0, 0, 0, 1, 1, 2, 3, 3, 3, 3]
    # * we need construct an array such as [0, 1, 2, 0, 1, 0, 0, 1, 2, 3, ...], i.e.
    # a one that resets every time the value of idx changes. If we can construct this, we
    # can add this array to idx and achieve our objective.
    # * the way to do it: start with a monotonously increasing array
    #  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ...] and subtract an array that looks like this:
    #  [0, 0, 0, 3, 3, 5, 6, 7, 7, 7, ...]. This is an array that at each location has
    #  the index where that location's pix value appeared for the first time. It's easy
    #  to confirm that this is simply x[i].
    #
    # And this is what the following four lines implement.
    _, x, i = np.unique(idx, return_inverse=True, return_index=True)
    x = x.astype(np.uint64)
    ii = np.arange(len(i), dtype=np.uint64)
    di = ii - x[i]
    idx += di

    # remap back to the old sort order
    idx = idx[orig_idx]

    return idx

if False:
    # quick test
    try:
        df_orig
    except NameError:
        #df_orig = pd.read_parquet('/epyc/projects3/sam_hipscat/output/gaia_real/Norder6/Npix29079/catalog.parquet')
        df_orig = pd.read_parquet('/epyc/projects3/sam_hipscat/output/gaia_real/Norder2/Npix138/catalog.parquet')
    dff = df_orig.copy()
    dff["_ID"] = compute_index(dff["ra"].values, dff["dec"].values, order=14)
    dff.set_index("_ID", inplace=True)
    dff.sort_index(inplace=True)

    from IPython.display import display
    display(dff.iloc[:10])

Sam's current conversion of Gaia doesn't include an index column. We'll have to clone his catalog and add it here. That's what the next few cells do.

In [5]:
# test parquet writing with metadata
def add_index(infile, outfile):
    df = pd.read_parquet(infile)

    df["_ID"] = compute_index(df["ra"].values, df["dec"].values, order=14)
    df.set_index("_ID", inplace=True)
    df.sort_index(inplace=True)

    import os, os.path
    os.makedirs(os.path.dirname(outfile), exist_ok=True)

    df.to_parquet(outfile, engine='pyarrow', compression='snappy', index=True)

def fixup_hipscat(inprefix, outprefix, start=None, end=None):
    # the [start:end] at the end lets you convert a subset of
    # partitions, to speed things up while developing

    incat = glob.glob(f'{inprefix}/*/*/catalog.parquet')
    outcat = [ outprefix + fn[len(inprefix):] for fn in incat ]

    from rich.progress import track
    for infile, outfile in track(list(zip(incat, outcat))[start:end]):
        add_index(infile, outfile)

This creates two catalogs, so we can test `join`s later on.

(note: rich.progress seems to render double progress bars for me; I found https://github.com/Textualize/rich/issues/1737 which claims this has been fixed, but doesn't seem so in this situation)

In [6]:
import glob

inprefix = "/epyc/projects3/sam_hipscat/output/gaia_real"
outprefix1 = "/epyc/projects3/mjuric_hipscat/gaiaA"
outprefix2 = "/epyc/projects3/mjuric_hipscat/gaiaB"

In [None]:
fixup_hipscat(inprefix, outprefix1, 0, 100)
fixup_hipscat(inprefix, outprefix2, 0, 10)

For our next trick, we'll use Dask's built-in `dd.read_parquet` to read our HiPSCat as a single parquet dataset.

We will also have dask discover the min/max values the index (`_ID`) column in each .parquet file, by reading parquet metadata (as opposed to loading the entire file). This will allow dask to compute its `DataFrame.divisions` field, which tells it which files contain what range of indices. This, in turn, allows for rapid selection based on the index, as Dask can just load the files having the selected data.

For this to work, a few things need to be true:
1. There has to be an index column in each file (✔)
1. Ranges of indices contained in individual partition must not overlap (✔, by construction of HiPSCat & our index)
1. The parquet files must be read in increasing order of index (see below)

If any of these (poorly documented) requirements are violated, Dask will still load the dataset but won't load the `DataFrame.divisions` field making it horribly slow. It also won't give any indication as to why it failed to load divisions -- i.e., it took me an hour to find a bug where one file had out-of-order indices. This is an example of unfortunate UX design. IMHO, if something is explicitly requested, and it fails, the code should complain loudly to let the user know something's off (raise an exception pointing to the problem). Here, Dask doesn't even emit a warning. This type of thing (silent failures w/o information on errors) are a common antipattern in Dask; caveat emptor.

Now for our point 3. above: when passed a list of files (or a directory), Dask will sort them using natural sort (https://en.wikipedia.org/wiki/Natural_sort_order) hoping this file order is also the order in which indices increase in each file. This works for the typical case of `(part-1.parquet, part-2.parquet, ... part-9.parquet, part-10.parquet, ...)`. Directory names are also included in this sort. But this is where there's an issue with HiPSCat as (for example) `Norder1/Npix11/catalog.parquet` will contain objects with indices that are larger than (say) `Norder2/Npix0/catalog.parquet` (Healpix indices start at zero at the north pole and roughly increase towards the south pole).

So we need to teach Dask to detect HiPSCat catalogs and sort their filenames differently. We'll do this by monkeypatching `dask.utils.natural_sort_key`:

In [7]:
# patch the dask.utils.natural_sort_key to recognize and sort hierarchical HiPS directories
# in order of increasing healpix index of their contents.

import dask.utils as du
try:
    _orig_natural_sort_key
except NameError:
    _orig_natural_sort_key = du.natural_sort_key

def hips_or_natural_sort_key(s: str) -> list[str | int]:
    import re
    m = re.match(r"^(.*)/Norder(\d+)/Npix(\d+)/([^/]*)$", s)
    if m is None:
        return _orig_natural_sort_key(s)
    
    root, order, ipix, leaf = m.groups()
    order, ipix = int(order), int(ipix)
    ipix20 = ipix << 2*(20 - order)
    k = (root, ipix20, leaf)
    return k
hips_or_natural_sort_key.__doc__ = _orig_natural_sort_key.__doc__

du.natural_sort_key = hips_or_natural_sort_key

Let's check that this works -- if it does, `df.divisions` will be a list of indices. Note `calculate_divisions=True` below; this requests this metadata to be loaded.

(I also specify which `columns` to load, just to speed up development)

In [8]:
df = dd.read_parquet(f'{outprefix1}/*/*/catalog.parquet', calculate_divisions=True, columns=['ra', 'dec', 'source_id', 'parallax'])
print(f"ndivisions={len(df.divisions)}: {df.divisions[:5]} ...")
assert df.divisions and df.divisions[0] is not None
print(len(df))

ndivisions=101: (8589934592, 72057602627862528, 126100802451275776, 144115505903435776, 216172782113783808) ...
61472247


In [10]:
df2 = dd.read_parquet(f'{outprefix2}/*/*/catalog.parquet', columns=['ra', 'dec', 'source_id', 'parallax'], calculate_divisions=True)
print(f"ndivisions={len(df2.divisions)}: {df2.divisions[:5]} ...")
assert df2.divisions and df2.divisions[0] is not None
print(len(df2))

ndivisions=11: (8589934592, 2449958424922816512, 4683743706954596352, 4755801335352262656, 4827858873555615744) ...
6961555


And now let's demonstrate we can use Dask's own machinery to do something interesting!

In [11]:
from dask.diagnostics import ProgressBar
ProgressBar().register()

In [12]:
df.join(df2, rsuffix="_2", how="inner").query("ra_2 >= 45 and dec > 25").\
    to_parquet(f'{outprefix2}-dask', overwrite=True, write_metadata_file=True)

[########################################] | 100% Completed | 14.69 s


The output is just plain partitioned parquet, with one file per partition. Note that most of these ended up being empty; Dask doesn't cull them by default.

In [13]:
! ls -lrt '{outprefix2}-dask'

total 38732
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.61.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.57.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.1.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.2.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.58.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.4.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.5.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.6.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.0.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.7.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.3.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.8.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.84.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.10.parquet
-rw-rw-r-- 1 mjuric mjuric     4962 Dec 28 16:08 part.12.parquet
-rw-rw

What have we learned?

The good:

* Dask can be taught to read HiPSCat parquet as if they're partitioned datasets
* It involves some nasty monkeypatching of code, but not too much (one can imagine a PR to generalize the sort order)
* Once loaded, things like simple joins work as expected

The bad:
* Dask's DataFrame partition model is rather inefficient (for our use case). It only records the _minmum_ index value of data each partition, but not the maximum. For example, given indices `df2.divisions = [8589934592, 2449958424922816512, 4683743706954596352, 4755801335352262656, 4827858873555615744, ...]`, it assumes that the first partition may contain any value between `[8589934592, 2449958424922816512)` -- essentially the entire northern sky. But it doesn't know that the maximum value in that partition is only around ~10000000000). So when it computes the execution graph for `join`, rather than immediately exluding many partitions from `df2`, it executes each and every one of those sub-joins (which all return empty dataframes). This can be seen by running `.visualize()` on the invocation above. This was discussed within the Dask community (https://github.com/dask/dask/issues/3384), but the conclusion was that it's a niche use case. Unfortunately, it's a very frequent, major, use case for us.
* Dask doesn't appear to have a mature optimizer like Spark's Catalyst, that could perform predicate push-down to (for example) eliminate all unused columns from parquet reads. To _not_ read everything, one has to explicitly specify the columns one will use in `read_parquet`'s `columns` kwarg, and this API is optional -- left unspecified, it defaults to reading everything. Users are bound to forget this :(, and likely read 10x more data than they really need (costing both time and money). We need to fix this in our API (e.g., require colums to be specified in any future `read_hipscat` or alike).

The ugly:
* The documentation of internals is extremely poor. For example, the dask version used here uses a mysterious `BlockIndex` instance as an argument to `df.map_partitions` in it's `to_parquet` implementation. This instance behaves "magically" -- in the actual call it ends up being replaced by an index of a partition being mapped. This isn't documented anywhere -- official documentation says to use `partition_info` to achieve this. After reading the source code and inserting print statements one can figure out what BlockIndex does, but the lack of documentation is a huge time sink. It also points to Dask's development model not having a requirement that the corresponding documentation is updated with each PR (a lesson to us NOT do to that).

Lessons learned and what do we do next?

1. The main purpose of this exercise was to figure out how far could we get if we relied on Dask's `df.divisions` for efficient large-dataset computations. While it's nice to see going from zero to a simple join in ~24 hrs, the `df.divisions` data model (just minimums, no maximums) makes it unlikely we can build a performant class around it. We'll unfortunately have to override division handling in its entirety.

  Fortunately, there's an example to follow! It looks like https://github.com/geopandas/dask-geopandas did something like this for geospatial dataframes. I haven't had much time to analyze their [code](https://github.com/geopandas/dask-geopandas/blob/main/dask_geopandas/core.py), but I see mentions of `spatial_partitions` which gives me hope).

1. I now have a much better idea of what a good API may look like. I think we'll have something like `HiPSDataFrame` class, and want to do something like for analytics:
  
```
df  = lsdb.read_hipscat(outprefix1, columns=['ra', 'dec', 'source_id', 'parallax'])
df2 = lsdb.read_hipscat(outprefix2, columns=['ra', 'dec', 'source_id', 'parallax'])
df.join(df2, rsuffix="_2", how="inner").query("ra_2 >= 45 and dec > 25").to_hipscat(f'{outprefix2}-dask')
```  
  for analytics, and something like
```
df = dd.read_csv("gaia_inputs/*.csv")
hcat = lsdb.from_dask(df, lon="ra", lat="dec")
hcat.to_hipscat('gaia')
```
for importing of catalogs.