-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use model name and options handling for imageClassifier
- Loading branch information
1 parent
561c9cd
commit ad18e64
Showing
1 changed file
with
78 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,17 +16,27 @@ import * as darknet from "./darknet"; | |
import * as doodlenet from "./doodlenet"; | ||
import callCallback from "../utils/callcallback"; | ||
import { imgToTensor, mediaReady } from "../utils/imageUtilities"; | ||
import handleOptions from "../utils/handleOptions"; | ||
import { handleModelName } from "../utils/handleOptions"; | ||
|
||
const DEFAULTS = { | ||
mobilenet: { | ||
version: 2, | ||
alpha: 1.0, | ||
topk: 3, | ||
}, | ||
}; | ||
const IMAGE_SIZE = 224; | ||
const MODEL_OPTIONS = ["mobilenet", "darknet", "darknet-tiny", "doodlenet"]; | ||
|
||
/** | ||
* Check if a string is a valid http url | ||
* @param {string} string - The string to check | ||
* @returns {boolean} - True if the string is a valid http url | ||
*/ | ||
function isHttpUrl(string) { | ||
let url; | ||
try { | ||
url = new URL(string); | ||
} catch (e) { | ||
return false; | ||
} | ||
return url.protocol === "http:" || url.protocol === "https:"; | ||
} | ||
This comment has been minimized.
Sorry, something went wrong. |
||
|
||
class ImageClassifier { | ||
/** | ||
* Create an ImageClassifier. | ||
|
@@ -44,36 +54,65 @@ class ImageClassifier { | |
this.signalStop = false; // Signal to stop the loop | ||
this.prevCall = ""; // Track previous call to detectStart() or detectStop() | ||
|
||
if (typeof modelNameOrUrl === "string") { | ||
if (MODEL_OPTIONS.includes(modelNameOrUrl)) { | ||
this.modelName = modelNameOrUrl; | ||
this.modelUrl = null; | ||
switch (this.modelName) { | ||
case "mobilenet": | ||
this.modelToUse = mobilenet; | ||
this.version = options.version || DEFAULTS.mobilenet.version; | ||
this.alpha = options.alpha || DEFAULTS.mobilenet.alpha; | ||
this.topk = options.topk || DEFAULTS.mobilenet.topk; | ||
break; | ||
case "darknet": | ||
this.version = "reference"; // this a 28mb model | ||
this.modelToUse = darknet; | ||
break; | ||
case "darknet-tiny": | ||
this.version = "tiny"; // this a 4mb model | ||
this.modelToUse = darknet; | ||
break; | ||
case "doodlenet": | ||
this.modelToUse = doodlenet; | ||
break; | ||
default: | ||
this.modelToUse = null; | ||
} | ||
} else { | ||
// its a url, we expect to find model.json | ||
this.modelUrl = modelNameOrUrl; | ||
// The teachablemachine urls end with a slash, so add model.json to complete the full path | ||
if (this.modelUrl.endsWith("/")) this.modelUrl += "model.json"; | ||
if (typeof modelNameOrUrl === "string" && isHttpUrl(modelNameOrUrl)) { | ||
// its a url, we expect to find model.json | ||
this.modelUrl = modelNameOrUrl; | ||
// The teachablemachine urls end with a slash, so add model.json to complete the full path | ||
if (this.modelUrl.endsWith("/")) this.modelUrl += "model.json"; | ||
} else { | ||
// its a model name | ||
this.modelUrl = null; | ||
this.modelName = handleModelName( | ||
modelNameOrUrl, | ||
MODEL_OPTIONS, | ||
"mobilenet", | ||
"imageClassifier" | ||
); | ||
|
||
switch (this.modelName) { | ||
case "mobilenet": | ||
this.modelToUse = mobilenet; | ||
const config = handleOptions( | ||
options, | ||
{ | ||
version: { | ||
type: "enum", | ||
enums: [1, 2], | ||
default: 2, | ||
}, | ||
alpha: { | ||
type: "enum", | ||
enums: (config) => | ||
config.version === 1 | ||
? [0.25, 0.5, 0.75, 1.0] | ||
: [0.5, 0.75, 1.0], | ||
default: 1.0, | ||
}, | ||
topk: { | ||
type: "number", | ||
integer: true, | ||
default: 3, | ||
This comment has been minimized.
Sorry, something went wrong.
lindapaiste
Contributor
|
||
}, | ||
}, | ||
"imageClassifier" | ||
); | ||
this.version = config.version; | ||
this.alpha = config.alpha; | ||
this.topk = config.topk; | ||
break; | ||
case "darknet": | ||
this.version = "reference"; // this a 28mb model | ||
this.modelToUse = darknet; | ||
break; | ||
case "darknet-tiny": | ||
this.version = "tiny"; // this a 4mb model | ||
this.modelToUse = darknet; | ||
break; | ||
case "doodlenet": | ||
this.modelToUse = doodlenet; | ||
break; | ||
default: | ||
this.modelToUse = null; | ||
} | ||
} | ||
// Load the model | ||
|
@@ -263,20 +302,11 @@ class ImageClassifier { | |
} | ||
|
||
const imageClassifier = (modelName, optionsOrCallback, cb) => { | ||
const args = handleArguments(modelName, optionsOrCallback, cb).require( | ||
"string", | ||
'Please specify a model to use. E.g: "MobileNet"' | ||
); | ||
const args = handleArguments(modelName, optionsOrCallback, cb); | ||
|
||
const { string, options = {}, callback } = args; | ||
|
||
let model = string; | ||
// TODO: I think we should delete this. | ||
if (model.indexOf("http") === -1) { | ||
model = model.toLowerCase(); | ||
} | ||
|
||
const instance = new ImageClassifier(model, options, callback); | ||
const instance = new ImageClassifier(string, options, callback); | ||
return instance; | ||
}; | ||
|
||
|
IMO, this function should go in
/utils
because it is not specific to this model.