Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions py_hamt/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ class KuboCAS(ContentAddressedStore):
the internally-created session.
- **rpc_base_url / gateway_base_url** (str | None): override daemon
endpoints (defaults match the local daemon ports).
- **gateway_base_urls** (list[str] | None): optional list of additional
gateway URLs to try in parallel when loading blocks. Each base URL is
normalized to end with ``/ipfs/``.

...
"""
Expand All @@ -137,6 +140,7 @@ def __init__(
session: aiohttp.ClientSession | None = None,
rpc_base_url: str | None = KUBO_DEFAULT_LOCAL_RPC_BASE_URL,
gateway_base_url: str | None = KUBO_DEFAULT_LOCAL_GATEWAY_BASE_URL,
gateway_base_urls: list[str] | None = None,
concurrency: int = 32,
*,
headers: dict[str, str] | None = None,
Expand Down Expand Up @@ -188,7 +192,14 @@ def __init__(

self.rpc_url: str = f"{rpc_base_url}/api/v0/add?hash={self.hasher}&pin=false"
"""@private"""
self.gateway_base_url: str = f"{gateway_base_url}/ipfs/"

def _normalize(url: str) -> str:
"""Ensure URL ends with '/ipfs/'."""
return url.rstrip("/") + "/ipfs/"

bases = gateway_base_urls if gateway_base_urls else [gateway_base_url]
self.gateway_base_urls = [_normalize(u) for u in bases]
self.gateway_base_url = self.gateway_base_urls[0]
"""@private"""

self._session_per_loop: dict[
Expand Down Expand Up @@ -262,8 +273,31 @@ async def save(self, data: bytes, codec: ContentAddressedStore.CodecInput) -> CI
async def load(self, id: IPLDKind) -> bytes:
"""@private"""
cid = cast(CID, id) # CID is definitely in the IPLDKind type
url: str = self.gateway_base_url + str(cid)
async with self._sem: # throttle gateway
async with self._loop_session().get(url) as resp:
resp.raise_for_status()
return await resp.read()

async def _fetch(base: str) -> bytes:
url: str = base + str(cid)
async with self._sem:
async with self._loop_session().get(url) as resp:
resp.raise_for_status()
return await resp.read()

if len(self.gateway_base_urls) == 1:
return await _fetch(self.gateway_base_urls[0])

tasks = [asyncio.create_task(_fetch(base)) for base in self.gateway_base_urls]
try:
for coro in asyncio.as_completed(tasks):
try:
result = await coro
except Exception: # keep racing
continue
else:
for t in tasks:
if not t.done():
t.cancel()
return result
finally:
for t in tasks:
if not t.done():
t.cancel()
raise RuntimeError("All gateway requests failed")
69 changes: 68 additions & 1 deletion tests/test_kubo_cas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Literal, cast
import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterator, Awaitable, Callable, Literal, cast

import aiohttp
import dag_cbor
import pytest
from aiohttp import web
from dag_cbor import IPLDKind
from hypothesis import given, settings
from multiformats import CID
from testing_utils import ipld_strategy # noqa

from py_hamt import KuboCAS
Expand Down Expand Up @@ -144,3 +148,66 @@ async def test_kubo_cas(create_ipfs, data: IPLDKind): # noqa
cid = await kubo_cas.save(dag_cbor.encode(data), codec=codec_typed)
result = dag_cbor.decode(await kubo_cas.load(cid))
assert data == result


@pytest.mark.ipfs
@pytest.mark.asyncio(loop_scope="session")
async def test_kubo_multi_gateway(create_ipfs, global_client_session):
"""Verify that multiple gateway URLs work."""
rpc_url, gateway_url = create_ipfs

async with KuboCAS(
rpc_base_url=rpc_url,
gateway_base_url=gateway_url,
gateway_base_urls=[gateway_url, gateway_url],
session=global_client_session,
) as kubo_cas:
cid = await kubo_cas.save(b"hello", codec="raw")
result = await kubo_cas.load(cid)
assert result == b"hello"


@asynccontextmanager
async def _run_server(
handler: Callable[[web.Request], Awaitable[web.StreamResponse]],
) -> AsyncIterator[str]:
app = web.Application()
app.router.add_get("/{tail:.*}", handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
port = site._server.sockets[0].getsockname()[1]
try:
yield f"http://127.0.0.1:{port}"
finally:
await runner.cleanup()


@pytest.mark.asyncio
async def test_gateway_race_has_fallback():
async def fail(request: web.Request) -> web.Response:
raise web.HTTPInternalServerError()

async def ok(request: web.Request) -> web.Response:
await asyncio.sleep(0.05)
return web.Response(body=b"ok")

async with _run_server(fail) as bad, _run_server(ok) as good:
cid = CID.decode("bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku")
async with KuboCAS(gateway_base_url=good, gateway_base_urls=[bad, good]) as cas:
assert await cas.load(cid) == b"ok"


@pytest.mark.asyncio
async def test_gateway_race_all_fail():
async def fail(request: web.Request) -> web.Response:
raise web.HTTPInternalServerError()

async with _run_server(fail) as bad1, _run_server(fail) as bad2:
cid = CID.decode("bafkreihdwdcefgh4dqkjv67uzcmw7ojee6xedzdetojuzjevtenxquvyku")
async with KuboCAS(
gateway_base_url=bad1, gateway_base_urls=[bad1, bad2]
) as cas:
with pytest.raises(RuntimeError, match="All gateway requests failed"):
await cas.load(cid)
Loading