Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions prompting/base/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,6 @@ def run(self):
self.loop.run_until_complete(
asyncio.wait_for(task, timeout=forward_timeout)
)
except torch.cuda.OutOfMemoryError as e:
bt.logging.error(f"Out of memory error: {e}")
continue
except MaxRetryError as e:
bt.logging.error(f"MaxRetryError: {e}")
continue
Expand Down
29 changes: 16 additions & 13 deletions prompting/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,28 @@
# DEALINGS IN
# THE SOFTWARE.

import asyncio
import random
import sys
import time
import random
import asyncio
import traceback
import numpy as np
from dataclasses import dataclass
from typing import Awaitable, Dict, List

import bittensor as bt
from typing import List, Dict, Awaitable
import numpy as np
import torch

from prompting.agent import HumanAgent
from prompting.dendrite import DendriteResponseEvent
from prompting.conversation import create_task
from prompting.dendrite import DendriteResponseEvent
from prompting.protocol import StreamPromptingSynapse
from prompting.rewards import RewardResult
from prompting.tasks import QuestionAnsweringTask
from prompting.utils.uids import get_random_uids
from prompting.utils.logging import log_event
from prompting.utils.misc import async_log, serialize_exception_to_string
from dataclasses import dataclass
from prompting.utils.uids import get_random_uids

@async_log
async def generate_reference(agent):
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline)
return result

@async_log
async def execute_dendrite_call(dendrite_call):
Expand Down Expand Up @@ -167,7 +165,6 @@ def log_stream_results(stream_results: List[StreamResult]):
f"Failed response for uid {failed_response.uid}: {formatted_exception}"
)


async def run_step(
self, agent: HumanAgent, roles: List[str], messages: List[str], k: int, timeout: float, exclude: list = None
):
Expand Down Expand Up @@ -274,6 +271,8 @@ async def forward(self):
"""
Encapsulates a full conversation between the validator and miners. Contains one or more rounds of request-response.

Raises:
torch.cuda.OutOfMemoryError: CUDA out of memory error.
"""
bt.logging.info("🚀 Starting forward loop...")
forward_start_time = time.time()
Expand Down Expand Up @@ -354,6 +353,10 @@ async def forward(self):
messages.append(agent.challenge)
turn += 1

except torch.cuda.OutOfMemoryError as err:
bt.logging.error("CUDA out of memory", str(err))
raise err

except BaseException as e:
unexpected_errors = serialize_exception_to_string(e)
bt.logging.error(
Expand Down