Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix add_limit wrapper for async generators #1310

Closed
wants to merge 9 commits into from
9 changes: 3 additions & 6 deletions dlt/extract/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,31 +341,28 @@ def add_limit(self: TDltResourceImpl, max_items: int) -> TDltResourceImpl: # no

def _gen_wrap(gen: TPipeStep) -> TPipeStep:
"""Wrap a generator to take the first `max_items` records"""

# zero items should produce empty generator
if max_items == 0:
return

count = 0
is_async_gen = False
if callable(gen):
gen = gen() # type: ignore

# wrap async gen already here
if isinstance(gen, AsyncIterator):
gen = wrap_async_iterator(gen)
is_async_gen = True

try:
for i in gen: # type: ignore # TODO: help me fix this later
yield i
if i is not None:
count += 1
# async gen yields awaitable so we must count one awaitable more
# so the previous one is evaluated and yielded.
# new awaitable will be cancelled
if count == max_items + int(is_async_gen):
if count == max_items:
return
count += 1
yield i
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to yield only after if section

finally:
if inspect.isgenerator(gen):
gen.close()
Expand Down
41 changes: 11 additions & 30 deletions dlt/extract/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,39 +180,20 @@ def wrap_async_iterator(
gen: AsyncIterator[TDataItems],
) -> Generator[Awaitable[TDataItems], None, None]:
"""Wraps an async generator into a list of awaitables"""
exhausted = False
busy = False

# creates an awaitable that will return the next item from the async generator
async def run() -> TDataItems:
nonlocal exhausted
loop = asyncio.get_event_loop()
should_stop = False
try:
try:
# if marked exhausted by the main thread and we are wrapping a generator
# we can close it here
if exhausted:
raise StopAsyncIteration()
item = await gen.__anext__()
return item
# on stop iteration mark as exhausted
# also called when futures are cancelled
while True:
if should_stop:
break
yield loop.run_until_complete(gen.__anext__()) # type: ignore[arg-type]
Copy link
Collaborator Author

@sultaniman sultaniman May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used eventloop because with the old implementation it has been complaining about not awaited generator when resource combined with .add_limit

tests/extract/test_sources.py::test_add_limit_async
  /[...]/dlt/dlt/extract/pipe_iterator.py:275: RuntimeWarning: coroutine 'wrap_async_iterator.<locals>.run' was never awaited
    pipe_item = next(gen)
  Enable tracemalloc to get traceback where the object was allocated.
  See https://docs.pytest.org/en/stable/how-to/capture-warnings.html#resource-warnings for more info.

except StopAsyncIteration:
exhausted = True
raise
finally:
nonlocal busy
busy = False

# this generator yields None while the async generator is not exhausted
try:
while not exhausted:
while busy:
yield None
busy = True
yield run()
# this gets called from the main thread when the wrapping generater is closed
should_stop = True
except GeneratorExit:
# mark as exhausted
exhausted = True
should_stop = True
if hasattr(gen, "aclose"):
loop.run_until_complete(gen.aclose())


def wrap_parallel_iterator(f: TAnyFunOrGenerator) -> TAnyFunOrGenerator:
Expand Down
42 changes: 29 additions & 13 deletions tests/pipeline/test_resources_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Any, List
import asyncio
import os
import time
import threading
import random
from itertools import product

import dlt, asyncio, pytest, os
import dlt
import pytest

from dlt.extract.exceptions import ResourceExtractionError


@pytest.mark.asyncio
def test_async_iterator_resource() -> None:
# define an asynchronous iterator
@dlt.resource()
Expand Down Expand Up @@ -41,6 +45,7 @@ async def __anext__(self):
#
# async generators resource tests
#
@pytest.mark.asyncio
def test_async_generator_resource() -> None:
async def async_gen_table():
for l_ in ["a", "b", "c"]:
Expand Down Expand Up @@ -70,6 +75,7 @@ async def async_gen_resource():
assert [r[0] for r in rows] == ["a", "b", "c", "d", "e", "f"]


@pytest.mark.asyncio
def test_async_generator_nested() -> None:
def async_inner_table():
async def _gen(idx):
Expand Down Expand Up @@ -124,6 +130,7 @@ async def async_transformer(item):
assert {r[0] for r in rows} == {"at", "bt", "ct"}


@pytest.mark.asyncio
@pytest.mark.parametrize("next_item_mode", ["fifo", "round_robin"])
@pytest.mark.parametrize(
"resource_mode", ["both_sync", "both_async", "first_async", "second_async"]
Expand Down Expand Up @@ -189,20 +196,29 @@ def source():
assert {r[0] for r in rows} == {"e", "f", "g"}

# in both item modes there will be parallel execution
if resource_mode in ["both_async"]:
assert execution_order == ["one", "two", "one", "two", "one", "two"]
fifo_result = ["one", "one", "one", "two", "two", "two"]
round_robin_result = ["one", "two", "one", "two", "one", "two"]

if resource_mode == "both_async" and next_item_mode == "fifo":
assert execution_order == fifo_result
elif resource_mode == "both_async" and next_item_mode == "round_robin":
assert execution_order == round_robin_result
# first the first resouce is exhausted, then the second
elif resource_mode in ["both_sync"] and next_item_mode == "fifo":
assert execution_order == ["one", "one", "one", "two", "two", "two"]
elif resource_mode == "both_sync" and next_item_mode == "fifo":
assert execution_order == fifo_result
# round robin is executed in sync
elif resource_mode in ["both_sync"] and next_item_mode == "round_robin":
assert execution_order == ["one", "two", "one", "two", "one", "two"]
elif resource_mode in ["first_async"]:
assert execution_order == ["two", "two", "two", "one", "one", "one"]
elif resource_mode in ["second_async"]:
assert execution_order == ["one", "one", "one", "two", "two", "two"]
elif resource_mode == "both_sync" and next_item_mode == "round_robin":
assert execution_order == round_robin_result
elif resource_mode == "first_async" and next_item_mode == "fifo":
assert execution_order == fifo_result
elif resource_mode == "first_async" and next_item_mode == "round_robin":
assert execution_order == round_robin_result
elif resource_mode == "second_async" and next_item_mode == "fifo":
assert execution_order == fifo_result
elif resource_mode == "second_async" and next_item_mode == "round_robin":
assert execution_order == round_robin_result
else:
raise AssertionError("Unknown combination")
raise AssertionError(f"Unknown combination [{resource_mode}, {next_item_mode}]")


def test_limit_async_resource() -> None:
Expand Down