diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index b3554f0eb0ea..a48fc530e8db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1324,6 +1324,10 @@ public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable return new WeightedCrossEntropyLoss(sameDiff(), targets, inputs, weights).outputVariable(); } + public SDVariable lossL2(SDVariable var){ + return new L2Loss(sameDiff(), var).outputVariable(); + } + public SDVariable lossAbsoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ return new AbsoluteDifferenceLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); } @@ -1396,6 +1400,14 @@ public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVar return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables(); } + public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){ + return new SparseSoftmaxCrossEntropyLossWithLogits(sameDiff(), logits, labels).outputVariable(); + } + + public SDVariable[] lossSparseSoftmaxCrossEntropyBp(SDVariable logits, SDVariable labels){ + return new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff(), logits, labels).outputVariables(); + } + public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias) { return new XwPlusB(sameDiff(), input, weights, bias).outputVariable(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 312f4361249c..e7a2f7ea9c5b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -1667,7 +1667,7 @@ public INDArray eval() { @Override public String toString() { - return varName; + return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType + ")"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index fb2cf69d5601..09dce5d92871 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -1267,6 +1267,7 @@ public void addArgsFor(SDVariable[] variables, DifferentialFunction function) { * @return The differential function that this variable is an output of, or null if it is not the output of a function */ public DifferentialFunction getVariableOutputFunction(String variableName) { + Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName); if(variables.get(variableName).getOutputOfOp() == null) return null; return ops.get(variables.get(variableName).getOutputOfOp()).getOp(); @@ -8709,6 +8710,26 @@ public SDVariable tensorMmul(String name, return updateVariableNameAndReference(result, name); } + /** + * L2 loss: 1/2 * sum(x^2) + * @param var Variable to calculate L2 loss of + * @return L2 loss + */ + public SDVariable lossL2(@NonNull SDVariable var){ + return lossL2(null, var); + } + + /** + * L2 loss: 1/2 * sum(x^2) + * @param name Name of the output variable + * @param var Variable to calculate L2 loss of + * @return L2 loss + */ + public SDVariable lossL2(String name, @NonNull SDVariable var){ + SDVariable ret = f().lossL2(var); + return updateVariableNameAndReference(ret, name); + } + /** * See {@link #lossAbsoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}. */ @@ -9034,6 +9055,29 @@ public SDVariable lossSoftmaxCrossEntropy(String name, @NonNull SDVariable oneHo return updateVariableNameAndReference(result, name); } + /** + * See {@link #lossSparseSoftmaxCrossEntropy(String, SDVariable, SDVariable)} + */ + public SDVariable lossSparseSoftmaxCrossEntropy(@NonNull SDVariable logits, @NonNull SDVariable labels) { + return lossSparseSoftmaxCrossEntropy(null, logits, labels); + } + + /** + * As per {@link #lossSoftmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce)} but the labels variable + * is represented as an integer array instead of the equivalent one-hot array.
+ * i.e., if logits are rank N, then labels have rank N-1 + * + * @param name Name of the output variable. May be null + * @param logits Logits array ("pre-softmax activations") + * @param labels Labels array. Must be an integer type. + * @return Softmax cross entropy + */ + public SDVariable lossSparseSoftmaxCrossEntropy(String name, @NonNull SDVariable logits, @NonNull SDVariable labels) { + Preconditions.checkState(labels.dataType().isIntType(), "Labels variable must be an integer type: got %s", logits); + SDVariable ret = f().lossSparseSoftmaxCrossEntropy(logits, labels); + return updateVariableNameAndReference(ret, name); + } + /** * TODO * @@ -10540,7 +10584,7 @@ protected int asFlatNode(@NonNull DifferentialFunction node, @NonNull FlatBuffer inPaired.add(IntPair.createIntPair(bufferBuilder, nodeId, outIdx)); } - log.debug("Own Name: {}", node.getOwnName()); + log.trace("Own Name: {}", node.getOwnName()); int ownId = id != null ? id : idCounter.incrementAndGet(); //forwardMap.containsKey(node.getOwnName()) ? forwardMap.get(node.getOwnName()) : idCounter.incrementAndGet(); String[] outNames = node.outputVariablesNames(); for(String s : outNames){ @@ -10665,7 +10709,7 @@ public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration con List allVars = variables(); for (SDVariable variable : allVars) { INDArray arr = variable.getArr(); - log.debug("Exporting variable: [{}]", variable.getVarName()); + log.trace("Exporting variable: [{}]", variable.getVarName()); //If variable is the output of some op - let's use the ONE index for exporting, and properly track the output // numbers. For example, unstack(x) -> y0, y1, y2 -> the y's should be say (3,0), (3,1), (3,2) NOT (4,0), (5,0), (6,0) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java index f4d3aabec507..3b3c0eda1600 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java @@ -129,7 +129,9 @@ public static boolean checkGradients(SameDiff sd, Map placehold Set gradVarNames = new HashSet<>(); for(Variable v : sd.getVariables().values()){ if(v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER){ - gradVarNames.add(v.getVariable().getGradient().getVarName()); + SDVariable g = v.getVariable().getGradient(); + Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable()); + gradVarNames.add(g.getVarName()); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java index b6d388d52a2c..cfef2a61b5dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java @@ -34,8 +34,8 @@ @NoArgsConstructor public class L2Loss extends DynamicCustomOp { - public L2Loss(SameDiff sameDiff, SDVariable[] args) { - super(null, sameDiff, args); + public L2Loss(SameDiff sameDiff, SDVariable var) { + super(sameDiff, new SDVariable[]{var}); } @Override @@ -59,4 +59,11 @@ public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes.get(0).isFPType(), "Input datatype must be floating point for %s, got %s", getClass(), inputDataTypes); return inputDataTypes; } + + @Override + public List doDiff(List grad){ + //L2 loss: L = 1/2 * sum(x_i^2) + //dL/dxi = xi + return Collections.singletonList(f().identity(arg())); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java index a69dcfc47897..468224939385 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java @@ -68,4 +68,12 @@ public String opName() { public String tensorflowName() { return "sigmoid_cross_entropy"; } + + @Override + public List doDiff(List grad){ + //No external gradient + //Args are: predictions, weights, label + SDVariable[] grads = f().lossSigmoidCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing); + return Arrays.asList(grads); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java index c5d461e76401..b873ca268bb5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java @@ -20,19 +20,12 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.Op; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; /** @@ -65,7 +58,7 @@ public String tensorflowName() { public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), "Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes); - + return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java index 2fefaea2be79..9c385c42513d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java @@ -46,7 +46,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp { public SparseSoftmaxCrossEntropyLossWithLogits(SameDiff sameDiff, SDVariable logits, SDVariable labels) { - super(null, sameDiff, new SDVariable[]{logits, labels}, false); + super(null, sameDiff, new SDVariable[]{labels, logits}, false); } @@ -87,6 +87,13 @@ public Op.Type opType() { @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input datatypes for %s, got %s", getClass(), inputDataTypes); - return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions + return Collections.singletonList(inputDataTypes.get(1)); //Same as predictions (logits) + } + + @Override + public List doDiff(List grad){ + //args: label, logits + SDVariable[] ret = f().lossSparseSoftmaxCrossEntropyBp(arg(1), arg(0)); + return Arrays.asList(f().zerosLike(arg(0)), ret[0]); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java index db585136f6aa..f26c97aa704e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java @@ -40,8 +40,16 @@ public int getNumOutputs(){ @Override public List calculateOutputDataTypes(List inputDataTypes){ - - return Arrays.asList(arg(0).dataType(), arg(1).dataType(), arg(2).dataType()); + Preconditions.checkState(inputDataTypes.get(0).isFPType(), "Input 0 (predictions) must be a floating point type; inputs datatypes are %s for %s", + inputDataTypes, getClass()); + DataType dt0 = inputDataTypes.get(0); + DataType dt1 = arg(1).dataType(); + DataType dt2 = arg(2).dataType(); + if(!dt1.isFPType()) + dt1 = dt0; + if(!dt2.isFPType()) + dt2 = dt0; + return Arrays.asList(dt0, dt1, dt2); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java index a27df6884937..68cd88f0998b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java @@ -19,11 +19,8 @@ import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java index f6983c04007e..13639d94fd63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java @@ -47,7 +47,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogitsBp extends DynamicCustomOp { public SparseSoftmaxCrossEntropyLossWithLogitsBp(SameDiff sameDiff, SDVariable logits, SDVariable labels) { - super(null, sameDiff, new SDVariable[]{logits, labels}, false); + super(null, sameDiff, new SDVariable[]{labels, logits}, false); } @Override @@ -59,4 +59,15 @@ public String opName() { public List doDiff(List grad){ throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported"); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(1)); //Same as predictions (logits) + } + + @Override + public int getNumOutputs(){ + return 1; + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java index 0f7c552e6064..380e3c0f474a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java @@ -33,10 +33,10 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; -import java.util.ArrayList; -import java.util.List; +import java.util.*; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; @Slf4j public class LossOpValidation extends BaseOpValidation { @@ -44,16 +44,26 @@ public LossOpValidation(Nd4jBackend backend) { super(backend); } + public static final Set NO_BP_YET = new HashSet<>(); + static { + NO_BP_YET.addAll(Arrays.asList("hinge", "huber", "l2_loss", "poisson", "mpwse")); + } + @Test public void testLoss2d() { - OpValidationSuite.ignoreFailing(); //2019/01/17 - WIP, some passing, some failing + OpValidationSuite.ignoreFailing(); //2019/01/17 - Some passing, some not yet implemented, some issues: Issue 17 https://github.com/deeplearning4j/deeplearning4j/issues/6958 Nd4j.getRandom().setSeed(12345); List failed = new ArrayList<>(); - for (String fn : new String[]{"absdiff", "cosine", "hinge", "huber", "log", "mse", - "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse", "softmaxxentlogits", "sparsesoftmax"}) { + int totalRun = 0; + for (String fn : new String[]{ + "absdiff", "cosine", "hinge", "huber", "log", "mse", + "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse", + "sparsesoftmax"}) { + + for(String weights : new String[]{"none", "scalar", "perExample", "perOutput"}) { if((fn.startsWith("softmax") || fn.equals("cosine")) && weights.equals("perOutput")) continue; //Skip this combination (not possible) @@ -66,12 +76,25 @@ public void testLoss2d() { if(fn.equals("mpwse") && (reduction != LossReduce.MEAN_BY_WEIGHT || weights.equals("perOutput"))) //LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT) continue; //MPWSE only provides scalar output - i.e., no other reduction modes. And only none/scalar/per-example weights + if((fn.equals("softmaxxent") || fn.equals("softmaxxent_smooth")) && reduction == LossReduce.NONE) + continue; //Combination not supported (doesn't make sense) + + if(fn.equals("sparsesoftmax") && (!weights.equals("none") || reduction != LossReduce.SUM) ) + continue; //sparse softmax doesn't support weights or reduction confic + SameDiff sd = SameDiff.create(); int nOut = 4; int minibatch = 10; SDVariable predictions = sd.var("in", DataType.DOUBLE, -1, nOut); - SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + SDVariable labels; + if("sparsesoftmax".equalsIgnoreCase(fn)){ + labels = sd.var("labels", DataType.INT, -1); + } else { + //ALl other loss functions + labels = sd.var("labels", DataType.DOUBLE, -1, nOut); + } + SDVariable w; INDArray wArrBroadcast; switch (weights){ @@ -84,7 +107,12 @@ public void testLoss2d() { wArrBroadcast = Nd4j.valueArrayOf(minibatch, nOut, 1.0).castTo(DataType.DOUBLE); break; case "perExample": - w = sd.var("weights", Nd4j.create(new double[]{0,0,1,1,2,2,3,3,4,4}).reshape(minibatch, 1)); + INDArray wpe = Nd4j.create(new double[]{0,0,1,1,2,2,3,3,4,4}); + if(!fn.equals("softmaxxent") && !fn.equals("softmaxxent_smooth")){ + //Softmaxxent only supports rank 1 not rank 2?? + wpe = wpe.reshape(minibatch, 1); + } + w = sd.var("weights", wpe); wArrBroadcast = Nd4j.create(DataType.DOUBLE, minibatch, nOut).addiColumnVector(w.getArr()); break; case "perOutput": @@ -131,8 +159,8 @@ public void testLoss2d() { double delta = 1.0; INDArray absDiff = Transforms.abs(labelsArr.sub(predictionsArr)); INDArray diff = labelsArr.sub(predictionsArr); - INDArray lte = absDiff.lte(delta); - INDArray gt = absDiff.gt(delta); + INDArray lte = absDiff.lte(delta).castTo(DataType.DOUBLE); + INDArray gt = absDiff.gt(delta).castTo(DataType.DOUBLE); expOut = diff.mul(diff).mul(0.5).muli(lte); expOut.addi(absDiff.mul(delta).subi(0.5 * delta * delta).mul(gt)); loss = sd.lossHuber("loss", labels, predictions, w, reduction, delta); @@ -207,14 +235,21 @@ public void testLoss2d() { } } } -// expOut.divi(pairCount); loss = sd.lossMeanPairwiseSquaredError("loss", labels, predictions, w); - break; - case "softmaxxentlogits": - break; case "sparsesoftmax": + labelsArr = Nd4j.create(DataType.INT, minibatch); + INDArray oneHot = Nd4j.create(DataType.DOUBLE, minibatch, nOut); + for( int i=0; i