diff --git a/neurons/validator.py b/neurons/validator.py index beb381830..eff556605 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -1,10 +1,14 @@ import asyncio +import atexit import os import signal import sys +import time from multiprocessing.managers import AcquirerProxy +from multiprocessing.synchronize import Event import netaddr +import psutil import requests import torch import torch.multiprocessing as mp @@ -102,6 +106,7 @@ def start_api( scoring_queue: list, reward_events: list, miners_dict: dict, + event_stop: Event, ): from prompting.api.api import start_scoring_api # noqa: F401 @@ -124,7 +129,7 @@ async def start(): logger.warning(f"Failed to serve scoring api to chain: {e}") await start_scoring_api(task_scorer, scoring_queue, reward_events, miners_dict) - while True: + while not event_stop.is_set(): await asyncio.sleep(10) asyncio.run(start()) @@ -134,6 +139,7 @@ def start_task_sending_loop( task_queue: list, scoring_queue: list, miners_dict: dict, + event_stop: Event, ): async def spawn_loops(task_queue, scoring_queue, miners_dict: dict): from prompting.tasks.task_sending import TaskSender @@ -142,7 +148,8 @@ async def spawn_loops(task_queue, scoring_queue, miners_dict: dict): task_sender = TaskSender() asyncio.create_task(task_sender.start(task_queue, scoring_queue, miners_dict, simultaneous_loops=1)) logger.debug("Task sending loop started") - while True: + + while not event_stop.is_set(): await asyncio.sleep(5) logger.debug("Task sending loop is running") @@ -155,13 +162,13 @@ async def spawn_loops(task_queue, scoring_queue, miners_dict: dict): raise -def start_availability_checking_loop(miners_dict: dict): +def start_availability_checking_loop(miners_dict: dict, event_stop: Event): async def spawn_loops(miners_dict: dict): from prompting.miner_availability.miner_availability import availability_checking_loop logger.info("Starting availability checking loop in validator...") asyncio.create_task(availability_checking_loop.start(miners_dict)) - while True: + while not event_stop.is_set(): await asyncio.sleep(5) logger.debug("Availability checking loop is running") @@ -174,13 +181,13 @@ async def spawn_loops(miners_dict: dict): raise -def start_weight_setter_loop(reward_events): +def start_weight_setter_loop(reward_events, event_stop: Event): async def spawn_loops(reward_events): from prompting.weight_setting.weight_setter import weight_setter logger.info("Starting weight setter loop in validator...") asyncio.create_task(weight_setter.start(reward_events)) - while True: + while not event_stop.is_set(): await asyncio.sleep(5) logger.debug("Weight setter loop is running") @@ -193,6 +200,34 @@ async def spawn_loops(reward_events): raise +def health_check(parent_pid: int, event_stop: Event): + """Monitor parent process and kill all child processes in case of emergency.""" + step = 0 + while True: + try: + if not psutil.pid_exists(parent_pid): + event_stop.set() + logger.warning("Parent process died, killing all child processes") + os.killpg(0, signal.SIGKILL) + + block = settings.shared_settings.block + if block - settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID] > 320 and step > 60: + event_stop.set() + last_update_block = settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID] + logger.warning( + f"Metagraph hasn't been updated for {block - last_update_block} blocks. " + f"Staled block: {block}, Last update: {last_update_block}" + ) + os.killpg(0, signal.SIGKILL) + step += 1 + + except Exception as e: + logger.error(f"Failed to kill process group: {e}") + finally: + sys.exit(1) + time.sleep(60) + + async def main( cache_rewards: list | None = None, cache_scores: list | None = None, @@ -208,6 +243,7 @@ async def main( mp_lock = manager.Lock() processes: list[mp.Process] = [] tasks: list[asyncio.Task] = [] + event_stop = mp.Event() model_scheduler = AsyncModelScheduler(llm_model_manager=ModelManager(), mp_lock=mp_lock, sync=True) @@ -216,15 +252,19 @@ async def main( if settings.shared_settings.DEPLOY_SCORING_API and not settings.shared_settings.NEURON_DISABLE_SET_WEIGHTS: # Use multiprocessing to bypass API blocking issue api_process = mp.Process( - target=start_api, args=(scoring_queue, reward_events, miners_dict), name="APIProcess" + target=start_api, + args=(scoring_queue, reward_events, miners_dict, event_stop), + name="APIProcess", + daemon=True, ) api_process.start() processes.append(api_process) availability_process = mp.Process( target=start_availability_checking_loop, - args=(miners_dict,), + args=(miners_dict, event_stop), name="AvailabilityProcess", + daemon=True, ) availability_process.start() processes.append(availability_process) @@ -243,62 +283,73 @@ async def main( sending_task = mp.Process( target=start_task_sending_loop, - args=(task_queue, scoring_queue, miners_dict), + args=(task_queue, scoring_queue, miners_dict, event_stop), name="SendingTaskProcess", + daemon=True, ) sending_task.start() processes.append(sending_task) weight_setter_process = mp.Process( target=start_weight_setter_loop, - args=(reward_events,), + args=(reward_events, event_stop), name="WeightSetterProcess", + daemon=True, ) weight_setter_process.start() processes.append(weight_setter_process) - GPUInfo.log_gpu_info() + health_check_process = mp.Process( + target=health_check, + args=(os.getpid(), event_stop), + name="HealthCheckProcess", + daemon=True, + ) + health_check_process.start() + processes.append(health_check_process) - step = 0 + GPUInfo.log_gpu_info() while True: await asyncio.sleep(30) - block = settings.shared_settings.block - if ( - block - settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID] > 500 - and step > 150 - ): - last_update_block = settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID] - logger.warning( - f"Metagraph hasn't been updated for {block - last_update_block} blocks. " - f"Staled block: {block}, Last update: {last_update_block}" - ) - break - step += 1 except KeyboardInterrupt: + event_stop.set() logger.info("KeyboardInterrupt detected. Shutting down gracefully...") except Exception as e: logger.error(f"Main loop error: {e}") raise finally: - logger.warning("🚨 Force‑killing entire process‑group") + logger.warning("🚨 Force‑killing entire process‑group") # 1. Cancel in‑process tasks so they stop touching the Manager. for t in tasks: t.cancel() await asyncio.gather(*tasks, return_exceptions=True) + await asyncio.sleep(5) # 2. Manager cleanup *first* (so its socket vanishes). manager.shutdown() # 3. Sledgehammer. - if os.name == "posix": + try: os.killpg(0, signal.SIGKILL) - else: - logger.error(f"Unsupported OS: {os.name}") + except Exception as e: + logger.error(f"Failed to kill process group: {e}") sys.exit(1) +def kill_process_group(): + try: + os.killpg(os.getpgid(0), signal.SIGKILL) + except Exception as e: + logger.error(f"Failed to kill process group: {e}") + + # The main function parses the configuration and runs the validator. if __name__ == "__main__": + try: + os.setpgrp() + atexit.register(kill_process_group) + except BaseException: + logger.warning("Failed to set process group; emergency termination may not work.") asyncio.run(main()) diff --git a/prompting/llms/model_manager.py b/prompting/llms/model_manager.py index cde70f52a..abff765e7 100644 --- a/prompting/llms/model_manager.py +++ b/prompting/llms/model_manager.py @@ -49,15 +49,12 @@ async def __aexit__(self, exc_type, exc, tb): class ModelManager(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - always_active_models: list[ModelConfig] = [] total_ram: float = settings.shared_settings.LLM_MODEL_RAM active_models: dict[ModelConfig, ReproducibleVLLM] = {} + loading_tasks: dict[ModelConfig, asyncio.Future] = {} used_ram: float = 0.0 lock: ClassVar[AsyncRLock] = AsyncRLock() - - async def load_always_active_models(self): - for model_config in self.always_active_models: - await self.load_model(model_config=model_config) + # lock: ClassVar[AsyncRLock] = asyncio.Lock() async def load_model(self, model_config: ModelConfig, force: bool = True) -> ReproducibleVLLM: """Load model into GPU. @@ -69,56 +66,40 @@ async def load_model(self, model_config: ModelConfig, force: bool = True) -> Rep force: If enabled, will unload all other models. """ async with self.lock: - if model_config in self.active_models.keys(): + # Copy active models, since they will be modified in the loop. + active_models = set(self.active_models.keys()) + + if model_config in active_models: logger.debug(f"Model {model_config.llm_model_id} is already loaded.") return self.active_models[model_config] if force: logger.debug(f"Forcing model {model_config.llm_model_id} to load.") - for active_model in list(self.active_models.keys()): - if active_model in self.always_active_models: - continue + for active_model in active_models: logger.debug(f"Unloading {active_model.llm_model_id} to make room for {model_config.llm_model_id}") - await self._unload_model(active_model) await self.cleanup() - retries_max = 1 - retry_counter = 0 - retry_delay = 15 - while True: - try: - GPUInfo.log_gpu_info() - model_class = model_factory(model_config.llm_model_id) - model = model_class( - model_id=model_config.llm_model_id, - device=settings.shared_settings.NEURON_DEVICE, - sampling_params=settings.shared_settings.SAMPLING_PARAMS, - ) - self.used_ram += model_config.min_ram - logger.info( - f"Model {model_config.llm_model_id} has been successfully loaded. " - f"Approx. used VRAM: {self.used_ram:.0f}GB" - ) - self.active_models[model_config] = model - await asyncio.sleep(1.0) - return model - except BaseException as e: - if retry_counter > retries_max: - logger.error(f"Failed to load model after {retries_max} retries. Terminating process") - await self.cleanup() - # In case of VRAM leak, raise an exception to terminate the process. - raise MemoryError - - retry_counter += 1 - retry_delay += retry_counter - await self.cleanup() - logger.error( - f"Failed to load model {model_config.llm_model_id}. Retrying in {retry_delay} seconds. " - f"Error: {str(e)}" - ) - logger.debug(f"Current active models: {self.active_models}") - await asyncio.sleep(retry_delay) + try: + GPUInfo.log_gpu_info() + model_class = model_factory(model_config.llm_model_id) + model = model_class( + model_id=model_config.llm_model_id, + device=settings.shared_settings.NEURON_DEVICE, + sampling_params=settings.shared_settings.SAMPLING_PARAMS, + ) + self.active_models[model_config] = model + self.used_ram += model_config.min_ram + logger.info( + f"Model {model_config.llm_model_id} has been successfully loaded. " + f"Approx. used VRAM: {self.used_ram:.0f}GB" + ) + await asyncio.sleep(1.0) + return model + except BaseException as e: + await self.cleanup() + # In case of VRAM leak, raise an exception to terminate the process. + raise MemoryError(f"Failed to load model {model_config.llm_model_id}: {e}") async def _cleanup_model(self, model_instance: ReproducibleVLLM, cpu_offload: bool = False): """Free VRAM from given model.""" @@ -144,12 +125,10 @@ async def _unload_model(self, model_config: ModelConfig): return try: - model_instance = self.active_models.pop(model_config) - - # Record initial memory state for debugging. initial_free_memory = GPUInfo.free_memory logger.debug(f"Initial free GPU memory before unloading: {initial_free_memory} GB") - + # async with self.rlock: + model_instance = self.active_models.pop(model_config) await self._cleanup_model(model_instance, cpu_offload=False) await self.cleanup() @@ -167,13 +146,13 @@ async def _unload_model(self, model_config: ModelConfig): async def get_model(self, llm_model: ModelConfig | str) -> ReproducibleVLLM: async with self.lock: if not llm_model: - llm_model = list(self.active_models.keys())[0] if self.active_models else ModelZoo.get_random() + llm_model = next(iter(self.active_models.keys())) if self.active_models else ModelZoo.get_random() if isinstance(llm_model, str): llm_model = ModelZoo.get_model_by_id(llm_model) if llm_model in self.active_models: return self.active_models[llm_model] - return await self.load_model(llm_model, force=True) + return await self.load_model(llm_model) async def generate( self, diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py index 0442b6b59..a58fbd180 100644 --- a/prompting/llms/vllm_llm.py +++ b/prompting/llms/vllm_llm.py @@ -191,21 +191,17 @@ def unload_model(self): if hasattr(self.model, "llm_engine") and hasattr(self.model.llm_engine, "driver_worker"): del self.model.llm_engine.driver_worker if hasattr(self.model, "model"): - self.model = None del self.model if hasattr(self.model, "tokenizer"): - self.tokenizer = None del self.tokenizer gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() - + if torch.cuda.is_available(): + torch.cuda.empty_cache() logger.info("Successfully deleted the LLM pipeline and freed GPU memory") - - except Exception as e: + except BaseException as e: logger.error(f"An error occurred during model unloading: {e}") gc.collect() if torch.cuda.is_available(): diff --git a/prompting/rewards/exact_match.py b/prompting/rewards/exact_match.py index 84be107da..46c7dd9c2 100644 --- a/prompting/rewards/exact_match.py +++ b/prompting/rewards/exact_match.py @@ -11,15 +11,31 @@ from shared.dendrite import DendriteResponseEvent shared_settings = settings.shared_settings -INCORRECT_PENALTY = 1 -INCOMPLETE_PENALTY = 1 +INCORRECT_PENALTY = 0.5 +INCOMPLETE_PENALTY = 0.25 +MIN_SMOOTH_REWARD = 0.6 VERIFICATION_RATIO = 0.1 VERIFICATION_THRESHOLD = 0.9 -def normalize_timing(timing: float, timings: float) -> float: +def smooth_timings_reward(timings_uid: list[float], min_reward: float = MIN_SMOOTH_REWARD) -> float: + """Return smooth stream ration based on the deviation between chunks timings. + + Args: + timings_uid: List of timings for a specific miner. + + Returns: + float: Smoothed penalty value. """ - Normalize the timing so that a lower timing (i.e. faster response) is closer to 1. + if not timings_uid: + return 0.0 + + smooth_penalty = np.std(timings_uid) + return max(min_reward, 1.0 - smooth_penalty) + + +def normalize_timing(timing: float, timings: float) -> float: + """Normalize the timing so that a lower timing (i.e. faster response) is closer to 1. Ensures the normalized value is between 0 and 1. """ @@ -34,16 +50,15 @@ def normalize_timing(timing: float, timings: float) -> float: return min(1, max(0, (last_chunk - timing) / last_chunk)) -def verify_single_logit(original_logits, verification_logits): - """ - Verify logits by computing cosine similarity between original and verification logits. +def verify_single_logit(original_logits: dict[str, float], verification_logits: dict[str, float]) -> float: + """Verify logits by computing cosine similarity between original and verification logits. Args: - original_logits: Original model logits - verification_logits: Verification model logits + original_logits: Original model logits. + verification_logits: Verification model logits. Returns: - float: Cosine similarity score + float: Cosine similarity score. """ # Create aligned vectors with same token ordering all_tokens = set(original_logits.keys()) | set(verification_logits.keys()) @@ -64,7 +79,7 @@ def verify_single_logit(original_logits, verification_logits): # Calculate cosine similarity orig_vec = orig_vec / np.linalg.norm(orig_vec) verif_vec = verif_vec / np.linalg.norm(verif_vec) - return np.dot(orig_vec, verif_vec) + return float(np.dot(orig_vec, verif_vec)) class LogitsRewardModel(BaseRewardModel): @@ -76,11 +91,7 @@ async def reward( model_manager: ModelManager, **kwargs, ) -> BatchRewardOutput: - """ - Calculates rewards based on the logits of the response and verifies them. - """ - - # Check that the model_manager is set + """Calculate rewards based on the logits of the response and verifies them.""" if model_manager is None: raise ValueError("Model manager must be set") @@ -96,52 +107,48 @@ async def reward( timings=np.array([0.0] * len(completions)), ) - max_length = 0 - for chunk in all_chunks: - if chunk and max_length < len(chunk): - max_length = len(chunk) - - if max_length == 0: - logger.debug("No chunks to verify, penalizing all") + if all(not chunk for chunk in all_chunks): + logger.warning("No chunks to verify, penalizing all miners") return PENALIZE_ALL if timeout <= 0: - logger.error("Timeout must be greater than 0. Received timeout: {}", timeout) + logger.error(f"Timeout must be greater than 0. Received timeout: {timeout}") raise ValueError("Timeout must be greater than 0.") - timing_outputs, rewards = [], [] - num_verify = max(1, int(max_length * VERIFICATION_RATIO)) - verify_indices = random.sample( - range(max_length - 1), num_verify - 1 - ) # Sample one less to save room for last index - verify_indices.append(max_length - 1) # Always verify the last index - verify_indices.sort() - - # Iterate over each response event - + # If max_tokens are not provided, always check for eos. + max_tokens = sampling_parameters.get("max_tokens", float("inf")) + model = await model_manager.get_model(task.llm_model_id) + eos_token = model.tokenizer.eos_token + timing_outputs = [] + rewards = [] + # Iterate over each miner response. for chunks, timings, chunk_dicts_raw, uid in zip(all_chunks, all_timings, all_chunk_dicts_raw, uids): try: - # If no response is provided, apply full penalty + # If no response is provided, apply full penalty. if not chunks: rewards.append(-INCORRECT_PENALTY) timing_outputs.append(0.0) continue - # Verify logits for selected indices - verification_scores = [] completion_length = len(chunks) + # Sample from 1 to 20 indices for verification. + num_verify = max(1, min(20, int(completion_length * VERIFICATION_RATIO))) + # Sample one less to save room for last index. + verify_indices = random.sample(range(completion_length - 1), num_verify - 1) + # Always verify the last index. + last_idx = completion_length - 1 + verify_indices.append(last_idx) + verify_indices.sort() + # Verify logits for selected indices. + verification_scores = [] for idx in verify_indices: check_idx = min(idx, completion_length - 1) if not chunk_dicts_raw[check_idx].choices[0].logprobs: - logger.debug(f"Miner {uid} failed to provide logprobs: {chunk_dicts_raw[check_idx]}") - verification_scores.append(0.0) - continue + raise ValueError("No logprobs provided") if chunk_dicts_raw[check_idx].choices[0].logprobs.content is None: - logger.debug(f"Miner {uid} failed to provide logprobs content: {chunk_dicts_raw[check_idx]}") - verification_scores.append(0.0) - continue + raise ValueError("Logprobs content is empty") original_logits = { info.token: info.logprob @@ -157,28 +164,36 @@ async def reward( logit_score = verify_single_logit(original_logits, verification_output) verification_scores.append(logit_score) - if idx >= completion_length: - break - final_score = np.mean(verification_scores) - # Compute timing reward - valid_chunks = [] + if idx == last_idx and completion_length < max_tokens: + if eos_token and (eos_token not in original_logits or eos_token not in verification_output): + # Do not set full penalty, since top_k = 50 and top_lobprobs = 10. + # TODO: Make top_k equal to top_logprobs and check for token in top_logprobs. + verification_scores = [-INCOMPLETE_PENALTY] + + final_score = float(np.mean(verification_scores)) + if final_score < VERIFICATION_THRESHOLD: + rewards.append(0.0) + timing_outputs.append(0.0) + continue + + valid_chunks: list[float] = [] for chunk, timing in zip(chunks, timings): if chunk: valid_chunks.append(normalize_timing(timing, all_timings)) + timing_reward = float(np.mean(valid_chunks)) if valid_chunks else 0.0 + smooth_reward = smooth_timings_reward(timings) - timing_reward = np.mean(valid_chunks) if valid_chunks else 0.0 - - rewards.append(float(final_score > VERIFICATION_THRESHOLD) * timing_reward) + rewards.append(final_score * timing_reward * smooth_reward) timing_outputs.append(np.array(valid_chunks).mean()) - except Exception as e: - logger.debug(f"Miner {uid} failed to provide logits chunk, setting reward to 0: {e}") - rewards.append(0.0) + except BaseException as e: + logger.debug(f"Miner {uid} failed to pass logits check: {e}") + rewards.append(-INCORRECT_PENALTY) timing_outputs.append(0.0) reward_output = BatchRewardOutput( rewards=np.array(rewards), timings=np.array(timing_outputs), ) - logger.debug(f"REWARD OUTPUT: {reward_output.model_dump()}") + logger.debug(f"Logits rewards: {reward_output.model_dump()}") return reward_output diff --git a/pyproject.toml b/pyproject.toml index 0038cd5cd..8c0e38836 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "prompting" -version = "2.18.2" +version = "2.18.3" description = "Subnetwork 1 runs on Bittensor and is maintained by Macrocosmos. It's an effort to create decentralised AI" authors = ["Kalei Brady, Dmytro Bobrenko, Felix Quinque, Steffen Cruz, Richard Wardle"] readme = "README.md" diff --git a/tests/prompting/rewards/test_exact_match.py b/tests/prompting/rewards/test_exact_match.py new file mode 100644 index 000000000..72678501d --- /dev/null +++ b/tests/prompting/rewards/test_exact_match.py @@ -0,0 +1,250 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest +from openai.types.chat import ChatCompletionChunk + +from prompting.llms.model_manager import ModelManager +from prompting.rewards.exact_match import ( + INCORRECT_PENALTY, + MIN_SMOOTH_REWARD, + VERIFICATION_THRESHOLD, + LogitsRewardModel, + smooth_timings_reward, + verify_single_logit, +) +from prompting.rewards.reward import BatchRewardOutput +from prompting.tasks.base_task import BaseTextTask +from shared.dendrite import DendriteResponseEvent + + +@pytest.fixture +def model_manager(): + """Mock ModelManager for testing.""" + manager = MagicMock(spec=ModelManager) + model = MagicMock() + tokenizer = MagicMock() + tokenizer.eos_token = "<|endoftext|>" + model.tokenizer = tokenizer + manager.get_model.return_value = model + + async def mock_generate_logits(*args, **kwargs): + return {"token1": -0.1, "token2": -0.5, "<|endoftext|>": -1.0}, "prompt" + + manager.generate_logits = AsyncMock(side_effect=mock_generate_logits) + return manager + + +@pytest.fixture +def task(): + """Mock Task for testing.""" + task = MagicMock(spec=BaseTextTask) + task.llm_model_id = "gpt-4" + task.task_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + task.sampling_params = {"temperature": 0.7, "max_tokens": 100} + return task + + +def create_chat_completion_chunk(content="", logprobs=None): + """Helper function to create a ChatCompletionChunk object.""" + if logprobs is None: + logprobs = {"token1": -0.1, "token2": -0.5, "token3": -0.6, "token4": -0.7, "<|endoftext|>": -1.0} + + chunk = MagicMock(spec=ChatCompletionChunk) + choice = MagicMock() + choice.index = 0 + choice.delta = MagicMock() + choice.delta.role = "assistant" + choice.delta.content = content + + if logprobs: + choice.logprobs = MagicMock() + choice.logprobs.content = [MagicMock()] + choice.logprobs.content[0].top_logprobs = [] + for token, logprob in logprobs.items(): + token_logprob = MagicMock() + token_logprob.token = token + token_logprob.logprob = logprob + choice.logprobs.content[0].top_logprobs.append(token_logprob) + else: + choice.logprobs = None + + chunk.choices = [choice] + chunk.id = "chunk_id" + chunk.created = 1234567890 + chunk.model = "VeryStronkModel" + chunk.object = "chat.completion.chunk" + chunk.usage = None + return chunk + + +@pytest.mark.asyncio +async def test_ideal_completion(model_manager, task): + """Test case 1: Ideal completion with reward >0.5 and ≤1.""" + chunks = [["Hello", ", ", "world", "!"]] + chunk_dicts_raw = [ + [ + create_chat_completion_chunk("Hello"), + create_chat_completion_chunk(", "), + create_chat_completion_chunk("world"), + create_chat_completion_chunk("!"), + ] + ] + + with patch("prompting.rewards.exact_match.verify_single_logit", return_value=0.95): + response_event = MagicMock(spec=DendriteResponseEvent) + response_event.stream_results_all_chunks = chunks + response_event.stream_results_all_chunk_dicts_raw = chunk_dicts_raw + response_event.uids = [1] + response_event.stream_results_all_chunks_timings = [[0.1, 0.2, 0.3, 0.4]] + response_event.completions = ["Hello, world!"] + response_event.timeout = 10.0 + + reward_model = LogitsRewardModel() + result = await reward_model.reward( + reference="", response_event=response_event, task=task, model_manager=model_manager + ) + + assert isinstance(result, BatchRewardOutput) + assert len(result.rewards) == 1 + assert 0.2 < result.rewards[0] <= 0.4 + + +@pytest.mark.asyncio +async def test_mixed_completions(model_manager, task): + """Test case 2: One ideal completion, one with missing logprobs penalized.""" + chunks = [["Hello", ", ", "world", "!"], ["Fail", "ed", " ", "completion"], ["Wro", "ng", " ", "completion"]] + correct_logprobs = [] + for part in chunks[0]: + correct_logprobs.append(create_chat_completion_chunk(part)) + + incorrect_logprobs = [] + wrong_logprobs = {"wrong": -0.1, "log": -5.43, "prob": -8.54, "defined": -11, "<|endoftext|>": -3000000} + for part in chunks[1]: + incorrect_logprobs.append(create_chat_completion_chunk(part, logprobs=wrong_logprobs)) + empty_logprobs = [] + for part in chunks[1]: + empty_logprobs.append(create_chat_completion_chunk(part, logprobs={})) + chunk_dicts_raw = [correct_logprobs, incorrect_logprobs, empty_logprobs] + + # Mock verify_single_logit to return different values based on input + def mock_verify(original_logits, verification_logits): + # Check if this is the incorrect logprobs case + if "wrong" in original_logits: + return VERIFICATION_THRESHOLD * 0.9 + else: + return VERIFICATION_THRESHOLD * 1.1 + + with patch("prompting.rewards.exact_match.verify_single_logit", side_effect=mock_verify): + response_event = MagicMock(spec=DendriteResponseEvent) + response_event.stream_results_all_chunks = chunks + response_event.stream_results_all_chunk_dicts_raw = chunk_dicts_raw + response_event.uids = [1, 2, 3] + response_event.stream_results_all_chunks_timings = [ + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ] + response_event.completions = ["Hello, world!", "Missing logprobs", "Empty logprobs"] + response_event.timeout = 10.0 + + reward_model = LogitsRewardModel() + result = await reward_model.reward( + reference="", response_event=response_event, task=task, model_manager=model_manager + ) + + assert isinstance(result, BatchRewardOutput) + assert len(result.rewards) == 3 + assert 0.2 < result.rewards[0] <= 0.5 + assert result.rewards[1] == 0 + assert result.rewards[2] == -INCORRECT_PENALTY + + +@pytest.mark.asyncio +async def test_no_eos_token(model_manager, task): + """Test case 3: Missing eos_token in logits → zero reward.""" + chunks = [["Hello", ", ", "world", "!"]] + chunk_dicts_raw = [ + [ + create_chat_completion_chunk("Hello"), + create_chat_completion_chunk(", "), + create_chat_completion_chunk("world"), + create_chat_completion_chunk("!"), + ] + ] + + async def mock_generate_logits_no_eos(*args, **kwargs): + return {"token1": -0.1, "token2": -0.5}, "prompt" + + model_manager.generate_logits = AsyncMock(side_effect=mock_generate_logits_no_eos) + + # Replace last chunk without eos in its logprobs + chunk_dicts_raw[0][3] = create_chat_completion_chunk("!", {"token1": -0.1, "token2": -0.5}) + + with patch("prompting.rewards.exact_match.verify_single_logit", return_value=0.95): + response_event = MagicMock(spec=DendriteResponseEvent) + response_event.stream_results_all_chunks = chunks + response_event.stream_results_all_chunk_dicts_raw = chunk_dicts_raw + response_event.uids = [1] + response_event.stream_results_all_chunks_timings = [[0.1, 0.2, 0.3, 0.4]] + response_event.completions = ["Hello, world!"] + response_event.timeout = 10.0 + + reward_model = LogitsRewardModel() + result = await reward_model.reward( + reference="", response_event=response_event, task=task, model_manager=model_manager + ) + + assert isinstance(result, BatchRewardOutput) + assert len(result.rewards) == 1 + assert result.rewards[0] == 0.0 + + +def test_verify_single_logit(): + """Test the verify_single_logit similarity metric.""" + original = {"token1": -0.1, "token2": -0.5} + # Identical distributions → 1.0 + assert verify_single_logit(original, original) == 1.0 + + # Disjoint tokens → near zero + disjoint = {"foo": -0.1, "bar": -0.5} + sim = verify_single_logit(original, disjoint) + assert 0 <= sim <= 0.01 + + # Partial overlap → between 0 and 1 + partial = {"token1": -0.1, "foo": -0.5} + sim2 = verify_single_logit(original, partial) + assert 0 < sim2 < 1.0 + + +def test_smooth_reward_scale(): + """Test the smooth_reward_scale function under various conditions.""" + # Test empty timings list. + assert smooth_timings_reward([]) == 0.0 + + # Test uniform timings (should give maximum reward). + uniform_timings = [1.0, 1.0, 1.0, 1.0, 1.0] + assert smooth_timings_reward(uniform_timings) == 1.0 + + # Test high variance timings (should give minimum reward). + high_var_timings = [0.1, 5.0, 10.0, 0.5, 8.0] + std_dev = np.std(high_var_timings) + assert smooth_timings_reward(high_var_timings) == MIN_SMOOTH_REWARD + assert 1.0 - std_dev < MIN_SMOOTH_REWARD + + # Test moderate variance timings + moderate_var_timings = [0.9, 1.0, 1.1, 0.95, 1.05] + expected = max(MIN_SMOOTH_REWARD, 1.0 - np.std(moderate_var_timings)) + assert smooth_timings_reward(moderate_var_timings) == pytest.approx(expected) + assert MIN_SMOOTH_REWARD < smooth_timings_reward(moderate_var_timings) < 1.0 + + # Test with custom minimum reward. + custom_min = 0.8 + assert smooth_timings_reward(high_var_timings, min_reward=custom_min) == custom_min + + # Test with single timing value. + single_timing = [1.5] + assert smooth_timings_reward(single_timing) == 1.0