diff --git a/prompting/rewards/scoring.py b/prompting/rewards/scoring.py index fa7126715..84ebaf5c8 100644 --- a/prompting/rewards/scoring.py +++ b/prompting/rewards/scoring.py @@ -1,6 +1,7 @@ import asyncio import copy import threading +import time from multiprocessing.managers import AcquirerProxy from loguru import logger @@ -30,6 +31,7 @@ class TaskScorer(AsyncLoopRunner): scoring_queue: list | None = None reward_events: list | None = None task_queue: list | None = None + expiry_time: int = 60 * 60 * 20 model_config = ConfigDict(arbitrary_types_allowed=True) async def start( @@ -70,12 +72,22 @@ def add_to_queue( async def run_step(self) -> RewardLoggingEvent: await asyncio.sleep(0.1) - if not self.scoring_queue: + scoring_config: ScoringConfig | None = None + while self.scoring_queue: + # Pop the oldest item from the queue. + config = self.scoring_queue.pop(0) + # Check if the config is recent enough to be processed. + if config.created_at >= time.time() - self.expiry_time: + scoring_config = config + break + # Otherwise, the old config is discarded and we continue to the next one. + else: + logger.debug( + f"Discarding old scoring config for {config.task.__class__.__name__} created at {config.created_at}" + ) + if not scoring_config: return - # TODO: Filter based on active models before selecting an item to score. - scoring_config: ScoringConfig = self.scoring_queue.pop(0) - # here we generate the actual reference with Timer(label=f"Generating reference for {scoring_config.task.__class__.__name__}"): await scoring_config.task.make_reference( diff --git a/prompting/rewards/scoring_config.py b/prompting/rewards/scoring_config.py index 1606b21b3..fb176c304 100644 --- a/prompting/rewards/scoring_config.py +++ b/prompting/rewards/scoring_config.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +import time +from dataclasses import dataclass, field from prompting.tasks.base_task import BaseTextTask from shared.base import DatasetEntry @@ -13,3 +14,4 @@ class ScoringConfig: block: int step: int task_id: str + created_at: float = field(default_factory=time.time)