From d61c446398797e15b9e33e46c2f9a71101dccbd8 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 17 Jan 2019 16:59:01 +1100 Subject: [PATCH 1/6] Loss fixes --- .../api/ops/impl/loss/SigmoidCrossEntropyLoss.java | 8 ++++++++ .../nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java | 12 ++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java index a69dcfc47897..468224939385 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java @@ -68,4 +68,12 @@ public String opName() { public String tensorflowName() { return "sigmoid_cross_entropy"; } + + @Override + public List doDiff(List grad){ + //No external gradient + //Args are: predictions, weights, label + SDVariable[] grads = f().lossSigmoidCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing); + return Arrays.asList(grads); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java index db585136f6aa..f26c97aa704e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java @@ -40,8 +40,16 @@ public int getNumOutputs(){ @Override public List calculateOutputDataTypes(List inputDataTypes){ - - return Arrays.asList(arg(0).dataType(), arg(1).dataType(), arg(2).dataType()); + Preconditions.checkState(inputDataTypes.get(0).isFPType(), "Input 0 (predictions) must be a floating point type; inputs datatypes are %s for %s", + inputDataTypes, getClass()); + DataType dt0 = inputDataTypes.get(0); + DataType dt1 = arg(1).dataType(); + DataType dt2 = arg(2).dataType(); + if(!dt1.isFPType()) + dt1 = dt0; + if(!dt2.isFPType()) + dt2 = dt0; + return Arrays.asList(dt0, dt1, dt2); } @Override From 79c08729c04bd435c76400f4b29b9ecde2e98525 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 17 Jan 2019 17:53:43 +1100 Subject: [PATCH 2/6] L2 loss backprop, tests, convenience methods --- .../DifferentialFunctionFactory.java | 4 + .../org/nd4j/autodiff/samediff/SameDiff.java | 24 ++++- .../nd4j/linalg/api/ops/impl/loss/L2Loss.java | 11 ++- .../opvalidation/LossOpValidation.java | 89 ++++++++++++++++--- 4 files changed, 114 insertions(+), 14 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index b3554f0eb0ea..b15e74ba4bee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1324,6 +1324,10 @@ public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable return new WeightedCrossEntropyLoss(sameDiff(), targets, inputs, weights).outputVariable(); } + public SDVariable lossL2(SDVariable var){ + return new L2Loss(sameDiff(), var).outputVariable(); + } + public SDVariable lossAbsoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ return new AbsoluteDifferenceLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index fb2cf69d5601..2d0bb276c5f3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -8709,6 +8709,26 @@ public SDVariable tensorMmul(String name, return updateVariableNameAndReference(result, name); } + /** + * L2 loss: 1/2 * sum(x^2) + * @param var Variable to calculate L2 loss of + * @return L2 loss + */ + public SDVariable lossL2(@NonNull SDVariable var){ + return lossL2(null, var); + } + + /** + * L2 loss: 1/2 * sum(x^2) + * @param name Name of the output variable + * @param var Variable to calculate L2 loss of + * @return L2 loss + */ + public SDVariable lossL2(String name, @NonNull SDVariable var){ + SDVariable ret = f().lossL2(var); + return updateVariableNameAndReference(ret, name); + } + /** * See {@link #lossAbsoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}. */ @@ -10540,7 +10560,7 @@ protected int asFlatNode(@NonNull DifferentialFunction node, @NonNull FlatBuffer inPaired.add(IntPair.createIntPair(bufferBuilder, nodeId, outIdx)); } - log.debug("Own Name: {}", node.getOwnName()); + log.trace("Own Name: {}", node.getOwnName()); int ownId = id != null ? id : idCounter.incrementAndGet(); //forwardMap.containsKey(node.getOwnName()) ? forwardMap.get(node.getOwnName()) : idCounter.incrementAndGet(); String[] outNames = node.outputVariablesNames(); for(String s : outNames){ @@ -10665,7 +10685,7 @@ public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration con List allVars = variables(); for (SDVariable variable : allVars) { INDArray arr = variable.getArr(); - log.debug("Exporting variable: [{}]", variable.getVarName()); + log.trace("Exporting variable: [{}]", variable.getVarName()); //If variable is the output of some op - let's use the ONE index for exporting, and properly track the output // numbers. For example, unstack(x) -> y0, y1, y2 -> the y's should be say (3,0), (3,1), (3,2) NOT (4,0), (5,0), (6,0) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java index b6d388d52a2c..cfef2a61b5dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java @@ -34,8 +34,8 @@ @NoArgsConstructor public class L2Loss extends DynamicCustomOp { - public L2Loss(SameDiff sameDiff, SDVariable[] args) { - super(null, sameDiff, args); + public L2Loss(SameDiff sameDiff, SDVariable var) { + super(sameDiff, new SDVariable[]{var}); } @Override @@ -59,4 +59,11 @@ public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes.get(0).isFPType(), "Input datatype must be floating point for %s, got %s", getClass(), inputDataTypes); return inputDataTypes; } + + @Override + public List doDiff(List grad){ + //L2 loss: L = 1/2 * sum(x_i^2) + //dL/dxi = xi + return Collections.singletonList(f().identity(arg())); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index 0f7c552e6064..278f34da2dac 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -33,10 +33,10 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import java.util.ArrayList; -import java.util.List; +import java.util.*; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; @Slf4j public class LossOpValidation extends BaseOpValidation { @@ -44,16 +44,29 @@ public LossOpValidation(Nd4jBackend backend) { super(backend); } + public static final Set NO_BP_YET = new HashSet<>(); + static { + NO_BP_YET.addAll(Arrays.asList("hinge", "huber", "l2_loss", "poisson", "mpwse")); + } + @Test public void testLoss2d() { - OpValidationSuite.ignoreFailing(); //2019/01/17 - WIP, some passing, some failing + OpValidationSuite.ignoreFailing(); //2019/01/17 - Some passing, some not yet implemented, some issues: Issue 17 https://github.com/deeplearning4j/deeplearning4j/issues/6958 Nd4j.getRandom().setSeed(12345); List failed = new ArrayList<>(); + int totalRun = 0; for (String fn : new String[]{"absdiff", "cosine", "hinge", "huber", "log", "mse", "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse", "softmaxxentlogits", "sparsesoftmax"}) { + + if((fn.equals("softmaxxentlogits") || fn.equals("sparsesoftmax")) && OpValidationSuite.IGNORE_FAILING){ + log.warn("NOT YET IMPLEMENTED: {}", fn); + continue; + } + + for(String weights : new String[]{"none", "scalar", "perExample", "perOutput"}) { if((fn.startsWith("softmax") || fn.equals("cosine")) && weights.equals("perOutput")) continue; //Skip this combination (not possible) @@ -66,6 +79,9 @@ public void testLoss2d() { if(fn.equals("mpwse") && (reduction != LossReduce.MEAN_BY_WEIGHT || weights.equals("perOutput"))) //LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT) continue; //MPWSE only provides scalar output - i.e., no other reduction modes. And only none/scalar/per-example weights + if((fn.equals("softmaxxent") || fn.equals("softmaxxent_smooth")) && reduction == LossReduce.NONE) + continue; //Combination not supported (doesn't make sense) + SameDiff sd = SameDiff.create(); int nOut = 4; @@ -84,7 +100,12 @@ public void testLoss2d() { wArrBroadcast = Nd4j.valueArrayOf(minibatch, nOut, 1.0).castTo(DataType.DOUBLE); break; case "perExample": - w = sd.var("weights", Nd4j.create(new double[]{0,0,1,1,2,2,3,3,4,4}).reshape(minibatch, 1)); + INDArray wpe = Nd4j.create(new double[]{0,0,1,1,2,2,3,3,4,4}); + if(!fn.equals("softmaxxent") && !fn.equals("softmaxxent_smooth")){ + //Softmaxxent only supports rank 1 not rank 2?? + wpe = wpe.reshape(minibatch, 1); + } + w = sd.var("weights", wpe); wArrBroadcast = Nd4j.create(DataType.DOUBLE, minibatch, nOut).addiColumnVector(w.getArr()); break; case "perOutput": @@ -131,8 +152,8 @@ public void testLoss2d() { double delta = 1.0; INDArray absDiff = Transforms.abs(labelsArr.sub(predictionsArr)); INDArray diff = labelsArr.sub(predictionsArr); - INDArray lte = absDiff.lte(delta); - INDArray gt = absDiff.gt(delta); + INDArray lte = absDiff.lte(delta).castTo(DataType.DOUBLE); + INDArray gt = absDiff.gt(delta).castTo(DataType.DOUBLE); expOut = diff.mul(diff).mul(0.5).muli(lte); expOut.addi(absDiff.mul(delta).subi(0.5 * delta * delta).mul(gt)); loss = sd.lossHuber("loss", labels, predictions, w, reduction, delta); @@ -207,7 +228,6 @@ public void testLoss2d() { } } } -// expOut.divi(pairCount); loss = sd.lossMeanPairwiseSquaredError("loss", labels, predictions, w); break; case "softmaxxentlogits": @@ -275,10 +295,16 @@ public void testLoss2d() { loss = loss.sum(); } + boolean doGradCheck = true; + if (OpValidationSuite.IGNORE_FAILING && NO_BP_YET.contains(fn)) { + log.warn("--- Skipping gradient check for: {} ---", fn); + doGradCheck = false; + } + TestCase tc = new TestCase(sd) .expectedOutput("loss", expOut) - .gradientCheck(true) - .testFlatBufferSerialization(TestCase.TestSerialization.NONE) //TODO Re-enable later + .gradientCheck(doGradCheck) + .testFlatBufferSerialization(TestCase.TestSerialization.BOTH) ; String error; @@ -291,11 +317,12 @@ public void testLoss2d() { if (error != null) { failed.add(msg + ": " + error); } + totalRun++; } } } - assertEquals(failed.toString(), 0, failed.size()); + assertEquals(failed.size() + " of " + totalRun + " failed: " + failed.toString(), 0, failed.size()); } @@ -316,4 +343,46 @@ public void testCosineDistance(){ INDArray exp = Nd4j.scalar(0.6); //https://github.com/deeplearning4j/deeplearning4j/issues/6532 assertEquals(exp, out); } + + @Test + public void testL2Loss(){ + + for( int rank=0; rank<=3; rank++ ){ + long[] shape; + switch (rank){ + case 0: + shape = new long[0]; + break; + case 1: + shape = new long[]{5}; + break; + case 2: + shape = new long[]{3,4}; + break; + case 3: + shape = new long[]{2,3,4}; + break; + case 4: + shape = new long[]{2,3,2,3}; + break; + default: + throw new RuntimeException(); + } + INDArray arr = Nd4j.rand(DataType.DOUBLE, shape); + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var("v", arr); + SDVariable loss = sd.lossL2("loss", in); + + INDArray exp = arr.mul(arr).sum().muli(0.5); + + TestCase tc = new TestCase(sd) + .expectedOutput("loss", exp) + .gradientCheck(true) + .testFlatBufferSerialization(TestCase.TestSerialization.BOTH); + + String err = OpValidation.validate(tc); + assertNull(err); + } + } } From 57649eaa4b7074b7d4dd446df6755158ee38fa52 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 17 Jan 2019 18:05:30 +1100 Subject: [PATCH 3/6] Remove SoftmaxCrossEntropyWithLogits as mathematically the same as SoftmaxCrossEntropy (and not used for import) --- .../SoftmaxCrossEntropyWithLogitsLoss.java | 79 ------------------- .../SoftmaxCrossEntropyWithLogitsLossBp.java | 55 ------------- .../opvalidation/LossOpValidation.java | 8 +- 3 files changed, 3 insertions(+), 139 deletions(-) delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java deleted file mode 100644 index c5d461e76401..000000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java +++ /dev/null @@ -1,79 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.loss; - -import lombok.NoArgsConstructor; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.Op; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; - - -/** - * Softmax cross entropy loss with Logits - * - * @author Max Pumperla - */ -@NoArgsConstructor -public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp { - - protected int classesDim; - - public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) { - super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false); - this.classesDim = classesDim; - addIArgument(classesDim); - } - - @Override - public String opName() { - return "softmax_cross_entropy_loss_with_logits"; - } - - @Override - public String tensorflowName() { - return "SoftmaxCrossEntropyWithLogits"; - } - - @Override - public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), - "Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes); - - return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions - } - - @Override - public List doDiff(List grad){ - //No external gradient - //Args: logits, weigths, label - SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(2), arg(0), arg(1), classesDim); - return Arrays.asList(grads); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java deleted file mode 100644 index a27df6884937..000000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java +++ /dev/null @@ -1,55 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.loss.bp; - -import lombok.NoArgsConstructor; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ops.DynamicCustomOp; - -import java.util.Collections; -import java.util.List; - - -/** - * Softmax cross entropy loss with Logits - * - * @author Max Pumperla - */ -@NoArgsConstructor -public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp { - - protected int classesDim; - - public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) { - super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false); - this.classesDim = classesDim; - addIArgument(classesDim); - } - - @Override - public String opName() { - return "softmax_cross_entropy_loss_with_logits_grad"; - } - - @Override - public List doDiff(List grad){ - throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported"); - } -} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index 278f34da2dac..b0e2935ee3d8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -59,9 +59,10 @@ public void testLoss2d() { int totalRun = 0; for (String fn : new String[]{"absdiff", "cosine", "hinge", "huber", "log", "mse", - "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse", "softmaxxentlogits", "sparsesoftmax"}) { + "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse", + "sparsesoftmax", "sparsesoftmax_onehot"}) { - if((fn.equals("softmaxxentlogits") || fn.equals("sparsesoftmax")) && OpValidationSuite.IGNORE_FAILING){ + if((fn.equals("sparsesoftmax")) && OpValidationSuite.IGNORE_FAILING){ log.warn("NOT YET IMPLEMENTED: {}", fn); continue; } @@ -229,9 +230,6 @@ public void testLoss2d() { } } loss = sd.lossMeanPairwiseSquaredError("loss", labels, predictions, w); - break; - case "softmaxxentlogits": - break; case "sparsesoftmax": From 92c72d6cee22cf9fdf74507b554da0cba2fb97f7 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 17 Jan 2019 20:12:30 +1100 Subject: [PATCH 4/6] Sparse softmax cross entropy loss function + tests --- .../DifferentialFunctionFactory.java | 8 ++-- .../nd4j/autodiff/samediff/SDVariable.java | 2 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 24 +++++++++++ .../autodiff/validation/GradCheckUtil.java | 4 +- ...arseSoftmaxCrossEntropyLossWithLogits.java | 11 ++++- ...seSoftmaxCrossEntropyLossWithLogitsBp.java | 13 +++++- .../opvalidation/LossOpValidation.java | 40 ++++++++++++++----- 7 files changed, 83 insertions(+), 19 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index b15e74ba4bee..ce6ba607325f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1392,12 +1392,12 @@ public SDVariable[] lossSoftmaxCrossEntropyBp(SDVariable labels, SDVariable logi return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); } - public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, weights, labels, classDim).outputVariable(); + public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){ + return new SparseSoftmaxCrossEntropyLossWithLogits(sameDiff(), logits, labels).outputVariable(); } - public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables(); + public SDVariable[] lossSparseSoftmaxCrossEntropyBp(SDVariable logits, SDVariable labels){ + return new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff(), logits, labels).outputVariables(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 312f4361249c..e7a2f7ea9c5b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -1667,7 +1667,7 @@ public INDArray eval() { @Override public String toString() { - return varName; + return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType + ")"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 2d0bb276c5f3..09dce5d92871 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -1267,6 +1267,7 @@ public void addArgsFor(SDVariable[] variables, DifferentialFunction function) { * @return The differential function that this variable is an output of, or null if it is not the output of a function */ public DifferentialFunction getVariableOutputFunction(String variableName) { + Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName); if(variables.get(variableName).getOutputOfOp() == null) return null; return ops.get(variables.get(variableName).getOutputOfOp()).getOp(); @@ -9054,6 +9055,29 @@ public SDVariable lossSoftmaxCrossEntropy(String name, @NonNull SDVariable oneHo return updateVariableNameAndReference(result, name); } + /** + * See {@link #lossSparseSoftmaxCrossEntropy(String, SDVariable, SDVariable)} + */ + public SDVariable lossSparseSoftmaxCrossEntropy(@NonNull SDVariable logits, @NonNull SDVariable labels) { + return lossSparseSoftmaxCrossEntropy(null, logits, labels); + } + + /** + * As per {@link #lossSoftmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce)} but the labels variable + * is represented as an integer array instead of the equivalent one-hot array.
+ * i.e., if logits are rank N, then labels have rank N-1 + * + * @param name Name of the output variable. May be null + * @param logits Logits array ("pre-softmax activations") + * @param labels Labels array. Must be an integer type. + * @return Softmax cross entropy + */ + public SDVariable lossSparseSoftmaxCrossEntropy(String name, @NonNull SDVariable logits, @NonNull SDVariable labels) { + Preconditions.checkState(labels.dataType().isIntType(), "Labels variable must be an integer type: got %s", logits); + SDVariable ret = f().lossSparseSoftmaxCrossEntropy(logits, labels); + return updateVariableNameAndReference(ret, name); + } + /** * TODO * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java index f4d3aabec507..3b3c0eda1600 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java @@ -129,7 +129,9 @@ public static boolean checkGradients(SameDiff sd, Map placehold Set gradVarNames = new HashSet<>(); for(Variable v : sd.getVariables().values()){ if(v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER){ - gradVarNames.add(v.getVariable().getGradient().getVarName()); + SDVariable g = v.getVariable().getGradient(); + Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable()); + gradVarNames.add(g.getVarName()); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java index 2fefaea2be79..9c385c42513d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java @@ -46,7 +46,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp { public SparseSoftmaxCrossEntropyLossWithLogits(SameDiff sameDiff, SDVariable logits, SDVariable labels) { - super(null, sameDiff, new SDVariable[]{logits, labels}, false); + super(null, sameDiff, new SDVariable[]{labels, logits}, false); } @@ -87,6 +87,13 @@ public Op.Type opType() { @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input datatypes for %s, got %s", getClass(), inputDataTypes); - return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions + return Collections.singletonList(inputDataTypes.get(1)); //Same as predictions (logits) + } + + @Override + public List doDiff(List grad){ + //args: label, logits + SDVariable[] ret = f().lossSparseSoftmaxCrossEntropyBp(arg(1), arg(0)); + return Arrays.asList(f().zerosLike(arg(0)), ret[0]); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java index f6983c04007e..13639d94fd63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java @@ -47,7 +47,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogitsBp extends DynamicCustomOp { public SparseSoftmaxCrossEntropyLossWithLogitsBp(SameDiff sameDiff, SDVariable logits, SDVariable labels) { - super(null, sameDiff, new SDVariable[]{logits, labels}, false); + super(null, sameDiff, new SDVariable[]{labels, logits}, false); } @Override @@ -59,4 +59,15 @@ public String opName() { public List doDiff(List grad){ throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported"); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(1)); //Same as predictions (logits) + } + + @Override + public int getNumOutputs(){ + return 1; + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index b0e2935ee3d8..380e3c0f474a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -58,14 +58,10 @@ public void testLoss2d() { List failed = new ArrayList<>(); int totalRun = 0; - for (String fn : new String[]{"absdiff", "cosine", "hinge", "huber", "log", "mse", + for (String fn : new String[]{ + "absdiff", "cosine", "hinge", "huber", "log", "mse", "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse", - "sparsesoftmax", "sparsesoftmax_onehot"}) { - - if((fn.equals("sparsesoftmax")) && OpValidationSuite.IGNORE_FAILING){ - log.warn("NOT YET IMPLEMENTED: {}", fn); - continue; - } + "sparsesoftmax"}) { for(String weights : new String[]{"none", "scalar", "perExample", "perOutput"}) { @@ -83,12 +79,22 @@ public void testLoss2d() { if((fn.equals("softmaxxent") || fn.equals("softmaxxent_smooth")) && reduction == LossReduce.NONE) continue; //Combination not supported (doesn't make sense) + if(fn.equals("sparsesoftmax") && (!weights.equals("none") || reduction != LossReduce.SUM) ) + continue; //sparse softmax doesn't support weights or reduction confic + SameDiff sd = SameDiff.create(); int nOut = 4; int minibatch = 10; SDVariable predictions = sd.var("in", DataType.DOUBLE, -1, nOut); - SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + SDVariable labels; + if("sparsesoftmax".equalsIgnoreCase(fn)){ + labels = sd.var("labels", DataType.INT, -1); + } else { + //ALl other loss functions + labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + } + SDVariable w; INDArray wArrBroadcast; switch (weights){ @@ -232,7 +238,18 @@ public void testLoss2d() { loss = sd.lossMeanPairwiseSquaredError("loss", labels, predictions, w); break; case "sparsesoftmax": + labelsArr = Nd4j.create(DataType.INT, minibatch); + INDArray oneHot = Nd4j.create(DataType.DOUBLE, minibatch, nOut); + for( int i=0; i Date: Thu, 17 Jan 2019 20:51:28 +1100 Subject: [PATCH 5/6] Re-add softmax cross entropy with logits (actually not identical behaviour to softmax cross entropy) --- .../DifferentialFunctionFactory.java | 8 +++ .../impl/loss/SoftmaxCrossEntropyLoss.java | 4 ++ .../SoftmaxCrossEntropyWithLogitsLoss.java | 72 +++++++++++++++++++ .../SoftmaxCrossEntropyWithLogitsLossBp.java | 52 ++++++++++++++ 4 files changed, 136 insertions(+) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index ce6ba607325f..a48fc530e8db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1392,6 +1392,14 @@ public SDVariable[] lossSoftmaxCrossEntropyBp(SDVariable labels, SDVariable logi return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); } + public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) { + return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, weights, labels, classDim).outputVariable(); + } + + public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) { + return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables(); + } + public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){ return new SparseSoftmaxCrossEntropyLossWithLogits(sameDiff(), logits, labels).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index 3b12187e307c..e63ba2d9f909 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -78,6 +78,10 @@ public String onnxName() { throw new NoOpNameFoundException("No onnx op opName found for " + opName()); } +// @Override +// public String tensorflowName() { +// return "SoftmaxCrossEntropy"; +// } @Override public String tensorflowName() { return "SoftmaxCrossEntropy"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java new file mode 100644 index 000000000000..b873ca268bb5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java @@ -0,0 +1,72 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.loss; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + + +/** + * Softmax cross entropy loss with Logits + * + * @author Max Pumperla + */ +@NoArgsConstructor +public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp { + + protected int classesDim; + + public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) { + super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false); + this.classesDim = classesDim; + addIArgument(classesDim); + } + + @Override + public String opName() { + return "softmax_cross_entropy_loss_with_logits"; + } + + @Override + public String tensorflowName() { + return "SoftmaxCrossEntropyWithLogits"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), + "Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes); + + return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions + } + + @Override + public List doDiff(List grad){ + //No external gradient + //Args: logits, weigths, label + SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(2), arg(0), arg(1), classesDim); + return Arrays.asList(grads); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java new file mode 100644 index 000000000000..68cd88f0998b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java @@ -0,0 +1,52 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.loss.bp; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.List; + + +/** + * Softmax cross entropy loss with Logits + * + * @author Max Pumperla + */ +@NoArgsConstructor +public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp { + + protected int classesDim; + + public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) { + super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false); + this.classesDim = classesDim; + addIArgument(classesDim); + } + + @Override + public String opName() { + return "softmax_cross_entropy_loss_with_logits_grad"; + } + + @Override + public List doDiff(List grad){ + throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported"); + } +} From 42908dedd77e8e96887ae0beb295c7de176da9ab Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 17 Jan 2019 20:54:01 +1100 Subject: [PATCH 6/6] Small cleanup --- .../linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index e63ba2d9f909..3b12187e307c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -78,10 +78,6 @@ public String onnxName() { throw new NoOpNameFoundException("No onnx op opName found for " + opName()); } -// @Override -// public String tensorflowName() { -// return "SoftmaxCrossEntropy"; -// } @Override public String tensorflowName() { return "SoftmaxCrossEntropy";