Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ wandb
.vscode
api_keys.json
prompting/api/api_keys.json
weights.csv
139 changes: 102 additions & 37 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import multiprocessing as mp
import time

import torch
from loguru import logger

from prompting.api.api import start_scoring_api
Expand All @@ -20,48 +21,112 @@
from prompting.weight_setting.weight_setter import weight_setter
from shared.profiling import profiler

torch.multiprocessing.set_start_method("spawn", force=True)

NEURON_SAMPLE_SIZE = 100


def create_loop_process(task_queue, scoring_queue, reward_events):
async def spawn_loops(task_queue, scoring_queue, reward_events):
logger.info("Starting Profiler...")
asyncio.create_task(profiler.print_stats(), name="Profiler"),
logger.info("Starting ModelScheduler...")
asyncio.create_task(model_scheduler.start(scoring_queue), name="ModelScheduler"),
logger.info("Starting TaskScorer...")
asyncio.create_task(task_scorer.start(scoring_queue, reward_events), name="TaskScorer"),
logger.info("Starting WeightSetter...")
asyncio.create_task(weight_setter.start(reward_events))

# Main monitoring loop
start = time.time()

logger.info("Starting Main Monitoring Loop...")
while True:
await asyncio.sleep(5)
current_time = time.time()
time_diff = current_time - start
start = current_time

# Check if all tasks are still running
logger.debug(f"Running {time_diff:.2f} seconds")
logger.debug(f"Number of tasks in Task Queue: {len(task_queue)}")
logger.debug(f"Number of tasks in Scoring Queue: {len(scoring_queue)}")
logger.debug(f"Number of tasks in Reward Events: {len(reward_events)}")

asyncio.run(spawn_loops(task_queue, scoring_queue, reward_events))


def start_api():
async def start():
await start_scoring_api()
while True:
await asyncio.sleep(10)
logger.debug("Running API...")

asyncio.run(start())


def create_task_loop(task_queue, scoring_queue):
async def start(task_queue, scoring_queue):
logger.info("Starting AvailabilityCheckingLoop...")
asyncio.create_task(availability_checking_loop.start())

logger.info("Starting TaskSender...")
asyncio.create_task(task_sender.start(task_queue, scoring_queue))

logger.info("Starting TaskLoop...")
asyncio.create_task(task_loop.start(task_queue, scoring_queue))
while True:
await asyncio.sleep(10)
logger.debug("Running task loop...")

asyncio.run(start(task_queue, scoring_queue))


async def main():
# will start checking the availability of miners at regular intervals, needed for API and Validator
asyncio.create_task(availability_checking_loop.start())

if shared_settings.DEPLOY_SCORING_API:
# Use multiprocessing to bypass API blocking issue.
api_process = mp.Process(target=lambda: asyncio.run(start_scoring_api()))
api_process.start()

GPUInfo.log_gpu_info()
# start profiling
asyncio.create_task(profiler.print_stats())

# start rotating LLM models
asyncio.create_task(model_scheduler.start())

# start creating tasks
asyncio.create_task(task_loop.start())

# will start checking the availability of miners at regular intervals
asyncio.create_task(availability_checking_loop.start())

# start sending tasks to miners
asyncio.create_task(task_sender.start())

# sets weights at regular intervals (synchronised between all validators)
asyncio.create_task(weight_setter.start())

# start scoring tasks in separate loop
asyncio.create_task(task_scorer.start())
# # TODO: Think about whether we want to store the task queue locally in case of a crash
# # TODO: Possibly run task scorer & model scheduler with a lock so I don't unload a model whilst it's generating
# # TODO: Make weight setting happen as specific intervals as we load/unload models
start = time.time()
await asyncio.sleep(60)
while True:
await asyncio.sleep(5)
time_diff = -start + (start := time.time())
logger.debug(f"Running {time_diff:.2f} seconds")
with torch.multiprocessing.Manager() as manager:
reward_events = manager.list()
scoring_queue = manager.list()
task_queue = manager.list()

# Create process pool for managed processes
processes = []

try:
# # Start checking the availability of miners at regular intervals

if shared_settings.DEPLOY_SCORING_API:
# Use multiprocessing to bypass API blocking issue
api_process = mp.Process(target=start_api, name="API_Process")
api_process.start()
processes.append(api_process)

loop_process = mp.Process(
target=create_loop_process, args=(task_queue, scoring_queue, reward_events), name="LoopProcess"
)
task_loop_process = mp.Process(
target=create_task_loop, args=(task_queue, scoring_queue), name="TaskLoopProcess"
)
loop_process.start()
task_loop_process.start()
processes.append(loop_process)
processes.append(task_loop_process)
GPUInfo.log_gpu_info()

while True:
await asyncio.sleep(10)
logger.debug("Running...")

except Exception as e:
logger.error(f"Main loop error: {e}")
raise
finally:
# Clean up processes
for process in processes:
if process.is_alive():
process.terminate()
process.join()


# The main function parses the configuration and runs the validator.
Expand Down
2 changes: 1 addition & 1 deletion prompting/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def health():


async def start_scoring_api():
logger.info("Starting API...")
logger.info(f"Starting Scoring API on https://0.0.0.0:{shared_settings.SCORING_API_PORT}")
uvicorn.run(
"prompting.api.api:app", host="0.0.0.0", port=shared_settings.SCORING_API_PORT, loop="asyncio", reload=False
)
21 changes: 19 additions & 2 deletions prompting/api/scoring/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid
from typing import Any

from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from loguru import logger

from prompting.llms.model_zoo import ModelZoo
from prompting.rewards.scoring import task_scorer
Expand All @@ -14,10 +15,25 @@
router = APIRouter()


def validate_scoring_key(api_key: str = Header(...)):
if api_key != shared_settings.SCORING_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")


@router.post("/scoring")
async def score_response(request: Request): # , api_key_data: dict = Depends(validate_api_key)):
async def score_response(request: Request, api_key_data: dict = Depends(validate_scoring_key)):
model = None
payload: dict[str, Any] = await request.json()
body = payload.get("body")

try:
if body.get("model") is not None:
model = ModelZoo.get_model_by_id(body.get("model"))
except Exception:
logger.warning(
f"Organic request with model {body.get('model')} made but the model cannot be found in model zoo. Skipping scoring."
)
return
uid = int(payload.get("uid"))
chunks = payload.get("chunks")
llm_model = ModelZoo.get_model_by_id(model) if (model := body.get("model")) else None
Expand All @@ -39,3 +55,4 @@ async def score_response(request: Request): # , api_key_data: dict = Depends(va
step=-1,
task_id=str(uuid.uuid4()),
)
logger.info("Organic tas appended to scoring queue")
10 changes: 7 additions & 3 deletions prompting/llms/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from prompting.llms.hf_llm import ReproducibleHF
from prompting.llms.model_zoo import ModelConfig, ModelZoo
from prompting.llms.utils import GPUInfo
from prompting.mutable_globals import scoring_queue
from shared.loop_runner import AsyncLoopRunner
from shared.settings import shared_settings

Expand Down Expand Up @@ -158,14 +157,19 @@ def generate(
class AsyncModelScheduler(AsyncLoopRunner):
llm_model_manager: ModelManager
interval: int = 14400
scoring_queue: list | None = None

async def start(self, scoring_queue: list):
self.scoring_queue = scoring_queue
return await super().start()

async def initialise_loop(self):
model_manager.load_always_active_models()

async def run_step(self):
"""This method is called periodically according to the interval."""
# try to load the model belonging to the oldest task in the queue
selected_model = scoring_queue[0].task.llm_model if scoring_queue else None
selected_model = self.scoring_queue[0].task.llm_model if self.scoring_queue else None
if not selected_model:
selected_model = ModelZoo.get_random(max_ram=self.llm_model_manager.total_ram)
logger.info(f"Loading model {selected_model.llm_model_id} for {self.interval} seconds.")
Expand All @@ -174,7 +178,7 @@ async def run_step(self):
logger.info(f"Model {selected_model.llm_model_id} is already loaded.")
return

logger.debug(f"Active models: {model_manager.active_models.keys()}")
logger.debug(f"Active models: {self.llm_model_manager.active_models.keys()}")
# Load the selected model
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self.llm_model_manager.load_model, selected_model)
Expand Down
22 changes: 13 additions & 9 deletions prompting/rewards/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from loguru import logger
from pydantic import ConfigDict

from prompting import mutable_globals
from prompting.llms.model_manager import model_manager, model_scheduler
from prompting.tasks.base_task import BaseTextTask
from prompting.tasks.task_registry import TaskRegistry
Expand Down Expand Up @@ -33,9 +32,16 @@ class TaskScorer(AsyncLoopRunner):
is_running: bool = False
thread: threading.Thread = None
interval: int = 10
scoring_queue: list | None = None
reward_events: list | None = None

model_config = ConfigDict(arbitrary_types_allowed=True)

async def start(self, scoring_queue, reward_events):
self.scoring_queue = scoring_queue
self.reward_events = reward_events
return await super().start()

def add_to_queue(
self,
task: BaseTextTask,
Expand All @@ -45,7 +51,7 @@ def add_to_queue(
step: int,
task_id: str,
) -> None:
mutable_globals.scoring_queue.append(
self.scoring_queue.append(
ScoringConfig(
task=task,
response=response,
Expand All @@ -55,26 +61,24 @@ def add_to_queue(
task_id=task_id,
)
)
logger.debug(
f"SCORING: Added to queue: {task.__class__.__name__}. Queue size: {len(mutable_globals.scoring_queue)}"
)
logger.debug(f"SCORING: Added to queue: {task.__class__.__name__}. Queue size: {len(self.scoring_queue)}")

async def run_step(self) -> RewardLoggingEvent:
await asyncio.sleep(0.1)
# Only score responses for which the model is loaded
scorable = [
scoring_config
for scoring_config in mutable_globals.scoring_queue
for scoring_config in self.scoring_queue
if (scoring_config.task.llm_model in model_manager.active_models.keys())
or (scoring_config.task.llm_model is None)
]
if len(scorable) == 0:
logger.debug("Nothing to score. Skipping scoring step.")
# Run a model_scheduler step to load a new model as there are no more tasks to be scored
if len(mutable_globals.scoring_queue) > 0:
if len(self.scoring_queue) > 0:
await model_scheduler.run_step()
return
mutable_globals.scoring_queue.remove(scorable[0])
self.scoring_queue.remove(scorable[0])
scoring_config: ScoringConfig = scorable.pop(0)

# here we generate the actual reference
Expand All @@ -94,7 +98,7 @@ async def run_step(self) -> RewardLoggingEvent:
model_id=scoring_config.task.llm_model,
task=scoring_config.task,
)
mutable_globals.reward_events.append(reward_events)
self.reward_events.append(reward_events)
logger.debug(
f"REFERENCE: {scoring_config.task.reference}\n\n||||RESPONSES: {scoring_config.response.completions}"
)
Expand Down
17 changes: 12 additions & 5 deletions prompting/tasks/task_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pydantic import ConfigDict

from prompting.miner_availability.miner_availability import miner_availabilities
from prompting.mutable_globals import scoring_queue, task_queue
from prompting.tasks.task_registry import TaskRegistry
from shared.logging import ErrorLoggingEvent, ValidatorLoggingEvent
from shared.loop_runner import AsyncLoopRunner
Expand All @@ -18,14 +17,20 @@ class TaskLoop(AsyncLoopRunner):
is_running: bool = False
thread: threading.Thread = None
interval: int = 10

task_queue: list | None = []
scoring_queue: list | None = []
model_config = ConfigDict(arbitrary_types_allowed=True)

async def start(self, task_queue, scoring_queue):
self.task_queue = task_queue
self.scoring_queue = scoring_queue
await super().start()

async def run_step(self) -> ValidatorLoggingEvent | ErrorLoggingEvent | None:
if len(task_queue) > shared_settings.TASK_QUEUE_LENGTH_THRESHOLD:
if len(self.task_queue) > shared_settings.TASK_QUEUE_LENGTH_THRESHOLD:
logger.debug("Task queue is full. Skipping task generation.")
return None
if len(scoring_queue) > shared_settings.SCORING_QUEUE_LENGTH_THRESHOLD:
if len(self.scoring_queue) > shared_settings.SCORING_QUEUE_LENGTH_THRESHOLD:
logger.debug("Scoring queue is full. Skipping task generation.")
return None
await asyncio.sleep(0.1)
Expand Down Expand Up @@ -55,7 +60,9 @@ async def run_step(self) -> ValidatorLoggingEvent | ErrorLoggingEvent | None:
if not task.query:
logger.debug(f"Generating query for task: {task.__class__.__name__}.")
task.make_query(dataset_entry=dataset_entry)
task_queue.append(task)

logger.debug(f"Appending task: {task.__class__.__name__} to task queue.")
self.task_queue.append(task)
except Exception as ex:
logger.exception(ex)
return None
Expand Down
Loading
Loading