Skip to content

Commit

Permalink
SameDiff loss functions (#6534)
Browse files Browse the repository at this point in the history
* First pass on properly blocking out loss functions

* Cleanup and loss function methods

* First steps for new loss tests

* Build out loss tests; Huber op arg fix

* Javadoc, test fixes

* More test fixes, javadoc pass

* Clean up old code; more tests

* Weighted loss function tests; ignore case with logged issue

* Fix MPWSE loss function (no reduce mode as not supported)

* Improvements to weighted loss tests

* Javadoc and cleanup
  • Loading branch information
AlexDBlack committed Oct 5, 2018
1 parent 6be8fd0 commit 409a706
Show file tree
Hide file tree
Showing 20 changed files with 791 additions and 1,218 deletions.
Expand Up @@ -20,6 +20,7 @@
import lombok.NonNull;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
Expand All @@ -36,6 +37,7 @@
import org.nd4j.linalg.api.ops.impl.indexaccum.*;
import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
import org.nd4j.linalg.api.ops.impl.loss.*;
import org.nd4j.linalg.api.ops.impl.scalar.*;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.*;
import org.nd4j.linalg.api.ops.impl.scatter.*;
Expand Down Expand Up @@ -1308,101 +1310,43 @@ public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable
return new WeightedCrossEntropyLoss(sameDiff(), targets, inputs, weights).outputVariable();
}

public SDVariable sigmoidCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
return new SigmoidCrossEntropyLoss(sameDiff(), logits, weights, labels,
reductionMode, labelSmoothing).outputVariable();
public SDVariable lossAbsoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){
return new AbsoluteDifferenceLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable();
}

public SDVariable softmaxCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels,
int reductionMode, double labelSmoothing) {
return new SoftmaxCrossEntropyLoss(sameDiff(), logits, weights, labels,
reductionMode, labelSmoothing).outputVariable();
public SDVariable lossCosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, int dimension){
return new CosineDistanceLoss(sameDiff(), lossReduce, predictions, weights, label, dimension).outputVariable();
}

public SDVariable lossBinaryXENT(SDVariable iX,
SDVariable i_y,
int... dimensions) {
throw new UnsupportedOperationException();
public SDVariable lossHinge(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){
return new HingeLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable();
}


public SDVariable lossCosineSimilarity(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();
}


public SDVariable lossHinge(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

}


public SDVariable lossKLD(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

}


public SDVariable lossL1(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

}


public SDVariable lossL2(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

}


public SDVariable lossMAE(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

public SDVariable lossHuber(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double delta){
return new HuberLoss(sameDiff(), lossReduce, predictions, weights, label, delta).outputVariable();
}


public SDVariable lossMAPE(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

public SDVariable lossLog(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double epsilon){
return new LogLoss(sameDiff(), lossReduce, predictions, weights, label, epsilon).outputVariable();
}


public SDVariable lossMSE(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

public SDVariable lossMeanPairwiseSquaredError(SDVariable label, SDVariable predictions, SDVariable weights){
return new MeanPairwiseSquaredErrorLoss(sameDiff(), predictions, weights, label).outputVariable();
}


public SDVariable lossMCXENT(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

public SDVariable lossMeanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){
return new MeanSquaredErrorLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable();
}


public SDVariable lossMSLE(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

public SDVariable lossSigmoidCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) {
return new SigmoidCrossEntropyLoss(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable();
}


public SDVariable lossNegativeLogLikelihood(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

public SDVariable lossSoftmaxCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) {
return new SoftmaxCrossEntropyLoss(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable();
}


public SDVariable lossPoisson(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

}


public SDVariable lossSquaredHinge(SDVariable iX, SDVariable i_y, int... dimensions) {
throw new UnsupportedOperationException();

}

public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias) {
return new XwPlusB(sameDiff(), input, weights, bias).outputVariable();
}
Expand Down

0 comments on commit 409a706

Please sign in to comment.