Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions elasticsearch/_async/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@
Union,
)

from ..compat import safe_task
from ..exceptions import ApiError, NotFoundError, TransportError
from ..helpers.actions import (
_TYPE_BULK_ACTION,
_TYPE_BULK_ACTION_BODY,
_TYPE_BULK_ACTION_HEADER,
_TYPE_BULK_ACTION_HEADER_AND_BODY,
_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY,
_TYPE_BULK_ACTION_WITH_META,
BulkMeta,
_ActionChunker,
_process_bulk_chunk_error,
_process_bulk_chunk_success,
Expand All @@ -54,9 +58,10 @@


async def _chunk_actions(
actions: AsyncIterable[_TYPE_BULK_ACTION_HEADER_AND_BODY],
actions: AsyncIterable[_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY],
chunk_size: int,
max_chunk_bytes: int,
flush_after_seconds: Optional[float],
serializer: Serializer,
) -> AsyncIterable[
Tuple[
Expand All @@ -76,10 +81,42 @@ async def _chunk_actions(
chunker = _ActionChunker(
chunk_size=chunk_size, max_chunk_bytes=max_chunk_bytes, serializer=serializer
)
async for action, data in actions:
ret = chunker.feed(action, data)
if ret:
yield ret

if not flush_after_seconds:
async for action, data in actions:
ret = chunker.feed(action, data)
if ret:
yield ret
else:
item_queue: asyncio.Queue[_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY] = (
asyncio.Queue()
)

async def get_items() -> None:
try:
async for item in actions:
await item_queue.put(item)
finally:
await item_queue.put((BulkMeta.done, None))

async with safe_task(get_items()):
timeout: Optional[float] = flush_after_seconds
while True:
try:
action, data = await asyncio.wait_for(
item_queue.get(), timeout=timeout
)
timeout = flush_after_seconds
except asyncio.TimeoutError:
action, data = BulkMeta.flush, None
timeout = None

if action is BulkMeta.done:
break
ret = chunker.feed(action, data)
if ret:
yield ret

ret = chunker.flush()
if ret:
yield ret
Expand Down Expand Up @@ -159,9 +196,13 @@ async def azip(

async def async_streaming_bulk(
client: AsyncElasticsearch,
actions: Union[Iterable[_TYPE_BULK_ACTION], AsyncIterable[_TYPE_BULK_ACTION]],
actions: Union[
Iterable[_TYPE_BULK_ACTION_WITH_META],
AsyncIterable[_TYPE_BULK_ACTION_WITH_META],
],
chunk_size: int = 500,
max_chunk_bytes: int = 100 * 1024 * 1024,
flush_after_seconds: Optional[float] = None,
raise_on_error: bool = True,
expand_action_callback: Callable[
[_TYPE_BULK_ACTION], _TYPE_BULK_ACTION_HEADER_AND_BODY
Expand Down Expand Up @@ -194,6 +235,9 @@ async def async_streaming_bulk(
:arg actions: iterable or async iterable containing the actions to be executed
:arg chunk_size: number of docs in one chunk sent to es (default: 500)
:arg max_chunk_bytes: the maximum size of the request in bytes (default: 100MB)
:arg flush_after_seconds: time in seconds after which a chunk is written even
if hasn't reached `chunk_size` or `max_chunk_bytes`. Set to 0 to not use a
timeout-based flush. (default: 0)
:arg raise_on_error: raise ``BulkIndexError`` containing errors (as `.errors`)
from the execution of the last chunk when some occur. By default we raise.
:arg raise_on_exception: if ``False`` then don't propagate exceptions from
Expand All @@ -220,9 +264,14 @@ async def async_streaming_bulk(
if isinstance(retry_on_status, int):
retry_on_status = (retry_on_status,)

async def map_actions() -> AsyncIterable[_TYPE_BULK_ACTION_HEADER_AND_BODY]:
async def map_actions() -> (
AsyncIterable[_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY]
):
async for item in aiter(actions):
yield expand_action_callback(item)
if isinstance(item, BulkMeta):
yield item, None
else:
yield expand_action_callback(item)

serializer = client.transport.serializers.get_serializer("application/json")

Expand All @@ -234,7 +283,7 @@ async def map_actions() -> AsyncIterable[_TYPE_BULK_ACTION_HEADER_AND_BODY]:
]
bulk_actions: List[bytes]
async for bulk_data, bulk_actions in _chunk_actions(
map_actions(), chunk_size, max_chunk_bytes, serializer
map_actions(), chunk_size, max_chunk_bytes, flush_after_seconds, serializer
):
for attempt in range(max_retries + 1):
to_retry: List[bytes] = []
Expand Down
46 changes: 45 additions & 1 deletion elasticsearch/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
# specific language governing permissions and limitations
# under the License.

import asyncio
import inspect
import os
import sys
from contextlib import asynccontextmanager, contextmanager
from pathlib import Path
from typing import Tuple, Type, Union
from threading import Thread
from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, Tuple, Type, Union

string_types: Tuple[Type[str], Type[bytes]] = (str, bytes)

Expand Down Expand Up @@ -76,9 +79,50 @@ def warn_stacklevel() -> int:
return 0


@contextmanager
def safe_thread(
target: Callable[..., Any], *args: Any, **kwargs: Any
) -> Iterator[Thread]:
"""Run a thread within a context manager block.

The thread is automatically joined when the block ends. If the thread raised
an exception, it is raised in the caller's context.
"""
captured_exception = None

def run() -> None:
try:
target(*args, **kwargs)
except BaseException as exc:
nonlocal captured_exception
captured_exception = exc

thread = Thread(target=run)
thread.start()
yield thread
thread.join()
if captured_exception:
raise captured_exception


@asynccontextmanager
async def safe_task(
coro: Coroutine[Any, Any, Any],
) -> "AsyncIterator[asyncio.Task[Any]]":
"""Run a background task within a context manager block.

The task is awaited when the block ends.
"""
task = asyncio.create_task(coro)
yield task
await task


__all__ = [
"string_types",
"to_str",
"to_bytes",
"warn_stacklevel",
"safe_thread",
"safe_task",
]
11 changes: 10 additions & 1 deletion elasticsearch/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,21 @@
from .._utils import fixup_module_metadata
from .actions import _chunk_actions # noqa: F401
from .actions import _process_bulk_chunk # noqa: F401
from .actions import bulk, expand_action, parallel_bulk, reindex, scan, streaming_bulk
from .actions import (
BULK_FLUSH,
bulk,
expand_action,
parallel_bulk,
reindex,
scan,
streaming_bulk,
)
from .errors import BulkIndexError, ScanError

__all__ = [
"BulkIndexError",
"ScanError",
"BULK_FLUSH",
"expand_action",
"streaming_bulk",
"bulk",
Expand Down
Loading