Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HSTree Multiclass Classification Support #122

Closed
Innixma opened this issue Jul 30, 2022 · 4 comments · Fixed by #130
Closed

HSTree Multiclass Classification Support #122

Innixma opened this issue Jul 30, 2022 · 4 comments · Fixed by #130
Assignees
Labels
enhancement New feature or request

Comments

@Innixma
Copy link

Innixma commented Jul 30, 2022

Does HSTree support multiclass classification problems with RandomForest / ExtraTrees as the estimator?

From my initial tests it appears buggy. Calling predict_proba with the final model results in lots of NaN predictions, along with warnings during training such as:

/Users/neerick/workspace/virtual/autogluon/lib/python3.8/site-packages/imodels/tree/hierarchical_shrinkage.py:87: RuntimeWarning: invalid value encountered in double_scalars
  val = tree.value[i][0, 1] / (tree.value[i][0, 0] + tree.value[i][0, 1])  # binary classification

If helpful I can try to create a reproducible example.

Here is an example result comparing with sklearn default RF (_og_) with accuracy metric. Because HSTree returns many NaN predictions, the scores are very low.

One observation is the scores get worse the more trees there are in HSTree forests. I'd guess the likelihood of returning a NaN result is increasing with the number of trees.

                       model  score_test  score_val  pred_time_test  pred_time_val  fit_time  pred_time_test_marginal  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order
0       RandomForest_og_n300    0.711651   0.723618        0.985573       0.050956  0.519926                 0.985573                0.050956           0.519926            1       True          1
1       RandomForest_og_n100    0.710154   0.748744        0.453769       0.019050  0.170951                 0.453769                0.019050           0.170951            1       True          2
2        WeightedEnsemble_L2    0.710154   0.748744        0.464755       0.019376  0.295161                 0.010986                0.000326           0.124210            2       True         36
3        RandomForest_og_n40    0.700636   0.698492        0.193009       0.010738  0.088012                 0.193009                0.010738           0.088012            1       True          3
4        RandomForest_og_n20    0.692039   0.698492        0.103616       0.007549  0.057396                 0.103616                0.007549           0.057396            1       True          4
5        RandomForest_og_n10    0.674165   0.688442        0.075296       0.006166  0.041720                 0.075296                0.006166           0.041720            1       True          5
6     RandomForest_hs=10_n10    0.521949   0.537688        0.070260       0.005246  0.082384                 0.070260                0.005246           0.082384            1       True         15
7     RandomForest_hs=50_n10    0.520839   0.517588        0.075151       0.004875  0.071219                 0.075151                0.004875           0.071219            1       True         20
8    RandomForest_hs=0.1_n10    0.520796   0.537688        0.074070       0.005233  0.093299                 0.074070                0.005233           0.093299            1       True         35
9      RandomForest_hs=1_n10    0.520692   0.542714        0.077687       0.005690  0.075061                 0.077687                0.005690           0.075061            1       True         10
10   RandomForest_hs=100_n10    0.519246   0.517588        0.075059       0.006019  0.082536                 0.075059                0.006019           0.082536            1       True         25
11   RandomForest_hs=500_n10    0.488877   0.517588        0.072145       0.005125  0.072223                 0.072145                0.005125           0.072223            1       True         30
12     RandomForest_hs=1_n20    0.485125   0.472362        0.113002       0.006484  0.123639                 0.113002                0.006484           0.123639            1       True          9
13   RandomForest_hs=0.1_n20    0.485005   0.472362        0.111342       0.005953  0.146246                 0.111342                0.005953           0.146246            1       True         34
14    RandomForest_hs=10_n20    0.484833   0.482412        0.104076       0.006577  0.131909                 0.104076                0.006577           0.131909            1       True         14
15    RandomForest_hs=50_n20    0.482896   0.482412        0.115057       0.006263  0.130512                 0.115057                0.006263           0.130512            1       True         19
16   RandomForest_hs=100_n20    0.480840   0.482412        0.108625       0.006045  0.135224                 0.108625                0.006045           0.135224            1       True         24
17   RandomForest_hs=500_n20    0.458035   0.467337        0.108658       0.006302  0.123907                 0.108658                0.006302           0.123907            1       True         29
18     RandomForest_hs=1_n40    0.451434   0.467337        0.185129       0.010619  0.210639                 0.185129                0.010619           0.210639            1       True          8
19   RandomForest_hs=0.1_n40    0.451382   0.467337        0.170597       0.009024  0.244322                 0.170597                0.009024           0.244322            1       True         33
20    RandomForest_hs=10_n40    0.451322   0.467337        0.173382       0.009955  0.210795                 0.173382                0.009955           0.210795            1       True         13
21    RandomForest_hs=50_n40    0.450350   0.467337        0.170041       0.008673  0.236081                 0.170041                0.008673           0.236081            1       True         18
22   RandomForest_hs=100_n40    0.449119   0.467337        0.169396       0.010918  0.226784                 0.169396                0.010918           0.226784            1       True         23
23   RandomForest_hs=500_n40    0.435832   0.472362        0.162881       0.009256  0.202447                 0.162881                0.009256           0.202447            1       True         28
24    RandomForest_hs=1_n100    0.420419   0.452261        0.442328       0.017688  0.480776                 0.442328                0.017688           0.480776            1       True          7
25  RandomForest_hs=0.1_n100    0.420411   0.452261        0.354523       0.018247  0.548557                 0.354523                0.018247           0.548557            1       True         32
26   RandomForest_hs=10_n100    0.419981   0.452261        0.355097       0.017487  0.469547                 0.355097                0.017487           0.469547            1       True         12
27   RandomForest_hs=50_n100    0.419034   0.447236        0.344341       0.021125  0.465810                 0.344341                0.021125           0.465810            1       True         17
28  RandomForest_hs=100_n100    0.418672   0.447236        0.372041       0.018402  0.477048                 0.372041                0.018402           0.477048            1       True         22
29  RandomForest_hs=500_n100    0.415256   0.457286        0.338696       0.017128  0.492786                 0.338696                0.017128           0.492786            1       True         27
30  RandomForest_hs=0.1_n300    0.381049   0.391960        0.967061       0.045552  1.533075                 0.967061                0.045552           1.533075            1       True         31
31   RandomForest_hs=10_n300    0.381049   0.391960        1.109062       0.054005  1.442369                 1.109062                0.054005           1.442369            1       True         11
32    RandomForest_hs=1_n300    0.381040   0.391960        1.677277       0.055421  2.346773                 1.677277                0.055421           2.346773            1       True          6
33   RandomForest_hs=50_n300    0.380945   0.391960        0.889030       0.053650  1.320377                 0.889030                0.053650           1.320377            1       True         16
34  RandomForest_hs=100_n300    0.380885   0.391960        1.031198       0.045266  1.254918                 1.031198                0.045266           1.254918            1       True         21
35  RandomForest_hs=500_n300    0.380816   0.391960        0.948715       0.050209  1.266396                 0.948715                0.050209           1.266396            1       True         26

@csinva
Copy link
Owner

csinva commented Jul 31, 2022

Hi Nick, you're right this is currently not supported (the shrink function is written only for univariate regression/binary classification and misbehaves with multiple classes). It's a pretty straightforward extension though and we can get to it soon!

@aagarwal1996 @yanshuotan

@csinva csinva added the enhancement New feature or request label Jul 31, 2022
@Innixma
Copy link
Author

Innixma commented Jul 31, 2022

That would be amazing, thanks!

@csinva csinva linked a pull request Aug 22, 2022 that will close this issue
@csinva
Copy link
Owner

csinva commented Aug 25, 2022

@Innixma should work now :)

@Innixma
Copy link
Author

Innixma commented Aug 25, 2022

Awesome, thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants