Skip to content

Commit

Permalink
fix softmax flag behavior
Browse files Browse the repository at this point in the history
Change-Id: I18ee08116a7ca302a0542ef5d361a64c9e5e2227
  • Loading branch information
roywei committed Feb 16, 2021
1 parent b9fc3d0 commit 6dd1604
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ public SoftmaxCrossEntropyLoss(String name) {
* @param name the name of the loss
* @param weight the weight to apply on the loss value, default 1
* @param classAxis the axis that represents the class probabilities, default -1
* @param sparseLabel whether labels are integer array or probabilities, default true
* @param fromLogit whether predictions are log probabilities or un-normalized numbers, default
* false
* @param sparseLabel whether labels are 1-D integer array or 2-D probabilities of [batch_size,
* n-class], default true
* @param fromLogit whether predictions are un-normalized numbers or log probabilities, if true,
* logSoftmax will be applied to input, default true
*/
public SoftmaxCrossEntropyLoss(
String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) {
Expand All @@ -69,7 +70,7 @@ public SoftmaxCrossEntropyLoss(
@Override
public NDArray evaluate(NDList label, NDList prediction) {
NDArray pred = prediction.singletonOrThrow();
if (!fromLogit) {
if (fromLogit) {
pred = pred.logSoftmax(classAxis);
}
NDArray loss;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,31 @@ public void l2LossTest() {
@Test
public void softmaxCrossEntropyTest() {
try (NDManager manager = NDManager.newBaseManager()) {
// test fromLogits=true, sparseLabel=true
NDArray pred = manager.create(new float[] {1, 2, 3, 4, 5});
NDArray label = manager.ones(new Shape(1));
Assertions.assertAlmostEquals(
Loss.softmaxCrossEntropyLoss().evaluate(new NDList(label), new NDList(pred)),
Loss.softmaxCrossEntropyLoss("loss", 1, -1, true, true)
.evaluate(new NDList(label), new NDList(pred)),
manager.create(3.45191431f));

// test fromLogits=false, sparseLabel=true
pred =
manager.create(
new float[] {4.0f, 2.0f, 1.0f, 0.0f, 5.0f, 1.0f}, new Shape(2, 3));
label = manager.create(new float[] {0, 1}, new Shape(2));
NDArray nonSparseLabel =
manager.create(new float[] {1f, 0f, 0f, 0f, 1f, 0f}, new Shape(2, 3));
NDArray sparseOutput =
Loss.softmaxCrossEntropyLoss()
.evaluate(new NDList(label), new NDList(pred.logSoftmax(-1)));
// test fromLogits=false, sparseLabel=false
NDArray nonSparseOutput =
Loss.softmaxCrossEntropyLoss("loss", 1, -1, false, false)
.evaluate(new NDList(nonSparseLabel), new NDList(pred.logSoftmax(-1)));

Assertions.assertAlmostEquals(sparseOutput, nonSparseOutput);
Assertions.assertAlmostEquals(sparseOutput, manager.create(0.09729549f));
}
}

Expand Down

0 comments on commit 6dd1604

Please sign in to comment.