Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Histogram computation optimization #14

Merged
merged 49 commits into from
Oct 19, 2018

Conversation

NicolasHug
Copy link
Collaborator

@NicolasHug NicolasHug commented Oct 4, 2018

This is an attempt to use the relation

hist(parent) = hist(parent.left) + hist(parent.right)

to only compute explicitly either hist(parent.left) or hist(parent.right), and get the sibling's histogram by a simple subtraction.

This is supposed to be faster because explicitly computing a histogram is O(n_samples_at_node) while the subtraction is only O(n_bins) and n_samples_at_node >> n_bins.

TODO:

  • Ideally only the histogram of the sibling with the least number of samples should be explicitly computed. The current implementation does not support this and will compute the histogram of whatever node comes first during the tree growing process.
  • The code is a bit ugly, could be refactored into something better
  • Make it actually work...
  • Write some tests

The good news is, the predictions are the same as before. The bad news is, no clear performance gain is achieved... I'll try to understand what's going on.

@codecov-io
Copy link

codecov-io commented Oct 5, 2018

Codecov Report

Merging #14 into master will increase coverage by 19.22%.
The diff coverage is 98.4%.

Impacted file tree graph

@@             Coverage Diff             @@
##           master      #14       +/-   ##
===========================================
+ Coverage   73.91%   93.13%   +19.22%     
===========================================
  Files           8        8               
  Lines         598      714      +116     
===========================================
+ Hits          442      665      +223     
+ Misses        156       49      -107
Impacted Files Coverage Δ
pygbm/splitting.py 100% <100%> (ø) ⬆️
pygbm/grower.py 89.75% <100%> (+3.14%) ⬆️
pygbm/histogram.py 100% <100%> (ø) ⬆️
pygbm/plotting.py 100% <100%> (+100%) ⬆️
pygbm/gradient_boosting.py 79.38% <72.72%> (+66.88%) ⬆️
pygbm/__init__.py 100% <0%> (ø) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 176a35b...1c1ed55. Read the comment docs.

@ogrisel
Copy link
Owner

ogrisel commented Oct 6, 2018

Ideally only the histogram of the sibling with the least number of samples should be explicitly computed. The current implementation does not support this and will compute the histogram of whatever node comes first during the tree growing process.

This can be fixed easily in the grower itself by choosing to call compute split-ability of the smallest node first.

@ogrisel
Copy link
Owner

ogrisel commented Oct 6, 2018

The bad news is, no clear performance gain is achieved... I'll try to understand what's going on.

Too bad. You might need to add some start_time = time() / duration += time() - start_time manual profiling probes around the interesting parts to understand what's going on.

@NicolasHug
Copy link
Collaborator Author

NicolasHug commented Oct 7, 2018

This can be fixed easily in the grower itself by choosing to call compute split-ability of the smallest node first.

So I tried doing just that, here's a tree example.
digraph

fast means that the node histogram was computed with the histogram trick, time corresponds to the time needed to compute the histogram + the best split, and ratio corresponds to (node.sibling.time) / (node.time) when the sibling's histogram was computed with the slow method, and 1 otherwise. So it gives an idea of how much faster the fast method is.

Now something pretty weird is going on: I computed the average ratio of the fast nodes:

mean_ratio = np.mean([node['ratio'] for pred in pygbm_model.predictors_
                      for node in pred.nodes if node['ratio'] != 1])

and this value is < 1 when we use 'fast' on the siblings with the highest number of samples (which is supposed to be the correct way). Also, it's > 1 if we use the 'fast' method on the siblings with the least number of samples. This might be due to the fact that time does not correspond to the time needed to only compute the histogram, but also to compute the best split.

Other thoughts:

  • ordered_gradients and ordered_hessians are not needed when using the fast method. We're computing them in all cases so that's a waste of time. That might also explain those discrepancies.
  • When computing the hist of a sibling because it has a fewer number of samples, if that sibling is a leaf (or will become one) we don't need to compute the splittability, only the hist. We're currently wasting time in computing the splittability everytime.

@ogrisel
Copy link
Owner

ogrisel commented Oct 8, 2018

You should replot the above analysis on a larger dataset such as the Higgs boson dataset with 1e7 samples. Otherwise, the splits are too quick and the timing probes might induce some non-trivial overhead. To make it easier to debug, use a smaller tree with a 5 leaf nodes limit for instance

In the above graph, only the first level splits with large sample counts benefit from the optimization.

@ogrisel
Copy link
Owner

ogrisel commented Oct 8, 2018

Actually, your PR is already a good improvement on a larger dataset:

When running the benchmark/bench_higgs_boson.py on master I get:

Training set with 10950000 records with 28 features.
Fitting a LightGBM model...
[..]
done in 13.362s, ROC AUC: 0.7872
JIT compiling code for the pygbm model...
done in 9.060s
Fitting a pygbm model...
Binning 1.226 GB of data: 4.193 s (292.462 MB/s)
Fitting gradient boosted rounds:
[...]
Fit 10 trees in 33.098 s, (310 total leaf nodes)
done in 33.130s, ROC AUC: 0.7893

on your branch (omitting the LightGBM output to avoid redundancy):

JIT compiling code for the pygbm model...
done in 13.738s
Fitting a pygbm model...
Binning 1.226 GB of data: 4.982 s (246.182 MB/s)
Fitting gradient boosted rounds:
[...]
Fit 10 trees in 24.658 s, (310 total leaf nodes)
done in 24.675s, ROC AUC: 0.7893

@ogrisel
Copy link
Owner

ogrisel commented Oct 8, 2018

When running single threaded (OMP_NUM_THREADS=1 and NUMBA_NUM_THREADS=1) I get the following numbers for the above benchmark:

  • LightGBM: 51.267s
  • pygbm (master): 104.512s
  • pygbm (this branch): 75.257s

So it's already a 25% time reduction w.r.t. master, halfway to LightGBM.

@ogrisel
Copy link
Owner

ogrisel commented Oct 8, 2018

Based on your tree, it seems that when count < 3e5 samples, it's always faster to compute the histogram (<1ms). Doing the histogram substraction is slower: >1ms.

So I think you should enable this histogram substraction if the node count is larger than 5e5 as a rule of thumb.

@ogrisel
Copy link
Owner

ogrisel commented Oct 8, 2018

Actually my analysis is wrong because it depends on the imbalance between the sibling nodes. Doing the histogram optimization probably never hurts but it's not expected to yield huge improvements when the count are below 5e5.

In your plot, think it would be more interesting to report "count / time" split speeds and the ratio of (count_fast / time_fast) / (count_slow / time_slow) that is split_speed_fast / split_speed_slow .

@ogrisel
Copy link
Owner

ogrisel commented Oct 8, 2018

I think it would be interesting to build LightGBM with its TIMETAG feature enabled and collect similar statistics in pygbm to be able to compare.

https://github.com/Microsoft/LightGBM/blob/master/src/treelearner/serial_tree_learner.cpp#L34

It would also be interesting to count the number of times the histogram computation functions are called for each implementation to check that we are consistent.

@ogrisel
Copy link
Owner

ogrisel commented Oct 8, 2018

To enable TIMETAG, just edit the CMakeLists.txt at the root of the LightGBM source folder to add:

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 057302b..b735793 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -37,6 +37,9 @@ elseif(MSVC)
   cmake_minimum_required(VERSION 3.8)
 endif()
 
+
+add_definitions(-DTIMETAG)
+
 if(USE_SWIG)
   find_package(SWIG REQUIRED)
   find_package(Java REQUIRED)

the rebuild the python package: cd python-package; python setup.py install.

Passing verbose >= 2 should output the following (e.g. in the Higgs Boson benchmark):

Fitting a LightGBM model...
[LightGBM] [Info] Total Bins 6143
[LightGBM] [Info] Number of data: 10950000, number of used features: 28
[LightGBM] [Info] Start training from score 0.529920
[LightGBM] [Info] GBDT::boosting costs 0.188611
[LightGBM] [Info] GBDT::train_score costs 0.807898
[LightGBM] [Info] GBDT::out_of_bag_score costs 0.000005
[LightGBM] [Info] GBDT::valid_score costs 0.000001
[LightGBM] [Info] GBDT::metric costs 0.000000
[LightGBM] [Info] GBDT::bagging costs 0.000006
[LightGBM] [Info] GBDT::tree costs 7.310715
[LightGBM] [Info] SerialTreeLearner::init_train costs 0.122483
[LightGBM] [Info] SerialTreeLearner::init_split costs 0.000371
[LightGBM] [Info] SerialTreeLearner::hist_build costs 5.476519
[LightGBM] [Info] SerialTreeLearner::find_split costs 0.325841
[LightGBM] [Info] SerialTreeLearner::split costs 1.381069
[LightGBM] [Info] SerialTreeLearner::ordered_bin costs 0.000000
done in 12.431s, ROC AUC: 0.7872

Copy link
Owner

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skipping ordered_gradients computation when parent_histogram is not None might indeed bring some speed up.

This computation is a sequential section of the code, so it might also increase parallelism.

histogram = np.zeros(n_bins, dtype=HISTOGRAM_DTYPE)
unrolled_upper = (n_bins // 4) * 4

for i in range(0, unrolled_upper, 4):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that unrolling bring any speed up for just a short iteration? n_bins is expected to be 255 max.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea... I just thought it coudn't hurt but indeed I don't think it's really useful.

@@ -20,6 +20,40 @@ def _build_ghc_histogram_naive(n_bins, sample_indices, binned_feature,
return histogram


@njit
def _build_ghc_histogram_unrolled_fast(n_bins, parent_histogram,
sibling_histogram):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this a _substract_histograms and parent_histogram to hist_a and sibling_histogram to hist_b.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think that you should create a new method name find_node_split_by_substraction that would call a new function named _parallel_substract_histograms.

And keep find_node_split for the case where we actually compute histograms with _parallel_find_splits.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this definitely needs some refactoring. Thanks for all the comments, I'll go back to it soon.

Also, any idea on how to report execution time from inside numba? importing the time module does not work when in no-python mode. We're currently reporting time(hist + splittability) and what we're interested in is only time(hist).

Copy link
Owner

@ogrisel ogrisel Oct 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea: I found this question on stackoverflow but it's unanswered yet.

Maybe change the decorator from numba.njit to numba.jit and just use time() and see if that causes some significant overhead?

@NicolasHug
Copy link
Collaborator Author

Soooo we're getting closer :)

I simply duplicated the whole pipeline, the only difference between calling find_node_split_subtraction and find_node_split is that find_node_split_subtraction will end up calling _subtract_ghc_histograms_unrolled instead of one of the other 'low' histogram building functions.

And just like that, I get to run boson in 23 seconds (17 for lightgbm). That's pretty weird to me because the whole find_node_split didn't change at all, but I won't complain. Numba mysteries I suppose.

Now, I tried to get rid of all the ordered_gradient stuff. As far as I understand, we should be able to compute the hessians and gradients of the current node by just summing up the histogram, right? Unfortunately this will break my basic sanity check (AUC is not the same as before). Any idea what might be going on? Or am I simply mistaken about the computation?

It's super worth it though, because once we get rid of the ordered_stuff computation, boson runs in 12 seconds ^^

Last comments:

  • Code is super ugly but I just wanted your input
  • sorry about the first WIP, you can ignore it
  • no idea what's going on with travis, tests are passing locally

@ogrisel
Copy link
Owner

ogrisel commented Oct 11, 2018

And just like that, I get to run boson in 23 seconds (17 for lightgbm). That's pretty weird to me because the whole find_node_split didn't change at all, but I won't complain. Numba mysteries I suppose.

Nice! It's possible that the numba type inference had trouble with the previous code layout and was missing additional typing hints.

It's always a good idea to check the type inferred for all the local variables of the performance critical sections of the code so as to ensure that there is not unnecessary type casts that might prevent the compiler to optimize the generated code:

https://github.com/ogrisel/pygbm/#debugging-numba-type-inference

Now, I tried to get rid of all the ordered_gradient stuff. As far as I understand, we should be able to compute the hessians and gradients of the current node by just summing up the histogram, right?

Yes I believe so. Maybe you could alternative try to store the gradient and hessian sums on the grower nodes themselves along with the histograms as ancillary arrays.

It's super worth it though, because once we get rid of the ordered_stuff computation, boson runs in 12 seconds ^^

Sounds promising. Once we get good performance on a few cores, it would be interesting to check the scalability on hardware with many more cores (e.g. 32 on a cloud machine for instance) and contrast that to the performance profile of lightgbm.

Also note, the following test times out when disabling the numba jit:

https://travis-ci.org/ogrisel/pygbm/jobs/439885373

Maybe it could be made faster so that it does not fail when numba is disabled. Running without numba is useful to collect coverage data (and also possibly to check that numba itself not introducing unexpected bugs :)

predicted_test = pygbm_model.predict(data_test)
roc_auc = roc_auc_score(target_test, predicted_test)

assert roc_auc == 0.9809811751028725
Copy link
Owner

@ogrisel ogrisel Oct 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will likely be some rounding errors in the summing of floating point values (total gradient sums, gradient histograms, substractions), leading to the choice of slightly different splits. Therefore this kind of text should use some tolerance for instance using pytest.approx.

@NicolasHug
Copy link
Collaborator Author

All right I just found what was wrong, it was the hessian computation:

hessian_parent = np.sum(histogram[:]['sum_hessians'])

this was run regardless of the constant_hessian value. As here the hessian is constant, the sum_hessian field of the histogram of the sibling (with the slow method) would never be computed and always be 0. So the current histogram would be wrong for the sum_hessian field (always zero as well), and we would always hit

        if hessian_left < min_hessian_to_split:
            continue

There will likely be some rounding errors

Indeed, there are rounding errors. The relative error between the gradient_parent variable computed with the 2 methods is always less than 1E-3 though, and usually much less. As far as I can tell the splits are still the same (trees have the same structure), but the value of the leaves slightly change:

>       assert roc_auc == 0.9809811751028725
E       assert 0.9809811974530592 == 0.9809811751028725

I'll clean this all up!! :)

@NicolasHug
Copy link
Collaborator Author

Something fishy is going on (at least in my head ^^), I'll try to add some tests and identify a potential mistake I made.

I've added some tests and everything looks fine. I didn't have to fix anything so the fishyness was probably just me not wrapping my head around it.

I've moved back the computation of gradients and hessians into find_node_split_subtraction(). Now both pipeline are really similar.

@NicolasHug
Copy link
Collaborator Author

Is there anything I should do w.r.t. the failing tests with NUMBA_DISABLE_JIT="1" ?

@ogrisel
Copy link
Owner

ogrisel commented Oct 16, 2018

Is there anything I should do w.r.t. the failing tests with NUMBA_DISABLE_JIT="1" ?

Yes: remove it :)

More seriously, it's just way to slow: use a much smaller dataset. You cannot expect the pure python version to do large compute heavy tasks in a time limited test.

@ogrisel
Copy link
Owner

ogrisel commented Oct 16, 2018

The new version of the code seems a bit slower than previously: I get 21+s for pygbm on the higgs boson benchmark (vs 18s for lightgbm).

@NicolasHug
Copy link
Collaborator Author

Is it consistently slower than before? Because I get a few variations between each tries.

Makes me think: is there any way to have consistent benchmarks, so that we're sure that a change actually causes a slow down and that it's not just due to random causes? I guess CI isn't an option given the size of higgs boson dataset.

@NicolasHug
Copy link
Collaborator Author

Looking at the logs from numba annotation, it seems that hessian and constant_hessian_value are inferred as float64.

Is there a way to set locals for jitclass methods? I couldn't find any. Do we need to flatten the class then?

@ogrisel
Copy link
Owner

ogrisel commented Oct 17, 2018

Is it consistently slower than before? Because I get a few variations between each tries.

It's not that a big deal. The difference is close to the noise level of repeated runs. We can optimize that later.

I opened an issue in numba/numba#3417 to track this. In the mean time ok to use more functions if we really need to.

Let's remove the test_sanity test and merge this PR.

Copy link
Owner

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a batch of comments on the timing info collected in this PR.

I also thing there should be an example script (in a new "examples/" folder) that builds a single grower tree with less than 10 nodes and plot it. The dataset should be large enough in terms of samples (e.g. at least one million samples so that the timing information can be useful to analyze). A bit like you did in the sanity test, but as an example instead.

pygbm/grower.py Outdated
# Computation time of the histograms, or more precisely time to compute
# splitability, which may involve some useless computations
time = 0
ratio = 1 # sibling.time / node.time if node.hist_subtraction, else 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this attribute is hard to interpret correctly. I would rather not compute and report it but instead report construction_speed as sample_indices.shape[0] / time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean time / sample_indices.shape[0] ?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Processing speed is sample_indices.shape[0] / time in samples per seconds.

pygbm/grower.py Outdated
parent = None # Link to parent node, None for root
# Computation time of the histograms, or more precisely time to compute
# splitability, which may involve some useless computations
time = 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this to find_split_time and also collect apply_split_time that times grower.splitter.split_indices().

pygbm/grower.py Outdated
split_info, histograms = self.splitter.find_node_split(
node.sample_indices)
toc = time()
node.time = toc - tic
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also accumulate the the total time spent finding the best splits in a grower level attribute and the same for the time spent in applying the splits.

In the gradient boosting class we should also record the time spent in binning and finally in computing the predictions to get the new gradients and hessians for the next boosting iteration.

At the end of the fit of the gradient boosting, if verbose > 0 we should print a profiling report with the accumulated time for each of those type of operations in a similar way as LightGBM does.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it

There's a lot of

tic = time()
blablah()
toc = time()
time_spent = toc - tic

Do we want a context manager that would allow us to do something like

with Timer() as time_spent:
    blahblah()

?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you wish but I think it's fine like it is now. Using a context manager will add an indentation level.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's keep it this way then

('time', np.float32),
('ratio', np.float32),
('sum_g', np.float32),
('sum_h', np.float32),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to store the sum of gradients and hessians in the predictor? Maybe we should find a way to just plot the grower tree instead of the predictor tree if the goal is to present a tree-structured profile report of the cost of growing a tree.

The predictor tree is part of the public API of the model and I would rather note pollute it with attributes that are only useful for pygbm developers who try to optimize the implementation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we don't need those. Same goes for time and ratio, etc. Those were only useful for debugging.

I agree that we should instead plot the grower tree. That's quite easy to do I think, we would just need to keep a growers_ list attribute in the GradientBoostingMachine class and slightly change the plotting procedure. Should I do this? Here or another PR?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to keep the grower nodes on the GradientBoostingMachine. Instead I think we should only plot it when calling the TreeGrower code directly as we do in bench_grower.py. This is just for pygbm developer. It's not meant to be part of the public estimator API.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plotting the predictor trees is still useful for the users to learn about the structure of the trees. So the plotting utility should still work in that case.

I just want to extend it to make it work with additional info on the grower tree and use it in an example folder stating explicitly that this is only to analyze the performance profile of a single small tree. E.g.:

example/plot_performance_profile_single_small_tree.py

@ogrisel ogrisel merged commit fb4e579 into ogrisel:master Oct 19, 2018
@ogrisel
Copy link
Owner

ogrisel commented Oct 19, 2018

Ok merged!

@ogrisel
Copy link
Owner

ogrisel commented Oct 19, 2018

I rebased / merged instead of squash merging... Too bad for the history of this project...

@ogrisel
Copy link
Owner

ogrisel commented Oct 19, 2018

BTW I also pushed 3f7c36f to fix the plotting example timings (pre-compiling). The results are now constant: the speed is always faster when using histogram substraction.

@NicolasHug NicolasHug deleted the histogram_subtraction branch October 19, 2018 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants