New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add regression tree #2905
Add regression tree #2905
Conversation
Hello @rcurtin, this is a follow-up from our discussion on IRC, I will move forward as discussed there. Another thing I want to clarify is that how do we modify the constructors for the class? Since the current implementation was for classification, it requires There are two approaches that I have in my mind to solve it.
From the end user's perspective, both are almost the same but what will be good from the perspective of the design principles and existing codebase? PS I have accumulated other ideas too which I compiled here. This can provide some better insights into my thoughts. |
Thanks for the clear writeup! In fact it is sufficiently comprehensive that I don't have anything to add. I agree with your preference---creating a separate class. Where possible, I would suggest sharing functionality between |
* @param minimum The minimum number of elements in a leaf. | ||
*/ | ||
template<bool UseWeights, typename ResponsesType, typename WeightVecType> | ||
void CalculateStatistics(const ResponsesType& responses, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, this refactoring looks great! I have a few comments and they mostly have to do with naming and API simplicity. I think they would all be simple refactorings. Mostly the thought I have here is "how can we future-proof this API for other split types?" and also "how can we make it as general as possible, so we don't have to change it later?"
The first suggestion would be to prefix these functions' names with something like Binary
, since these three functions (CalculateStatistics()
, UpdateStatistics()
, and the new overload of Evaluate()
) are specific to the case where we are looking to find the best binary split by scanning an entire array.
Next, I see that you're calling Evaluate()
with two different indices to get the left and right gain. But I wonder if it would be simpler to just return a std::tuple<double, double>
(i.e. both gains at once), since the strategy we are using here is restricted to a binary split. (It seems possible but very nontrivial to generalize to a more-than-binary split, and we don't need it for our purposes anyway...) Another way to achieve the same thing would be to take two double&
s that you set to the left and right gains in the function. That might "look" more like other functions inside of mlpack.
If you did return both gains at once, then actually it would be possible to simplify further and combine UpdateStatistics()
with this new function that computes both gains.
So, for instance, we might have two functions like BinaryScanInitialize()
(I believe this actually does not need any parameters---more comments below) and BinaryGains(const ResponsesType& responses, const WeightVecType& weights, const size_t splitIndex, double& leftGain, double& rightGain)
.
Those are just some ideas... let me know what you think. Like any design, part of it is personal preference, so, don't feel obligated to take every suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take a look at https://github.com/RishabhGarg108/mlpack-1/blob/ad64ad07d717c2b2d19be7bac82be03a2329071b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp#L407 from lines 407 to 417. First, we update statistics, then we do a check if we can skip the gain computation for a particular index and then we evaluate gain and later do other stuff.
Now, if we want to combine the UpdateStatistics
and Evaluate
method, then we have to move this check that skips the gain computation into that function too and it will be ugly to continue
the loop from the function. We will have to create a flag variable that will be set to true when the skip condition is true in the function and based on that we will have to put another condition in the loop to continue
it.
So, I think we can keep them separated. Let me know if I overlooked something or if there is some other way to achieve this.
One thing that we can definitely do is to return a tuple from the Evaluate
method. 👍
Also, I am not that good at naming functions. Can you tell the exact names for these functions that you would like? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh, so the only thing I was thinking is that the update on line 407 can be removed, but then you would need to have the function that computes the gain allow updating the index by many points at once (like described here). But I do think either way is fine, so, up to you.
For names I might suggest:
BinaryScanInitialize()
BinaryStep()
forUpdateStatistics()
(since it 'step's one index at a time)BinaryGains()
to compute the left and right gains
Let me know what you think. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The names make sense. Thanks :)
|
||
if (UseWeights) | ||
{ | ||
const WType w = weights[index]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the user does not pass index
as 1 greater than the previous index
that was used when UpdateStatistics()
was called, this could give an incorrect result. I wonder if it might be better to internally store lastIndex
(initialized to 0
), and loop over all values between lastIndex + 1
and index
(inclusive) to update the value, instead of just index
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is just fine because we don't expect the user to call this function directly.
Moreover, iterating in this way allows us to check if the data value has changed from the last index or not. This helps us to skip computing gain for some of the indexes where the value doesn't change. Now, I don't deny that this can't be done with the way you are suggesting here but for that too, something similar would be needed to skip some indexes. So, I think this is okay the way it is now. Let me know if it doesn't make sense :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, agreed, there are a couple ways to do it---either you have two functions, one of which takes a step but doesn't return the gain, and the other of which takes a step and returns the gain (this is the way you have it now), or you have one function that allows 'fast forwarding', e.g., taking possibly multiple steps and returning the gain of the last one.
(By 'step' there I mean 'increase the index'.)
Up to you which you want to go with---personally, I think just one function is cleaner, but, I can see advantages and disadvantages to both approaches. (They are pretty minor tradeoffs though.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything is looking good to me here! I have a couple small comments---some little style issues, but I think there is an SFINAE bug (it should be easy to fix).
It looks like there is a merge conflict---can you try merging master into this branch? Alternately, I know the history is a little messed up in this branch already, so it might be worth creating a new branch, cherry-picking the relevant commits from this branch, and then opening a new PR. Either way should work, I think. 👍
*/ | ||
template<bool UseWeights, typename VecType, typename ResponsesType, | ||
typename WeightVecType> | ||
typename std::enable_if< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one isn't marked static
, but I think it should be. However, when you do this, I expect you will have an ambiguous function call compilation error, because MSEGain
can match both this overload and the other one. Thus, you need to use std::enable_if<>
with the other overload, with the negated conditional: std::enable_if<!HasBinaryScanInitialize ... || !HasBinaryStep ..., double
.
Note also you can use SFINAE as default-valued arguments to the function, like this:
SplitIfBetter(
const double bestGain,
const VecType& data,
const ResponsesType& responses,
const WeightVecType& weights,
const size_t minimumLeafSize,
const double minimumGainSplit,
double& splitInfo,
AuxiliarySplitInfo& /* aux */,
typename std::enable_if<..., void>::type* = 0);
But, to my knowledge they both work the same way, so it doesn't make a difference which you want to use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @rcurtin, I added the static
keyword to the function and it did compile. But, when I added the negated condition to the regular overload, It is giving me an error that "no matching overload of SplitIfBetter could be found". Can you please take a look at what is going wrong here?
src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/decision_tree/decision_tree_regressor_impl.hpp
Outdated
Show resolved
Hide resolved
Awesome---let's continue this in #3011. |
This PR attempts to add regression tree support to mlpack. Relevant discussion #2619.
This is going to be a long PR so hopefully, I will divide it into multiple parts.
The following checklist will broadly keep track of the PR.