diff --git a/prompting/__init__.py b/prompting/__init__.py index 476009f42..643f4f1ac 100644 --- a/prompting/__init__.py +++ b/prompting/__init__.py @@ -16,7 +16,7 @@ # DEALINGS IN THE SOFTWARE. # Define the version of the template module. -__version__ = "2.3.1" +__version__ = "2.4.0" version_split = __version__.split(".") __spec_version__ = ( (10000 * int(version_split[0])) diff --git a/prompting/dendrite.py b/prompting/dendrite.py index a9df0aa8c..7b9981a94 100644 --- a/prompting/dendrite.py +++ b/prompting/dendrite.py @@ -1,45 +1,59 @@ import torch -import bittensor as bt from typing import List +from dataclasses import dataclass +from prompting.protocol import StreamPromptingSynapse +from prompting.utils.misc import serialize_exception_to_string + + +@dataclass +class SynapseStreamResult: + exception: BaseException = None + uid: int = None + accumulated_chunks: List[str] = None + accumulated_chunks_timings: List[float] = None + tokens_per_chunk: List[int] = None + synapse: StreamPromptingSynapse = None class DendriteResponseEvent: def __init__( - self, responses: List[bt.Synapse], uids: torch.LongTensor, timeout: float + self, stream_results: SynapseStreamResult, uids: torch.LongTensor, timeout: float ): self.uids = uids self.completions = [] self.status_messages = [] self.status_codes = [] self.timings = [] + self.stream_results_uids = [] + self.stream_results_exceptions = [] + self.stream_results_all_chunks = [] + self.stream_results_all_chunks_timings = [] + self.stream_results_all_tokens_per_chunk = [] + + for stream_result in stream_results: + synapse = stream_result.synapse - for synapse in responses: self.completions.append(synapse.completion) self.status_messages.append(synapse.dendrite.status_message) + status_code = synapse.dendrite.status_code - if len(synapse.completion) == 0 and synapse.dendrite.status_code == 200: - synapse.dendrite.status_code = 204 + if len(synapse.completion) == 0 and status_code == 200: + status_code = 204 - self.status_codes.append(synapse.dendrite.status_code) - - if (synapse.dendrite.process_time) and ( - synapse.dendrite.status_code == 200 - or synapse.dendrite.status_code == 204 - ): - self.timings.append(synapse.dendrite.process_time) - elif synapse.dendrite.status_code == 408: + self.status_codes.append(status_code) + process_time = synapse.dendrite.process_time or 0 + if status_code == 200 or status_code == 204: + self.timings.append(process_time) + elif status_code == 408: self.timings.append(timeout) else: - self.timings.append(0) # situation where miner is not alive + self.timings.append(0) - self.completions = [synapse.completion for synapse in responses] - self.timings = [ - synapse.dendrite.process_time or timeout for synapse in responses - ] - self.status_messages = [ - synapse.dendrite.status_message for synapse in responses - ] - self.status_codes = [synapse.dendrite.status_code for synapse in responses] + self.stream_results_uids.append(stream_result.uid) + self.stream_results_exceptions.append(serialize_exception_to_string(stream_result.exception)) + self.stream_results_all_chunks.append(stream_result.accumulated_chunks) + self.stream_results_all_chunks_timings.append(stream_result.accumulated_chunks_timings) + self.stream_results_all_tokens_per_chunk.append(stream_result.tokens_per_chunk) def __state_dict__(self): return { @@ -48,6 +62,11 @@ def __state_dict__(self): "timings": self.timings, "status_messages": self.status_messages, "status_codes": self.status_codes, + "stream_results_uids": self.stream_results_uids, + "stream_results_exceptions": self.stream_results_exceptions, + "stream_results_all_chunks": self.stream_results_all_chunks, + "stream_results_all_chunks_timings": self.stream_results_all_chunks_timings, + "stream_results_all_tokens_per_chunk": self.stream_results_all_tokens_per_chunk, } def __repr__(self): diff --git a/prompting/forward.py b/prompting/forward.py index 170f85254..46d7b9836 100644 --- a/prompting/forward.py +++ b/prompting/forward.py @@ -13,8 +13,7 @@ # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN -# THE SOFTWARE. +# DEALINGS IN THE SOFTWARE. import sys import time @@ -25,7 +24,7 @@ import bittensor as bt from typing import List, Dict, Awaitable from prompting.agent import HumanAgent -from prompting.dendrite import DendriteResponseEvent +from prompting.dendrite import DendriteResponseEvent, SynapseStreamResult from prompting.conversation import create_task from prompting.protocol import StreamPromptingSynapse from prompting.rewards import RewardResult @@ -33,6 +32,8 @@ from prompting.utils.uids import get_random_uids from prompting.utils.logging import log_event from prompting.utils.misc import async_log, serialize_exception_to_string +from transformers import PreTrainedTokenizerFast as Tokenizer +from prompting.utils.uids import get_random_uids from dataclasses import dataclass @async_log @@ -46,32 +47,35 @@ async def execute_dendrite_call(dendrite_call): responses = await dendrite_call return responses -@dataclass -class StreamResult: - synapse: StreamPromptingSynapse = None - exception: BaseException = None - uid: int = None - -async def process_response(uid: int, async_generator: Awaitable): +async def process_stream(uid: int, async_iterator: Awaitable, tokenizer: Tokenizer) -> SynapseStreamResult: """Process a single response asynchronously.""" - try: - chunk = None # Initialize chunk with a default value - async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse. - bt.logging.debug(f"\nchunk for uid {uid}: {chunk}") - - if chunk is not None: - synapse = chunk # last object yielded is the synapse itself with completion filled - - # Assuming chunk holds the last value yielded which should be a synapse - if isinstance(synapse, StreamPromptingSynapse): - return synapse - - bt.logging.debug( - f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' " - ) - except Exception as e: - # bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True) + synapse = None # Initialize chunk with a default value + exception = None + accumulated_chunks = [] + accumulated_chunks_timings = [] + accumulated_tokens_per_chunk = [] + start_time = time.time() + + try: + async for chunk in async_iterator: # most important loop, as this is where we acquire the final synapse. + if isinstance(chunk, str): + accumulated_chunks.append(chunk) + accumulated_chunks_timings.append(time.time() - start_time) + + tokens_in_chunk = len(tokenizer.tokenize(chunk)) + accumulated_tokens_per_chunk.append(tokens_in_chunk) + + bt.logging.debug(f"\nchunk for uid {uid}: {chunk}") + + # Assuming last chunk of async_iterator holds the last value yielded as a StreamingSynapse + synapse = chunk + if synapse is None or not isinstance(synapse, StreamPromptingSynapse): + raise ValueError( + f"Something went wrong with miner uid {uid}, Synapse is not StreamPromptingSynapse." + ) + except Exception as e: + exception = e traceback_details = traceback.format_exc() bt.logging.error( f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}" @@ -81,11 +85,20 @@ async def process_response(uid: int, async_generator: Awaitable): roles=["user"], messages=["failure"], completion="" ) - return failed_synapse + synapse = failed_synapse + finally: + return SynapseStreamResult( + accumulated_chunks=accumulated_chunks, + accumulated_chunks_timings=accumulated_chunks_timings, + tokens_per_chunk=accumulated_tokens_per_chunk, + synapse=synapse, + uid=uid, + exception=exception + ) @async_log -async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]: +async def handle_response(stream_results_dict: Dict[int, Awaitable], tokenizer: Tokenizer) -> List[SynapseStreamResult]: """The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults. @@ -99,36 +112,14 @@ async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult] List[StreamResult]: DataClass containing the synapse, exception, and uid """ tasks_with_uid = [ - (uid, responses[uid]) for uid, _ in responses.items() + (uid, stream_results_dict[uid]) for uid, _ in stream_results_dict.items() ] # Pair UIDs with their tasks # Start tasks, preserving order and their associated UIDs - tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid] - - results = await asyncio.gather(*tasks, return_exceptions=True) - - mapped_results = [] - # Pair each result with its original uid - for (uid, _), result in zip(tasks_with_uid, results): - # If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions - if isinstance(result, StreamPromptingSynapse): - mapped_results.append(StreamResult(synapse=result, uid=uid)) - - # If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse - elif isinstance(result, BaseException): - failed_synapse = StreamPromptingSynapse( - roles=["user"], messages=["failure"], completion="" - ) - mapped_results.append( - StreamResult(synapse=failed_synapse, exception=result, uid=uid) - ) - - # If the result is neither an error or a StreamSynapse, log the error and raise a ValueError - else: - bt.logging.error(f"Unexpected result type for UID {uid}: {result}") - raise ValueError(f"Unexpected result type for UID {uid}: {result}") - - return mapped_results + process_stream_tasks = [process_stream(uid, resp, tokenizer) for uid, resp in tasks_with_uid] + processed_stream_results = await asyncio.gather(*process_stream_tasks, return_exceptions=True) + + return processed_stream_results @async_log @@ -140,7 +131,7 @@ async def generate_reference(agent: HumanAgent): return result -def log_stream_results(stream_results: List[StreamResult]): +def log_stream_results(stream_results: List[SynapseStreamResult]): failed_responses = [ response for response in stream_results if response.exception is not None ] @@ -186,7 +177,6 @@ async def run_step( timeout (float): The timeout for the queries. exclude (list, optional): The list of uids to exclude from the query. Defaults to []. """ - bt.logging.debug("run_step", agent.task.name) # Record event start time. @@ -207,9 +197,9 @@ async def run_step( ) # Prepare the task for handling stream responses - handle_stream_responses_task = asyncio.create_task( - handle_response(responses=dict(zip(uids_cpu, streams_responses))) - ) + stream_results_dict = dict(zip(uids_cpu, streams_responses)) + tokenizer = self.llm_pipeline.tokenizer + handle_stream_responses_task = asyncio.create_task(handle_response(stream_results_dict, tokenizer)) if not agent.task.static_reference: reference_generation_task = generate_reference(agent) @@ -221,11 +211,9 @@ async def run_step( log_stream_results(stream_results) - all_synapses_results = [stream_result.synapse for stream_result in stream_results] - # Encapsulate the responses in a response event (dataclass) response_event = DendriteResponseEvent( - responses=all_synapses_results, uids=uids, timeout=timeout + stream_results=stream_results, uids=uids, timeout=timeout ) bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}") @@ -248,20 +236,13 @@ async def run_step( ) self.update_scores(reward_result.rewards, uids) - - stream_results_uids = [stream_result.uid for stream_result in stream_results] - stream_results_exceptions = [ - serialize_exception_to_string(stream_result.exception) - for stream_result in stream_results - ] + # Log the step event. event = { "best": best_response, "block": self.block, "step": self.step, - "step_time": time.time() - start_time, - "stream_results_uids": stream_results_uids, - "stream_results_exceptions": stream_results_exceptions, + "step_time": time.time() - start_time, **agent.__state_dict__(full=self.config.neuron.log_full), **reward_result.__state_dict__(full=self.config.neuron.log_full), **response_event.__state_dict__(), diff --git a/prompting/llms/base_llm.py b/prompting/llms/base_llm.py index 0ed7b1398..e5df5d319 100644 --- a/prompting/llms/base_llm.py +++ b/prompting/llms/base_llm.py @@ -22,6 +22,7 @@ def __init__( self.model_kwargs = model_kwargs self.messages = [] self.times = [] + self.tokenizer = None def query( self, diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py index 58fca8af0..a8bf9e095 100644 --- a/prompting/llms/vllm_llm.py +++ b/prompting/llms/vllm_llm.py @@ -67,6 +67,7 @@ def __init__( self.llm = load_vllm_pipeline(model_id, device, gpus, llm_max_allowed_memory_in_gb, mock) self.mock = mock self.gpus = gpus + self.tokenizer = self.llm.llm_engine.tokenizer def __call__(self, composed_prompt: str, **model_kwargs: Dict) -> str: if self.mock: diff --git a/prompting/mock.py b/prompting/mock.py index e6058ddff..cc8ea8167 100644 --- a/prompting/mock.py +++ b/prompting/mock.py @@ -3,10 +3,11 @@ import asyncio import random import bittensor as bt -from prompting.protocol import StreamPromptingSynapse, PromptingSynapse +from prompting.protocol import StreamPromptingSynapse from functools import partial from typing import Dict, List, Union, AsyncGenerator, Any, Iterator +from types import SimpleNamespace class MockTokenizer: @@ -44,6 +45,12 @@ class MockPipeline: @property def tokenizer(self): return self.model.tokenizer + + @property + def llm_engine(self): + return SimpleNamespace( + tokenizer=self.model.tokenizer + ) def __init__( self, @@ -287,15 +294,10 @@ async def forward( deserialize: bool = True, run_async: bool = True, streaming: bool = False, - ): - if streaming: - assert isinstance( - synapse, StreamPromptingSynapse - ), "Synapse must be a StreamPromptingSynapse object when is_stream is True." - else: - assert isinstance( - synapse, PromptingSynapse - ), "Synapse must be a PromptingSynapse object when is_stream is False." + ): + assert isinstance( + synapse, StreamPromptingSynapse + ), "Synapse must be a StreamPromptingSynapse object when is_stream is True." async def query_all_axons(is_stream: bool): """Queries all axons for responses.""" diff --git a/prompting/protocol.py b/prompting/protocol.py index 1367a4f45..5b8ad01c7 100644 --- a/prompting/protocol.py +++ b/prompting/protocol.py @@ -17,116 +17,9 @@ import pydantic import bittensor as bt - from typing import List, AsyncIterator from starlette.responses import StreamingResponse -import pdb - - -class PromptingSynapse(bt.Synapse): - """ - The PromptingSynapse subclass of the Synapse class encapsulates the functionalities related to prompting scenarios. - - It specifies three fields - `roles`, `messages` and `completion` - that define the state of the PromptingSynapse object. - The `roles` and `messages` are read-only fields defined during object initialization, and `completion` is a mutable - field that can be updated as the prompting scenario progresses. - - The Config inner class specifies that assignment validation should occur on this class (validate_assignment = True), - meaning value assignments to the instance fields are checked against their defined types for correctness. - - Attributes: - roles (List[str]): A list of roles in the prompting scenario. This field is both mandatory and immutable. - messages (List[str]): A list of messages in the prompting scenario. This field is both mandatory and immutable. - completion (str): A string that captures completion of the prompt. This field is mutable. - required_hash_fields List[str]: A list of fields that are required for the hash. - - Methods: - deserialize() -> "PromptingSynapse": Returns the instance of the current object. - - - The `PromptingSynapse` class also overrides the `deserialize` method, returning the - instance itself when this method is invoked. Additionally, it provides a `Config` - inner class that enforces the validation of assignments (`validate_assignment = True`). - - Here is an example of how the `PromptingSynapse` class can be used: - - ```python - # Create a PromptingSynapse instance - prompt = PromptingSynapse(roles=["system", "user"], messages=["Hello", "Hi"]) - - # Print the roles and messages - print("Roles:", prompt.roles) - print("Messages:", prompt.messages) - - # Update the completion - model_prompt =... # Use prompt.roles and prompt.messages to generate a prompt - for your LLM as a single string. - prompt.completion = model(model_prompt) - - # Print the completion - print("Completion:", prompt.completion) - ``` - - This will output: - ``` - Roles: ['system', 'user'] - Messages: ['You are a helpful assistant.', 'Hi, what is the meaning of life?'] - Completion: "The meaning of life is 42. Deal with it, human." - ``` - - This example demonstrates how to create an instance of the `PromptingSynapse` class, access the - `roles` and `messages` fields, and update the `completion` field. - """ - - class Config: - """ - Pydantic model configuration class for PromptingSynapse. This class sets validation of attribute assignment as True. - validate_assignment set to True means the pydantic model will validate attribute assignments on the class. - """ - - validate_assignment = True - - def deserialize(self) -> "PromptingSynapse": - """ - Returns the instance of the current PromptingSynapse object. - - This method is intended to be potentially overridden by subclasses for custom deserialization logic. - In the context of the PromptingSynapse class, it simply returns the instance itself. However, for subclasses - inheriting from this class, it might give a custom implementation for deserialization if need be. - - Returns: - PromptingSynapse: The current instance of the PromptingSynapse class. - """ - return self - - roles: List[str] = pydantic.Field( - ..., - title="Roles", - description="A list of roles in the PromptingSynapse scenario. Immuatable.", - allow_mutation=False, - ) - - messages: List[str] = pydantic.Field( - ..., - title="Messages", - description="A list of messages in the PromptingSynapse scenario. Immutable.", - allow_mutation=False, - ) - - completion: str = pydantic.Field( - "", - title="Completion", - description="Completion status of the current PromptingSynapse object. This attribute is mutable and can be updated.", - ) - - required_hash_fields: List[str] = pydantic.Field( - ["messages"], - title="Required Hash Fields", - description="A list of required fields for the hash.", - allow_mutation=False, - ) - class StreamPromptingSynapse(bt.StreamingSynapse): """ @@ -212,9 +105,9 @@ async def process_streaming_response( self.completion = "" async for chunk in response.content.iter_any(): - tokens = chunk.decode("utf-8").split("\n") + tokens = chunk.decode("utf-8") - self.completion = self.completion + "".join([t for t in tokens if t]) + self.completion = self.completion + tokens yield tokens def deserialize(self) -> str: diff --git a/prompting/rewards/__init__.py b/prompting/rewards/__init__.py index 51cab779a..a44b50bfd 100644 --- a/prompting/rewards/__init__.py +++ b/prompting/rewards/__init__.py @@ -11,4 +11,5 @@ from .float_diff import FloatDiffModel from .date import DateRewardModel from .ordinal import OrdinalRewardModel +from .streaming import StreamingRewardModel from .pipeline import RewardPipeline, REWARD_MODELS diff --git a/prompting/rewards/code_diff.py b/prompting/rewards/code_diff.py index 356617902..df9e55cf4 100644 --- a/prompting/rewards/code_diff.py +++ b/prompting/rewards/code_diff.py @@ -6,6 +6,7 @@ BatchRewardOutput, RewardModelTypeEnum, ) +from prompting.dendrite import DendriteResponseEvent import time @@ -27,13 +28,13 @@ def unified_diff(self, reference, completion): def seq_match(self, reference, completion): return difflib.SequenceMatcher(None, reference, completion).ratio() - def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: + def reward(self, reference: str, response_event: DendriteResponseEvent) -> BatchRewardOutput: """Get the score between two strings. lines: If True, return a unified diff. If False, return a ratio. """ - rewards = [] timings = [] + completions: List[str] = response_event.completions if self.lines: for completion in completions: diff --git a/prompting/rewards/date.py b/prompting/rewards/date.py index 9fcb5bc1e..d2df28423 100644 --- a/prompting/rewards/date.py +++ b/prompting/rewards/date.py @@ -5,6 +5,7 @@ import numpy as np from typing import List from prompting.rewards import BaseRewardModel, BatchRewardOutput, RewardModelTypeEnum +from prompting.dendrite import DendriteResponseEvent import bittensor as bt @@ -83,7 +84,7 @@ def date_score(self, reference: str, completion: str) -> float: score = 0 return score - def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: + def reward(self, reference: str, response_event: DendriteResponseEvent) -> BatchRewardOutput: """Compute difference scores given a completion and reference pair. Args: @@ -93,6 +94,7 @@ def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: Returns: BatchRewardOutput: A BatchRewardOutput object containing the rewards and timings. """ + completions: List[str] = response_event.completions rewards = [] timings = [] diff --git a/prompting/rewards/float_diff.py b/prompting/rewards/float_diff.py index 7d14c65cd..387c2a8a8 100644 --- a/prompting/rewards/float_diff.py +++ b/prompting/rewards/float_diff.py @@ -3,7 +3,7 @@ from typing import List from sympy.parsing.sympy_parser import parse_expr from prompting.rewards import BaseRewardModel, BatchRewardOutput, RewardModelTypeEnum - +from prompting.dendrite import DendriteResponseEvent class FloatDiffModel(BaseRewardModel): @property @@ -52,10 +52,11 @@ def math_score(reference: str, completion: str) -> float: except Exception: return 0.0 - def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: + def reward(self, reference: str, response_event: DendriteResponseEvent) -> BatchRewardOutput: """Compute difference scores given a completion and reference pair.""" rewards = [] timings = [] + completions: List[str] = response_event.completions for completion in completions: t0 = time.time() diff --git a/prompting/rewards/ordinal.py b/prompting/rewards/ordinal.py index 34f749621..bf2801a79 100644 --- a/prompting/rewards/ordinal.py +++ b/prompting/rewards/ordinal.py @@ -2,7 +2,7 @@ import torch from typing import List from prompting.rewards import BaseRewardModel, BatchRewardOutput - +from prompting.dendrite import DendriteResponseEvent class OrdinalRewardModel(BaseRewardModel): @property @@ -18,11 +18,13 @@ def __init__(self, **kwargs): "negative", ] - def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: + def reward(self, reference: str, response_event: DendriteResponseEvent) -> BatchRewardOutput: """Compute difference scores given a completion and reference pair.""" rewards = [] timings = [] classes = self.sentiments + completions: List[str] = response_event.completions + for completion in completions: t0 = time.time() completion = completion.lower() diff --git a/prompting/rewards/pipeline.py b/prompting/rewards/pipeline.py index 244ec0b3e..2b126c401 100644 --- a/prompting/rewards/pipeline.py +++ b/prompting/rewards/pipeline.py @@ -9,6 +9,7 @@ FloatDiffModel, DateRewardModel, OrdinalRewardModel, + StreamingRewardModel ) REWARD_MODELS = { @@ -18,6 +19,7 @@ "float_diff": FloatDiffModel, "date": DateRewardModel, "ordinal": OrdinalRewardModel, + "streaming": StreamingRewardModel, } @@ -90,6 +92,7 @@ def load_reward_pipeline(self): for task in self.selected_tasks: active_reward_models += TASKS[task].reward_definition active_reward_models += TASKS[task].penalty_definition + active_reward_models += TASKS[task].global_penalty_definition # Instantiate only the required reward models reward_models = {} diff --git a/prompting/rewards/relevance.py b/prompting/rewards/relevance.py index b754ae330..e063a582c 100644 --- a/prompting/rewards/relevance.py +++ b/prompting/rewards/relevance.py @@ -6,8 +6,9 @@ from prompting.rewards import ( BaseRewardModel, BatchRewardOutput, - RewardModelTypeEnum, ) +from prompting.dendrite import DendriteResponseEvent + class RelevanceRewardModel(BaseRewardModel): @@ -25,7 +26,7 @@ def __init__(self, threshold=None, device=None, pooling_strategy="cls"): # This line is necessary to pass the model to the device defined at its initialization self.model = self.model.cuda() - def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: + def reward(self, reference: str, response_event: DendriteResponseEvent) -> BatchRewardOutput: """Calculates the cosine similarity between sentence embeddings of the reference and completions. We subtract a baseline score which is what an empty string would get (a failed completion). This is usually around 0.35 We also clip the rewards between 0 and 1. The maximum effective score is around 0.65 @@ -33,6 +34,7 @@ def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: reference_embedding = self.model.encode(reference, to_numpy=False) rewards = [] timings = [] + completions: List[str] = response_event.completions # baseline is the cosine similarity between the reference and an empty string baseline = cosine_similarity( reference_embedding.reshape(1, -1), diff --git a/prompting/rewards/reward.py b/prompting/rewards/reward.py index bf5d2bc0b..d51ce4c77 100644 --- a/prompting/rewards/reward.py +++ b/prompting/rewards/reward.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum - +from prompting.dendrite import DendriteResponseEvent class RewardModelTypeEnum(Enum): WEIGHTED_REWARD = "reward" @@ -28,12 +28,16 @@ class RewardEvent: # implement custom asdict to return a dict with the same keys as the dataclass using the model name def asdict(self) -> dict: return { - f"{self.model_name}_raw_{self.model_type.value}": self.rewards.tolist(), - f"{self.model_name}_{self.model_type.value}": self.rewards_normalized.tolist(), - f"{self.model_name}_{self.model_type.value}_timings": self.timings.tolist(), + f"{self.model_name}_raw_{self.model_type.value}": self.tensor_to_rounded_list(self.rewards), + f"{self.model_name}_{self.model_type.value}": self.tensor_to_rounded_list(self.rewards_normalized, 4), + f"{self.model_name}_{self.model_type.value}_timings": self.tensor_to_rounded_list(self.timings), f"{self.model_name}_{self.model_type.value}_batch_time": self.batch_time, f"{self.model_name}_{self.model_type.value}_extra_info": self.extra_info, } + + def tensor_to_rounded_list(self, tensor, decimals=6): + # Convert the tensor elements to floats and round them to 6 decimal places + return [round(float(element), decimals) for element in tensor] class RewardResult: @@ -51,7 +55,7 @@ def __init__(self, reward_pipeline, agent, response_event, device): self.response_event = response_event self.device = device self.task_rewards = agent.task.reward_definition - self.task_penalties = agent.task.penalty_definition + self.task_penalties = agent.task.penalty_definition + agent.task.global_penalty_definition self.reward_events = self.reward_responses( reference=agent.task.reference, models=self.task_rewards, @@ -151,12 +155,12 @@ def __init__(self, **kwargs): pass @abstractmethod - def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: + def reward(self, reference: str, response_event: DendriteResponseEvent) -> BatchRewardOutput: pass - def apply(self, reference: str, response_event, reward_type) -> RewardEvent: + def apply(self, reference: str, response_event: DendriteResponseEvent, reward_type: RewardModelTypeEnum) -> RewardEvent: t0 = time.time() - batch_rewards_output = self.reward(reference, response_event.completions) + batch_rewards_output = self.reward(reference, response_event) batch_rewards_time = time.time() - t0 return RewardEvent( diff --git a/prompting/rewards/rouge.py b/prompting/rewards/rouge.py index c06d236f8..cd56bec0a 100644 --- a/prompting/rewards/rouge.py +++ b/prompting/rewards/rouge.py @@ -5,8 +5,8 @@ from prompting.rewards import ( BaseRewardModel, BatchRewardOutput, - RewardModelTypeEnum, ) +from prompting.dendrite import DendriteResponseEvent class RougeRewardModel(BaseRewardModel): @@ -28,10 +28,11 @@ def rouge_score(self, reference, completion): self.ngram ][self.metric] - def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: + def reward(self, reference: str, response_event: DendriteResponseEvent) -> BatchRewardOutput: """Compute ROUGE scores given a completion and reference pair.""" rewards = [] timings = [] + completions: List[str] = response_event.completions for completion in completions: t0 = time.time() diff --git a/prompting/rewards/streaming.py b/prompting/rewards/streaming.py new file mode 100644 index 000000000..f6443915a --- /dev/null +++ b/prompting/rewards/streaming.py @@ -0,0 +1,48 @@ +import time +import torch +from prompting.dendrite import DendriteResponseEvent +from prompting.rewards import ( + BaseRewardModel, + BatchRewardOutput, +) + +class StreamingRewardModel(BaseRewardModel): + @property + def name(self) -> str: + return "streaming" + + def __init__(self, max_tokens_per_chunk:int, **kwargs): + super().__init__() + self.max_tokens_per_chunk = max_tokens_per_chunk + + + def reward(self, _: str, response_event: DendriteResponseEvent) -> BatchRewardOutput: + """Compute difference scores given a completion and reference pair.""" + rewards = [] + timings = [] + penalty_per_exceeding_chunk = 0.25 + + # Iterate through each chunk of response tokens + for response_tokens_per_chunks in response_event.stream_results_all_tokens_per_chunk: + start_time = time.time() + + # Calculate the accumulated penalty for the current chunk + accumulated_penalty = sum( + penalty_per_exceeding_chunk if tokens_per_chunk > self.max_tokens_per_chunk else 0 + for tokens_per_chunk in response_tokens_per_chunks + ) + + # Record the timing for this computation + timings.append(time.time() - start_time) + + # Calculate the reward and ensure it does not go above 1 + rewards.append(min(accumulated_penalty, 1)) + + # Create the output object with rewards, timings, and extra information + output = BatchRewardOutput( + rewards=torch.FloatTensor(rewards), + timings=torch.FloatTensor(timings), + extra_info={"type": "streaming"} + ) + return output + diff --git a/prompting/tasks/task.py b/prompting/tasks/task.py index 041798d3e..e58900623 100644 --- a/prompting/tasks/task.py +++ b/prompting/tasks/task.py @@ -51,6 +51,10 @@ class Task(ABC): cleaner = None clean_reference = True challenge_type = 'inference' + + global_penalty_definition = [ + dict(name="streaming", max_tokens_per_chunk=200, weight=0.2) + ] def __str__(self): return f"{self.__class__.__name__}(name={self.name!r}, desc={self.desc!r}, goal={self.goal!r}, query={self.query!r}, reference={self.reference!r}, topic={self.topic!r}, subtopic={self.subtopic!r}, tags={self.tags!r})" diff --git a/tests/test_mock.py b/tests/test_mock.py index 12e5dfa6c..7ae6f8604 100644 --- a/tests/test_mock.py +++ b/tests/test_mock.py @@ -2,7 +2,7 @@ import asyncio import bittensor as bt from prompting.mock import MockDendrite, MockMetagraph, MockSubtensor -from prompting.protocol import PromptingSynapse +from prompting.protocol import StreamPromptingSynapse wallet = bt.MockWallet() wallet.create(coldkey_use_password=False) @@ -70,10 +70,11 @@ def test_mock_dendrite_timings(timeout, min_time, max_time, n): async def run(): return await mock_dendrite( axons, - synapse=PromptingSynapse( + synapse=StreamPromptingSynapse( roles=["user"], messages=["What is the capital of France?"] ), timeout=timeout, + deserialize=False, ) eps = 0.2 diff --git a/tests/test_reward.py b/tests/test_reward.py new file mode 100644 index 000000000..d074fb9fe --- /dev/null +++ b/tests/test_reward.py @@ -0,0 +1,30 @@ +import pytest +import torch +from unittest.mock import MagicMock +from prompting.rewards import StreamingRewardModel + +@pytest.mark.parametrize( + "all_tokens_per_chunk, expected_rewards", + [ + ([[1, 1, 1, 1, 1]], [0]), # No penalties + ([[2, 1, 1, 1, 1]], [0.25]), # First chunk exceeds + ([[2, 2, 1, 1, 1]], [0.5]), # Two chunks exceed + ([[2, 2, 2, 1, 1]], [0.75]), # Three chunks exceed + ([[2, 2, 2, 2, 1]], [1]), # Four chunks exceed + ([[2, 2, 2, 2, 2, 2]], [1]), # Sum of chunks > 1, clipped at 1 + ] +) +def test_streaming_reward_model(all_tokens_per_chunk, expected_rewards): + max_tokens_per_chunk = 1 + response_event = MagicMock() + response_event.stream_results_all_tokens_per_chunk = all_tokens_per_chunk + + model = StreamingRewardModel(max_tokens_per_chunk) + + output = model.reward("", response_event) + + assert torch.allclose(output.rewards, torch.tensor(expected_rewards, dtype=torch.float)), \ + f"Expected rewards {expected_rewards} but got {output.rewards.tolist()}" + +if __name__ == "__main__": + pytest.main() \ No newline at end of file