Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api_keys.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{}
{}
6 changes: 6 additions & 0 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def start_api(scoring_queue, reward_events):
async def start():
from prompting.api.api import start_scoring_api # noqa: F401

# TODO: We should not use 2 availability loops for each process, in reality
# we should only be sharing the miner availability data between processes.
from prompting.miner_availability.miner_availability import availability_checking_loop

asyncio.create_task(availability_checking_loop.start())

await start_scoring_api(scoring_queue, reward_events)

while True:
Expand Down
6 changes: 6 additions & 0 deletions prompting/miner_availability/miner_availability.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ async def run_step(self):
if self.current_index >= len(self.uids):
self.current_index = 0

tracked_miners = [m for m in miner_availabilities.miners.values() if m is not None]
logger.debug(
f"TRACKED MINERS: {len(tracked_miners)} --- UNTRACKED MINERS: {len(self.uids) - len(tracked_miners)}"
)
if tracked_miners:
logger.debug(f"SAMPLE MINER: {tracked_miners[0]}")
await asyncio.sleep(0.1)


Expand Down
1 change: 0 additions & 1 deletion shared/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def log_event(event: BaseEvent):
reinit_wandb()
unpacked_event = unpack_events(event)
unpacked_event = convert_arrays_to_lists(unpacked_event)
logger.debug(f"""LOGGING WANDB EVENT: {unpacked_event}""")
wandb.log(unpacked_event)


Expand Down
15 changes: 10 additions & 5 deletions validator_api/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import uvicorn
from fastapi import FastAPI

Expand All @@ -8,28 +10,31 @@

from validator_api.api_management import router as api_management_router
from validator_api.gpt_endpoints import router as gpt_router
from validator_api.utils import update_miner_availabilities_for_api

app = FastAPI()
app.include_router(gpt_router, tags=["GPT Endpoints"])
app.include_router(api_management_router, tags=["API Management"])

# TODO: This api requests miner availabilities from validator
# TODO: Forward the results from miners to the validator


@app.get("/health")
async def health():
return {"status": "ok"}


def main():
async def main():
asyncio.create_task(update_miner_availabilities_for_api.start())
uvicorn.run(
"validator_api.api:app",
app,
host=shared_settings.API_HOST,
port=shared_settings.API_PORT,
log_level="debug",
timeout_keep_alive=60,
workers=shared_settings.WORKERS,
reload=False,
)


if __name__ == "__main__":
main()
asyncio.run(main())
26 changes: 15 additions & 11 deletions validator_api/gpt_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@

from shared.epistula import SynapseStreamResult, query_miners
from shared.settings import shared_settings
from shared.uids import get_uids
from validator_api.api_management import _keys
from validator_api.chat_completion import chat_completion
from validator_api.mixture_of_miners import mixture_of_miners
from validator_api.test_time_inference import generate_response
from validator_api.utils import forward_response
from validator_api.utils import filter_available_uids, forward_response

router = APIRouter()
N_MINERS = 5


def validate_api_key(api_key: str = Header(...)):
Expand All @@ -34,14 +34,18 @@ async def completions(request: Request, api_key: str = Depends(validate_api_key)
try:
body = await request.json()
body["seed"] = int(body.get("seed") or random.randint(0, 1000000))
uids = body.get("uids") or filter_available_uids(task=body.get("task"), model=body.get("model"))
if not uids:
raise HTTPException(status_code=500, detail="No available miners")
uids = random.sample(uids, min(len(uids), N_MINERS))

# Choose between regular completion and mixture of miners.
if body.get("test_time_inference", False):
return await test_time_inference(body["messages"], body.get("model"))
if body.get("mixture", False):
return await mixture_of_miners(body)
return await mixture_of_miners(body, uids=uids)
else:
return await chat_completion(body)
return await chat_completion(body, uids=uids)

except Exception as e:
logger.exception(f"Error in chat completion: {e}")
Expand All @@ -50,7 +54,11 @@ async def completions(request: Request, api_key: str = Depends(validate_api_key)

@router.post("/web_retrieval")
async def web_retrieval(search_query: str, n_miners: int = 10, uids: list[int] = None):
uids = list(get_uids(sampling_mode="random", k=n_miners))
if not uids:
uids = filter_available_uids(task="WebRetrievalTask")
if not uids:
raise HTTPException(status_code=500, detail="No available miners")
uids = random.sample(uids, min(len(uids), n_miners))
logger.debug(f"🔍 Querying uids: {uids}")
if len(uids) == 0:
logger.warning("No available miners. This should already have been caught earlier.")
Expand Down Expand Up @@ -86,12 +94,8 @@ async def web_retrieval(search_query: str, n_miners: int = 10, uids: list[int] =
if len(loaded_results) == 0:
raise HTTPException(status_code=500, detail="No miner responded successfully")

for uid, res in zip(uids, stream_results):
asyncio.create_task(
forward_response(
uid=uid, body=body, chunks=res.accumulated_chunks if res and res.accumulated_chunks else []
)
)
chunks = [res.accumulated_chunks if res and res.accumulated_chunks else [] for res in stream_results]
asyncio.create_task(forward_response(uids=uids, body=body, chunks=chunks))
return loaded_results


Expand Down
9 changes: 5 additions & 4 deletions validator_api/mixture_of_miners.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def get_miner_response(body: dict, uid: str, timeout_seconds: int) -> tupl
return None


async def mixture_of_miners(body: dict[str, any]) -> tuple | StreamingResponse:
async def mixture_of_miners(body: dict[str, any], uids: list[int]) -> tuple | StreamingResponse:
"""Handle chat completion with mixture of miners approach.

Based on Mixture-of-Agents Enhances Large Language Model Capabilities, 2024, Wang et al.:
Expand All @@ -53,15 +53,16 @@ async def mixture_of_miners(body: dict[str, any]) -> tuple | StreamingResponse:
body_first_step["stream"] = False

# Get multiple miners
miner_uids = get_uids(sampling_mode="top_incentive", k=NUM_MIXTURE_MINERS)
if len(miner_uids) == 0:
if not uids:
uids = get_uids(sampling_mode="top_incentive", k=NUM_MIXTURE_MINERS)
if len(uids) == 0:
raise HTTPException(status_code=503, detail="No available miners found")

# Concurrently collect responses from all miners.
timeout_seconds = max(
30, max(0, math.floor(math.log2(body["sampling_parameters"].get("max_new_tokens", 256) / 256))) * 10 + 30
)
miner_tasks = [get_miner_response(body_first_step, uid, timeout_seconds=timeout_seconds) for uid in miner_uids]
miner_tasks = [get_miner_response(body_first_step, uid, timeout_seconds=timeout_seconds) for uid in uids]
responses = await asyncio.gather(*miner_tasks)

# Filter out None responses (failed requests).
Expand Down
65 changes: 64 additions & 1 deletion validator_api/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,75 @@
import httpx
import requests
from loguru import logger

from shared.loop_runner import AsyncLoopRunner
from shared.settings import shared_settings
from shared.uids import get_uids


class UpdateMinerAvailabilitiesForAPI(AsyncLoopRunner):
miner_availabilities: dict[int, dict] = {}

async def run_step(self):
try:
response = requests.post(
f"http://{shared_settings.VALIDATOR_API}/miner_availabilities/miner_availabilities",
headers={"accept": "application/json", "Content-Type": "application/json"},
json=get_uids(sampling_mode="all"),
timeout=10,
)

self.miner_availabilities = response.json()
except Exception as e:
logger.exception(f"Error while updating miner availabilities for API: {e}")
tracked_availabilities = [m for m in self.miner_availabilities.values() if m is not None]
logger.debug(
f"MINER AVAILABILITIES UPDATED, TRACKED: {len(tracked_availabilities)}, UNTRACKED: {len(self.miner_availabilities) - len(tracked_availabilities)}"
)


update_miner_availabilities_for_api = UpdateMinerAvailabilitiesForAPI()


def filter_available_uids(task: str | None = None, model: str | None = None) -> list[int]:
"""
Filter UIDs based on task and model availability.

Args:
uids: List of UIDs to filter
task: Task type to check availability for, or None if any task is acceptable
model: Model name to check availability for, or None if any model is acceptable

Returns:
List of UIDs that can serve the requested task/model combination
"""
filtered_uids = []

for uid in get_uids(sampling_mode="all"):
# Skip if miner data is None/unavailable
if update_miner_availabilities_for_api.miner_availabilities.get(str(uid)) is None:
continue

miner_data = update_miner_availabilities_for_api.miner_availabilities[str(uid)]

# Check task availability if specified
if task is not None:
if not miner_data["task_availabilities"].get(task, False):
continue

# Check model availability if specified
if model is not None:
if not miner_data["llm_model_availabilities"].get(model, False):
continue

filtered_uids.append(uid)

return filtered_uids


# TODO: Modify this so that all the forwarded responses are sent in a single request. This is both more efficient but
# also means that on the validator side all responses are scored at once, speeding up the scoring process.
async def forward_response(uids: int, body: dict[str, any], chunks: list[str]):
async def forward_response(uids: list[int], body: dict[str, any], chunks: list[list[str]]):
uids = [int(u) for u in uids]
chunk_dict = {u: c for u, c in zip(uids, chunks)}
logger.info(f"Forwarding response from uid {uids} to scoring with body: {body} and chunks: {chunks}")
Expand Down