Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9b33454
clean up legacy prompting synapse
p-ferreira May 16, 2024
5d3e1b9
initial adjustments to protocol + chunks logging
p-ferreira May 16, 2024
b6b7e22
fix chunk processing + mock
p-ferreira May 16, 2024
f3e5466
adjust dendrite response
p-ferreira May 16, 2024
36344f0
refactors base pipeline apply function
p-ferreira May 16, 2024
b269365
adds tokenizer to flow + dendrite refactor
p-ferreira May 17, 2024
7227814
implements streaming reward function
p-ferreira May 17, 2024
a47bdf1
adds tokens_per_chunk in forward flow + adds global penalties
p-ferreira May 17, 2024
919780d
drops synapse final data out of accumulated chunks
p-ferreira May 20, 2024
a737825
Merge branch 'main' into features/stream-adjustments
p-ferreira May 21, 2024
4b862e7
adds unit tests for streaming reward model
p-ferreira May 21, 2024
9d67963
adjust reward calculation to include global penalties
p-ferreira May 22, 2024
854b7ad
remove unit test todo comment
p-ferreira May 28, 2024
590a7c5
Merge branch 'staging' into features/stream-adjustments
p-ferreira May 28, 2024
d1b73eb
updates versioning
p-ferreira May 28, 2024
3808f4e
drops deprecated prompting synapse from mock code
p-ferreira May 28, 2024
bf61e97
fix mock pipeline
p-ferreira May 28, 2024
ff0de98
fix stream synpase import
p-ferreira Jun 4, 2024
1d3e6bc
fix mock dendrite call
p-ferreira Jun 4, 2024
dfb1c76
Merge branch 'staging' into features/stream-adjustments
p-ferreira Jun 5, 2024
1961f6c
updates versioning
p-ferreira Jun 5, 2024
41d1c67
Merge branch 'staging' into features/stream-adjustments
p-ferreira Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prompting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# DEALINGS IN THE SOFTWARE.

# Define the version of the template module.
__version__ = "2.3.1"
__version__ = "2.4.0"
version_split = __version__.split(".")
__spec_version__ = (
(10000 * int(version_split[0]))
Expand Down
63 changes: 41 additions & 22 deletions prompting/dendrite.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,59 @@
import torch
import bittensor as bt
from typing import List
from dataclasses import dataclass
from prompting.protocol import StreamPromptingSynapse
from prompting.utils.misc import serialize_exception_to_string


@dataclass
class SynapseStreamResult:
exception: BaseException = None
uid: int = None
accumulated_chunks: List[str] = None
accumulated_chunks_timings: List[float] = None
tokens_per_chunk: List[int] = None
synapse: StreamPromptingSynapse = None


class DendriteResponseEvent:
def __init__(
self, responses: List[bt.Synapse], uids: torch.LongTensor, timeout: float
self, stream_results: SynapseStreamResult, uids: torch.LongTensor, timeout: float
):
self.uids = uids
self.completions = []
self.status_messages = []
self.status_codes = []
self.timings = []
self.stream_results_uids = []
self.stream_results_exceptions = []
self.stream_results_all_chunks = []
self.stream_results_all_chunks_timings = []
self.stream_results_all_tokens_per_chunk = []

for stream_result in stream_results:
synapse = stream_result.synapse

for synapse in responses:
self.completions.append(synapse.completion)
self.status_messages.append(synapse.dendrite.status_message)
status_code = synapse.dendrite.status_code

if len(synapse.completion) == 0 and synapse.dendrite.status_code == 200:
synapse.dendrite.status_code = 204
if len(synapse.completion) == 0 and status_code == 200:
status_code = 204

self.status_codes.append(synapse.dendrite.status_code)

if (synapse.dendrite.process_time) and (
synapse.dendrite.status_code == 200
or synapse.dendrite.status_code == 204
):
self.timings.append(synapse.dendrite.process_time)
elif synapse.dendrite.status_code == 408:
self.status_codes.append(status_code)
process_time = synapse.dendrite.process_time or 0
if status_code == 200 or status_code == 204:
self.timings.append(process_time)
elif status_code == 408:
self.timings.append(timeout)
else:
self.timings.append(0) # situation where miner is not alive
self.timings.append(0)

self.completions = [synapse.completion for synapse in responses]
self.timings = [
synapse.dendrite.process_time or timeout for synapse in responses
]
self.status_messages = [
synapse.dendrite.status_message for synapse in responses
]
self.status_codes = [synapse.dendrite.status_code for synapse in responses]
self.stream_results_uids.append(stream_result.uid)
self.stream_results_exceptions.append(serialize_exception_to_string(stream_result.exception))
self.stream_results_all_chunks.append(stream_result.accumulated_chunks)
self.stream_results_all_chunks_timings.append(stream_result.accumulated_chunks_timings)
self.stream_results_all_tokens_per_chunk.append(stream_result.tokens_per_chunk)

def __state_dict__(self):
return {
Expand All @@ -48,6 +62,11 @@ def __state_dict__(self):
"timings": self.timings,
"status_messages": self.status_messages,
"status_codes": self.status_codes,
"stream_results_uids": self.stream_results_uids,
"stream_results_exceptions": self.stream_results_exceptions,
"stream_results_all_chunks": self.stream_results_all_chunks,
"stream_results_all_chunks_timings": self.stream_results_all_chunks_timings,
"stream_results_all_tokens_per_chunk": self.stream_results_all_tokens_per_chunk,
}

def __repr__(self):
Expand Down
127 changes: 54 additions & 73 deletions prompting/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN
# THE SOFTWARE.
# DEALINGS IN THE SOFTWARE.

import sys
import time
Expand All @@ -25,14 +24,16 @@
import bittensor as bt
from typing import List, Dict, Awaitable
from prompting.agent import HumanAgent
from prompting.dendrite import DendriteResponseEvent
from prompting.dendrite import DendriteResponseEvent, SynapseStreamResult
from prompting.conversation import create_task
from prompting.protocol import StreamPromptingSynapse
from prompting.rewards import RewardResult
from prompting.tasks import QuestionAnsweringTask
from prompting.utils.uids import get_random_uids
from prompting.utils.logging import log_event
from prompting.utils.misc import async_log, serialize_exception_to_string
from transformers import PreTrainedTokenizerFast as Tokenizer
from prompting.utils.uids import get_random_uids
from dataclasses import dataclass

@async_log
Expand All @@ -46,32 +47,35 @@ async def execute_dendrite_call(dendrite_call):
responses = await dendrite_call
return responses

@dataclass
class StreamResult:
synapse: StreamPromptingSynapse = None
exception: BaseException = None
uid: int = None


async def process_response(uid: int, async_generator: Awaitable):
async def process_stream(uid: int, async_iterator: Awaitable, tokenizer: Tokenizer) -> SynapseStreamResult:
"""Process a single response asynchronously."""
try:
chunk = None # Initialize chunk with a default value
async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse.
bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")

if chunk is not None:
synapse = chunk # last object yielded is the synapse itself with completion filled

# Assuming chunk holds the last value yielded which should be a synapse
if isinstance(synapse, StreamPromptingSynapse):
return synapse

bt.logging.debug(
f"Synapse is not StreamPromptingSynapse. Miner uid {uid} completion set to '' "
)
except Exception as e:
# bt.logging.error(f"Error in generating reference or handling responses: {e}", exc_info=True)
synapse = None # Initialize chunk with a default value
exception = None
accumulated_chunks = []
accumulated_chunks_timings = []
accumulated_tokens_per_chunk = []
start_time = time.time()

try:
async for chunk in async_iterator: # most important loop, as this is where we acquire the final synapse.
if isinstance(chunk, str):
accumulated_chunks.append(chunk)
accumulated_chunks_timings.append(time.time() - start_time)

tokens_in_chunk = len(tokenizer.tokenize(chunk))
accumulated_tokens_per_chunk.append(tokens_in_chunk)

bt.logging.debug(f"\nchunk for uid {uid}: {chunk}")

# Assuming last chunk of async_iterator holds the last value yielded as a StreamingSynapse
synapse = chunk
if synapse is None or not isinstance(synapse, StreamPromptingSynapse):
raise ValueError(
f"Something went wrong with miner uid {uid}, Synapse is not StreamPromptingSynapse."
)
except Exception as e:
exception = e
traceback_details = traceback.format_exc()
bt.logging.error(
f"Error in generating reference or handling responses for uid {uid}: {e}\n{traceback_details}"
Expand All @@ -81,11 +85,20 @@ async def process_response(uid: int, async_generator: Awaitable):
roles=["user"], messages=["failure"], completion=""
)

return failed_synapse
synapse = failed_synapse
finally:
return SynapseStreamResult(
accumulated_chunks=accumulated_chunks,
accumulated_chunks_timings=accumulated_chunks_timings,
tokens_per_chunk=accumulated_tokens_per_chunk,
synapse=synapse,
uid=uid,
exception=exception
)


@async_log
async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]:
async def handle_response(stream_results_dict: Dict[int, Awaitable], tokenizer: Tokenizer) -> List[SynapseStreamResult]:
"""The handle_response function is responsible for creating asyncio tasks around acquiring streamed miner chunks
and processing them asynchronously. It then pairs the results with their original UIDs and returns a list of StreamResults.

Expand All @@ -99,36 +112,14 @@ async def handle_response(responses: Dict[int, Awaitable]) -> List[StreamResult]
List[StreamResult]: DataClass containing the synapse, exception, and uid
"""
tasks_with_uid = [
(uid, responses[uid]) for uid, _ in responses.items()
(uid, stream_results_dict[uid]) for uid, _ in stream_results_dict.items()
] # Pair UIDs with their tasks

# Start tasks, preserving order and their associated UIDs
tasks = [process_response(uid, resp) for uid, resp in tasks_with_uid]

results = await asyncio.gather(*tasks, return_exceptions=True)

mapped_results = []
# Pair each result with its original uid
for (uid, _), result in zip(tasks_with_uid, results):
# If the result is a StreamPromptingSynapse, the response was successful and the stream result is added without exceptions
if isinstance(result, StreamPromptingSynapse):
mapped_results.append(StreamResult(synapse=result, uid=uid))

# If the result is an exception, the response was unsuccessful and the stream result is added with the exception and an empty synapse
elif isinstance(result, BaseException):
failed_synapse = StreamPromptingSynapse(
roles=["user"], messages=["failure"], completion=""
)
mapped_results.append(
StreamResult(synapse=failed_synapse, exception=result, uid=uid)
)

# If the result is neither an error or a StreamSynapse, log the error and raise a ValueError
else:
bt.logging.error(f"Unexpected result type for UID {uid}: {result}")
raise ValueError(f"Unexpected result type for UID {uid}: {result}")

return mapped_results
process_stream_tasks = [process_stream(uid, resp, tokenizer) for uid, resp in tasks_with_uid]
processed_stream_results = await asyncio.gather(*process_stream_tasks, return_exceptions=True)

return processed_stream_results


@async_log
Expand All @@ -140,7 +131,7 @@ async def generate_reference(agent: HumanAgent):
return result


def log_stream_results(stream_results: List[StreamResult]):
def log_stream_results(stream_results: List[SynapseStreamResult]):
failed_responses = [
response for response in stream_results if response.exception is not None
]
Expand Down Expand Up @@ -186,7 +177,6 @@ async def run_step(
timeout (float): The timeout for the queries.
exclude (list, optional): The list of uids to exclude from the query. Defaults to [].
"""

bt.logging.debug("run_step", agent.task.name)

# Record event start time.
Expand All @@ -207,9 +197,9 @@ async def run_step(
)

# Prepare the task for handling stream responses
handle_stream_responses_task = asyncio.create_task(
handle_response(responses=dict(zip(uids_cpu, streams_responses)))
)
stream_results_dict = dict(zip(uids_cpu, streams_responses))
tokenizer = self.llm_pipeline.tokenizer
handle_stream_responses_task = asyncio.create_task(handle_response(stream_results_dict, tokenizer))

if not agent.task.static_reference:
reference_generation_task = generate_reference(agent)
Expand All @@ -221,11 +211,9 @@ async def run_step(

log_stream_results(stream_results)

all_synapses_results = [stream_result.synapse for stream_result in stream_results]

# Encapsulate the responses in a response event (dataclass)
response_event = DendriteResponseEvent(
responses=all_synapses_results, uids=uids, timeout=timeout
stream_results=stream_results, uids=uids, timeout=timeout
)

bt.logging.info(f"Created DendriteResponseEvent:\n {response_event}")
Expand All @@ -248,20 +236,13 @@ async def run_step(
)

self.update_scores(reward_result.rewards, uids)

stream_results_uids = [stream_result.uid for stream_result in stream_results]
stream_results_exceptions = [
serialize_exception_to_string(stream_result.exception)
for stream_result in stream_results
]

# Log the step event.
event = {
"best": best_response,
"block": self.block,
"step": self.step,
"step_time": time.time() - start_time,
"stream_results_uids": stream_results_uids,
"stream_results_exceptions": stream_results_exceptions,
"step_time": time.time() - start_time,
**agent.__state_dict__(full=self.config.neuron.log_full),
**reward_result.__state_dict__(full=self.config.neuron.log_full),
**response_event.__state_dict__(),
Expand Down
1 change: 1 addition & 0 deletions prompting/llms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
self.model_kwargs = model_kwargs
self.messages = []
self.times = []
self.tokenizer = None

def query(
self,
Expand Down
1 change: 1 addition & 0 deletions prompting/llms/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.llm = load_vllm_pipeline(model_id, device, gpus, llm_max_allowed_memory_in_gb, mock)
self.mock = mock
self.gpus = gpus
self.tokenizer = self.llm.llm_engine.tokenizer

def __call__(self, composed_prompt: str, **model_kwargs: Dict) -> str:
if self.mock:
Expand Down
22 changes: 12 additions & 10 deletions prompting/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import asyncio
import random
import bittensor as bt
from prompting.protocol import StreamPromptingSynapse, PromptingSynapse
from prompting.protocol import StreamPromptingSynapse

from functools import partial
from typing import Dict, List, Union, AsyncGenerator, Any, Iterator
from types import SimpleNamespace


class MockTokenizer:
Expand Down Expand Up @@ -44,6 +45,12 @@ class MockPipeline:
@property
def tokenizer(self):
return self.model.tokenizer

@property
def llm_engine(self):
return SimpleNamespace(
tokenizer=self.model.tokenizer
)

def __init__(
self,
Expand Down Expand Up @@ -287,15 +294,10 @@ async def forward(
deserialize: bool = True,
run_async: bool = True,
streaming: bool = False,
):
if streaming:
assert isinstance(
synapse, StreamPromptingSynapse
), "Synapse must be a StreamPromptingSynapse object when is_stream is True."
else:
assert isinstance(
synapse, PromptingSynapse
), "Synapse must be a PromptingSynapse object when is_stream is False."
):
assert isinstance(
synapse, StreamPromptingSynapse
), "Synapse must be a StreamPromptingSynapse object when is_stream is True."

async def query_all_axons(is_stream: bool):
"""Queries all axons for responses."""
Expand Down
Loading