# Iterate by row group on import

In [1]:
from pathlib import Path

import cloudpickle
import pyarrow.dataset as pds
import pyarrow.parquet as pq

from dask.distributed import Client, LocalCluster
from hats import read_hats
from hats.pixel_math.sparse_histogram import SparseHistogram
from hats_import.catalog.arguments import ImportArguments
from hats_import.catalog.file_readers import ParquetPandasReader, ParquetPyarrowReader
from hats_import.pipeline import pipeline_with_client


hats_import_data_dir = Path("/Users/orl/code/lsdb-plus/hats-import/tests/data")
current_tmp_dir = Path("/Users/orl/code/lsdb-plus/liv-lf/09 - LSDB Sprint Demos/tmp")

if not current_tmp_dir.exists():
    current_tmp_dir.mkdir(parents=True, exist_ok=True)

In [2]:
# Unit test utilities


def pickle_file_reader(tmp_path, file_reader) -> str:
    """Utility method to pickle a file reader, and return path to pickle."""
    pickled_reader_file = tmp_path / "reader.pickle"
    with open(pickled_reader_file, "wb") as pickle_file:
        cloudpickle.dump(file_reader, pickle_file)
    return pickled_reader_file


def read_partial_histogram(tmp_path, mapping_key, which_histogram="row_count"):
    """Helper to read in the former result of a map operation."""
    histogram_file = tmp_path / f"{which_histogram}_histograms" / f"{mapping_key}.npz"
    hist = SparseHistogram.from_file(histogram_file)
    return hist.to_array()


def dask_client():
    """Create a single client for use by all unit test cases."""
    cluster = LocalCluster(n_workers=1, threads_per_worker=1, dashboard_address=":0")
    client = Client(cluster)
    yield client
    client.close()
    cluster.close()

## Make multi row group data set

Take the small sky object catalog and make a version with multiple row groups

In [3]:
num_row_groups = 12

In [4]:
# Load the input data that will be used to generate the multi-row-group dataset.
input_dataset_path = Path(hats_import_data_dir) / "small_sky_object_catalog" / "dataset"
input_ds = pds.parquet_dataset(input_dataset_path / "_metadata")

# Unit tests expect the Npix=11 data file
input_frag = next(
    frag for frag in input_ds.get_fragments() if frag.path.endswith("Npix=11.parquet")
)
frag_key = Path(input_frag.path).relative_to(input_dataset_path)
input_tbl = input_frag.to_table()

In [5]:
# Write the multi-row-group parquet file.

output_dataset_dir = Path(hats_import_data_dir) / "test_formats"
output_dataset_path = current_tmp_dir / "multi_row_group.parquet"
parquet_writer = pq.ParquetWriter(
    where=output_dataset_path,
    schema=input_tbl.schema,
    use_dictionary=True,
    compression="SNAPPY",
)
step_size = (len(input_tbl) + num_row_groups - 1) // num_row_groups
for i in range(0, len(input_tbl), step_size):
    end = min(i + step_size, len(input_tbl))
    batch = input_tbl.slice(i, end - i)
    parquet_writer.write_table(batch)
parquet_writer.close()

In [6]:
# Quick check that the file was written...

dataset_in = pq.ParquetDataset(output_dataset_path)
dataset_in.read()

pyarrow.Table
_healpix_29: int64
id: int64
ra: double
dec: double
ra_error: int64
dec_error: int64
----
_healpix_29: [[3187422220182231470,3187796123455121090,3188300701662669945,3188300701662669945,3192670279995812269,3192995164288065358,3194102393993053262,3195678697494500888,3196676706683767043,3196723640945762243,3197084959829715592],[3199487976390127826,3200256676290451752,3204516948860795443,3205876081882660000,3210595332878490279,3210618432891708849,3213763510984835044,3214195389015536144,3214969534743101722,3216746212972589575,3220523316512691843],...,[3372246530833947442,3380119216995778932,3380458994856416389,3388235695417907461,3388424365484869373,3389280889354717694,3389344265064945161,3389454143235220745,3390042224874323348,3390233494165235443,3390395511633005928],[3390927915493503682,3391172539243045430,3391463069396352355,3397177333029719609,3397704562975227384,3397804200316730633,3399000453069933430,3399532867186255393,3400255793565258227,3424180623569024089]]
id: [[707

In [7]:
# and that it was written with multiple row groups.

file_in = pq.ParquetFile(output_dataset_path)
file_in.num_row_groups

12

## Read by row group

In [8]:
multi_row_group_parquet = output_dataset_path

num_row_groups = 12

# Check number of row groups in test file.
assert pq.ParquetFile(multi_row_group_parquet).num_row_groups == num_row_groups

# Check we can iterate by row group, with ParquetPandasReader.
total_row_groups = 0
for row_group in ParquetPandasReader(iterate_by_row_groups=True).read(
    multi_row_group_parquet
):
    total_row_groups += 1
    assert len(row_group) > 0
assert total_row_groups == num_row_groups

# Check we can iterate by row group, with ParquetPyarrowReader.
total_row_groups = 0
for row_group in ParquetPyarrowReader(iterate_by_row_groups=True).read(
    multi_row_group_parquet
):
    total_row_groups += 1
    assert len(row_group) > 0
assert total_row_groups == num_row_groups

## Import while iterating by row group

In [9]:
output_artifact_name = "small_sky_source_catalog_by_row"
output_full_path = (current_tmp_dir / "imported_catalog_by_row").as_posix()
if Path(output_full_path).exists():
    import shutil

    shutil.rmtree(output_full_path)

args = ImportArguments(
    output_artifact_name="small_sky_source_catalog_by_row",
    input_file_list=[multi_row_group_parquet],
    file_reader=ParquetPyarrowReader(iterate_by_row_groups=True),
    output_path=(current_tmp_dir / "imported_catalog_by_row").as_posix(),
    highest_healpix_order=0,
    progress_bar=False,
    add_healpix_29=False,
)

In [10]:
with Client(n_workers=1, memory_limit="auto") as client:
    pipeline_with_client(args, client)

In [11]:
catalog = read_hats(args.catalog_path)

assert len(catalog) == len(input_tbl)

In [12]:
# :)