feat(infra): add client-side fetch buffer for RTensor#1122
Conversation
Add a per-process cache (_fetch_buffer) keyed by shard_id so that repeated to_local() / localize() calls for the same rollout batch avoid redundant network round-trips. Entries are evicted by clear_node() at the end of each train step. Key changes: - Cache check in to_local() before backend fetch - Batch buffer resolution in localize() (fetch only misses) - clear_node() evicts buffer entries before deleting remote shards - Add buffer_stats() for operational monitoring - Add strict=True to zip in localize() for safety - Add TestFetchBuffer integration test suite (8 tests)
There was a problem hiding this comment.
Code Review
This pull request introduces a client-side fetch buffer for RTensor to cache fetched tensors by shard_id, reducing redundant network transfers. It includes methods for clearing the buffer, retrieving statistics, and evicting entries during node cleanup, along with a comprehensive test suite. Feedback was provided to de-duplicate shard_ids within the localize method to avoid redundant batch-fetching of the same shard.
| to_fetch: list[RTensor] = [] | ||
| with _fetch_buffer_lock: | ||
| for rt in meta_rtensors: | ||
| cached = _fetch_buffer.get(rt.shard.shard_id) | ||
| if cached is not None: | ||
| rt.data = cached | ||
| else: | ||
| to_fetch.append(rt) | ||
|
|
||
| # Batch-fetch only the misses from the backend. | ||
| if to_fetch: | ||
| shards = [rt.shard for rt in to_fetch] | ||
| results = get_backend().fetch(shards) | ||
| with _fetch_buffer_lock: | ||
| for rt, tensor in zip(to_fetch, results, strict=True): | ||
| rt.data = tensor | ||
| _fetch_buffer[rt.shard.shard_id] = tensor |
There was a problem hiding this comment.
The current implementation of localize does not de-duplicate shard_ids before fetching from the backend. If the input structure contains multiple RTensor instances sharing the same shard_id (which is common in rollout batches), they will all be added to to_fetch and subsequently fetched multiple times over the network. De-duplicating by shard_id within the call significantly reduces redundant network traffic and server load.
| to_fetch: list[RTensor] = [] | |
| with _fetch_buffer_lock: | |
| for rt in meta_rtensors: | |
| cached = _fetch_buffer.get(rt.shard.shard_id) | |
| if cached is not None: | |
| rt.data = cached | |
| else: | |
| to_fetch.append(rt) | |
| # Batch-fetch only the misses from the backend. | |
| if to_fetch: | |
| shards = [rt.shard for rt in to_fetch] | |
| results = get_backend().fetch(shards) | |
| with _fetch_buffer_lock: | |
| for rt, tensor in zip(to_fetch, results, strict=True): | |
| rt.data = tensor | |
| _fetch_buffer[rt.shard.shard_id] = tensor | |
| # Resolve as many as possible from the client-side fetch buffer. | |
| to_fetch_map: dict[Any, list[RTensor]] = {} | |
| with _fetch_buffer_lock: | |
| for rt in meta_rtensors: | |
| sid = rt.shard.shard_id | |
| cached = _fetch_buffer.get(sid) | |
| if cached is not None: | |
| rt.data = cached | |
| else: | |
| to_fetch_map.setdefault(sid, []).append(rt) | |
| # Batch-fetch only the unique misses from the backend. | |
| if to_fetch_map: | |
| unique_shards = [rts[0].shard for rts in to_fetch_map.values()] | |
| results = get_backend().fetch(unique_shards) | |
| with _fetch_buffer_lock: | |
| for shard, tensor in zip(unique_shards, results, strict=True): | |
| _fetch_buffer[shard.shard_id] = tensor | |
| for rt in to_fetch_map[shard.shard_id]: | |
| rt.data = tensor |
garrett4wade
left a comment
There was a problem hiding this comment.
LGTM except for a minor issue.
| def clear_fetch_buffer() -> None: | ||
| """Remove all entries from the client-side fetch buffer.""" | ||
| with _fetch_buffer_lock: | ||
| _fetch_buffer.clear() | ||
|
|
||
|
|
||
| def buffer_stats() -> dict[str, int]: | ||
| """Get current fetch buffer statistics.""" | ||
| with _fetch_buffer_lock: | ||
| return dict( | ||
| num_tensors=len(_fetch_buffer), | ||
| total_bytes=sum(t.nbytes for t in _fetch_buffer.values()), | ||
| ) | ||
|
|
There was a problem hiding this comment.
These functions are not used elsewhere except for tests.
Description
Add a per-process client-side cache (
_fetch_buffer) for RTensor, keyed byshard_id, so that repeatedto_local()/localize()calls for the same rollout batch avoid redundant network round-trips. Entries are evicted byclear_node()at the end of each train step. Also includes minor defensive improvements surfaced during code review.Related Issue
#1117
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A
Additional Context
Key changes:
to_local()before backend fetchlocalize()— only cache misses are fetched from the backendclear_node()evicts buffer entries before deleting remote shardsbuffer_stats()for operational monitoring (symmetric with existingstorage_stats())strict=Truetozipinlocalize()for consistency and safety (matches line 231)globaldeclaration inclear_fetch_buffer()TestFetchBufferintegration test suite (8 tests covering populate, serve, partial hit, eviction, thread safety)Files changed:
areal/infra/rpc/rtensor.py: Fetch buffer implementation + review fixestests/test_rtensor.py: 8 new integration tests for fetch buffer