From 457174a993b14dd0c26785a950435d25d4d0b43c Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Wed, 30 Apr 2025 00:44:47 +0000 Subject: [PATCH 1/5] Fix BOS - Fix bos; - Deprecate webQA; - Increase inference and web retrieval. --- prompting/rewards/exact_match.py | 41 ++++++++++++++++----- prompting/tasks/task_registry.py | 12 ++---- tests/prompting/rewards/test_exact_match.py | 6 +-- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/prompting/rewards/exact_match.py b/prompting/rewards/exact_match.py index b1b21c66a..5d0667f66 100644 --- a/prompting/rewards/exact_match.py +++ b/prompting/rewards/exact_match.py @@ -14,14 +14,14 @@ shared_settings = settings.shared_settings +MIN_VERIFY_TOKENS = 10 +MAX_VERIFY_TOKENS = 30 NO_EOS_PENALTY = -0.1 INCORRECT_PENALTY = -2 MIN_SMOOTH_PENALTY_SCALE = 0.6 MIN_TIME_PENALTY_SCALE = 0.3 -VERIFICATION_THRESH_CONTAINS = 0.92 -VERIFICATION_THRESH_SIM = 0.90 -MIN_VERIFY_TOKENS = 10 -MAX_VERIFY_TOKENS = 30 +VERIFICATION_THRESH_CONTAINS = 0.96 +VERIFICATION_THRESH_SIM = 0.88 class LogitsRewardModel(BaseRewardModel): @@ -93,11 +93,15 @@ async def reward( # noqa: C901 scores_contains: list[float] = [] for idx in verify_indices: check_idx = min(idx, completion_length) + messages = task.task_messages.copy() + to_complete = "".join(chunks[:check_idx]) + if to_complete: + messages.extend([{"role": "assistant", "content": to_complete}]) verification_logits, _ = await model_manager.generate_logits( model=task.llm_model_id, - messages=task.task_messages + [{"role": "assistant", "content": "".join(chunks[:check_idx])}], + messages=messages, sampling_params=sampling_parameters, - continue_last_message=True, + continue_last_message=len(to_complete) > 0, ) if check_idx < eos_idx: if not chunk_dicts_raw[check_idx].choices[0].logprobs: @@ -111,13 +115,29 @@ async def reward( # noqa: C901 for info in chunk_dicts_raw[check_idx].choices[0].logprobs.content[0].top_logprobs } + logit_sim = self.verify_logit_similarity(original_logits, verification_logits) scores_sim.append(logit_sim) logit_contains = self.verify_logit_contains( chunks[check_idx], original_logits, verification_logits ) + # if ( + # check_idx == 0 + # and logit_sim < VERIFICATION_THRESH_SIM + # and logit_contains < VERIFICATION_THRESH_CONTAINS + # ): + # raise ValueError("First token verification failed") + scores_contains.append(logit_contains) + if logit_sim < VERIFICATION_THRESH_SIM: + messages=task.task_messages + [{"role": "assistant", "content": "".join(chunks[:check_idx])}] + logger.debug(f"Failed single logit: {logit_sim}") + logger.debug(f"Messages: {messages}") + logger.debug(f"sampling_parameters: {sampling_parameters}") + logger.debug(f"Verification: {verification_logits}") + logger.debug(f"Original: {original_logits}") + logger.debug("====================") elif check_idx == eos_idx and completion_length < max_tokens: if eos_token and eos_token not in verification_logits: @@ -181,9 +201,12 @@ async def reward( # noqa: C901 @staticmethod def sample_verification_indices(completion_length: int) -> list[int]: - """Sample random indices for verification, always add eos_token index.""" - num_verify = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) - verify_indices = random.sample(range(completion_length), num_verify) + """Sample random indices for verification, always add 0 and eos_token index.""" + # Sample indices without first and last index. + num_verify = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) - 2 + verify_indices = random.sample(range(1, completion_length - 1), num_verify) + # Add first index. + verify_indices.append(0) # Add eos_token index. verify_indices.append(completion_length) verify_indices.sort() diff --git a/prompting/tasks/task_registry.py b/prompting/tasks/task_registry.py index 879f07297..33239f5f3 100644 --- a/prompting/tasks/task_registry.py +++ b/prompting/tasks/task_registry.py @@ -33,27 +33,21 @@ def __hash__(self): class TaskRegistry(BaseModel): task_configs: ClassVar[list[TaskConfig]] = [ TaskConfig(task=MSRv2Task, probability=0.05, datasets=[DDGDataset], reward_model=MSRv2RewardConfig), - TaskConfig( - task=WebQuestionAnsweringTask, - probability=0.05, - datasets=[DDGDataset], - reward_model=QARewardConfig, - ), TaskConfig( task=InferenceTask, - probability=0.40, + probability=0.45, datasets=[SN13Dataset], reward_model=InferenceRewardConfig, ), TaskConfig( task=ProgrammingTask, - probability=0.20, + probability=0.10, datasets=[HuggingFaceGithubDataset], reward_model=ProgrammingRewardConfig, ), TaskConfig( task=WebRetrievalTask, - probability=0.3, + probability=0.40, datasets=[DDGDataset], reward_model=WebRetrievalRewardConfig, ), diff --git a/tests/prompting/rewards/test_exact_match.py b/tests/prompting/rewards/test_exact_match.py index eb670be25..8a40c5042 100644 --- a/tests/prompting/rewards/test_exact_match.py +++ b/tests/prompting/rewards/test_exact_match.py @@ -303,10 +303,10 @@ def test_fastest_timing_various_cases(values, expected): def test_sample_verification_indices_properties(completion_length): indices = LogitsRewardModel.sample_verification_indices(completion_length) - # Compute expected number of sampled tokens (before adding EOS) - expected_k = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) + 1 + # Compute expected number of sampled tokens with first and eos indices. + expected_k = int(np.clip(completion_length, 1, MAX_VERIFY_TOKENS)) - # The result should have expected_k samples plus one EOS index + # The result should have expected_k samples plus one EOS index. assert isinstance(indices, list) assert len(indices) == expected_k assert indices == sorted(indices) From 6f8c9558c96b7bf985e1d978ee6b3fdc60efe762 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Wed, 30 Apr 2025 00:47:04 +0000 Subject: [PATCH 2/5] Clean up code --- prompting/rewards/exact_match.py | 16 ---------------- prompting/tasks/task_registry.py | 1 - 2 files changed, 17 deletions(-) diff --git a/prompting/rewards/exact_match.py b/prompting/rewards/exact_match.py index 5d0667f66..f31978076 100644 --- a/prompting/rewards/exact_match.py +++ b/prompting/rewards/exact_match.py @@ -115,29 +115,13 @@ async def reward( # noqa: C901 for info in chunk_dicts_raw[check_idx].choices[0].logprobs.content[0].top_logprobs } - logit_sim = self.verify_logit_similarity(original_logits, verification_logits) scores_sim.append(logit_sim) logit_contains = self.verify_logit_contains( chunks[check_idx], original_logits, verification_logits ) - # if ( - # check_idx == 0 - # and logit_sim < VERIFICATION_THRESH_SIM - # and logit_contains < VERIFICATION_THRESH_CONTAINS - # ): - # raise ValueError("First token verification failed") - scores_contains.append(logit_contains) - if logit_sim < VERIFICATION_THRESH_SIM: - messages=task.task_messages + [{"role": "assistant", "content": "".join(chunks[:check_idx])}] - logger.debug(f"Failed single logit: {logit_sim}") - logger.debug(f"Messages: {messages}") - logger.debug(f"sampling_parameters: {sampling_parameters}") - logger.debug(f"Verification: {verification_logits}") - logger.debug(f"Original: {original_logits}") - logger.debug("====================") elif check_idx == eos_idx and completion_length < max_tokens: if eos_token and eos_token not in verification_logits: diff --git a/prompting/tasks/task_registry.py b/prompting/tasks/task_registry.py index 33239f5f3..c16a4c83b 100644 --- a/prompting/tasks/task_registry.py +++ b/prompting/tasks/task_registry.py @@ -13,7 +13,6 @@ from prompting.tasks.inference import InferenceRewardConfig, InferenceTask from prompting.tasks.MSRv2_task import MSRv2RewardConfig, MSRv2Task from prompting.tasks.programming_task import ProgrammingRewardConfig, ProgrammingTask -from prompting.tasks.qa import QARewardConfig, WebQuestionAnsweringTask from prompting.tasks.web_retrieval import WebRetrievalRewardConfig, WebRetrievalTask from shared.base import BaseDataset From c264b2751f916aa6620b87ff874ef86ee5daf36d Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Wed, 30 Apr 2025 01:24:45 +0000 Subject: [PATCH 3/5] Clean up code --- prompting/rewards/exact_match.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/prompting/rewards/exact_match.py b/prompting/rewards/exact_match.py index f31978076..5a3e941a9 100644 --- a/prompting/rewards/exact_match.py +++ b/prompting/rewards/exact_match.py @@ -21,7 +21,7 @@ MIN_SMOOTH_PENALTY_SCALE = 0.6 MIN_TIME_PENALTY_SCALE = 0.3 VERIFICATION_THRESH_CONTAINS = 0.96 -VERIFICATION_THRESH_SIM = 0.88 +VERIFICATION_THRESH_SIM = 0.87 class LogitsRewardModel(BaseRewardModel): @@ -130,7 +130,6 @@ async def reward( # noqa: C901 score_sim_mean = float(np.mean(scores_sim)) score_contains_mean = float(np.mean(scores_contains)) - logger.debug(f"Scores: {score_sim_mean}; {score_contains_mean}") if score_sim_mean < VERIFICATION_THRESH_SIM: raise ValueError(f"Logits similarity mean score is below threshold: {score_sim_mean:.2f}") From 06ae7d93c06db8c5d66e09fc6a26119ecb8cb5c9 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Wed, 30 Apr 2025 01:51:40 +0000 Subject: [PATCH 4/5] Increase model rotation to 1h --- prompting/llms/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/llms/model_manager.py b/prompting/llms/model_manager.py index abff765e7..ed1fd107d 100644 --- a/prompting/llms/model_manager.py +++ b/prompting/llms/model_manager.py @@ -233,7 +233,7 @@ async def cleanup(self): class AsyncModelScheduler(AsyncLoopRunner): llm_model_manager: ModelManager mp_lock: AcquirerProxy - interval: int = 1200 + interval: int = 3600 scoring_queue: list | None = None memory_error: MemoryError | None = None From c9f778d4a7de620ad011f2b5640234b3aa392de7 Mon Sep 17 00:00:00 2001 From: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com> Date: Wed, 30 Apr 2025 19:22:37 +0000 Subject: [PATCH 5/5] Modify similarity threshold --- prompting/rewards/exact_match.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/rewards/exact_match.py b/prompting/rewards/exact_match.py index 5a3e941a9..b97314193 100644 --- a/prompting/rewards/exact_match.py +++ b/prompting/rewards/exact_match.py @@ -21,7 +21,7 @@ MIN_SMOOTH_PENALTY_SCALE = 0.6 MIN_TIME_PENALTY_SCALE = 0.3 VERIFICATION_THRESH_CONTAINS = 0.96 -VERIFICATION_THRESH_SIM = 0.87 +VERIFICATION_THRESH_SIM = 0.86 class LogitsRewardModel(BaseRewardModel):