Skip to content

Commit

Permalink
Core ML linkedModel support (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Apr 20, 2020
1 parent 8d785db commit 2a79e04
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 33 deletions.
82 changes: 49 additions & 33 deletions src/coreml.js
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ coreml.Graph = class {
predictedProbabilitiesName = predictedProbabilitiesName ? predictedProbabilitiesName : '?';
let labelProbabilityInput = this._updateOutput(labelProbabilityLayerName, labelProbabilityLayerName + ':labelProbabilityLayerName');
let operator = classifier.ClassLabels;
this._nodes.push(new coreml.Node(this._metadata, this._group, operator, null, classifier[operator], [ labelProbabilityInput ], [ predictedProbabilitiesName, predictedFeatureName ]));
this._nodes.push(new coreml.Node(this._metadata, this._group, operator, null, '', classifier[operator], [ labelProbabilityInput ], [ predictedProbabilitiesName, predictedFeatureName ]));
}
}

Expand All @@ -183,7 +183,7 @@ coreml.Graph = class {
for (const p of preprocessing) {
let input = p.featureName ? p.featureName : preprocessorOutput;
preprocessorOutput = preprocessingInput + ':' + preprocessorIndex.toString();
this._createNode(scope, group, p.preprocessor, null, p[p.preprocessor], [ input ], [ preprocessorOutput ]);
this._createNode(scope, group, p.preprocessor, null, '', p[p.preprocessor], [ input ], [ preprocessorOutput ]);
preprocessorIndex++;
}
for (const node of inputNodes) {
Expand All @@ -200,10 +200,11 @@ coreml.Graph = class {

_loadModel(model, scope, group) {
this._groups = this._groups | (group.length > 0 ? true : false);
const description = model && model.description && model.description.metadata && model.description.metadata.shortDescription ? model.description.metadata.shortDescription : '';
if (model.neuralNetworkClassifier) {
const neuralNetworkClassifier = model.neuralNetworkClassifier;
for (const layer of neuralNetworkClassifier.layers) {
this._createNode(scope, group, layer.layer, layer.name, layer[layer.layer], layer.input, layer.output);
this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
}
this._updateClassifierOutput(group, neuralNetworkClassifier);
this._updatePreprocessing(scope, group, neuralNetworkClassifier.preprocessing);
Expand All @@ -212,15 +213,15 @@ coreml.Graph = class {
else if (model.neuralNetwork) {
const neuralNetwork = model.neuralNetwork;
for (const layer of neuralNetwork.layers) {
this._createNode(scope, group, layer.layer, layer.name, layer[layer.layer], layer.input, layer.output);
this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
}
this._updatePreprocessing(scope, group, neuralNetwork.preprocessing);
return 'Neural Network';
}
else if (model.neuralNetworkRegressor) {
const neuralNetworkRegressor = model.neuralNetworkRegressor;
for (const layer of neuralNetworkRegressor.layers) {
this._createNode(scope, group, layer.layer, layer.name, layer[layer.layer], layer.input, layer.output);
this._createNode(scope, group, layer.layer, layer.name, description, layer[layer.layer], layer.input, layer.output);
}
this._updatePreprocessing(scope, group, neuralNetworkRegressor);
return 'Neural Network Regressor';
Expand All @@ -244,7 +245,7 @@ coreml.Graph = class {
return 'Pipeline Regressor';
}
else if (model.glmClassifier) {
this._createNode(scope, group, 'glmClassifier', null,
this._createNode(scope, group, 'glmClassifier', null, description,
{
classEncoding: model.glmClassifier.classEncoding,
offset: model.glmClassifier.offset,
Expand All @@ -256,39 +257,43 @@ coreml.Graph = class {
return 'Generalized Linear Classifier';
}
else if (model.glmRegressor) {
this._createNode(scope, group, 'glmRegressor', null,
this._createNode(scope, group, 'glmRegressor', null, description,
model.glmRegressor,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
return 'Generalized Linear Regressor';
}
else if (model.dictVectorizer) {
this._createNode(scope, group, 'dictVectorizer', null, model.dictVectorizer,
this._createNode(scope, group, 'dictVectorizer', null, description,
model.dictVectorizer,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
return 'Dictionary Vectorizer';
}
else if (model.featureVectorizer) {
this._createNode(scope, group, 'featureVectorizer', null, model.featureVectorizer,
this._createNode(scope, group, 'featureVectorizer', null, description,
model.featureVectorizer,
coreml.Graph._formatFeatureDescriptionList(model.description.input),
[ model.description.output[0].name ]);
return 'Feature Vectorizer';
}
else if (model.treeEnsembleClassifier) {
this._createNode(scope, group, 'treeEnsembleClassifier', null, model.treeEnsembleClassifier.treeEnsemble,
this._createNode(scope, group, 'treeEnsembleClassifier', null, description,
model.treeEnsembleClassifier.treeEnsemble,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
this._updateClassifierOutput(group, model.treeEnsembleClassifier);
return 'Tree Ensemble Classifier';
}
else if (model.treeEnsembleRegressor) {
this._createNode(scope, group, 'treeEnsembleRegressor', null, model.treeEnsembleRegressor.treeEnsemble,
this._createNode(scope, group, 'treeEnsembleRegressor', null, description,
model.treeEnsembleRegressor.treeEnsemble,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
return 'Tree Ensemble Regressor';
}
else if (model.supportVectorClassifier) {
this._createNode(scope, group, 'supportVectorClassifier', null,
this._createNode(scope, group, 'supportVectorClassifier', null, description,
{
coefficients: model.supportVectorClassifier.coefficients,
denseSupportVectors: model.supportVectorClassifier.denseSupportVectors,
Expand All @@ -305,7 +310,7 @@ coreml.Graph = class {
return 'Support Vector Classifier';
}
else if (model.supportVectorRegressor) {
this._createNode(scope, group, 'supportVectorRegressor', null,
this._createNode(scope, group, 'supportVectorRegressor', null, description,
{
coefficients: model.supportVectorRegressor.coefficients,
kernel: model.supportVectorRegressor.kernel,
Expand All @@ -317,7 +322,7 @@ coreml.Graph = class {
return 'Support Vector Regressor';
}
else if (model.arrayFeatureExtractor) {
this._createNode(scope, group, 'arrayFeatureExtractor', null,
this._createNode(scope, group, 'arrayFeatureExtractor', null, description,
{ extractIndex: model.arrayFeatureExtractor.extractIndex },
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
Expand All @@ -327,7 +332,7 @@ coreml.Graph = class {
const categoryType = model.oneHotEncoder.CategoryType;
const oneHotEncoderParams = { outputSparse: model.oneHotEncoder.outputSparse };
oneHotEncoderParams[categoryType] = model.oneHotEncoder[categoryType];
this._createNode(scope, group, 'oneHotEncoder', null,
this._createNode(scope, group, 'oneHotEncoder', null, description,
oneHotEncoderParams,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
Expand All @@ -339,22 +344,22 @@ coreml.Graph = class {
let imputerParams = {};
imputerParams[imputedValue] = model.imputer[imputedValue];
imputerParams[replaceValue] = model.imputer[replaceValue];
this._createNode(scope, group, 'oneHotEncoder', null,
this._createNode(scope, group, 'oneHotEncoder', null, description,
imputerParams,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
return 'Imputer';

}
else if (model.normalizer) {
this._createNode(scope, group, 'normalizer', null,
this._createNode(scope, group, 'normalizer', null, description,
model.normalizer,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
return 'Normalizer';
}
else if (model.wordTagger) {
this._createNode(scope, group, 'wordTagger', null,
this._createNode(scope, group, 'wordTagger', null, description,
model.wordTagger,
[ model.description.input[0].name ],
[
Expand All @@ -366,7 +371,7 @@ coreml.Graph = class {
return 'Word Tagger';
}
else if (model.textClassifier) {
this._createNode(scope, group, 'textClassifier', null,
this._createNode(scope, group, 'textClassifier', null, description,
model.textClassifier,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
Expand All @@ -379,7 +384,7 @@ coreml.Graph = class {
iouThreshold: model.nonMaximumSuppression.iouThreshold,
confidenceThreshold: model.nonMaximumSuppression.confidenceThreshold
};
this._createNode(scope, group, 'nonMaximumSuppression', null,
this._createNode(scope, group, 'nonMaximumSuppression', null, description,
nonMaximumSuppressionParams,
[
model.nonMaximumSuppression.confidenceInputFeatureName,
Expand All @@ -397,40 +402,46 @@ coreml.Graph = class {
const visionFeaturePrintParams = {
scene: model.visionFeaturePrint.scene
}
this._createNode(scope, group, 'visionFeaturePrint', null,
this._createNode(scope, group, 'visionFeaturePrint', null, description,
visionFeaturePrintParams,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
return 'Vision Feature Print';
}
else if (model.soundAnalysisPreprocessing) {
this._createNode(scope, group, 'soundAnalysisPreprocessing', null,
this._createNode(scope, group, 'soundAnalysisPreprocessing', null, description,
model.soundAnalysisPreprocessing,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
return 'Sound Analysis Preprocessing';
}
else if (model.kNearestNeighborsClassifier) {
this._createNode(scope, group, 'kNearestNeighborsClassifier', null,
this._createNode(scope, group, 'kNearestNeighborsClassifier', null, description,
model.kNearestNeighborsClassifier,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
this._updateClassifierOutput(group, model.kNearestNeighborsClassifier);
return 'kNearestNeighborsClassifier';
}
else if (model.itemSimilarityRecommender) {
const itemSimilarityRecommenderParams = {
itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector,
itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities
}
this._createNode(scope, group, 'itemSimilarityRecommender', null,
itemSimilarityRecommenderParams,
this._createNode(scope, group, 'itemSimilarityRecommender', null, description,
{
itemStringIds: model.itemSimilarityRecommender.itemStringIds.vector,
itemItemSimilarities: model.itemSimilarityRecommender.itemItemSimilarities
},
model.description.input.map((feature) => feature.name),
model.description.output.map((feature) => feature.name));
return 'itemSimilarityRecommender'
}
else if (model.linkedModel) {
this._createNode(scope, group, 'linkedModel', null, description,
model.linkedModel.linkedModelFile,
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
return 'customModel';
}
else if (model.customModel) {
this._createNode(scope, group, 'customModel', null,
this._createNode(scope, group, 'customModel', null, description,
{ className: model.customModel.className, parameters: model.customModel.parameters },
[ model.description.input[0].name ],
[ model.description.output[0].name ]);
Expand All @@ -439,7 +450,7 @@ coreml.Graph = class {
throw new coreml.Error("Unknown model type '" + JSON.stringify(Object.keys(model)) + "'.");
}

_createNode(scope, group, operator, name, data, inputs, outputs) {
_createNode(scope, group, operator, name, description, data, inputs, outputs) {
inputs = inputs.map((input) => scope[input] ? scope[input].argument : input);
outputs = outputs.map((output) => {
if (scope[output]) {
Expand All @@ -455,7 +466,7 @@ coreml.Graph = class {
return output;
});

const node = new coreml.Node(this._metadata, group, operator, name, data, inputs, outputs);
const node = new coreml.Node(this._metadata, group, operator, name, description, data, inputs, outputs);
this._nodes.push(node);
return node;
}
Expand Down Expand Up @@ -583,13 +594,14 @@ coreml.Argument = class {

coreml.Node = class {

constructor(metadata, group, operator, name, data, inputs, outputs) {
constructor(metadata, group, operator, name, description, data, inputs, outputs) {
this._metadata = metadata;
if (group) {
this._group = group;
}
this._operator = operator;
this._name = name || '';
this._description = description || '';
this._attributes = [];
let initializers = [];
if (data) {
Expand Down Expand Up @@ -620,6 +632,10 @@ coreml.Node = class {
return this._name;
}

get description() {
return this._description;
}

get metadata() {
return this._metadata.type(this.operator);
}
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,13 @@
"format": "Core ML v1",
"link": "https://github.com/gavi/Iris"
},
{
"type": "coreml",
"target": "LinkedUpdatableTinyDrawingClassifier.mlmodel",
"source": "https://github.com/lutzroeder/netron/files/4500539/LinkedUpdatableTinyDrawingClassifier.zip[LinkedUpdatableTinyDrawingClassifier.mlmodel]",
"format": "Core ML v4",
"link": "https://github.com/lutzroeder/netron/issues/193"
},
{
"type": "coreml",
"target": "MessageClassifier.mlmodel",
Expand Down

0 comments on commit 2a79e04

Please sign in to comment.