Skip to content

Commit

Permalink
chore: type hint QuerySet and QuerySetEndpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
jorwoods committed Jun 17, 2024
1 parent 2f8fd5d commit c7cec85
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""


class CustomViews(QuerysetEndpoint):
class CustomViews(QuerysetEndpoint[CustomViewItem]):
def __init__(self, parent_srv):
super(CustomViews, self).__init__(parent_srv)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
PathOrFileW = Union[FilePath, FileObjectW]


class Datasources(QuerysetEndpoint):
class Datasources(QuerysetEndpoint[DatasourceItem]):
def __init__(self, parent_srv: "Server") -> None:
super(Datasources, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
27 changes: 21 additions & 6 deletions tableauserverclient/server/endpoint/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from tableauserverclient import datetime_helpers as datetime

import abc
from packaging.version import Version
from functools import wraps
from xml.etree.ElementTree import ParseError
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Generic, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Union

from tableauserverclient.models.pagination_item import PaginationItem
from tableauserverclient.server.request_options import RequestOptions

from .exceptions import (
ServerResponseError,
Expand Down Expand Up @@ -300,25 +304,36 @@ def wrapper(self, *args, **kwargs):
return _decorator


class QuerysetEndpoint(Endpoint):
T = TypeVar("T")


class QuerysetEndpoint(Endpoint, Generic[T]):
@api(version="2.0")
def all(self, *args, **kwargs):
def all(self, *args, **kwargs) -> QuerySet[T]:
if args or kwargs:
raise ValueError(".all method takes no arguments.")
queryset = QuerySet(self)
return queryset

@api(version="2.0")
def filter(self, *_, **kwargs) -> QuerySet:
def filter(self, *_, **kwargs) -> QuerySet[T]:
if _:
raise RuntimeError("Only keyword arguments accepted.")
queryset = QuerySet(self).filter(**kwargs)
return queryset

@api(version="2.0")
def order_by(self, *args, **kwargs):
def order_by(self, *args, **kwargs) -> QuerySet[T]:
if kwargs:
raise ValueError(".order_by does not accept keyword arguments.")
queryset = QuerySet(self).order_by(*args)
return queryset

@api(version="2.0")
def paginate(self, **kwargs):
def paginate(self, **kwargs) -> QuerySet[T]:
queryset = QuerySet(self).paginate(**kwargs)
return queryset

@abc.abstractmethod
def get(self, request_options: RequestOptions) -> Tuple[List[T], PaginationItem]:
raise NotImplementedError(f".get has not been implemented for {self.__class__.__qualname__}")
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/flow_runs_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..request_options import RequestOptions


class FlowRuns(QuerysetEndpoint):
class FlowRuns(QuerysetEndpoint[FlowRunItem]):
def __init__(self, parent_srv: "Server") -> None:
super(FlowRuns, self).__init__(parent_srv)
return None
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/flows_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
PathOrFileW = Union[FilePath, FileObjectW]


class Flows(QuerysetEndpoint):
class Flows(QuerysetEndpoint[FlowItem]):
def __init__(self, parent_srv):
super(Flows, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/groups_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..request_options import RequestOptions


class Groups(QuerysetEndpoint):
class Groups(QuerysetEndpoint[GroupItem]):
@property
def baseurl(self) -> str:
return "{0}/sites/{1}/groups".format(self.parent_srv.baseurl, self.parent_srv.site_id)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/jobs_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import List, Optional, Tuple, Union


class Jobs(QuerysetEndpoint):
class Jobs(QuerysetEndpoint[JobItem]):
@property
def baseurl(self):
return "{0}/sites/{1}/jobs".format(self.parent_srv.baseurl, self.parent_srv.site_id)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/metrics_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tableauserverclient.helpers.logging import logger


class Metrics(QuerysetEndpoint):
class Metrics(QuerysetEndpoint[MetricItem]):
def __init__(self, parent_srv: "Server") -> None:
super(Metrics, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/projects_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tableauserverclient.helpers.logging import logger


class Projects(QuerysetEndpoint):
class Projects(QuerysetEndpoint[ProjectItem]):
def __init__(self, parent_srv: "Server") -> None:
super(Projects, self).__init__(parent_srv)

Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/users_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tableauserverclient.helpers.logging import logger


class Users(QuerysetEndpoint):
class Users(QuerysetEndpoint[UserItem]):
@property
def baseurl(self) -> str:
return "{0}/sites/{1}/users".format(self.parent_srv.baseurl, self.parent_srv.site_id)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/views_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)


class Views(QuerysetEndpoint):
class Views(QuerysetEndpoint[ViewItem]):
def __init__(self, parent_srv):
super(Views, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
2 changes: 1 addition & 1 deletion tableauserverclient/server/endpoint/workbooks_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
PathOrFileW = Union[FilePath, FileObjectW]


class Workbooks(QuerysetEndpoint):
class Workbooks(QuerysetEndpoint[WorkbookItem]):
def __init__(self, parent_srv: "Server") -> None:
super(Workbooks, self).__init__(parent_srv)
self._resource_tagger = _ResourceTagger(parent_srv)
Expand Down
71 changes: 46 additions & 25 deletions tableauserverclient/server/query.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
from typing import Tuple
from .filter import Filter
from .request_options import RequestOptions
from .sort import Sort
from collections.abc import Iterable, Sized
from itertools import count
from typing import Iterator, List, Optional, Protocol, Tuple, TYPE_CHECKING, TypeVar, overload
from tableauserverclient.models.pagination_item import PaginationItem
from tableauserverclient.server.filter import Filter
from tableauserverclient.server.request_options import RequestOptions
from tableauserverclient.server.sort import Sort
import math

from typing_extensions import Self

if TYPE_CHECKING:
from tableauserverclient.server.endpoint import QuerysetEndpoint

T = TypeVar("T")


class Slice(Protocol):
start: Optional[int]
step: Optional[int]
stop: Optional[int]


def to_camel_case(word: str) -> str:
return word.split("_")[0] + "".join(x.capitalize() or "_" for x in word.split("_")[1:])
Expand All @@ -16,28 +32,33 @@ def to_camel_case(word: str) -> str:
"""


class QuerySet:
def __init__(self, model):
class QuerySet(Iterable[T], Sized):
def __init__(self, model: "QuerysetEndpoint[T]") -> None:
self.model = model
self.request_options = RequestOptions()
self._result_cache = None
self._pagination_item = None
self._result_cache: List[T] = []
self._pagination_item = PaginationItem()

def __iter__(self):
def __iter__(self) -> Iterator[T]:
# Not built to be re-entrant. Starts back at page 1, and empties
# the result cache.
self.request_options.pagenumber = 1
self._result_cache = None
total = self.total_available
size = self.page_size
yield from self._result_cache

# Loop through the subsequent pages.
for page in range(1, math.ceil(total / size)):
self.request_options.pagenumber = page + 1
self._result_cache = None
for page in count(1):
self.request_options.pagenumber = page
self._fetch_all()
yield from self._result_cache
# Set result_cache to empty so the fetch will populate
self._result_cache = []
if (page * self.page_size) >= len(self):
return

@overload
def __getitem__(self, k: Slice) -> List[T]:
...

@overload
def __getitem__(self, k: int) -> T:
...

def __getitem__(self, k):
page = self.page_number
Expand Down Expand Up @@ -78,19 +99,19 @@ def __getitem__(self, k):
return self._result_cache[k % size]
elif k in range(self.total_available):
# Otherwise, check if k is even sensible to return
self._result_cache = None
self._result_cache = []
# Add one to k, otherwise it gets stuck at page boundaries, e.g. 100
self.request_options.pagenumber = max(1, math.ceil((k + 1) / size))
return self[k]
else:
# If k is unreasonable, raise an IndexError.
raise IndexError

def _fetch_all(self):
def _fetch_all(self) -> None:
"""
Retrieve the data and store result and pagination item in cache
"""
if self._result_cache is None:
if not self._result_cache:
self._result_cache, self._pagination_item = self.model.get(self.request_options)

def __len__(self) -> int:
Expand All @@ -111,21 +132,21 @@ def page_size(self) -> int:
self._fetch_all()
return self._pagination_item.page_size

def filter(self, *invalid, **kwargs):
def filter(self, *invalid, **kwargs) -> Self:
if invalid:
raise RuntimeError(f"Only accepts keyword arguments.")
raise RuntimeError("Only accepts keyword arguments.")
for kwarg_key, value in kwargs.items():
field_name, operator = self._parse_shorthand_filter(kwarg_key)
self.request_options.filter.add(Filter(field_name, operator, value))
return self

def order_by(self, *args):
def order_by(self, *args) -> Self:
for arg in args:
field_name, direction = self._parse_shorthand_sort(arg)
self.request_options.sort.add(Sort(field_name, direction))
return self

def paginate(self, **kwargs):
def paginate(self, **kwargs) -> Self:
if "page_number" in kwargs:
self.request_options.pagenumber = kwargs["page_number"]
if "page_size" in kwargs:
Expand Down

0 comments on commit c7cec85

Please sign in to comment.