From e21ac6b09d42612d560c1637478314d135f5c18e Mon Sep 17 00:00:00 2001 From: Lai Wei Date: Tue, 16 Feb 2021 14:34:54 -0800 Subject: [PATCH] fix rnn test Change-Id: I25f80a4f965e820d7d16aba515928b009d1a8b76 --- .../ai/djl/integration/tests/nn/BlockCoreTest.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java index df34fb7d777..46767e14057 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java @@ -473,7 +473,7 @@ public void testRNNTanh() throws IOException, MalformedModelException { Assertions.assertAlmostEquals(result.size(), 2); NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -18); + Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533); testEncode(manager, block); } } @@ -521,7 +521,9 @@ public void testRNNRelu() throws IOException, MalformedModelException { Assertions.assertAlmostEquals(result.size(), 2); NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -908); + // loss should be the same as testRNNTanh because outputs are equal for each + // class + Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533); testEncode(manager, block); } } @@ -571,7 +573,7 @@ public void testLstm() throws IOException, MalformedModelException { Assertions.assertAlmostEquals(result.size(), 3); NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -16.340019); + Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533); testEncode(manager, block); } } @@ -628,7 +630,7 @@ public void testGRU() throws IOException, MalformedModelException { Assertions.assertAlmostEquals(result.size(), 1); NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -8.17537307E-4); + Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533); testEncode(manager, block); } }