diff --git a/prompting/base/validator.py b/prompting/base/validator.py index 893bf2055..3ce252ee1 100644 --- a/prompting/base/validator.py +++ b/prompting/base/validator.py @@ -142,9 +142,6 @@ def run(self): self.loop.run_until_complete( asyncio.wait_for(task, timeout=forward_timeout) ) - except torch.cuda.OutOfMemoryError as e: - bt.logging.error(f"Out of memory error: {e}") - continue except MaxRetryError as e: bt.logging.error(f"MaxRetryError: {e}") continue diff --git a/prompting/forward.py b/prompting/forward.py index 170f85254..a92cb80bd 100644 --- a/prompting/forward.py +++ b/prompting/forward.py @@ -16,30 +16,28 @@ # DEALINGS IN # THE SOFTWARE. +import asyncio +import random import sys import time -import random -import asyncio import traceback -import numpy as np +from dataclasses import dataclass +from typing import Awaitable, Dict, List + import bittensor as bt -from typing import List, Dict, Awaitable +import numpy as np +import torch + from prompting.agent import HumanAgent -from prompting.dendrite import DendriteResponseEvent from prompting.conversation import create_task +from prompting.dendrite import DendriteResponseEvent from prompting.protocol import StreamPromptingSynapse from prompting.rewards import RewardResult from prompting.tasks import QuestionAnsweringTask -from prompting.utils.uids import get_random_uids from prompting.utils.logging import log_event from prompting.utils.misc import async_log, serialize_exception_to_string -from dataclasses import dataclass +from prompting.utils.uids import get_random_uids -@async_log -async def generate_reference(agent): - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline) - return result @async_log async def execute_dendrite_call(dendrite_call): @@ -167,7 +165,6 @@ def log_stream_results(stream_results: List[StreamResult]): f"Failed response for uid {failed_response.uid}: {formatted_exception}" ) - async def run_step( self, agent: HumanAgent, roles: List[str], messages: List[str], k: int, timeout: float, exclude: list = None ): @@ -274,6 +271,8 @@ async def forward(self): """ Encapsulates a full conversation between the validator and miners. Contains one or more rounds of request-response. + Raises: + torch.cuda.OutOfMemoryError: CUDA out of memory error. """ bt.logging.info("🚀 Starting forward loop...") forward_start_time = time.time() @@ -354,6 +353,10 @@ async def forward(self): messages.append(agent.challenge) turn += 1 + except torch.cuda.OutOfMemoryError as err: + bt.logging.error("CUDA out of memory", str(err)) + raise err + except BaseException as e: unexpected_errors = serialize_exception_to_string(e) bt.logging.error(