diff --git a/src/coreml.js b/src/coreml.js index b5cbb3c760..cde66a7da2 100644 --- a/src/coreml.js +++ b/src/coreml.js @@ -693,6 +693,7 @@ coreml.Node = class { this._initializer(initializers, 'Weights', 'weights', [ data.inputDim, data.outputChannels ], data.weights); return { 'weights': true }; case 'loadConstant': + case 'loadConstantND': this._initializer(initializers, 'Weights', 'data', data.shape, data.data); return { 'data': true }; case 'scale': @@ -809,16 +810,8 @@ coreml.Attribute = class { this._type = schema.type; } if (this._type && coreml.proto) { - let type = coreml.proto; - const parts = this._type.split('.'); - while (type && parts.length > 0) { - type = type[parts.shift()]; - } - if (type && type[this._value]) { - this._value = type[this.value]; - } + this._value = coreml.Utility.enum(this._type, this._value); } - if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) { this._visible = false; } @@ -1099,6 +1092,32 @@ coreml.OptionalType = class { } }; +coreml.Utility = class { + + static enum(name, value) { + let type = coreml.proto; + const parts = name.split('.'); + while (type && parts.length > 0) { + type = type[parts.shift()]; + } + if (type) { + coreml.Utility._enumKeyMap = coreml.Utility._enumKeyMap || new Map(); + if (!coreml.Utility._enumKeyMap.has(name)) { + const map = new Map(); + for (const key of Object.keys(type)) { + map.set(type[key], key); + } + coreml.Utility._enumKeyMap.set(name, map); + } + const map = coreml.Utility._enumKeyMap.get(name); + if (map.has(value)) { + return map.get(value); + } + } + return value; + } +}; + coreml.Metadata = class { static open(host) {