Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Add r_squared eval metric to regression #44248

Merged

Conversation

benwtrent
Copy link
Member

This adds the RSquared metric to Regression evaluations. RSquared (also called Coefficient of determination) is a useful, well known, and widely used evaluation metric for regression type models.

This was easily enough done utilizing a sum aggregation with a script (for sum of residual squares) and utilizing the variance in the extended_stats aggregation.

I initially thought I could use the sum_of_squares result from the extended_stats aggregation, but the value is literally the sum of the squares of each of the values. So, values of [1, 2, 3] the sum_of_squares would be 1 + 4 + 9 = 14. This is not to be confused with total sum of squares, which is what we actually need.

Copy link
Contributor

@tveasey tveasey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good Ben. I have one minor comment on normalisation of variance, but otherwise LGTM

ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
null :
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question, does extendedStats use the maximum likelihood or unbiased estimate of variance? This would be either a * extendedStats.getCount() or * (extendedStats.getCount() - 1) depending on the answer. I'm guessing this would have been flushed by the tests, but the difference is obviously small for large doc count, I think either way it is probably worth a comment here to say what the aggregation does use.

Copy link
Contributor

@tveasey tveasey Jul 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at InternalExtendedStats this does use the maximum likelihood form, so multiplying by count is the right thing to do. So maybe just worthwhile adding the comment that the definition of the sample variance used by getVariance() is sum squared residuals divided by count.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tveasey I will add the comment. FWIW, I have tested the calculation on some various datasets to make sure it matched what other tools (namely scikit) calculate for the r_squared metric.

…ml/dataframe/evaluation/regression/RSquared.java

Co-Authored-By: David Kyle <david.kyle@elastic.co>
Copy link
Contributor

@tveasey tveasey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@benwtrent
Copy link
Member Author

@elasticmachine update branch

@benwtrent benwtrent requested a review from davidkyle July 12, 2019 17:59
Copy link
Member

@davidkyle davidkyle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

/**
* Calculates R-Squared between two known numerical fields.
*
* equation: mse = 1 - SSres/SStot
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* equation: mse = 1 - SSres/SStot
* equation: R-Squared = 1 - SSres/SStot

@@ -88,7 +94,10 @@ public BinarySoftClassification(String actualField, String predictedProbabilityF
@Nullable List<EvaluationMetric> metrics) {
this.actualField = Objects.requireNonNull(actualField);
this.predictedProbabilityField = Objects.requireNonNull(predictedProbabilityField);
this.metrics = Objects.requireNonNull(metrics);
if (metrics != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it mean if there are no Evaluation Metrics, does that make any sense to allow this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way this is an improvement on the code that tagged metrics as nullable then has requireNonNull(metrics) 3 lines later

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see now. The server side has a set of default metrics that are used if this parameter is null

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, null => use default value. This follows the implementation pattern we have elsewhere in the HLRC.

private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
private static final String SS_RES = "residual_sum_of_squares";

private static String buildScript(Object...args) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why make this a varargs parameter when we know the template takes 2 replacements

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows the call down to the MessageFormat#format to be done without warnings or errors and no explicit casting. using a String var1, String var2 requires casting to get around warnings, and using Object var1, Object var2 requires manual construction of an array so that it matches the appropriate function definition.

@benwtrent
Copy link
Member Author

run elasticsearch-ci/2
run elasticsearch-ci/packaging-sample

@benwtrent benwtrent merged commit b4e16b6 into elastic:master Jul 15, 2019
@benwtrent benwtrent deleted the feature/ml-add-r_squared-metric-to-eval branch July 15, 2019 19:00
benwtrent added a commit to benwtrent/elasticsearch that referenced this pull request Jul 15, 2019
* [ML] Add r_squared eval metric to regression

* fixing tests and binarysoftclassification class

* Update RSquared.java

* Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java

Co-Authored-By: David Kyle <david.kyle@elastic.co>

* removing unnecessary debug test
benwtrent added a commit that referenced this pull request Jul 16, 2019
* [ML] Add r_squared eval metric to regression

* fixing tests and binarysoftclassification class

* Update RSquared.java

* Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java

Co-Authored-By: David Kyle <david.kyle@elastic.co>

* removing unnecessary debug test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants