diff --git a/.github/workflows/momento-local-tests.yml b/.github/workflows/momento-local-tests.yml new file mode 100644 index 00000000..2fe2a134 --- /dev/null +++ b/.github/workflows/momento-local-tests.yml @@ -0,0 +1,41 @@ +name: Momento Local tests + +on: + pull_request: + branches: [main] + +jobs: + local-tests: + strategy: + matrix: + os: [ubuntu-24.04] + python-version: ["3.13"] + runs-on: ${{ matrix.os }} + + env: + TEST_API_KEY: ${{ secrets.ALPHA_TEST_AUTH_TOKEN }} + TEST_CACHE_NAME: python-integration-test-${{ matrix.python-version }}-${{ matrix.new-python-protobuf }}-${{ github.sha }} + + steps: + - uses: actions/checkout@v4 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install and configure Poetry + uses: snok/install-poetry@v1 + with: + version: 1.3.1 + virtualenvs-in-project: true + + - name: Install dependencies + run: poetry install + + - name: Start Momento Local + run: | + docker run --cap-add=NET_ADMIN --rm -d -p 8080:8080 -p 9090:9090 gomomento/momento-local --enable-test-admin + + - name: Run tests + run: poetry run pytest -p no:sugar -q -m local diff --git a/.github/workflows/on-pull-request.yml b/.github/workflows/on-pull-request.yml index 844cff77..9989494c 100644 --- a/.github/workflows/on-pull-request.yml +++ b/.github/workflows/on-pull-request.yml @@ -57,7 +57,7 @@ jobs: run: poetry run ruff format --check --diff src tests - name: Run tests - run: poetry run pytest -p no:sugar -q + run: poetry run pytest -p no:sugar -q -m "not local" test-examples: runs-on: ubuntu-24.04 diff --git a/Makefile b/Makefile index 271a83f4..20bc4a88 100644 --- a/Makefile +++ b/Makefile @@ -50,7 +50,12 @@ gen-sync: do-gen-sync format lint .PHONY: test ## Run unit and integration tests with pytest test: - @poetry run pytest + @poetry run pytest -m "not local" + +.PHONY: test-local +## Run the integration tests that require Momento Local +test-local: + @poetry run pytest -m local .PHONY: precommit ## Run format, lint, and test as a step before committing. diff --git a/pyproject.toml b/pyproject.toml index 9be222ca..604f8a2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,9 @@ log_level = "ERROR" log_cli = true log_cli_format = "%(asctime)s [%(levelname)s] %(message)s" log_cli_date_format = "%Y-%m-%d %H:%M:%S.%f" +markers = [ + "local: tests that require Momento Local", +] [tool.mypy] python_version = "3.7" diff --git a/src/momento/retry/fixed_timeout_retry_strategy.py b/src/momento/retry/fixed_timeout_retry_strategy.py index 52b465ca..d5bad5fc 100644 --- a/src/momento/retry/fixed_timeout_retry_strategy.py +++ b/src/momento/retry/fixed_timeout_retry_strategy.py @@ -49,7 +49,7 @@ def determine_when_to_retry(self, props: RetryableProps) -> Optional[float]: # If a retry attempt's timeout has passed but the client's overall timeout has not yet passed, # we should reset the deadline and retry. if ( - props.attempt_number > 0 + props.attempt_number > 0 # type: ignore[misc] and props.grpc_status == grpc.StatusCode.DEADLINE_EXCEEDED # type: ignore[misc] and props.overall_deadline > datetime.now() ): diff --git a/tests/conftest.py b/tests/conftest.py index e56cf925..55c9a3d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,9 @@ import asyncio import os import random +from contextlib import asynccontextmanager, contextmanager from datetime import timedelta -from typing import AsyncIterator, Callable, Iterator, List, Optional, Union, cast +from typing import AsyncGenerator, AsyncIterator, Callable, Iterator, List, Optional, Union, cast import pytest import pytest_asyncio @@ -41,6 +42,8 @@ TTopicName, ) +from tests.momento.local.momento_local_async_middleware import MomentoLocalAsyncMiddleware, MomentoLocalMiddlewareArgs +from tests.momento.local.momento_local_middleware import MomentoLocalMiddleware from tests.utils import ( unique_test_cache_name, uuid_bytes, @@ -51,13 +54,17 @@ # Integration test data ####################### -TEST_CONFIGURATION = Configurations.Laptop.latest() +TEST_CONFIGURATION: Configuration = Configurations.Laptop.latest() TEST_TOPIC_CONFIGURATION = TopicConfigurations.Default.latest().with_client_timeout(timedelta(seconds=10)) TEST_AUTH_CONFIGURATION = AuthConfigurations.Laptop.latest() TEST_AUTH_PROVIDER = CredentialProvider.from_environment_variable("TEST_API_KEY") +MOMENTO_LOCAL_HOSTNAME = os.environ.get("MOMENTO_HOSTNAME", "127.0.0.1") +MOMENTO_LOCAL_PORT = int(os.environ.get("MOMENTO_PORT", "8080")) +TEST_LOCAL_AUTH_PROVIDER = CredentialProvider.for_momento_local(MOMENTO_LOCAL_HOSTNAME, MOMENTO_LOCAL_PORT) + TEST_CACHE_NAME: Optional[str] = os.getenv("TEST_CACHE_NAME") if not TEST_CACHE_NAME: @@ -354,6 +361,48 @@ async def auth_client_async() -> AsyncIterator[AuthClientAsync]: yield _auth_client +@asynccontextmanager +async def client_async_local( + cache_name: str, + middleware_args: Optional[MomentoLocalMiddlewareArgs] = None, + config_fn: Optional[Callable[[Configuration], Configuration]] = None, +) -> AsyncGenerator[CacheClientAsync, None]: + config = TEST_CONFIGURATION + + if config_fn: + config = config_fn(config) + + if middleware_args: + config = config.add_middleware(MomentoLocalAsyncMiddleware(middleware_args)) + + client = await CacheClientAsync.create(config, TEST_LOCAL_AUTH_PROVIDER, DEFAULT_TTL_SECONDS) + + await client.create_cache(cache_name) + + yield client + + +@contextmanager +def client_local( + cache_name: str, + middleware_args: Optional[MomentoLocalMiddlewareArgs] = None, + config_fn: Optional[Callable[[Configuration], Configuration]] = None, +) -> Iterator[CacheClient]: + config = TEST_CONFIGURATION + + if config_fn: + config = config_fn(config) + + if middleware_args: + config = config.add_middleware(MomentoLocalMiddleware(middleware_args)) + + client = CacheClient.create(config, TEST_LOCAL_AUTH_PROVIDER, DEFAULT_TTL_SECONDS) + + client.create_cache(cache_name) + + yield client + + TUniqueCacheName = Callable[[CacheClient], str] diff --git a/tests/momento/local/__init__.py b/tests/momento/local/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/momento/local/momento_error_code_metadata.py b/tests/momento/local/momento_error_code_metadata.py new file mode 100644 index 00000000..275c99ac --- /dev/null +++ b/tests/momento/local/momento_error_code_metadata.py @@ -0,0 +1,20 @@ +from momento.errors import MomentoErrorCode + +MOMENTO_ERROR_CODE_TO_METADATA = { + MomentoErrorCode.INVALID_ARGUMENT_ERROR: "invalid-argument", + MomentoErrorCode.UNKNOWN_SERVICE_ERROR: "unknown", + MomentoErrorCode.ALREADY_EXISTS_ERROR: "already-exists", + MomentoErrorCode.NOT_FOUND_ERROR: "not-found", + MomentoErrorCode.INTERNAL_SERVER_ERROR: "internal", + MomentoErrorCode.PERMISSION_ERROR: "permission-denied", + MomentoErrorCode.AUTHENTICATION_ERROR: "unauthenticated", + MomentoErrorCode.CANCELLED_ERROR: "cancelled", + MomentoErrorCode.LIMIT_EXCEEDED_ERROR: "resource-exhausted", + MomentoErrorCode.BAD_REQUEST_ERROR: "invalid-argument", + MomentoErrorCode.TIMEOUT_ERROR: "deadline-exceeded", + MomentoErrorCode.SERVER_UNAVAILABLE: "unavailable", + MomentoErrorCode.CLIENT_RESOURCE_EXHAUSTED: "resource-exhausted", + MomentoErrorCode.FAILED_PRECONDITION_ERROR: "failed-precondition", + MomentoErrorCode.UNKNOWN_ERROR: "unknown", + MomentoErrorCode.CONNECTION_ERROR: "unavailable", +} diff --git a/tests/momento/local/momento_local_async_middleware.py b/tests/momento/local/momento_local_async_middleware.py new file mode 100644 index 00000000..f184defd --- /dev/null +++ b/tests/momento/local/momento_local_async_middleware.py @@ -0,0 +1,110 @@ +import asyncio +from typing import List + +from grpc.aio import Metadata +from momento import logs +from momento.config.middleware import MiddlewareMessage, MiddlewareRequestHandlerContext, MiddlewareStatus +from momento.config.middleware.aio import Middleware, MiddlewareMetadata, MiddlewareRequestHandler + +from tests.momento.local.momento_error_code_metadata import MOMENTO_ERROR_CODE_TO_METADATA +from tests.momento.local.momento_local_middleware_args import MomentoLocalMiddlewareArgs +from tests.momento.local.momento_rpc_method import MomentoRpcMethod + + +class MomentoLocalAsyncMiddlewareRequestHandler(MiddlewareRequestHandler): + def __init__(self, args: MomentoLocalMiddlewareArgs): + self._args = args + self._cache_name = None + self._logger = logs.logger + + async def on_request_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata: + grpc_metadata = metadata.grpc_metadata + + if grpc_metadata is not None: + self._set_grpc_metadata(grpc_metadata, "request-id", self._args.request_id) + + if self._args.return_error is not None: + error = MOMENTO_ERROR_CODE_TO_METADATA[self._args.return_error] + if error is not None: + self._set_grpc_metadata(grpc_metadata, "return-error", error) + + if self._args.error_rpc_list is not None: + rpcs = self._concatenate_rpcs(self._args.error_rpc_list) + self._set_grpc_metadata(grpc_metadata, "error-rpcs", rpcs) + + if self._args.delay_rpc_list is not None: + rpcs = self._concatenate_rpcs(self._args.delay_rpc_list) + self._set_grpc_metadata(grpc_metadata, "delay-rpcs", rpcs) + + if self._args.error_count is not None: + self._set_grpc_metadata(grpc_metadata, "error-count", str(self._args.error_count)) + + if self._args.delay_millis is not None: + self._set_grpc_metadata(grpc_metadata, "delay-ms", str(self._args.delay_millis)) + + if self._args.delay_count is not None: + self._set_grpc_metadata(grpc_metadata, "delay-count", str(self._args.delay_count)) + + if self._args.stream_error_rpc_list is not None: + rpcs = self._concatenate_rpcs(self._args.stream_error_rpc_list) + self._set_grpc_metadata(grpc_metadata, "stream-error-rpcs", rpcs) + + if self._args.stream_error is not None: + error = MOMENTO_ERROR_CODE_TO_METADATA[self._args.stream_error] + if error is not None: + self._set_grpc_metadata(grpc_metadata, "stream-error", error) + + if self._args.stream_error_message_limit is not None: + limit_str = str(self._args.stream_error_message_limit) + self._set_grpc_metadata(grpc_metadata, "stream-error-message-limit", limit_str) + + cache_name = grpc_metadata.get("cache") + if cache_name is not None: + self._cache_name = cache_name + else: + self._logger.debug("No cache name found in metadata.") + + return metadata + + async def on_request_body(self, request: MiddlewareMessage) -> MiddlewareMessage: + request_type = request.constructor_name + + if self._cache_name is not None: + if self._args.test_metrics_collector is not None: # type: ignore[unreachable] + rpc_method = MomentoRpcMethod.from_request_name(request_type) + if rpc_method: + self._args.test_metrics_collector.add_timestamp( + self._cache_name, + rpc_method, + int(asyncio.get_event_loop().time() * 1000), # Current time in milliseconds + ) + else: + self._logger.debug("No cache name available. Timestamp will not be collected.") + + return request + + async def on_response_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata: + return metadata + + async def on_response_body(self, response: MiddlewareMessage) -> MiddlewareMessage: + return response + + async def on_response_status(self, status: MiddlewareStatus) -> MiddlewareStatus: + return status + + @staticmethod + def _set_grpc_metadata(metadata: Metadata, key: str, value: str) -> None: + if value is not None: + metadata[key] = value + + @staticmethod + def _concatenate_rpcs(rpcs: List[MomentoRpcMethod]) -> str: + return " ".join(rpc.metadata for rpc in rpcs) + + +class MomentoLocalAsyncMiddleware(Middleware): + def __init__(self, args: MomentoLocalMiddlewareArgs): + self._args = args + + async def on_new_request(self, context: MiddlewareRequestHandlerContext) -> MiddlewareRequestHandler: + return MomentoLocalAsyncMiddlewareRequestHandler(self._args) diff --git a/tests/momento/local/momento_local_metrics_collector.py b/tests/momento/local/momento_local_metrics_collector.py new file mode 100644 index 00000000..43762465 --- /dev/null +++ b/tests/momento/local/momento_local_metrics_collector.py @@ -0,0 +1,59 @@ +from collections import defaultdict +from typing import Dict, List + +from tests.momento.local.momento_rpc_method import MomentoRpcMethod + + +class MomentoLocalMetricsCollector: + def __init__(self) -> None: + # Data structure to store timestamps: cacheName -> requestName -> [timestamps] + self.data: Dict[str, Dict[MomentoRpcMethod, List[int]]] = defaultdict(lambda: defaultdict(list)) + + def add_timestamp(self, cache_name: str, request_name: MomentoRpcMethod, timestamp: int) -> None: + """Add a timestamp for a specific request and cache. + + Args: + cache_name: The name of the cache + request_name: The name of the request (using MomentoRpcMethod enum) + timestamp: The timestamp to record in seconds since epoch + """ + self.data[cache_name][request_name].append(timestamp) + + def get_total_retry_count(self, cache_name: str, request_name: MomentoRpcMethod) -> int: + """Calculate the total retry count for a specific cache and request. + + Args: + cache_name: The name of the cache + request_name: The name of the request (using MomentoRpcMethod enum) + + Returns: + The total number of retries + """ + timestamps = self.data.get(cache_name, {}).get(request_name, []) + # Number of retries is one less than the number of timestamps + return max(0, len(timestamps) - 1) + + def get_average_time_between_retries(self, cache_name: str, request_name: MomentoRpcMethod) -> float: + """Calculate the average time between retries for a specific cache and request. + + Args: + cache_name: The name of the cache + request_name: The name of the request (using MomentoRpcMethod enum) + + Returns: + The average time in seconds, or 0.0 if there are no retries + """ + timestamps = self.data.get(cache_name, {}).get(request_name, []) + if len(timestamps) < 2: + return 0.0 # No retries occurred + + total_interval = sum(timestamps[i] - timestamps[i - 1] for i in range(1, len(timestamps))) + return total_interval / (len(timestamps) - 1) + + def get_all_metrics(self) -> Dict[str, Dict[MomentoRpcMethod, List[int]]]: + """Retrieve all collected metrics for debugging or analysis. + + Returns: + The complete data structure with all recorded metrics + """ + return self.data diff --git a/tests/momento/local/momento_local_middleware.py b/tests/momento/local/momento_local_middleware.py new file mode 100644 index 00000000..4b689a0f --- /dev/null +++ b/tests/momento/local/momento_local_middleware.py @@ -0,0 +1,121 @@ +import asyncio +from typing import List, Optional + +from grpc._typing import MetadataType +from momento import logs +from momento.config.middleware import MiddlewareMessage, MiddlewareRequestHandlerContext, MiddlewareStatus +from momento.config.middleware.synchronous import Middleware, MiddlewareMetadata, MiddlewareRequestHandler + +from tests.momento.local.momento_error_code_metadata import MOMENTO_ERROR_CODE_TO_METADATA +from tests.momento.local.momento_local_middleware_args import MomentoLocalMiddlewareArgs +from tests.momento.local.momento_rpc_method import MomentoRpcMethod + + +class MomentoLocalMiddlewareRequestHandler(MiddlewareRequestHandler): + def __init__(self, args: MomentoLocalMiddlewareArgs): + self._args = args + self._cache_name: Optional[str] = None + self._logger = logs.logger + + def on_request_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata: + grpc_metadata = metadata.grpc_metadata + + if grpc_metadata is not None: + self._set_grpc_metadata(grpc_metadata, "request-id", self._args.request_id) + + if self._args.return_error is not None: + error = MOMENTO_ERROR_CODE_TO_METADATA[self._args.return_error] + if error is not None: + self._set_grpc_metadata(grpc_metadata, "return-error", error) + + if self._args.error_rpc_list is not None: + rpcs = self._concatenate_rpcs(self._args.error_rpc_list) + self._set_grpc_metadata(grpc_metadata, "error-rpcs", rpcs) + + if self._args.delay_rpc_list is not None: + rpcs = self._concatenate_rpcs(self._args.delay_rpc_list) + self._set_grpc_metadata(grpc_metadata, "delay-rpcs", rpcs) + + if self._args.error_count is not None: + self._set_grpc_metadata(grpc_metadata, "error-count", str(self._args.error_count)) + + if self._args.delay_millis is not None: + self._set_grpc_metadata(grpc_metadata, "delay-ms", str(self._args.delay_millis)) + + if self._args.delay_count is not None: + self._set_grpc_metadata(grpc_metadata, "delay-count", str(self._args.delay_count)) + + if self._args.stream_error_rpc_list is not None: + rpcs = self._concatenate_rpcs(self._args.stream_error_rpc_list) + self._set_grpc_metadata(grpc_metadata, "stream-error-rpcs", rpcs) + + if self._args.stream_error is not None: + error = MOMENTO_ERROR_CODE_TO_METADATA[self._args.stream_error] + if error is not None: + self._set_grpc_metadata(grpc_metadata, "stream-error", error) + + if self._args.stream_error_message_limit is not None: + limit_str = str(self._args.stream_error_message_limit) + self._set_grpc_metadata(grpc_metadata, "stream-error-message-limit", limit_str) + + cache_name = self._get_from_metadata(grpc_metadata, "cache") + if cache_name is not None: + self._cache_name = cache_name + else: + self._logger.debug("No cache name found in metadata.") + + return metadata + + def on_request_body(self, request: MiddlewareMessage) -> MiddlewareMessage: + request_type = request.constructor_name + + if self._cache_name is not None: + if self._args.test_metrics_collector is not None: + rpc_method = MomentoRpcMethod.from_request_name(request_type) + if rpc_method: + self._args.test_metrics_collector.add_timestamp( + self._cache_name, + rpc_method, + int(asyncio.get_event_loop().time() * 1000), # Current time in milliseconds + ) + else: + self._logger.debug("No cache name available. Timestamp will not be collected.") + + return request + + def on_response_metadata(self, metadata: MiddlewareMetadata) -> MiddlewareMetadata: + return metadata + + def on_response_body(self, response: MiddlewareMessage) -> MiddlewareMessage: + return response + + def on_response_status(self, status: MiddlewareStatus) -> MiddlewareStatus: + return status + + @staticmethod + def _set_grpc_metadata(metadata: MetadataType, key: str, value: str) -> None: + for i, (k, _) in enumerate(metadata): + if k == key: + metadata[i] = (key, value) + break + else: + metadata.append((key, value)) + + @staticmethod + def _get_from_metadata(metadata: MetadataType, key: str) -> Optional[str]: + for k, v in metadata: + if k == key: + return str(v) + return None + + @staticmethod + def _concatenate_rpcs(rpcs: List[MomentoRpcMethod]) -> str: + return " ".join(rpc.metadata for rpc in rpcs) + + +class MomentoLocalMiddleware(Middleware): + def __init__(self, args: MomentoLocalMiddlewareArgs) -> None: + self._args = args + + def on_new_request(self, context: MiddlewareRequestHandlerContext) -> MiddlewareRequestHandler: + return MomentoLocalMiddlewareRequestHandler(self._args) diff --git a/tests/momento/local/momento_local_middleware_args.py b/tests/momento/local/momento_local_middleware_args.py new file mode 100644 index 00000000..23267472 --- /dev/null +++ b/tests/momento/local/momento_local_middleware_args.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import List, Optional + +from momento.errors import MomentoErrorCode + +from tests.momento.local.momento_local_metrics_collector import MomentoLocalMetricsCollector +from tests.momento.local.momento_rpc_method import MomentoRpcMethod + + +@dataclass +class MomentoLocalMiddlewareArgs: + """Arguments for Momento local middleware.""" + + request_id: str + test_metrics_collector: Optional[MomentoLocalMetricsCollector] = None + return_error: Optional[MomentoErrorCode] = None + error_rpc_list: Optional[List[MomentoRpcMethod]] = None + error_count: Optional[int] = None + delay_rpc_list: Optional[List[MomentoRpcMethod]] = None + delay_millis: Optional[int] = None + delay_count: Optional[int] = None + stream_error_rpc_list: Optional[List[MomentoRpcMethod]] = None + stream_error: Optional[MomentoErrorCode] = None + stream_error_message_limit: Optional[int] = None diff --git a/tests/momento/local/momento_rpc_method.py b/tests/momento/local/momento_rpc_method.py new file mode 100644 index 00000000..8d5678e5 --- /dev/null +++ b/tests/momento/local/momento_rpc_method.py @@ -0,0 +1,70 @@ +from enum import Enum +from typing import Optional + + +class MomentoRpcMethod(Enum): + GET = ("_GetRequest", "get") + SET = ("_SetRequest", "set") + DELETE = ("_DeleteRequest", "delete") + INCREMENT = ("_IncrementRequest", "increment") + SET_IF = ("_SetIfRequest", "set-if") + SET_IF_NOT_EXISTS = ("_SetIfNotExistsRequest", "set-if") + GET_BATCH = ("_GetBatchRequest", "get-batch") + SET_BATCH = ("_SetBatchRequest", "set-batch") + KEYS_EXIST = ("_KeysExistRequest", "keys-exist") + UPDATE_TTL = ("_UpdateTtlRequest", "update-ttl") + ITEM_GET_TTL = ("_ItemGetTtlRequest", "item-get-ttl") + ITEM_GET_TYPE = ("_ItemGetTypeRequest", "item-get-type") + DICTIONARY_GET = ("_DictionaryGetRequest", "dictionary-get") + DICTIONARY_FETCH = ("_DictionaryFetchRequest", "dictionary-fetch") + DICTIONARY_SET = ("_DictionarySetRequest", "dictionary-set") + DICTIONARY_INCREMENT = ("_DictionaryIncrementRequest", "dictionary-increment") + DICTIONARY_DELETE = ("_DictionaryDeleteRequest", "dictionary-delete") + DICTIONARY_LENGTH = ("_DictionaryLengthRequest", "dictionary-length") + SET_FETCH = ("_SetFetchRequest", "set-fetch") + SET_SAMPLE = ("_SetSampleRequest", "set-sample") + SET_UNION = ("_SetUnionRequest", "set-union") + SET_DIFFERENCE = ("_SetDifferenceRequest", "set-difference") + SET_CONTAINS = ("_SetContainsRequest", "set-contains") + SET_LENGTH = ("_SetLengthRequest", "set-length") + SET_POP = ("_SetPopRequest", "set-pop") + LIST_PUSH_FRONT = ("_ListPushFrontRequest", "list-push-front") + LIST_PUSH_BACK = ("_ListPushBackRequest", "list-push-back") + LIST_POP_FRONT = ("_ListPopFrontRequest", "list-push-front") + LIST_POP_BACK = ("_ListPopBackRequest", "list-pop-back") + LIST_ERASE = ("_ListEraseRequest", "list-remove") # Alias for list-remove + LIST_REMOVE = ("_ListRemoveRequest", "list-remove") + LIST_FETCH = ("_ListFetchRequest", "list-fetch") + LIST_LENGTH = ("_ListLengthRequest", "list-length") + LIST_CONCATENATE_FRONT = ("_ListConcatenateFrontRequest", "list-concatenate-front") + LIST_CONCATENATE_BACK = ("_ListConcatenateBackRequest", "list-concatenate-back") + LIST_RETAIN = ("_ListRetainRequest", "list-retain") + SORTED_SET_PUT = ("_SortedSetPutRequest", "sorted-set-put") + SORTED_SET_FETCH = ("_SortedSetFetchRequest", "sorted-set-fetch") + SORTED_SET_GET_SCORE = ("_SortedSetGetScoreRequest", "sorted-set-get-score") + SORTED_SET_REMOVE = ("_SortedSetRemoveRequest", "sorted-set-remove") + SORTED_SET_INCREMENT = ("_SortedSetIncrementRequest", "sorted-set-increment") + SORTED_SET_GET_RANK = ("_SortedSetGetRankRequest", "sorted-set-get-rank") + SORTED_SET_LENGTH = ("_SortedSetLengthRequest", "sorted-set-length") + SORTED_SET_LENGTH_BY_SCORE = ("_SortedSetLengthByScoreRequest", "sorted-set-length-by-score") + TOPIC_PUBLISH = ("_PublishRequest", "topic-publish") + TOPIC_SUBSCRIBE = ("_SubscriptionRequest", "topic-subscribe") + + def __init__(self, request_name: str, metadata: str) -> None: + self._request_name = request_name + self._metadata = metadata + + @property + def request_name(self) -> str: + return self._request_name + + @property + def metadata(self) -> str: + return self._metadata + + @classmethod + def from_request_name(cls, request_name: str) -> Optional["MomentoRpcMethod"]: + for method in cls: + if method.request_name == request_name: + return method + return None diff --git a/tests/momento/local/test_fixed_count_retry_strategy.py b/tests/momento/local/test_fixed_count_retry_strategy.py new file mode 100644 index 00000000..3df7c54d --- /dev/null +++ b/tests/momento/local/test_fixed_count_retry_strategy.py @@ -0,0 +1,72 @@ +import pytest +from momento.errors import MomentoErrorCode +from momento.responses import CacheGet, CacheIncrement + +from tests.conftest import client_local +from tests.momento.local.momento_local_async_middleware import MomentoLocalMiddlewareArgs +from tests.momento.local.momento_local_metrics_collector import MomentoLocalMetricsCollector +from tests.momento.local.momento_rpc_method import MomentoRpcMethod +from tests.utils import uuid_str + + +@pytest.mark.local +def test_retry_eligible_api_should_make_max_attempts_when_full_network_outage() -> None: + metrics_collector = MomentoLocalMetricsCollector() + middleware_args = MomentoLocalMiddlewareArgs( + request_id=str(uuid_str()), + test_metrics_collector=metrics_collector, + return_error=MomentoErrorCode.SERVER_UNAVAILABLE, + error_rpc_list=[MomentoRpcMethod.GET], + ) + cache_name = uuid_str() + + with client_local(cache_name, middleware_args) as client: + response = client.get(cache_name, "key") + + assert isinstance(response, CacheGet.Error) + assert response.error_code == MomentoErrorCode.SERVER_UNAVAILABLE + + retry_count = metrics_collector.get_total_retry_count(cache_name, MomentoRpcMethod.GET) + assert retry_count == 3 + + +@pytest.mark.local +def test_non_retry_eligible_api_should_make_no_attempts_when_full_network_outage() -> None: + metrics_collector = MomentoLocalMetricsCollector() + middleware_args = MomentoLocalMiddlewareArgs( + request_id=str(uuid_str()), + test_metrics_collector=metrics_collector, + return_error=MomentoErrorCode.SERVER_UNAVAILABLE, + error_rpc_list=[MomentoRpcMethod.INCREMENT], + ) + cache_name = uuid_str() + + with client_local(cache_name, middleware_args) as client: + response = client.increment(cache_name, "key", 1) + + assert isinstance(response, CacheIncrement.Error) + assert response.error_code == MomentoErrorCode.SERVER_UNAVAILABLE + + retry_count = metrics_collector.get_total_retry_count(cache_name, MomentoRpcMethod.INCREMENT) + assert retry_count == 0 + + +@pytest.mark.local +def test_retry_eligible_api_should_make_less_than_max_attempts_when_temporary_network_outage() -> None: + metrics_collector = MomentoLocalMetricsCollector() + middleware_args = MomentoLocalMiddlewareArgs( + request_id=str(uuid_str()), + test_metrics_collector=metrics_collector, + return_error=MomentoErrorCode.SERVER_UNAVAILABLE, + error_rpc_list=[MomentoRpcMethod.GET], + error_count=2, + ) + cache_name = uuid_str() + + with client_local(cache_name, middleware_args) as client: + response = client.get(cache_name, "key") + + assert isinstance(response, CacheGet.Miss) + + retry_count = metrics_collector.get_total_retry_count(cache_name, MomentoRpcMethod.GET) + assert 2 <= retry_count <= 3 diff --git a/tests/momento/local/test_fixed_count_retry_strategy_async.py b/tests/momento/local/test_fixed_count_retry_strategy_async.py new file mode 100644 index 00000000..463b9877 --- /dev/null +++ b/tests/momento/local/test_fixed_count_retry_strategy_async.py @@ -0,0 +1,75 @@ +import pytest +from momento.errors import MomentoErrorCode +from momento.responses import CacheGet, CacheIncrement + +from tests.conftest import client_async_local +from tests.momento.local.momento_local_async_middleware import MomentoLocalMiddlewareArgs +from tests.momento.local.momento_local_metrics_collector import MomentoLocalMetricsCollector +from tests.momento.local.momento_rpc_method import MomentoRpcMethod +from tests.utils import uuid_str + + +@pytest.mark.asyncio +@pytest.mark.local +async def test_retry_eligible_api_should_make_max_attempts_when_full_network_outage() -> None: + metrics_collector = MomentoLocalMetricsCollector() + middleware_args = MomentoLocalMiddlewareArgs( + request_id=str(uuid_str()), + test_metrics_collector=metrics_collector, + return_error=MomentoErrorCode.SERVER_UNAVAILABLE, + error_rpc_list=[MomentoRpcMethod.GET], + ) + cache_name = uuid_str() + + async with client_async_local(cache_name, middleware_args) as client: + response = await client.get(cache_name, "key") + + assert isinstance(response, CacheGet.Error) + assert response.error_code == MomentoErrorCode.SERVER_UNAVAILABLE + + retry_count = metrics_collector.get_total_retry_count(cache_name, MomentoRpcMethod.GET) + assert retry_count == 3 + + +@pytest.mark.asyncio +@pytest.mark.local +async def test_non_retry_eligible_api_should_make_no_attempts_when_full_network_outage() -> None: + metrics_collector = MomentoLocalMetricsCollector() + middleware_args = MomentoLocalMiddlewareArgs( + request_id=str(uuid_str()), + test_metrics_collector=metrics_collector, + return_error=MomentoErrorCode.SERVER_UNAVAILABLE, + error_rpc_list=[MomentoRpcMethod.INCREMENT], + ) + cache_name = uuid_str() + + async with client_async_local(cache_name, middleware_args) as client: + response = await client.increment(cache_name, "key", 1) + + assert isinstance(response, CacheIncrement.Error) + assert response.error_code == MomentoErrorCode.SERVER_UNAVAILABLE + + retry_count = metrics_collector.get_total_retry_count(cache_name, MomentoRpcMethod.INCREMENT) + assert retry_count == 0 + + +@pytest.mark.asyncio +@pytest.mark.local +async def test_retry_eligible_api_should_make_less_than_max_attempts_when_temporary_network_outage() -> None: + metrics_collector = MomentoLocalMetricsCollector() + middleware_args = MomentoLocalMiddlewareArgs( + request_id=str(uuid_str()), + test_metrics_collector=metrics_collector, + return_error=MomentoErrorCode.SERVER_UNAVAILABLE, + error_rpc_list=[MomentoRpcMethod.GET], + error_count=2, + ) + cache_name = uuid_str() + + async with client_async_local(cache_name, middleware_args) as client: + response = await client.get(cache_name, "key") + + assert isinstance(response, CacheGet.Miss) + + retry_count = metrics_collector.get_total_retry_count(cache_name, MomentoRpcMethod.GET) + assert 2 <= retry_count <= 3