diff --git a/validator_api/api.py b/validator_api/api.py index 9e06f1246..d31da3a0a 100644 --- a/validator_api/api.py +++ b/validator_api/api.py @@ -33,13 +33,42 @@ async def lifespan(app: FastAPI): pass -app = FastAPI(lifespan=lifespan) +app = FastAPI( + title="Validator API", + description="API for interacting with the validator network and miners", + version="1.0.0", + docs_url="/docs", + redoc_url="/redoc", + openapi_url="/openapi.json", + openapi_tags=[ + { + "name": "GPT Endpoints", + "description": "Endpoints for chat completions, web retrieval, and test time inference", + }, + { + "name": "API Management", + "description": "Endpoints for API key management and validation", + }, + ], + lifespan=lifespan, +) app.include_router(gpt_router, tags=["GPT Endpoints"]) app.include_router(api_management_router, tags=["API Management"]) -@app.get("/health") +@app.get( + "/health", + summary="Health check endpoint", + description="Simple endpoint to check if the API is running", + tags=["Health"], + response_description="Status of the API", +) async def health(): + """ + Health check endpoint to verify the API is operational. + + Returns a simple JSON object with status "ok" if the API is running. + """ return {"status": "ok"} diff --git a/validator_api/chat_completion.py b/validator_api/chat_completion.py index 774dd2b6b..368f8ba6d 100644 --- a/validator_api/chat_completion.py +++ b/validator_api/chat_completion.py @@ -98,6 +98,7 @@ async def stream_from_first_response( ) -> AsyncGenerator[str, None]: first_valid_response = None response_start_time = time.monotonic() + try: # Keep looping until we find a valid response or run out of tasks while responses and first_valid_response is None: @@ -245,11 +246,20 @@ async def chat_completion( collected_chunks_list = [[] for _ in uids] timings_list = [[] for _ in uids] - if not body.get("sampling_parameters"): - raise HTTPException(status_code=422, detail="Sampling parameters are required") timeout_seconds = max( - 30, max(0, math.floor(math.log2(body["sampling_parameters"].get("max_new_tokens", 256) / 256))) * 10 + 30 + 30, + max( + 0, + math.floor( + math.log2( + body.get("sampling_parameters", shared_settings.SAMPLING_PARAMS).get("max_new_tokens", 256) / 256 + ) + ), + ) + * 10 + + 30, ) + if STREAM: # Create tasks for all miners response_tasks = [ @@ -297,7 +307,7 @@ async def chat_completion( raise HTTPException(status_code=502, detail="No valid response received") asyncio.create_task( - collect_remainin_nonstream_responses( + collect_remaining_nonstream_responses( pending=pending, collected_responses=collected_responses, body=body, @@ -308,7 +318,7 @@ async def chat_completion( return first_valid_response[0] # Return only the response object, not the chunks -async def collect_remainin_nonstream_responses( +async def collect_remaining_nonstream_responses( pending: set[asyncio.Task], collected_responses: list, body: dict, @@ -316,6 +326,7 @@ async def collect_remainin_nonstream_responses( timings_list: list, ): """Wait for all pending miner tasks to complete and append their responses to the scoring queue.""" + try: # Wait for all remaining tasks; allow exceptions to be returned. remaining_responses = await asyncio.gather(*pending, return_exceptions=True) diff --git a/validator_api/gpt_endpoints.py b/validator_api/gpt_endpoints.py index 37bbbca60..67f9c1b83 100644 --- a/validator_api/gpt_endpoints.py +++ b/validator_api/gpt_endpoints.py @@ -5,7 +5,7 @@ import uuid import numpy as np -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, status from loguru import logger from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta from starlette.responses import StreamingResponse @@ -18,6 +18,13 @@ from validator_api.api_management import validate_api_key from validator_api.chat_completion import chat_completion from validator_api.mixture_of_miners import mixture_of_miners +from validator_api.serializers import ( + CompletionsRequest, + TestTimeInferenceRequest, + WebRetrievalRequest, + WebRetrievalResponse, + WebSearchResult, +) from validator_api.test_time_inference import generate_response from validator_api.utils import filter_available_uids @@ -25,15 +32,63 @@ N_MINERS = 5 -@router.post("/v1/chat/completions") -async def completions(request: Request, api_key: str = Depends(validate_api_key)): - """Main endpoint that handles both regular and mixture of miners chat completion.""" +@router.post( + "/v1/chat/completions", + summary="Chat completions endpoint", + description="Main endpoint that handles both regular, multi step reasoning, test time inference, and mixture of miners chat completion.", + response_description="Streaming response with generated text", + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: { + "description": "Successful response with streaming text", + "content": {"text/event-stream": {}}, + }, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Internal server error or no available miners"}, + }, +) +async def completions(request: CompletionsRequest, api_key: str = Depends(validate_api_key)): + """ + Chat completions endpoint that supports different inference modes. + + This endpoint processes chat messages and returns generated completions using + different inference strategies based on the request parameters. + + ## Inference Modes: + - Regular chat completion + - Multi Step Reasoning + - Test time inference + - Mixture of miners + + ## Request Parameters: + - **uids** (List[int], optional): Specific miner UIDs to query. If not provided, miners will be selected automatically. + - **messages** (List[dict]): List of message objects with 'role' and 'content' keys. Required. + - **seed** (int, optional): Random seed for reproducible results. + - **task** (str, optional): Task identifier to filter available miners. + - **model** (str, optional): Model identifier to filter available miners. + - **test_time_inference** (bool, default=False): Enable step-by-step reasoning mode. + - **mixture** (bool, default=False): Enable mixture of miners mode. + - **sampling_parameters** (dict, optional): Parameters to control text generation. + + The endpoint selects miners based on the provided UIDs or filters available miners + based on task and model requirements. + + Example request: + ```json + { + "messages": [ + {"role": "user", "content": "Tell me about neural networks"} + ], + "model": "gpt-4", + "seed": 42 + } + ``` + """ try: - body = await request.json() + body = request.model_dump() body["seed"] = int(body.get("seed") or random.randint(0, 1000000)) if body.get("uids"): try: - uids = [int(uid) for uid in body.get("uids")] + uids = list(map(int, body.get("uids"))) except Exception: logger.error(f"Error in uids: {body.get('uids')}") else: @@ -43,10 +98,10 @@ async def completions(request: Request, api_key: str = Depends(validate_api_key) if not uids: raise HTTPException(status_code=500, detail="No available miners") - # Choose between regular completion and mixture of miners. + # Choose between regular inference, test time inference, and mixture of miners. if body.get("test_time_inference", False): - return await test_time_inference(body["messages"], body.get("model", None), target_uids=body.get("uids")) - if body.get("mixture", False): + return await test_time_inference(request) + elif body.get("mixture", False): return await mixture_of_miners(body, uids=uids) else: return await chat_completion(body, uids=uids) @@ -56,24 +111,66 @@ async def completions(request: Request, api_key: str = Depends(validate_api_key) return StreamingResponse(content="Internal Server Error", status_code=500) -@router.post("/web_retrieval") +@router.post( + "/web_retrieval", + response_model=WebRetrievalResponse, + summary="Web retrieval endpoint", + description="Retrieves information from the web based on a search query using multiple miners.", + response_description="List of unique web search results", + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: { + "description": "Successful response with web search results", + "model": WebRetrievalResponse, + }, + status.HTTP_500_INTERNAL_SERVER_ERROR: { + "description": "Internal server error, no available miners, or no successful miner responses" + }, + }, +) async def web_retrieval( - search_query: str, - n_miners: int = 10, - n_results: int = 5, - max_response_time: int = 10, + request: WebRetrievalRequest, api_key: str = Depends(validate_api_key), - target_uids: list[str] | list[int] = None, ): - if target_uids: - uids = target_uids + """ + Web retrieval endpoint that queries multiple miners to search the web. + + This endpoint distributes a search query to multiple miners, which perform web searches + and return relevant results. The results are deduplicated based on URLs before being returned. + + ## Request Parameters: + - **search_query** (str): The query to search for on the web. Required. + - **n_miners** (int, default=10): Number of miners to query for results. + - **n_results** (int, default=5): Maximum number of results to return in the response. + - **max_response_time** (int, default=10): Maximum time to wait for responses in seconds. + - **uids** (List[int], optional): Optional list of specific miner UIDs to query. + + ## Response: + Returns a list of unique web search results, each containing: + - **url** (str): The URL of the web page + - **content** (str, optional): The relevant content from the page + - **relevant** (str, optional): Information about why this result is relevant + + Example request: + ```json + { + "search_query": "latest advancements in quantum computing", + "n_miners": 15, + "n_results": 10 + } + ``` + """ + if request.uids: + uids = request.uids try: - uids = [int(uid) for uid in uids] + uids = list(map(int, uids)) except Exception: logger.error(f"Error in uids: {uids}") else: - uids = filter_available_uids(task="WebRetrievalTask", test=shared_settings.API_TEST_MODE, n_miners=n_miners) - uids = random.sample(uids, min(len(uids), n_miners)) + uids = filter_available_uids( + task="WebRetrievalTask", test=shared_settings.API_TEST_MODE, n_miners=request.n_miners + ) + uids = random.sample(uids, min(len(uids), request.n_miners)) if len(uids) == 0: raise HTTPException(status_code=500, detail="No available miners") @@ -82,10 +179,10 @@ async def web_retrieval( "seed": random.randint(0, 1_000_000), "sampling_parameters": shared_settings.SAMPLING_PARAMS, "task": "WebRetrievalTask", - "target_results": n_results, - "timeout": max_response_time, + "target_results": request.n_results, + "timeout": request.max_response_time, "messages": [ - {"role": "user", "content": search_query}, + {"role": "user", "content": request.search_query}, ], } @@ -115,31 +212,65 @@ async def web_retrieval( unique_results = [] seen_urls = set() - # for result in flat_results: - # # TODO: This is a hack to try and avoid the stringify json issue, this needs a deeper fix. - # try: - # if isinstance(result, str): - # result = json.loads(result) - # if isinstance(result, dict) and 'url' in result: - # if result["url"] not in seen_urls: - # seen_urls.add(result["url"]) - # unique_results.append(result) - # except Exception: - # logger.warning(f"Skipping invalid result: {result}") - - # sometimes the results are not in the correct format, so we need to filter them out for result in flat_results: if isinstance(result, dict) and "url" in result: if result["url"] not in seen_urls: seen_urls.add(result["url"]) - unique_results.append(result) - return unique_results + # Convert dict to WebSearchResult + unique_results.append(WebSearchResult(**result)) + return WebRetrievalResponse(results=unique_results) + + +@router.post( + "/test_time_inference", + summary="Test time inference endpoint", + description="Provides step-by-step reasoning and thinking process during inference.", + response_description="Streaming response with reasoning steps", + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: { + "description": "Successful streaming response with reasoning steps", + "content": {"text/event-stream": {}}, + }, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"description": "Internal server error during streaming"}, + }, +) +async def test_time_inference(request: TestTimeInferenceRequest): + """ + Test time inference endpoint that provides step-by-step reasoning. + + This endpoint streams the thinking process and reasoning steps during inference, + allowing visibility into how the model arrives at its conclusions. Each step of + the reasoning process is streamed as it becomes available. + + ## Request Parameters: + - **messages** (List[dict]): List of message objects with 'role' and 'content' keys. Required. + - **model** (str, optional): Optional model identifier to use for inference. + - **uids** (List[int], optional): Optional list of specific miner UIDs to query. + + ## Response: + The response is streamed as server-sent events (SSE) with each step of reasoning. + Each event contains: + - A step title/heading + - The content of the reasoning step + - Timing information (debug only) + + Example request: + ```json + { + "messages": [ + {"role": "user", "content": "Solve the equation: 3x + 5 = 14"} + ], + "model": "gpt-4" + } + ``` + """ -@router.post("/test_time_inference") -async def test_time_inference(messages: list[dict], model: str = None, target_uids: list[str] = None): - async def create_response_stream(messages): - async for steps, total_thinking_time in generate_response(messages, model=model, target_uids=target_uids): + async def create_response_stream(request): + async for steps, total_thinking_time in generate_response( + request.messages, model=request.model, uids=request.uids + ): if total_thinking_time is not None: logger.debug(f"**Total thinking time: {total_thinking_time:.2f} seconds**") yield steps, total_thinking_time @@ -148,12 +279,12 @@ async def create_response_stream(messages): async def stream_steps(): try: i = 0 - async for steps, thinking_time in create_response_stream(messages): + async for steps, thinking_time in create_response_stream(request): i += 1 yield "data: " + ChatCompletionChunk( id=str(uuid.uuid4()), created=int(time.time()), - model=model or "None", + model=request.model or "None", object="chat.completion.chunk", choices=[ Choice(index=i, delta=ChoiceDelta(content=f"## {steps[-1][0]}\n\n{steps[-1][1]}" + "\n\n")) diff --git a/validator_api/mixture_of_miners.py b/validator_api/mixture_of_miners.py index b5742e185..e26c018d3 100644 --- a/validator_api/mixture_of_miners.py +++ b/validator_api/mixture_of_miners.py @@ -7,6 +7,7 @@ from fastapi.responses import StreamingResponse from loguru import logger +from shared.settings import shared_settings from shared.uids import get_uids from validator_api.chat_completion import chat_completion, get_response_from_miner @@ -45,6 +46,7 @@ async def mixture_of_miners(body: dict[str, any], uids: list[int]) -> tuple | St if len(uids) == 0: raise HTTPException(status_code=503, detail="No available miners found") + body["sampling_parameters"] = body.get("sampling_parameters", shared_settings.SAMPLING_PARAMS) # Concurrently collect responses from all miners. timeout_seconds = max( 30, max(0, math.floor(math.log2(body["sampling_parameters"].get("max_new_tokens", 256) / 256))) * 10 + 30 diff --git a/validator_api/scoring_queue.py b/validator_api/scoring_queue.py index 0ee876d6d..220ae92fb 100644 --- a/validator_api/scoring_queue.py +++ b/validator_api/scoring_queue.py @@ -58,6 +58,11 @@ async def run_step(self): except Exception as e: logger.exception(f"Could not find available validator scoring endpoint: {e}") try: + if hasattr(payload, "to_dict"): + payload = payload.to_dict() + elif isinstance(payload, BaseModel): + payload = payload.model_dump() + timeout = httpx.Timeout(timeout=120.0, connect=60.0, read=30.0, write=30.0, pool=5.0) # Add required headers for signature verification @@ -94,7 +99,7 @@ async def append_response( # logger.debug(f"Skipping forwarding for non-inference/web retrieval task: {body.get('task')}") return - uids = [int(u) for u in uids] + uids = list(map(int, uids)) chunk_dict = {str(u): c for u, c in zip(uids, chunks)} if timings: timing_dict = {str(u): t for u, t in zip(uids, timings)} diff --git a/validator_api/serializers.py b/validator_api/serializers.py new file mode 100644 index 000000000..d0355ea76 --- /dev/null +++ b/validator_api/serializers.py @@ -0,0 +1,112 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class CompletionsRequest(BaseModel): + """Request model for the /v1/chat/completions endpoint.""" + + uids: Optional[List[int]] = Field( + default=None, + description="List of specific miner UIDs to query. If not provided, miners will be selected automatically.", + example=[1, 2, 3], + ) + messages: List[Dict[str, str]] = Field( + ..., + description="List of message objects with 'role' and 'content' keys. Roles can be 'system', 'user', or 'assistant'.", + example=[{"role": "user", "content": "Tell me about neural networks"}], + ) + seed: Optional[int] = Field( + default=None, + description="Random seed for reproducible results. If not provided, a random seed will be generated.", + example=42, + ) + task: Optional[str] = Field( + default=None, description="Task identifier to choose the inference type.", example="InferenceTask" + ) + model: Optional[str] = Field( + default=None, + description="Model identifier to filter available miners.", + example="hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", + ) + test_time_inference: bool = Field( + default=False, description="Enable step-by-step reasoning mode that shows the model's thinking process." + ) + mixture: bool = Field( + default=False, description="Enable mixture of miners mode that combines responses from multiple miners." + ) + sampling_parameters: Optional[Dict[str, Any]] = Field( + default=None, + description="Parameters to control text generation, such as temperature, top_p, etc.", + example={ + "temperature": 0.7, + "top_p": 0.95, + "top_k": 50, + "max_new_tokens": 512, + "do_sample": True, + }, + ) + + +class WebRetrievalRequest(BaseModel): + """Request model for the /web_retrieval endpoint.""" + + uids: Optional[List[int]] = Field( + default=None, + description="List of specific miner UIDs to query. If not provided, miners will be selected automatically.", + example=[1, 2, 3], + ) + search_query: str = Field( + ..., description="The query to search for on the web.", example="latest advancements in quantum computing" + ) + n_miners: int = Field(default=3, description="Number of miners to query for results.", example=15, ge=1) + n_results: int = Field( + default=1, description="Maximum number of results to return in the response.", example=5, ge=1 + ) + max_response_time: int = Field( + default=10, description="Maximum time to wait for responses in seconds.", example=15, ge=1 + ) + + +class WebSearchResult(BaseModel): + """Model for a single web search result.""" + + url: str = Field(..., description="The URL of the web page.", example="https://example.com/article") + content: Optional[str] = Field( + default=None, + description="The relevant content extracted from the page.", + example="Quantum computing has seen significant advancements in the past year...", + ) + relevant: Optional[str] = Field( + default=None, + description="Information about why this result is relevant to the query.", + example="This article discusses the latest breakthroughs in quantum computing research.", + ) + + +class WebRetrievalResponse(BaseModel): + """Response model for the /web_retrieval endpoint.""" + + results: List[WebSearchResult] = Field(..., description="List of unique web search results.") + + def to_dict(self): + return self.model_dump().update({"results": [r.model_dump() for r in self.results]}) + + +class TestTimeInferenceRequest(BaseModel): + """Request model for the /test_time_inference endpoint.""" + + uids: Optional[List[int]] = Field( + default=None, + description="List of specific miner UIDs to query. If not provided, miners will be selected automatically.", + example=[1, 2, 3], + ) + messages: List[Dict[str, str]] = Field( + ..., + description="List of message objects with 'role' and 'content' keys. Roles can be 'system', 'user', or 'assistant'.", + example=[{"role": "user", "content": "Solve the equation: 3x + 5 = 14"}], + ) + model: Optional[str] = Field(default=None, description="Model identifier to use for inference.", example="gpt-4") + + def to_dict(self): + return self.model_dump().update({"messages": [m.model_dump() for m in self.messages]}) diff --git a/validator_api/test_time_inference.py b/validator_api/test_time_inference.py index c7847fb65..1f084318c 100644 --- a/validator_api/test_time_inference.py +++ b/validator_api/test_time_inference.py @@ -58,11 +58,12 @@ def parse_multiple_json(api_response): f"Invalid JSON object found in the response - field missing. The miner response was: {api_response}" ) return None + return parsed_objects async def make_api_call( - messages, model=None, is_final_answer: bool = False, use_miners: bool = True, target_uids: list[str] = None + messages, model=None, is_final_answer: bool = False, use_miners: bool = True, uids: list[int] | None = None ): async def single_attempt(): try: @@ -81,7 +82,7 @@ async def single_attempt(): "seed": random.randint(0, 1000000), }, num_miners=3, - uids=target_uids, + uids=uids, ) response_str = response.choices[0].message.content else: @@ -147,7 +148,7 @@ async def single_attempt(): async def generate_response( - original_messages: list[dict[str, str]], model: str = None, target_uids: list[str] = None, use_miners: bool = True + original_messages: list[dict[str, str]], model: str = None, uids: list[int] | None = None, use_miners: bool = True ): messages = [ { @@ -236,7 +237,7 @@ async def generate_response( for _ in range(MAX_THINKING_STEPS): with Timer() as timer: - step_data = await make_api_call(messages, model=model, use_miners=use_miners, target_uids=target_uids) + step_data = await make_api_call(messages, model=model, use_miners=use_miners, uids=uids) thinking_time = timer.final_time total_thinking_time += thinking_time @@ -272,9 +273,7 @@ async def generate_response( ) start_time = time.time() - final_data = await make_api_call( - messages, model=model, is_final_answer=True, use_miners=use_miners, target_uids=target_uids - ) + final_data = await make_api_call(messages, model=model, is_final_answer=True, use_miners=use_miners, uids=uids) end_time = time.time() thinking_time = end_time - start_time