In [1]:
from pathlib import Path

import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from astropy.coordinates import Angle, Latitude, Longitude
from hipscat.pixel_math.hipscat_id import healpix_to_hipscat_id, HIPSCAT_ID_COLUMN
from mocpy import MOC
from mom_builder import MOMMerger
from mom_builder.mom_generator import gen_mom_from_fn
from tqdm import tqdm

In [2]:
class Writer:
    """Write tiles to parquet files
    
    It doesn't optimize Parquet group size for now.
    """

    def __init__(self, path, col_name, col_type):
        self.path = Path(path)
        self.path.mkdir(parents=True, exist_ok=True)

        self.col_name = col_name
        self.col_type = col_type

        self.parquet_writers = {}

    def _create_parquet_writer(self, norder):
        path = self.path / f'pixel_norder={norder:02d}.parquet'
        return pq.ParquetWriter(
            path,
            pa.schema([
                pa.field(HIPSCAT_ID_COLUMN, pa.uint64()),
                pa.field('pixel_Norder', pa.uint8()),
                pa.field('pixel_Npix', pa.uint64()),
                pa.field(self.col_name, self.col_type),
            ])
        )

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        for writer in self.parquet_writers.values():
            writer.close()

    def write(self, norder, indexes, values):
        hipscat_index = healpix_to_hipscat_id(norder, indexes)
        table = pa.Table.from_arrays(
            [hipscat_index, np.full(hipscat_index.shape, norder, dtype=np.uint8), indexes, values],
            names=[HIPSCAT_ID_COLUMN, 'pixel_Norder', 'pixel_Npix', self.col_name]
        )

        if norder not in self.parquet_writers:
            self.parquet_writers[norder] = self._create_parquet_writer(norder)
        self.parquet_writers[norder].write_table(table)

In [3]:
max_norder = 14
split_norder = 2

In [4]:
n = 20_000
rng = np.random.default_rng(0)
ra = Longitude(rng.uniform(0, 360, n), 'deg')
dec = Latitude(np.arcsin(rng.uniform(-1, 1, n)), 'rad')
radius = Angle(rng.lognormal(1.0, 0.5, n), 'arcmin')

In [5]:
# mocs = MOC.from_cones(lon=ra, lat=dec, radius=radius, max_depth=19, delta_depth=2)
# moc = mocs[0].union(*mocs[1:])
# del mocs

In [6]:
class Mask:
    def __init__(self, *, ra, dec, radius, split_norder):
        self.ra = ra
        self.dec = dec
        self.radius = radius
        self.split_norder = split_norder

        self.masks_in_tiles = self._build_masks_in_tiles()

    def _build_masks_in_tiles(self):
        """Returns a lookup table tile-index -> list(mask index) for split order"""
        mocs = MOC.from_cones(lon=self.ra, lat=self.dec, radius=self.radius, max_depth=self.split_norder,
                              delta_depth=2)
        n_tiles = 12 << (2 * self.split_norder)

        lookup = [[]] * n_tiles
        for i, moc in enumerate(mocs):
            for index in moc.flatten():
                lookup[index].append(i)

        lookup = [np.array(a) for a in lookup]
        return lookup

    def moc_in_tile_approx(self, split_index, max_norder):
        """MOC with all the masks intersects with given split_norder tile"""
        idx = self.masks_in_tiles[split_index]
        mocs = MOC.from_cones(lon=ra[idx], lat=dec[idx], radius=radius[idx], max_depth=max_norder,
                              delta_depth=2)
        if len(mocs) == 0:
            return MOC.from_lonlat(lon=Longitude([], 'deg'), lat=Latitude([], 'deg'), max_norder=max_norder)
        if len(mocs) == 1:
            return mocs[0]
        return mocs[0].union(*mocs[1:])

    def moc_in_tile(self, split_index, max_norder):
        """MOC with masks within given tile of split_norder"""
        moc = self.moc_in_tile_approx(split_index, max_norder)
        healpix_cells = np.array([split_index], dtype=np.uint64)
        tile_moc = MOC.from_healpix_cells(healpix_cells, self.split_norder, self.split_norder)
        return tile_moc.intersection(moc)

    def indexes_for_tile(self, split_index, target_norder):
        """Healpix indexes of targer_depth (>= split_norder) for split_index tile of split_norder"""
        moc = self.moc_in_tile(split_index, target_norder)
        return moc.flatten()

%time mask = Mask(ra=ra, dec=dec, radius=radius, split_norder=split_norder)

CPU times: user 196 ms, sys: 46.2 ms, total: 242 ms
Wall time: 165 ms


In [7]:
def parent(index, child_order, parent_order):
    delta_depth = child_order - parent_order
    return np.array(int(index) >> (2 * delta_depth), dtype=np.uint64)


def get_value(order, indexes):
    first_index = indexes[0]
    top_index = parent(first_index, order, mask.split_norder)
    mask_indexes = mask.indexes_for_tile(top_index, order)

    # values = np.isin(indexes, mask_indexes).astype(np.uint8)
    values = np.zeros(indexes.shape, dtype=np.uint8)
    values[mask_indexes - first_index] = 1

    return values

In [8]:
%%time

merger = MOMMerger(state="value", merger="equal", dtype=np.dtype('u1'))

with Writer('parquet', col_name='value', col_type=pa.uint8()) as writer:
    for tiles in tqdm(gen_mom_from_fn(
            get_value,
            max_norder=max_norder,
            split_norder=split_norder,
            merger=merger,
            n_threads=12,
    )):
        writer.write(*tiles)

1923it [01:05, 29.29it/s] 

CPU times: user 10min 8s, sys: 6.94 s, total: 10min 15s
Wall time: 1min 5s



