-
Notifications
You must be signed in to change notification settings - Fork 24.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Add r_squared eval metric to regression #44248
[ML] Add r_squared eval metric to regression #44248
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good Ben. I have one minor comment on normalisation of variance, but otherwise LGTM
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual"); | ||
return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ? | ||
null : | ||
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount()))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question, does extendedStats use the maximum likelihood or unbiased estimate of variance? This would be either a * extendedStats.getCount()
or * (extendedStats.getCount() - 1)
depending on the answer. I'm guessing this would have been flushed by the tests, but the difference is obviously small for large doc count, I think either way it is probably worth a comment here to say what the aggregation does use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at InternalExtendedStats
this does use the maximum likelihood form, so multiplying by count is the right thing to do. So maybe just worthwhile adding the comment that the definition of the sample variance used by getVariance() is sum squared residuals divided by count.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tveasey I will add the comment. FWIW, I have tested the calculation on some various datasets to make sure it matched what other tools (namely scikit) calculate for the r_squared metric.
.../src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java
Outdated
Show resolved
Hide resolved
…ml/dataframe/evaluation/regression/RSquared.java Co-Authored-By: David Kyle <david.kyle@elastic.co>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@elasticmachine update branch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
/** | ||
* Calculates R-Squared between two known numerical fields. | ||
* | ||
* equation: mse = 1 - SSres/SStot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* equation: mse = 1 - SSres/SStot | |
* equation: R-Squared = 1 - SSres/SStot |
@@ -88,7 +94,10 @@ public BinarySoftClassification(String actualField, String predictedProbabilityF | |||
@Nullable List<EvaluationMetric> metrics) { | |||
this.actualField = Objects.requireNonNull(actualField); | |||
this.predictedProbabilityField = Objects.requireNonNull(predictedProbabilityField); | |||
this.metrics = Objects.requireNonNull(metrics); | |||
if (metrics != null) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it mean if there are no Evaluation Metrics, does that make any sense to allow this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either way this is an improvement on the code that tagged metrics
as nullable then has requireNonNull(metrics)
3 lines later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see now. The server side has a set of default metrics that are used if this parameter is null
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, null
=> use default value
. This follows the implementation pattern we have elsewhere in the HLRC.
...c/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java
Outdated
Show resolved
Hide resolved
...csearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java
Outdated
Show resolved
Hide resolved
private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;"; | ||
private static final String SS_RES = "residual_sum_of_squares"; | ||
|
||
private static String buildScript(Object...args) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why make this a varargs parameter when we know the template takes 2 replacements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows the call down to the MessageFormat#format
to be done without warnings or errors and no explicit casting. using a String var1, String var2
requires casting to get around warnings, and using Object var1, Object var2
requires manual construction of an array so that it matches the appropriate function definition.
…benwtrent/elasticsearch into feature/ml-add-r_squared-metric-to-eval
run elasticsearch-ci/2 |
* [ML] Add r_squared eval metric to regression * fixing tests and binarysoftclassification class * Update RSquared.java * Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java Co-Authored-By: David Kyle <david.kyle@elastic.co> * removing unnecessary debug test
* [ML] Add r_squared eval metric to regression * fixing tests and binarysoftclassification class * Update RSquared.java * Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java Co-Authored-By: David Kyle <david.kyle@elastic.co> * removing unnecessary debug test
This adds the RSquared metric to Regression evaluations. RSquared (also called Coefficient of determination) is a useful, well known, and widely used evaluation metric for regression type models.
This was easily enough done utilizing a
sum
aggregation with a script (for sum of residual squares) and utilizing the variance in theextended_stats
aggregation.I initially thought I could use the
sum_of_squares
result from theextended_stats
aggregation, but the value is literally thesum of the squares of each of the values
. So, values of[1, 2, 3]
thesum_of_squares
would be1 + 4 + 9 = 14
. This is not to be confused with total sum of squares, which is what we actually need.