diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 296f9b5c6..d601319e7 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -12,7 +12,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10"] + python-version: ["3.10"] steps: - uses: actions/checkout@v3 diff --git a/neurons/miners/epistula_miner/miner.py b/neurons/miners/epistula_miner/miner.py index 7e918830e..7ed556f9f 100644 --- a/neurons/miners/epistula_miner/miner.py +++ b/neurons/miners/epistula_miner/miner.py @@ -1,8 +1,8 @@ # ruff: noqa: E402 -from prompting import settings +from shared import settings -settings.settings = settings.Settings.load(mode="miner") -settings = settings.settings +settings.shared_settings = settings.SharedSettings.load(mode="miner") +shared_settings = settings.shared_settings import asyncio import json @@ -39,7 +39,7 @@ def __init__(self): self.client = httpx.AsyncClient( base_url="https://api.openai.com/v1", headers={ - "Authorization": f"Bearer {settings.OPENAI_API_KEY}", + "Authorization": f"Bearer {shared_settings.OPENAI_API_KEY}", "Content-Type": "application/json", }, ) @@ -107,14 +107,14 @@ async def verify_request( signed_by = request.headers.get("Epistula-Signed-By") signed_for = request.headers.get("Epistula-Signed-For") - if signed_for != settings.WALLET.hotkey.ss58_address: + if signed_for != shared_settings.WALLET.hotkey.ss58_address: raise HTTPException(status_code=400, detail="Bad Request, message is not intended for self") - if signed_by not in settings.METAGRAPH.hotkeys: + if signed_by not in shared_settings.METAGRAPH.hotkeys: raise HTTPException(status_code=401, detail="Signer not in metagraph") - uid = settings.METAGRAPH.hotkeys.index(signed_by) - stake = settings.METAGRAPH.S[uid].item() - if not settings.NETUID == 61 and stake < 10000: + uid = shared_settings.METAGRAPH.hotkeys.index(signed_by) + stake = shared_settings.METAGRAPH.S[uid].item() + if not shared_settings.NETUID == 61 and stake < 10000: logger.warning(f"Blacklisting request from {signed_by} [uid={uid}], not enough stake -- {stake}") raise HTTPException(status_code=401, detail="Stake below minimum: {stake}") @@ -133,7 +133,7 @@ async def verify_request( raise HTTPException(status_code=400, detail=err) def run(self): - external_ip = None # settings.EXTERNAL_IP + external_ip = None # shared_settings.EXTERNAL_IP if not external_ip or external_ip == "[::]": try: external_ip = requests.get("https://checkip.amazonaws.com").text.strip() @@ -142,16 +142,16 @@ def run(self): logger.error("Failed to get external IP") logger.info( - f"Serving miner endpoint {external_ip}:{settings.AXON_PORT} on network: {settings.SUBTENSOR_NETWORK} with netuid: {settings.NETUID}" + f"Serving miner endpoint {external_ip}:{shared_settings.AXON_PORT} on network: {shared_settings.SUBTENSOR_NETWORK} with netuid: {shared_settings.NETUID}" ) serve_success = serve_extrinsic( - subtensor=settings.SUBTENSOR, - wallet=settings.WALLET, + subtensor=shared_settings.SUBTENSOR, + wallet=shared_settings.WALLET, ip=external_ip, - port=settings.AXON_PORT, + port=shared_settings.AXON_PORT, protocol=4, - netuid=settings.NETUID, + netuid=shared_settings.NETUID, ) if not serve_success: logger.error("Failed to serve endpoint") @@ -174,7 +174,7 @@ def run(self): fast_config = uvicorn.Config( app, host="0.0.0.0", - port=settings.AXON_PORT, + port=shared_settings.AXON_PORT, log_level="info", loop="asyncio", workers=4, @@ -182,7 +182,7 @@ def run(self): self.fast_api = FastAPIThreadedServer(config=fast_config) self.fast_api.start() - logger.info(f"Miner starting at block: {settings.SUBTENSOR.block}") + logger.info(f"Miner starting at block: {shared_settings.SUBTENSOR.block}") # Main execution loop. try: diff --git a/neurons/test_vanilla_post.py b/neurons/test_vanilla_post.py index ecd458786..31e6adeb0 100644 --- a/neurons/test_vanilla_post.py +++ b/neurons/test_vanilla_post.py @@ -1,10 +1,10 @@ import openai from httpx import Timeout -from prompting import settings +from shared import settings -settings.settings = settings.Settings.load(mode="validator") -settings = settings.settings +settings.shared_settings = settings.SharedSettings.load(mode="validator") +shared_settings = settings.shared_settings from shared.epistula import create_header_hook diff --git a/neurons/validator.py b/neurons/validator.py index 9d7e63054..6062ec8f2 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -81,11 +81,12 @@ async def spawn_loops(task_queue, scoring_queue, reward_events): asyncio.run(spawn_loops(task_queue, scoring_queue, reward_events)) -def start_api(): +def start_api(scoring_queue, reward_events): async def start(): from prompting.api.api import start_scoring_api # noqa: F401 - await start_scoring_api() + await start_scoring_api(scoring_queue, reward_events) + while True: await asyncio.sleep(10) logger.debug("Running API...") @@ -125,7 +126,7 @@ async def main(): if shared_settings.DEPLOY_SCORING_API: # Use multiprocessing to bypass API blocking issue - api_process = mp.Process(target=start_api, name="API_Process") + api_process = mp.Process(target=start_api, args=(scoring_queue, reward_events), name="API_Process") api_process.start() processes.append(api_process) diff --git a/prompting/api/api.py b/prompting/api/api.py index 63678ab75..825b19c7b 100644 --- a/prompting/api/api.py +++ b/prompting/api/api.py @@ -4,6 +4,7 @@ from prompting.api.miner_availabilities.api import router as miner_availabilities_router from prompting.api.scoring.api import router as scoring_router +from prompting.rewards.scoring import task_scorer from shared.settings import shared_settings app = FastAPI() @@ -17,7 +18,9 @@ def health(): return {"status": "healthy"} -async def start_scoring_api(): +async def start_scoring_api(scoring_queue, reward_events): + task_scorer.scoring_queue = scoring_queue + task_scorer.reward_events = reward_events logger.info(f"Starting Scoring API on https://0.0.0.0:{shared_settings.SCORING_API_PORT}") uvicorn.run( "prompting.api.api:app", host="0.0.0.0", port=shared_settings.SCORING_API_PORT, loop="asyncio", reload=False diff --git a/prompting/api/scoring/api.py b/prompting/api/scoring/api.py index ee18ade14..d7c981c17 100644 --- a/prompting/api/scoring/api.py +++ b/prompting/api/scoring/api.py @@ -4,9 +4,11 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Request from loguru import logger +from prompting.datasets.random_website import DDGDatasetEntry from prompting.llms.model_zoo import ModelZoo from prompting.rewards.scoring import task_scorer from prompting.tasks.inference import InferenceTask +from prompting.tasks.web_retrieval import WebRetrievalTask from shared.base import DatasetEntry from shared.dendrite import DendriteResponseEvent from shared.epistula import SynapseStreamResult @@ -37,22 +39,54 @@ async def score_response(request: Request, api_key_data: dict = Depends(validate uid = int(payload.get("uid")) chunks = payload.get("chunks") llm_model = ModelZoo.get_model_by_id(model) if (model := body.get("model")) else None - task_scorer.add_to_queue( - task=InferenceTask( - messages=[msg["content"] for msg in body.get("messages")], - llm_model=llm_model, - llm_model_id=body.get("model"), - seed=int(body.get("seed", 0)), - sampling_params=body.get("sampling_params", {}), - ), - response=DendriteResponseEvent( - uids=[uid], - stream_results=[SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None])], - timeout=shared_settings.NEURON_TIMEOUT, - ), - dataset_entry=DatasetEntry(), - block=shared_settings.METAGRAPH.block, - step=-1, - task_id=str(uuid.uuid4()), - ) - logger.info("Organic tas appended to scoring queue") + task = body.get("task") + if task == "InferenceTask": + logger.info(f"Received Organic InferenceTask with body: {body}") + task_scorer.add_to_queue( + task=InferenceTask( + messages=[msg["content"] for msg in body.get("messages")], + llm_model=llm_model, + llm_model_id=body.get("model"), + seed=int(body.get("seed", 0)), + sampling_params=body.get("sampling_params", {}), + ), + response=DendriteResponseEvent( + uids=[uid], + stream_results=[ + SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None]) + ], + timeout=shared_settings.NEURON_TIMEOUT, + ), + dataset_entry=DatasetEntry(), + block=shared_settings.METAGRAPH.block, + step=-1, + task_id=str(uuid.uuid4()), + ) + elif task == "WebRetrievalTask": + logger.info(f"Received Organic WebRetrievalTask with body: {body}") + try: + search_term = body.get("messages")[0].get("content") + except Exception as ex: + logger.error(f"Failed to get search term from messages: {ex}, can't score WebRetrievalTask") + return + + task_scorer.add_to_queue( + task=WebRetrievalTask( + messages=[msg["content"] for msg in body.get("messages")], + seed=int(body.get("seed", 0)), + sampling_params=body.get("sampling_params", {}), + query=search_term, + ), + response=DendriteResponseEvent( + uids=[uid], + stream_results=[ + SynapseStreamResult(accumulated_chunks=[chunk for chunk in chunks if chunk is not None]) + ], + timeout=shared_settings.NEURON_TIMEOUT, + ), + dataset_entry=DDGDatasetEntry(search_term=search_term), + block=shared_settings.METAGRAPH.block, + step=-1, + task_id=str(uuid.uuid4()), + ) + logger.info("Organic task appended to scoring queue") diff --git a/prompting/datasets/random_website.py b/prompting/datasets/random_website.py index 3058812a0..18fd3c8de 100644 --- a/prompting/datasets/random_website.py +++ b/prompting/datasets/random_website.py @@ -15,8 +15,8 @@ class DDGDatasetEntry(DatasetEntry): search_term: str - website_url: str - website_content: str + website_url: str = None + website_content: str = None class DDGDataset(BaseDataset): diff --git a/prompting/settings.py b/prompting/settings.py deleted file mode 100644 index 01ce347d3..000000000 --- a/prompting/settings.py +++ /dev/null @@ -1,134 +0,0 @@ -import os -from functools import cached_property -from typing import Any, Literal, Optional - -import bittensor as bt -import dotenv -from loguru import logger -from pydantic import Field, model_validator -from pydantic_settings import BaseSettings - - -class Settings(BaseSettings): - mode: Literal["miner", "validator", "mock"] - MOCK: bool = False - NO_BACKGROUND_THREAD: bool = True - SAVE_PATH: Optional[str] = Field("./storage", env="SAVE_PATH") - model_config = {"frozen": True, "arbitrary_types_allowed": False} - - # Class variables for singleton. - _instance: Optional["Settings"] = None - _instance_mode: Optional[str] = None - - @classmethod - def load_env_file(cls, mode: Literal["miner", "validator", "mock"]): - """Load the appropriate .env file based on the mode.""" - if mode == "miner": - dotenv_file = ".env.miner" - elif mode == "validator": - dotenv_file = ".env.validator" - # For mock testing, still make validator env vars available where possible. - elif mode == "mock": - dotenv_file = ".env.validator" - else: - raise ValueError(f"Invalid mode: {mode}") - - if dotenv_file: - if not dotenv.load_dotenv(dotenv.find_dotenv(filename=dotenv_file)): - logger.warning( - f"No {dotenv_file} file found. The use of args when running a {mode} will be deprecated " - "in the near future." - ) - - @classmethod - def load(cls, mode: Literal["miner", "validator", "mock"]) -> "Settings": - """Load or retrieve the Settings instance based on the mode.""" - if cls._instance is not None and cls._instance_mode == mode: - return cls._instance - else: - cls.load_env_file(mode) - cls._instance = cls(mode=mode) - cls._instance_mode = mode - return cls._instance - - @model_validator(mode="before") - def complete_settings(cls, values: dict[str, Any]) -> dict[str, Any]: - mode = values["mode"] - netuid = values.get("NETUID", 61) - - if netuid is None: - raise ValueError("NETUID must be specified") - values["TEST"] = netuid != 1 - if values.get("TEST_MINER_IDS"): - values["TEST_MINER_IDS"] = str(values["TEST_MINER_IDS"]).split(",") - if mode == "mock": - values["MOCK"] = True - values["NEURON_DEVICE"] = "cpu" - logger.info("Running in mock mode. Bittensor objects will not be initialized.") - return values - - # load slow packages only if not in mock mode - import torch - - if not values.get("NEURON_DEVICE"): - values["NEURON_DEVICE"] = "cuda" if torch.cuda.is_available() else "cpu" - - # Ensure SAVE_PATH exists. - save_path = values.get("SAVE_PATH", "./storage") - if not os.path.exists(save_path): - os.makedirs(save_path) - if values.get("SN19_API_KEY") is None or values.get("SN19_API_URL") is None: - logger.warning( - "It is strongly recommended to provide an SN19 API KEY + URL to avoid incurring OpenAI API costs." - ) - if mode == "validator": - if values.get("OPENAI_API_KEY") is None: - raise Exception( - "You must provide an OpenAI API key as a backup. It is recommended to also provide an SN19 API key + url to avoid incurring API costs." - ) - if values.get("SCORING_ADMIN_KEY") is None: - raise Exception("You must provide an admin key to access the API.") - if values.get("PROXY_URL") is None: - logger.warning( - "You must provide a proxy URL to use the DuckDuckGo API - your vtrust might decrease if no DDG URL is provided." - ) - return values - - @cached_property - def WALLET(self): - wallet_name = self.WALLET_NAME # or config().wallet.name - hotkey = self.HOTKEY # or config().wallet.hotkey - logger.info(f"Instantiating wallet with name: {wallet_name}, hotkey: {hotkey}") - return bt.wallet(name=wallet_name, hotkey=hotkey) - - @cached_property - def SUBTENSOR(self) -> bt.subtensor: - subtensor_network = self.SUBTENSOR_NETWORK or os.environ.get("SUBTENSOR_NETWORK", "local") - # bt_config = config() - if subtensor_network.lower() == "local": - subtensor_network = os.environ.get("SUBTENSOR_CHAIN_ENDPOINT") # bt_config.subtensor.chain_endpoint or - else: - subtensor_network = subtensor_network.lower() # bt_config.subtensor.network or - logger.info(f"Instantiating subtensor with network: {subtensor_network}") - return bt.subtensor(network=subtensor_network) - - @cached_property - def METAGRAPH(self) -> bt.metagraph: - logger.info(f"Instantiating metagraph with NETUID: {self.NETUID}") - return self.SUBTENSOR.metagraph(netuid=self.NETUID) - - @cached_property - def DENDRITE(self) -> bt.dendrite: - logger.info(f"Instantiating dendrite with wallet: {self.WALLET}") - return bt.dendrite(wallet=self.WALLET) - - -logger.info("Settings class instantiated.") -settings: Optional[Settings] = None -try: - settings: Optional[Settings] = Settings.load(mode="mock") - pass -except Exception as e: - logger.exception(f"Error loading settings: {e}") - settings = None -logger.info("Settings loaded.") diff --git a/prompting/weight_setting/weight_setter.py b/prompting/weight_setting/weight_setter.py index 28c195b13..12fc97c0f 100644 --- a/prompting/weight_setting/weight_setter.py +++ b/prompting/weight_setting/weight_setter.py @@ -102,7 +102,7 @@ def set_weights( "weights": processed_weights.flatten(), "raw_weights": str(list(weights.flatten())), "averaged_weights": str(list(averaged_weights.flatten())), - "block": ttl_get_block(), + "block": ttl_get_block(subtensor=subtensor), } ) step_filename = "weights.csv" diff --git a/scripts/client.py b/scripts/client.py index aabe7e431..29e72a0d3 100644 --- a/scripts/client.py +++ b/scripts/client.py @@ -1,7 +1,7 @@ -from prompting import settings +from shared import settings -settings.settings = settings.Settings.load(mode="validator") -settings = settings.settings +settings.shared_settings = settings.SharedSettings.load(mode="validator") +shared_settings = settings.shared_settings import asyncio import json @@ -9,7 +9,6 @@ from loguru import logger from shared.epistula import query_miners -from shared.settings import shared_settings """ This has assumed you have: diff --git a/shared/dendrite.py b/shared/dendrite.py index 254a6c799..ccfd5c86a 100644 --- a/shared/dendrite.py +++ b/shared/dendrite.py @@ -35,9 +35,9 @@ def model_dump(self): class DendriteResponseEvent(BaseModel): uids: np.ndarray | list[float] - axons: list[str] timeout: float stream_results: list[SynapseStreamResult] + axons: list[str] = [] completions: list[str] = [] status_messages: list[str] = [] status_codes: list[int] = [] diff --git a/shared/epistula.py b/shared/epistula.py index 44a59f6f1..5af064856 100644 --- a/shared/epistula.py +++ b/shared/epistula.py @@ -111,7 +111,7 @@ async def merged_stream(responses: list[AsyncGenerator]): logger.error(f"Error while streaming: {e}") -async def query_miners(uids, body: dict[str, Any]): +async def query_miners(uids, body: dict[str, Any]) -> list[SynapseStreamResult]: try: tasks = [] for uid in uids: diff --git a/shared/misc.py b/shared/misc.py index e58f313cd..858765b7c 100644 --- a/shared/misc.py +++ b/shared/misc.py @@ -100,7 +100,7 @@ def ttl_get_block(subtensor: bt.Subtensor | None = None) -> int: efficiently reduces the workload on the blockchain interface. Example: - current_block = ttl_get_block(self) + current_block = ttl_get_block(subtensor=subtensor) Note: self here is the miner or validator instance """ diff --git a/shared/uids.py b/shared/uids.py index ef494ba27..a8c99b5d2 100644 --- a/shared/uids.py +++ b/shared/uids.py @@ -114,7 +114,8 @@ def get_top_incentive_uids(k: int, vpermit_tao_limit: int) -> np.ndarray: # Extract the top uids. top_k_uids = [uid for uid, incentive in uid_incentive_pairs_sorted[:k]] - return np.array(top_k_uids).astype(int) + return list(np.array(top_k_uids).astype(int)) + # return [int(k) for k in top_k_uids] def get_uids( diff --git a/validator_api/chat_completion.py b/validator_api/chat_completion.py index b0cef8729..fd05a174a 100644 --- a/validator_api/chat_completion.py +++ b/validator_api/chat_completion.py @@ -1,9 +1,8 @@ import asyncio import json import random -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional -import httpx from fastapi import HTTPException from fastapi.responses import StreamingResponse from loguru import logger @@ -11,83 +10,118 @@ from shared.epistula import make_openai_query from shared.settings import shared_settings from shared.uids import get_uids +from validator_api.utils import forward_response -async def forward_response(uid: int, body: dict[str, any], chunks: list[str]): - uid = int(uid) # sometimes uid is type np.uint64 - logger.info(f"Forwarding response to scoring with body: {body}") - if not shared_settings.SCORE_ORGANICS: # Allow disabling of scoring by default - return - - if body.get("task") != "InferenceTask": - logger.debug(f"Skipping forwarding for non-inference task: {body.get('task')}") - return - url = f"http://{shared_settings.VALIDATOR_API}/scoring" - payload = {"body": body, "chunks": chunks, "uid": uid} - try: - timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0) - async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.post( - url, json=payload, headers={"api-key": shared_settings.SCORING_KEY, "Content-Type": "application/json"} - ) - if response.status_code == 200: - logger.info(f"Forwarding response completed with status {response.status_code}") - - else: - logger.exception( - f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}" - ) - - except Exception as e: - logger.error(f"Tried to forward response to {url} with payload {payload}") - logger.exception(f"Error while forwarding response: {e}") - - -async def stream_response( - response, collected_chunks: list[str], body: dict[str, any], uid: int +async def stream_from_first_response( + responses: List[asyncio.Task], collected_chunks_list: List[List[str]], body: dict[str, any], uids: List[int] ) -> AsyncGenerator[str, None]: - chunks_received = False + first_valid_response = None try: - async for chunk in response: + # Wait for the first valid response + while responses and first_valid_response is None: + done, pending = await asyncio.wait(responses, return_when=asyncio.FIRST_COMPLETED) + + for task in done: + try: + response = await task + if response and not isinstance(response, Exception): + first_valid_response = response + break + except Exception as e: + logger.error(f"Error in miner response: {e}") + responses.remove(task) + + if first_valid_response is None: + logger.error("No valid response received from any miner") + yield 'data: {"error": "502 - No valid response received"}\n\n' + return + + # Stream the first valid response + chunks_received = False + async for chunk in first_valid_response: chunks_received = True - collected_chunks.append(chunk.choices[0].delta.content) + collected_chunks_list[0].append(chunk.choices[0].delta.content) yield f"data: {json.dumps(chunk.model_dump())}\n\n" if not chunks_received: logger.error("Stream is empty: No chunks were received") yield 'data: {"error": "502 - Response is empty"}\n\n' + yield "data: [DONE]\n\n" - # Forward the collected chunks after streaming is complete - asyncio.create_task(forward_response(uid=uid, body=body, chunks=collected_chunks)) + # Continue collecting remaining responses in background for scoring + remaining = asyncio.gather(*pending, return_exceptions=True) + asyncio.create_task(collect_remaining_responses(remaining, collected_chunks_list, body, uids)) + except asyncio.CancelledError: logger.info("Client disconnected, streaming cancelled") + for task in responses: + task.cancel() raise except Exception as e: logger.exception(f"Error during streaming: {e}") yield 'data: {"error": "Internal server Error"}\n\n' -async def chat_completion(body: dict[str, any], uid: int | None = None) -> tuple | StreamingResponse: - """Handle regular chat completion without mixture of miners.""" - if uid is None: - uid = random.choice(get_uids(sampling_mode="top_incentive", k=100)) +async def collect_remaining_responses( + remaining: asyncio.Task, collected_chunks_list: List[List[str]], body: dict[str, any], uids: List[int] +): + """Collect remaining responses for scoring without blocking the main response.""" + try: + responses = await remaining + logger.debug(f"responses to forward: {responses}") + for i, response in enumerate(responses): + if isinstance(response, Exception): + logger.error(f"Error collecting response from uid {uids[i+1]}: {response}") + continue + + async for chunk in response: + collected_chunks_list[i + 1].append(chunk.choices[0].delta.content) + for uid, chunks in zip(uids, collected_chunks_list): + # Forward for scoring + asyncio.create_task(forward_response(uid, body, chunks)) - if uid is None: - logger.error("No available miner found") - raise HTTPException(status_code=503, detail="No available miner found") + except Exception as e: + logger.exception(f"Error collecting remaining responses: {e}") + + +async def get_response_from_miner(body: dict[str, any], uid: int) -> tuple: + """Get response from a single miner.""" + return await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=False) - logger.debug(f"Querying uid {uid}") - STREAM = body.get("stream", False) - collected_chunks: list[str] = [] +async def chat_completion( + body: dict[str, any], uids: Optional[int] = None, num_miners: int = 3 +) -> tuple | StreamingResponse: + """Handle chat completion with multiple miners in parallel.""" + # Get multiple UIDs if none specified + if uids is None: + uids = list(get_uids(sampling_mode="top_incentive", k=100)) + if uids is None or len(uids) == 0: # if not uids throws error, figure out how to fix + logger.error("No available miners found") + raise HTTPException(status_code=503, detail="No available miners found") + selected_uids = random.sample(uids, min(num_miners, len(uids))) + else: + selected_uids = uids[:num_miners] # If UID is specified, only use that one + + logger.debug(f"Querying uids {selected_uids}") + STREAM = body.get("stream", False) - logger.info(f"Making {'streaming' if STREAM else 'non-streaming'} openai query with body: {body}") - response = await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=STREAM) + # Initialize chunks collection for each miner + collected_chunks_list = [[] for _ in selected_uids] if STREAM: + # Create tasks for all miners + response_tasks = [ + asyncio.create_task( + make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=True) + ) + for uid in selected_uids + ] + return StreamingResponse( - stream_response(response, collected_chunks, body, uid), + stream_from_first_response(response_tasks, collected_chunks_list, body, selected_uids), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", @@ -95,10 +129,32 @@ async def chat_completion(body: dict[str, any], uid: int | None = None) -> tuple }, ) else: - asyncio.create_task(forward_response(uid=uid, body=body, chunks=response[1])) - return response[0] - - -async def get_response_from_miner(body: dict[str, any], uid: int) -> tuple: - """Get response from a single miner.""" - return await make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, body, uid, stream=False) + # For non-streaming requests, wait for first valid response + response_tasks = [asyncio.create_task(get_response_from_miner(body, uid)) for uid in selected_uids] + + first_valid_response = None + collected_responses = [] + + while response_tasks and first_valid_response is None: + done, pending = await asyncio.wait(response_tasks, return_when=asyncio.FIRST_COMPLETED) + + for task in done: + try: + response = await task + if response and isinstance(response, tuple): + if first_valid_response is None: + first_valid_response = response + collected_responses.append(response) + except Exception as e: + logger.error(f"Error in miner response: {e}") + response_tasks.remove(task) + + if first_valid_response is None: + raise HTTPException(status_code=502, detail="No valid response received") + + # Forward all collected responses for scoring in the background + for i, response in enumerate(collected_responses): + if response and isinstance(response, tuple): + asyncio.create_task(forward_response(uid=selected_uids[i], body=body, chunks=response[1])) + + return first_valid_response[0] # Return only the response object, not the chunks diff --git a/validator_api/gpt_endpoints.py b/validator_api/gpt_endpoints.py index 34681f0e1..b363c1b13 100644 --- a/validator_api/gpt_endpoints.py +++ b/validator_api/gpt_endpoints.py @@ -1,17 +1,34 @@ +import asyncio +import json import random -from fastapi import APIRouter, Request +import numpy as np +from fastapi import APIRouter, Depends, Header, HTTPException, Request from loguru import logger from starlette.responses import StreamingResponse +from shared.epistula import SynapseStreamResult, query_miners +from shared.settings import shared_settings +from shared.uids import get_uids from validator_api.chat_completion import chat_completion from validator_api.mixture_of_miners import mixture_of_miners +from validator_api.utils import forward_response router = APIRouter() +# load api keys from api_keys.json +with open("api_keys.json", "r") as f: + _keys = json.load(f) + + +def validate_api_key(api_key: str = Header(...)): + if api_key not in _keys: + raise HTTPException(status_code=403, detail="Invalid API key") + return _keys[api_key] + @router.post("/v1/chat/completions") -async def completions(request: Request): +async def completions(request: Request, api_key: str = Depends(validate_api_key)): """Main endpoint that handles both regular and mixture of miners chat completion.""" try: body = await request.json() @@ -26,3 +43,48 @@ async def completions(request: Request): except Exception as e: logger.exception(f"Error in chat completion: {e}") return StreamingResponse(content="Internal Server Error", status_code=500) + + +@router.post("/web_retrieval") +async def web_retrieval(search_query: str, n_miners: int = 10, uids: list[int] = None): + uids = list(get_uids(sampling_mode="random", k=n_miners)) + logger.debug(f"🔍 Querying uids: {uids}") + if len(uids) == 0: + logger.warning("No available miners. This should already have been caught earlier.") + return + + body = { + "seed": random.randint(0, 1_000_000), + "sampling_parameters": shared_settings.SAMPLING_PARAMS, + "task": "WebRetrievalTask", + "messages": [ + {"role": "user", "content": search_query}, + ], + } + stream_results = await query_miners(uids, body) + results = [ + "".join(res.accumulated_chunks) + for res in stream_results + if isinstance(res, SynapseStreamResult) and res.accumulated_chunks + ] + distinct_results = list(np.unique(results)) + logger.info( + f"🔍 Collected responses from {len(stream_results)} miners. {len(results)} responded successfully with a total of {len(distinct_results)} distinct results" + ) + loaded_results = [] + for result in distinct_results: + try: + loaded_results.append(json.loads(result)) + logger.info(f"🔍 Result: {result}") + except Exception: + logger.error(f"🔍 Result: {result}") + if len(loaded_results) == 0: + raise HTTPException(status_code=500, detail="No miner responded successfully") + + for uid, res in zip(uids, stream_results): + asyncio.create_task( + forward_response( + uid=uid, body=body, chunks=res.accumulated_chunks if res and res.accumulated_chunks else [] + ) + ) + return loaded_results diff --git a/validator_api/utils.py b/validator_api/utils.py new file mode 100644 index 000000000..b83e819b1 --- /dev/null +++ b/validator_api/utils.py @@ -0,0 +1,35 @@ +import httpx +from loguru import logger + +from shared.settings import shared_settings + + +# TODO: Modify this so that all the forwarded responses are sent in a single request. This is both more efficient but +# also means that on the validator side all responses are scored at once, speeding up the scoring process. +async def forward_response(uid: int, body: dict[str, any], chunks: list[str]): + uid = int(uid) + logger.info(f"Forwarding response from uid {uid} to scoring with body: {body} and chunks: {chunks}") + if not shared_settings.SCORE_ORGANICS: + return + + if body.get("task") != "InferenceTask" and body.get("task") != "WebRetrievalTask": + logger.debug(f"Skipping forwarding for non- inference/web retrieval task: {body.get('task')}") + return + + url = f"http://{shared_settings.VALIDATOR_API}/scoring" + payload = {"body": body, "chunks": chunks, "uid": uid} + try: + timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0) + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.post( + url, json=payload, headers={"api-key": shared_settings.SCORING_KEY, "Content-Type": "application/json"} + ) + if response.status_code == 200: + logger.info(f"Forwarding response completed with status {response.status_code}") + else: + logger.exception( + f"Forwarding response uid {uid} failed with status {response.status_code} and payload {payload}" + ) + except Exception as e: + logger.error(f"Tried to forward response to {url} with payload {payload}") + logger.exception(f"Error while forwarding response: {e}")