Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark results with better parameters #30

Closed
Laurae2 opened this issue Nov 1, 2018 · 23 comments
Closed

Benchmark results with better parameters #30

Laurae2 opened this issue Nov 1, 2018 · 23 comments
Labels
perf Computational performance issue or improvement

Comments

@Laurae2
Copy link
Contributor

Laurae2 commented Nov 1, 2018

Used a laptop for a better demo benchmark:

  • Intel Core i7-7700HQ (4 cores, 8 threads), unthrottled
  • 32GB RAM DDR4 2400 MHz (dual channel)
  • Python 3.6, scikit-learn 0.20, numba 0.40.1

Setup for the proper benchmarking:

  • No LightGBM / pygbm warmup allowed
  • 1 million training samples (10 million might crash on 64GB RAM? pygbm requires at least 24GB RAM for 1 million)
  • 500 training iterations
  • 255 leaves
  • 0.05 learning rate (can change to 0.10 actually for better comparison with independent benchmarks)

The benchmark in the master branch (https://github.com/ogrisel/pygbm/blob/master/benchmarks/bench_higgs_boson.py) is way too short and doesn't exactly test the speed of whole model due to how fast it is: there are diminishing returns when the number of iterations increases, and this is what is difficult to optimize once the tree construction is already optimized.

Results:

Model Time AUC Comments
LightGBM 45.260s 0.8293 Reference, runnable with 8GB RAM.
pygbm 359.101s 0.8180 Requires over 24GB RAM.
Slower as more trees are added over time.

Conclusion:

  • pygbm is 5 to 10 times slower, but don't consider because it is slower it is worse. It is actually very fast if we compare to 2 years ago with xgboost with exact method, and as of today we can consider it competitive in speed with xgboost exact if you have enough RAM
  • pygbm requires way too much RAM, you will notice it only when using many iterations because it seems to increase linearly

To run the benchmark, one can use the following for a clean setup, not optimized for fastest performance but you have the pre-requisites (0.20 scikit-learn, 0.39 numba):

pip install lightgbm
pip install -U scikit-learn
pip install -U numba

git clone https://github.com/ogrisel/pygbm.git
cd pygbm

Before installing pygbm, change the following in line 147 of pygbm/grower (https://github.com/ogrisel/pygbm/blob/master/pygbm/grower.py#L146-L147):

            node.construction_speed = (node.sample_indices.shape[0] /
                                       node.find_split_time)

to:

            node.construction_speed = (node.sample_indices.shape[0] / 1.0)

Allows to avoid the infamous divide by zero error.

Then, one can run the following:

pip install --editable .

If you have slow Internet, download HIGGS dataset: https://archive.ics.uci.edu/ml/machine-learning-databases/00280/ then uncompress it.

Then, you may run a proper benchmark using the following (make sure to change the load_path to your HIGGS csv file):

import os
from time import time
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from pygbm import GradientBoostingMachine
from lightgbm import LGBMRegressor
import numba
import gc


n_leaf_nodes = 255
n_trees = 500
lr = 0.05
max_bins = 255
load_path = "mnt/HIGGS/HIGGS.csv"
subsample = 1000000 # Change this to 10000000 if you wish, or to None

df = pd.read_csv(load_path, header=None, dtype=np.float32)
target = df.values[:, 0]
data = np.ascontiguousarray(df.values[:, 1:])
data_train, data_test, target_train, target_test = train_test_split(
    data, target, test_size=50000, random_state=0)

if subsample is not None:
    data_train, target_train = data_train[:subsample], target_train[:subsample]

n_samples, n_features = data_train.shape
print(f"Training set with {n_samples} records with {n_features} features.")

# Includes warmup time penalty
print("Fitting a LightGBM model...")
tic = time()
lightgbm_model = LGBMRegressor(n_estimators=n_trees, num_leaves=n_leaf_nodes,
                               learning_rate=lr, silent=False)
lightgbm_model.fit(data_train, target_train)
toc = time()
predicted_test = lightgbm_model.predict(data_test)
roc_auc = roc_auc_score(target_test, predicted_test)
print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}")
del lightgbm_model
del predicted_test
gc.collect()

# Includes warmup time penalty
print("Fitting a pygbm model...")
tic = time()
pygbm_model = GradientBoostingMachine(learning_rate=lr, max_iter=n_trees,
                                      max_bins=max_bins,
                                      max_leaf_nodes=n_leaf_nodes,
                                      random_state=0, scoring=None,
                                      verbose=1, validation_split=None)
pygbm_model.fit(data_train, target_train)
toc = time()
predicted_test = pygbm_model.predict(data_test)
roc_auc = roc_auc_score(target_test, predicted_test)
print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}")
del pygbm_model
del predicted_test
gc.collect()


if hasattr(numba, 'threading_layer'):
    print("Threading layer chosen: %s" % numba.threading_layer())

If something is missing in the script, please let me know.

@ogrisel
Copy link
Owner

ogrisel commented Nov 1, 2018

Indeed I observed the memory usage issue. We will need to investigate why this is the case. The grower object is expected to be big because we store the samples indices in the nodes but this should not be the case for the predictor objects. When the number of trees is increased, we only accumulate the predictor objects in a list. The grower objects should garbage collected has we progress. Maybe we do not collect them correctly for some reason.

I had not realized performance would degrade with more trees and larger number of leaf nodes.

@ogrisel
Copy link
Owner

ogrisel commented Nov 1, 2018

Including the warm-up penalty (JIT compilation overhead) is a bit unfair because it would be possible to use numba with a compile cache or Ahead of Time compiling but we do not want to spend time on this while we are still developing the pygbm package.

@ogrisel
Copy link
Owner

ogrisel commented Nov 1, 2018

To get additional info on where the time is spent in LightGBM you can compile it with the following:

diff --git a/CMakeLists.txt b/CMakeLists.txt
index c222221..9309026 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -54,6 +54,8 @@ if(USE_R35)
     ADD_DEFINITIONS(-DR_VER_ABOVE_35)
 endif()
 
+add_definitions(-DTIMETAG)
+
 if(USE_MPI)
     find_package(MPI REQUIRED)
     ADD_DEFINITIONS(-DUSE_MPI)

It should report additional information that makes it possible to compare with the verbose output of pygbm.

@ogrisel
Copy link
Owner

ogrisel commented Nov 1, 2018

On my laptop (XPS13 from 2 years ago) with the following change to the benchmark settings:

diff --git a/benchmarks/bench_higgs_boson.py b/benchmarks/bench_higgs_boson.py
index 3631edd..0097dcd 100644
--- a/benchmarks/bench_higgs_boson.py
+++ b/benchmarks/bench_higgs_boson.py
@@ -18,10 +18,10 @@ HERE = os.path.dirname(__file__)
 URL = ("https://archive.ics.uci.edu/ml/machine-learning-databases/00280/"
        "HIGGS.csv.gz")
 m = Memory(location='/tmp', mmap_mode='r')
-n_leaf_nodes = 31
-n_trees = 10
+n_leaf_nodes = 255
+n_trees = 50
 subsample = None
-lr = 1.
+lr = 0.1
 max_bins = 255

I get:

Model Time AUC Comments
LightGBM 177.147s 0.8188 Saturates the 4 Hyperthreads of my latop
pygbm 201.917s 0.8075 Requires over ~16GB RAM.
Individual trees construction time tend to decrease from 4.3s to 3.85s.

So LightGBM performs better than pygbm but the difference is not as strong as what you report.

Compile time overhead for pygbm (not included in the above results) is around 6-8s on my machine.

@ogrisel
Copy link
Owner

ogrisel commented Nov 1, 2018

With your hyperparameters:

diff --git a/benchmarks/bench_higgs_boson.py b/benchmarks/bench_higgs_boson.py
index 3631edd..a2b3866 100644
--- a/benchmarks/bench_higgs_boson.py
+++ b/benchmarks/bench_higgs_boson.py
@@ -18,10 +18,10 @@ HERE = os.path.dirname(__file__)
 URL = ("https://archive.ics.uci.edu/ml/machine-learning-databases/00280/"
        "HIGGS.csv.gz")
 m = Memory(location='/tmp', mmap_mode='r')
-n_leaf_nodes = 31
-n_trees = 10
-subsample = None
-lr = 1.
+n_leaf_nodes = 255
+n_trees = 500
+subsample = int(1e6)
+lr = 0.05
 max_bins = 255

I get the following:

Model Time AUC Comments
LightGBM 63.927s 0.8293 Saturates the 4 Hyperthreads of my laptop
pygbm 165.774s 0.8180 RES mem stays at 13GB but VIRT mem climbs to 24GB and slows down the last 200 trees (computer hanging)

I only have 16GB of RAM on this laptop and the VIRT memory usage seems to be the cause of the slowdown.

@Laurae2
Copy link
Contributor Author

Laurae2 commented Nov 1, 2018

New results with 1 million and n_leaf_nodes = 255 and n_trees = 50:

Model Time AUC Comments
LightGBM 9.037s 0.8080 Reference
pygbm 28.707s 0.7956 Seems 0.55s/tree

With subsample = None and n_leaf_nodes = 255 and n_trees = 50, I get the expected improvement vs your dual core:

Model Time AUC Comments
LightGBM 94.115s 0.8082 Reference
pygbm 151.020s 0.7971 9.173s 1st tree, drops to 2.933s last tree

That's a VERY interesting result for a small model!

Results with 1 million and n_leaf_nodes = 255 and n_trees = 250:

Model Time AUC Comments
LightGBM 31.793s 0.8266 Reference
pygbm 154.504s 0.8150 From 7.494s to 0.614s per tree

By the way, I noticed pygbm does not fully saturate my 8 threads. Usually, around 65% CPU usage (cores full, hyperthreaded cores not full), which means around 25% of performance remains unused (75% of the hyperthreads are not fully exploited) => potential parallelism issue? I'll check later on my 72 thread server.

@Laurae2
Copy link
Contributor Author

Laurae2 commented Nov 1, 2018

@ogrisel It seems with n_trees = 250 and n_leaf_nodes = 255 and 1 million samples, pygbm (and the dataset) wants in total about 16GB RAM.

@ogrisel
Copy link
Owner

ogrisel commented Nov 1, 2018

By the way, I noticed pygbm does not fully saturate my 8 threads. Usually, around 65% CPU usage (cores full, hyperthreaded cores not full), which means around 25% of performance remains unused (75% of the hyperthreads are not fully exploited) .

I noticed that too. If you install numba with conda, you can set numba.config.THREADING_LAYER to "omp" to use the OpenMP thread pool, as LightGBM does (although LightGBM is compiled with GCC and uses libgomp while numba uses LLVM openmp implementation). When I use OpenMP I see better CPU usage in htop (than with the tbb threading layer used by default), but the computation performance is approximately the same, or even slightly worse.

@ogrisel
Copy link
Owner

ogrisel commented Nov 1, 2018

We might have a discrepancy in the hyperparameters that would explain the difference in AUC but I am not sure which. We have a test that check that we get the same trees in non-pathological cases here:

https://github.com/ogrisel/pygbm/blob/master/tests/test_compare_lightgbm.py

But apparently this is not the case for Higgs boson. This would also require more investigation. I suspect our handling of shrinkage / learning rate is different.

@Laurae2
Copy link
Contributor Author

Laurae2 commented Nov 1, 2018

You need to check the following equivalent hyperparameters in pygbm from LightGBM:

  • num_iterations
  • learning_rate
  • num_leaves
  • max_depth (default: -1 for infinite)
  • min_data_in_leaf (default: 20, means at least 20 sample in a leaf)
  • min_sum_hessian_in_leaf (default: 1e-3, means at least 1e-3 hessian sum in a leaf)
  • lambda_l1 (default: 0)
  • lambda_l2 (default: 0)
  • min_gain_to_split (default: 0)
  • max_bin (default: 255)
  • min_data_in_bin (default: 3, means at least 3 samples in a bin)
  • bin_construct_sample_cnt (default: 200000, means use 200,000 samples to compute initial histogram for binning)
  • boost_from_average (default: True, means use optimal starting prediction value for all samples)

@NicolasHug
Copy link
Collaborator

One thing I noticed before is that min_samples_leaf in pygbm isn't really what we think it is. It's not as in scikit learn:

min_samples_leaf : int, float, optional (default=1)
The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least min_samples_leaf training samples in each of the left and right branches

In pygbm the logic is different, the parent will be split even if the childs have less samples than min_samples_leaf:

            if (self.min_samples_leaf is not None
                    and len(sample_indices_left) < self.min_samples_leaf):
                self._finalize_leaf(left_child_node)
            else:
                self._compute_spittability(left_child_node)
            if (self.min_samples_leaf is not None
                    and len(sample_indices_right) < self.min_samples_leaf):
                self._finalize_leaf(right_child_node)

Also I seem to remember that Lightgbm's min_child_samples was doing something quite weird (not equivalent to sklearn nor pygbm)

@szilard
Copy link

szilard commented Nov 1, 2018

great work (both the new implementation allowing to compare the effectiveness of LLVM/code generation etc. vs traditional implementation/compilation for GBMs) and the benchmarking effort as uncovered here in this issue

@NicolasHug
Copy link
Collaborator

NicolasHug commented Nov 1, 2018

EDIT: nevermind I need some sleep ^^
Thanks @Laurae2 for correcting me

Hmmm so num_leaves in LightGBM is exactly what the name suggest: the actual number of leaves. The docs say it's a maximum but that's not what I observe.

import lightgbm as lb
from sklearn.model_selection import train_test_split
import numpy as np
import pytest

from pygbm import GradientBoostingMachine
from pygbm import plotting

rng = np.random.RandomState(2)

n_samples = 100

max_leaf_nodes = 40
min_sample_leaf = 40

max_iter = 1

# data = linear target, 5 features, 3 irrelevant.
X = rng.normal(size=(n_samples, 5))
y = X[:, 0] - X[:, 1]

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)

est_lightgbm = lb.LGBMRegressor(n_estimators=max_iter,
                                min_data=1, min_data_in_bin=1,
                                learning_rate=1,
                                min_sample_leaf=min_sample_leaf,
                                num_leaves=max_leaf_nodes)
est_pygbm = GradientBoostingMachine(validation_split=None)  # just train for plotting to work
est_lightgbm.fit(X_train, y_train)
est_pygbm.fit(X_train, y_train)

plotting.plot_tree(est_pygbm, est_lightgbm, view=True)
max_leaf_nodes = 40
min_sample_leaf = 40

I get a tree with 40 leaves and leaves with 1 sample. Changing max_leaf_nodes to 10 I get 10 leaves.

So a side effect is that min_sample_leaf is completely bypassed.

I don't know if this comes from the python binding or directly from the c++ source though.

@dhirschfeld
Copy link

Are you using defaults::numpy? Anaconda patch numpy to have deeper integration with the mkl so I'm curious if that provides any benefit to pygbm?

Also, I'm curious if setting the KMP_COMPOSABILITY env var makes any difference?

# Composable OpenMP, exclusive mode
export KMP_COMPOSABILITY=mode=exclusive
# Composable OpenMP, counting mode
export KMP_COMPOSABILITY=mode=counting

https://github.com/scipy-conference/scipy_proceedings/blob/2018/papers/anton_malakhov/composability.rst

@guolinke
Copy link

guolinke commented Nov 2, 2018

@NicolasHug LightGBM doesn't have a parameter named min_sample_leaf.
Refer to https://github.com/Microsoft/LightGBM/blob/dfe0fae4ea5a412d253c25fbd997224e9243bd9a/docs/Parameters.rst#min_data_in_leaf

@ogrisel
Copy link
Owner

ogrisel commented Nov 2, 2018

@dhirschfeld there are no nested prange loops in pygbm so far and we don't do any linear algebra, numpy is just used as a passive datastructure (no BLAS routines used) so composability is probably useless in this context.

@ogrisel ogrisel added the perf Computational performance issue or improvement label Nov 3, 2018
@ogrisel
Copy link
Owner

ogrisel commented Nov 5, 2018

@Laurae2 I merged #36 with a fix to workaround a memory leak by numba. Please feel free to try again, the results should be better.

I also noticed that on a many core machine the tbb threading layer of numba gives much better performance than the workqueue backend, but LightGBM is still better at using all the cores very efficiently.

@ogrisel
Copy link
Owner

ogrisel commented Nov 5, 2018

We still get a lower accuracy when the number of trees is large and with a small learning rate. This discrepancy is tracked in #32.

@ogrisel
Copy link
Owner

ogrisel commented Nov 6, 2018

We also merged #37 that makes it possible to customize the benchmark parameters from the command line.

For instance:

$ python benchmarks/bench_higgs_boson.py --n-trees 500 --learning-rate 0.1 --n-leaf-nodes 255

Gives the following results (on a workstation with Intel(R) Xeon(R) CPU E5-2650 v4 @ 2.20GHz with 2 sockets each with 12 cores which means 48 hyperthreads in total) :

Model Time AUC Comments
LightGBM 235s 0.8402 Reference
pygbm 429s 0.8352 Using TBB threading layer, memory usage fluctuates but stays approximately within reasonable bounds (5-6GB)

@ogrisel
Copy link
Owner

ogrisel commented Nov 6, 2018

Here is another run on the same machine with a single core to check the scalability w.r.t. the number of threads:

NUMBA_NUM_THREADS=1 OMP_NUM_THREADS=1 python benchmarks/bench_higgs_boson.py  --n-trees 100 --learning-rate 0.1 --n-leaf-nodes 255
Model Time AUC Comments
LightGBM 1045s 0.8282 Reference
pygbm 1129s 0.8192

So even in single thread, pygbm is slightly slower.

And the AUC is slightly lower (tracked in #32).

@NicolasHug
Copy link
Collaborator

Regarding the speed there are still some places where we can parallelize the code, for example when computing the new gradients and hessians.

@ogrisel
Copy link
Owner

ogrisel commented Nov 6, 2018

I will open an issue dedicated to the thread scalability with more details.

Edit: here it is: #38.

@ogrisel
Copy link
Owner

ogrisel commented Nov 6, 2018

@Laurae2 I think we can close this issue. Remaining work will be tracked in #32 and #38.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
perf Computational performance issue or improvement
Projects
None yet
Development

No branches or pull requests

6 participants