Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
697 changes: 247 additions & 450 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions prompting/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ def health():


async def start_scoring_api(task_scorer, scoring_queue, reward_events):
# We pass an object of task scorer then override it's attributes to ensure that they are managed by mp
app.state.task_scorer = task_scorer
app.state.task_scorer.scoring_queue = scoring_queue
app.state.task_scorer.reward_events = reward_events

logger.info(f"Starting Scoring API on https://0.0.0.0:{settings.shared_settings.SCORING_API_PORT}")
uvicorn.run(
config = uvicorn.Config(
"prompting.api.api:app",
host="0.0.0.0",
port=settings.shared_settings.SCORING_API_PORT,
loop="asyncio",
reload=False,
)
server = uvicorn.Server(config)
await server.serve()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ docstring-code-line-length = "dynamic"
max-complexity = 14

[tool.poetry.dependencies]
bittensor = "8.5.1"
bittensor = "9.0.0"
tqdm = "^4.66.4"
requests = "^2.32.3"
python = ">=3.10 <3.11"
Expand Down
2 changes: 1 addition & 1 deletion scripts/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def make_completion(client: openai.AsyncOpenAI, prompt: str, stream: bool
async def main():
PORT = 8005
# Example API key, replace with yours:
API_KEY = "0566dbe21ee33bba9419549716cd6f1f"
API_KEY = ""
client = openai.AsyncOpenAI(
base_url=f"http://0.0.0.0:{PORT}/v1",
max_retries=0,
Expand Down
6 changes: 3 additions & 3 deletions shared/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,17 @@ class SharedSettings(BaseSettings):
# API key used to access validator organic scoring mechanism (both .env.validator and .env.api).
SCORING_KEY: str | None = Field(None, env="SCORING_KEY")
# Scoring request rate limit in seconds.
SCORING_RATE_LIMIT_SEC: float = Field(0.5, env="SCORING_RATE_LIMIT_SEC")
SCORING_RATE_LIMIT_SEC: float = Field(5, env="SCORING_RATE_LIMIT_SEC")
# Scoring queue threshold when rate-limit start to kick in, used to query validator API with scoring requests.
SCORING_QUEUE_API_THRESHOLD: int = Field(5, env="SCORING_QUEUE_API_THRESHOLD")
SCORING_QUEUE_API_THRESHOLD: int = Field(1, env="SCORING_QUEUE_API_THRESHOLD")
API_TEST_MODE: bool = Field(False, env="API_TEST_MODE")

# Validator scoring API (.env.validator).
DEPLOY_SCORING_API: bool = Field(False, env="DEPLOY_SCORING_API")
SCORING_API_PORT: int = Field(8094, env="SCORING_API_PORT")
SCORING_ADMIN_KEY: str | None = Field(None, env="SCORING_ADMIN_KEY")
SCORE_ORGANICS: bool = Field(False, env="SCORE_ORGANICS")
WORKERS: int = Field(2, env="WORKERS")
WORKERS: int = Field(1, env="WORKERS")

# API Management (.env.api).
API_PORT: int = Field(8005, env="API_PORT")
Expand Down
10 changes: 7 additions & 3 deletions validator_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: start the background tasks.
background_task = asyncio.create_task(update_miner_availabilities_for_api.start())
yield
background_task.cancel()
Expand All @@ -27,7 +26,6 @@ async def lifespan(app: FastAPI):
pass


# Create the FastAPI app with the lifespan handler.
app = FastAPI(lifespan=lifespan)
app.include_router(gpt_router, tags=["GPT Endpoints"])
app.include_router(api_management_router, tags=["API Management"])
Expand All @@ -41,15 +39,21 @@ async def health():
async def main():
asyncio.create_task(update_miner_availabilities_for_api.start())
asyncio.create_task(scoring_queue.scoring_queue.start())
uvicorn.run(

config = uvicorn.Config(
"validator_api.api:app",
host=shared_settings.API_HOST,
port=shared_settings.API_PORT,
log_level="debug",
timeout_keep_alive=60,
# Note: The `workers` parameter is typically only supported via the CLI.
# When running programmatically with `server.serve()`, only a single worker will run.
workers=shared_settings.WORKERS,
reload=False,
)
server = uvicorn.Server(config)

await server.serve()


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion validator_api/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from fastapi.responses import StreamingResponse
from loguru import logger

from shared import settings

shared_settings = settings.shared_settings

from shared.epistula import make_openai_query
from shared.settings import shared_settings
from validator_api import scoring_queue
from validator_api.utils import filter_available_uids

Expand Down
4 changes: 2 additions & 2 deletions validator_api/gpt_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from starlette.responses import StreamingResponse

from shared import settings

shared_settings = settings.shared_settings
from shared.epistula import SynapseStreamResult, query_miners
from validator_api import scoring_queue
from validator_api.api_management import _keys
Expand All @@ -19,8 +21,6 @@
from validator_api.test_time_inference import generate_response
from validator_api.utils import filter_available_uids

shared_settings = settings.shared_settings

router = APIRouter()
N_MINERS = 5

Expand Down
4 changes: 3 additions & 1 deletion validator_api/miner_availabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from pydantic import BaseModel

from shared import settings

shared_settings = settings.shared_settings

from shared.loop_runner import AsyncLoopRunner
from shared.uids import get_uids

shared_settings = settings.shared_settings
router = APIRouter()


Expand Down
3 changes: 2 additions & 1 deletion validator_api/scoring_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from pydantic import BaseModel

from shared import settings
from shared.loop_runner import AsyncLoopRunner

shared_settings = settings.shared_settings

from shared.loop_runner import AsyncLoopRunner


class ScoringPayload(BaseModel):
payload: dict[str, Any]
Expand Down
50 changes: 34 additions & 16 deletions validator_api/test_time_inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import random
import re
Expand All @@ -9,6 +10,7 @@
from validator_api.chat_completion import chat_completion

MAX_THINKING_STEPS = 10
ATTEMPTS_PER_STEP = 10


def parse_multiple_json(api_response):
Expand Down Expand Up @@ -74,7 +76,8 @@ async def single_attempt():
"max_new_tokens": 500,
},
"seed": (seed := random.randint(0, 1000000)),
}
},
num_miners=3,
)
logger.debug(
f"Making API call with\n\nMESSAGES: {messages}\n\nSEED: {seed}\n\nRESPONSE: {response.choices[0].message.content}"
Expand All @@ -85,21 +88,36 @@ async def single_attempt():
logger.error(f"Failed to get valid response: {e}")
return None

# Create three concurrent tasks
try:
return await single_attempt()
except Exception as e:
if is_final_answer:
return {
"title": "Error",
"content": f"Failed to generate final answer. Error: {e}",
}
else:
return {
"title": "Error",
"content": f"Failed to generate step. Error: {e}",
"next_action": "final_answer",
}
# Create three concurrent tasks for more robustness against invalid jsons
tasks = [asyncio.create_task(single_attempt()) for _ in range(ATTEMPTS_PER_STEP)]

# As each task completes, check if it was successful
for completed_task in asyncio.as_completed(tasks):
try:
result = await completed_task
if result is not None:
# Cancel remaining tasks
for task in tasks:
task.cancel()
return result
except Exception as e:
logger.error(f"Task failed with error: {e}")
continue

# If all tasks failed, return error response
error_msg = "All concurrent API calls failed"
logger.error(error_msg)
if is_final_answer:
return {
"title": "Error",
"content": f"Failed to generate final answer. Error: {error_msg}",
}
else:
return {
"title": "Error",
"content": f"Failed to generate step. Error: {error_msg}",
"next_action": "final_answer",
}


async def generate_response(original_messages: list[dict[str, str]], model: str = None):
Expand Down