From 589f3876a0765c685bc0ff397d0935a5ae63dd6a Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 8 Dec 2018 16:16:47 -0800 Subject: [PATCH] Fix CoreML tensor shape (#193) --- src/coreml.js | 17 ++++++++++------- src/onnx.js | 5 ++++- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/coreml.js b/src/coreml.js index 8280370aa08..a05a6580e09 100644 --- a/src/coreml.js +++ b/src/coreml.js @@ -393,7 +393,7 @@ coreml.Graph = class { var result = ''; switch (type.Type) { case 'multiArrayType': - var shape = new coreml.TensorShape(null); + var shape = new coreml.TensorShape([]); if (type.multiArrayType.shape && type.multiArrayType.shape.length > 0) { shape = new coreml.TensorShape(type.multiArrayType.shape); } @@ -630,11 +630,11 @@ coreml.Node = class { case 'bias': this._initializers.push(new coreml.Tensor('Weights', 'bias', data.shapeBias, data.bias)); return { 'bias': true }; - case 'simpleRecurrentLayer': - this._initializers.push(new coreml.Tensor('Weights', 'weights', null, data.weightMatrix)); - this._initializers.push(new coreml.Tensor('Weights', 'recurrent', null, data.recursionMatrix)); + case 'simpleRecurrent': + this._initializers.push(new coreml.Tensor('Weights', 'weights', [ data.outputVectorSize, data.inputVectorSize ], data.weightMatrix)); + this._initializers.push(new coreml.Tensor('Weights', 'recurrent', [ data.outputVectorSize, data.inputVectorSize ], data.recursionMatrix)); if (data.hasBiasVectors) { - this._initializers.push(new coreml.Tensor('Weights', 'bias', null, data.biasVector)); + this._initializers.push(new coreml.Tensor('Weights', 'bias', [ data.outputVectorSize ], data.biasVector)); } return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors }; case 'gru': @@ -894,7 +894,7 @@ coreml.TensorType = class { constructor(dataType, shape) { this._dataType = dataType; - this._shape = shape || new coreml.TensorShape(null); + this._shape = shape || new coreml.TensorShape([]); } get dataType() { @@ -921,7 +921,10 @@ coreml.TensorShape = class { } toString() { - return this._dimensions ? ('[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']') : ''; + if (!this._dimensions || this._dimensions.length == 0) { + return ''; + } + return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']'; } }; diff --git a/src/onnx.js b/src/onnx.js index f10fa34125a..b5134fe3045 100644 --- a/src/onnx.js +++ b/src/onnx.js @@ -1039,7 +1039,10 @@ onnx.TensorShape = class { } toString() { - return (this._dimensions && this._dimensions.length) ? ('[' + this._dimensions.join(',') + ']') : ''; + if (!this._dimensions || this._dimensions.length == 0) { + return ''; + } + return '[' + this._dimensions.join(',') + ']'; } };