diff --git a/prompting/api/scoring/api.py b/prompting/api/scoring/api.py index 2216e5e7c..49f930f5f 100644 --- a/prompting/api/scoring/api.py +++ b/prompting/api/scoring/api.py @@ -59,29 +59,22 @@ async def score_response( ): model = None payload: dict[str, Any] = await request.json() - logger.debug(f"Awaited body: {payload}") body = payload.get("body") timeout = payload.get("timeout", shared_settings.NEURON_TIMEOUT) uids = payload.get("uids", []) chunks = payload.get("chunks", {}) chunk_dicts_raw = payload.get("chunk_dicts_raw", {}) timings = payload.get("timings", {}) - logger.debug("About to check chunks and uids") if not uids or not chunks: logger.error(f"Either uids: {uids} or chunks: {chunks} is not valid, skipping scoring") return uids = [int(uid) for uid in uids] model = body.get("model") - logger.debug("About to check model") if model and model not in shared_settings.LLM_MODEL: logger.error(f"Model {model} not available for scoring on this validator.") return - logger.debug("Model has been checked") llm_model = ModelZoo.get_model_by_id(model) - logger.debug("Got LLM Model from ModelZoo") task_name = body.get("task") - logger.debug(f"Task name set: {task_name}") - logger.debug(f"Length pre-insertion: {len(task_scorer.scoring_queue)}") if task_name == "InferenceTask": organic_task = InferenceTask( messages=body.get("messages"), @@ -120,6 +113,8 @@ async def score_response( seed=int(body.get("seed", 0)), sampling_params=body.get("sampling_params", {}), query=search_term, + target_results=body.get("target_results", 1), + timeout=body.get("timeout", 10), ), response=DendriteResponseEvent( uids=uids, @@ -132,5 +127,4 @@ async def score_response( step=-1, task_id=str(uuid.uuid4()), ) - logger.debug(f"Current Queue: {len(task_scorer.scoring_queue)}") - logger.info("Organic task appended to scoring queue") + logger.debug(f"Organic queue size: {len(task_scorer.scoring_queue)}") diff --git a/prompting/datasets/random_website.py b/prompting/datasets/random_website.py index 00729f8ac..7f85c5a6e 100644 --- a/prompting/datasets/random_website.py +++ b/prompting/datasets/random_website.py @@ -15,8 +15,8 @@ class DDGDatasetEntry(DatasetEntry): search_term: str - website_url: str = None - website_content: str = None + website_url: str | None = None + website_content: str | None = None query: str | None = None source: str | None = None @@ -40,7 +40,7 @@ def search_random_term(self, retries: int = 3) -> tuple[Optional[str], Optional[ @staticmethod @lru_cache(maxsize=1000) - def extract_website_content(url: str, retries: int = 3) -> Optional[str]: + def extract_website_content(url: str, retries: int = 3) -> str | None: exception: Exception | None = None for _ in range(retries): try: diff --git a/prompting/llms/model_manager.py b/prompting/llms/model_manager.py index ed1fd107d..aa5388273 100644 --- a/prompting/llms/model_manager.py +++ b/prompting/llms/model_manager.py @@ -190,6 +190,7 @@ async def generate_logits( sampling_params: dict[str, float] = None, seed: int = None, continue_last_message: bool = False, + top_logprobs: int = 10, ): model_instance: ReproducibleVLLM = await self.get_model(model) return await model_instance.generate_logits( @@ -197,6 +198,7 @@ async def generate_logits( sampling_params=sampling_params, seed=seed, continue_last_message=continue_last_message, + top_logprobs=top_logprobs, ) async def cleanup(self): @@ -262,5 +264,5 @@ async def run_step(self): await self.llm_model_manager.load_model(selected_model) except MemoryError as e: self.memory_error = e - logger.debug(f"Active models: {self.llm_model_manager.active_models.keys()}") + logger.debug(f"Active models: {list(self.llm_model_manager.active_models.keys())}") await asyncio.sleep(0.01) diff --git a/prompting/llms/vllm_llm.py b/prompting/llms/vllm_llm.py index d590ecc31..083cc21eb 100644 --- a/prompting/llms/vllm_llm.py +++ b/prompting/llms/vllm_llm.py @@ -34,6 +34,38 @@ def __init__( # Store tokenizer from VLLM for consistency self.tokenizer = self.model.get_tokenizer() + @classmethod + async def get_max_tokens( + cls, + sampling_params: dict[str, str | float | int | bool], + default_value: int = 512, + ) -> int: + # Process max tokens with backward compatibility. + max_tokens = sampling_params.get("max_tokens") + if max_tokens is None: + max_tokens = sampling_params.get("max_new_tokens") + if max_tokens is None: + max_tokens = sampling_params.get("max_completion_tokens", default_value) + return max_tokens + + @classmethod + async def prepare_sampling_params( + cls, sampling_params: dict[str, str | float | int | bool] | None = None + ) -> SamplingParams: + sampling_params = sampling_params or {} + max_tokens = await cls.get_max_tokens(sampling_params) + + params = SamplingParams( + temperature=float(sampling_params.get("temperature", 1.0)), + top_p=float(sampling_params.get("top_p", 1.0)), + max_tokens=int(max_tokens), + presence_penalty=float(sampling_params.get("presence_penalty", 0.0)), + frequency_penalty=float(sampling_params.get("frequency_penalty", 0.0)), + top_k=int(sampling_params.get("top_k", -1)), + logprobs=sampling_params.get("logprobs", None), + ) + return params + async def generate( self, messages: list[str] | list[dict[str, str]], @@ -71,24 +103,9 @@ async def generate( else: prompt = messages[0] if isinstance(messages, list) else messages - # Convert sampling parameters to VLLM format + # Convert sampling parameters to vLLM format. params = sampling_params if sampling_params else self.sampling_params - - max_tokens = params.get("max_new_tokens") - if max_tokens is None: - max_tokens = params.get("max_tokens", 512) - - vllm_params = SamplingParams( - temperature=params.get("temperature", 1.0), - top_p=params.get("top_p", 1.0), - max_tokens=int(max_tokens), - presence_penalty=params.get("presence_penalty", 0.0), - frequency_penalty=params.get("frequency_penalty", 0.0), - top_k=int(params.get("top_k", -1)), - logprobs=params.get("logprobs", None), - ) - - # Generate using VLLM + vllm_params = await self.prepare_sampling_params(params) outputs = self.model.generate(prompt, vllm_params) if not outputs: @@ -101,7 +118,7 @@ async def generate( async def generate_logits( self, messages: list[str] | list[dict[str, str]], - top_n: int = 10, + top_logprobs: int = 10, sampling_params: dict[str, str | float | int | bool] | None = None, seed: int | None = None, continue_last_message: bool = False, @@ -110,7 +127,7 @@ async def generate_logits( Args: messages: Input messages or text. - top_n: Number of top logits to return (default: 10). + top_logprobs: Number of top logits to return (default: 10). sampling_params: Generation parameters. seed: Random seed for reproducibility. continue_last_message: Whether to continue the last message in chat format. @@ -122,8 +139,8 @@ async def generate_logits( params = sampling_params if sampling_params else self.sampling_params params = params.copy() params["max_tokens"] = 1 - params["logprobs"] = top_n - vllm_params = SamplingParams(**params) + params["logprobs"] = top_logprobs + vllm_params = await self.prepare_sampling_params(params) prompt = self.tokenizer.apply_chat_template( conversation=messages, diff --git a/prompting/rewards/exact_match.py b/prompting/rewards/exact_match.py index 0895ca87f..8ff5db9c3 100644 --- a/prompting/rewards/exact_match.py +++ b/prompting/rewards/exact_match.py @@ -1,4 +1,5 @@ import random +from typing import Any import numpy as np import torch @@ -14,15 +15,17 @@ shared_settings = settings.shared_settings +TOP_LOGPROBS = 10 MIN_VERIFY_TOKENS = 10 -MAX_VERIFY_TOKENS = 30 -NO_EOS_PENALTY = -0.1 -INCORRECT_PENALTY = -1.5 +MAX_VERIFY_TOKENS = 51 +PARTIAL_PENALTY = -1.0 +INCORRECT_PENALTY = -2.0 NOT_ENOUGH_TOKENS_PENALTY_SCALE = 0.1 -MIN_SMOOTH_PENALTY_SCALE = 0.6 +MIN_SMOOTH_PENALTY_SCALE = 0.3 MIN_TIME_PENALTY_SCALE = 0.3 VERIFICATION_THRESH_CONTAINS = 0.92 VERIFICATION_THRESH_SIM = 0.83 +VERIFICATION_SIM_EXP_SCALE = 2.0 class LogitsRewardModel(BaseRewardModel): @@ -39,7 +42,7 @@ async def reward( # noqa: C901 raise ValueError("Model manager must be set") all_chunks: list[list[str]] = response_event.stream_results_all_chunks - all_chunk_dicts_raw: list[list[ChatCompletionChunk]] = response_event.stream_results_all_chunk_dicts_raw + all_chunk_dicts_raw: list[list[ChatCompletionChunk | dict]] = response_event.stream_results_all_chunk_dicts_raw uids: np.ndarray | list[float] = response_event.uids all_timings: list[list[float]] = response_event.stream_results_all_chunks_timings completions: list[str] = response_event.completions @@ -59,9 +62,11 @@ async def reward( # noqa: C901 raise ValueError("Timeout must be greater than 0.") # If max_tokens are not provided, always check for eos. - max_tokens = sampling_parameters.get("max_tokens", 8192) model = await model_manager.get_model(task.llm_model_id) + max_tokens = await model.get_max_tokens(sampling_parameters, default_value=2048) eos_token = model.tokenizer.eos_token + bos_token = model.tokenizer.bos_token + special_tokens = set([bos_token, eos_token]) timing_verified: list[list[float]] = [] rewards: list[float] = [] logger.info(f"Verifying logits with model {task.llm_model_id}") @@ -70,8 +75,8 @@ async def reward( # noqa: C901 penalty = INCORRECT_PENALTY reward_scale = 1.0 try: - # If no response is provided, apply full penalty. - if not chunks: + if not chunks or not chunk_dicts_raw: + # If no response is provided, apply full penalty. rewards.append(INCORRECT_PENALTY) timing_verified.append([-1.0]) continue @@ -83,6 +88,12 @@ async def reward( # noqa: C901 timing_verified.append([-1.0]) continue + if completion_length > max_tokens: + # Sampling params is ignored. + rewards.append(PARTIAL_PENALTY) + timing_verified.append([-1.0]) + continue + if completion_length < MIN_VERIFY_TOKENS: # Not enough tokens to verify, still proceed to verification with scaled reward if checks will pass. reward_scale = NOT_ENOUGH_TOKENS_PENALTY_SCALE @@ -100,32 +111,43 @@ async def reward( # noqa: C901 verification_logits, _ = await model_manager.generate_logits( model=task.llm_model_id, messages=messages, + top_logprobs=TOP_LOGPROBS, sampling_params=sampling_parameters, continue_last_message=len(to_complete) > 0, ) if check_idx < eos_idx: - if not chunk_dicts_raw[check_idx].choices[0].logprobs: - raise ValueError("No logprobs provided") + if chunks[check_idx] in special_tokens: + raise ValueError("Special tokens mid-completion") - if chunk_dicts_raw[check_idx].choices[0].logprobs.content is None: + chunk_dict: dict[str, Any] | ChatCompletionChunk = chunk_dicts_raw[check_idx] + if isinstance(chunk_dict, ChatCompletionChunk): + # Convert chunks to unified dict format. + chunk_dict = chunk_dict.model_dump(mode="python") + + if chunk_dict.get("choices", [{}])[0].get("logprobs", {}).get("content") is None: raise ValueError("Logprobs content is empty") original_logits = { - info.token: info.logprob - for info in chunk_dicts_raw[check_idx].choices[0].logprobs.content[0].top_logprobs + info["token"]: info["logprob"] + for info in chunk_dict["choices"][0]["logprobs"]["content"][0]["top_logprobs"] } + if len(verification_logits) == TOP_LOGPROBS + 1: + # Sampled logprobs can be +1, remove the lowest value. + del verification_logits[min(verification_logits, key=verification_logits.get)] + logit_sim = self.verify_logit_similarity(original_logits, verification_logits) scores_sim.append(logit_sim) logit_contains = self.verify_logit_contains( chunks[check_idx], original_logits, verification_logits ) + scores_contains.append(logit_contains) elif check_idx == eos_idx and completion_length < max_tokens: if eos_token and eos_token not in verification_logits: - penalty = NO_EOS_PENALTY + penalty = PARTIAL_PENALTY raise ValueError("Partial completion") score_sim_mean = float(np.mean(scores_sim)) @@ -138,9 +160,11 @@ async def reward( # noqa: C901 raise ValueError(f"Logits contains mean score is below threshold: {score_contains_mean:.2f}") timing_verified.append(timings) - smooth_reward = self.smooth_timings_reward(timings) + timingsdt = np.abs(np.diff(timings)) + smooth_reward = self.smooth_timings_reward(timingsdt) # Min-max scale logits reward, e.g from [0.95; 1.0] to [0.0, 1.0]. score_sim_mean = self.rescale(score_sim_mean, min_value=VERIFICATION_THRESH_SIM) + score_sim_mean = score_sim_mean**VERIFICATION_SIM_EXP_SCALE score_contains_mean = self.rescale(score_contains_mean, min_value=VERIFICATION_THRESH_CONTAINS) rewards.append(score_sim_mean * score_contains_mean * smooth_reward * reward_scale) except BaseException as e: @@ -205,26 +229,28 @@ def fastest_timing(values: list[list[float]]) -> float: """Return the smallest sum of inner list, compute its sum only if the list contains no negative numbers.""" best = float("+inf") for subset in values: - if subset and min(subset) >= 0.0: + if len(subset) and min(subset) >= 0.0: subset_sum = sum(subset) / len(subset) if subset_sum < best: best = subset_sum return best if best < float("+inf") else 1e-6 @staticmethod - def smooth_timings_reward(timings_uid: list[float], min_reward: float = MIN_SMOOTH_PENALTY_SCALE) -> float: - """Return smooth stream ration based on the deviation between chunks timings. - - Args: - timings_uid: List of timings for a specific miner. - - Returns: - float: Smoothed penalty value. - """ - if not timings_uid: + def smooth_timings_reward( + timings_uid: list[float] | np.ndarray, + tolerance_sec: float = 1, + min_reward: float = MIN_SMOOTH_PENALTY_SCALE, + penalty_strength: float = 5, + ) -> float: + """If delay between chunks is longer than tolerance, apply non-smooth stream penalty.""" + if not len(timings_uid): return 0.0 - smooth_penalty = np.std(timings_uid) + max_timing = max(timings_uid) + if max_timing < tolerance_sec: + return 1.0 + + smooth_penalty = np.std(timings_uid) * penalty_strength return max(min_reward, 1.0 - smooth_penalty) @staticmethod @@ -252,7 +278,7 @@ def verify_logit_similarity( if not gt_logits: return 0.0 - if len(candidate_logits) != len(gt_logits): + if len(candidate_logits) != TOP_LOGPROBS: return 0.0 # Tokens common to both distributions. diff --git a/prompting/rewards/inference_reward_model.py b/prompting/rewards/inference_reward_model.py index d889ba67c..a70db9d4a 100644 --- a/prompting/rewards/inference_reward_model.py +++ b/prompting/rewards/inference_reward_model.py @@ -1,3 +1,5 @@ +from loguru import logger + from prompting.rewards.exact_match import LogitsRewardModel from prompting.rewards.relevance import RelevanceRewardModel from prompting.rewards.reward import BaseRewardModel, BatchRewardOutput @@ -19,9 +21,10 @@ async def reward( if model_manager is None: raise ValueError("Model manager must be set") - if not model_id or task.organic: - relevance_reward_model = RelevanceRewardModel() - return await relevance_reward_model.reward(reference, response_event, model_manager=model_manager) + if model_id or task.organic: + logger.info("Using logits reward model") + logits_reward_model = LogitsRewardModel() + return await logits_reward_model.reward(reference, response_event, task, model_manager=model_manager) - logits_reward_model = LogitsRewardModel() - return await logits_reward_model.reward(reference, response_event, task, model_manager=model_manager) + relevance_reward_model = RelevanceRewardModel() + return await relevance_reward_model.reward(reference, response_event, model_manager=model_manager) diff --git a/prompting/rewards/scoring.py b/prompting/rewards/scoring.py index f1647c39b..ca744885b 100644 --- a/prompting/rewards/scoring.py +++ b/prompting/rewards/scoring.py @@ -1,4 +1,5 @@ import asyncio +import copy import threading from multiprocessing.managers import AcquirerProxy @@ -107,7 +108,7 @@ async def run_step(self) -> RewardLoggingEvent: task_queue=self.task_queue, ) if scoring_config.task.organic: - logger.info(f"Reward events: {reward_events}") + logger.debug(f"Reward events size: {len(reward_events)}") self.reward_events.append(reward_events) logger.debug( @@ -115,9 +116,15 @@ async def run_step(self) -> RewardLoggingEvent: f"{scoring_config.task.llm_model_id}" ) if not scoring_config.task.organic: + # Reduce log size for raw chunks, wandb fails to log any data when overloaded. + response = copy.deepcopy(scoring_config.response) + response.stream_results_all_chunk_dicts_raw = [] + for idx in range(len(response.stream_results)): + response.stream_results[idx].accumulated_chunk_dicts_raw = [] + log_event( RewardLoggingEvent( - response_event=scoring_config.response, + response_event=response, reward_events=reward_events, reference=scoring_config.task.reference, challenge=scoring_config.task.query, diff --git a/prompting/rewards/web_retrieval.py b/prompting/rewards/web_retrieval.py index 03deb2b4e..2d33f9fb8 100644 --- a/prompting/rewards/web_retrieval.py +++ b/prompting/rewards/web_retrieval.py @@ -24,6 +24,10 @@ MIN_RELEVANT_CHARS = 300 MIN_MATCH_THRESHOLD = 98 +MIN_SIM_THRESHOLD = 0.44 +PENALIZE_SIM_THRESHOLD = 0.36 +PENALTY = -0.1 + # Define file paths PAST_WEBSITES_FILE = "past_websites.csv" TOP_DOMAINS_FILE = "data/top100k_domains.csv" @@ -144,7 +148,7 @@ async def score_website_result( self, dataset_entry: DDGDatasetEntry, response_url: str, response_content: str, response_relevant: str, uid: str ) -> float: if not response_url or not response_content or not response_relevant: - return 0 + return PENALTY # Extract domain from URL. netloc = extract_main_domain(response_url) @@ -165,6 +169,7 @@ async def score_website_result( if not response_url or len(response_url) > 500: logger.debug(f"URL {response_url} is too long, setting discount factor to 0") return 0 + if not netloc or any(c.isdigit() for c in netloc.split(".")) or ":" in netloc: discount_factor = 0 logger.debug(f"URL {response_url} appears to be IP-based or on specific port, setting discount factor to 0") @@ -204,10 +209,15 @@ async def score_website_result( if response_relevant not in response_content: return 0 - return ( - await self._cosine_similarity(content1=dataset_entry.query, content2=response_relevant) - * discount_factor - ) + similarity = await self._cosine_similarity(content1=dataset_entry.query, content2=response_relevant) + if similarity < PENALIZE_SIM_THRESHOLD: + # Penalise if similarity is too low. + logger.debug(f"Miner {uid} returned text that doesn't match the query") + return PENALTY + elif similarity < MIN_SIM_THRESHOLD: + logger.debug(f"Miner {uid} returned text has low similarity") + return 0 + return similarity * discount_factor async def score_miner_response( self, dataset_entry: DDGDatasetEntry, completion: str, task: BaseTextTask | None = None, uid: str | None = None @@ -217,7 +227,7 @@ async def score_miner_response( unique_websites = np.unique([website.url for website in miner_websites]) if unique_websites.size != len(miner_websites) or unique_websites.size != task.target_results: # logger.warning("Miner returned multiple websites with the same URL") - return 0 + return PENALTY tasks = [ self.score_website_result(dataset_entry, website.url, website.content, website.relevant, uid) diff --git a/prompting/tasks/task_registry.py b/prompting/tasks/task_registry.py index c16a4c83b..8504afecd 100644 --- a/prompting/tasks/task_registry.py +++ b/prompting/tasks/task_registry.py @@ -5,14 +5,12 @@ from loguru import logger from pydantic import BaseModel, ConfigDict -from prompting.datasets.huggingface_github import HuggingFaceGithubDataset from prompting.datasets.random_website import DDGDataset from prompting.datasets.sn13 import SN13Dataset from prompting.rewards.reward import BaseRewardConfig from prompting.tasks.base_task import BaseTextTask from prompting.tasks.inference import InferenceRewardConfig, InferenceTask from prompting.tasks.MSRv2_task import MSRv2RewardConfig, MSRv2Task -from prompting.tasks.programming_task import ProgrammingRewardConfig, ProgrammingTask from prompting.tasks.web_retrieval import WebRetrievalRewardConfig, WebRetrievalTask from shared.base import BaseDataset @@ -31,22 +29,16 @@ def __hash__(self): class TaskRegistry(BaseModel): task_configs: ClassVar[list[TaskConfig]] = [ - TaskConfig(task=MSRv2Task, probability=0.05, datasets=[DDGDataset], reward_model=MSRv2RewardConfig), + TaskConfig(task=MSRv2Task, probability=0.10, datasets=[DDGDataset], reward_model=MSRv2RewardConfig), TaskConfig( task=InferenceTask, - probability=0.45, + probability=0.55, datasets=[SN13Dataset], reward_model=InferenceRewardConfig, ), - TaskConfig( - task=ProgrammingTask, - probability=0.10, - datasets=[HuggingFaceGithubDataset], - reward_model=ProgrammingRewardConfig, - ), TaskConfig( task=WebRetrievalTask, - probability=0.40, + probability=0.35, datasets=[DDGDataset], reward_model=WebRetrievalRewardConfig, ), diff --git a/pyproject.toml b/pyproject.toml index be33a7b58..831361976 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "prompting" -version = "2.19.2" +version = "2.19.3" description = "Subnetwork 1 runs on Bittensor and is maintained by Macrocosmos. It's an effort to create decentralised AI" authors = ["Kalei Brady, Dmytro Bobrenko, Felix Quinque, Steffen Cruz, Richard Wardle"] readme = "README.md" diff --git a/shared/settings.py b/shared/settings.py index f44f29c75..27cdba1b7 100644 --- a/shared/settings.py +++ b/shared/settings.py @@ -96,6 +96,9 @@ class SharedSettings(BaseSettings): DEPLOY_VALIDATOR: bool = Field(True, env="DEPLOY_VALDITAOR") DEPLOY_SCORING_API: bool = Field(True, env="DEPLOY_SCORING_API") SCORING_API_PORT: int = Field(8095, env="SCORING_API_PORT") + # Hard-code MC validator axon, since it might be overwritten in the metagraph. + MC_VALIDATOR_HOTKEY: str = Field("5Cg5QgjMfRqBC6bh8X4PDbQi7UzVRn9eyWXsB8gkyfppFPPy", env="MC_VALIDATOR_HOTKEY") + MC_VALIDATOR_AXON: str = Field("184.105.5.17:42174", env="MC_VALIDATOR_AXON") # ==== API ===== # Hotkey used to run api, defaults to Macrocosmos @@ -105,6 +108,8 @@ class SharedSettings(BaseSettings): # Scoring queue threshold when rate-limit start to kick in, used to query validator API with scoring requests. SCORING_QUEUE_API_THRESHOLD: int = Field(1, env="SCORING_QUEUE_API_THRESHOLD") API_TEST_MODE: bool = Field(False, env="API_TEST_MODE") + API_UIDS_EXPLORE: float = Field(0.2, env="API_UIDS_EXPLORE") + API_TOP_MINERS_SAMPLE: int = Field(400, env="API_TOP_MINERS_SAMPLE") # Validator scoring API (.env.validator). SCORE_ORGANICS: bool = Field(False, env="SCORE_ORGANICS") diff --git a/shared/uids.py b/shared/uids.py index 8de376ebe..1c6b4a349 100644 --- a/shared/uids.py +++ b/shared/uids.py @@ -92,7 +92,7 @@ def get_random_uids(k: int | None = 10**6, exclude: list[int] = None, own_uid: i raise ValueError(f"No eligible uids were found. Cannot return {k} uids") -def get_top_incentive_uids(k: int, vpermit_tao_limit: int) -> np.ndarray: +def get_top_incentive_uids(k: int, vpermit_tao_limit: int, explore: float = 0) -> np.ndarray: miners_uids = list(map(int, filter(lambda uid: check_uid_availability(uid), shared_settings.METAGRAPH.uids))) # Builds a dictionary of uids and their corresponding incentives. @@ -108,10 +108,18 @@ def get_top_incentive_uids(k: int, vpermit_tao_limit: int) -> np.ndarray: uid_incentive_pairs_sorted = sorted(uid_incentive_pairs, key=lambda x: x[1], reverse=True) # Extract the top uids. - top_k_uids = [uid for uid, incentive in uid_incentive_pairs_sorted[:k]] + num_explore_uids = int(k * explore) + num_top_uids = k - num_explore_uids + top_k_uids = [uid for uid, _ in uid_incentive_pairs_sorted[:num_top_uids]] - return list(np.array(top_k_uids).astype(int)) - # return [int(k) for k in top_k_uids] + if num_explore_uids > 0: + # Sample exploration uids randomly from the remaining pool. + remaining_pairs = uid_incentive_pairs_sorted[num_top_uids:] + remaining_uids = [uid for uid, _ in remaining_pairs] + explore_uids = random.sample(remaining_uids, min(num_explore_uids, len(remaining_uids))) + top_k_uids.extend(explore_uids) + + return top_k_uids def get_uids( @@ -119,6 +127,7 @@ def get_uids( k: int | None = None, exclude: list[int] = [], own_uid: int | None = None, + explore: float = 0.0, ) -> np.ndarray: if shared_settings.TEST and shared_settings.TEST_MINER_IDS: return random.sample( @@ -129,6 +138,6 @@ def get_uids( return get_random_uids(k=k, exclude=exclude or []) if sampling_mode == "top_incentive": vpermit_tao_limit = shared_settings.NEURON_VPERMIT_TAO_LIMIT - return get_top_incentive_uids(k=k, vpermit_tao_limit=vpermit_tao_limit) + return get_top_incentive_uids(k=k, vpermit_tao_limit=vpermit_tao_limit, explore=explore) if sampling_mode == "all": return [int(uid) for uid in shared_settings.METAGRAPH.uids if (uid != own_uid and check_uid_availability(uid))] diff --git a/tests/prompting/rewards/test_exact_match.py b/tests/prompting/rewards/test_exact_match.py index 8a40c5042..d8d4839bd 100644 --- a/tests/prompting/rewards/test_exact_match.py +++ b/tests/prompting/rewards/test_exact_match.py @@ -10,7 +10,8 @@ MAX_VERIFY_TOKENS, MIN_SMOOTH_PENALTY_SCALE, MIN_VERIFY_TOKENS, - NO_EOS_PENALTY, + PARTIAL_PENALTY, + TOP_LOGPROBS, VERIFICATION_THRESH_SIM, LogitsRewardModel, ) @@ -160,7 +161,7 @@ def mock_verify_sim(original_logits, verification_logits): assert isinstance(result, BatchRewardOutput) assert len(result.rewards) == 3 - assert 0.3 < result.rewards[0] <= 0.9 + assert 0.3 < result.rewards[0] <= 1.0 assert result.rewards[1] == INCORRECT_PENALTY assert result.rewards[2] == INCORRECT_PENALTY @@ -170,7 +171,7 @@ def mock_verify_sim(original_logits, verification_logits): "eos_in_logits, expected_penalty", [ (True, None), - (False, NO_EOS_PENALTY), + (False, PARTIAL_PENALTY), ], ids=["eos_present", "eos_missing"], ) @@ -200,7 +201,7 @@ async def test_eos_handling(eos_in_logits, expected_penalty, model_manager, task assert len(result.rewards) == 1 if expected_penalty is None: # eos present. - assert result.rewards[0] != NO_EOS_PENALTY + assert result.rewards[0] != PARTIAL_PENALTY else: # eos missing. assert result.rewards[0] == pytest.approx(expected_penalty) @@ -208,19 +209,21 @@ async def test_eos_handling(eos_in_logits, expected_penalty, model_manager, task def test_verify_logit_similarity(): """Test the verify_logit_similarity similarity metric.""" - original = {"token1": -0.1, "token2": -0.5, "token3": -1.0, "token4": -1.5, "token5": -2.0} + original = {f"token{idx}": -0.01 for idx in range(TOP_LOGPROBS)} # Identical distributions -> 1.0. assert LogitsRewardModel.verify_logit_similarity(original, original) == pytest.approx(1.0) - # Disjoint tokens -> near zero. - disjoint = {"foo": -0.1, "bar": -0.5, "foo1": -1.0, "bar1": -1.5, "foo2": -2.0} - sim = LogitsRewardModel.verify_logit_similarity(original, disjoint) - assert sim == pytest.approx(0.0) + with patch("prompting.rewards.exact_match.TOP_LOGPROBS", 5): + # Disjoint tokens -> near zero. + disjoint = {"foo": -0.1, "bar": -0.5, "foo1": -1.0, "bar1": -1.5, "foo2": -2.0} + sim = LogitsRewardModel.verify_logit_similarity(original, disjoint) + assert sim == pytest.approx(0.0) - # Partial overlap -> between 0 and 1. - partial = {"token1": -0.1, "token2": -0.5, "token3": -1.0, "foo1": -1.5, "bar1": -2.0} - sim2 = LogitsRewardModel.verify_logit_similarity(original, partial) - assert sim2 == pytest.approx(0.6) + # Partial overlap -> between 0 and 1. + original = {"token1": -0.1, "token2": -0.5, "token3": -1.0, "foo1": -1.5, "bar1": -2.0} + partial = {"token1": -0.1, "token2": -0.5, "token3": -1.0, "token4": -1.5, "token5": -2.0} + sim2 = LogitsRewardModel.verify_logit_similarity(original, partial) + assert sim2 == pytest.approx(0.6) def test_smooth_reward_scale(): @@ -229,20 +232,16 @@ def test_smooth_reward_scale(): assert LogitsRewardModel.smooth_timings_reward([]) == 0.0 # Test uniform timings (should give maximum reward). - uniform_timings = [1.0, 1.0, 1.0, 1.0, 1.0] - assert LogitsRewardModel.smooth_timings_reward(uniform_timings) == 1.0 + uniform_timings = [0.1, 0.1, 0.1, 0.1, 0.1] + assert LogitsRewardModel.smooth_timings_reward(uniform_timings) == pytest.approx(1.0) # Test high variance timings (should give minimum reward). - high_var_timings = [0.1, 5.0, 10.0, 0.5, 8.0] - std_dev = np.std(high_var_timings) + high_var_timings = [0.1, 0.1, 15.0, 0.1, 0.1] assert LogitsRewardModel.smooth_timings_reward(high_var_timings) == MIN_SMOOTH_PENALTY_SCALE - assert 1.0 - std_dev < MIN_SMOOTH_PENALTY_SCALE # Test moderate variance timings. - moderate_var_timings = [0.9, 1.0, 1.1, 0.95, 1.05] - expected = max(MIN_SMOOTH_PENALTY_SCALE, 1.0 - np.std(moderate_var_timings)) - assert LogitsRewardModel.smooth_timings_reward(moderate_var_timings) == pytest.approx(expected) - assert MIN_SMOOTH_PENALTY_SCALE < LogitsRewardModel.smooth_timings_reward(moderate_var_timings) < 1.0 + moderate_var_timings = [0.3, 0.2, 0.4, 0.1, 0.1] + assert LogitsRewardModel.smooth_timings_reward(moderate_var_timings) == pytest.approx(1.0) # Test with custom minimum reward. custom_min = 0.8 diff --git a/validator_api/api_management.py b/validator_api/api_management.py index 2604e32f7..337fda339 100644 --- a/validator_api/api_management.py +++ b/validator_api/api_management.py @@ -48,12 +48,6 @@ def validate_api_key( 2) Else, if 'Authorization' header exists and starts with Bearer, extract token and validate. 3) Otherwise, raise a 403. """ - - if api_key: - if api_key not in _keys: - raise HTTPException(status_code=403, detail="Invalid API key") - return _keys[api_key] - if authorization: scheme, _, token = authorization.partition(" ") if scheme.lower() != "bearer": @@ -62,6 +56,11 @@ def validate_api_key( raise HTTPException(status_code=403, detail="Invalid API key") return _keys[token] + if api_key: + if api_key not in _keys: + raise HTTPException(status_code=403, detail="Invalid API key") + return _keys[api_key] + raise HTTPException(status_code=403, detail="Missing API key") diff --git a/validator_api/chat_completion.py b/validator_api/chat_completion.py index b3499d515..312ce8fc4 100644 --- a/validator_api/chat_completion.py +++ b/validator_api/chat_completion.py @@ -3,10 +3,9 @@ import math import random import time -from typing import Any, AsyncGenerator, Callable, List, Optional +from typing import Any, AsyncGenerator, Optional -import httpcore -import httpx +import openai from fastapi import HTTPException from fastapi.responses import StreamingResponse from loguru import logger @@ -20,176 +19,112 @@ from validator_api.utils import filter_available_uids -async def peek_until_valid_chunk( - response: AsyncGenerator, is_valid_chunk: Callable[[Any], bool] -) -> tuple[Optional[Any], Optional[AsyncGenerator]]: - """ - Keep reading chunks until we find a 'valid' one or run out of chunks. - Return (first_valid_chunk, a_generator_of_all_chunks_including_this_one). - If no chunks or no valid chunks, return (None, None). - """ - consumed = [] - valid_chunk = None - - try: - async for chunk in response: - consumed.append(chunk) - if is_valid_chunk(chunk): - valid_chunk = chunk - break # we found our valid chunk - except StopAsyncIteration: - # no more chunks - pass - - if not consumed or valid_chunk is None: - # Either the generator is empty or we never found a valid chunk - return None, None - - # Rebuild a generator from the chunks we already consumed - # plus any remaining chunks that weren't pulled yet. - async def rebuilt_generator() -> AsyncGenerator: - # yield everything we consumed - for c in consumed: - yield c - # yield anything else still left in 'response' - async for c in response: - yield c - - return valid_chunk, rebuilt_generator() - - -def is_valid_chunk(chunk: Any) -> bool: - if chunk: - return ( - hasattr(chunk, "choices") - and len(chunk.choices) > 0 - and getattr(chunk.choices[0].delta, "content", None) is not None - ) - - -async def peek_first_chunk( - response: AsyncGenerator, -) -> tuple[Optional[any], Optional[AsyncGenerator]]: - """ - Pull one chunk from the async generator and return: - (the_chunk, a_new_generator_that_includes_this_chunk) - If the generator is empty, return (None, None). - """ - try: - first_chunk = await anext(response) # or: await anext(response, default=None) in Python 3.10+ - except StopAsyncIteration: - # Generator is empty - return None, None - - # At this point, we have the first chunk. We need to rebuild a generator - # that yields this chunk first, then yields the rest of the original response. - async def reconstructed_response() -> AsyncGenerator: - yield first_chunk - async for c in response: - yield c - - return first_chunk, reconstructed_response() - - -async def stream_chunks( - first_valid_response: AsyncGenerator, - collected_chunks_list: List[List[str]], - timings_list: List[List[float]], - response_start_time: float, +async def stream_from_first_response( # noqa: C901 + responses: list[asyncio.Task], + collected_chunks_list: list[list[str]], + collected_chunks_raw_list: list[list[Any]], + body: dict[str, Any], + uids: list[int], + timings_list: list[list[float]], ) -> AsyncGenerator[str, None]: - """Stream chunks from a valid response and collect timing data. + """Start streaming as soon as any miner produces the first non-empty chunk. + + While streaming, collect primary miner and all other miners chunks for scoring. - Args: - first_valid_response: The async generator containing response chunks - collected_chunks_list: List to collect response chunks - timings_list: List to collect timing data - response_start_time: Start time of the response for timing calculations + A chunk is considered non-empty if it has: + - chunk.choices[0].delta + - chunk.choices[0].delta.content + - chunk.choices[0].logprobs + - chunk.choices[0].logprobs.content """ - chunks_received = False - async for chunk in first_valid_response: - # Safely handle the chunk - if not chunk.choices or not chunk.choices[0].delta: - continue - - content = getattr(chunk.choices[0].delta, "content", None) - if content is None: - continue - - chunks_received = True - timings_list[0].append(time.monotonic() - response_start_time) - collected_chunks_list[0].append(content) - yield f"data: {json.dumps(chunk.model_dump())}\n\n" - - if not chunks_received: - logger.error("Stream is empty: No chunks were received") - yield 'data: {"error": "502 - Response is empty"}\n\n' - - yield "data: [DONE]\n\n" - - if timings_list and timings_list[0]: - logger.info(f"Response completion time: {timings_list[0][-1]:.2f}s") - - -async def stream_from_first_response( - responses: List[asyncio.Task], - collected_chunks_list: List[List[str]], - collected_chunks_raw_list: List, - body: dict[str, any], - uids: List[int], - timings_list: List[List[float]], -) -> AsyncGenerator[str, None]: - first_valid_response = None response_start_time = time.monotonic() + def _is_valid(chunk: Any) -> bool: + """Return True for the first chunk we care about (delta + logprobs.content).""" + try: + choice = chunk.choices[0] + return ( + choice.delta is not None + and getattr(choice.delta, "content", None) is not None + and choice.logprobs is not None + and getattr(choice.logprobs, "content", None) is not None + ) + except (AttributeError, IndexError): + return False + + # Guards first stream. + first_found_evt = asyncio.Event() + first_queue: asyncio.Queue[tuple[int, Any, AsyncGenerator]] = asyncio.Queue() + + async def _collector(idx: int, resp_task: asyncio.Task) -> None: + """Miner stream collector. + + 1. Wait for the miner's async-generator. + 2. On first valid chunk: + - if we’re FIRST → notify main via queue and exit + - else (someone already started) → keep collecting for scoring + """ + try: + resp_gen = await resp_task + if not resp_gen or isinstance(resp_gen, Exception): + return + + async for chunk in resp_gen: + if not _is_valid(chunk): + continue + + # Someone already claimed the stream? + if not first_found_evt.is_set(): + first_found_evt.set() + await first_queue.put((idx, chunk, resp_gen)) + return + + # We’re NOT the first – just collect for scoring. + collected_chunks_raw_list[idx].append(chunk) + collected_chunks_list[idx].append(chunk.choices[0].delta.content) + timings_list[idx].append(time.monotonic() - response_start_time) + + except (openai.APIConnectionError, asyncio.CancelledError): + pass + except Exception as e: + logger.exception(f"Collector error for miner index {idx}: {e}") + + # Spawn collectors for every miner. + collectors = [asyncio.create_task(_collector(idx, stream)) for idx, stream in enumerate(responses)] + + # Wait for the first valid chunk. try: - # Keep looping until we find a valid response or run out of tasks - while responses and first_valid_response is None: - done, pending = await asyncio.wait(responses, return_when=asyncio.FIRST_COMPLETED) - - for task in done: - responses.remove(task) - try: - response = await task # This is (presumably) an async generator + try: + primary_idx, first_chunk, primary_gen = await asyncio.wait_for(first_queue.get(), timeout=30) + except asyncio.TimeoutError: + logger.error("No miner produced a valid chunk within 30 s") + yield 'data: {"error": "502 - No valid response received"}\n\n' + return - if not response or isinstance(response, Exception): - continue - # Peak at the first chunk - first_chunk, rebuilt_generator = await peek_until_valid_chunk(response, is_valid_chunk) - if first_chunk is None: - continue + # Stream the very first chunk immediately. + collected_chunks_raw_list[primary_idx].append(first_chunk) + collected_chunks_list[primary_idx].append(first_chunk.choices[0].delta.content) + timings_list[primary_idx].append(time.monotonic() - response_start_time) + yield f"data: {json.dumps(first_chunk.model_dump())}\n\n" - first_valid_response = rebuilt_generator - break + # Continue streaming the primary miner. + async for chunk in primary_gen: + if not _is_valid(chunk): + continue + collected_chunks_raw_list[primary_idx].append(chunk) + collected_chunks_list[primary_idx].append(chunk.choices[0].delta.content) + timings_list[primary_idx].append(time.monotonic() - response_start_time) + yield f"data: {json.dumps(chunk.model_dump())}\n\n" - except Exception as e: - logger.exception(f"Error in miner response: {e}") - # just skip and continue to the next task + # End of stream. + yield "data: [DONE]\n\n" + if timings_list[primary_idx]: + logger.info(f"Response completion time: {timings_list[primary_idx][-1]:.2f}s") - if first_valid_response is None: - logger.error("No valid response received from any miner") - yield 'data: {"error": "502 - No valid response received"}\n\n' - return + # Wait for background collectors to finish. + await asyncio.gather(*collectors, return_exceptions=True) - # Stream the first valid response - async for chunk_data in stream_chunks( - first_valid_response, collected_chunks_list, timings_list, response_start_time - ): - yield chunk_data - - # Continue collecting remaining responses in background for scoring - remaining = asyncio.gather(*pending, return_exceptions=True) - remaining_tasks = asyncio.create_task( - collect_remaining_responses( - remaining=remaining, - collected_chunks_list=collected_chunks_list, - collected_chunks_raw_list=collected_chunks_raw_list, - body=body, - uids=uids, - timings_list=timings_list, - response_start_time=response_start_time, - ) - ) - await remaining_tasks + # Push everything to the scoring queue. asyncio.create_task( scoring_queue.scoring_queue.append_response( uids=uids, @@ -200,54 +135,16 @@ async def stream_from_first_response( ) ) - except asyncio.CancelledError: + except (openai.APIConnectionError, asyncio.CancelledError): logger.info("Client disconnected, streaming cancelled") - for task in responses: - task.cancel() + for c in collectors: + c.cancel() raise except Exception as e: logger.exception(f"Error during streaming: {e}") yield 'data: {"error": "Internal server Error"}\n\n' -async def collect_remaining_responses( - remaining: asyncio.Task, - collected_chunks_list: List[List[str]], - collected_chunks_raw_list: List, - body: dict[str, any], - uids: List[int], - timings_list: List[List[float]], - response_start_time: float, -): - """Collect remaining responses for scoring without blocking the main response.""" - try: - responses = await remaining - for i, response in enumerate(responses): - if isinstance(response, Exception): - logger.error(f"Error collecting response from uid {uids[i+1]}: {response}") - continue - - try: - async for chunk in response: - if not chunk.choices or not chunk.choices[0].delta: - continue - content = getattr(chunk.choices[0].delta, "content", None) - if content is None: - continue - - timings_list[i + 1].append(time.monotonic() - response_start_time) - collected_chunks_list[i + 1].append(content) - collected_chunks_raw_list[i + 1].append(chunk) - - except (httpx.ReadTimeout, httpcore.ReadTimeout) as e: - logger.warning(f"Stream timeout for index {i}: partial results collected. {e}") - except Exception as e: - logger.error(f"Unexpected error collecting stream for index {i}: {e}") - - except Exception as e: - logger.exception(f"Error collecting remaining responses: {e}") - - async def get_response_from_miner(body: dict[str, any], uid: int, timeout_seconds: int) -> tuple: """Get response from a single miner.""" return await make_openai_query( @@ -349,16 +246,18 @@ async def chat_completion( if first_valid_response is None: raise HTTPException(status_code=502, detail="No valid response received") - asyncio.create_task( - collect_remaining_nonstream_responses( - pending=pending, - collected_responses=collected_responses, - body=body, - uids=uids, - timings_list=timings_list, - ) - ) - return first_valid_response[0] # Return only the response object, not the chunks + # TODO: Non-stream scoring is not supported right now. + # asyncio.create_task( + # collect_remaining_nonstream_responses( + # pending=pending, + # collected_responses=collected_responses, + # body=body, + # uids=uids, + # timings_list=timings_list, + # ) + # ) + # Return only the response object, not the chunks. + return first_valid_response[0] async def collect_remaining_nonstream_responses( diff --git a/validator_api/deep_research/orchestrator_v2.py b/validator_api/deep_research/orchestrator_v2.py index ebe955af3..f5b4fe4af 100644 --- a/validator_api/deep_research/orchestrator_v2.py +++ b/validator_api/deep_research/orchestrator_v2.py @@ -72,7 +72,7 @@ async def search_web(question: str, n_results: int = 2, completions=None) -> dic search_results = {"results": []} # Generate referenced answer - answer_prompt = f"""Based on the provided search results, generate a comprehensive answer to the question. + answer_prompt = f"""Based on the provided search results, generate a concise but well-structured answer to the question. Include inline references to sources using markdown format [n] where n is the source number. Question: {question} @@ -119,12 +119,14 @@ async def search_web(question: str, n_results: int = 2, completions=None) -> dic @retry( - stop=stop_after_attempt(5), + stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=5), retry=retry_if_exception_type(json.JSONDecodeError), ) async def make_mistral_request_with_json( - messages: list[dict], step_name: str, completions: Callable[[CompletionsRequest], Awaitable[StreamingResponse]] + messages: list[dict], + step_name: str, + completions: Callable[[CompletionsRequest], Awaitable[StreamingResponse]], ): """Makes a request to Mistral API and records the query""" raw_response, query_record = await make_mistral_request(messages, step_name, completions) @@ -137,7 +139,7 @@ async def make_mistral_request_with_json( @retry( - stop=stop_after_attempt(7), + stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=5), retry=retry_if_exception_type(BaseException), ) @@ -234,7 +236,7 @@ def description(self) -> str: return """Searches the web to answer a question. Provides a referenced answer with citations. Input parameters: - question: The natural language question to answer - - n_results: (optional) Number of search results to use (default: 5) + - n_results: (optional) Number of search results to use (default: 2) Returns a dictionary containing: - question: Original question asked diff --git a/validator_api/gpt_endpoints.py b/validator_api/gpt_endpoints.py index 1a0ec281e..f8179111d 100644 --- a/validator_api/gpt_endpoints.py +++ b/validator_api/gpt_endpoints.py @@ -1,6 +1,6 @@ import random -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, status from loguru import logger from starlette.responses import StreamingResponse @@ -10,8 +10,9 @@ from validator_api.api_management import validate_api_key from validator_api.chat_completion import chat_completion from validator_api.deep_research.orchestrator_v2 import OrchestratorV2 +from validator_api.job_store import JobStatus, job_store, process_chain_of_thought_job from validator_api.mixture_of_miners import mixture_of_miners -from validator_api.serializers import CompletionsRequest, TestTimeInferenceRequest +from validator_api.serializers import CompletionsRequest, JobResponse, JobResultResponse, TestTimeInferenceRequest from validator_api.utils import filter_available_uids router = APIRouter() @@ -84,7 +85,12 @@ async def completions(request: CompletionsRequest, api_key: str = Depends(valida logger.error(f"Error in uids: {body.get('uids')}") else: uids = filter_available_uids( - task=body.get("task"), model=body.get("model"), test=shared_settings.API_TEST_MODE, n_miners=N_MINERS + task=body.get("task"), + model=body.get("model"), + test=shared_settings.API_TEST_MODE, + n_miners=N_MINERS, + n_top_incentive=shared_settings.API_TOP_MINERS_SAMPLE, + explore=shared_settings.API_UIDS_EXPLORE, ) if not uids: raise HTTPException(status_code=500, detail="No available miners") @@ -151,3 +157,161 @@ async def create_response_stream(request): "Connection": "keep-alive", }, ) + + +@router.post( + "/v1/chat/completions/jobs", + summary="Asynchronous chat completions endpoint for Chain-of-Thought", + description="Submit a Chain-of-Thought inference job to be processed in the background and get a job ID immediately.", + response_model=JobResponse, + status_code=status.HTTP_202_ACCEPTED, + responses={ + status.HTTP_202_ACCEPTED: { + "description": "Job accepted for processing", + "model": JobResponse, + }, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Internal server error or no available miners"}, + }, +) +async def submit_chain_of_thought_job( + request: CompletionsRequest, background_tasks: BackgroundTasks, api_key: str = Depends(validate_api_key) +): + """ + Submit a Chain-of-Thought inference job to be processed in the background. + + This endpoint accepts the same parameters as the /v1/chat/completions endpoint, + but instead of streaming the response, it submits the job to the background and + returns a job ID immediately. The job results can be retrieved using the + /v1/chat/completions/jobs/{job_id} endpoint. + + ## Request Parameters: + - **uids** (List[int], optional): Specific miner UIDs to query. If not provided, miners will be selected automatically. + - **messages** (List[dict]): List of message objects with 'role' and 'content' keys. Required. + - **model** (str, optional): Model identifier to filter available miners. + + ## Response: + - **job_id** (str): Unique identifier for the job. + - **status** (str): Current status of the job (pending, running, completed, failed). + - **created_at** (str): Timestamp when the job was created. + - **updated_at** (str): Timestamp when the job was last updated. + + Example request: + ```json + { + "messages": [ + {"role": "user", "content": "Solve the equation: 3x + 5 = 14"} + ], + "model": "gpt-4" + } + ``` + """ + try: + body = request.model_dump() + + # Check if inference mode is Chain-of-Thought, if not return error + if body.get("inference_mode") != "Chain-of-Thought": + raise HTTPException(status_code=400, detail="This endpoint only accepts Chain-of-Thought inference mode") + + body["model"] = ( + "mrfakename/mistral-small-3.1-24b-instruct-2503-hf" if body.get("model") == "Default" else body.get("model") + ) + + body["seed"] = int(body.get("seed") or random.randint(0, 1000000)) + + uids = ( + [int(uid) for uid in body.get("uids")] + if body.get("uids") + else filter_available_uids( + task=body.get("task"), model=body.get("model"), test=shared_settings.API_TEST_MODE, n_miners=N_MINERS + ) + ) + + if not uids: + raise HTTPException(status_code=500, detail="No available miners") + + # Create a new job + job_id = job_store.create_job() + + # Create the test time inference request + test_time_request = TestTimeInferenceRequest( + messages=request.messages, + model=request.model, + uids=uids, + json_format=request.json_format, + ) + + # Create the orchestrator + orchestrator = OrchestratorV2(completions=completions) + + # Add the background task + background_tasks.add_task( + process_chain_of_thought_job, + job_id=job_id, + orchestrator=orchestrator, + messages=test_time_request.messages, + ) + + # Get the job + job = job_store.get_job(job_id) + + # Return the job response + return JobResponse( + job_id=job.job_id, + status=job.status, + created_at=job.created_at.isoformat(), + updated_at=job.updated_at.isoformat(), + ) + + except Exception as e: + logger.exception(f"Error in creating chain of thought job: {e}") + raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") + + +@router.get( + "/v1/chat/completions/jobs/{job_id}", + summary="Get the status and result of a Chain-of-Thought job", + description="Retrieve the status and result of a Chain-of-Thought job by its ID.", + response_model=JobResultResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: { + "description": "Job status and result", + "model": JobResultResponse, + }, + status.HTTP_404_NOT_FOUND: {"description": "Job not found"}, + }, +) +async def get_chain_of_thought_job(job_id: str, api_key: str = Depends(validate_api_key)): + """ + Get the status and result of a Chain-of-Thought job. + + This endpoint retrieves the status and result of a Chain-of-Thought job by its ID. + If the job is completed, the result will be included in the response. + If the job failed, the error message will be included in the response. + + ## Path Parameters: + - **job_id** (str): The ID of the job to retrieve. + + ## Response: + - **job_id** (str): Unique identifier for the job. + - **status** (str): Current status of the job (pending, running, completed, failed). + - **created_at** (str): Timestamp when the job was created. + - **updated_at** (str): Timestamp when the job was last updated. + - **result** (List[str], optional): Result of the job if completed. + - **error** (str, optional): Error message if the job failed. + """ + job = job_store.get_job(job_id) + + if job.status == JobStatus.COMPLETED: # todo check if job is deleted + job_store.delete_job(job_id) + if not job: + raise HTTPException(status_code=404, detail=f"Job with ID {job_id} not found") + + return JobResultResponse( + job_id=job.job_id, + status=job.status, + created_at=job.created_at.isoformat(), + updated_at=job.updated_at.isoformat(), + result=job.result, + error=job.error, + ) diff --git a/validator_api/job_store.py b/validator_api/job_store.py new file mode 100644 index 000000000..d20905a51 --- /dev/null +++ b/validator_api/job_store.py @@ -0,0 +1,156 @@ +import sqlite3 +import uuid +from datetime import datetime +from enum import Enum +from typing import Dict, List, Optional + +from pydantic import BaseModel + + +class JobStatus(str, Enum): + """Enum for job status.""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class JobResult(BaseModel): + """Model for job result.""" + + job_id: str + status: JobStatus + created_at: datetime + updated_at: datetime + result: Optional[List[str]] = None + error: Optional[str] = None + + +class JobStore: + """Store for background jobs using SQLite database.""" + + def __init__(self, db_path: str = "jobs.db"): + """Initialize the job store with SQLite database.""" + self.db_path = db_path + self._init_db() + + def _init_db(self) -> None: + """Initialize the database and create the jobs table if it doesn't exist.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS jobs ( + job_id TEXT PRIMARY KEY, + status TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL, + result TEXT, + error TEXT + ) + """ + ) + conn.commit() + + def delete_job(self, job_id: str) -> None: + """Delete a job by its ID.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute("DELETE FROM jobs WHERE job_id = ?", (job_id,)) + conn.commit() + + def create_job(self) -> str: + """Create a new job and return its ID.""" + job_id = str(uuid.uuid4()) + now = datetime.now() + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + INSERT INTO jobs (job_id, status, created_at, updated_at) + VALUES (?, ?, ?, ?) + """, + (job_id, JobStatus.PENDING, now, now), + ) + conn.commit() + return job_id + + def get_job(self, job_id: str) -> Optional[JobResult]: + """Get a job by its ID.""" + with sqlite3.connect(self.db_path) as conn: + conn.row_factory = sqlite3.Row + cursor = conn.execute("SELECT * FROM jobs WHERE job_id = ?", (job_id,)) + row = cursor.fetchone() + + if row is None: + return None + + # Convert the result string to List[str] if it exists + result = eval(row["result"]) if row["result"] is not None else None + + return JobResult( + job_id=row["job_id"], + status=JobStatus(row["status"]), + created_at=datetime.fromisoformat(row["created_at"]), + updated_at=datetime.fromisoformat(row["updated_at"]), + result=result, + error=row["error"], + ) + + def update_job_status(self, job_id: str, status: JobStatus) -> None: + """Update the status of a job.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + UPDATE jobs + SET status = ?, updated_at = ? + WHERE job_id = ? + """, + (status, datetime.now(), job_id), + ) + conn.commit() + + def update_job_result(self, job_id: str, result: List[str]) -> None: + """Update the result of a job.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + UPDATE jobs + SET result = ?, status = ?, updated_at = ? + WHERE job_id = ? + """, + (str(result), JobStatus.COMPLETED, datetime.now(), job_id), + ) + conn.commit() + + def update_job_error(self, job_id: str, error: str) -> None: + """Update the error of a job.""" + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + UPDATE jobs + SET error = ?, status = ?, updated_at = ? + WHERE job_id = ? + """, + (error, JobStatus.FAILED, datetime.now(), job_id), + ) + conn.commit() + + +# Create a singleton instance +job_store = JobStore() + + +async def process_chain_of_thought_job(job_id: str, orchestrator, messages: List[Dict[str, str]]) -> None: + """Process a chain of thought job in the background.""" + try: + job_store.update_job_status(job_id, JobStatus.RUNNING) + + # Collect all chunks from the orchestrator + chunks = [] + async for chunk in orchestrator.run(messages=messages): + chunks.append(chunk) + + # Update the job with the result + job_store.update_job_result(job_id, chunks) + except Exception as e: + # Update the job with the error + job_store.update_job_error(job_id, str(e)) diff --git a/validator_api/scoring_queue.py b/validator_api/scoring_queue.py index 00d12b6dd..4769d9251 100644 --- a/validator_api/scoring_queue.py +++ b/validator_api/scoring_queue.py @@ -1,5 +1,6 @@ import asyncio import datetime +import json from collections import deque from typing import Any @@ -47,7 +48,7 @@ async def wait_for_next_execution(self, last_run_time) -> datetime.datetime: async def run_step(self): """Perform organic scoring: pop queued payload, forward to the validator API.""" - logger.debug("Running scoring step") + # logger.debug("Running scoring step") async with self._scoring_lock: if not self._scoring_queue: return @@ -70,6 +71,7 @@ async def run_step(self): payload = payload.to_dict() elif isinstance(payload, BaseModel): payload = payload.model_dump() + payload_bytes = json.dumps(payload).encode() timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0) # Add required headers for signature verification @@ -80,8 +82,8 @@ async def run_step(self): ) as client: response = await client.post( url=url, - json=payload, - # headers=headers, + content=payload_bytes, + headers={"Content-Type": "application/json"}, ) validator_registry.update_validators(uid=vali_uid, response_code=response.status_code) if response.status_code != 200: @@ -142,7 +144,7 @@ async def append_response( "chunk_dicts_raw": chunk_dict_raw, } scoring_item = ScoringPayload(payload=payload, date=datetime.datetime.now().replace(microsecond=0)) - logger.info(f"Appending organic to scoring queue: {scoring_item}") + # logger.info(f"Appending organic to scoring queue: {scoring_item}") async with self._scoring_lock: if len(self._scoring_queue) >= self._queue_maxlen: scoring_payload = self._scoring_queue.popleft() diff --git a/validator_api/serializers.py b/validator_api/serializers.py index df609b020..ed40aacec 100644 --- a/validator_api/serializers.py +++ b/validator_api/serializers.py @@ -1,11 +1,20 @@ +import json from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator + +from validator_api.job_store import JobStatus class CompletionsRequest(BaseModel): """Request model for the /v1/chat/completions endpoint.""" + @model_validator(mode="after") + def add_tools(self): + if self.tools: + self.messages.append({"role": "tool", "content": json.dumps(self.tools)}) + return self + uids: Optional[List[int]] = Field( default=None, description="List of specific miner UIDs to query. If not provided, miners will be selected automatically.", @@ -25,7 +34,7 @@ class CompletionsRequest(BaseModel): default="InferenceTask", description="Task identifier to choose the inference type.", example="InferenceTask" ) model: Optional[str] = Field( - default=None, + default="Default", description="Model identifier to filter available miners.", example="Default", ) @@ -59,6 +68,21 @@ class CompletionsRequest(BaseModel): ) json_format: bool = Field(default=False, description="Enable JSON format for the response.", example=True) stream: bool = Field(default=False, description="Enable streaming for the response.", example=True) + tools: Optional[List[Dict[str, Any]]] = Field( + default=None, + description="List of tools to use for the task.", + # TODO: Add example that's not just from claude + example=[ + { + "type": "function", + "function": { + "name": "get_current_time", + "description": "Get the current time", + "parameters": {"timezone": {"type": "string", "description": "The timezone to get the time in"}}, + }, + } + ], + ) class WebRetrievalRequest(BaseModel): @@ -124,3 +148,19 @@ class TestTimeInferenceRequest(BaseModel): def to_dict(self): return self.model_dump().update({"messages": [m.model_dump() for m in self.messages]}) + + +class JobResponse(BaseModel): + """Response model for the /v1/chat/completions/jobs endpoint.""" + + job_id: str = Field(..., description="Unique identifier for the job") + status: JobStatus = Field(..., description="Current status of the job") + created_at: str = Field(..., description="Timestamp when the job was created") + updated_at: str = Field(..., description="Timestamp when the job was last updated") + + +class JobResultResponse(JobResponse): + """Response model for the /v1/chat/completions/jobs/{job_id} endpoint.""" + + result: Optional[List[str]] = Field(None, description="Result of the job if completed") + error: Optional[str] = Field(None, description="Error message if the job failed") diff --git a/validator_api/utils.py b/validator_api/utils.py index f75611bd6..16a067d5b 100644 --- a/validator_api/utils.py +++ b/validator_api/utils.py @@ -62,6 +62,7 @@ def filter_available_uids( test: bool = False, n_miners: int = 10, n_top_incentive: int = 400, + explore: float = 0.0, ) -> list[int]: """Filter UIDs based on task and model availability. @@ -80,7 +81,7 @@ def filter_available_uids( filtered_uids = [] - for uid in get_uids(sampling_mode="top_incentive", k=max(n_top_incentive, n_miners)): + for uid in get_uids(sampling_mode="top_incentive", k=max(n_top_incentive, n_miners), explore=explore): # Skip if miner data is None/unavailable if update_miner_availabilities_for_api.miner_availabilities.get(str(uid)) is None: continue @@ -103,7 +104,7 @@ def filter_available_uids( # "Got an empty list of available UIDs, falling back to all uids. " # "Check VALIDATOR_API and SCORING_KEY in .env.api" # ) - filtered_uids = get_uids(sampling_mode="top_incentive", k=n_top_incentive) + filtered_uids = get_uids(sampling_mode="top_incentive", k=n_top_incentive, explore=explore) # logger.info(f"Filtered UIDs: {filtered_uids}") filtered_uids = random.sample(filtered_uids, min(len(filtered_uids), n_miners)) diff --git a/validator_api/validator_forwarding.py b/validator_api/validator_forwarding.py index 260a3c223..869a9e992 100644 --- a/validator_api/validator_forwarding.py +++ b/validator_api/validator_forwarding.py @@ -88,6 +88,8 @@ def get_available_axon(self) -> Optional[Tuple[int, List[str], str]]: return None weights = [self.validators[uid].stake for uid in validator_list] chosen = self.validators[random.choices(validator_list, weights=weights, k=1)[0]] + if chosen.hotkey == shared_settings.MC_VALIDATOR_HOTKEY: + chosen.axon = shared_settings.MC_VALIDATOR_AXON return chosen.uid, chosen.axon, chosen.hotkey def update_validators(self, uid: int, response_code: int) -> None: diff --git a/validator_api/web_retrieval.py b/validator_api/web_retrieval.py index ec1119368..845406332 100644 --- a/validator_api/web_retrieval.py +++ b/validator_api/web_retrieval.py @@ -1,3 +1,5 @@ +from typing import Any + from fastapi import APIRouter, Depends, HTTPException, status from shared import settings @@ -7,7 +9,6 @@ import json import random -import numpy as np from loguru import logger from shared.epistula import SynapseStreamResult, query_miners @@ -36,99 +37,141 @@ }, }, ) -async def web_retrieval( +async def web_retrieval( # noqa: C901 request: WebRetrievalRequest, api_key: str = Depends(validate_api_key), ): - """ - Web retrieval endpoint that queries multiple miners to search the web. - - This endpoint distributes a search query to multiple miners, which perform web searches - and return relevant results. The results are deduplicated based on URLs before being returned. - - ## Request Parameters: - - **search_query** (str): The query to search for on the web. Required. - - **n_miners** (int, default=10): Number of miners to query for results. - - **n_results** (int, default=5): Maximum number of results to return in the response. - - **max_response_time** (int, default=10): Maximum time to wait for responses in seconds. - - **uids** (List[int], optional): Optional list of specific miner UIDs to query. - - ## Response: - Returns a list of unique web search results, each containing: - - **url** (str): The URL of the web page - - **content** (str, optional): The relevant content from the page - - **relevant** (str, optional): Information about why this result is relevant - - Example request: - ```json - { - "search_query": "latest advancements in quantum computing", - "n_miners": 15, - "n_results": 10 - } - ``` - """ + """Launch *all* requested miners in parallel, return immediately when the first miner delivers a valid result.""" + # Choose miners. if request.uids: - uids = request.uids try: - uids = list(map(int, uids)) + uids: list[int] = list(map(int, request.uids)) except Exception: - logger.error(f"Error in uids: {uids}") + logger.error(f"Invalid uids supplied: {request.uids}") + raise HTTPException(status_code=500, detail="Invalid miner uids") else: - uids = filter_available_uids( - task="WebRetrievalTask", test=shared_settings.API_TEST_MODE, n_miners=request.n_miners + available = filter_available_uids( + task="WebRetrievalTask", + test=shared_settings.API_TEST_MODE, + n_miners=request.n_miners, + explore=shared_settings.API_UIDS_EXPLORE, ) - uids = random.sample(uids, min(len(uids), request.n_miners)) + uids = random.sample(available, min(len(available), request.n_miners)) - if len(uids) == 0: + if not uids: raise HTTPException(status_code=500, detail="No available miners") - body = { + # Shared miner request body. + body: dict[str, Any] = { "seed": random.randint(0, 1_000_000), "sampling_parameters": shared_settings.SAMPLING_PARAMS, "task": "WebRetrievalTask", "target_results": request.n_results, "timeout": request.max_response_time, - "messages": [ - {"role": "user", "content": request.search_query}, - ], + "messages": [{"role": "user", "content": request.search_query}], } + timeout_seconds = body["timeout"] or 15 + + async def _call_miner(idx: int, uid: int) -> tuple[int, list[str], list[str]]: + """Fire a single miner and return (index, accumulated_chunks, raw_chunks). - timeout_seconds = 15 # TODO: We need to scale down this timeout - logger.debug(f"🔍 Querying miners: {uids} for web retrieval") - stream_results = await query_miners(uids, body, timeout_seconds) - results = [ - "".join(res.accumulated_chunks) - for res in stream_results - if isinstance(res, SynapseStreamResult) and res.accumulated_chunks - ] - distinct_results = list(np.unique(results)) - loaded_results = [] - for result in distinct_results: + The result is per-miner; we don't wait for the others here. + """ + stream_results = await query_miners([uid], body, timeout_seconds) + if not stream_results: + return idx, [], [] + + res: SynapseStreamResult = stream_results[0] + return idx, res.accumulated_chunks or [], getattr(res, "raw_chunks", []) + + def _parse_chunks(chunks: list[str]) -> list[dict[str, Any]] | None: + """Load JSON, filter dicts with required keys, None on failure/empty.""" + if not chunks: + return None try: - loaded_results.append(json.loads(result)) - logger.info(f"🔍 Result: {result}") + payload: Any = json.loads("".join(chunks)) + # Handle double-encoded JSON. + if isinstance(payload, str): + payload = json.loads(payload) + if isinstance(payload, dict): + payload = [payload] + if not isinstance(payload, list): + return None + required = ("url", "content", "relevant") + filtered = [d for d in payload if (isinstance(d, dict) and all(k in d and d[k] for k in required))] + return filtered or None except Exception: - logger.error(f"🔍 Result: {result}") - if len(loaded_results) == 0: - raise HTTPException(status_code=500, detail="No miner responded successfully") - - collected_chunks_list = [res.accumulated_chunks if res and res.accumulated_chunks else [] for res in stream_results] - asyncio.create_task( - scoring_queue.scoring_queue.append_response( - uids=uids, body=body, chunks=collected_chunks_list, chunk_dicts_raw=None, timings=None - ) - ) - loaded_results = [json.loads(r) if isinstance(r, str) else r for r in loaded_results] - flat_results = [item for sublist in loaded_results for item in sublist] - unique_results = [] - seen_urls = set() - - for result in flat_results: - if isinstance(result, dict) and "url" in result: - if result["url"] not in seen_urls: - seen_urls.add(result["url"]) - # Convert dict to WebSearchResult - unique_results.append(WebSearchResult(**result)) - - return WebRetrievalResponse(results=unique_results) + return None + + # Fire miners concurrently. + logger.debug(f"🔍 Querying miners for web retrieval: {uids}") + miner_tasks = [asyncio.create_task(_call_miner(i, uid)) for i, uid in enumerate(uids)] + + # Pre-allocate structures (same order as `uids`) for later scoring. + collected_chunks_list: list[list[str]] = [[] for _ in uids] + collected_chunks_raw_list: list[list[Any]] = [[] for _ in uids] + + try: + first_valid: list[dict[str, Any]] | None = None + primary_idx: int | None = None + + # as_completed yields tasks exactly when each finishes. + for fut in asyncio.as_completed(miner_tasks): + idx, chunks, raw_chunks = await fut + collected_chunks_list[idx] = chunks + collected_chunks_raw_list[idx] = raw_chunks + + parsed = _parse_chunks(chunks) + if parsed: + first_valid = parsed[: request.n_results] + primary_idx = idx + # Stop iterating; others handled in background + break + + if first_valid is None: + logger.warning("No miner produced a valid (non-empty) result list") + raise HTTPException(status_code=500, detail="No miner responded successfully") + + # Build client response from the winner. + unique, seen = [], set() + for item in first_valid: + if item["url"] not in seen: + seen.add(item["url"]) + unique.append(WebSearchResult(**item)) + + # Collect all remaining miners *quietly* then push to scoring. + async def _collect_remaining(pending: list[asyncio.Task]) -> None: + try: + for fut in asyncio.as_completed(pending): + idx, chunks, raw_chunks = await fut + collected_chunks_list[idx] = chunks + collected_chunks_raw_list[idx] = raw_chunks + except Exception as exc: + logger.debug(f"Error collecting remaining miners: {exc}") + + await scoring_queue.scoring_queue.append_response( + uids=uids, + body=body, + chunks=collected_chunks_list, + chunk_dicts_raw=collected_chunks_raw_list, + timings=None, + ) + + # Pending tasks still not finished. + pending_tasks = [t for t in miner_tasks if not t.done()] + asyncio.create_task(_collect_remaining(pending_tasks)) + + logger.info(f"✅ Returning {len(unique)} results from miner idx={primary_idx}, uid={uids[primary_idx]}") + return WebRetrievalResponse(results=unique) + + # Cleanup and error handling. + except asyncio.CancelledError: + logger.warning("Client disconnected – cancelling miner tasks") + for t in miner_tasks: + t.cancel() + raise + except HTTPException: + raise + except Exception as exc: + logger.exception(f"Unhandled error in web_retrieval: {exc}") + raise HTTPException(status_code=500, detail="Internal server error")