diff --git a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java index df34fb7d777..46767e14057 100644 --- a/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/nn/BlockCoreTest.java @@ -473,7 +473,7 @@ public void testRNNTanh() throws IOException, MalformedModelException { Assertions.assertAlmostEquals(result.size(), 2); NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -18); + Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533); testEncode(manager, block); } } @@ -521,7 +521,9 @@ public void testRNNRelu() throws IOException, MalformedModelException { Assertions.assertAlmostEquals(result.size(), 2); NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -908); + // loss should be the same as testRNNTanh because outputs are equal for each + // class + Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533); testEncode(manager, block); } } @@ -571,7 +573,7 @@ public void testLstm() throws IOException, MalformedModelException { Assertions.assertAlmostEquals(result.size(), 3); NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -16.340019); + Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533); testEncode(manager, block); } } @@ -628,7 +630,7 @@ public void testGRU() throws IOException, MalformedModelException { Assertions.assertAlmostEquals(result.size(), 1); NDArray lossValue = loss.evaluate(new NDList(labels), new NDList(result.head())); - Assertions.assertAlmostEquals(lossValue.getFloat(), -8.17537307E-4); + Assertions.assertAlmostEquals(lossValue.getFloat(), 24.9533); testEncode(manager, block); } }