From 9ac28bc4ffe56b3758c13706032b7d2fbb88fe34 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Sat, 15 Mar 2025 17:06:06 +0200 Subject: [PATCH] snmp: fix async handling Have the callback fire an event to avoid waiting too long Signed-off-by: Benny Zlotnik --- .../jumpstarter_driver_snmp/driver.py | 50 +++++++++++++------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver.py b/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver.py index 72f42dcd2..dac94f642 100644 --- a/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver.py +++ b/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver.py @@ -2,6 +2,7 @@ import socket from dataclasses import dataclass, field from enum import Enum, IntEnum +from typing import Any, Dict, Tuple from pysnmp.carrier.asyncio.dgram import udp from pysnmp.entity import config, engine @@ -116,9 +117,7 @@ def _setup_snmp(self): def client(cls) -> str: return "jumpstarter_driver_snmp.client.SNMPServerClient" - def _snmp_set(self, state: PowerState): - result = {"success": False, "error": None} - + def _create_snmp_callback(self, result: Dict[str, Any], response_received: asyncio.Event): def callback(snmpEngine, sendRequestHandle, errorIndication, errorStatus, errorIndex, varBinds, cbCtx): self.logger.debug(f"Callback {errorIndication} {errorStatus} {errorIndex} {varBinds}") if errorIndication: @@ -135,20 +134,35 @@ def callback(snmpEngine, sendRequestHandle, errorIndication, errorStatus, errorI for oid, val in varBinds: self.logger.debug(f"{oid.prettyPrint()} = {val.prettyPrint()}") self.logger.debug(f"SNMP set result: {result}") + response_received.set() + return callback + + def _setup_event_loop(self) -> Tuple[asyncio.AbstractEventLoop, bool]: try: - self.logger.info(f"Sending power {state.name} command to {self.host}") - created_loop = False + loop = asyncio.get_running_loop() + return loop, False + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop, True + + async def _run_snmp_dispatcher(self, snmp_engine: engine.SnmpEngine, response_received: asyncio.Event): + snmp_engine.open_dispatcher() + await response_received.wait() + snmp_engine.close_dispatcher() - try: - asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - created_loop = True + def _snmp_set(self, state: PowerState): + result = {"success": False, "error": None} + response_received = asyncio.Event() + loop = None + created_loop = False + try: + self.logger.info(f"Sending power {state.name} command to {self.host}") + loop, created_loop = self._setup_event_loop() snmp_engine = self._setup_snmp() - + callback = self._create_snmp_callback(result, response_received) cmdgen.SetCommandGenerator().send_varbinds( snmp_engine, "my-target", @@ -158,11 +172,15 @@ def callback(snmpEngine, sendRequestHandle, errorIndication, errorStatus, errorI callback, ) - snmp_engine.open_dispatcher(self.timeout) - snmp_engine.close_dispatcher() + dispatcher_task = loop.create_task(self._run_snmp_dispatcher(snmp_engine, response_received)) + try: + loop.run_until_complete(asyncio.wait_for(dispatcher_task, self.timeout)) + except asyncio.TimeoutError: + self.logger.warning(f"SNMP operation timed out after {self.timeout} seconds") + result["error"] = "SNMP operation timed out" if not result["success"]: - raise SNMPError(result["error"]) + raise SNMPError(result["error"] or "Unknown SNMP error") return f"Power {state.name} command sent successfully" @@ -171,7 +189,7 @@ def callback(snmpEngine, sendRequestHandle, errorIndication, errorStatus, errorI self.logger.error(error_msg) raise SNMPError(error_msg) from e finally: - if created_loop: + if created_loop and loop: loop.close() @export