diff --git a/karma.conf.js b/karma.conf.js index e7e918f95..8a791e483 100644 --- a/karma.conf.js +++ b/karma.conf.js @@ -8,6 +8,7 @@ module.exports = (config) => { files: [ 'src/index.js', `src/${config.model ? config.model : '**'}/*_test.js`, + `src/${config.model ? config.model : '**'}/**/*_test.js`, ], preprocessors: { 'src/index.js': ['webpack'], diff --git a/package-lock.json b/package-lock.json index b36c256da..bb23e7628 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1045,6 +1045,11 @@ "resolved": "https://registry.npmjs.org/@tensorflow-models/body-pix/-/body-pix-1.1.2.tgz", "integrity": "sha512-moCCTlP77v20HMg1e/Hs1LehCDLAKS32e6OUeI1MA/4HrRRO1Dq9engVCLFZUMO2+mJXdQeBdzexcFg0WQox7w==" }, + "@tensorflow-models/coco-ssd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@tensorflow-models/coco-ssd/-/coco-ssd-2.0.0.tgz", + "integrity": "sha512-JexMswea9a5k8//sdImcxdVaofejkzUzJA8cpDBOZSauBkBM34M3c1YoEUrFuJSa5sKxgMu5TwMPWbJ59A7IVA==" + }, "@tensorflow-models/knn-classifier": { "version": "1.2.1", "resolved": "https://registry.npmjs.org/@tensorflow-models/knn-classifier/-/knn-classifier-1.2.1.tgz", diff --git a/package.json b/package.json index 5896b91c7..ae04c3c3a 100644 --- a/package.json +++ b/package.json @@ -98,6 +98,7 @@ "dependencies": { "@magenta/sketch": "0.2.0", "@tensorflow-models/body-pix": "1.1.2", + "@tensorflow-models/coco-ssd": "^2.0.0", "@tensorflow-models/knn-classifier": "1.2.1", "@tensorflow-models/mobilenet": "2.0.3", "@tensorflow-models/posenet": "2.1.3", diff --git a/src/ImageClassifier/index.js b/src/ImageClassifier/index.js index a7b5848c5..418cf6ddb 100644 --- a/src/ImageClassifier/index.js +++ b/src/ImageClassifier/index.js @@ -27,7 +27,7 @@ const MODEL_OPTIONS = ['mobilenet', 'darknet', 'darknet-tiny', 'doodlenet']; class ImageClassifier { /** * Create an ImageClassifier. - * @param {modelNameOrUrl} modelNameOrUrl - The name or the URL of the model to use. Current model name options + * @param {string} modelNameOrUrl - The name or the URL of the model to use. Current model name options * are: 'mobilenet', 'darknet', 'darknet-tiny', and 'doodlenet'. * @param {HTMLVideoElement} video - An HTMLVideoElement. * @param {object} options - An object with options. diff --git a/src/ObjectDetector/CocoSsd/index.js b/src/ObjectDetector/CocoSsd/index.js new file mode 100644 index 000000000..6dac986f7 --- /dev/null +++ b/src/ObjectDetector/CocoSsd/index.js @@ -0,0 +1,59 @@ +// Copyright (c) 2019 ml5 +// +// This software is released under the MIT License. +// https://opensource.org/licenses/MIT + +/* + COCO-SSD Object detection + Wraps the coco-ssd model in tfjs to be used in ml5 +*/ + +import * as cocoSsd from '@tensorflow-models/coco-ssd'; +import callCallback from '../../utils/callcallback'; + +class CocoSsd { + /** + * Create CocoSsd model. Works on video and images. + * @param {function} constructorCallback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise + * that will be resolved once the model has loaded. + */ + constructor(constructorCallback) { + this.constructorCallback = constructorCallback; + this.ready = callCallback(this.loadModel(), constructorCallback); + } + + async loadModel() { + await cocoSsd.load().then(_cocoSsdModel => { + this.cocoSsdModel = _cocoSsdModel; + }); + } + + /** + * Detect objects that are in video, returns bounding box, label, and confidence scores + * @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} subject - Subject of the detection. + * @param {function} callback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise + * that will be resolved once the prediction is done. + */ + async detect(subject, callback) { + await this.ready; + return this.cocoSsdModel.detect(subject).then((predictions) => { + const formattedPredictions = []; + for (let i = 0; i < predictions.length; i += 1) { + const prediction = predictions[i]; + formattedPredictions.push({ + label: prediction.class, + confidence: prediction.score, + x: prediction.bbox[0] / subject.width, + y: prediction.bbox[1] / subject.height, + w: prediction.bbox[2] / subject.width, + h: prediction.bbox[3] / subject.height, + }); + } + return callCallback(new Promise((resolve) => { + resolve(formattedPredictions); + }), callback); + }) + } +} + +export default CocoSsd; diff --git a/src/ObjectDetector/CocoSsd/index_test.js b/src/ObjectDetector/CocoSsd/index_test.js new file mode 100644 index 000000000..9340ce400 --- /dev/null +++ b/src/ObjectDetector/CocoSsd/index_test.js @@ -0,0 +1,49 @@ +// Copyright (c) 2019 ml5 +// +// This software is released under the MIT License. +// https://opensource.org/licenses/MIT + +describe('CocoSsd', () => { + let cocoSsd; + + async function getRobin() { + const img = new Image(); + img.crossOrigin = ''; + img.src = 'https://cdn.jsdelivr.net/gh/ml5js/ml5-library@development/assets/bird.jpg'; + await new Promise((resolve) => { img.onload = resolve; }); + return img; + } + + async function getImageData() { + const arr = new Uint8ClampedArray(40000); + + // Iterate through every pixel + for (let i = 0; i < arr.length; i += 4) { + arr[i + 0] = 0; // R value + arr[i + 1] = 190; // G value + arr[i + 2] = 0; // B value + arr[i + 3] = 255; // A value + } + + // Initialize a new ImageData object + const img = new ImageData(arr, 200); + return img; + } + + beforeEach(async () => { + jasmine.DEFAULT_TIMEOUT_INTERVAL = 100000; + cocoSsd = await objectDetector('CocoSsd'); + }); + + it('detects a robin', async () => { + const robin = await getRobin(); + const detection = await cocoSsd.detect(robin); + expect(detection[0].label).toBe('bird'); + }); + + it('detects takes ImageData', async () => { + const img = await getImageData(); + const detection = await cocoSsd.detect(img); + expect(detection).toEqual([]); + }); +}); diff --git a/src/YOLO/index.js b/src/ObjectDetector/YOLO/index.js similarity index 80% rename from src/YOLO/index.js rename to src/ObjectDetector/YOLO/index.js index 263e9dd49..9d235ce75 100644 --- a/src/YOLO/index.js +++ b/src/ObjectDetector/YOLO/index.js @@ -10,11 +10,14 @@ Heavily derived from https://github.com/ModelDepot/tfjs-yolo-tiny (ModelDepot: m */ import * as tf from '@tensorflow/tfjs'; -import Video from '../utils/Video'; -import { imgToTensor } from '../utils/imageUtilities'; -import callCallback from '../utils/callcallback'; -import CLASS_NAMES from './../utils/COCO_CLASSES'; -import modelLoader from '../utils/modelLoader'; +import Video from './../../utils/Video'; +import { + imgToTensor, + isInstanceOfSupportedElement +} from "./../../utils/imageUtilities"; +import callCallback from './../../utils/callcallback'; +import CLASS_NAMES from './../../utils/COCO_CLASSES'; +import modelLoader from './../../utils/modelLoader'; import { nonMaxSuppression, @@ -34,6 +37,9 @@ const DEFAULTS = { const imageSize = 416; class YOLOBase extends Video { + /** + * @deprecated Please use ObjectDetector class instead + */ /** * @typedef {Object} options * @property {number} filterBoxesThreshold - default 0.01 @@ -42,10 +48,9 @@ class YOLOBase extends Video { */ /** * Create YOLO model. Works on video and images. - * @param {HTMLVideoElement} video - Optional. The video to be used for object detection and classification. + * @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} video - Optional. The video to be used for object detection and classification. * @param {Object} options - Optional. A set of options. - * @param {function} callback - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise - * that will be resolved once the model has loaded. + * @param {function} callback - Optional. A callback function that is called once the model has loaded. */ constructor(video, options, callback) { super(video, imageSize); @@ -57,7 +62,10 @@ class YOLOBase extends Video { this.modelReady = false; this.isPredicting = false; this.ready = callCallback(this.loadModel(), callback); - // this.then = this.ready.then; + + if (!options.disableDeprecationNotice) { + console.warn("WARNING! Function YOLO has been deprecated, please use the new ObjectDetector function instead"); + } } async loadModel() { @@ -77,22 +85,22 @@ class YOLOBase extends Video { return this; } + /** + * Detect objects that are in video, returns bounding box, label, and confidence scores + * @param {HTMLVideoElement|HTMLImageElement|HTMLCanvasElement|ImageData} inputOrCallback - Subject of the detection, or callback + * @param {function} cb - Optional. A callback function that is called once the model has loaded. If no callback is provided, it will return a promise + * that will be resolved once the prediction is done. + */ async detect(inputOrCallback, cb) { await this.ready; let imgToPredict; let callback = cb; - if (inputOrCallback instanceof HTMLImageElement - || inputOrCallback instanceof HTMLVideoElement - || inputOrCallback instanceof HTMLCanvasElement - || inputOrCallback instanceof ImageData) { + if (isInstanceOfSupportedElement(inputOrCallback)) { imgToPredict = inputOrCallback; - } else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement - || inputOrCallback.elt instanceof HTMLVideoElement - || inputOrCallback.elt instanceof HTMLCanvasElement - || inputOrCallback.elt instanceof ImageData)) { + } else if (typeof inputOrCallback === "object" && isInstanceOfSupportedElement(inputOrCallback.elt)) { imgToPredict = inputOrCallback.elt; // Handle p5.js image and video. - } else if (typeof inputOrCallback === 'function') { + } else if (typeof inputOrCallback === "function") { imgToPredict = this.video; callback = inputOrCallback; } diff --git a/src/YOLO/index_test.js b/src/ObjectDetector/YOLO/index_test.js similarity index 96% rename from src/YOLO/index_test.js rename to src/ObjectDetector/YOLO/index_test.js index 56293b2fc..4c7d69f20 100644 --- a/src/YOLO/index_test.js +++ b/src/ObjectDetector/YOLO/index_test.js @@ -41,7 +41,7 @@ describe('YOLO', () => { beforeEach(async () => { jasmine.DEFAULT_TIMEOUT_INTERVAL = 100000; - yolo = await YOLO(); + yolo = await YOLO({ disableDeprecationNotice: true }); }); it('instantiates the YOLO classifier with defaults', () => { diff --git a/src/YOLO/postprocess.js b/src/ObjectDetector/YOLO/postprocess.js similarity index 100% rename from src/YOLO/postprocess.js rename to src/ObjectDetector/YOLO/postprocess.js diff --git a/src/ObjectDetector/index.js b/src/ObjectDetector/index.js new file mode 100644 index 000000000..f95dc8353 --- /dev/null +++ b/src/ObjectDetector/index.js @@ -0,0 +1,85 @@ +// Copyright (c) 2019 ml5 +// +// This software is released under the MIT License. +// https://opensource.org/licenses/MIT + +/* + ObjectDetection +*/ + +import YOLO from './YOLO/index'; +import CocoSsd from './CocoSsd/index'; +import { isInstanceOfSupportedElement } from '../utils/imageUtilities'; + +class ObjectDetector { + /** + * @typedef {Object} options + * @property {number} filterBoxesThreshold - Optional. default 0.01 + * @property {number} IOUThreshold - Optional. default 0.4 + * @property {number} classProbThreshold - Optional. default 0.4 + */ + /** + * Create ObjectDetector model. Works on video and images. + * @param {string} modelNameOrUrl - The name or the URL of the model to use. Current model name options + * are: 'YOLO' and 'CocoSsd'. + * @param {Object} options - Optional. A set of options. + * @param {function} callback - Optional. A callback function that is called once the model has loaded. + */ + constructor(modelNameOrUrl, options, callback) { + this.modelNameOrUrl = modelNameOrUrl; + this.options = options || {}; + this.callback = callback; + + switch (modelNameOrUrl) { + case "YOLO": + this.model = new YOLO( + { disableDeprecationNotice: true, ...options }, + callback + ); + break; + case "CocoSsd": + this.model = new CocoSsd(callback); + break; + default: + // Uses custom model url + this.model = new YOLO( + { + disableDeprecationNotice: true, + modelUrl: modelNameOrUrl, + ...options + }, + callback + ); + } + } + + /** + * @typedef {Object} ObjectDetectorPrediction + * @property {number} x - top left x coordinate of the prediction box (0 to 1). + * @property {number} y - top left y coordinate of the prediction box (0 to 1). + * @property {number} w - width of the prediction box (0 to 1). + * @property {number} h - height of the prediction box (0 to 1). + * @property {string} label - the label given. + * @property {number} confidence - the confidence score (0 to 1). + */ + /** + * Returns an array of predicted objects + * @param {function} callback - Optional. A callback that deliver the result. If no callback is + * given, a promise is will be returned. + * @return {ObjectDetectorPrediction[]} an array of the prediction result + */ + detect(subject, callback) { + if (isInstanceOfSupportedElement(subject)) { + return this.model.detect(subject, callback); + } else if (typeof subject === "object" && isInstanceOfSupportedElement(subject.elt)) { + return this.model.detect(subject.elt, callback); // Handle p5.js video and image + } + throw new Error('Detection subject not supported'); + } +} + +const objectDetector = (modelName, video, options, callback) => { + return new ObjectDetector(modelName, video, options, callback) +} + +export default objectDetector; diff --git a/src/ObjectDetector/index_test.js b/src/ObjectDetector/index_test.js new file mode 100644 index 000000000..d84d5b9cf --- /dev/null +++ b/src/ObjectDetector/index_test.js @@ -0,0 +1,26 @@ +// Copyright (c) 2019 ml5 +// +// This software is released under the MIT License. +// https://opensource.org/licenses/MIT + +const { objectDetector } = ml5; + +xdescribe('ObjectDetector', () => { + let cocoSsd; + + beforeEach(async () => { + jasmine.DEFAULT_TIMEOUT_INTERVAL = 100000; + cocoSsd = await objectDetector('CocoSsd'); + }); + + it('throws error when a non image is trying to be detected', async () => { + const notAnImage = 'not_an_image' + try { + await cocoSsd.detect(notAnImage); + fail('Error should have been thrown'); + } + catch (error) { + expect(error.message).toBe('Detection subject not supported'); + } + }); +}); diff --git a/src/index.js b/src/index.js index 076402c86..8c973145b 100644 --- a/src/index.js +++ b/src/index.js @@ -10,7 +10,8 @@ import soundClassifier from './SoundClassifier/'; import KNNClassifier from './KNNClassifier/'; import featureExtractor from './FeatureExtractor/'; import word2vec from './Word2vec/'; -import YOLO from './YOLO'; +import YOLO from './ObjectDetector/YOLO'; +import objectDetector from './ObjectDetector'; import poseNet from './PoseNet'; import * as imageUtils from './utils/imageUtilities'; import styleTransfer from './StyleTransfer/'; @@ -44,6 +45,7 @@ const withPreload = { styleTransfer, word2vec, YOLO, + objectDetector, uNet, sentiment, bodyPix, diff --git a/src/utils/imageUtilities.js b/src/utils/imageUtilities.js index 63b066154..902252c0e 100644 --- a/src/utils/imageUtilities.js +++ b/src/utils/imageUtilities.js @@ -134,10 +134,18 @@ function imgToTensor(input, size = null) { }); } +function isInstanceOfSupportedElement(subject) { + return (subject instanceof HTMLVideoElement + || subject instanceof HTMLImageElement + || subject instanceof HTMLCanvasElement + || subject instanceof ImageData) +} + export { array3DToImage, processVideo, cropImage, imgToTensor, + isInstanceOfSupportedElement, flipImage -}; \ No newline at end of file +};