Skip to content
Merged
107 changes: 79 additions & 28 deletions neurons/validator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import asyncio
import atexit
import os
import signal
import sys
import time
from multiprocessing.managers import AcquirerProxy
from multiprocessing.synchronize import Event

import netaddr
import psutil
import requests
import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -102,6 +106,7 @@ def start_api(
scoring_queue: list,
reward_events: list,
miners_dict: dict,
event_stop: Event,
):
from prompting.api.api import start_scoring_api # noqa: F401

Expand All @@ -124,7 +129,7 @@ async def start():
logger.warning(f"Failed to serve scoring api to chain: {e}")
await start_scoring_api(task_scorer, scoring_queue, reward_events, miners_dict)

while True:
while not event_stop.is_set():
await asyncio.sleep(10)

asyncio.run(start())
Expand All @@ -134,6 +139,7 @@ def start_task_sending_loop(
task_queue: list,
scoring_queue: list,
miners_dict: dict,
event_stop: Event,
):
async def spawn_loops(task_queue, scoring_queue, miners_dict: dict):
from prompting.tasks.task_sending import TaskSender
Expand All @@ -142,7 +148,8 @@ async def spawn_loops(task_queue, scoring_queue, miners_dict: dict):
task_sender = TaskSender()
asyncio.create_task(task_sender.start(task_queue, scoring_queue, miners_dict, simultaneous_loops=1))
logger.debug("Task sending loop started")
while True:

while not event_stop.is_set():
await asyncio.sleep(5)
logger.debug("Task sending loop is running")

Expand All @@ -155,13 +162,13 @@ async def spawn_loops(task_queue, scoring_queue, miners_dict: dict):
raise


def start_availability_checking_loop(miners_dict: dict):
def start_availability_checking_loop(miners_dict: dict, event_stop: Event):
async def spawn_loops(miners_dict: dict):
from prompting.miner_availability.miner_availability import availability_checking_loop

logger.info("Starting availability checking loop in validator...")
asyncio.create_task(availability_checking_loop.start(miners_dict))
while True:
while not event_stop.is_set():
await asyncio.sleep(5)
logger.debug("Availability checking loop is running")

Expand All @@ -174,13 +181,13 @@ async def spawn_loops(miners_dict: dict):
raise


def start_weight_setter_loop(reward_events):
def start_weight_setter_loop(reward_events, event_stop: Event):
async def spawn_loops(reward_events):
from prompting.weight_setting.weight_setter import weight_setter

logger.info("Starting weight setter loop in validator...")
asyncio.create_task(weight_setter.start(reward_events))
while True:
while not event_stop.is_set():
await asyncio.sleep(5)
logger.debug("Weight setter loop is running")

Expand All @@ -193,6 +200,34 @@ async def spawn_loops(reward_events):
raise


def health_check(parent_pid: int, event_stop: Event):
"""Monitor parent process and kill all child processes in case of emergency."""
step = 0
while True:
try:
if not psutil.pid_exists(parent_pid):
event_stop.set()
logger.warning("Parent process died, killing all child processes")
os.killpg(0, signal.SIGKILL)

block = settings.shared_settings.block
if block - settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID] > 320 and step > 60:
event_stop.set()
last_update_block = settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID]
logger.warning(
f"Metagraph hasn't been updated for {block - last_update_block} blocks. "
f"Staled block: {block}, Last update: {last_update_block}"
)
os.killpg(0, signal.SIGKILL)
step += 1

except Exception as e:
logger.error(f"Failed to kill process group: {e}")
finally:
sys.exit(1)
time.sleep(60)


async def main(
cache_rewards: list | None = None,
cache_scores: list | None = None,
Expand All @@ -208,6 +243,7 @@ async def main(
mp_lock = manager.Lock()
processes: list[mp.Process] = []
tasks: list[asyncio.Task] = []
event_stop = mp.Event()

model_scheduler = AsyncModelScheduler(llm_model_manager=ModelManager(), mp_lock=mp_lock, sync=True)

Expand All @@ -216,15 +252,19 @@ async def main(
if settings.shared_settings.DEPLOY_SCORING_API and not settings.shared_settings.NEURON_DISABLE_SET_WEIGHTS:
# Use multiprocessing to bypass API blocking issue
api_process = mp.Process(
target=start_api, args=(scoring_queue, reward_events, miners_dict), name="APIProcess"
target=start_api,
args=(scoring_queue, reward_events, miners_dict, event_stop),
name="APIProcess",
daemon=True,
)
api_process.start()
processes.append(api_process)

availability_process = mp.Process(
target=start_availability_checking_loop,
args=(miners_dict,),
args=(miners_dict, event_stop),
name="AvailabilityProcess",
daemon=True,
)
availability_process.start()
processes.append(availability_process)
Expand All @@ -243,62 +283,73 @@ async def main(

sending_task = mp.Process(
target=start_task_sending_loop,
args=(task_queue, scoring_queue, miners_dict),
args=(task_queue, scoring_queue, miners_dict, event_stop),
name="SendingTaskProcess",
daemon=True,
)
sending_task.start()
processes.append(sending_task)

weight_setter_process = mp.Process(
target=start_weight_setter_loop,
args=(reward_events,),
args=(reward_events, event_stop),
name="WeightSetterProcess",
daemon=True,
)
weight_setter_process.start()
processes.append(weight_setter_process)

GPUInfo.log_gpu_info()
health_check_process = mp.Process(
target=health_check,
args=(os.getpid(), event_stop),
name="HealthCheckProcess",
daemon=True,
)
health_check_process.start()
processes.append(health_check_process)

step = 0
GPUInfo.log_gpu_info()
while True:
await asyncio.sleep(30)
block = settings.shared_settings.block
if (
block - settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID] > 500
and step > 150
):
last_update_block = settings.shared_settings.METAGRAPH.last_update[settings.shared_settings.UID]
logger.warning(
f"Metagraph hasn't been updated for {block - last_update_block} blocks. "
f"Staled block: {block}, Last update: {last_update_block}"
)
break
step += 1

except KeyboardInterrupt:
event_stop.set()
logger.info("KeyboardInterrupt detected. Shutting down gracefully...")
except Exception as e:
logger.error(f"Main loop error: {e}")
raise
finally:
logger.warning("🚨 Force‑killing entire process‑group")
logger.warning("🚨 Force‑killing entire process‑group")

# 1. Cancel in‑process tasks so they stop touching the Manager.
for t in tasks:
t.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
await asyncio.sleep(5)

# 2. Manager cleanup *first* (so its socket vanishes).
manager.shutdown()

# 3. Sledgehammer.
if os.name == "posix":
try:
os.killpg(0, signal.SIGKILL)
else:
logger.error(f"Unsupported OS: {os.name}")
except Exception as e:
logger.error(f"Failed to kill process group: {e}")
sys.exit(1)


def kill_process_group():
try:
os.killpg(os.getpgid(0), signal.SIGKILL)
except Exception as e:
logger.error(f"Failed to kill process group: {e}")


# The main function parses the configuration and runs the validator.
if __name__ == "__main__":
try:
os.setpgrp()
atexit.register(kill_process_group)
except BaseException:
logger.warning("Failed to set process group; emergency termination may not work.")
asyncio.run(main())
83 changes: 31 additions & 52 deletions prompting/llms/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,12 @@ async def __aexit__(self, exc_type, exc, tb):

class ModelManager(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
always_active_models: list[ModelConfig] = []
total_ram: float = settings.shared_settings.LLM_MODEL_RAM
active_models: dict[ModelConfig, ReproducibleVLLM] = {}
loading_tasks: dict[ModelConfig, asyncio.Future] = {}
used_ram: float = 0.0
lock: ClassVar[AsyncRLock] = AsyncRLock()

async def load_always_active_models(self):
for model_config in self.always_active_models:
await self.load_model(model_config=model_config)
# lock: ClassVar[AsyncRLock] = asyncio.Lock()

async def load_model(self, model_config: ModelConfig, force: bool = True) -> ReproducibleVLLM:
"""Load model into GPU.
Expand All @@ -69,56 +66,40 @@ async def load_model(self, model_config: ModelConfig, force: bool = True) -> Rep
force: If enabled, will unload all other models.
"""
async with self.lock:
if model_config in self.active_models.keys():
# Copy active models, since they will be modified in the loop.
active_models = set(self.active_models.keys())

if model_config in active_models:
logger.debug(f"Model {model_config.llm_model_id} is already loaded.")
return self.active_models[model_config]

if force:
logger.debug(f"Forcing model {model_config.llm_model_id} to load.")
for active_model in list(self.active_models.keys()):
if active_model in self.always_active_models:
continue
for active_model in active_models:
logger.debug(f"Unloading {active_model.llm_model_id} to make room for {model_config.llm_model_id}")

await self._unload_model(active_model)
await self.cleanup()

retries_max = 1
retry_counter = 0
retry_delay = 15
while True:
try:
GPUInfo.log_gpu_info()
model_class = model_factory(model_config.llm_model_id)
model = model_class(
model_id=model_config.llm_model_id,
device=settings.shared_settings.NEURON_DEVICE,
sampling_params=settings.shared_settings.SAMPLING_PARAMS,
)
self.used_ram += model_config.min_ram
logger.info(
f"Model {model_config.llm_model_id} has been successfully loaded. "
f"Approx. used VRAM: {self.used_ram:.0f}GB"
)
self.active_models[model_config] = model
await asyncio.sleep(1.0)
return model
except BaseException as e:
if retry_counter > retries_max:
logger.error(f"Failed to load model after {retries_max} retries. Terminating process")
await self.cleanup()
# In case of VRAM leak, raise an exception to terminate the process.
raise MemoryError

retry_counter += 1
retry_delay += retry_counter
await self.cleanup()
logger.error(
f"Failed to load model {model_config.llm_model_id}. Retrying in {retry_delay} seconds. "
f"Error: {str(e)}"
)
logger.debug(f"Current active models: {self.active_models}")
await asyncio.sleep(retry_delay)
try:
GPUInfo.log_gpu_info()
model_class = model_factory(model_config.llm_model_id)
model = model_class(
model_id=model_config.llm_model_id,
device=settings.shared_settings.NEURON_DEVICE,
sampling_params=settings.shared_settings.SAMPLING_PARAMS,
)
self.active_models[model_config] = model
self.used_ram += model_config.min_ram
logger.info(
f"Model {model_config.llm_model_id} has been successfully loaded. "
f"Approx. used VRAM: {self.used_ram:.0f}GB"
)
await asyncio.sleep(1.0)
return model
except BaseException as e:
await self.cleanup()
# In case of VRAM leak, raise an exception to terminate the process.
raise MemoryError(f"Failed to load model {model_config.llm_model_id}: {e}")

async def _cleanup_model(self, model_instance: ReproducibleVLLM, cpu_offload: bool = False):
"""Free VRAM from given model."""
Expand All @@ -144,12 +125,10 @@ async def _unload_model(self, model_config: ModelConfig):
return

try:
model_instance = self.active_models.pop(model_config)

# Record initial memory state for debugging.
initial_free_memory = GPUInfo.free_memory
logger.debug(f"Initial free GPU memory before unloading: {initial_free_memory} GB")

# async with self.rlock:
model_instance = self.active_models.pop(model_config)
await self._cleanup_model(model_instance, cpu_offload=False)
await self.cleanup()

Expand All @@ -167,13 +146,13 @@ async def _unload_model(self, model_config: ModelConfig):
async def get_model(self, llm_model: ModelConfig | str) -> ReproducibleVLLM:
async with self.lock:
if not llm_model:
llm_model = list(self.active_models.keys())[0] if self.active_models else ModelZoo.get_random()
llm_model = next(iter(self.active_models.keys())) if self.active_models else ModelZoo.get_random()
if isinstance(llm_model, str):
llm_model = ModelZoo.get_model_by_id(llm_model)
if llm_model in self.active_models:
return self.active_models[llm_model]

return await self.load_model(llm_model, force=True)
return await self.load_model(llm_model)

async def generate(
self,
Expand Down
10 changes: 3 additions & 7 deletions prompting/llms/vllm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,17 @@ def unload_model(self):
if hasattr(self.model, "llm_engine") and hasattr(self.model.llm_engine, "driver_worker"):
del self.model.llm_engine.driver_worker
if hasattr(self.model, "model"):
self.model = None
del self.model
if hasattr(self.model, "tokenizer"):
self.tokenizer = None
del self.tokenizer

gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()

if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Successfully deleted the LLM pipeline and freed GPU memory")

except Exception as e:
except BaseException as e:
logger.error(f"An error occurred during model unloading: {e}")
gc.collect()
if torch.cuda.is_available():
Expand Down
Loading