In [None]:
import healpy
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from astropy.coordinates import SkyCoord
from dustmaps.sfd import SFDQuery
from joblib import Parallel, delayed
from hipscat.pixel_math.hipscat_id import healpix_to_hipscat_id, HIPSCAT_ID_COLUMN

from mom_builder import gen_mom_from_fn
from paths import *

Create a function to query SFD dust map

In [None]:
sfd_query = SFDQuery(INPUT_DIR)


def ebv(norder: int, index_range=None):
    n_size = healpy.order2nside(norder)
    n_pix = healpy.order2npix(norder)

    if not isinstance(index_range, np.ndarray):
        if index_range is None:
            index_range = (0, n_pix)
        if index_range[1] > n_pix:
            index_range = (index_range[0], n_pix)
        index_range = np.arange(*index_range)
    index_range = np.asarray(index_range, dtype=int)

    ra, dec = healpy.pix2ang(n_size, index_range, nest=True, lonlat=True)
    coord = SkyCoord(ra=ra, dec=dec, unit='deg')

    return sfd_query(coord)

Context manager to write tiles to parquet files

In [None]:
class Writer:
    """Write tiles to parquet files
    
    It doesn't optimize Parquet group size for now.
    """

    def __init__(self, path=PARQUET_DIR):
        self.path = path
        self.path.mkdir(parents=True, exist_ok=True)

        self.parquet_writers = {}

    def _create_parquet_writer(self, norder):
        path = self.path / f'pixel_norder={norder:02d}.parquet'
        return pq.ParquetWriter(
            path,
            pa.schema([
                pa.field(HIPSCAT_ID_COLUMN, pa.uint64()),
                pa.field('pixel_Norder', pa.uint8()),
                pa.field('pixel_Npix', pa.uint64()),
                pa.field('ebv', pa.float32()),
            ])
        )

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        for writer in self.parquet_writers.values():
            writer.close()

    def write(self, norder, indexes, values):
        hipscat_index = healpix_to_hipscat_id(norder, indexes)
        table = pa.Table.from_arrays(
            [hipscat_index, np.full(hipscat_index.shape, norder, dtype=np.uint8), indexes, values],
            names=[HIPSCAT_ID_COLUMN, 'pixel_Norder', 'pixel_Npix', 'ebv']
        )

        if norder not in self.parquet_writers:
            self.parquet_writers[norder] = self._create_parquet_writer(norder)
        self.parquet_writers[norder].write_table(table)

Create intermediate parquet files for multiorder map

In [None]:
%%time 

max_norder = 17
threshold = 0.16 / 2 ** (max_norder - 13)
subtree_norder = max(max_norder - 12, 1)


def worker(n_jobs, parallel):
    def fn(norder, rng):
        n_batch = len(rng) // n_jobs
        batches = parallel([
            delayed(ebv)(norder, rng[i:i + n_batch])
            for i in range(0, len(rng), n_batch)
        ])
        return np.concatenate(batches)

    return fn


import shutil

shutil.rmtree(PARQUET_DIR, ignore_errors=True)

with Parallel(n_jobs=12, backend="threading") as parallel:
    worker = worker(parallel.n_jobs, parallel)
    with Writer() as writer:
        for tiles in gen_mom_from_fn(
                worker,
                max_norder=max_norder,
                split_norder=subtree_norder,
                merger=threshold,
        ):
            writer.write(*tiles)

In [None]:
area_desired = 12 * 4 ** max_norder
area_actual = sum(pq.read_metadata(f).num_rows * 4 ** (max_norder - o) for o in range(0, max_norder + 1) if
                  (f := PARQUET_DIR / f'pixel_norder={o:02d}.parquet').exists())

assert area_actual == area_desired, f'{area_actual} != {area_desired}'