Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 85 additions & 21 deletions py_hamt/store_httpx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import random
import re
from abc import ABC, abstractmethod
from typing import Any, Literal, Tuple, cast
Expand Down Expand Up @@ -146,6 +147,9 @@ def __init__(
headers: dict[str, str] | None = None,
auth: Tuple[str, str] | None = None,
chunker: str = "size-1048576",
max_retries: int = 3,
initial_delay: float = 1.0,
backoff_factor: float = 2.0,
):
"""
If None is passed into the rpc or gateway base url, then the default for kubo local daemons will be used. The default local values will also be used if nothing is passed in at all.
Expand Down Expand Up @@ -219,14 +223,24 @@ def __init__(
self._owns_client = True
self._client_per_loop = {}

# The instance is never closed on initialization.
self._closed = False

# store for later use by _loop_client()
self._default_headers = headers
self._default_auth = auth

self._sem: asyncio.Semaphore = asyncio.Semaphore(concurrency)
self._closed: bool = False

# Validate retry parameters
if max_retries < 0:
raise ValueError("max_retries must be non-negative")
if initial_delay <= 0:
raise ValueError("initial_delay must be positive")
if backoff_factor < 1.0:
raise ValueError("backoff_factor must be >= 1.0 for exponential backoff")

self.max_retries = max_retries
self.initial_delay = initial_delay
self.backoff_factor = backoff_factor

# --------------------------------------------------------------------- #
# helper: get or create the client bound to the current running loop #
Expand Down Expand Up @@ -338,28 +352,78 @@ def __del__(self) -> None:
# save() – now uses the per-loop client #
# --------------------------------------------------------------------- #
async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> CID:
async with self._sem: # throttle RPC
# Create multipart form data
async with self._sem:
files = {"file": data}

# Send the POST request
client = self._loop_client()
response = await client.post(self.rpc_url, files=files)
response.raise_for_status()
cid_str: str = response.json()["Hash"]
retry_count = 0

cid: CID = CID.decode(cid_str)
if cid.codec.code != self.DAG_PB_MARKER:
cid = cid.set(codec=codec)
return cid
while retry_count <= self.max_retries:
try:
response = await client.post(
self.rpc_url, files=files, timeout=60.0
)
response.raise_for_status()
cid_str: str = response.json()["Hash"]
cid: CID = CID.decode(cid_str)
if cid.codec.code != self.DAG_PB_MARKER:
cid = cid.set(codec=codec)
return cid

except (httpx.TimeoutException, httpx.RequestError) as e:
retry_count += 1
if retry_count > self.max_retries:
raise httpx.TimeoutException(
f"Failed to save data after {self.max_retries} retries: {str(e)}",
request=e.request
if isinstance(e, httpx.RequestError)
else None,
)

# Calculate backoff delay
delay = self.initial_delay * (
self.backoff_factor ** (retry_count - 1)
)
# Add some jitter to prevent thundering herd
jitter = delay * 0.1 * (random.random() - 0.5)
await asyncio.sleep(delay + jitter)

except httpx.HTTPStatusError:
# Re-raise non-timeout HTTP errors immediately
raise
raise RuntimeError("Exited the retry loop unexpectedly.") # pragma: no cover

async def load(self, id: IPLDKind) -> bytes:
"""@private"""
cid = cast(CID, id) # CID is definitely in the IPLDKind type
cid = cast(CID, id)
url: str = f"{self.gateway_base_url + str(cid)}"

async with self._sem: # throttle gateway
async with self._sem:
client = self._loop_client()
response = await client.get(url)
response.raise_for_status()
return response.content
retry_count = 0

while retry_count <= self.max_retries:
try:
response = await client.get(url, timeout=60.0)
response.raise_for_status()
return response.content

except (httpx.TimeoutException, httpx.RequestError) as e:
retry_count += 1
if retry_count > self.max_retries:
raise httpx.TimeoutException(
f"Failed to load data after {self.max_retries} retries: {str(e)}",
request=e.request
if isinstance(e, httpx.RequestError)
else None,
)

# Calculate backoff delay
delay = self.initial_delay * (
self.backoff_factor ** (retry_count - 1)
)
# Add some jitter to prevent thundering herd
jitter = delay * 0.1 * (random.random() - 0.5)
await asyncio.sleep(delay + jitter)

except httpx.HTTPStatusError:
# Re-raise non-timeout HTTP errors immediately
raise
raise RuntimeError("Exited the retry loop unexpectedly.") # pragma: no cover
207 changes: 207 additions & 0 deletions tests/test_kubo_cas.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
from typing import Literal, cast
from unittest.mock import AsyncMock, patch

import dag_cbor
import httpx
Expand Down Expand Up @@ -164,3 +166,208 @@ async def test_chunker_invalid_patterns(invalid):
with pytest.raises(ValueError, match="Invalid chunker specification"):
async with KuboCAS(chunker=invalid):
pass


@pytest.mark.asyncio
async def test_kubo_timeout_retries():
"""
Test that KuboCAS handles timeouts with retries and exponential backoff
for both save and load operations using unittest.mock.
"""
timeout_count = 0
successful_after = 2 # Succeed after 2 timeout attempts
test_cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"

async def mock_post(url, **kwargs):
nonlocal timeout_count
# Manually create a dummy request object
dummy_request = httpx.Request("POST", url, files=kwargs.get("files"))
if timeout_count < successful_after:
timeout_count += 1
raise httpx.TimeoutException("Simulated timeout", request=dummy_request)
return httpx.Response(200, json={"Hash": test_cid}, request=dummy_request)

async def mock_get(url, **kwargs):
nonlocal timeout_count
# Manually create a dummy request object
dummy_request = httpx.Request("GET", url)
if timeout_count < successful_after:
timeout_count += 1
raise httpx.TimeoutException("Simulated timeout", request=dummy_request)
return httpx.Response(200, content=test_data, request=dummy_request)

# Patch the httpx.AsyncClient methods
with patch.object(httpx.AsyncClient, "post", new=AsyncMock(side_effect=mock_post)):
with patch.object(
httpx.AsyncClient, "get", new=AsyncMock(side_effect=mock_get)
):
async with httpx.AsyncClient() as client:
async with KuboCAS(
rpc_base_url="http://127.0.0.1:5001",
gateway_base_url="http://127.0.0.1:8080",
client=client,
max_retries=3,
initial_delay=0.1,
backoff_factor=2.0,
) as kubo_cas:
# Test save with retries
timeout_count = 0
test_data = dag_cbor.encode("test")
cid = await kubo_cas.save(test_data, codec="dag-cbor")
assert timeout_count == successful_after, (
"Should have retried twice before success"
)
assert str(cid) == test_cid

# Test load with retries
timeout_count = 0
result = await kubo_cas.load(cid)
assert timeout_count == successful_after, (
"Should have retried twice before success"
)
assert result == test_data

# Test failure after max retries
async def failing_method(url, **kwargs):
dummy_request = httpx.Request(
"POST", url
) # Create the dummy request
raise httpx.TimeoutException(
"Simulated timeout", request=dummy_request
)

with patch.object(
httpx.AsyncClient,
"post",
new=AsyncMock(side_effect=failing_method),
):
with patch.object(
httpx.AsyncClient,
"get",
new=AsyncMock(side_effect=failing_method),
):
with pytest.raises(
httpx.TimeoutException,
match="Failed to save data after 3 retries",
):
await kubo_cas.save(test_data, codec="dag-cbor")

with pytest.raises(
httpx.TimeoutException,
match="Failed to load data after 3 retries",
):
await kubo_cas.load(cid)


@pytest.mark.asyncio
async def test_kubo_backoff_timing():
"""
Test that KuboCAS implements exponential backoff with jitter correctly.
"""

async def timeout_method(url, **kwargs):
# Manually create a dummy request for the exception
dummy_request = httpx.Request("POST", url)
raise httpx.TimeoutException("Simulated timeout", request=dummy_request)

# Patch sleep to record timing
original_sleep = asyncio.sleep
sleep_times = []

async def mock_sleep(delay):
sleep_times.append(delay)
# Call the original sleep function to avoid recursion
await original_sleep(0)

with patch.object(
httpx.AsyncClient, "post", new=AsyncMock(side_effect=timeout_method)
):
async with httpx.AsyncClient() as client:
async with KuboCAS(
rpc_base_url="http://127.0.0.1:5001",
gateway_base_url="http://127.0.0.1:8080",
client=client,
max_retries=3,
initial_delay=0.1,
backoff_factor=2.0,
) as kubo_cas:
with patch("asyncio.sleep", side_effect=mock_sleep):
with pytest.raises(httpx.TimeoutException):
await kubo_cas.save(b"test", codec="dag-cbor")

# Verify backoff timing
assert len(sleep_times) == 3, "Should have attempted 3 retries"
assert 0.09 <= sleep_times[0] <= 0.11, "First retry should be ~0.1s"
assert 0.18 <= sleep_times[1] <= 0.22, (
"Second retry should be ~0.2s"
)
assert 0.36 <= sleep_times[2] <= 0.44, "Third retry should be ~0.4s"


@pytest.mark.asyncio
async def test_kubo_http_status_error_no_retry():
"""
Tests that KuboCAS immediately raises HTTPStatusError without retrying.
"""

# This mock simulates a server error by returning a 500 status code.
async def mock_post_server_error(url, **kwargs):
dummy_request = httpx.Request("POST", url)
return httpx.Response(
500, request=dummy_request, content=b"Internal Server Error"
)

# Patch the client's post method to always return the 500 error.
with patch.object(
httpx.AsyncClient, "post", new=AsyncMock(side_effect=mock_post_server_error)
):
# Also patch asyncio.sleep to verify it's not called (i.e., no retries).
with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep:
async with httpx.AsyncClient() as client:
async with KuboCAS(client=client) as kubo_cas:
# Assert that the specific error is raised.
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await kubo_cas.save(b"some data", codec="raw")

# Verify that the response in the exception has the correct status code.
assert exc_info.value.response.status_code == 500
# Verify that no retry was attempted.
mock_sleep.assert_not_called()


@pytest.mark.asyncio
async def test_kubo_cas_retry_validation():
"""Test validation of retry parameters in KuboCAS constructor"""

# Test max_retries validation
with pytest.raises(ValueError, match="max_retries must be non-negative"):
KuboCAS(max_retries=-1)

with pytest.raises(ValueError, match="max_retries must be non-negative"):
KuboCAS(max_retries=-5)

# Test initial_delay validation
with pytest.raises(ValueError, match="initial_delay must be positive"):
KuboCAS(initial_delay=0)

with pytest.raises(ValueError, match="initial_delay must be positive"):
KuboCAS(initial_delay=-1.0)

# Test backoff_factor validation
with pytest.raises(
ValueError, match="backoff_factor must be >= 1.0 for exponential backoff"
):
KuboCAS(backoff_factor=0.5)

with pytest.raises(
ValueError, match="backoff_factor must be >= 1.0 for exponential backoff"
):
KuboCAS(backoff_factor=0.9)

# Test valid edge case values
async with KuboCAS(
max_retries=0, initial_delay=0.001, backoff_factor=1.0
) as kubo_cas:
assert kubo_cas.max_retries == 0
assert kubo_cas.initial_delay == 0.001
assert kubo_cas.backoff_factor == 1.0
12 changes: 10 additions & 2 deletions tests/test_public_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ async def test_kubocas_public_gateway():
cas_save = KuboCAS(
rpc_base_url="http://127.0.0.1:5001",
gateway_base_url="http://127.0.0.1:8080",
max_retries=0,
)

try:
Expand Down Expand Up @@ -218,6 +219,7 @@ async def test_trailing_slash_gateway():
cas = KuboCAS(
rpc_base_url="http://127.0.0.1:5001",
gateway_base_url="http://127.0.0.1:8080/", # Note the trailing slash
max_retries=0,
)

try:
Expand Down Expand Up @@ -278,15 +280,21 @@ async def test_fix_kubocas_load():
]

for input_url, expected_base in test_cases:
cas = KuboCAS(rpc_base_url="http://127.0.0.1:5001", gateway_base_url=input_url)
cas = KuboCAS(
rpc_base_url="http://127.0.0.1:5001",
gateway_base_url=input_url,
max_retries=0,
)
assert cas.gateway_base_url == expected_base, (
f"URL construction failed for {input_url}"
)
await cas.aclose()

# Test actual loading with local gateway
cas = KuboCAS(
rpc_base_url="http://127.0.0.1:5001", gateway_base_url="http://127.0.0.1:8080"
rpc_base_url="http://127.0.0.1:5001",
gateway_base_url="http://127.0.0.1:8080",
max_retries=0,
)

try:
Expand Down