From 6547207b07eea22d090a6bb3d3c35e1a6cf63226 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Wed, 27 Aug 2025 19:03:14 -0700 Subject: [PATCH 1/4] [RFC] Add BufferView and RawBuffer interfaces Summary: Added `BufferView` and `RawBuffer` interface. A sampler will operate on a BufferView and return the sampled keys. A ReplayBuffer will own a RawBuffer and operate on that. Test Plan: n/a --- src/forge/interfaces.py | 100 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 2 deletions(-) diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 4bd2d4bbe..736c3356d 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -5,11 +5,14 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Mapping +from typing import Any, Generic, Iterable, Mapping, TypeVar + +from forge.types import Action, Message, Observation, Scalar, State from monarch.actor import Actor, endpoint -from forge.types import Action, Message, Observation, Scalar, State +K = TypeVar("K") +V = TypeVar("V") class Transform(ABC): @@ -88,6 +91,99 @@ async def update_weights(self): pass +class BufferView(ABC, Generic[K, V]): + """Abstract base class for a view into a buffer with key-value pairs. + + This class defines the interface for accessing elements in a buffer + through dictionary-like operations. It supports generic key and value types. + """ + + @abstractmethod + def __len__(self) -> int: + """Return the number of key-value pairs in the buffer. + + Returns: + int: The number of items in the buffer. + """ + pass + + @abstractmethod + def __getitem__(self, key: K) -> V: + """Retrieve a value from the buffer using the specified key. + + Args: + key (K): The key to look up in the buffer. + + Returns: + V: The value associated with the key. + + Raises: + KeyError: If the key is not found in the buffer. + """ + pass + + @abstractmethod + def __iter__(self) -> Iterable[tuple[K, V]]: + """Return an iterator over the key-value pairs in the buffer. + + Returns: + Iterable[tuple[K, V]]: An iterator yielding (key, value) tuples. + """ + pass + + @abstractmethod + def keys(self) -> Iterable[K]: + """Return an iterable of all keys in the buffer. + + Returns: + Iterable[K]: An iterable containing all keys in the buffer. + """ + pass + + +class RawBuffer(BufferView[K, V], ABC): + """Abstract interface for the underlying storage backend (raw buffer) of a ReplayBuffer.""" + + @abstractmethod + def add(self, key: K, val: V) -> None: + """ + Add a key-value pair to the buffer. + + Args: + key (K): The key to store the value under + val (V): The value to store in the buffer + + Returns: + None + """ + pass + + @abstractmethod + def pop(self, key: K) -> V: + """ + Remove and return a value from the buffer using the specified key. + + Args: + key (K): The key to look up and remove from the buffer + + Returns: + V: The value associated with the key before removal + """ + pass + + @abstractmethod + def clear(self) -> None: + """ + Remove all key-value pairs from the buffer, effectively emptying it. + + This method should reset the buffer to its initial empty state. + + Returns: + None + """ + pass + + class BaseTokenizer(ABC): """ Abstract token encoding model that implements ``encode`` and ``decode`` methods. From 2ab5d1852b0d8c949c6d16ae2af91f64038271a2 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Wed, 27 Aug 2025 19:03:14 -0700 Subject: [PATCH 2/4] Implement SimpleRawBuffer, a RawBuffer backed by a python dict. Summary: Implement SimpleRawBuffer, a RawBuffer backed by a python dict. Test Plan: unit tests --- src/forge/data/raw_buffer.py | 55 +++++++ tests/unit_tests/test_raw_buffer.py | 231 ++++++++++++++++++++++++++++ 2 files changed, 286 insertions(+) create mode 100644 src/forge/data/raw_buffer.py create mode 100644 tests/unit_tests/test_raw_buffer.py diff --git a/src/forge/data/raw_buffer.py b/src/forge/data/raw_buffer.py new file mode 100644 index 000000000..ed5a14918 --- /dev/null +++ b/src/forge/data/raw_buffer.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterator, TypeVar + +from forge.interfaces import RawBuffer + +K = TypeVar("K") +V = TypeVar("V") + + +class SimpleRawBuffer(RawBuffer[K, V]): + """Simple in-memory RawBuffer backed by a Python dictionary.""" + + def __init__(self) -> None: + self._buffer: dict[K, V] = {} + + def __len__(self) -> int: + """Return the number of key-value pairs in the buffer.""" + return len(self._buffer) + + def __getitem__(self, key: K) -> V: + """Get a value from the buffer using the specified key.""" + return self._buffer[key] + + def __iter__(self) -> Iterator[tuple[K, V]]: + """Iterate over the key-value pairs in the buffer.""" + for k, v in self._buffer.items(): + yield k, v + + def keys(self) -> Iterator[K]: + """Iterate over the keys in the buffer.""" + for k in self._buffer.keys(): + yield k + + def add(self, key: K, val: V) -> None: + """Add a key-value pair to the buffer.""" + if key in self._buffer: + raise KeyError(f"Key {key} already exists in the buffer.") + self._buffer[key] = val + + def pop(self, key: K) -> V: + """Remove and return a value from the buffer using the specified key.""" + if key not in self._buffer: + raise KeyError(f"Key {key} does not exist in the buffer.") + val = self._buffer[key] + del self._buffer[key] + return val + + def clear(self) -> None: + """Clear the buffer.""" + self._buffer.clear() diff --git a/tests/unit_tests/test_raw_buffer.py b/tests/unit_tests/test_raw_buffer.py new file mode 100644 index 000000000..a6a8e5242 --- /dev/null +++ b/tests/unit_tests/test_raw_buffer.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Test for data/raw_buffer.py""" + +import pytest + +from forge.data.raw_buffer import SimpleRawBuffer + + +class TestSimpleRawBuffer: + """Test suite for SimpleRawBuffer class.""" + + def test_init_empty_buffer(self): + """Test that a new buffer is initialized empty.""" + buffer = SimpleRawBuffer[str, int]() + assert len(buffer) == 0 + + def test_add_single_item(self): + """Test adding a single key-value pair.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + + assert len(buffer) == 1 + assert buffer["key1"] == 100 + + def test_add_multiple_items(self): + """Test adding multiple key-value pairs.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + buffer.add("key3", 300) + + assert len(buffer) == 3 + assert buffer["key1"] == 100 + assert buffer["key2"] == 200 + assert buffer["key3"] == 300 + + def test_add_duplicate_key_raises_error(self): + """Test that adding a duplicate key raises KeyError.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + + with pytest.raises(KeyError, match="Key key1 already exists in the buffer"): + buffer.add("key1", 200) + + def test_getitem_existing_key(self): + """Test retrieving an existing key.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("test_key", 42) + + assert buffer["test_key"] == 42 + + def test_getitem_missing_key_raises_error(self): + """Test that accessing a non-existent key raises KeyError.""" + buffer = SimpleRawBuffer[str, int]() + + with pytest.raises(KeyError): + _ = buffer["missing_key"] + + def test_pop_existing_key(self): + """Test removing and returning a value for an existing key.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + + value = buffer.pop("key1") + + assert value == 100 + assert len(buffer) == 1 + assert "key1" not in buffer.keys() + assert buffer["key2"] == 200 + + def test_pop_missing_key_raises_error(self): + """Test that popping a non-existent key raises KeyError.""" + buffer = SimpleRawBuffer[str, int]() + + with pytest.raises( + KeyError, match="Key missing_key does not exist in the buffer" + ): + buffer.pop("missing_key") + + def test_keys_iteration(self): + """Test iterating over keys.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + buffer.add("key3", 300) + + keys = list(buffer.keys()) + + assert len(keys) == 3 + assert "key1" in keys + assert "key2" in keys + assert "key3" in keys + + def test_keys_empty_buffer(self): + """Test iterating over keys in an empty buffer.""" + buffer = SimpleRawBuffer[str, int]() + + keys = list(buffer.keys()) + + assert keys == [] + + def test_iter_key_value_pairs(self): + """Test iterating over key-value pairs.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + + items = list(buffer) + + assert len(items) == 2 + assert ("key1", 100) in items + assert ("key2", 200) in items + + def test_iter_empty_buffer(self): + """Test iterating over an empty buffer.""" + buffer = SimpleRawBuffer[str, int]() + + items = list(buffer) + + assert items == [] + + def test_clear_buffer(self): + """Test clearing the buffer.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + buffer.add("key3", 300) + + assert len(buffer) == 3 + + buffer.clear() + + assert len(buffer) == 0 + assert list(buffer.keys()) == [] + assert list(buffer) == [] + + def test_clear_empty_buffer(self): + """Test clearing an already empty buffer.""" + buffer = SimpleRawBuffer[str, int]() + + assert len(buffer) == 0 + + buffer.clear() + + assert len(buffer) == 0 + + def test_different_value_types(self): + """Test buffer with different value types.""" + buffer = SimpleRawBuffer[str, list[int]]() + buffer.add("list1", [1, 2, 3]) + buffer.add("list2", [4, 5, 6]) + + assert buffer["list1"] == [1, 2, 3] + assert buffer["list2"] == [4, 5, 6] + + def test_different_key_types(self): + """Test buffer with different key types.""" + buffer = SimpleRawBuffer[int, str]() + buffer.add(1, "value1") + buffer.add(2, "value2") + + assert buffer[1] == "value1" + assert buffer[2] == "value2" + + def test_complex_workflow(self): + """Test a complex workflow with multiple operations.""" + buffer = SimpleRawBuffer[str, int]() + + # Add some items + buffer.add("a", 1) + buffer.add("b", 2) + buffer.add("c", 3) + assert len(buffer) == 3 + + # Pop one item + value = buffer.pop("b") + assert value == 2 + assert len(buffer) == 2 + + # Add another item + buffer.add("d", 4) + assert len(buffer) == 3 + + # Verify remaining items + assert buffer["a"] == 1 + assert buffer["c"] == 3 + assert buffer["d"] == 4 + + # Clear and verify empty + buffer.clear() + assert len(buffer) == 0 + + def test_len_consistency(self): + """Test that len() remains consistent with add/pop operations.""" + buffer = SimpleRawBuffer[str, int]() + + # Initially empty + assert len(buffer) == 0 + + # Add items and check length + for i in range(5): + buffer.add(f"key{i}", i) + assert len(buffer) == i + 1 + + # Remove items and check length + for i in range(5): + buffer.pop(f"key{i}") + assert len(buffer) == 4 - i + + def test_none_values(self): + """Test storing None values.""" + buffer = SimpleRawBuffer[str, int | None]() + buffer.add("none_value", None) + buffer.add("int_value", 42) + + assert buffer["none_value"] is None + assert buffer["int_value"] == 42 + + def test_empty_string_key(self): + """Test using empty string as key.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("", 42) + + assert buffer[""] == 42 + assert "" in list(buffer.keys()) From 1aeb2b37f531271c79ccdb5a9e10a47a3c16c270 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Wed, 27 Aug 2025 19:03:14 -0700 Subject: [PATCH 3/4] Add StatefulSampler interface and implement RandomStatefulSampler Summary: This diff adds a new interface called `StatefulSampler` and implements a new class called `RandomStatefulSampler`. The `RandomStatefulSampler` class is a stateful sampler that uses Python's `random.sample` function for deterministic sampling. Test Plan: unit tests --- src/forge/data/stateful_sampler.py | 73 +++++++++++++++++++++++ src/forge/interfaces.py | 52 ++++++++++++++++ tests/unit_tests/test_stateful_sampler.py | 59 ++++++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 src/forge/data/stateful_sampler.py create mode 100644 tests/unit_tests/test_stateful_sampler.py diff --git a/src/forge/data/stateful_sampler.py b/src/forge/data/stateful_sampler.py new file mode 100644 index 000000000..3279a658c --- /dev/null +++ b/src/forge/data/stateful_sampler.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import random +from typing import Any, Generic, List, Mapping, TypeVar + +from forge.interfaces import BufferView, StatefulSampler + +K = TypeVar("K") +V = TypeVar("V") + + +class RandomStatefulSampler(StatefulSampler[K, V], Generic[K, V]): + """A simple stateful sampler that uses Python's random.sample for deterministic sampling. + + This sampler maintains an internal random state that can be saved and restored, + allowing for reproducible sampling behavior. It uses random.sample to select + keys from the buffer without replacement. + """ + + def __init__(self, seed: int | None = None): + """Initialize the sampler with an optional random seed. + + Args: + seed: Optional seed for the random number generator. If None, + the sampler will use Python's default random state. + """ + if seed is None: + self._random = random.Random() + self._random = random.Random(seed) + + def sample_keys(self, buffer: BufferView[K, V], num: int) -> List[K]: + """Sample keys from the buffer using random.sample. + + Args: + buffer: The buffer to sample from + num: Number of keys to sample + + Returns: + A list of sampled keys. If num is greater than the buffer size, + returns all available keys. + """ + # Get all keys from the buffer + all_keys = list(buffer.keys()) + + # If requesting more samples than available, return all keys + if num >= len(all_keys): + return all_keys + + # Use random.sample for sampling without replacement + return self._random.sample(all_keys, num) + + def state_dict(self): + """Return the state dict of the sampler. + + Returns: + A dictionary containing the random number generator state. + """ + return {"random_state": self._random.getstate()} + + def set_state_dict(self, state_dict: Mapping[str, Any]): + """Set the state dict of the sampler. + + Args: + state_dict: Dictionary containing the random state to restore. + """ + if "random_state" in state_dict: + self._random.setstate(state_dict["random_state"]) + else: + raise ValueError("Missing 'random_state' in state dict") diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 736c3356d..c0579f09f 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -184,6 +184,58 @@ def clear(self) -> None: pass +class StatefulSampler(ABC, Generic[K, V]): + """Abstract interface for stateful samplers with deterministic behavior given a state. + + This class defines the interface for samplers that maintain internal state and provide + deterministic sampling behavior when the state is fixed. + """ + + @abstractmethod + def sample_keys(self, buffer: BufferView[K, V], num: int) -> list[K]: + """Return the keys of selected samples from the buffer. + + This method samples a specified number of keys from the provided buffer + according to the sampler's internal sampling strategy. The sampling + behavior is deterministic for a given internal state of the sampler. + + Args: + buffer (BufferView[K, V]): The buffer to sample from, containing key-value pairs. + num (int): Desired number of samples to retrieve from the buffer. + If num is greater than the buffer size, implementation may + return fewer samples or handle it according to the specific + sampling strategy. + + Returns: + list[K]: A list of keys corresponding to the selected samples. + The length of this list will typically be equal to num, + unless the buffer contains fewer items. + """ + pass + + @abstractmethod + def state_dict(self) -> Mapping[str, Any]: + """Return the state dict of the sampler. + + This method should capture all the internal state necessary to reproduce + the sampler's behavior, such as random number generator states. + + Returns: + dict: A dictionary containing the internal state of the sampler. + """ + pass + + @abstractmethod + def set_state_dict(self, state_dict): + """Set the state dict of the sampler. + + Args: + state_dict (dict): A dictionary containing the internal state to restore + the sampler to a specific configuration. + """ + pass + + class BaseTokenizer(ABC): """ Abstract token encoding model that implements ``encode`` and ``decode`` methods. diff --git a/tests/unit_tests/test_stateful_sampler.py b/tests/unit_tests/test_stateful_sampler.py new file mode 100644 index 000000000..287b69aa5 --- /dev/null +++ b/tests/unit_tests/test_stateful_sampler.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from forge.data.raw_buffer import SimpleRawBuffer +from forge.data.stateful_sampler import RandomStatefulSampler + +from forge.interfaces import RawBuffer + + +class TestRandomStatefulSampler: + @pytest.fixture + def raw_buffer(self) -> RawBuffer[int, int]: + buffer = SimpleRawBuffer[int, int]() + for n in range(1000): + buffer.add(n, n) + return buffer + + def test_init(self): + sampler = RandomStatefulSampler() + assert True + + def test_init_with_seed(self): + sampler1 = RandomStatefulSampler(seed=42) + sampler2 = RandomStatefulSampler(seed=41) + assert str(sampler1.state_dict()) != str(sampler2.state_dict()) + + def test_state_dict(self): + sampler = RandomStatefulSampler() + state_dict = sampler.state_dict() + assert "random_state" in state_dict + assert state_dict["random_state"] is not None + + def test_set_state_dict_no_random_state(self): + sampler = RandomStatefulSampler() + state_dict = {} + with pytest.raises(ValueError, match="Missing 'random_state'"): + sampler.set_state_dict(state_dict) + + def test_deterministic(self, raw_buffer): + sampler1 = RandomStatefulSampler(seed=42) + sampler2 = RandomStatefulSampler() + sampler2.set_state_dict(sampler1.state_dict()) + for _ in range(10): + batch1 = sampler1.sample_keys(raw_buffer, 5) + batch2 = sampler2.sample_keys(raw_buffer, 5) + assert batch1 == batch2 + + def test_deterministic_resume(self, raw_buffer): + sampler1 = RandomStatefulSampler(seed=42) + sampler2 = RandomStatefulSampler() + for _ in range(10): + sampler2.set_state_dict(sampler1.state_dict()) + batch1 = sampler1.sample_keys(raw_buffer, 5) + batch2 = sampler2.sample_keys(raw_buffer, 5) + assert batch1 == batch2 From 25e829281744b693c70528053786ec7d7be9111f Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Wed, 27 Aug 2025 19:29:39 -0700 Subject: [PATCH 4/4] Allow passing in custom sampler in ReplayBuffer Summary: This diff allows passing in a custom sampler in the ReplayBuffer class. The changes include adding a new sampler class and modifying the ReplayBuffer class to accept a sampler as a parameter. The code changes also include modifying the test_replay_buffer.py file to test the new sampler functionality. Test Plan: unit tests (flaky) --- src/forge/actors/replay_buffer.py | 64 +++++++++------- src/forge/test_util/__init__.py | 0 src/forge/test_util/udp_trace.py | 101 +++++++++++++++++++++++++ tests/unit_tests/test_replay_buffer.py | 54 +++++++++++-- 4 files changed, 185 insertions(+), 34 deletions(-) create mode 100644 src/forge/test_util/__init__.py create mode 100644 src/forge/test_util/udp_trace.py diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index d0e70e85f..ac3618a13 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -5,12 +5,17 @@ # LICENSE file in the root directory of this source tree. import random +import uuid from dataclasses import dataclass from typing import Any -from monarch.actor import endpoint - from forge.controller import ForgeActor +from forge.data.raw_buffer import SimpleRawBuffer +from forge.data.stateful_sampler import RandomStatefulSampler + +from forge.interfaces import StatefulSampler + +from monarch.actor import endpoint @dataclass @@ -22,16 +27,23 @@ class ReplayBuffer(ForgeActor): seed: int | None = None @endpoint - async def setup(self) -> None: - self.buffer: list = [] + async def setup(self, *, sampler: StatefulSampler | None = None) -> None: + self._buffer = SimpleRawBuffer[int, Any]() if self.seed is None: self.seed = random.randint(0, 2**32) - random.seed(self.seed) - self.sampler = random.sample + if sampler is None: + sampler = RandomStatefulSampler(seed=self.seed) + + self._sampler = sampler @endpoint async def add(self, episode) -> None: - self.buffer.append(episode) + # I think key should be provided by the caller, but let's just generate a random one for now + # Note that this means add() is not deterministic, however the original implementation using list + # isn't actually deterministic either because it depends on the order of add() being called. + # Alternatively, add a field in Trajectory as the id of the trajectory. + key = uuid.uuid4().int + self._buffer.add(key, episode) @endpoint async def sample(self, curr_policy_version: int, batch_size: int | None = None): @@ -50,15 +62,11 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None): # Evict old episodes self._evict(curr_policy_version) - if bsz > len(self.buffer): + if bsz > len(self._buffer): return None - # TODO: Make this more efficient - idx_to_sample = self.sampler(range(len(self.buffer)), k=bsz) - sorted_idxs = sorted( - idx_to_sample, reverse=True - ) # Sort in desc order to avoid shifting idxs - sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs] + keys_to_sample = self._sampler.sample_keys(self._buffer, num=bsz) + sampled_episodes = [self._buffer.pop(k) for k in keys_to_sample] return sampled_episodes @endpoint @@ -72,35 +80,33 @@ async def evict(self, curr_policy_version: int) -> None: self._evict(curr_policy_version) def _evict(self, curr_policy_version: int) -> None: - self.buffer = [ - trajectory - for trajectory in self.buffer - if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age - ] - - @endpoint - async def _getitem(self, idx: int): - return self.buffer[idx] + keys_to_delete = [] + for key, episode in self._buffer: + if curr_policy_version - episode.policy_version > self.max_policy_age: + keys_to_delete.append(key) + for key in keys_to_delete: + self._buffer.pop(key) @endpoint async def _numel(self) -> int: """Number of elements (episodes) in the replay buffer.""" - return len(self.buffer) + return len(self._buffer) @endpoint async def clear(self) -> None: """Clear the replay buffer immediately - dropping all episodes.""" - self.buffer.clear() + self._buffer.clear() @endpoint async def state_dict(self) -> dict[str, Any]: return { - "buffer": self.buffer, - "rng_state": random.getstate(), + "buffer": self._buffer, + "sampler_state": self._sampler.state_dict(), "seed": self.seed, } @endpoint async def load_state_dict(self, state_dict: dict[str, Any]) -> None: - self.buffer = state_dict["buffer"] - random.setstate(state_dict["rng_state"]) + self._buffer = state_dict["buffer"] + self._sampler.set_state_dict(state_dict["sampler_state"]) + self.seed = state_dict["seed"] diff --git a/src/forge/test_util/__init__.py b/src/forge/test_util/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/forge/test_util/udp_trace.py b/src/forge/test_util/udp_trace.py new file mode 100644 index 000000000..0b7846414 --- /dev/null +++ b/src/forge/test_util/udp_trace.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Test utilities for tracing function calls via UDP packets. + +This module provides utilities for testing distributed/async components where +traditional mocking is difficult due to pickling/unpickling (e.g., with monarch actors). +The UDP tracing approach allows tests to verify that specific functions were called +by having them send UDP packets that can be received and verified by the test. + +Warning: This approach has limitations - tests using UDP tracing can be flaky and +only work reliably when run on a single machine because they listen to localhost. + +Example usage: + # In test code + sampler.sample_keys = add_udp_callback( + sampler.sample_keys, port=TEST_PORT, message=b"sample_keys" + ) + + # Start UDP receiver in separate thread + received = [] + server_thread = threading.Thread( + target=receive_udp_packet, + args=(TEST_PORT, received), + kwargs={"timeout": 15}, + ) + server_thread.start() + + # Execute code that should call the wrapped function + # ... + + # Verify the function was called + server_thread.join() + assert b"sample_keys" in received +""" + +import socket + + +def receive_udp_packet(port, received, *, timeout): + """ + Receives a UDP packet on the specified port and appends it to the received list. + + Args: + port: The port number to listen on + received: A list to which received data will be appended + timeout: Keyword-only argument specifying socket timeout in seconds + + Returns: + None. Data is appended to the received list if a packet is received before timeout. + """ + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind(("localhost", port)) + sock.settimeout(timeout) + try: + data, _ = sock.recvfrom(1024) # addr is not used + received.append(data) + except socket.timeout: + pass + finally: + sock.close() + + +def send_udp_packet(port, message): + """ + Sends a UDP packet to localhost on the specified port. + + Args: + port: The port number to send the packet to + message: The message/data to send + + Returns: + None + """ + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.sendto(message, ("localhost", port)) + sock.close() + + +def add_udp_callback(func, port, message): + """ + Decorator function that wraps another function to send a UDP packet after execution. + + Args: + func: The function to wrap + port: The port number to send the packet to + message: The message/data to send + + Returns: + A wrapped function that calls the original function and then sends a UDP packet + """ + + def f(*args, **kwargs): + ret = func(*args, **kwargs) + send_udp_packet(port, message) + return ret + + return f diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index 19258a1a7..61a0fb104 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -4,12 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Test for data/replay_buffer.py""" +"""Test for actors/replay_buffer.py""" + +import threading +from random import Random import pytest import pytest_asyncio + from forge.actors.replay_buffer import ReplayBuffer -from forge.types import Trajectory +from forge.data.stateful_sampler import RandomStatefulSampler +from forge.interfaces import StatefulSampler + +from forge.test_util.udp_trace import add_udp_callback, receive_udp_packet +from forge.types import State, Trajectory from monarch.actor import proc_mesh @@ -24,12 +32,43 @@ async def replay_buffer(self) -> ReplayBuffer: await replay_buffer.setup.call() return replay_buffer + @pytest.mark.asyncio + async def test_setup_accepts_sampler(self) -> None: + # This test is flaky and only works if it is run on a single machine. + # However, it's impossible to directly mock a function called by monarch + # because it is first pickled and then unpickled. + + sampler = RandomStatefulSampler() + TEST_PORT = 34958 + sampler.sample_keys = add_udp_callback( + sampler.sample_keys, port=TEST_PORT, message=b"sample_keys" + ) + + mesh = await proc_mesh(gpus=1) + replay_buffer = await mesh.spawn( + "replay_buffer", ReplayBuffer, batch_size=1, max_policy_age=1 + ) + received = [] + server_thread = threading.Thread( + target=receive_udp_packet, + args=(TEST_PORT, received), + kwargs={"timeout": 15}, + ) + server_thread.start() + await replay_buffer.setup.call(sampler=sampler) + await replay_buffer.add.call_one(Trajectory(policy_version=0)) + await replay_buffer.sample.call_one(curr_policy_version=0) + server_thread.join() + assert b"".join(received) == b"sample_keys" + @pytest.mark.asyncio async def test_add(self, replay_buffer: ReplayBuffer) -> None: trajectory = Trajectory(policy_version=0) await replay_buffer.add.call_one(trajectory) assert replay_buffer._numel.call_one().get() == 1 - assert replay_buffer._getitem.call_one(0).get() == trajectory + assert replay_buffer.sample.call_one( + curr_policy_version=0, batch_size=1 + ).get() == [trajectory] replay_buffer.clear.call_one().get() @pytest.mark.asyncio @@ -39,8 +78,13 @@ async def test_add_multiple(self, replay_buffer) -> None: await replay_buffer.add.call_one(trajectory_0) await replay_buffer.add.call_one(trajectory_1) assert replay_buffer._numel.call_one().get() == 2 - assert replay_buffer._getitem.call_one(0).get() == trajectory_0 - assert replay_buffer._getitem.call_one(1).get() == trajectory_1 + all_samples = replay_buffer.sample.call_one( + curr_policy_version=0, batch_size=2 + ).get() + assert all_samples == [trajectory_0, trajectory_1] or all_samples == [ + trajectory_1, + trajectory_0, + ] replay_buffer.clear.call_one().get() @pytest.mark.asyncio