Skip to content

Commit

Permalink
Add PaginationIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
jmolinski committed Apr 8, 2019
1 parent 3ae17c7 commit f5ff543
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 35 deletions.
51 changes: 40 additions & 11 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import pytest
from tests.test_data.countries import COUNTRIES
from tests.test_data.oauth import OAUTH_GET_TOKEN
from tests.utils import MockRequests
from tests.utils import MockRequests, get_last_req, mk_mock_client
from trakt import Trakt, TraktCredentials
from trakt.core.components import DefaultHttpComponent
from trakt.core.exceptions import ClientError
from trakt.core.executors import Executor
from trakt.core.executors import Executor, PaginationIterator
from trakt.core.paths import Path


Expand Down Expand Up @@ -76,15 +76,9 @@ def test_refresh_token_on():

def test_pagination():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

http = lambda client: DefaultHttpComponent(
client,
requests_dependency=MockRequests(
{"pag_off": [data, 200], "pag_on": [data, 200]}, paginated=["pag_on"]
),
client = mk_mock_client(
{"pag_off": [data, 200], "pag_on": [data, 200]}, paginated=["pag_on"]
)

client = Trakt("", "", http_component=http)
executor = Executor(client)

p_nopag = Path("pag_off", [int])
Expand All @@ -96,5 +90,40 @@ def test_pagination():
assert isinstance(res_nopag, list)
assert res_nopag == data

assert isinstance(res_pag, types.GeneratorType)
assert isinstance(res_pag, PaginationIterator)
assert list(executor.run(path=p_pag, page=2, per_page=4)) == [5, 6, 7, 8, 9, 10]


def test_prefetch_off():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
client = mk_mock_client({"pag_on": [data, 200]}, paginated=["pag_on"])
executor = Executor(client)
p_pag = Path("pag_on", [int], pagination=True)

assert get_last_req(client.http) is None
req = executor.run(path=p_pag, page=2, per_page=3)
assert get_last_req(client.http) is None
list(req)
assert get_last_req(client.http) is not None


def test_prefetch_on():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
client = mk_mock_client({"pag_on": [data, 200]}, paginated=["pag_on"])
executor = Executor(client)
p_pag = Path("pag_on", [int], pagination=True)

# prefetch
assert get_last_req(client.http) is None
req = executor.run(path=p_pag, page=2, per_page=3)
assert get_last_req(client.http) is None
req.prefetch_all()
assert get_last_req(client.http) is not None

# reset history
client.http._requests.req_stack = []
assert get_last_req(client.http) is None

# execute prefetched -> assert no new requests
list(req)
assert get_last_req(client.http) is None
9 changes: 6 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ def wrapper(client):
USER = TraktCredentials("", "", "", 10e14)


def mk_mock_client(endpoints, client_id="", client_secret="", user=False):
def mk_mock_client(
endpoints, client_id="", client_secret="", user=False, paginated=None
):
return Trakt(
client_id,
client_secret,
http_component=get_mock_http_component(endpoints),
http_component=get_mock_http_component(endpoints, paginated=paginated),
user=USER if user is False else None,
)


def get_last_req(http):
return http._requests.req_stack[-1]
if http._requests.req_stack:
return http._requests.req_stack[-1]
87 changes: 66 additions & 21 deletions trakt/core/executors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast

from trakt.core import json_parser
from trakt.core.exceptions import ClientError
Expand Down Expand Up @@ -35,10 +35,6 @@ def __init__(
if expires_in < self.client.config["oauth"]["refresh_token_s"]:
self.client.oauth.refresh_token()

def __getattr__(self, param: str) -> Executor:
self.params.append(param)
return self

def __repr__(self) -> str: # pragma: no cover
return f'Executor(params={".".join(self.params)})'

Expand Down Expand Up @@ -101,25 +97,74 @@ def exec_path_call(
def make_generator(self, path: Path, **kwargs: Any):
start_page = int(kwargs.get("page", 1))
per_page = int(kwargs.get("per_page", 10))
max_pages = 10e10
max_pages = 1 << 16

def generator():
page = start_page
stop_at_page = max_pages
return PaginationIterator(self, path, start_page, per_page, max_pages)

while page < stop_at_page:
response, pagination = self.exec_path_call(
path,
pagination=True,
extra_quargs={"page": str(page), "limit": str(per_page)},
)
def find_matching_path(self) -> List[Path]:
return [p for s in self.path_suites for p in s.find_matching(self.params)]

yield from response

page += 1
stop_at_page = int(pagination["page_count"]) + 1
class PaginationIterator:
pages_total: Optional[int] = None

return generator()
def __init__(
self,
executor: Executor,
path: Path,
start_page: int,
per_page: int,
max_pages: int,
) -> None:
self._executor = executor
self._path = path
self._start_page = start_page
self._per_page = per_page
self._max_pages = max_pages

def find_matching_path(self) -> List[Path]:
return [p for s in self.path_suites for p in s.find_matching(self.params)]
self._exhausted = False

def __iter__(self) -> PaginationIterator:
if self._exhausted:
return self

self._exhausted = True
self._page = self._start_page
self._stop_at_page = self._max_pages

self._queue: Any = [] # TODO generic type specification

return self

def __next__(self):
if not self._queue:
if not self._has_next_page():
raise StopIteration()

self._fetch_next_page()

return self._queue.pop(0)

def _fetch_next_page(self) -> None:
response, pagination = self._executor.exec_path_call(
self._path,
pagination=True,
extra_quargs={"page": str(self._page), "limit": str(self._per_page)},
)

for r in response:
self._queue.append(r)

self._page += 1
self._stop_at_page = int(pagination["page_count"])
self.pages_total = self._stop_at_page

def prefetch_all(self) -> PaginationIterator:
iterator = cast(PaginationIterator, iter(self))
while self._has_next_page():
self._fetch_next_page()

return iterator

def _has_next_page(self) -> bool:
return self._page <= self._stop_at_page

0 comments on commit f5ff543

Please sign in to comment.