diff --git a/prompting/rewards/exact_match.py b/prompting/rewards/exact_match.py index b97314193..2c82911c1 100644 --- a/prompting/rewards/exact_match.py +++ b/prompting/rewards/exact_match.py @@ -18,6 +18,7 @@ MAX_VERIFY_TOKENS = 30 NO_EOS_PENALTY = -0.1 INCORRECT_PENALTY = -2 +NOT_ENOUGH_TOKENS_PENALTY_SCALE = 0.1 MIN_SMOOTH_PENALTY_SCALE = 0.6 MIN_TIME_PENALTY_SCALE = 0.3 VERIFICATION_THRESH_CONTAINS = 0.96 @@ -67,6 +68,7 @@ async def reward( # noqa: C901 # Iterate over each miner response. for chunks, timings, chunk_dicts_raw, uid in zip(all_chunks, all_timings, all_chunk_dicts_raw, uids): penalty = INCORRECT_PENALTY + reward_scale = 1.0 try: # If no response is provided, apply full penalty. if not chunks: @@ -82,10 +84,8 @@ async def reward( # noqa: C901 continue if completion_length < MIN_VERIFY_TOKENS: - # Not enough tokens to verify, set reward to 0. - rewards.append(0) - timing_verified.append([-1.0]) - continue + # Not enough tokens to verify, still proceed to verification with scaled reward if checks will pass. + reward_scale = NOT_ENOUGH_TOKENS_PENALTY_SCALE eos_idx = completion_length verify_indices = self.sample_verification_indices(completion_length) @@ -142,8 +142,7 @@ async def reward( # noqa: C901 # Min-max scale logits reward, e.g from [0.95; 1.0] to [0.0, 1.0]. score_sim_mean = self.rescale(score_sim_mean, min_value=VERIFICATION_THRESH_SIM) score_contains_mean = self.rescale(score_contains_mean, min_value=VERIFICATION_THRESH_CONTAINS) - logits_score = (score_sim_mean + score_contains_mean) / 2 - rewards.append(logits_score * smooth_reward) + rewards.append(score_sim_mean * score_contains_mean * smooth_reward * reward_scale) except BaseException as e: logger.debug(f"Miner {uid} failed to pass logits check: {e}") rewards.append(penalty)