diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScore.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScore.java index 0149d7feec9d1..f365e4758cc94 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScore.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScore.java @@ -113,6 +113,15 @@ public double getScore(long subsetFreq, long subsetSize, long supersetFreq, long return 0.0; } + // casting to `long` to round down to nearest whole number + double epsAllDocsInClass = (long)eps(allDocsInClass); + double epsAllDocsNotInClass = (long)eps(allDocsNotInClass); + + docsContainTermInClass += epsAllDocsInClass; + docsContainTermNotInClass += epsAllDocsNotInClass; + allDocsInClass += epsAllDocsInClass; + allDocsNotInClass += epsAllDocsNotInClass; + // Adjust counts to ignore ratio changes which are less than 5% // casting to `long` to round down to nearest whole number docsContainTermNotInClass = (long)(Math.min( @@ -120,41 +129,34 @@ public double getScore(long subsetFreq, long subsetSize, long supersetFreq, long docsContainTermInClass / allDocsInClass * allDocsNotInClass ) + 0.5); - // casting to `long` to round down to nearest whole number - double epsAllDocsInClass = (long)eps(allDocsInClass); - double epsAllDocsNotInClass = (long)eps(allDocsNotInClass); - - if ((allDocsInClass + epsAllDocsInClass) > Long.MAX_VALUE - || (docsContainTermInClass + epsAllDocsInClass) > Long.MAX_VALUE - || (allDocsNotInClass + epsAllDocsNotInClass) > Long.MAX_VALUE - || (docsContainTermNotInClass + epsAllDocsNotInClass) > Long.MAX_VALUE) { + if (allDocsInClass > Long.MAX_VALUE + || docsContainTermInClass > Long.MAX_VALUE + || allDocsNotInClass > Long.MAX_VALUE + || docsContainTermNotInClass > Long.MAX_VALUE) { throw new AggregationExecutionException( "too many documents in background and foreground sets, further restrict sets for execution" ); } double v1 = new LongBinomialDistribution( - (long)(allDocsInClass + epsAllDocsInClass), - (docsContainTermInClass + epsAllDocsInClass)/(allDocsInClass + epsAllDocsInClass) - ).logProbability((long)(docsContainTermInClass + epsAllDocsInClass)); + (long)allDocsInClass, docsContainTermInClass / allDocsInClass + ).logProbability((long)docsContainTermInClass); double v2 = new LongBinomialDistribution( - (long)(allDocsNotInClass + epsAllDocsNotInClass), - (docsContainTermNotInClass + epsAllDocsNotInClass)/(allDocsNotInClass + epsAllDocsNotInClass) - ).logProbability((long)(docsContainTermNotInClass + epsAllDocsNotInClass)); + (long)allDocsNotInClass, docsContainTermNotInClass / allDocsNotInClass + ).logProbability((long)docsContainTermNotInClass); - double p2 = (docsContainTermInClass + docsContainTermNotInClass + epsAllDocsNotInClass + epsAllDocsInClass) - / (allDocsInClass + allDocsNotInClass + epsAllDocsNotInClass + epsAllDocsInClass); + double p2 = (docsContainTermInClass + docsContainTermNotInClass) / (allDocsInClass + allDocsNotInClass); - double v3 = new LongBinomialDistribution((long)(allDocsInClass + epsAllDocsInClass), p2) - .logProbability((long)(docsContainTermInClass + epsAllDocsInClass)); + double v3 = new LongBinomialDistribution((long)allDocsInClass, p2) + .logProbability((long)docsContainTermInClass); - double v4 = new LongBinomialDistribution((long)(allDocsNotInClass + epsAllDocsNotInClass), p2) - .logProbability((long)(docsContainTermNotInClass + epsAllDocsNotInClass)); + double v4 = new LongBinomialDistribution((long)allDocsNotInClass, p2) + .logProbability((long)docsContainTermNotInClass); double logLikelihoodRatio = v1 + v2 - v3 - v4; double pValue = CHI_SQUARED_DISTRIBUTION.survivalFunction(2.0 * logLikelihoodRatio); - return -FastMath.log(FastMath.max(pValue, Double.MIN_NORMAL)); + return FastMath.max(-FastMath.log(FastMath.max(pValue, Double.MIN_NORMAL)), 0.0); } private double eps(double value) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScoreTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScoreTests.java index 24119d31c080a..42346627b17d5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScoreTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScoreTests.java @@ -22,7 +22,6 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -110,7 +109,7 @@ public void testPValueScore() { ); assertThat( FastMath.exp(-new PValueScore(false).getScore(10, 100, 10, 1000)), - closeTo(0.002988594884934073, eps) + closeTo(0.003972388976814195, eps) ); assertThat( FastMath.exp(-new PValueScore(false).getScore(10, 100, 200, 1000)), @@ -118,23 +117,23 @@ public void testPValueScore() { ); assertThat( FastMath.exp(-new PValueScore(false).getScore(20, 10000, 5, 10000)), - closeTo(0.6309430298306147, eps) + closeTo(1.0, eps) ); } public void testSmallChanges() { assertThat( FastMath.exp(-new PValueScore(false).getScore(1, 4205, 0, 821496)), - closeTo(0.9572480202044421, eps) + closeTo(0.9999037287868853, eps) ); // Same(ish) ratios assertThat( FastMath.exp(-new PValueScore(false).getScore(10, 4205, 195, 82149)), - closeTo(0.9893886454928338, eps) + closeTo(0.9995943820612134, eps) ); assertThat( FastMath.exp(-new PValueScore(false).getScore(10, 4205, 1950, 821496)), - closeTo(0.9867689169546193, eps) + closeTo(0.9999942565428899, eps) ); // 4% vs 0% @@ -145,12 +144,12 @@ public void testSmallChanges() { // 4% vs 2% assertThat( FastMath.exp(-new PValueScore(false).getScore(168, 4205, 16429, 821496)), - closeTo(4.78464746423625e-06, eps) + closeTo(8.542608559219833e-5, eps) ); // 4% vs 3.5% assertThat( FastMath.exp(-new PValueScore(false).getScore(168, 4205, 28752, 821496)), - closeTo(0.4728938449949742, eps) + closeTo(0.8833950526957098, eps) ); } @@ -186,7 +185,7 @@ public void testIncreasedSubsetIncreasedScore() { for (int j = 1; j < 11; j++) { double nextScore = getScore.apply(j*10L); assertThat(nextScore, greaterThanOrEqualTo(0.0)); - assertThat(nextScore, greaterThan(priorScore)); + assertThat(nextScore, greaterThanOrEqualTo(priorScore)); priorScore = nextScore; } }