Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions prompting/rewards/scoring.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import copy
import threading
import time
from multiprocessing.managers import AcquirerProxy

from loguru import logger
Expand Down Expand Up @@ -30,6 +31,7 @@ class TaskScorer(AsyncLoopRunner):
scoring_queue: list | None = None
reward_events: list | None = None
task_queue: list | None = None
expiry_time: int = 60 * 60 * 20
model_config = ConfigDict(arbitrary_types_allowed=True)

async def start(
Expand Down Expand Up @@ -70,12 +72,22 @@ def add_to_queue(
async def run_step(self) -> RewardLoggingEvent:
await asyncio.sleep(0.1)

if not self.scoring_queue:
scoring_config: ScoringConfig | None = None
while self.scoring_queue:
# Pop the oldest item from the queue.
config = self.scoring_queue.pop(0)
# Check if the config is recent enough to be processed.
if config.created_at >= time.time() - self.expiry_time:
scoring_config = config
break
# Otherwise, the old config is discarded and we continue to the next one.
else:
logger.debug(
f"Discarding old scoring config for {config.task.__class__.__name__} created at {config.created_at}"
)
if not scoring_config:
return

# TODO: Filter based on active models before selecting an item to score.
scoring_config: ScoringConfig = self.scoring_queue.pop(0)

# here we generate the actual reference
with Timer(label=f"Generating reference for {scoring_config.task.__class__.__name__}"):
await scoring_config.task.make_reference(
Expand Down
4 changes: 3 additions & 1 deletion prompting/rewards/scoring_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import time
from dataclasses import dataclass, field

from prompting.tasks.base_task import BaseTextTask
from shared.base import DatasetEntry
Expand All @@ -13,3 +14,4 @@ class ScoringConfig:
block: int
step: int
task_id: str
created_at: float = field(default_factory=time.time)