diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index 77b4f6f8a..37a1558d5 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -120,6 +120,27 @@ async def sample( entry.sample_count += 1 sampled_episodes.append(entry.data) + # Calculate and record policy age metrics for sampled episodes + sampled_policy_ages = [ + curr_policy_version - ep.policy_version for ep in sampled_episodes + ] + if sampled_policy_ages: + record_metric( + "buffer/sample/avg_sampled_policy_age", + sum(sampled_policy_ages) / len(sampled_policy_ages), + Reduce.MEAN, + ) + record_metric( + "buffer/sample/max_sampled_policy_age", + max(sampled_policy_ages), + Reduce.MAX, + ) + record_metric( + "buffer/sample/min_sampled_policy_age", + min(sampled_policy_ages), + Reduce.MIN, + ) + # Reshape into (dp_size, bsz, ...) reshaped_episodes = [ sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size] @@ -149,22 +170,6 @@ def _evict(self, curr_policy_version): ) self.buffer = deque(self._collect(indices)) - # Record evict metrics - policy_age = [ - curr_policy_version - ep.data.policy_version for ep in self.buffer - ] - if policy_age: - record_metric( - "buffer/evict/avg_policy_age", - sum(policy_age) / len(policy_age), - Reduce.MEAN, - ) - record_metric( - "buffer/evict/max_policy_age", - max(policy_age), - Reduce.MAX, - ) - evicted_count = buffer_len_before_evict - len(self.buffer) record_metric("buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM)