Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,13 @@ def init_process_logging(name: str):
try:
# Add process-specific handlers
logger.add(
f"{name}_{os.getpid()}.log",
f"{name}.log",
rotation="100 MB",
retention="10 days",
level="DEBUG",
enqueue=True, # Use queue for thread-safe logging
)
logger.add(
f"{name}_err_{os.getpid()}.log", rotation="100 MB", retention="10 days", level="WARNING", enqueue=True
)
logger.add(f"{name}_err.log", rotation="100 MB", retention="10 days", level="WARNING", enqueue=True)
logger.add(sys.stderr, level=settings.shared_settings.LOG_LEVEL, enqueue=True)
except Exception as e:
print(f"Failed to initialize logging for process {os.getpid()}: {e}")
Expand Down
37 changes: 32 additions & 5 deletions prompting/rewards/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch.nn.functional as F
from loguru import logger
from openai.types.chat import ChatCompletionChunk
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from prompting.rewards.reward import BaseRewardModel, BatchRewardOutput
from prompting.tasks.base_task import BaseTextTask
from shared import constants, settings
from shared import settings
from shared.dendrite import DendriteResponseEvent
from shared.docker_utils import get_logits

Expand Down Expand Up @@ -61,9 +62,17 @@ async def reward( # noqa: C901

# If max_tokens are not provided, always check for eos.
model = task.llm_model_id
max_tokens = sampling_parameters.get("max_tokens", 2048)
eos_token = constants.SPECIAL_TOKENS.get(model, {}).get("eos_token")
bos_token = constants.SPECIAL_TOKENS.get(model, {}).get("bos_token")
max_tokens = await self.get_max_tokens(sampling_parameters, default=2048)

try:
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model, use_fast=True)
eos_token = tokenizer.eos_token
bos_token = tokenizer.bos_token
except BaseException as exc:
logger.error(f"Cannot get {model} tokenizer: {exc}. EOS token check is disabled")
eos_token = None
bos_token = None

special_tokens = set([bos_token, eos_token])
timing_verified: list[list[float]] = []
rewards: list[float] = []
Expand Down Expand Up @@ -106,13 +115,19 @@ async def reward( # noqa: C901
to_complete = "".join(chunks[:check_idx])
if to_complete:
messages.extend([{"role": "assistant", "content": to_complete}])
response = await get_logits(
response: dict[str, Any] | None = await get_logits(
model=task.llm_model_id,
messages=messages,
top_logprobs=TOP_LOGPROBS,
sampling_params=sampling_parameters,
continue_last_message=len(to_complete) > 0,
)
if response is None:
# Unexpected error on validator side, do no set penalty.
penalty = 0.0
logger.error(f"Cannot get logprobs for model {task.llm_model_id}")
raise ValueError(f"Cannot get logprobs for model {task.llm_model_id} and {messages}")

verification_logits = response[0]
if check_idx < eos_idx:
if chunks[check_idx] in special_tokens:
Expand Down Expand Up @@ -204,6 +219,18 @@ async def reward( # noqa: C901
logger.debug(f"Logits rewards: {reward_output.model_dump()}")
return reward_output

@classmethod
async def get_max_tokens(cls, sampling_params: dict[str, Any], default: int = 2048) -> int:
# vLLM / HF request.
max_tokens = sampling_params.get("max_tokens")
if max_tokens is None:
# Deprecated request.
max_tokens = sampling_params.get("max_new_tokens")
if max_tokens is None:
# OpenAI request.
max_tokens = sampling_params.get("max_completion_tokens", default)
return max_tokens

@staticmethod
def sample_verification_indices(completion_length: int) -> list[int]:
"""Sample random indices for verification, always add 0 and eos_token index."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from shared.uids import get_uids

if TYPE_CHECKING:
from prompting.tasks.MSRv2_task import MSRv2Task
from prompting.tasks.msrv2_task import MSRv2Task

shared_settings = settings.shared_settings

Expand Down
3 changes: 1 addition & 2 deletions prompting/rewards/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from prompting.rewards.scoring_config import ScoringConfig
from prompting.tasks.base_task import BaseTextTask
from prompting.tasks.MSRv2_task import MSRv2Task
from prompting.tasks.msrv2_task import MSRv2Task
from prompting.tasks.task_registry import TaskRegistry
from shared.base import DatasetEntry
from shared.dendrite import DendriteResponseEvent
Expand Down Expand Up @@ -81,7 +81,6 @@ async def run_step(self) -> RewardLoggingEvent:
await scoring_config.task.make_reference(
dataset_entry=scoring_config.dataset_entry,
)
logger.info(f"Reference: {scoring_config.task.reference}")

# and there we then calculate the reward
reward_pipeline = TaskRegistry.get_task_reward(scoring_config.task)
Expand Down
2 changes: 1 addition & 1 deletion prompting/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def task_messages(self) -> list[str] | list[dict]:
@model_validator(mode="after")
def get_model_id_and_seed(self) -> "BaseTextTask":
if self.llm_model:
self.llm_model_id = self.llm_model.llm_model_id if self.llm_model else None
self.llm_model_id = self.llm_model
return self

async def make_query(self, dataset_entry: DatasetEntry, **kwargs) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from loguru import logger

from prompting.datasets.random_website import DDGDatasetEntry
from prompting.rewards.MSRv2_reward import MSRv2RewardModel
from prompting.rewards.msrv2_reward import MSRv2RewardModel
from prompting.rewards.reward import BaseRewardConfig, BaseRewardModel
from prompting.tasks.multi_step_reasoning import MultiStepReasoningTask
from shared.base import Context
Expand Down
2 changes: 1 addition & 1 deletion prompting/tasks/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from prompting.rewards.reward import BaseRewardConfig
from prompting.tasks.base_task import BaseTextTask
from prompting.tasks.inference import InferenceRewardConfig, InferenceTask
from prompting.tasks.MSRv2_task import MSRv2RewardConfig, MSRv2Task
from prompting.tasks.msrv2_task import MSRv2RewardConfig, MSRv2Task
from prompting.tasks.web_retrieval import WebRetrievalRewardConfig, WebRetrievalTask
from shared.base import BaseDataset

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "prompting"
version = "2.19.5"
version = "2.19.6"
description = "Subnetwork 1 runs on Bittensor and is maintained by Macrocosmos. It's an effort to create decentralised AI"
authors = ["Kalei Brady, Dmytro Bobrenko, Felix Quinque, Steffen Cruz, Richard Wardle"]
readme = "README.md"
Expand Down
2 changes: 0 additions & 2 deletions shared/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
WHITELISTED_VALIDATORS_UIDS = [5, 518, 674, 966, 502, 520, 993, 24] # OTF # WildSageLabs # Rizzo # Macrocosmos

DOCKER_BASE_URL = "http://localhost:8000"

SPECIAL_TOKENS = {"mrfakename/mistral-small-3.1-24b-instruct-2503-hf": {"bos_token": "<s>", "eos_token": "</s>"}}
6 changes: 4 additions & 2 deletions shared/docker_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import requests
from loguru import logger

Expand Down Expand Up @@ -38,7 +40,7 @@ async def get_logits(
seed: int = None,
continue_last_message: bool = False,
top_logprobs: int = 10,
):
) -> dict[str, Any] | None:
url = f"{constants.DOCKER_BASE_URL}/v1/chat/generate_logits"
headers = {"Content-Type": "application/json"}
payload = {
Expand All @@ -54,7 +56,7 @@ async def get_logits(
return json_response
except requests.exceptions.JSONDecodeError:
logger.error(f"Error generating logits. Status: {response.status_code}, Body: {response.text}")
return ""
return None


def get_embeddings(inputs):
Expand Down