Skip to content

Commit

Permalink
Add model functionalities, update example
Browse files Browse the repository at this point in the history
  • Loading branch information
OrpheasK committed Jan 26, 2024
1 parent 9aeb162 commit 3443d8f
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 33 deletions.
1 change: 0 additions & 1 deletion examples/ImageClassifier-video/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
<script src="../../dist/ml5.js"></script>
</head>
<body>
<h1>Webcam Image classification using MobileNet</h1>
<script src="sketch.js"></script>
</body>
</html>
43 changes: 29 additions & 14 deletions examples/ImageClassifier-video/sketch.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
/* ===
ml5 Example
Webcam video classification using MobileNet and p5.js
This example uses a callback function to display the results
This example uses a callback function to update the canvas label with the latest results,
it makes use of the p5 mousePressed() function to toggle between an active classification
=== */

// A variable to initialize the Image Classifier
Expand All @@ -15,28 +16,42 @@ let classifier;
// A variable to hold the video we want to classify
let vid;

// Element for displaying the results
let resultsP;
// Variable for displaying the results on the canvas
let label = 'Model loading...';

function preload() {
classifier = ml5.imageClassifier('MobileNet');
}

function setup() {
noCanvas();
// Using webcam feed as video input
createCanvas(640, 480);
background(255);
textSize(32);
fill(255);
// Using webcam feed as video input, hiding html element to avoid duplicate with canvas
vid = createCapture(VIDEO);
vid.hide();
classifier.classifyStart(vid, gotResult);
resultsP = createP("Model loading...");
}

// A function to run when we get any errors and the results
function gotResult(results, error) {
// Display error in the console
if (error) {
console.error(error);
function draw() {
//Each video frame is painted on the canvas
image(vid, 0, 0);
//Printing class with the highest probability on the canvas
text(label, 20, 50);
}

//A mouse click to stop and restart the classification process
function mousePressed(){
if (classifier.isClassifying){
classifier.classifyStop();
}else{
classifier.classifyStart(vid, gotResult);
}
// The results are in an array ordered by confidence.
console.log(results);
resultsP.html('Label: ' + results[0].label + '<br>Confidence: ' + nf(results[0].confidence, 0, 2));
}

// A function to run when we get the results and any errors
function gotResult(results) {
//update label variable which is displayed on the canvas
label = results[0].label;
}
69 changes: 51 additions & 18 deletions src/ImageClassifier/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import handleArguments from "../utils/handleArguments";
import * as darknet from "./darknet";
import * as doodlenet from "./doodlenet";
import callCallback from "../utils/callcallback";
import { mediaReady } from "../utils/imageUtilities";
import { imgToTensor, mediaReady } from "../utils/imageUtilities";

const DEFAULTS = {
mobilenet: {
Expand All @@ -24,6 +24,7 @@ const DEFAULTS = {
topk: 3,
},
};
const IMAGE_SIZE = 224;
const MODEL_OPTIONS = ["mobilenet", "darknet", "darknet-tiny", "doodlenet"];

class ImageClassifier {
Expand All @@ -37,6 +38,12 @@ class ImageClassifier {
constructor(modelNameOrUrl, options, callback) {
this.model = null;
this.mapStringToIndex = [];

// flags for classifyStart() and classifyStop()
this.isClassifying = false;// True when classification loop is running
this.signalStop = false; // Signal to stop the loop
this.prevCall = ""; // Track previous call to detectStart() or detectStop()

if (typeof modelNameOrUrl === "string") {
if (MODEL_OPTIONS.includes(modelNameOrUrl)) {
this.modelName = modelNameOrUrl;
Expand Down Expand Up @@ -134,6 +141,10 @@ class ImageClassifier {
await this.ready;
await mediaReady(imgToPredict, true);

// For Doodlenet and Teachable Machine models a manual resizing of the image is still necessary
const imageResize = [IMAGE_SIZE, IMAGE_SIZE];
if (this.modelName == "doodlenet" || this.modelUrl) imgToPredict = imgToTensor(imgToPredict, imageResize);

if (this.modelUrl) {
await tf.nextFrame();
const predictedClasses = tf.tidy(() => {
Expand Down Expand Up @@ -188,26 +199,48 @@ class ImageClassifier {
* @param {function} cb - a callback function that handles the results of the function.
* @return {function} a promise or the results of a given callback, cb.
*/
async classifyStart(inputNumOrCallback, numOrCallback, cb) {
const { image, number, callback } = handleArguments(inputNumOrCallback, numOrCallback, cb)
.require('image', "No input provided.");

// Function to classify a single frame
const classifyFrame = async () => {
await mediaReady(image, true);
await this.classifyInternal(image, number);
// call the callback function
callCallback(this.classifyInternal(image, number), callback);

// call recursively for continuous classification
async classifyStart(inputNumOrCallback, numOrCallback, cb) {
const { image, number, callback } = handleArguments(inputNumOrCallback, numOrCallback, cb)
.require('image', "No input provided.");

// Function to classify a single frame
const classifyFrame = async () => {
await mediaReady(image, true);
await this.classifyInternal(image, number);
// call the callback function
callCallback(this.classifyInternal(image, number), callback);

// call recursively for continuous classification
if (!this.signalStop){
requestAnimationFrame(classifyFrame);
};

// Start the classification
}else{
this.isClassifying = false;
}
};

// Start the classification
this.signalStop = false;
if (!this.isClassifying){
this.isClassifying = true;
classifyFrame();

return callCallback(this.classifyInternal(image, number), callback);
}
if (this.prevCall === "start") {
console.warn(
"classifyStart() was called more than once without calling classifyStop(). Only the latest classifyStart() call will take effect."
);
}
this.prevCall = "start";
}

/**
* Used to stop the continuous classification of a video
*/
classifyStop(){
if (this.isClassifying) {
this.signalStop = true;
}
this.prevCall = "stop";
}
}

const imageClassifier = (modelName, optionsOrCallback, cb) => {
Expand Down

0 comments on commit 3443d8f

Please sign in to comment.