diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
index b3554f0eb0ea..a48fc530e8db 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
@@ -1324,6 +1324,10 @@ public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable
return new WeightedCrossEntropyLoss(sameDiff(), targets, inputs, weights).outputVariable();
}
+ public SDVariable lossL2(SDVariable var){
+ return new L2Loss(sameDiff(), var).outputVariable();
+ }
+
public SDVariable lossAbsoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){
return new AbsoluteDifferenceLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable();
}
@@ -1396,6 +1400,14 @@ public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVar
return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables();
}
+ public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){
+ return new SparseSoftmaxCrossEntropyLossWithLogits(sameDiff(), logits, labels).outputVariable();
+ }
+
+ public SDVariable[] lossSparseSoftmaxCrossEntropyBp(SDVariable logits, SDVariable labels){
+ return new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff(), logits, labels).outputVariables();
+ }
+
public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias) {
return new XwPlusB(sameDiff(), input, weights, bias).outputVariable();
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java
index 312f4361249c..e7a2f7ea9c5b 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java
@@ -1667,7 +1667,7 @@ public INDArray eval() {
@Override
public String toString() {
- return varName;
+ return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType + ")";
}
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
index fb2cf69d5601..09dce5d92871 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java
@@ -1267,6 +1267,7 @@ public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
* @return The differential function that this variable is an output of, or null if it is not the output of a function
*/
public DifferentialFunction getVariableOutputFunction(String variableName) {
+ Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName);
if(variables.get(variableName).getOutputOfOp() == null)
return null;
return ops.get(variables.get(variableName).getOutputOfOp()).getOp();
@@ -8709,6 +8710,26 @@ public SDVariable tensorMmul(String name,
return updateVariableNameAndReference(result, name);
}
+ /**
+ * L2 loss: 1/2 * sum(x^2)
+ * @param var Variable to calculate L2 loss of
+ * @return L2 loss
+ */
+ public SDVariable lossL2(@NonNull SDVariable var){
+ return lossL2(null, var);
+ }
+
+ /**
+ * L2 loss: 1/2 * sum(x^2)
+ * @param name Name of the output variable
+ * @param var Variable to calculate L2 loss of
+ * @return L2 loss
+ */
+ public SDVariable lossL2(String name, @NonNull SDVariable var){
+ SDVariable ret = f().lossL2(var);
+ return updateVariableNameAndReference(ret, name);
+ }
+
/**
* See {@link #lossAbsoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}.
*/
@@ -9034,6 +9055,29 @@ public SDVariable lossSoftmaxCrossEntropy(String name, @NonNull SDVariable oneHo
return updateVariableNameAndReference(result, name);
}
+ /**
+ * See {@link #lossSparseSoftmaxCrossEntropy(String, SDVariable, SDVariable)}
+ */
+ public SDVariable lossSparseSoftmaxCrossEntropy(@NonNull SDVariable logits, @NonNull SDVariable labels) {
+ return lossSparseSoftmaxCrossEntropy(null, logits, labels);
+ }
+
+ /**
+ * As per {@link #lossSoftmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce)} but the labels variable
+ * is represented as an integer array instead of the equivalent one-hot array.
+ * i.e., if logits are rank N, then labels have rank N-1
+ *
+ * @param name Name of the output variable. May be null
+ * @param logits Logits array ("pre-softmax activations")
+ * @param labels Labels array. Must be an integer type.
+ * @return Softmax cross entropy
+ */
+ public SDVariable lossSparseSoftmaxCrossEntropy(String name, @NonNull SDVariable logits, @NonNull SDVariable labels) {
+ Preconditions.checkState(labels.dataType().isIntType(), "Labels variable must be an integer type: got %s", logits);
+ SDVariable ret = f().lossSparseSoftmaxCrossEntropy(logits, labels);
+ return updateVariableNameAndReference(ret, name);
+ }
+
/**
* TODO
*
@@ -10540,7 +10584,7 @@ protected int asFlatNode(@NonNull DifferentialFunction node, @NonNull FlatBuffer
inPaired.add(IntPair.createIntPair(bufferBuilder, nodeId, outIdx));
}
- log.debug("Own Name: {}", node.getOwnName());
+ log.trace("Own Name: {}", node.getOwnName());
int ownId = id != null ? id : idCounter.incrementAndGet(); //forwardMap.containsKey(node.getOwnName()) ? forwardMap.get(node.getOwnName()) : idCounter.incrementAndGet();
String[] outNames = node.outputVariablesNames();
for(String s : outNames){
@@ -10665,7 +10709,7 @@ public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration con
List allVars = variables();
for (SDVariable variable : allVars) {
INDArray arr = variable.getArr();
- log.debug("Exporting variable: [{}]", variable.getVarName());
+ log.trace("Exporting variable: [{}]", variable.getVarName());
//If variable is the output of some op - let's use the ONE index for exporting, and properly track the output
// numbers. For example, unstack(x) -> y0, y1, y2 -> the y's should be say (3,0), (3,1), (3,2) NOT (4,0), (5,0), (6,0)
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java
index f4d3aabec507..3b3c0eda1600 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java
@@ -129,7 +129,9 @@ public static boolean checkGradients(SameDiff sd, Map placehold
Set gradVarNames = new HashSet<>();
for(Variable v : sd.getVariables().values()){
if(v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER){
- gradVarNames.add(v.getVariable().getGradient().getVarName());
+ SDVariable g = v.getVariable().getGradient();
+ Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable());
+ gradVarNames.add(g.getVarName());
}
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java
index b6d388d52a2c..cfef2a61b5dd 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java
@@ -34,8 +34,8 @@
@NoArgsConstructor
public class L2Loss extends DynamicCustomOp {
- public L2Loss(SameDiff sameDiff, SDVariable[] args) {
- super(null, sameDiff, args);
+ public L2Loss(SameDiff sameDiff, SDVariable var) {
+ super(sameDiff, new SDVariable[]{var});
}
@Override
@@ -59,4 +59,11 @@ public List calculateOutputDataTypes(List inputDataTypes){
Preconditions.checkState(inputDataTypes.get(0).isFPType(), "Input datatype must be floating point for %s, got %s", getClass(), inputDataTypes);
return inputDataTypes;
}
+
+ @Override
+ public List doDiff(List grad){
+ //L2 loss: L = 1/2 * sum(x_i^2)
+ //dL/dxi = xi
+ return Collections.singletonList(f().identity(arg()));
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java
index a69dcfc47897..468224939385 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java
@@ -68,4 +68,12 @@ public String opName() {
public String tensorflowName() {
return "sigmoid_cross_entropy";
}
+
+ @Override
+ public List doDiff(List grad){
+ //No external gradient
+ //Args are: predictions, weights, label
+ SDVariable[] grads = f().lossSigmoidCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing);
+ return Arrays.asList(grads);
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java
index c5d461e76401..b873ca268bb5 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java
@@ -20,19 +20,12 @@
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
-import org.nd4j.imports.NoOpNameFoundException;
-import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
-import org.nd4j.linalg.api.ops.Op;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.GraphDef;
-import org.tensorflow.framework.NodeDef;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import java.util.Map;
/**
@@ -65,7 +58,7 @@ public String tensorflowName() {
public List calculateOutputDataTypes(List inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
-
+
return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java
index 2fefaea2be79..9c385c42513d 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java
@@ -46,7 +46,7 @@
public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp {
public SparseSoftmaxCrossEntropyLossWithLogits(SameDiff sameDiff, SDVariable logits, SDVariable labels) {
- super(null, sameDiff, new SDVariable[]{logits, labels}, false);
+ super(null, sameDiff, new SDVariable[]{labels, logits}, false);
}
@@ -87,6 +87,13 @@ public Op.Type opType() {
@Override
public List calculateOutputDataTypes(List inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
- return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions
+ return Collections.singletonList(inputDataTypes.get(1)); //Same as predictions (logits)
+ }
+
+ @Override
+ public List doDiff(List grad){
+ //args: label, logits
+ SDVariable[] ret = f().lossSparseSoftmaxCrossEntropyBp(arg(1), arg(0));
+ return Arrays.asList(f().zerosLike(arg(0)), ret[0]);
}
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java
index db585136f6aa..f26c97aa704e 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/BaseLossBp.java
@@ -40,8 +40,16 @@ public int getNumOutputs(){
@Override
public List calculateOutputDataTypes(List inputDataTypes){
-
- return Arrays.asList(arg(0).dataType(), arg(1).dataType(), arg(2).dataType());
+ Preconditions.checkState(inputDataTypes.get(0).isFPType(), "Input 0 (predictions) must be a floating point type; inputs datatypes are %s for %s",
+ inputDataTypes, getClass());
+ DataType dt0 = inputDataTypes.get(0);
+ DataType dt1 = arg(1).dataType();
+ DataType dt2 = arg(2).dataType();
+ if(!dt1.isFPType())
+ dt1 = dt0;
+ if(!dt2.isFPType())
+ dt2 = dt0;
+ return Arrays.asList(dt0, dt1, dt2);
}
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java
index a27df6884937..68cd88f0998b 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SoftmaxCrossEntropyWithLogitsLossBp.java
@@ -19,11 +19,8 @@
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
-import org.nd4j.base.Preconditions;
-import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
-import java.util.Collections;
import java.util.List;
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java
index f6983c04007e..13639d94fd63 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/SparseSoftmaxCrossEntropyLossWithLogitsBp.java
@@ -47,7 +47,7 @@
public class SparseSoftmaxCrossEntropyLossWithLogitsBp extends DynamicCustomOp {
public SparseSoftmaxCrossEntropyLossWithLogitsBp(SameDiff sameDiff, SDVariable logits, SDVariable labels) {
- super(null, sameDiff, new SDVariable[]{logits, labels}, false);
+ super(null, sameDiff, new SDVariable[]{labels, logits}, false);
}
@Override
@@ -59,4 +59,15 @@ public String opName() {
public List doDiff(List grad){
throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported");
}
+
+ @Override
+ public List calculateOutputDataTypes(List inputDataTypes){
+ Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
+ return Collections.singletonList(inputDataTypes.get(1)); //Same as predictions (logits)
+ }
+
+ @Override
+ public int getNumOutputs(){
+ return 1;
+ }
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java
index 0f7c552e6064..380e3c0f474a 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LossOpValidation.java
@@ -33,10 +33,10 @@
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.ops.transforms.Transforms;
-import java.util.ArrayList;
-import java.util.List;
+import java.util.*;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
@Slf4j
public class LossOpValidation extends BaseOpValidation {
@@ -44,16 +44,26 @@ public LossOpValidation(Nd4jBackend backend) {
super(backend);
}
+ public static final Set NO_BP_YET = new HashSet<>();
+ static {
+ NO_BP_YET.addAll(Arrays.asList("hinge", "huber", "l2_loss", "poisson", "mpwse"));
+ }
+
@Test
public void testLoss2d() {
- OpValidationSuite.ignoreFailing(); //2019/01/17 - WIP, some passing, some failing
+ OpValidationSuite.ignoreFailing(); //2019/01/17 - Some passing, some not yet implemented, some issues: Issue 17 https://github.com/deeplearning4j/deeplearning4j/issues/6958
Nd4j.getRandom().setSeed(12345);
List failed = new ArrayList<>();
- for (String fn : new String[]{"absdiff", "cosine", "hinge", "huber", "log", "mse",
- "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse", "softmaxxentlogits", "sparsesoftmax"}) {
+ int totalRun = 0;
+ for (String fn : new String[]{
+ "absdiff", "cosine", "hinge", "huber", "log", "mse",
+ "sigmoidxent", "sigmoidxent_smooth", "softmaxxent", "softmaxxent_smooth", "mpwse",
+ "sparsesoftmax"}) {
+
+
for(String weights : new String[]{"none", "scalar", "perExample", "perOutput"}) {
if((fn.startsWith("softmax") || fn.equals("cosine")) && weights.equals("perOutput"))
continue; //Skip this combination (not possible)
@@ -66,12 +76,25 @@ public void testLoss2d() {
if(fn.equals("mpwse") && (reduction != LossReduce.MEAN_BY_WEIGHT || weights.equals("perOutput"))) //LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT)
continue; //MPWSE only provides scalar output - i.e., no other reduction modes. And only none/scalar/per-example weights
+ if((fn.equals("softmaxxent") || fn.equals("softmaxxent_smooth")) && reduction == LossReduce.NONE)
+ continue; //Combination not supported (doesn't make sense)
+
+ if(fn.equals("sparsesoftmax") && (!weights.equals("none") || reduction != LossReduce.SUM) )
+ continue; //sparse softmax doesn't support weights or reduction confic
+
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable predictions = sd.var("in", DataType.DOUBLE, -1, nOut);
- SDVariable labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+ SDVariable labels;
+ if("sparsesoftmax".equalsIgnoreCase(fn)){
+ labels = sd.var("labels", DataType.INT, -1);
+ } else {
+ //ALl other loss functions
+ labels = sd.var("labels", DataType.DOUBLE, -1, nOut);
+ }
+
SDVariable w;
INDArray wArrBroadcast;
switch (weights){
@@ -84,7 +107,12 @@ public void testLoss2d() {
wArrBroadcast = Nd4j.valueArrayOf(minibatch, nOut, 1.0).castTo(DataType.DOUBLE);
break;
case "perExample":
- w = sd.var("weights", Nd4j.create(new double[]{0,0,1,1,2,2,3,3,4,4}).reshape(minibatch, 1));
+ INDArray wpe = Nd4j.create(new double[]{0,0,1,1,2,2,3,3,4,4});
+ if(!fn.equals("softmaxxent") && !fn.equals("softmaxxent_smooth")){
+ //Softmaxxent only supports rank 1 not rank 2??
+ wpe = wpe.reshape(minibatch, 1);
+ }
+ w = sd.var("weights", wpe);
wArrBroadcast = Nd4j.create(DataType.DOUBLE, minibatch, nOut).addiColumnVector(w.getArr());
break;
case "perOutput":
@@ -131,8 +159,8 @@ public void testLoss2d() {
double delta = 1.0;
INDArray absDiff = Transforms.abs(labelsArr.sub(predictionsArr));
INDArray diff = labelsArr.sub(predictionsArr);
- INDArray lte = absDiff.lte(delta);
- INDArray gt = absDiff.gt(delta);
+ INDArray lte = absDiff.lte(delta).castTo(DataType.DOUBLE);
+ INDArray gt = absDiff.gt(delta).castTo(DataType.DOUBLE);
expOut = diff.mul(diff).mul(0.5).muli(lte);
expOut.addi(absDiff.mul(delta).subi(0.5 * delta * delta).mul(gt));
loss = sd.lossHuber("loss", labels, predictions, w, reduction, delta);
@@ -207,14 +235,21 @@ public void testLoss2d() {
}
}
}
-// expOut.divi(pairCount);
loss = sd.lossMeanPairwiseSquaredError("loss", labels, predictions, w);
- break;
- case "softmaxxentlogits":
-
break;
case "sparsesoftmax":
+ labelsArr = Nd4j.create(DataType.INT, minibatch);
+ INDArray oneHot = Nd4j.create(DataType.DOUBLE, minibatch, nOut);
+ for( int i=0; i