Skip to content

Commit

Permalink
Merge pull request #4461 from deeplearning4j/ab_4459_roc
Browse files Browse the repository at this point in the history
#4459 Fix ROC merging for exact mode
  • Loading branch information
AlexDBlack committed Jan 3, 2018
2 parents dbc0578 + ef71b0a commit ee42dbc
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 16 deletions.
Expand Up @@ -410,8 +410,7 @@ public void testROCMerging() {
int minibatch = 64;
int nROCs = 3;

for (int steps : new int[] {0}) { //0 steps: exact
// int steps = 20;
for (int steps : new int[] {0, 20}) { //0 steps: exact, 20 steps: thresholded

Nd4j.getRandom().setSeed(12345);
Random r = new Random(12345);
Expand Down Expand Up @@ -450,6 +449,50 @@ public void testROCMerging() {
}
}

@Test
public void testROCMerging2() {
int nArrays = 10;
int minibatch = 64;
int exactAllocBlockSize = 10;
int nROCs = 3;
int steps = 0; //Exact

Nd4j.getRandom().setSeed(12345);
Random r = new Random(12345);

List<ROC> rocList = new ArrayList<>();
for (int i = 0; i < nROCs; i++) {
rocList.add(new ROC(steps, true, exactAllocBlockSize));
}

ROC single = new ROC(steps);
for (int i = 0; i < nArrays; i++) {
INDArray p = Nd4j.rand(minibatch, 2);
p.diviColumnVector(p.sum(1));

INDArray l = Nd4j.zeros(minibatch, 2);
for (int j = 0; j < minibatch; j++) {
l.putScalar(j, r.nextInt(2), 1.0);
}

single.eval(l, p);

ROC other = rocList.get(i % rocList.size());
other.eval(l, p);
}

ROC first = rocList.get(0);
for (int i = 1; i < nROCs; i++) {
first.merge(rocList.get(i));
}

double singleAUC = single.calculateAUC();
assertTrue(singleAUC >= 0.0 && singleAUC <= 1.0);
assertEquals(singleAUC, first.calculateAUC(), 1e-6);

assertEquals(single.getRocCurve(), first.getRocCurve());
}


@Test
public void testROCMultiMerging() {
Expand Down
31 changes: 17 additions & 14 deletions deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/ROC.java
Expand Up @@ -25,6 +25,9 @@
import java.util.LinkedHashMap;
import java.util.Map;

import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;

/**
* ROC (Receiver Operating Characteristic) for binary classifiers.<br>
* ROC has 2 modes of operation:
Expand Down Expand Up @@ -123,7 +126,7 @@ protected INDArray getProbAndLabelUsed() {
if (probAndLabel == null || exampleCount == 0) {
return null;
}
return probAndLabel.get(NDArrayIndex.interval(0, exampleCount), NDArrayIndex.all());
return probAndLabel.get(interval(0, exampleCount), all());
}

private double getAuc() {
Expand Down Expand Up @@ -205,8 +208,8 @@ public void eval(INDArray labels, INDArray predictions) {
INDArray newProbAndLabel = Nd4j.create(new int[] {newSize, 2}, 'c');
if (exampleCount > 0) {
//If statement to handle edge case: no examples, but we need to re-allocate right away
newProbAndLabel.get(NDArrayIndex.interval(0, exampleCount), NDArrayIndex.all()).assign(
probAndLabel.get(NDArrayIndex.interval(0, exampleCount), NDArrayIndex.all()));
newProbAndLabel.get(interval(0, exampleCount), all()).assign(
probAndLabel.get(interval(0, exampleCount), all()));
}
probAndLabel = newProbAndLabel;
}
Expand All @@ -222,10 +225,10 @@ public void eval(INDArray labels, INDArray predictions) {
labelClass1 = labels.getColumn(1);
}
int currMinibatchSize = labels.size(0);
probAndLabel.get(NDArrayIndex.interval(exampleCount, exampleCount + currMinibatchSize),
probAndLabel.get(interval(exampleCount, exampleCount + currMinibatchSize),
NDArrayIndex.point(0)).assign(probClass1);

probAndLabel.get(NDArrayIndex.interval(exampleCount, exampleCount + currMinibatchSize),
probAndLabel.get(interval(exampleCount, exampleCount + currMinibatchSize),
NDArrayIndex.point(1)).assign(labelClass1);

int countClass1CurrMinibatch = labelClass1.sumNumber().intValue();
Expand Down Expand Up @@ -350,16 +353,16 @@ as class 0, all others are predicted as class 1
*/

INDArray t = Nd4j.create(new int[] {length + 2, 1});
t.put(new INDArrayIndex[] {NDArrayIndex.interval(1, length + 1), NDArrayIndex.all()}, sorted.getColumn(0));
t.put(new INDArrayIndex[] {interval(1, length + 1), all()}, sorted.getColumn(0));

INDArray linspace = Nd4j.linspace(1, length, length);
INDArray precision = cumSumPos.div(linspace.reshape(cumSumPos.shape()));
INDArray prec = Nd4j.create(new int[] {length + 2, 1});
prec.put(new INDArrayIndex[] {NDArrayIndex.interval(1, length + 1), NDArrayIndex.all()}, precision);
prec.put(new INDArrayIndex[] {interval(1, length + 1), all()}, precision);

//Recall/TPR
INDArray rec = Nd4j.create(new int[] {length + 2, 1});
rec.put(new INDArrayIndex[] {NDArrayIndex.interval(1, length + 1), NDArrayIndex.all()},
rec.put(new INDArrayIndex[] {interval(1, length + 1), all()},
cumSumPos.div(countActualPositive));

//Edge cases
Expand Down Expand Up @@ -488,14 +491,14 @@ public RocCurve getRocCurve() {
int length = sorted.size(0);

INDArray t = Nd4j.create(new int[] {length + 2, 1});
t.put(new INDArrayIndex[] {NDArrayIndex.interval(1, length + 1), NDArrayIndex.all()}, sorted.getColumn(0));
t.put(new INDArrayIndex[] {interval(1, length + 1), all()}, sorted.getColumn(0));

INDArray fpr = Nd4j.create(new int[] {length + 2, 1});
fpr.put(new INDArrayIndex[] {NDArrayIndex.interval(1, length + 1), NDArrayIndex.all()},
fpr.put(new INDArrayIndex[] {interval(1, length + 1), all()},
cumSumNeg.div(countActualNegative));

INDArray tpr = Nd4j.create(new int[] {length + 2, 1});
tpr.put(new INDArrayIndex[] {NDArrayIndex.interval(1, length + 1), NDArrayIndex.all()},
tpr.put(new INDArrayIndex[] {interval(1, length + 1), all()},
cumSumPos.div(countActualPositive));

//Edge cases
Expand Down Expand Up @@ -669,13 +672,13 @@ public void merge(ROC other) {
//Allocate new array
int newSize = this.probAndLabel.size(0) + Math.max(other.probAndLabel.size(0), exactAllocBlockSize);
INDArray newProbAndLabel = Nd4j.create(newSize, 2);
newProbAndLabel.assign(probAndLabel.get(NDArrayIndex.interval(0, exampleCount), NDArrayIndex.all()));
newProbAndLabel.put(new INDArrayIndex[]{interval(0,exampleCount), all()}, probAndLabel.get(interval(0, exampleCount), all()));
probAndLabel = newProbAndLabel;
}

INDArray toPut = other.probAndLabel.get(NDArrayIndex.interval(0, other.exampleCount), NDArrayIndex.all());
INDArray toPut = other.probAndLabel.get(interval(0, other.exampleCount), all());
probAndLabel.put(new INDArrayIndex[] {
NDArrayIndex.interval(exampleCount, exampleCount + other.exampleCount), NDArrayIndex.all()},
interval(exampleCount, exampleCount + other.exampleCount), all()},
toPut);
} else {
for (Double d : this.counts.keySet()) {
Expand Down

0 comments on commit ee42dbc

Please sign in to comment.