In [1]:
import healpy
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from astropy.coordinates import SkyCoord
from dustmaps.sfd import SFDQuery
from joblib import Parallel, delayed
from hipscat.pixel_math.hipscat_id import healpix_to_hipscat_id, HIPSCAT_ID_COLUMN

from mom_builder import mom_from_array, mom_from_batch_it, gen_mom_from_fn
from multiorder import AbstractMultiorderMapBuilder
from min_max_mean_state import MinMaxMeanState, MinMaxMeanStateMerger
from paths import *

In [2]:
class StateMerger(MinMaxMeanStateMerger):
    def __init__(self, threshold: float = 0.16):
        self.threshold = threshold
    
    def validate_state(self, state: MinMaxMeanState) -> bool:
        norm = max(map(abs, [state.min, state.max]))
        if norm == 0.0:
            return True
        return (state.max - state.min) / norm <= self.threshold
    
sfd_query = SFDQuery(INPUT_DIR)

def ebv(norder: int, index_range=None):
    n_size = healpy.order2nside(norder)
    n_pix = healpy.order2npix(norder)
    
    if not isinstance(index_range, np.ndarray):
        if index_range is None:
            index_range = (0, n_pix)
        if index_range[1] > n_pix:
            index_range = (index_range[0], n_pix)
        index_range = np.arange(*index_range)
    index_range = np.asarray(index_range, dtype=int)
    
    ra, dec = healpy.pix2ang(n_size, index_range, nest=True, lonlat=True)
    coord = SkyCoord(ra=ra, dec=dec, unit='deg')
    
    return sfd_query(coord)


def ebv_f64(norder: int, index_range=None):
    return ebv(norder, index_range).astype(np.float64)
    

class Builder(StateMerger, AbstractMultiorderMapBuilder):
    def __init__(self, max_norder, threshold, ebv_vals):
        AbstractMultiorderMapBuilder.__init__(self, max_norder)
        StateMerger.__init__(self, threshold)
        
        self.ebv = ebv_vals
        
    def calculate_state(self, index_max_norder: int) -> MinMaxMeanState:
        value = self.ebv[index_max_norder].item()
        return MinMaxMeanState(min=value, max=value, mean=value)

In [3]:
def accumulate_gen(max_norder, threshold):
    gen = gen_mom_from_fn(
        fn=ebv,
        max_norder=max_norder,
        subtree_norder=0,
        threshold=threshold,
    )
    tiles = [[(np.array([], dtype=np.uint64), np.array([], dtype=float))] for _ in range(max_norder + 1)]
    for norder, indexes, values in gen:
        tiles[norder].append((indexes, values))
    for norder in range(max_norder + 1):
        indexes = np.concatenate([indexes for indexes, values in tiles[norder]])
        values = np.concatenate([values for indexes, values in tiles[norder]])
        tiles[norder] = (indexes, values)
    return tiles

In [4]:
class Writer:
    """Write tiles to parquet files
    
    It doesn't optimize Parquet group size for now.
    """
    def __init__(self, path=PARQUET_DIR):
        self.path = path
        self.path.mkdir(parents=True, exist_ok=True)
        
        self.parquet_writers = {}
        
    def _create_parquet_writer(self, norder):
        path = self.path / f'pixel_norder={norder:02d}.parquet'
        return pq.ParquetWriter(
            path, 
            pa.schema([
                pa.field(HIPSCAT_ID_COLUMN, pa.uint64()),
                pa.field('pixel_Norder', pa.uint8()),
                pa.field('pixel_Npix', pa.uint64()),
                pa.field('ebv', pa.float32()),
            ])
        )
        
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        for writer in self.parquet_writers.values():
            writer.close()
    
    def write(self, norder, indexes, values):
        hipscat_index = healpix_to_hipscat_id(norder, indexes)
        table = pa.Table.from_arrays(
            [hipscat_index, np.full(hipscat_index.shape, norder, dtype=np.uint8), indexes, values],
            names=[HIPSCAT_ID_COLUMN, 'pixel_Norder', 'pixel_Npix', 'ebv']
        )
        
        if norder not in self.parquet_writers:
            self.parquet_writers[norder] = self._create_parquet_writer(norder)
        self.parquet_writers[norder].write_table(table)

In [5]:
max_norder = 9

%time ebv_vals = ebv(max_norder)

%time tiles = Builder(max_norder=max_norder, threshold=0.16, ebv_vals=ebv_vals).build()

%time mom = mom_from_array(ebv_vals, max_norder, 0.16)

%time mom_from_gen = accumulate_gen(max_norder, 0.16)

CPU times: user 595 ms, sys: 126 ms, total: 722 ms
Wall time: 751 ms
CPU times: user 3.69 s, sys: 34.4 ms, total: 3.73 s
Wall time: 3.77 s
CPU times: user 40.5 ms, sys: 4.46 ms, total: 45 ms
Wall time: 45.3 ms
CPU times: user 679 ms, sys: 141 ms, total: 820 ms
Wall time: 839 ms


In [6]:
mom

[(array([], dtype=uint64), array([], dtype=float32)),
 (array([], dtype=uint64), array([], dtype=float32)),
 (array([], dtype=uint64), array([], dtype=float32)),
 (array([], dtype=uint64), array([], dtype=float32)),
 (array([], dtype=uint64), array([], dtype=float32)),
 (array([], dtype=uint64), array([], dtype=float32)),
 (array([ 2126,  2137,  2320,  2504,  2572,  5065,  5157,  6385,  6402,
          6707,  6749,  6753,  6772,  6887,  8209, 10931, 10933, 12612,
         12617, 12623, 12626, 12634, 12653, 12664, 12665, 12667, 12752,
         12753, 12758, 12761, 12767, 13149, 13151, 13482, 13832, 13834,
         13856, 13857, 13858, 13859, 13864, 13866, 13993, 16394, 16416,
         16419, 16523, 16569, 16614, 16901, 16912, 16924, 16946, 16952,
         16953, 16956, 16957, 16959, 17000, 17002, 17003, 17006, 17041,
         17043, 17044, 17051, 17054, 17076, 17077, 17079, 17087, 17089,
         17091, 17096, 17097, 17098, 17099, 17102, 17106, 17107, 17110,
         17112, 17115, 17120

In [7]:
mom_from_gen

[(array([], dtype=uint64), array([], dtype=float64)),
 (array([], dtype=uint64), array([], dtype=float64)),
 (array([], dtype=uint64), array([], dtype=float64)),
 (array([], dtype=uint64), array([], dtype=float64)),
 (array([], dtype=uint64), array([], dtype=float64)),
 (array([], dtype=uint64), array([], dtype=float64)),
 (array([ 2126,  2137,  2320,  2504,  2572,  5065,  5157,  6385,  6402,
          6707,  6749,  6753,  6772,  6887,  8209, 10931, 10933, 12612,
         12617, 12623, 12626, 12634, 12653, 12664, 12665, 12667, 12752,
         12753, 12758, 12761, 12767, 13149, 13151, 13482, 13832, 13834,
         13856, 13857, 13858, 13859, 13864, 13866, 13993, 16394, 16416,
         16419, 16523, 16569, 16614, 16901, 16912, 16924, 16946, 16952,
         16953, 16956, 16957, 16959, 17000, 17002, 17003, 17006, 17041,
         17043, 17044, 17051, 17054, 17076, 17077, 17079, 17087, 17089,
         17091, 17096, 17097, 17098, 17099, 17102, 17106, 17107, 17110,
         17112, 17115, 17120

In [8]:
len(tiles[6].indexes), len(mom[6][0])

(318, 318)

In [9]:
%%time

max_norder = 14
threshold = 0.16 / 2
batch_size = 1 << 20

batches = range(0, healpy.order2npix(max_norder) + batch_size, batch_size)

it = Parallel(n_jobs=-1, return_as="generator", backend="threading")(
    delayed(ebv)(max_norder, rng)
    for rng in zip(batches, batches[1:])
)

mom = mom_from_batch_it(it, max_norder, threshold)

CPU times: user 14min 1s, sys: 1min 14s, total: 15min 15s
Wall time: 2min 37s


In [6]:
ntiles = sum(len(indexes) for indexes, values in mom)
max_ntiles = healpy.order2npix(max_norder)
print(f"Number of tiles: {ntiles:_d} / {max_ntiles:_d} ({ntiles / max_ntiles:.2%})")

Number of tiles: 25_869_093 / 3_221_225_472 (0.80%)


In [5]:
%%time 

max_norder = 17
threshold = 0.16 / 2**(max_norder - 13)
subtree_norder = max(max_norder - 12, 0)


def worker(n_jobs, parallel):
    def fn(norder, rng):
        n_batch = len(rng) // n_jobs
        batches = parallel([
            delayed(ebv)(norder, rng[i:i + n_batch])
            for i in range(0, len(rng), n_batch)
        ])
        return np.concatenate(batches)
    return fn


import shutil
shutil.rmtree(PARQUET_DIR, ignore_errors=True)


with Parallel(n_jobs=12, backend="threading") as parallel:
    worker = worker(parallel.n_jobs, parallel)
    with Writer() as writer:
        for tiles in gen_mom_from_fn(
                worker,
                max_norder,
                subtree_norder=subtree_norder,
                threshold=threshold
        ):
            writer.write(*tiles)

CPU times: user 14h 59min 54s, sys: 3h 43min 14s, total: 18h 43min 9s
Wall time: 4h 7min


### Validation

In [32]:
max_norder = 17

parquet_dataset = pa.dataset.dataset(PARQUET_DIR, format='parquet')

ntiles = sum(fragment.count_rows() for fragment in parquet_dataset.get_fragments())
max_ntiles = healpy.order2npix(max_norder)
print(f"Number of tiles: {ntiles:_d} / {max_ntiles:_d} ({ntiles / max_ntiles:.2%})")

area_in_max_norder = 0
for fragment in parquet_dataset.get_fragments():
    path = Path(fragment.path)
    norder = int(path.name.split('=')[1].split('.')[0])
    assert norder <= max_norder
    area = 4 ** (max_norder - norder)
    nrows = fragment.count_rows()
    area_in_max_norder += nrows * area    
assert area_in_max_norder == max_ntiles, f"{area_in_max_norder} != {max_ntiles}"

Number of tiles: 1_849_261_851 / 206_158_430_208 (0.90%)
