diff --git a/octoprint_nanny/clients/rest.py b/octoprint_nanny/clients/rest.py index 8058684e..d5dcb73c 100644 --- a/octoprint_nanny/clients/rest.py +++ b/octoprint_nanny/clients/rest.py @@ -3,7 +3,7 @@ import urllib import hashlib import backoff - +import json import beeline from octoprint.events import Events @@ -21,6 +21,7 @@ from print_nanny_client.models.octo_print_device_request import ( OctoPrintDeviceRequest, ) +from octoprint_nanny.utils.encoder import NumpyEncoder logger = logging.getLogger("octoprint.plugins.octoprint_nanny.clients.rest") @@ -294,3 +295,25 @@ async def update_or_create_printer_profile( request ) return printer_profile + + @beeline.traced("RestAPIClient.update_or_create_device_calibration") + @backoff.on_exception( + backoff.expo, + aiohttp.ClientConnectionError, + logger=logger, + max_time=MAX_BACKOFF_TIME, + ) + async def update_or_create_device_calibration( + self, octoprint_device_id, coordinates, mask + ): + mask = json.dumps(mask, cls=NumpyEncoder) + async with AsyncApiClient(self._api_config) as api_client: + api_instance = print_nanny_client.MlOpsApi(api_client=api_client) + + request = print_nanny_client.DeviceCalibrationRequest( + octoprint_device=octoprint_device_id, coordinates=coordinates, mask=mask + ) + device_calibration = await api_instance.device_calibration_update_or_create( + request + ) + return device_calibration diff --git a/octoprint_nanny/manager.py b/octoprint_nanny/manager.py index 2cbacde4..3a5b8dc7 100644 --- a/octoprint_nanny/manager.py +++ b/octoprint_nanny/manager.py @@ -111,8 +111,9 @@ def _register_plugin_event_handlers(self): """ callbacks = { - Events.PLUGIN_OCTOPRINT_NANNY_MONITORING_START: self.monitoring_manager.start, - Events.PLUGIN_OCTOPRINT_NANNY_MONITORING_STOP: self.monitoring_manager.stop, + Events.PLUGIN_OCTOPRINT_NANNY_RC_MONITORING_START: self.monitoring_manager.start, + Events.PLUGIN_OCTOPRINT_NANNY_RC_MONITORING_STOP: self.monitoring_manager.stop, + Events.PLUGIN_OCTOPRINT_NANNY_CALIBRATION_UPDATE: self.on_calibration_update, } self.mqtt_manager.publisher_worker.register_callbacks(callbacks) logger.info(f"Registered callbacks {callbacks} on publisher worker") @@ -167,6 +168,9 @@ async def shutdown(self): await self.mqtt_manager.stop() self._honeycomb_tracer.on_shutdown() + ## + # Event handlers + ## @beeline.traced("WorkerManager.on_print_start") async def on_print_start(self, event_type, event_data, **kwargs): logger.info(f"on_print_start called for {event_type} with data {event_data}") @@ -207,3 +211,22 @@ async def on_print_start(self, event_type, event_data, **kwargs): if self.plugin.get_setting("auto_start"): logger.info("Print Nanny monitoring is set to auto-start") self.monitoring_manager.start() + + async def on_calibration_update(self, event_type, event_data, **kwargs): + logger.info( + f"{self.__class__}.on_calibration_update called for event_type={event_type} event_data={event_data}" + ) + await self.apply_monitoring_settings() + device_calibration = ( + await self.plugin.settings.rest_client.update_or_create_device_calibration( + self.plugin.settings.device_id, + { + "x0": self.plugin.get_setting("calibrate_x0"), + "x1": self.plugin.get_setting("calibrate_x1"), + "y0": self.plugin.get_setting("calibrate_y0"), + "y1": self.plugin.get_setting("calibrate_y1"), + }, + self.plugin.settings.calibration, + ) + ) + logger.info(f"Device calibration upsert succeeded {device_calibration}") diff --git a/octoprint_nanny/plugins.py b/octoprint_nanny/plugins.py index 7834541b..33ed5d83 100644 --- a/octoprint_nanny/plugins.py +++ b/octoprint_nanny/plugins.py @@ -465,13 +465,13 @@ def start_predict(self): Events.PLUGIN_OCTOPRINT_NANNY_PREDICT_DONE, payload={"image": base64.b64encode(res.content)}, ) - self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_MONITORING_START) + self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_RC_MONITORING_START) return flask.json.jsonify({"ok": 1}) @beeline.traced(name="OctoPrintNannyPlugin.stop_predict") @octoprint.plugin.BlueprintPlugin.route("/stopPredict", methods=["POST"]) def stop_predict(self): - self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_MONITORING_STOP) + self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_RC_MONITORING_STOP) return flask.json.jsonify({"ok": 1}) @beeline.traced(name="OctoPrintNannyPlugin.register_device") @@ -541,18 +541,24 @@ def test_auth_token(self): def register_custom_events(self): return [ + # events from octoprint plugin + "calibration_update", "predict_done", - "monitoring_start", - "monitoring_stop", - "snapshot", "device_register_start", "device_register_done", "device_register_failed", "printer_profile_sync_start", "printer_profile_sync_done", "printer_profile_sync_failed", - "worker_stop", - "worker_start", + # events from RemoteControlCommand.CommandChoices (webapp) + "rc_print_start", + "rc_print_stop", + "rc_print_pause", + "rc_print_resume", + "rc_snapshot", + "rc_move_nozzle", + "rc_monitoring_start", + "rc_monitoring_stop", ] @beeline.traced(name="OctoPrintNannyPlugin.on_after_startup") @@ -641,10 +647,6 @@ def on_settings_save(self, data): ["mqtt_bridge_certificate_url"] ) - if prev_mqtt_bridge_certificate_url != new_mqtt_bridge_certificate_url: - asyncio.run_coroutine_threadsafe( - self._download_root_certificates(), self.worker_manager.loop - ) if ( prev_monitoring_fpm != new_monitoring_fpm or prev_calibration != new_calibration @@ -652,8 +654,10 @@ def on_settings_save(self, data): logger.info( "Change in frames per minute or calibration detected, applying new settings" ) - self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_PREDICT_OFFLINE) - self.worker_manager.apply_monitoring_settings() + self._event_bus.fire( + Events.PLUGIN_OCTOPRINT_NANNY_CALIBRATION_UPDATE, + payload={"calibration"}, + ) if prev_auth_token != new_auth_token: logger.info("Change in auth detected, applying new settings") diff --git a/octoprint_nanny/predictor.py b/octoprint_nanny/predictor.py index 9e865703..a3ca9508 100644 --- a/octoprint_nanny/predictor.py +++ b/octoprint_nanny/predictor.py @@ -19,7 +19,8 @@ import sys import beeline -from PIL import Image as PImage + +import PIL import requests import tflite_runtime.interpreter as tflite @@ -49,7 +50,7 @@ class Prediction(TypedDict): detection_scores: np.ndarray detection_boxes: np.ndarray detection_classes: np.ndarray - viz: Optional[PImage.Image] + viz: Optional[PIL.Image.Image] class ThreadLocalPredictor(threading.local): @@ -92,6 +93,8 @@ def __init__( with open(self.metadata_path) as f: self.metadata = json.load(f) + self.input_shape = self.metadata["inputShape"] + with open(self.label_path) as f: self.category_index = [l.strip() for l in f.readlines()] self.category_index = { @@ -102,18 +105,23 @@ def __init__( self.calibration = calibration def load_image(self, bytes): - return PImage.open(bytes) + return PIL.Image.open(bytes) def load_file(self, filepath: str): - return PImage.open(filepath) + return PIL.Image.open(filepath) - def preprocess(self, image: PImage): + def preprocess(self, image: PIL.Image): + # resize to input shape provided by model metadata.json + _, target_height, target_width, _ = self.input_shape + image = image.resize((target_width, target_height), resample=PIL.Image.BILINEAR) image = np.asarray(image) + # expand dimensions to batch size = 1 + image = np.expand_dims(image, 0) return image def write_image(self, outfile: str, image_np: np.ndarray): - img = PImage.fromarray(image_np) + img = PIL.Image.fromarray(image_np) img.save(outfile) def percent_intersection( @@ -161,7 +169,7 @@ def percent_intersection( return aou - def postprocess(self, image: PImage, prediction: Prediction) -> np.array: + def postprocess(self, image: PIL.Image, prediction: Prediction) -> np.array: image_np = np.asarray(image).copy() height, width, _ = image_np.shape @@ -204,7 +212,7 @@ def postprocess(self, image: PImage, prediction: Prediction) -> np.array: ) return viz - def predict(self, image: PImage) -> Prediction: + def predict(self, image: PIL.Image) -> Prediction: tensor = self.preprocess(image) self.tflite_interpreter.set_tensor(self.input_details[0]["index"], tensor) @@ -242,7 +250,7 @@ def _get_predict_bytes(msg): prediction = predictor.predict(image) viz_np = predictor.postprocess(image, prediction) - viz_image = PImage.fromarray(viz_np, "RGB") + viz_image = PIL.Image.fromarray(viz_np, "RGB") viz_buffer = io.BytesIO() viz_buffer.name = "annotated_image.jpg" viz_image.save(viz_buffer, format="JPEG") @@ -363,17 +371,9 @@ def _create_msgs(self, msg, viz_buffer, prediction): mqtt_msg = msg.copy() # publish bounding box prediction to mqtt telemetry topic # del mqtt_msg["original_image"] - - calibration = ( - self._calibration - if self._calibration is None - else self._calibration.get("coords") - ) mqtt_msg.update( { - "calibration": calibration, "event_type": "bounding_box_predict", - "fpm": self._fpm, } ) mqtt_msg.update(prediction) diff --git a/octoprint_nanny/workers/monitoring.py b/octoprint_nanny/workers/monitoring.py index 00ed75d1..44023120 100644 --- a/octoprint_nanny/workers/monitoring.py +++ b/octoprint_nanny/workers/monitoring.py @@ -64,7 +64,7 @@ def _reset(self): logger.info(f"Finished resetting MonitoringManager") @beeline.traced("MonitoringManager.start") - async def start(self): + async def start(self, **kwargs): self._reset() @@ -79,7 +79,7 @@ async def start(self): ) @beeline.traced("MonitoringManager.stop") - async def stop(self): + async def stop(self, **kwargs): self._drain() await self.plugin.settings.rest_client.update_octoprint_device( self.plugin.settings.device_id, monitoring_active=False diff --git a/octoprint_nanny/workers/mqtt.py b/octoprint_nanny/workers/mqtt.py index cd89074c..0d1dbe27 100644 --- a/octoprint_nanny/workers/mqtt.py +++ b/octoprint_nanny/workers/mqtt.py @@ -231,21 +231,22 @@ async def _loop(self): tracked = await self.plugin.settings.event_in_tracked_telemetry(event_type) if tracked: await self._publish_octoprint_event_telemetry(event) - else: - if event_type not in self.MUTED_EVENTS: - logger.warning(f"Discarding {event_type} with payload {event}") - return handler_fns = self._callbacks.get(event_type) if handler_fns is None: - logger.info(f"No {self.__class__} handler registered for {event_type}") + if event_type not in self.MUTED_EVENTS: + logger.info( + f"No {self.__class__} handler registered for {event_type}" + ) return for handler_fn in handler_fns: - if handler_fn: - if inspect.isawaitable(handler_fn): - await handler_fn(**event) - else: - handler_fn(**event) + logger.debug(f"MQTTPublisherWorker calling handler_fn={handler_fn}") + if inspect.isawaitable(handler_fn) or inspect.iscoroutinefunction( + handler_fn + ): + await handler_fn(**event) + else: + handler_fn(**event) except API_CLIENT_EXCEPTIONS as e: logger.error(f"REST client raised exception {e}", exc_info=True)