Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] More loss function implementation and tests #7022

Merged
merged 6 commits into from Jan 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -1324,6 +1324,10 @@ public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable
return new WeightedCrossEntropyLoss(sameDiff(), targets, inputs, weights).outputVariable();
}

public SDVariable lossL2(SDVariable var){
return new L2Loss(sameDiff(), var).outputVariable();
}

public SDVariable lossAbsoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){
return new AbsoluteDifferenceLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable();
}
Expand Down Expand Up @@ -1396,6 +1400,14 @@ public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVar
return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables();
}

public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){
return new SparseSoftmaxCrossEntropyLossWithLogits(sameDiff(), logits, labels).outputVariable();
}

public SDVariable[] lossSparseSoftmaxCrossEntropyBp(SDVariable logits, SDVariable labels){
return new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff(), logits, labels).outputVariables();
}


public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias) {
return new XwPlusB(sameDiff(), input, weights, bias).outputVariable();
Expand Down
Expand Up @@ -1667,7 +1667,7 @@ public INDArray eval() {

@Override
public String toString() {
return varName;
return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType + ")";
}

@Override
Expand Down
Expand Up @@ -1267,6 +1267,7 @@ public void addArgsFor(SDVariable[] variables, DifferentialFunction function) {
* @return The differential function that this variable is an output of, or null if it is not the output of a function
*/
public DifferentialFunction getVariableOutputFunction(String variableName) {
Preconditions.checkState(variables.containsKey(variableName), "No variable with name \"%s\" found in graph", variableName);
if(variables.get(variableName).getOutputOfOp() == null)
return null;
return ops.get(variables.get(variableName).getOutputOfOp()).getOp();
Expand Down Expand Up @@ -8709,6 +8710,26 @@ public SDVariable tensorMmul(String name,
return updateVariableNameAndReference(result, name);
}

/**
* L2 loss: 1/2 * sum(x^2)
* @param var Variable to calculate L2 loss of
* @return L2 loss
*/
public SDVariable lossL2(@NonNull SDVariable var){
return lossL2(null, var);
}

/**
* L2 loss: 1/2 * sum(x^2)
* @param name Name of the output variable
* @param var Variable to calculate L2 loss of
* @return L2 loss
*/
public SDVariable lossL2(String name, @NonNull SDVariable var){
SDVariable ret = f().lossL2(var);
return updateVariableNameAndReference(ret, name);
}

/**
* See {@link #lossAbsoluteDifference(String, SDVariable, SDVariable, SDVariable, LossReduce)}.
*/
Expand Down Expand Up @@ -9034,6 +9055,29 @@ public SDVariable lossSoftmaxCrossEntropy(String name, @NonNull SDVariable oneHo
return updateVariableNameAndReference(result, name);
}

/**
* See {@link #lossSparseSoftmaxCrossEntropy(String, SDVariable, SDVariable)}
*/
public SDVariable lossSparseSoftmaxCrossEntropy(@NonNull SDVariable logits, @NonNull SDVariable labels) {
return lossSparseSoftmaxCrossEntropy(null, logits, labels);
}

/**
* As per {@link #lossSoftmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce)} but the labels variable
* is represented as an integer array instead of the equivalent one-hot array.<br>
* i.e., if logits are rank N, then labels have rank N-1
*
* @param name Name of the output variable. May be null
* @param logits Logits array ("pre-softmax activations")
* @param labels Labels array. Must be an integer type.
* @return Softmax cross entropy
*/
public SDVariable lossSparseSoftmaxCrossEntropy(String name, @NonNull SDVariable logits, @NonNull SDVariable labels) {
Preconditions.checkState(labels.dataType().isIntType(), "Labels variable must be an integer type: got %s", logits);
SDVariable ret = f().lossSparseSoftmaxCrossEntropy(logits, labels);
return updateVariableNameAndReference(ret, name);
}

/**
* TODO
*
Expand Down Expand Up @@ -10540,7 +10584,7 @@ protected int asFlatNode(@NonNull DifferentialFunction node, @NonNull FlatBuffer
inPaired.add(IntPair.createIntPair(bufferBuilder, nodeId, outIdx));
}

log.debug("Own Name: {}", node.getOwnName());
log.trace("Own Name: {}", node.getOwnName());
int ownId = id != null ? id : idCounter.incrementAndGet(); //forwardMap.containsKey(node.getOwnName()) ? forwardMap.get(node.getOwnName()) : idCounter.incrementAndGet();
String[] outNames = node.outputVariablesNames();
for(String s : outNames){
Expand Down Expand Up @@ -10665,7 +10709,7 @@ public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration con
List<SDVariable> allVars = variables();
for (SDVariable variable : allVars) {
INDArray arr = variable.getArr();
log.debug("Exporting variable: [{}]", variable.getVarName());
log.trace("Exporting variable: [{}]", variable.getVarName());

//If variable is the output of some op - let's use the ONE index for exporting, and properly track the output
// numbers. For example, unstack(x) -> y0, y1, y2 -> the y's should be say (3,0), (3,1), (3,2) NOT (4,0), (5,0), (6,0)
Expand Down
Expand Up @@ -129,7 +129,9 @@ public static boolean checkGradients(SameDiff sd, Map<String,INDArray> placehold
Set<String> gradVarNames = new HashSet<>();
for(Variable v : sd.getVariables().values()){
if(v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER){
gradVarNames.add(v.getVariable().getGradient().getVarName());
SDVariable g = v.getVariable().getGradient();
Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable());
gradVarNames.add(g.getVarName());
}
}

Expand Down
Expand Up @@ -34,8 +34,8 @@
@NoArgsConstructor
public class L2Loss extends DynamicCustomOp {

public L2Loss(SameDiff sameDiff, SDVariable[] args) {
super(null, sameDiff, args);
public L2Loss(SameDiff sameDiff, SDVariable var) {
super(sameDiff, new SDVariable[]{var});
}

@Override
Expand All @@ -59,4 +59,11 @@ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes.get(0).isFPType(), "Input datatype must be floating point for %s, got %s", getClass(), inputDataTypes);
return inputDataTypes;
}

@Override
public List<SDVariable> doDiff(List<SDVariable> grad){
//L2 loss: L = 1/2 * sum(x_i^2)
//dL/dxi = xi
return Collections.singletonList(f().identity(arg()));
}
}
Expand Up @@ -68,4 +68,12 @@ public String opName() {
public String tensorflowName() {
return "sigmoid_cross_entropy";
}

@Override
public List<SDVariable> doDiff(List<SDVariable> grad){
//No external gradient
//Args are: predictions, weights, label
SDVariable[] grads = f().lossSigmoidCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing);
return Arrays.asList(grads);
}
}
Expand Up @@ -20,19 +20,12 @@
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;


/**
Expand Down Expand Up @@ -65,7 +58,7 @@ public String tensorflowName() {
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes);

return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions
}

Expand Down
Expand Up @@ -46,7 +46,7 @@
public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp {

public SparseSoftmaxCrossEntropyLossWithLogits(SameDiff sameDiff, SDVariable logits, SDVariable labels) {
super(null, sameDiff, new SDVariable[]{logits, labels}, false);
super(null, sameDiff, new SDVariable[]{labels, logits}, false);
}


Expand Down Expand Up @@ -87,6 +87,13 @@ public Op.Type opType() {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0)); //Same as predictions
return Collections.singletonList(inputDataTypes.get(1)); //Same as predictions (logits)
}

@Override
public List<SDVariable> doDiff(List<SDVariable> grad){
//args: label, logits
SDVariable[] ret = f().lossSparseSoftmaxCrossEntropyBp(arg(1), arg(0));
return Arrays.asList(f().zerosLike(arg(0)), ret[0]);
}
}
Expand Up @@ -40,8 +40,16 @@ public int getNumOutputs(){

@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){

return Arrays.asList(arg(0).dataType(), arg(1).dataType(), arg(2).dataType());
Preconditions.checkState(inputDataTypes.get(0).isFPType(), "Input 0 (predictions) must be a floating point type; inputs datatypes are %s for %s",
inputDataTypes, getClass());
DataType dt0 = inputDataTypes.get(0);
DataType dt1 = arg(1).dataType();
DataType dt2 = arg(2).dataType();
if(!dt1.isFPType())
dt1 = dt0;
if(!dt2.isFPType())
dt2 = dt0;
return Arrays.asList(dt0, dt1, dt2);
}

@Override
Expand Down
Expand Up @@ -19,11 +19,8 @@
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;

import java.util.Collections;
import java.util.List;


Expand Down
Expand Up @@ -47,7 +47,7 @@
public class SparseSoftmaxCrossEntropyLossWithLogitsBp extends DynamicCustomOp {

public SparseSoftmaxCrossEntropyLossWithLogitsBp(SameDiff sameDiff, SDVariable logits, SDVariable labels) {
super(null, sameDiff, new SDVariable[]{logits, labels}, false);
super(null, sameDiff, new SDVariable[]{labels, logits}, false);
}

@Override
Expand All @@ -59,4 +59,15 @@ public String opName() {
public List<SDVariable> doDiff(List<SDVariable> grad){
throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported");
}

@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(1)); //Same as predictions (logits)
}

@Override
public int getNumOutputs(){
return 1;
}
}