Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prompting/llms/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ async def cleanup(self):
class AsyncModelScheduler(AsyncLoopRunner):
llm_model_manager: ModelManager
mp_lock: AcquirerProxy
interval: int = 1200
interval: int = 3600
scoring_queue: list | None = None
memory_error: MemoryError | None = None

Expand Down
26 changes: 16 additions & 10 deletions prompting/rewards/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

shared_settings = settings.shared_settings

MIN_VERIFY_TOKENS = 10
MAX_VERIFY_TOKENS = 30
NO_EOS_PENALTY = -0.1
INCORRECT_PENALTY = -2
MIN_SMOOTH_PENALTY_SCALE = 0.6
MIN_TIME_PENALTY_SCALE = 0.3
VERIFICATION_THRESH_CONTAINS = 0.92
VERIFICATION_THRESH_SIM = 0.90
MIN_VERIFY_TOKENS = 10
MAX_VERIFY_TOKENS = 30
VERIFICATION_THRESH_CONTAINS = 0.96
VERIFICATION_THRESH_SIM = 0.86


class LogitsRewardModel(BaseRewardModel):
Expand Down Expand Up @@ -93,11 +93,15 @@ async def reward( # noqa: C901
scores_contains: list[float] = []
for idx in verify_indices:
check_idx = min(idx, completion_length)
messages = task.task_messages.copy()
to_complete = "".join(chunks[:check_idx])
if to_complete:
messages.extend([{"role": "assistant", "content": to_complete}])
verification_logits, _ = await model_manager.generate_logits(
model=task.llm_model_id,
messages=task.task_messages + [{"role": "assistant", "content": "".join(chunks[:check_idx])}],
messages=messages,
sampling_params=sampling_parameters,
continue_last_message=True,
continue_last_message=len(to_complete) > 0,
)
if check_idx < eos_idx:
if not chunk_dicts_raw[check_idx].choices[0].logprobs:
Expand Down Expand Up @@ -126,7 +130,6 @@ async def reward( # noqa: C901

score_sim_mean = float(np.mean(scores_sim))
score_contains_mean = float(np.mean(scores_contains))
logger.debug(f"Scores: {score_sim_mean}; {score_contains_mean}")

if score_sim_mean < VERIFICATION_THRESH_SIM:
raise ValueError(f"Logits similarity mean score is below threshold: {score_sim_mean:.2f}")
Expand Down Expand Up @@ -181,9 +184,12 @@ async def reward( # noqa: C901

@staticmethod
def sample_verification_indices(completion_length: int) -> list[int]:
"""Sample random indices for verification, always add eos_token index."""
num_verify = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS))
verify_indices = random.sample(range(completion_length), num_verify)
"""Sample random indices for verification, always add 0 and eos_token index."""
# Sample indices without first and last index.
num_verify = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) - 2
verify_indices = random.sample(range(1, completion_length - 1), num_verify)
# Add first index.
verify_indices.append(0)
# Add eos_token index.
verify_indices.append(completion_length)
verify_indices.sort()
Expand Down
13 changes: 3 additions & 10 deletions prompting/tasks/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from prompting.tasks.inference import InferenceRewardConfig, InferenceTask
from prompting.tasks.MSRv2_task import MSRv2RewardConfig, MSRv2Task
from prompting.tasks.programming_task import ProgrammingRewardConfig, ProgrammingTask
from prompting.tasks.qa import QARewardConfig, WebQuestionAnsweringTask
from prompting.tasks.web_retrieval import WebRetrievalRewardConfig, WebRetrievalTask
from shared.base import BaseDataset

Expand All @@ -33,27 +32,21 @@ def __hash__(self):
class TaskRegistry(BaseModel):
task_configs: ClassVar[list[TaskConfig]] = [
TaskConfig(task=MSRv2Task, probability=0.05, datasets=[DDGDataset], reward_model=MSRv2RewardConfig),
TaskConfig(
task=WebQuestionAnsweringTask,
probability=0.05,
datasets=[DDGDataset],
reward_model=QARewardConfig,
),
TaskConfig(
task=InferenceTask,
probability=0.40,
probability=0.45,
datasets=[SN13Dataset],
reward_model=InferenceRewardConfig,
),
TaskConfig(
task=ProgrammingTask,
probability=0.20,
probability=0.10,
datasets=[HuggingFaceGithubDataset],
reward_model=ProgrammingRewardConfig,
),
TaskConfig(
task=WebRetrievalTask,
probability=0.3,
probability=0.40,
datasets=[DDGDataset],
reward_model=WebRetrievalRewardConfig,
),
Expand Down
6 changes: 3 additions & 3 deletions tests/prompting/rewards/test_exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,10 @@ def test_fastest_timing_various_cases(values, expected):
def test_sample_verification_indices_properties(completion_length):
indices = LogitsRewardModel.sample_verification_indices(completion_length)

# Compute expected number of sampled tokens (before adding EOS)
expected_k = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) + 1
# Compute expected number of sampled tokens with first and eos indices.
expected_k = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS))

# The result should have expected_k samples plus one EOS index
# The result should have expected_k samples plus one EOS index.
assert isinstance(indices, list)
assert len(indices) == expected_k
assert indices == sorted(indices)
Expand Down