Skip to content

Commit

Permalink
fix issue in TrainMnistWithLSTM (#1965)
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Aug 26, 2022
1 parent a45dc3b commit 0c12a55
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
5 changes: 4 additions & 1 deletion api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java
Expand Up @@ -68,7 +68,10 @@ public RecurrentBlock(BaseBuilder<?> builder) {
bidirectional = builder.bidirectional;
returnState = builder.returnState;

Parameter.Type[] parameterTypes = {Parameter.Type.WEIGHT, Parameter.Type.BIAS};
Parameter.Type[] parameterTypes =
hasBiases
? new Parameter.Type[] {Parameter.Type.WEIGHT, Parameter.Type.BIAS}
: new Parameter.Type[] {Parameter.Type.WEIGHT};
String[] directions = {"l"};
if (builder.bidirectional) {
directions = new String[] {"l", "r"};
Expand Down
Expand Up @@ -835,7 +835,11 @@ public NDList lstm(
boolean training,
boolean bidirectional,
boolean batchFirst) {
int numParams = numLayers * ((hasBiases) ? 4 : 2) * ((bidirectional) ? 2 : 1);
if (!hasBiases) {
throw new UnsupportedOperationException(
"Setting hasBias to be false is not supported on MXNet engine.");
}
int numParams = numLayers * 4 * (bidirectional ? 2 : 1);
Preconditions.checkArgument(
params.size() == numParams,
"The size of Params is incorrect expect "
Expand Down

0 comments on commit 0c12a55

Please sign in to comment.