diff --git a/py_hamt/store_httpx.py b/py_hamt/store_httpx.py index a4bb45d..ecbe4d0 100644 --- a/py_hamt/store_httpx.py +++ b/py_hamt/store_httpx.py @@ -1,4 +1,5 @@ import asyncio +import random import re from abc import ABC, abstractmethod from typing import Any, Literal, Tuple, cast @@ -146,6 +147,9 @@ def __init__( headers: dict[str, str] | None = None, auth: Tuple[str, str] | None = None, chunker: str = "size-1048576", + max_retries: int = 3, + initial_delay: float = 1.0, + backoff_factor: float = 2.0, ): """ If None is passed into the rpc or gateway base url, then the default for kubo local daemons will be used. The default local values will also be used if nothing is passed in at all. @@ -219,14 +223,24 @@ def __init__( self._owns_client = True self._client_per_loop = {} - # The instance is never closed on initialization. - self._closed = False - # store for later use by _loop_client() self._default_headers = headers self._default_auth = auth self._sem: asyncio.Semaphore = asyncio.Semaphore(concurrency) + self._closed: bool = False + + # Validate retry parameters + if max_retries < 0: + raise ValueError("max_retries must be non-negative") + if initial_delay <= 0: + raise ValueError("initial_delay must be positive") + if backoff_factor < 1.0: + raise ValueError("backoff_factor must be >= 1.0 for exponential backoff") + + self.max_retries = max_retries + self.initial_delay = initial_delay + self.backoff_factor = backoff_factor # --------------------------------------------------------------------- # # helper: get or create the client bound to the current running loop # @@ -338,28 +352,78 @@ def __del__(self) -> None: # save() – now uses the per-loop client # # --------------------------------------------------------------------- # async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> CID: - async with self._sem: # throttle RPC - # Create multipart form data + async with self._sem: files = {"file": data} - - # Send the POST request client = self._loop_client() - response = await client.post(self.rpc_url, files=files) - response.raise_for_status() - cid_str: str = response.json()["Hash"] + retry_count = 0 - cid: CID = CID.decode(cid_str) - if cid.codec.code != self.DAG_PB_MARKER: - cid = cid.set(codec=codec) - return cid + while retry_count <= self.max_retries: + try: + response = await client.post( + self.rpc_url, files=files, timeout=60.0 + ) + response.raise_for_status() + cid_str: str = response.json()["Hash"] + cid: CID = CID.decode(cid_str) + if cid.codec.code != self.DAG_PB_MARKER: + cid = cid.set(codec=codec) + return cid + + except (httpx.TimeoutException, httpx.RequestError) as e: + retry_count += 1 + if retry_count > self.max_retries: + raise httpx.TimeoutException( + f"Failed to save data after {self.max_retries} retries: {str(e)}", + request=e.request + if isinstance(e, httpx.RequestError) + else None, + ) + + # Calculate backoff delay + delay = self.initial_delay * ( + self.backoff_factor ** (retry_count - 1) + ) + # Add some jitter to prevent thundering herd + jitter = delay * 0.1 * (random.random() - 0.5) + await asyncio.sleep(delay + jitter) + + except httpx.HTTPStatusError: + # Re-raise non-timeout HTTP errors immediately + raise + raise RuntimeError("Exited the retry loop unexpectedly.") # pragma: no cover async def load(self, id: IPLDKind) -> bytes: - """@private""" - cid = cast(CID, id) # CID is definitely in the IPLDKind type + cid = cast(CID, id) url: str = f"{self.gateway_base_url + str(cid)}" - - async with self._sem: # throttle gateway + async with self._sem: client = self._loop_client() - response = await client.get(url) - response.raise_for_status() - return response.content + retry_count = 0 + + while retry_count <= self.max_retries: + try: + response = await client.get(url, timeout=60.0) + response.raise_for_status() + return response.content + + except (httpx.TimeoutException, httpx.RequestError) as e: + retry_count += 1 + if retry_count > self.max_retries: + raise httpx.TimeoutException( + f"Failed to load data after {self.max_retries} retries: {str(e)}", + request=e.request + if isinstance(e, httpx.RequestError) + else None, + ) + + # Calculate backoff delay + delay = self.initial_delay * ( + self.backoff_factor ** (retry_count - 1) + ) + # Add some jitter to prevent thundering herd + jitter = delay * 0.1 * (random.random() - 0.5) + await asyncio.sleep(delay + jitter) + + except httpx.HTTPStatusError: + # Re-raise non-timeout HTTP errors immediately + raise + raise RuntimeError("Exited the retry loop unexpectedly.") # pragma: no cover diff --git a/tests/test_kubo_cas.py b/tests/test_kubo_cas.py index db138b1..a508313 100644 --- a/tests/test_kubo_cas.py +++ b/tests/test_kubo_cas.py @@ -1,4 +1,6 @@ +import asyncio from typing import Literal, cast +from unittest.mock import AsyncMock, patch import dag_cbor import httpx @@ -164,3 +166,208 @@ async def test_chunker_invalid_patterns(invalid): with pytest.raises(ValueError, match="Invalid chunker specification"): async with KuboCAS(chunker=invalid): pass + + +@pytest.mark.asyncio +async def test_kubo_timeout_retries(): + """ + Test that KuboCAS handles timeouts with retries and exponential backoff + for both save and load operations using unittest.mock. + """ + timeout_count = 0 + successful_after = 2 # Succeed after 2 timeout attempts + test_cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi" + + async def mock_post(url, **kwargs): + nonlocal timeout_count + # Manually create a dummy request object + dummy_request = httpx.Request("POST", url, files=kwargs.get("files")) + if timeout_count < successful_after: + timeout_count += 1 + raise httpx.TimeoutException("Simulated timeout", request=dummy_request) + return httpx.Response(200, json={"Hash": test_cid}, request=dummy_request) + + async def mock_get(url, **kwargs): + nonlocal timeout_count + # Manually create a dummy request object + dummy_request = httpx.Request("GET", url) + if timeout_count < successful_after: + timeout_count += 1 + raise httpx.TimeoutException("Simulated timeout", request=dummy_request) + return httpx.Response(200, content=test_data, request=dummy_request) + + # Patch the httpx.AsyncClient methods + with patch.object(httpx.AsyncClient, "post", new=AsyncMock(side_effect=mock_post)): + with patch.object( + httpx.AsyncClient, "get", new=AsyncMock(side_effect=mock_get) + ): + async with httpx.AsyncClient() as client: + async with KuboCAS( + rpc_base_url="http://127.0.0.1:5001", + gateway_base_url="http://127.0.0.1:8080", + client=client, + max_retries=3, + initial_delay=0.1, + backoff_factor=2.0, + ) as kubo_cas: + # Test save with retries + timeout_count = 0 + test_data = dag_cbor.encode("test") + cid = await kubo_cas.save(test_data, codec="dag-cbor") + assert timeout_count == successful_after, ( + "Should have retried twice before success" + ) + assert str(cid) == test_cid + + # Test load with retries + timeout_count = 0 + result = await kubo_cas.load(cid) + assert timeout_count == successful_after, ( + "Should have retried twice before success" + ) + assert result == test_data + + # Test failure after max retries + async def failing_method(url, **kwargs): + dummy_request = httpx.Request( + "POST", url + ) # Create the dummy request + raise httpx.TimeoutException( + "Simulated timeout", request=dummy_request + ) + + with patch.object( + httpx.AsyncClient, + "post", + new=AsyncMock(side_effect=failing_method), + ): + with patch.object( + httpx.AsyncClient, + "get", + new=AsyncMock(side_effect=failing_method), + ): + with pytest.raises( + httpx.TimeoutException, + match="Failed to save data after 3 retries", + ): + await kubo_cas.save(test_data, codec="dag-cbor") + + with pytest.raises( + httpx.TimeoutException, + match="Failed to load data after 3 retries", + ): + await kubo_cas.load(cid) + + +@pytest.mark.asyncio +async def test_kubo_backoff_timing(): + """ + Test that KuboCAS implements exponential backoff with jitter correctly. + """ + + async def timeout_method(url, **kwargs): + # Manually create a dummy request for the exception + dummy_request = httpx.Request("POST", url) + raise httpx.TimeoutException("Simulated timeout", request=dummy_request) + + # Patch sleep to record timing + original_sleep = asyncio.sleep + sleep_times = [] + + async def mock_sleep(delay): + sleep_times.append(delay) + # Call the original sleep function to avoid recursion + await original_sleep(0) + + with patch.object( + httpx.AsyncClient, "post", new=AsyncMock(side_effect=timeout_method) + ): + async with httpx.AsyncClient() as client: + async with KuboCAS( + rpc_base_url="http://127.0.0.1:5001", + gateway_base_url="http://127.0.0.1:8080", + client=client, + max_retries=3, + initial_delay=0.1, + backoff_factor=2.0, + ) as kubo_cas: + with patch("asyncio.sleep", side_effect=mock_sleep): + with pytest.raises(httpx.TimeoutException): + await kubo_cas.save(b"test", codec="dag-cbor") + + # Verify backoff timing + assert len(sleep_times) == 3, "Should have attempted 3 retries" + assert 0.09 <= sleep_times[0] <= 0.11, "First retry should be ~0.1s" + assert 0.18 <= sleep_times[1] <= 0.22, ( + "Second retry should be ~0.2s" + ) + assert 0.36 <= sleep_times[2] <= 0.44, "Third retry should be ~0.4s" + + +@pytest.mark.asyncio +async def test_kubo_http_status_error_no_retry(): + """ + Tests that KuboCAS immediately raises HTTPStatusError without retrying. + """ + + # This mock simulates a server error by returning a 500 status code. + async def mock_post_server_error(url, **kwargs): + dummy_request = httpx.Request("POST", url) + return httpx.Response( + 500, request=dummy_request, content=b"Internal Server Error" + ) + + # Patch the client's post method to always return the 500 error. + with patch.object( + httpx.AsyncClient, "post", new=AsyncMock(side_effect=mock_post_server_error) + ): + # Also patch asyncio.sleep to verify it's not called (i.e., no retries). + with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep: + async with httpx.AsyncClient() as client: + async with KuboCAS(client=client) as kubo_cas: + # Assert that the specific error is raised. + with pytest.raises(httpx.HTTPStatusError) as exc_info: + await kubo_cas.save(b"some data", codec="raw") + + # Verify that the response in the exception has the correct status code. + assert exc_info.value.response.status_code == 500 + # Verify that no retry was attempted. + mock_sleep.assert_not_called() + + +@pytest.mark.asyncio +async def test_kubo_cas_retry_validation(): + """Test validation of retry parameters in KuboCAS constructor""" + + # Test max_retries validation + with pytest.raises(ValueError, match="max_retries must be non-negative"): + KuboCAS(max_retries=-1) + + with pytest.raises(ValueError, match="max_retries must be non-negative"): + KuboCAS(max_retries=-5) + + # Test initial_delay validation + with pytest.raises(ValueError, match="initial_delay must be positive"): + KuboCAS(initial_delay=0) + + with pytest.raises(ValueError, match="initial_delay must be positive"): + KuboCAS(initial_delay=-1.0) + + # Test backoff_factor validation + with pytest.raises( + ValueError, match="backoff_factor must be >= 1.0 for exponential backoff" + ): + KuboCAS(backoff_factor=0.5) + + with pytest.raises( + ValueError, match="backoff_factor must be >= 1.0 for exponential backoff" + ): + KuboCAS(backoff_factor=0.9) + + # Test valid edge case values + async with KuboCAS( + max_retries=0, initial_delay=0.001, backoff_factor=1.0 + ) as kubo_cas: + assert kubo_cas.max_retries == 0 + assert kubo_cas.initial_delay == 0.001 + assert kubo_cas.backoff_factor == 1.0 diff --git a/tests/test_public_gateway.py b/tests/test_public_gateway.py index 0cbc4c8..6a988cc 100644 --- a/tests/test_public_gateway.py +++ b/tests/test_public_gateway.py @@ -152,6 +152,7 @@ async def test_kubocas_public_gateway(): cas_save = KuboCAS( rpc_base_url="http://127.0.0.1:5001", gateway_base_url="http://127.0.0.1:8080", + max_retries=0, ) try: @@ -218,6 +219,7 @@ async def test_trailing_slash_gateway(): cas = KuboCAS( rpc_base_url="http://127.0.0.1:5001", gateway_base_url="http://127.0.0.1:8080/", # Note the trailing slash + max_retries=0, ) try: @@ -278,7 +280,11 @@ async def test_fix_kubocas_load(): ] for input_url, expected_base in test_cases: - cas = KuboCAS(rpc_base_url="http://127.0.0.1:5001", gateway_base_url=input_url) + cas = KuboCAS( + rpc_base_url="http://127.0.0.1:5001", + gateway_base_url=input_url, + max_retries=0, + ) assert cas.gateway_base_url == expected_base, ( f"URL construction failed for {input_url}" ) @@ -286,7 +292,9 @@ async def test_fix_kubocas_load(): # Test actual loading with local gateway cas = KuboCAS( - rpc_base_url="http://127.0.0.1:5001", gateway_base_url="http://127.0.0.1:8080" + rpc_base_url="http://127.0.0.1:5001", + gateway_base_url="http://127.0.0.1:8080", + max_retries=0, ) try: