Skip to content

Commit

Permalink
Remove calibration from telemetry messages (#124)
Browse files Browse the repository at this point in the history
* add rc_ prefix to event strings originating from webapp RemoteControlCommand.CommandChoices enum

* persist calibration on settings save

* remove calibration from every inference message
  • Loading branch information
leigh-johnson committed Feb 14, 2021
1 parent 33bbd1b commit 5b3397f
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 45 deletions.
25 changes: 24 additions & 1 deletion octoprint_nanny/clients/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import urllib
import hashlib
import backoff

import json
import beeline

from octoprint.events import Events
Expand All @@ -21,6 +21,7 @@
from print_nanny_client.models.octo_print_device_request import (
OctoPrintDeviceRequest,
)
from octoprint_nanny.utils.encoder import NumpyEncoder


logger = logging.getLogger("octoprint.plugins.octoprint_nanny.clients.rest")
Expand Down Expand Up @@ -294,3 +295,25 @@ async def update_or_create_printer_profile(
request
)
return printer_profile

@beeline.traced("RestAPIClient.update_or_create_device_calibration")
@backoff.on_exception(
backoff.expo,
aiohttp.ClientConnectionError,
logger=logger,
max_time=MAX_BACKOFF_TIME,
)
async def update_or_create_device_calibration(
self, octoprint_device_id, coordinates, mask
):
mask = json.dumps(mask, cls=NumpyEncoder)
async with AsyncApiClient(self._api_config) as api_client:
api_instance = print_nanny_client.MlOpsApi(api_client=api_client)

request = print_nanny_client.DeviceCalibrationRequest(
octoprint_device=octoprint_device_id, coordinates=coordinates, mask=mask
)
device_calibration = await api_instance.device_calibration_update_or_create(
request
)
return device_calibration
27 changes: 25 additions & 2 deletions octoprint_nanny/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def _register_plugin_event_handlers(self):
"""

callbacks = {
Events.PLUGIN_OCTOPRINT_NANNY_MONITORING_START: self.monitoring_manager.start,
Events.PLUGIN_OCTOPRINT_NANNY_MONITORING_STOP: self.monitoring_manager.stop,
Events.PLUGIN_OCTOPRINT_NANNY_RC_MONITORING_START: self.monitoring_manager.start,
Events.PLUGIN_OCTOPRINT_NANNY_RC_MONITORING_STOP: self.monitoring_manager.stop,
Events.PLUGIN_OCTOPRINT_NANNY_CALIBRATION_UPDATE: self.on_calibration_update,
}
self.mqtt_manager.publisher_worker.register_callbacks(callbacks)
logger.info(f"Registered callbacks {callbacks} on publisher worker")
Expand Down Expand Up @@ -167,6 +168,9 @@ async def shutdown(self):
await self.mqtt_manager.stop()
self._honeycomb_tracer.on_shutdown()

##
# Event handlers
##
@beeline.traced("WorkerManager.on_print_start")
async def on_print_start(self, event_type, event_data, **kwargs):
logger.info(f"on_print_start called for {event_type} with data {event_data}")
Expand Down Expand Up @@ -207,3 +211,22 @@ async def on_print_start(self, event_type, event_data, **kwargs):
if self.plugin.get_setting("auto_start"):
logger.info("Print Nanny monitoring is set to auto-start")
self.monitoring_manager.start()

async def on_calibration_update(self, event_type, event_data, **kwargs):
logger.info(
f"{self.__class__}.on_calibration_update called for event_type={event_type} event_data={event_data}"
)
await self.apply_monitoring_settings()
device_calibration = (
await self.plugin.settings.rest_client.update_or_create_device_calibration(
self.plugin.settings.device_id,
{
"x0": self.plugin.get_setting("calibrate_x0"),
"x1": self.plugin.get_setting("calibrate_x1"),
"y0": self.plugin.get_setting("calibrate_y0"),
"y1": self.plugin.get_setting("calibrate_y1"),
},
self.plugin.settings.calibration,
)
)
logger.info(f"Device calibration upsert succeeded {device_calibration}")
30 changes: 17 additions & 13 deletions octoprint_nanny/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,13 @@ def start_predict(self):
Events.PLUGIN_OCTOPRINT_NANNY_PREDICT_DONE,
payload={"image": base64.b64encode(res.content)},
)
self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_MONITORING_START)
self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_RC_MONITORING_START)
return flask.json.jsonify({"ok": 1})

@beeline.traced(name="OctoPrintNannyPlugin.stop_predict")
@octoprint.plugin.BlueprintPlugin.route("/stopPredict", methods=["POST"])
def stop_predict(self):
self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_MONITORING_STOP)
self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_RC_MONITORING_STOP)
return flask.json.jsonify({"ok": 1})

@beeline.traced(name="OctoPrintNannyPlugin.register_device")
Expand Down Expand Up @@ -541,18 +541,24 @@ def test_auth_token(self):

def register_custom_events(self):
return [
# events from octoprint plugin
"calibration_update",
"predict_done",
"monitoring_start",
"monitoring_stop",
"snapshot",
"device_register_start",
"device_register_done",
"device_register_failed",
"printer_profile_sync_start",
"printer_profile_sync_done",
"printer_profile_sync_failed",
"worker_stop",
"worker_start",
# events from RemoteControlCommand.CommandChoices (webapp)
"rc_print_start",
"rc_print_stop",
"rc_print_pause",
"rc_print_resume",
"rc_snapshot",
"rc_move_nozzle",
"rc_monitoring_start",
"rc_monitoring_stop",
]

@beeline.traced(name="OctoPrintNannyPlugin.on_after_startup")
Expand Down Expand Up @@ -641,19 +647,17 @@ def on_settings_save(self, data):
["mqtt_bridge_certificate_url"]
)

if prev_mqtt_bridge_certificate_url != new_mqtt_bridge_certificate_url:
asyncio.run_coroutine_threadsafe(
self._download_root_certificates(), self.worker_manager.loop
)
if (
prev_monitoring_fpm != new_monitoring_fpm
or prev_calibration != new_calibration
):
logger.info(
"Change in frames per minute or calibration detected, applying new settings"
)
self._event_bus.fire(Events.PLUGIN_OCTOPRINT_NANNY_PREDICT_OFFLINE)
self.worker_manager.apply_monitoring_settings()
self._event_bus.fire(
Events.PLUGIN_OCTOPRINT_NANNY_CALIBRATION_UPDATE,
payload={"calibration"},
)

if prev_auth_token != new_auth_token:
logger.info("Change in auth detected, applying new settings")
Expand Down
34 changes: 17 additions & 17 deletions octoprint_nanny/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import sys

import beeline
from PIL import Image as PImage

import PIL
import requests
import tflite_runtime.interpreter as tflite

Expand Down Expand Up @@ -49,7 +50,7 @@ class Prediction(TypedDict):
detection_scores: np.ndarray
detection_boxes: np.ndarray
detection_classes: np.ndarray
viz: Optional[PImage.Image]
viz: Optional[PIL.Image.Image]


class ThreadLocalPredictor(threading.local):
Expand Down Expand Up @@ -92,6 +93,8 @@ def __init__(
with open(self.metadata_path) as f:
self.metadata = json.load(f)

self.input_shape = self.metadata["inputShape"]

with open(self.label_path) as f:
self.category_index = [l.strip() for l in f.readlines()]
self.category_index = {
Expand All @@ -102,18 +105,23 @@ def __init__(
self.calibration = calibration

def load_image(self, bytes):
return PImage.open(bytes)
return PIL.Image.open(bytes)

def load_file(self, filepath: str):
return PImage.open(filepath)
return PIL.Image.open(filepath)

def preprocess(self, image: PImage):
def preprocess(self, image: PIL.Image):
# resize to input shape provided by model metadata.json
_, target_height, target_width, _ = self.input_shape
image = image.resize((target_width, target_height), resample=PIL.Image.BILINEAR)
image = np.asarray(image)
# expand dimensions to batch size = 1
image = np.expand_dims(image, 0)
return image

def write_image(self, outfile: str, image_np: np.ndarray):

img = PImage.fromarray(image_np)
img = PIL.Image.fromarray(image_np)
img.save(outfile)

def percent_intersection(
Expand Down Expand Up @@ -161,7 +169,7 @@ def percent_intersection(

return aou

def postprocess(self, image: PImage, prediction: Prediction) -> np.array:
def postprocess(self, image: PIL.Image, prediction: Prediction) -> np.array:

image_np = np.asarray(image).copy()
height, width, _ = image_np.shape
Expand Down Expand Up @@ -204,7 +212,7 @@ def postprocess(self, image: PImage, prediction: Prediction) -> np.array:
)
return viz

def predict(self, image: PImage) -> Prediction:
def predict(self, image: PIL.Image) -> Prediction:
tensor = self.preprocess(image)

self.tflite_interpreter.set_tensor(self.input_details[0]["index"], tensor)
Expand Down Expand Up @@ -242,7 +250,7 @@ def _get_predict_bytes(msg):
prediction = predictor.predict(image)

viz_np = predictor.postprocess(image, prediction)
viz_image = PImage.fromarray(viz_np, "RGB")
viz_image = PIL.Image.fromarray(viz_np, "RGB")
viz_buffer = io.BytesIO()
viz_buffer.name = "annotated_image.jpg"
viz_image.save(viz_buffer, format="JPEG")
Expand Down Expand Up @@ -363,17 +371,9 @@ def _create_msgs(self, msg, viz_buffer, prediction):
mqtt_msg = msg.copy()
# publish bounding box prediction to mqtt telemetry topic
# del mqtt_msg["original_image"]

calibration = (
self._calibration
if self._calibration is None
else self._calibration.get("coords")
)
mqtt_msg.update(
{
"calibration": calibration,
"event_type": "bounding_box_predict",
"fpm": self._fpm,
}
)
mqtt_msg.update(prediction)
Expand Down
4 changes: 2 additions & 2 deletions octoprint_nanny/workers/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _reset(self):
logger.info(f"Finished resetting MonitoringManager")

@beeline.traced("MonitoringManager.start")
async def start(self):
async def start(self, **kwargs):

self._reset()

Expand All @@ -79,7 +79,7 @@ async def start(self):
)

@beeline.traced("MonitoringManager.stop")
async def stop(self):
async def stop(self, **kwargs):
self._drain()
await self.plugin.settings.rest_client.update_octoprint_device(
self.plugin.settings.device_id, monitoring_active=False
Expand Down
21 changes: 11 additions & 10 deletions octoprint_nanny/workers/mqtt.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,22 @@ async def _loop(self):
tracked = await self.plugin.settings.event_in_tracked_telemetry(event_type)
if tracked:
await self._publish_octoprint_event_telemetry(event)
else:
if event_type not in self.MUTED_EVENTS:
logger.warning(f"Discarding {event_type} with payload {event}")
return

handler_fns = self._callbacks.get(event_type)
if handler_fns is None:
logger.info(f"No {self.__class__} handler registered for {event_type}")
if event_type not in self.MUTED_EVENTS:
logger.info(
f"No {self.__class__} handler registered for {event_type}"
)
return
for handler_fn in handler_fns:
if handler_fn:
if inspect.isawaitable(handler_fn):
await handler_fn(**event)
else:
handler_fn(**event)
logger.debug(f"MQTTPublisherWorker calling handler_fn={handler_fn}")
if inspect.isawaitable(handler_fn) or inspect.iscoroutinefunction(
handler_fn
):
await handler_fn(**event)
else:
handler_fn(**event)
except API_CLIENT_EXCEPTIONS as e:
logger.error(f"REST client raised exception {e}", exc_info=True)

Expand Down

0 comments on commit 5b3397f

Please sign in to comment.