Skip to content

Commit

Permalink
coreml: yolo-nas
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed May 11, 2024
1 parent 01dd480 commit bb1e0ac
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 5 deletions.
4 changes: 2 additions & 2 deletions plugins/coreml/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion plugins/coreml/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
"devDependencies": {
"@scrypted/sdk": "file:../../sdk"
},
"version": "0.1.49"
"version": "0.1.50"
}
10 changes: 9 additions & 1 deletion plugins/coreml/src/coreml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import re
from typing import Any, List, Tuple
import numpy as np

import coremltools as ct
import scrypted_sdk
Expand All @@ -26,6 +27,7 @@

availableModels = [
"Default",
"scrypted_yolo_nas_s_320",
"scrypted_yolov9c_320",
"scrypted_yolov9c",
"scrypted_yolov6n_320",
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(self, nativeId: str | None = None):
self.storage.setItem("model", "Default")
model = "scrypted_yolov9c_320"
self.yolo = "yolo" in model
self.scrypted_yolo_nas = "scrypted_yolo_nas" in model
self.scrypted_yolo = "scrypted_yolo" in model
self.scrypted_model = "scrypted" in model
model_version = "v7"
Expand Down Expand Up @@ -211,7 +214,12 @@ async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss):
if self.yolo:
out_dict = await self.queue_batch({self.input_name: input})

if self.scrypted_yolo:
if self.scrypted_yolo_nas:
predictions = list(out_dict.values())
objs = yolo.parse_yolo_nas(predictions)
ret = self.create_detection_result(objs, src_size, cvss)
return ret
elif self.scrypted_yolo:
results = list(out_dict.values())[0][0]
objs = yolo.parse_yolov9(results)
ret = self.create_detection_result(objs, src_size, cvss)
Expand Down
14 changes: 14 additions & 0 deletions plugins/openvino/src/common/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@

defaultThreshold = .2

def parse_yolo_nas(predictions):
objs = []
for pred_scores, pred_bboxes in zip(*predictions):
i, j = np.nonzero(pred_scores > .5)
pred_bboxes = pred_bboxes[i]
pred_cls_conf = pred_scores[i, j]
pred_cls_label = j[:]
for box, conf, label in zip(pred_bboxes, pred_cls_conf, pred_cls_label):
obj = Prediction(
int(label), conf.astype(float), Rectangle(box[0].astype(float), box[1].astype(float), box[2].astype(float), box[3].astype(float))
)
objs.append(obj)
return objs

def parse_yolov9(results, threshold = defaultThreshold, scale = None, confidence_scale = None):
objs = []
keep = np.argwhere(results[4:] > threshold)
Expand Down
1 change: 0 additions & 1 deletion plugins/tensorflow-lite/src/detect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ async def detectObjects(self, mediaObject: MediaObject, session: ObjectDetection
if mediaObject.mimeType == ScryptedMimeTypes.Image.value:
image = await scrypted_sdk.sdk.connectRPCObject(mediaObject)
else:
print('non image provided')
image = await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(mediaObject, ScryptedMimeTypes.Image.value)

return await self.run_detection_image(image, session)

0 comments on commit bb1e0ac

Please sign in to comment.