Skip to content

Commit

Permalink
Merge 80a62fd into 3ae17c7
Browse files Browse the repository at this point in the history
  • Loading branch information
jmolinski committed Apr 9, 2019
2 parents 3ae17c7 + 80a62fd commit cb7af16
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 76 deletions.
95 changes: 82 additions & 13 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# flake8: noqa: F403, F405

import time
import types
from dataclasses import asdict

import pytest
from tests.test_data.countries import COUNTRIES
from tests.test_data.oauth import OAUTH_GET_TOKEN
from tests.utils import MockRequests
from tests.utils import MockRequests, get_last_req, mk_mock_client
from trakt import Trakt, TraktCredentials
from trakt.core.components import DefaultHttpComponent
from trakt.core.exceptions import ClientError
from trakt.core.executors import Executor
from trakt.core.exceptions import ArgumentError, ClientError
from trakt.core.executors import Executor, PaginationIterator
from trakt.core.paths import Path


Expand Down Expand Up @@ -76,15 +75,9 @@ def test_refresh_token_on():

def test_pagination():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

http = lambda client: DefaultHttpComponent(
client,
requests_dependency=MockRequests(
{"pag_off": [data, 200], "pag_on": [data, 200]}, paginated=["pag_on"]
),
client = mk_mock_client(
{"pag_off": [data, 200], "pag_on": [data, 200]}, paginated=["pag_on"]
)

client = Trakt("", "", http_component=http)
executor = Executor(client)

p_nopag = Path("pag_off", [int])
Expand All @@ -96,5 +89,81 @@ def test_pagination():
assert isinstance(res_nopag, list)
assert res_nopag == data

assert isinstance(res_pag, types.GeneratorType)
assert isinstance(res_pag, PaginationIterator)
assert list(executor.run(path=p_pag, page=2, per_page=4)) == [5, 6, 7, 8, 9, 10]


def test_prefetch_off():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
client = mk_mock_client({"pag_on": [data, 200]}, paginated=["pag_on"])
executor = Executor(client)
p_pag = Path("pag_on", [int], pagination=True)

assert get_last_req(client.http) is None
req = executor.run(path=p_pag, page=2, per_page=3)
assert get_last_req(client.http) is None
list(req)
assert get_last_req(client.http) is not None


def test_prefetch_on():
data = list(range(10 ** 4))
client = mk_mock_client({"pag_on": [data, 200]}, paginated=["pag_on"])
executor = Executor(client)
p_pag = Path("pag_on", [int], pagination=True)

# prefetch
assert get_last_req(client.http) is None
req = executor.run(path=p_pag, page=2, per_page=3)
assert get_last_req(client.http) is None
req.prefetch_all()
assert get_last_req(client.http) is not None

# reset history
client.http._requests.req_stack = []
assert get_last_req(client.http) is None

# execute prefetched -> assert no new requests
list(req)
assert get_last_req(client.http) is None


def test_take():
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
client = mk_mock_client({"pag_on": [data, 200]}, paginated=["pag_on"])
executor = Executor(client)
p_pag = Path("pag_on", [int], pagination=True)

it = executor.run(path=p_pag, per_page=2)
assert it.has_next()
assert isinstance(next(it), int)
assert next(it) == 2
assert it.has_next()

assert it.take(3) == [3, 4, 5]
assert it.has_next()

with pytest.raises(ArgumentError):
it.take(-5)

assert it.take(0) == []
assert it.take() == [6, 7] # per_page setting
assert it.has_next()

assert it.take_all() == [8, 9, 10]
assert not it.has_next()

with pytest.raises(StopIteration):
next(it)

assert it.take(2) == it.take_all() == []


def test_chaining():
data = list(range(300))
client = mk_mock_client({"pag_on": [data, 200]}, paginated=["pag_on"])
executor = Executor(client)
p_pag = Path("pag_on", [int], pagination=True)

assert executor.run(path=p_pag, per_page=2).take_all() == data
assert executor.run(path=p_pag, per_page=2).prefetch_all().take_all() == data
9 changes: 6 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ def wrapper(client):
USER = TraktCredentials("", "", "", 10e14)


def mk_mock_client(endpoints, client_id="", client_secret="", user=False):
def mk_mock_client(
endpoints, client_id="", client_secret="", user=False, paginated=None
):
return Trakt(
client_id,
client_secret,
http_component=get_mock_http_component(endpoints),
http_component=get_mock_http_component(endpoints, paginated=paginated),
user=USER if user is False else None,
)


def get_last_req(http):
return http._requests.req_stack[-1]
if http._requests.req_stack:
return http._requests.req_stack[-1]
149 changes: 127 additions & 22 deletions trakt/core/executors.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from __future__ import annotations

import itertools
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
TypeVar,
Union,
cast,
)

from trakt.core import json_parser
from trakt.core.exceptions import ClientError
from trakt.core.exceptions import ArgumentError, ClientError

if TYPE_CHECKING: # pragma: no cover
from trakt.api import TraktApi
Expand Down Expand Up @@ -35,10 +46,6 @@ def __init__(
if expires_in < self.client.config["oauth"]["refresh_token_s"]:
self.client.oauth.refresh_token()

def __getattr__(self, param: str) -> Executor:
self.params.append(param)
return self

def __repr__(self) -> str: # pragma: no cover
return f'Executor(params={".".join(self.params)})'

Expand Down Expand Up @@ -101,25 +108,123 @@ def exec_path_call(
def make_generator(self, path: Path, **kwargs: Any):
start_page = int(kwargs.get("page", 1))
per_page = int(kwargs.get("per_page", 10))
max_pages = 10e10
max_pages = 1 << 16

def generator():
page = start_page
stop_at_page = max_pages
return PaginationIterator(self, path, start_page, per_page, max_pages)

while page < stop_at_page:
response, pagination = self.exec_path_call(
path,
pagination=True,
extra_quargs={"page": str(page), "limit": str(per_page)},
)
def find_matching_path(self) -> List[Path]:
return [p for s in self.path_suites for p in s.find_matching(self.params)]

yield from response

page += 1
stop_at_page = int(pagination["page_count"]) + 1
T = TypeVar("T")
PER_PAGE_LIMIT = 100

return generator()

def find_matching_path(self) -> List[Path]:
return [p for s in self.path_suites for p in s.find_matching(self.params)]
class PaginationIterator(Iterable[T]):
pages_total: Optional[int] = None

def __init__(
self,
executor: Executor,
path: Path,
start_page: int,
per_page: int,
max_pages: int,
) -> None:
self._executor = executor
self._path = path
self._start_page = start_page
self._per_page = per_page
self._max_pages = max_pages

self._exhausted = False
self._queue: List[T] = []
self._yielded_items = 0

def __iter__(self) -> PaginationIterator[T]:
if self._exhausted:
return self

self._exhausted = True
self._page = self._start_page
self._stop_at_page = self._max_pages

self._queue = []
self._yielded_items = 0

return self

def __next__(self) -> T:
if not self._queue:
if not self._has_next_page():
raise StopIteration()

self._fetch_next_page()

self._yielded_items += 1
return self._queue.pop(0)

def _fetch_next_page(self, skip_first: int = 0) -> None:
response, pagination = self._executor.exec_path_call(
self._path,
pagination=True,
extra_quargs={"page": str(self._page), "limit": str(self._per_page)},
)

for r in response[skip_first:]:
self._queue.append(r)

self._page += 1
self._stop_at_page = int(pagination["page_count"])
self.pages_total = self._stop_at_page

def prefetch_all(self) -> PaginationIterator[T]:
"""Prefetch all results. Optimized."""
iterator = cast(PaginationIterator[T], iter(self))

if not self._has_next_page():
return iterator

# tweak per_page setting to make fetching as fast as possible
old_per_page = self._per_page
self._per_page = PER_PAGE_LIMIT

self._page = (self._yielded_items // PER_PAGE_LIMIT) + 1
to_skip = (self._yielded_items % PER_PAGE_LIMIT) + len(self._queue)

self._fetch_next_page(skip_first=to_skip)

while self._has_next_page():
self._fetch_next_page()

self._per_page = old_per_page

return iterator

def _has_next_page(self) -> bool:
return self._page <= self._stop_at_page

def take(self, n: int = -1) -> List[T]:
"""Take n next results. By default returns per_page results."""
if n == -1:
n = self._per_page

if not isinstance(n, int) or n < 0:
raise ArgumentError(
f"argument n={n} is invalid; n must be an int and n >= 1"
)

it = iter(self)
return list(itertools.islice(it, n))

def take_all(self) -> List[T]:
"""Take all available results."""
self.prefetch_all()
return self.take(len(self._queue))

def has_next(self) -> bool:
"""Check if there are any results left."""
if not self._exhausted:
iter(self)

return bool(self._queue or self._has_next_page())
11 changes: 8 additions & 3 deletions trakt/core/paths/endpoint_mappings/comments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from dataclasses import asdict
from typing import Dict, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from trakt.core.models import Episode, Movie, Season, Show
from trakt.core.paths.path import Path
Expand All @@ -18,6 +20,9 @@
Validator,
)

if TYPE_CHECKING: # pragma: no cover
from trakt.core.executors import PaginationIterator

COMMENT_TEXT_VALIDATOR = PerArgValidator(
"comment", lambda c: isinstance(c, str) and len(c.split(" ")) > 4
)
Expand Down Expand Up @@ -157,7 +162,7 @@ def delete_comment(self, *, id: Union[Comment, str, int], **kwargs) -> None:

def get_replies(
self, *, id: Union[Comment, str, int], **kwargs
) -> Iterable[Comment]:
) -> PaginationIterator[Comment]:
id = int(self._generic_get_id(id))
return self.run("get_replies", **kwargs, id=id)

Expand All @@ -168,7 +173,7 @@ def post_reply(
comment: str,
spoiler: bool = False,
**kwargs
) -> Iterable[Comment]:
) -> PaginationIterator[Comment]:
id = int(self._generic_get_id(id))

body = {"comment": comment, "spoiler": spoiler}
Expand Down
11 changes: 8 additions & 3 deletions trakt/core/paths/endpoint_mappings/episodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Iterable, List, Union
from __future__ import annotations

from typing import TYPE_CHECKING, List, Union

from trakt.core.models import Comment, Episode, Season
from trakt.core.paths.endpoint_mappings.movies import (
Expand All @@ -22,6 +24,9 @@
SEASON_ID_VALIDATOR = PerArgValidator("season", lambda i: isinstance(i, int))
EPISODE_ID_VALIDATOR = PerArgValidator("season", lambda i: isinstance(i, int))

if TYPE_CHECKING: # pragma: no cover
from trakt.core.executors import PaginationIterator


class EpisodesI(SuiteInterface):
name = "episodes"
Expand Down Expand Up @@ -105,7 +110,7 @@ def get_comments(
episode: Union[Episode, int, str],
sort: str = "newest",
**kwargs
) -> Iterable[Comment]:
) -> PaginationIterator[Comment]:
id = self._generic_get_id(show)
season = self._generic_get_id(season)
episode = self._generic_get_id(episode)
Expand Down Expand Up @@ -143,7 +148,7 @@ def get_lists(
type: str = "personal",
sort: str = "popular",
**kwargs
) -> Iterable[TraktList]:
) -> PaginationIterator[TraktList]:
id = self._generic_get_id(show)
season = self._generic_get_id(season)
episode = self._generic_get_id(episode)
Expand Down
Loading

0 comments on commit cb7af16

Please sign in to comment.