Skip to content

Commit

Permalink
use model name and options handling for imageClassifier
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyuan-linn committed Feb 29, 2024
1 parent 561c9cd commit ad18e64
Showing 1 changed file with 78 additions and 48 deletions.
126 changes: 78 additions & 48 deletions src/ImageClassifier/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,27 @@ import * as darknet from "./darknet";
import * as doodlenet from "./doodlenet";
import callCallback from "../utils/callcallback";
import { imgToTensor, mediaReady } from "../utils/imageUtilities";
import handleOptions from "../utils/handleOptions";
import { handleModelName } from "../utils/handleOptions";

const DEFAULTS = {
mobilenet: {
version: 2,
alpha: 1.0,
topk: 3,
},
};
const IMAGE_SIZE = 224;
const MODEL_OPTIONS = ["mobilenet", "darknet", "darknet-tiny", "doodlenet"];

/**
* Check if a string is a valid http url
* @param {string} string - The string to check
* @returns {boolean} - True if the string is a valid http url
*/
function isHttpUrl(string) {
let url;
try {
url = new URL(string);
} catch (e) {
return false;
}
return url.protocol === "http:" || url.protocol === "https:";
}

This comment has been minimized.

Copy link
@lindapaiste

lindapaiste Feb 29, 2024

Contributor

IMO, this function should go in /utils because it is not specific to this model.


class ImageClassifier {
/**
* Create an ImageClassifier.
Expand All @@ -44,36 +54,65 @@ class ImageClassifier {
this.signalStop = false; // Signal to stop the loop
this.prevCall = ""; // Track previous call to detectStart() or detectStop()

if (typeof modelNameOrUrl === "string") {
if (MODEL_OPTIONS.includes(modelNameOrUrl)) {
this.modelName = modelNameOrUrl;
this.modelUrl = null;
switch (this.modelName) {
case "mobilenet":
this.modelToUse = mobilenet;
this.version = options.version || DEFAULTS.mobilenet.version;
this.alpha = options.alpha || DEFAULTS.mobilenet.alpha;
this.topk = options.topk || DEFAULTS.mobilenet.topk;
break;
case "darknet":
this.version = "reference"; // this a 28mb model
this.modelToUse = darknet;
break;
case "darknet-tiny":
this.version = "tiny"; // this a 4mb model
this.modelToUse = darknet;
break;
case "doodlenet":
this.modelToUse = doodlenet;
break;
default:
this.modelToUse = null;
}
} else {
// its a url, we expect to find model.json
this.modelUrl = modelNameOrUrl;
// The teachablemachine urls end with a slash, so add model.json to complete the full path
if (this.modelUrl.endsWith("/")) this.modelUrl += "model.json";
if (typeof modelNameOrUrl === "string" && isHttpUrl(modelNameOrUrl)) {
// its a url, we expect to find model.json
this.modelUrl = modelNameOrUrl;
// The teachablemachine urls end with a slash, so add model.json to complete the full path
if (this.modelUrl.endsWith("/")) this.modelUrl += "model.json";
} else {
// its a model name
this.modelUrl = null;
this.modelName = handleModelName(
modelNameOrUrl,
MODEL_OPTIONS,
"mobilenet",
"imageClassifier"
);

switch (this.modelName) {
case "mobilenet":
this.modelToUse = mobilenet;
const config = handleOptions(
options,
{
version: {
type: "enum",
enums: [1, 2],
default: 2,
},
alpha: {
type: "enum",
enums: (config) =>
config.version === 1
? [0.25, 0.5, 0.75, 1.0]
: [0.5, 0.75, 1.0],
default: 1.0,
},
topk: {
type: "number",
integer: true,
default: 3,

This comment has been minimized.

Copy link
@lindapaiste

lindapaiste Feb 29, 2024

Contributor

Applying topk only to mobilenet matches what's in the existing code. But what's there in the existing code doesn't make any sense because all of the models support using topk. The topk is actually applied on our end after we get the predictions from the model so it's definitely not model-specific. I fixed that in ml5js/ml5-library#1362. I guess keep what you have for now, because it matches what's there currently.

},
},
"imageClassifier"
);
this.version = config.version;
this.alpha = config.alpha;
this.topk = config.topk;
break;
case "darknet":
this.version = "reference"; // this a 28mb model
this.modelToUse = darknet;
break;
case "darknet-tiny":
this.version = "tiny"; // this a 4mb model
this.modelToUse = darknet;
break;
case "doodlenet":
this.modelToUse = doodlenet;
break;
default:
this.modelToUse = null;
}
}
// Load the model
Expand Down Expand Up @@ -263,20 +302,11 @@ class ImageClassifier {
}

const imageClassifier = (modelName, optionsOrCallback, cb) => {
const args = handleArguments(modelName, optionsOrCallback, cb).require(
"string",
'Please specify a model to use. E.g: "MobileNet"'
);
const args = handleArguments(modelName, optionsOrCallback, cb);

const { string, options = {}, callback } = args;

let model = string;
// TODO: I think we should delete this.
if (model.indexOf("http") === -1) {
model = model.toLowerCase();
}

const instance = new ImageClassifier(model, options, callback);
const instance = new ImageClassifier(string, options, callback);
return instance;
};

Expand Down

0 comments on commit ad18e64

Please sign in to comment.