Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add regression tree #2905

Closed
wants to merge 118 commits into from
Closed

Conversation

RishabhGarg108
Copy link
Member

@RishabhGarg108 RishabhGarg108 commented Apr 5, 2021

This PR attempts to add regression tree support to mlpack. Relevant discussion #2619.
This is going to be a long PR so hopefully, I will divide it into multiple parts.

The following checklist will broadly keep track of the PR.

  • MAD gain.
  • MSE gain
  • All categorical splitter
  • Best numeric Splitter
  • Random numeric Splitter
  • Add DecisionTreeRegressor class

@RishabhGarg108
Copy link
Member Author

RishabhGarg108 commented Apr 7, 2021

Hello @rcurtin, this is a follow-up from our discussion on IRC, I will move forward as discussed there. Another thing I want to clarify is that how do we modify the constructors for the class? Since the current implementation was for classification, it requires numClasses as a required argument. But this doesn't apply to the regression case.

There are two approaches that I have in my mind to solve it.

  1. One is to add 6 new constructors without the numClasses argument. And in those constructors, we can call the corresponding Train function.
  2. Another one is that we can set numClasses = 0 by default. So, if the user needs to do classification, he will require to set it to some non-zero value and in the default case, it does regression. Inside the constructors, we can add an if statement which will call the corresponding train function based on whether numClasses is zero or not.

From the end user's perspective, both are almost the same but what will be good from the perspective of the design principles and existing codebase?

PS I have accumulated other ideas too which I compiled here. This can provide some better insights into my thoughts.

@rcurtin
Copy link
Member

rcurtin commented Apr 8, 2021

Thanks for the clear writeup! In fact it is sufficiently comprehensive that I don't have anything to add. I agree with your preference---creating a separate class. Where possible, I would suggest sharing functionality between DecisionTree<> and DecisionTreeRegressor<>. So, it may make sense to factor out the overload of Train() that actually trains the tree into a generic standalone function that can be used by both classes. 👍

* @param minimum The minimum number of elements in a leaf.
*/
template<bool UseWeights, typename ResponsesType, typename WeightVecType>
void CalculateStatistics(const ResponsesType& responses,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this refactoring looks great! I have a few comments and they mostly have to do with naming and API simplicity. I think they would all be simple refactorings. Mostly the thought I have here is "how can we future-proof this API for other split types?" and also "how can we make it as general as possible, so we don't have to change it later?"

The first suggestion would be to prefix these functions' names with something like Binary, since these three functions (CalculateStatistics(), UpdateStatistics(), and the new overload of Evaluate()) are specific to the case where we are looking to find the best binary split by scanning an entire array.

Next, I see that you're calling Evaluate() with two different indices to get the left and right gain. But I wonder if it would be simpler to just return a std::tuple<double, double> (i.e. both gains at once), since the strategy we are using here is restricted to a binary split. (It seems possible but very nontrivial to generalize to a more-than-binary split, and we don't need it for our purposes anyway...) Another way to achieve the same thing would be to take two double&s that you set to the left and right gains in the function. That might "look" more like other functions inside of mlpack.

If you did return both gains at once, then actually it would be possible to simplify further and combine UpdateStatistics() with this new function that computes both gains.

So, for instance, we might have two functions like BinaryScanInitialize() (I believe this actually does not need any parameters---more comments below) and BinaryGains(const ResponsesType& responses, const WeightVecType& weights, const size_t splitIndex, double& leftGain, double& rightGain).

Those are just some ideas... let me know what you think. Like any design, part of it is personal preference, so, don't feel obligated to take every suggestion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a look at https://github.com/RishabhGarg108/mlpack-1/blob/ad64ad07d717c2b2d19be7bac82be03a2329071b/src/mlpack/methods/decision_tree/best_binary_numeric_split_impl.hpp#L407 from lines 407 to 417. First, we update statistics, then we do a check if we can skip the gain computation for a particular index and then we evaluate gain and later do other stuff.

Now, if we want to combine the UpdateStatistics and Evaluate method, then we have to move this check that skips the gain computation into that function too and it will be ugly to continue the loop from the function. We will have to create a flag variable that will be set to true when the skip condition is true in the function and based on that we will have to put another condition in the loop to continue it.

So, I think we can keep them separated. Let me know if I overlooked something or if there is some other way to achieve this.

One thing that we can definitely do is to return a tuple from the Evaluate method. 👍

Also, I am not that good at naming functions. Can you tell the exact names for these functions that you would like? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, so the only thing I was thinking is that the update on line 407 can be removed, but then you would need to have the function that computes the gain allow updating the index by many points at once (like described here). But I do think either way is fine, so, up to you.

For names I might suggest:

  • BinaryScanInitialize()
  • BinaryStep() for UpdateStatistics() (since it 'step's one index at a time)
  • BinaryGains() to compute the left and right gains

Let me know what you think. 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names make sense. Thanks :)

src/mlpack/methods/decision_tree/mse_gain.hpp Outdated Show resolved Hide resolved

if (UseWeights)
{
const WType w = weights[index];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user does not pass index as 1 greater than the previous index that was used when UpdateStatistics() was called, this could give an incorrect result. I wonder if it might be better to internally store lastIndex (initialized to 0), and loop over all values between lastIndex + 1 and index (inclusive) to update the value, instead of just index.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is just fine because we don't expect the user to call this function directly.

Moreover, iterating in this way allows us to check if the data value has changed from the last index or not. This helps us to skip computing gain for some of the indexes where the value doesn't change. Now, I don't deny that this can't be done with the way you are suggesting here but for that too, something similar would be needed to skip some indexes. So, I think this is okay the way it is now. Let me know if it doesn't make sense :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, agreed, there are a couple ways to do it---either you have two functions, one of which takes a step but doesn't return the gain, and the other of which takes a step and returns the gain (this is the way you have it now), or you have one function that allows 'fast forwarding', e.g., taking possibly multiple steps and returning the gain of the last one.

(By 'step' there I mean 'increase the index'.)

Up to you which you want to go with---personally, I think just one function is cleaner, but, I can see advantages and disadvantages to both approaches. (They are pretty minor tradeoffs though.)

src/mlpack/methods/decision_tree/mse_gain.hpp Show resolved Hide resolved
@RishabhGarg108 RishabhGarg108 changed the title [WIP] Add regression tree Add regression tree Jul 7, 2021
Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything is looking good to me here! I have a couple small comments---some little style issues, but I think there is an SFINAE bug (it should be easy to fix).

It looks like there is a merge conflict---can you try merging master into this branch? Alternately, I know the history is a little messed up in this branch already, so it might be worth creating a new branch, cherry-picking the relevant commits from this branch, and then opening a new PR. Either way should work, I think. 👍

HISTORY.md Show resolved Hide resolved
*/
template<bool UseWeights, typename VecType, typename ResponsesType,
typename WeightVecType>
typename std::enable_if<
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one isn't marked static, but I think it should be. However, when you do this, I expect you will have an ambiguous function call compilation error, because MSEGain can match both this overload and the other one. Thus, you need to use std::enable_if<> with the other overload, with the negated conditional: std::enable_if<!HasBinaryScanInitialize ... || !HasBinaryStep ..., double.

Note also you can use SFINAE as default-valued arguments to the function, like this:

  SplitIfBetter(
      const double bestGain,
      const VecType& data,
      const ResponsesType& responses,
      const WeightVecType& weights,
      const size_t minimumLeafSize,
      const double minimumGainSplit,
      double& splitInfo,
      AuxiliarySplitInfo& /* aux */,
      typename std::enable_if<..., void>::type* = 0);

But, to my knowledge they both work the same way, so it doesn't make a difference which you want to use.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @rcurtin, I added the static keyword to the function and it did compile. But, when I added the negated condition to the regular overload, It is giving me an error that "no matching overload of SplitIfBetter could be found". Can you please take a look at what is going wrong here?

@rcurtin
Copy link
Member

rcurtin commented Jul 15, 2021

Awesome---let's continue this in #3011.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
PR Tracking
Need Review
Development

Successfully merging this pull request may close these issues.

None yet

7 participants