diff --git a/pyproject.toml b/pyproject.toml index 1321a0ed848c..f6af1da8c223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -376,7 +376,6 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ "sentry.api.endpoints.organization_releases", - "sentry.api.paginator", "sentry.db.postgres.base", "sentry.middleware.auth", "sentry.middleware.ratelimit", diff --git a/src/sentry/api/paginator.py b/src/sentry/api/paginator.py index dfdf16b5d11b..dab8cea9c9ad 100644 --- a/src/sentry/api/paginator.py +++ b/src/sentry/api/paginator.py @@ -14,7 +14,7 @@ from django.db.models.functions import Lower from sentry_protos.snuba.v1.request_common_pb2 import PageToken -from sentry.utils.cursors import Cursor, CursorResult, build_cursor +from sentry.utils.cursors import Cursor, CursorResult, CursorValue, build_cursor from sentry.utils.pagination_factory import PaginatorLike quote_name = connections["default"].ops.quote_name @@ -27,7 +27,7 @@ MAX_SNUBA_ELEMENTS = 10000 -def count_hits(queryset, max_hits): +def count_hits(queryset: Any, max_hits: int) -> int: if not max_hits: return 0 hits_query = queryset.values()[:max_hits].query @@ -79,7 +79,7 @@ def __init__( self.on_results = on_results self.post_query_filter = post_query_filter - def _is_asc(self, is_prev): + def _is_asc(self, is_prev: bool) -> bool: return (self.desc and is_prev) or not (self.desc or is_prev) def build_queryset(self, value, is_prev): @@ -133,10 +133,10 @@ def build_queryset(self, value, is_prev): return queryset - def get_item_key(self, item, for_prev): + def get_item_key(self, item: Any, for_prev: bool = False) -> CursorValue: raise NotImplementedError - def value_from_cursor(self, cursor): + def value_from_cursor(self, cursor: Cursor) -> CursorValue: raise NotImplementedError def get_result(self, limit=100, cursor=None, count_hits=False, known_hits=None, max_hits=None): @@ -204,7 +204,7 @@ def get_result(self, limit=100, cursor=None, count_hits=False, known_hits=None, max_hits=max_hits if count_hits else None, cursor=cursor, is_desc=self.desc, - key=self.get_item_key, + key=self.get_item_key, # type: ignore[arg-type] # mypy loses default-param info on bound methods on_results=self.on_results, ) @@ -216,28 +216,28 @@ def get_result(self, limit=100, cursor=None, count_hits=False, known_hits=None, return cursor - def count_hits(self, max_hits): + def count_hits(self, max_hits: int) -> int: return count_hits(self.queryset, max_hits) class Paginator(BasePaginator): - def get_item_key(self, item, for_prev=False): + def get_item_key(self, item: Any, for_prev: bool = False) -> int: value = getattr(item, self.key) return int(math.floor(value) if self._is_asc(for_prev) else math.ceil(value)) - def value_from_cursor(self, cursor): + def value_from_cursor(self, cursor: Cursor) -> CursorValue: return cursor.value class DateTimePaginator(BasePaginator): multiplier = 1000 - def get_item_key(self, item, for_prev=False): + def get_item_key(self, item: Any, for_prev: bool = False) -> int: value = getattr(item, self.key) value = float(value.strftime("%s.%f")) * self.multiplier return int(math.floor(value) if self._is_asc(for_prev) else math.ceil(value)) - def value_from_cursor(self, cursor): + def value_from_cursor(self, cursor: Cursor) -> datetime: # type: ignore[override] return datetime.fromtimestamp(float(cursor.value) / self.multiplier).replace( tzinfo=timezone.utc ) @@ -264,10 +264,10 @@ def get_result( self, limit: int = 100, cursor: Any = None, - count_hits: Any = False, - known_hits: Any = None, - max_hits: Any = None, - ): + count_hits: bool = False, + known_hits: int | None = None, + max_hits: int | None = None, + ) -> CursorResult[Any]: # offset is page # # value is page limit if cursor is None: @@ -310,7 +310,7 @@ def get_result( return CursorResult(results=results, next=next_cursor, prev=prev_cursor, hits=hits) - def count_hits(self, max_hits): + def count_hits(self, max_hits: int) -> int: return count_hits(self.queryset, max_hits) @@ -341,7 +341,7 @@ def __init__( self.data_count_func = data_count_func self.queryset_load_func = queryset_load_func - def get_result(self, limit=100, cursor=None): + def get_result(self, limit: int = 100, cursor: Any = None) -> CursorResult[Any]: # type: ignore[override] if cursor is None: cursor = Cursor(0, 0, 0) @@ -393,7 +393,7 @@ def get_result(self, limit=100, cursor=None): return CursorResult(results=results, next=next_cursor, prev=prev_cursor) -def reverse_bisect_left(a, x, lo=0, hi=None): +def reverse_bisect_left(a: Sequence[Any], x: Any, lo: int = 0, hi: int | None = None) -> int: """\ Similar to ``bisect.bisect_left``, but expects the data in the array ``a`` to be provided in descending order, rather than the ascending order assumed @@ -808,7 +808,7 @@ def __init__( self.callback = callback self.on_results = on_results - def get_result(self, limit: int, cursor: Cursor | None = None): + def get_result(self, limit: int, cursor: Cursor | None = None) -> CursorResult[Any]: if cursor is None: cursor = Cursor(0, 0, 0) @@ -818,7 +818,8 @@ def get_result(self, limit: int, cursor: Cursor | None = None): fetch_limit += 1 # +1 to limit so that we can tell if there are more results left after the current page # offset = "page" number * max number of items per page - fetch_offset = cursor.offset * cursor.value + assert isinstance(cursor.value, (int, float)) + fetch_offset = int(cursor.offset * cursor.value) if self.offset < 0: raise BadPaginationError("Pagination offset cannot be negative") @@ -855,7 +856,7 @@ def get_result(self, limit, cursor=None): next=self.cursor_from_page_token(page_token=next_page_token), ) - def cursor_from_page_token(self, page_token: PageToken): + def cursor_from_page_token(self, page_token: PageToken) -> Cursor: has_more = not page_token.HasField("end_pagination") or not page_token.end_pagination return Cursor( @@ -864,10 +865,11 @@ def cursor_from_page_token(self, page_token: PageToken): has_results=has_more, ) - def page_token_from_cursor(self, cursor: Cursor | None): + def page_token_from_cursor(self, cursor: Cursor | None) -> PageToken | None: if cursor is None: return None + assert isinstance(cursor.value, str) bytes = base64.b64decode(cursor.value.encode("utf-8")) page_token = PageToken() diff --git a/src/sentry/utils/pagination_factory.py b/src/sentry/utils/pagination_factory.py index 7c59c0a96546..23a38ab6c554 100644 --- a/src/sentry/utils/pagination_factory.py +++ b/src/sentry/utils/pagination_factory.py @@ -18,9 +18,9 @@ def get_result( self, limit: int = 100, cursor: Any = None, - count_hits: Any = False, - known_hits: Any = None, - max_hits: Any = None, + count_hits: bool = False, + known_hits: int | None = None, + max_hits: int | None = None, ) -> CursorResult[Any]: pass