-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Stop early (based on non-decreasing validation error) when adding trees to the forest #875
Conversation
retest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM although I also think it's not the final version. I have only a minor comment and a question about computeEta
call.
computeMaximumNumberTrees(m_TreeImpl->m_Eta)}; | ||
double eta{m_TreeImpl->m_EtaOverride != boost::none | ||
? *m_TreeImpl->m_EtaOverride | ||
: computeEta(frame.numberColumns())}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it not sufficient to have computeEta
within initializeHyperparameters
? Here it seems to be an unexpected place for the call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is an order problem: we need to set up progress monitoring before starting initializeHyperparameters
. However, this needs to anticipate the correct value for eta
so we monitor progress "correctly". In fact, there is still a problem because eta
gets set to different values in the hyperparameter optimisation loop, but this is at least better than what we had before. (I want to re-evaluate our strategy for eta
in a following PR, at which point I'll also fully fix progress monitoring.)
…g trees to the regression/classification forest (elastic#875)
We can relatively cheaply (around a 2% overhead) compute predictions for the test rows at the same time as we compute them for the training rows. This means we can cheaply track the validation error as we add additional trees to the forest during training.
The validation error curve is fairly predictable: it decreases quickly (typically exponentially) at the start, hits a minimum and then often increases slightly as the model starts to overfit the training data. This change introduces a very simple exit condition designed to ensure we pay a fixed relative runtime cost for ensuring we don't exit too soon. Specifically, add at least f * "maximum number of trees" trees to the lowest validation error forest and resize the forest to minimise validation error at the end of training. We set f to be 0.05, i.e. 5% of the total cost of training the forest.
I don't think this is the final version of early stopping. When we have a compute budget, or some way for the user to fix a scale between the run time and accuracy they'd be happy with, we can be more clever in how we stop. However, I've tested this on 15 benchmark data sets and it often slightly improves QoR and I've seen large (6x) drops in runtime for some data sets. I think this therefore represents a clear improvement over our current strategy.