Skip to content

Commit

Permalink
feat: refactor parquet multiprocessing wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
RaczeQ committed May 21, 2024
1 parent 7d29648 commit 3636fea
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 82 deletions.
105 changes: 23 additions & 82 deletions quackosm/_intersection.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,29 @@
import multiprocessing
from functools import partial
from pathlib import Path
from queue import Queue
from time import sleep
from typing import Optional

import pyarrow as pa
import pyarrow.parquet as pq
from geoarrow.rust.core import PointArray
from shapely import STRtree
from shapely.geometry.base import BaseGeometry

from quackosm._rich_progress import TaskProgressBar, log_message # type: ignore[attr-defined]
from quackosm._parquet_multiprocessing import map_parquet_dataset
from quackosm._rich_progress import TaskProgressBar # type: ignore[attr-defined]


def _intersection_worker(
queue: Queue[tuple[str, int]], save_path: Path, geometry_filter: BaseGeometry
) -> None: # pragma: no cover
current_pid = multiprocessing.current_process().pid

filepath = save_path / f"{current_pid}.parquet"
writer = None
while not queue.empty():
try:
file_name = None
file_name, row_group_index = queue.get(block=True, timeout=1)

pq_file = pq.ParquetFile(file_name)
row_group_table = pq_file.read_row_group(row_group_index, ["id", "lat", "lon"])
if len(row_group_table) == 0:
continue

points_array = PointArray.from_xy(
x=row_group_table["lon"].combine_chunks(), y=row_group_table["lat"].combine_chunks()
)

tree = STRtree(points_array.to_shapely())

intersecting_ids_array = row_group_table["id"].take(
tree.query(geometry_filter, predicate="intersects")
)

table = pa.table({"id": intersecting_ids_array})
def _intersect_nodes(
table: pa.Table,
geometry_filter: BaseGeometry,
) -> pa.Table: # pragma: no cover
points_array = PointArray.from_xy(
x=table["lon"].combine_chunks(), y=table["lat"].combine_chunks()
)

if not writer:
writer = pq.ParquetWriter(filepath, table.schema)
tree = STRtree(points_array.to_shapely())

writer.write_table(table)
except Exception as ex:
log_message(ex)
if file_name is not None:
queue.put((file_name, row_group_index))
intersecting_ids_array = table["id"].take(tree.query(geometry_filter, predicate="intersects"))

if writer:
writer.close()
return pa.table({"id": intersecting_ids_array})


def intersect_nodes_with_geometry(
Expand All @@ -69,43 +40,13 @@ def intersect_nodes_with_geometry(
progress_bar (Optional[TaskProgressBar]): Progress bar to show task status.
Defaults to `None`
"""
queue: Queue[tuple[str, int]] = multiprocessing.Manager().Queue()

dataset = pq.ParquetDataset(tmp_dir_path / "nodes_valid_with_tags")

for pq_file in dataset.files:
for row_group in range(pq.ParquetFile(pq_file).num_row_groups):
queue.put((pq_file, row_group))

total = queue.qsize()

nodes_intersecting_path = tmp_dir_path / "nodes_intersecting_ids"
nodes_intersecting_path.mkdir(parents=True, exist_ok=True)

try:
processes = [
multiprocessing.Process(
target=_intersection_worker,
args=(queue, nodes_intersecting_path, geometry_filter),
)
for _ in range(multiprocessing.cpu_count())
]

# Run processes
for p in processes:
p.start()

if progress_bar: # pragma: no cover
progress_bar.create_manual_bar(total=total)
while any(process.is_alive() for process in processes):
if progress_bar: # pragma: no cover
progress_bar.update_manual_bar(current_progress=total - queue.qsize())
sleep(1)

if progress_bar: # pragma: no cover
progress_bar.update_manual_bar(current_progress=total)
finally: # pragma: no cover
# In case of exception
for p in processes:
if p.is_alive():
p.terminate()
dataset_path = tmp_dir_path / "nodes_valid_with_tags"
destination_path = tmp_dir_path / "nodes_intersecting_ids"

map_parquet_dataset(
dataset_path=dataset_path,
destination_path=destination_path,
progress_bar=progress_bar,
function=partial(_intersect_nodes, geometry_filter=geometry_filter),
columns=["id", "lat", "lon"],
)
156 changes: 156 additions & 0 deletions quackosm/_parquet_multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import multiprocessing
import traceback
from pathlib import Path
from queue import Queue
from time import sleep
from typing import Callable, Optional

import pyarrow as pa
import pyarrow.parquet as pq

from quackosm._rich_progress import TaskProgressBar # type: ignore[attr-defined]


def _job(
queue: Queue[tuple[str, int]],
save_path: Path,
function: Callable[[pa.Table], pa.Table],
columns: Optional[list[str]] = None,
) -> None: # pragma: no cover
current_pid = multiprocessing.current_process().pid

filepath = save_path / f"{current_pid}.parquet"
writer = None
while not queue.empty():
try:
file_name = None
file_name, row_group_index = queue.get(block=True, timeout=1)

pq_file = pq.ParquetFile(file_name)
row_group_table = pq_file.read_row_group(row_group_index, columns=columns)
if len(row_group_table) == 0:
continue

result_table = function(row_group_table)

if not writer:
writer = pq.ParquetWriter(filepath, result_table.schema)

writer.write_table(result_table)
except Exception as ex:
if file_name is not None:
queue.put((file_name, row_group_index))

msg = (
f"Error in worker (PID: {current_pid},"
f" Parquet: {file_name}, Row group: {row_group_index})"
)
raise RuntimeError(msg) from ex

if writer:
writer.close()


class WorkerProcess(multiprocessing.Process):
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
multiprocessing.Process.__init__(self, *args, **kwargs)
self._pconn, self._cconn = multiprocessing.Pipe()
self._exception: Optional[tuple[Exception, str]] = None

def run(self) -> None:
try:
multiprocessing.Process.run(self)
self._cconn.send(None)
except Exception as e:
tb: str = traceback.format_exc()
self._cconn.send((e, tb))

@property
def exception(self) -> Optional[tuple[Exception, str]]:
if self._pconn.poll():
self._exception = self._pconn.recv()
return self._exception


def map_parquet_dataset(
dataset_path: Path,
destination_path: Path,
function: Callable[[pa.Table], pa.Table],
columns: Optional[list[str]] = None,
progress_bar: Optional[TaskProgressBar] = None,
) -> None:
"""
Apply a function over parquet dataset in a multiprocessing environment.
Will save results in multiple files in a destination path.
Args:
dataset_path (Path): Path of the parquet dataset.
destination_path (Path): Path of the destination.
function (Callable[[pa.Table], pa.Table]): Function to apply over a row group table.
Will save resulting table in a new parquet file.
columns (Optional[list[str]]): List of columns to read. Defaults to `None`.
progress_bar (Optional[TaskProgressBar]): Progress bar to show task status.
Defaults to `None`.
"""
queue: Queue[tuple[str, int]] = multiprocessing.Manager().Queue()

dataset = pq.ParquetDataset(dataset_path)

for pq_file in dataset.files:
for row_group in range(pq.ParquetFile(pq_file).num_row_groups):
queue.put((pq_file, row_group))

total = queue.qsize()

destination_path.mkdir(parents=True, exist_ok=True)

try:
processes = [
WorkerProcess(
target=_job,
args=(queue, destination_path, function, columns),
) # type: ignore[no-untyped-call]
for _ in range(multiprocessing.cpu_count())
]

# Run processes
for p in processes:
p.start()

if progress_bar: # pragma: no cover
progress_bar.create_manual_bar(total=total)
while any(process.is_alive() for process in processes):
if any(p.exception for p in processes):
break

if progress_bar: # pragma: no cover
progress_bar.update_manual_bar(current_progress=total - queue.qsize())
sleep(1)

if progress_bar: # pragma: no cover
progress_bar.update_manual_bar(current_progress=total)
finally: # pragma: no cover
# In case of exception
exceptions = []
for p in processes:
if p.is_alive():
p.terminate()

if p.exception:
exceptions.append(p.exception)

if exceptions:
# use ExceptionGroup in Python3.11
_raise_multiple(exceptions)

Check notice on line 145 in quackosm/_parquet_multiprocessing.py

View check run for this annotation

codefactor.io / CodeFactor

quackosm/_parquet_multiprocessing.py#L75-L145

Complex Method


def _raise_multiple(exceptions: list[tuple[Exception, str]]) -> None:
if not exceptions:
return
try:
error, traceback = exceptions.pop()
msg = f"{error}\n\nOriginal {traceback}"
raise type(error)(msg)
finally:
_raise_multiple(exceptions)
49 changes: 49 additions & 0 deletions tests/base/test_parquet_multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Tests for Parquet multiprocessing wrapper."""

import tempfile
from pathlib import Path
from random import random
from time import sleep
from typing import Any

import duckdb
import pytest

from quackosm._parquet_multiprocessing import map_parquet_dataset


def test_exception_wrapping() -> None:
"""Test if multiprocessing exception raising works.."""
pbf_file = Path(__file__).parent.parent / "test_files" / "monaco.osm.pbf"

with tempfile.TemporaryDirectory(dir=Path(__file__).parent.resolve()) as tmp_dir_name:
duckdb.install_extension("spatial")
duckdb.load_extension("spatial")
nodes_destination = Path(tmp_dir_name) / "nodes_valid_with_tags"
nodes_destination.mkdir(exist_ok=True, parents=True)
duckdb.sql(
f"""
COPY (
SELECT
id, lon, lat
FROM ST_ReadOSM('{pbf_file}')
WHERE kind = 'node'
AND lat IS NOT NULL AND lon IS NOT NULL
) TO '{nodes_destination}' (
FORMAT 'parquet',
PER_THREAD_OUTPUT true,
ROW_GROUP_SIZE 25000
)
"""
)

def raise_error(pa: Any) -> Any:
sleep(random())
raise KeyError("XD")

with pytest.raises(RuntimeError):
map_parquet_dataset(
dataset_path=nodes_destination,
destination_path=Path(tmp_dir_name) / "test",
function=raise_error,
)

0 comments on commit 3636fea

Please sign in to comment.