Skip to content

feat(infra): add client-side fetch buffer for RTensor#1122

Merged
garrett4wade merged 2 commits intoinclusionAI:mainfrom
guozhihao-224:fix/rtensor-fetch-buffer-improvements
Apr 1, 2026
Merged

feat(infra): add client-side fetch buffer for RTensor#1122
garrett4wade merged 2 commits intoinclusionAI:mainfrom
guozhihao-224:fix/rtensor-fetch-buffer-improvements

Conversation

@guozhihao-224
Copy link
Copy Markdown
Contributor

@guozhihao-224 guozhihao-224 commented Mar 31, 2026

Description

Add a per-process client-side cache (_fetch_buffer) for RTensor, keyed by shard_id, so that repeated to_local() / localize() calls for the same rollout batch avoid redundant network round-trips. Entries are evicted by clear_node() at the end of each train step. Also includes minor defensive improvements surfaced during code review.

Related Issue

#1117

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

N/A

Additional Context

Key changes:

  • Cache check in to_local() before backend fetch
  • Batch buffer resolution in localize() — only cache misses are fetched from the backend
  • clear_node() evicts buffer entries before deleting remote shards
  • Add buffer_stats() for operational monitoring (symmetric with existing storage_stats())
  • Add strict=True to zip in localize() for consistency and safety (matches line 231)
  • Remove unnecessary global declaration in clear_fetch_buffer()
  • Add TestFetchBuffer integration test suite (8 tests covering populate, serve, partial hit, eviction, thread safety)

Files changed:

  • areal/infra/rpc/rtensor.py: Fetch buffer implementation + review fixes
  • tests/test_rtensor.py: 8 new integration tests for fetch buffer

Add a per-process cache (_fetch_buffer) keyed by shard_id so that
repeated to_local() / localize() calls for the same rollout batch
avoid redundant network round-trips.  Entries are evicted by
clear_node() at the end of each train step.

Key changes:
- Cache check in to_local() before backend fetch
- Batch buffer resolution in localize() (fetch only misses)
- clear_node() evicts buffer entries before deleting remote shards
- Add buffer_stats() for operational monitoring
- Add strict=True to zip in localize() for safety
- Add TestFetchBuffer integration test suite (8 tests)
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a client-side fetch buffer for RTensor to cache fetched tensors by shard_id, reducing redundant network transfers. It includes methods for clearing the buffer, retrieving statistics, and evicting entries during node cleanup, along with a comprehensive test suite. Feedback was provided to de-duplicate shard_ids within the localize method to avoid redundant batch-fetching of the same shard.

Comment on lines +430 to +446
to_fetch: list[RTensor] = []
with _fetch_buffer_lock:
for rt in meta_rtensors:
cached = _fetch_buffer.get(rt.shard.shard_id)
if cached is not None:
rt.data = cached
else:
to_fetch.append(rt)

# Batch-fetch only the misses from the backend.
if to_fetch:
shards = [rt.shard for rt in to_fetch]
results = get_backend().fetch(shards)
with _fetch_buffer_lock:
for rt, tensor in zip(to_fetch, results, strict=True):
rt.data = tensor
_fetch_buffer[rt.shard.shard_id] = tensor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of localize does not de-duplicate shard_ids before fetching from the backend. If the input structure contains multiple RTensor instances sharing the same shard_id (which is common in rollout batches), they will all be added to to_fetch and subsequently fetched multiple times over the network. De-duplicating by shard_id within the call significantly reduces redundant network traffic and server load.

Suggested change
to_fetch: list[RTensor] = []
with _fetch_buffer_lock:
for rt in meta_rtensors:
cached = _fetch_buffer.get(rt.shard.shard_id)
if cached is not None:
rt.data = cached
else:
to_fetch.append(rt)
# Batch-fetch only the misses from the backend.
if to_fetch:
shards = [rt.shard for rt in to_fetch]
results = get_backend().fetch(shards)
with _fetch_buffer_lock:
for rt, tensor in zip(to_fetch, results, strict=True):
rt.data = tensor
_fetch_buffer[rt.shard.shard_id] = tensor
# Resolve as many as possible from the client-side fetch buffer.
to_fetch_map: dict[Any, list[RTensor]] = {}
with _fetch_buffer_lock:
for rt in meta_rtensors:
sid = rt.shard.shard_id
cached = _fetch_buffer.get(sid)
if cached is not None:
rt.data = cached
else:
to_fetch_map.setdefault(sid, []).append(rt)
# Batch-fetch only the unique misses from the backend.
if to_fetch_map:
unique_shards = [rts[0].shard for rts in to_fetch_map.values()]
results = get_backend().fetch(unique_shards)
with _fetch_buffer_lock:
for shard, tensor in zip(unique_shards, results, strict=True):
_fetch_buffer[shard.shard_id] = tensor
for rt in to_fetch_map[shard.shard_id]:
rt.data = tensor

Copy link
Copy Markdown
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for a minor issue.

Comment thread areal/infra/rpc/rtensor.py Outdated
Comment on lines +317 to +330
def clear_fetch_buffer() -> None:
"""Remove all entries from the client-side fetch buffer."""
with _fetch_buffer_lock:
_fetch_buffer.clear()


def buffer_stats() -> dict[str, int]:
"""Get current fetch buffer statistics."""
with _fetch_buffer_lock:
return dict(
num_tensors=len(_fetch_buffer),
total_bytes=sum(t.nbytes for t in _fetch_buffer.values()),
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions are not used elsewhere except for tests.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@garrett4wade garrett4wade merged commit 44d54cf into inclusionAI:main Apr 1, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants