Skip to content

Commit

Permalink
Fix CoreML GRU tensor sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 6, 2018
1 parent 5f9e541 commit 93faa39
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/coreml.js
Original file line number Diff line number Diff line change
Expand Up @@ -597,16 +597,19 @@ coreml.Node = class {
}
return { 'weightMatrix': true, 'recursionMatrix': true, 'biasVector': data.hasBiasVectors };
case 'gru':
this._initializers.push(new coreml.Tensor('Weights', 'updateGateWeightMatrix', [ data.updateGateWeightMatrix.length ], data.updateGateWeightMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'resetGateWeightMatrix', [ data.resetGateWeightMatrix.length ], data.resetGateWeightMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'outputGateWeightMatrix', [ data.outputGateWeightMatrix.length ], data.outputGateWeightMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'updateGateRecursionMatrix', [ data.updateGateRecursionMatrix.length ], data.updateGateRecursionMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'resetGateRecursionMatrix', [ data.resetGateRecursionMatrix.length ], data.resetGateRecursionMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'outputGateRecursionMatrix', [ data.outputGateRecursionMatrix.length ], data.outputGateRecursionMatrix));
var recursionMatrixShape = [ data.outputVectorSize, data.outputVectorSize ];
var weightMatrixShape = [ data.outputVectorSize, data.inputVectorSize ];
var biasVectorShape = [ data.outputVectorSize ];
this._initializers.push(new coreml.Tensor('Weights', 'updateGateWeightMatrix', weightMatrixShape, data.updateGateWeightMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'resetGateWeightMatrix', weightMatrixShape, data.resetGateWeightMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'outputGateWeightMatrix', weightMatrixShape, data.outputGateWeightMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'updateGateRecursionMatrix', recursionMatrixShape, data.updateGateRecursionMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'resetGateRecursionMatrix', recursionMatrixShape, data.resetGateRecursionMatrix));
this._initializers.push(new coreml.Tensor('Weights', 'outputGateRecursionMatrix', recursionMatrixShape, data.outputGateRecursionMatrix));
if (data.hasBiasVectors) {
this._initializers.push(new coreml.Tensor('Weights', 'updateGateBiasVector', [ data.updateGateBiasVector.length ], data.updateGateBiasVector));
this._initializers.push(new coreml.Tensor('Weights', 'resetGateBiasVector', [ data.resetGateBiasVector.length ], data.resetGateBiasVector));
this._initializers.push(new coreml.Tensor('Weights', 'outputGateBiasVector', [ data.outputGateBiasVector.length ], data.outputGateBiasVector));
this._initializers.push(new coreml.Tensor('Weights', 'updateGateBiasVector', biasVectorShape, data.updateGateBiasVector));
this._initializers.push(new coreml.Tensor('Weights', 'resetGateBiasVector', biasVectorShape, data.resetGateBiasVector));
this._initializers.push(new coreml.Tensor('Weights', 'outputGateBiasVector', biasVectorShape, data.outputGateBiasVector));
}
return {
'updateGateWeightMatrix': true, 'resetGateWeightMatrix': true, 'outputGateWeightMatrix': true,
Expand Down

0 comments on commit 93faa39

Please sign in to comment.