Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction Results Do Not Match Original XGBoost #7

Closed
alex-r-xiao opened this issue Oct 4, 2017 · 20 comments
Closed

Prediction Results Do Not Match Original XGBoost #7

alex-r-xiao opened this issue Oct 4, 2017 · 20 comments

Comments

@alex-r-xiao
Copy link

I trained a model with a single iteration:

param = {
    'nthread': 10,
    'objective': 'multi:softprob',
    'eval_metric': 'mlogloss',
    'num_class': 3,
    'silent': 1,
    'max_depth': 5,
    'min_child_weight': 5,
    'eta': 0.5,  # learning rate
    'subsample': 1,
    'colsample_bytree': 1,
    'gamma': 0, 
    'alpha': 0,
    'lambda': 1, 
}
watchlist  = [(dtest, 'eval')]
bst = xgb.train(param, dtrain, 1, watchlist)

sample will be a numpy array with a single row of features. The output of:

bst.predict(xgb.DMatrix(sample))

differs noticeably from the output of:

bst_lite = treelite.Model.from_xgboost(bst)
bst_lite.export_lib(toolchain='gcc', libpath=model_path, verbose=True)
batch = Batch.from_npy2d(sample)
predictor = Predictor(model_path, verbose=True)
predictor.predict(batch)

The problem is compounded when more trees are added.

Any ideas of what could be going on here?

Thanks for your work on this @hcho3 .

@hcho3
Copy link
Collaborator

hcho3 commented Oct 4, 2017

@alex-r-xiao Thanks for trying treelite. It may be precision issue. Which dataset are you using?

@alex-r-xiao
Copy link
Author

It's a custom dataset. The predictions were about 7% off, and I made sure all my features are float comparison safe.

I won't have time for a couple of weeks but maybe we can try to replicate the issue with Boston Housing or something.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 4, 2017

I'd be very interested in re-producing the problem. Would you be able to send us the dataset? I understand that you may not be able to do so, and in that case, I'll try to re-produce the problem with other datasets.

@alex-r-xiao
Copy link
Author

Here is a sample of my use-case, it seems to work well however. I will continue to investigate on my end, I could have made a mistake somewhere.

from sklearn import datasets
import treelite
from treelite.runtime import Predictor, Batch
import xgboost as xgb
import numpy as np
import pandas as pd

model_path = './model.so'

# Load up dataset
boston = datasets.load_boston()
labels = pd.qcut(boston.target, [0, 0.33, 0.67, 1.0], [0, 1, 2])
dtrain = xgb.DMatrix(boston.data, label=labels, feature_names=boston.feature_names)

# Train the model
param = {
    'nthread': 1,
    'objective': 'multi:softprob',
    'eval_metric': 'mlogloss',
    'num_class': 3,
    'silent': 1,
    'max_depth': 5,
    'min_child_weight': 4,
    'eta': 0.3,
    'subsample': 1,
    'colsample_bytree': 1,
    'gamma': 0,
    'alpha': 0,
    'lambda': 1,
}
watchlist  = [(dtrain, 'eval')]
bst = xgb.train(param, dtrain, 2, watchlist)

# Build treelite model
bst_lite = treelite.Model.from_xgboost(bst)
bst_lite.export_lib(toolchain='gcc', libpath=model_path, verbose=True)
predictor = Predictor(model_path, verbose=True)

# Compare predictions
sample = boston.data[:2, :]
print 'treelite prediction: ', predictor.predict(
    Batch.from_npy2d(sample)
)
print 'xgboost  prediction: ', bst.predict(
    xgb.DMatrix(sample, feature_names=boston.feature_names)
)

@hcho3
Copy link
Collaborator

hcho3 commented Oct 4, 2017

One thing you should try: try predicting score margins rather than probabilities, by writing

bst.predict(xgb.DMatrix(sample), output_margin=True)

and

predictor = Predictor(model_path, verbose=True)
predictor.predict(batch, pred_margin=True)

@alex-r-xiao
Copy link
Author

alex-r-xiao commented Oct 4, 2017

Ok I did some more testing. Margins do not match in any of my test cases, including the Boston one above. What are score margins by the way?

Further, I was able to replicate the issue I was having by increasing the number of features. The problem becomes noticeable around 300 features I think. You can play around with the code below.

Perhaps the issue is some numerical propagation of errors. Also, there could be separate issue with how margins are reported.

from sklearn import datasets
import treelite
from treelite.runtime import Predictor, Batch
import xgboost as xgb
import pandas as pd

model_path = './model.so'

# Load up dataset
data, target = datasets.make_classification(
    n_samples=10000,
    n_features=350,
    n_classes=3,
    n_informative=40
)
# boston = datasets.load_boston()
# data = boston.data
# target = pd.qcut(boston.target, [0, 0.33, 0.67, 1.0], [0, 1, 2])

dtrain = xgb.DMatrix(data, label=target)

# Train the model
param = {
    'nthread': 1,
    'objective': 'multi:softprob',
    'eval_metric': 'mlogloss',
    'num_class': 3,
    'silent': 1,
    'max_depth': 5,
    'min_child_weight': 4,
    'eta': 0.3,
    'subsample': 1,
    'colsample_bytree': 1,
    'gamma': 0,
    'alpha': 0,
    'lambda': 1,
}
watchlist = [(dtrain, 'eval')]
bst = xgb.train(param, dtrain, 2, watchlist)

# Build treelite model
bst_lite = treelite.Model.from_xgboost(bst)
bst_lite.export_lib(toolchain='gcc', libpath=model_path, verbose=True)
predictor = Predictor(model_path, verbose=True)

# Compare predictions
sample = data[:1, :]
batch = Batch.from_npy2d(sample)
dtest = xgb.DMatrix(sample)
print 'treelite prediction: ', predictor.predict(batch)
print 'xgboost  prediction: ', bst.predict(xgb.DMatrix(sample))[0]
print 'treelite margin    : ', predictor.predict(batch, pred_margin=True)
print 'xgboost  margin    : ', bst.predict(xgb.DMatrix(sample), output_margin=True)[0]

@hcho3
Copy link
Collaborator

hcho3 commented Oct 12, 2017

@alex-r-xiao After some trial and error, I've located a bug in prediction code. The bug only affects dense matrices, not sparse matrices. I will push a fix within a day or two.

Thanks a lot for taking time to try out treelite!

hcho3 added a commit that referenced this issue Oct 13, 2017
Responding to issue #7.

What was the problem?
* Some calling parameters for TreeliteAssembleSparseBatch() and
  TreeliteAssembleDenseBatch() were incorrect.
* Array slice [rbegin,rend) must be explicitly saved to the handle object
  so that it doesn't get garbage collected
@hcho3
Copy link
Collaborator

hcho3 commented Oct 13, 2017

@alex-r-xiao The bug fix has been uploaded. Be sure to install the latest binary release (0.1a5) from PyPI.

@hcho3 hcho3 closed this as completed Oct 13, 2017
@alex-r-xiao
Copy link
Author

@hcho3 Predictions match now, but score margins are still off, could that be an issue? What are score margins?

@hcho3 hcho3 reopened this Oct 17, 2017
@hcho3
Copy link
Collaborator

hcho3 commented Oct 17, 2017

@alex-r-xiao Gradient boosted trees produce arbitrary real-numbered outputs, which we refer to as "margin scores." The margin scores then gets transformed into a proper probability distribution (all probabilities adding up to 1) by the softmax function.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 17, 2017

@alex-r-xiao
Let me give you an example: the following is one possible set of margin scores produced by a 2-iteration ensemble. We assume num_class=3. (The numbers are all made up, for the sake of illustration.)

Tree 0 produces  +0.5
Tree 1 produces  +1.5
Tree 2 produces  -2.3
Tree 3 produces  -1.5
Tree 4 produces  +0.1
Tree 5 produces  +0.7

Even though we trained for 2 iterations, we have a total of 6 decision trees. This is because, for multi-class classification, XGBoost will produce [number of iterations] * [number of classes] trees. Let's walk through to see how these margin scores get transformed into the final prediction.

  1. Group the trees into output groups.
    There should be as many output groups as there are label classes. For this example, there are 3 label classes, so there are 3 output groups. The member trees are grouped as follows:
Output group 0:  Tree 0, Tree 3
Output group 1:  Tree 1, Tree 4
Output group 2:  Tree 2, Tree 5
  1. Compute the sum of margin scores for each output group.
Output group 0:  (+0.5) + (-1.5) = -1.0
Output group 1:  (+1.5) + (+0.1) = +1.6
Output group 2:  (-2.3) + (+0.7) = -1.6

The vector [-1.0, +1.6, -1.6] is not a proper probability distribution, as the numbers don't add up to 1. However, even this is quite useful: the sum for output group 1 is the largest, so the second class is the most probable class.

  1. Transform the sums using the softmax function.
    To apply what is known as softmax, first apply the exponential function (exp) to every element in the vector.
    So we start with
[-1.0, +1.6, -1.6]

and compute

np.exp([-1.0, +1.6, -1.6])

which gives us

[ 0.36787944,  4.95303242,  0.20189652]

Then normalize the vector by dividing it by its element sum:

np.exp([-1.0, +1.6, -1.6]) / np.sum(np.exp([-1.0, +1.6, -1.6]))

giving a proper probability distribution:

[ 0.06661094,  0.89683221,  0.03655686]

We interpret this vector as showing 6.7% probability for the first class, 89.7% for the second class, and 3.7% for the third class.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 17, 2017

@alex-r-xiao As for score margins being off, can you provide us with a reproducible example? So far, I did not find any discrepancy with synthetic datasets generated by make_classification.

@alex-r-xiao
Copy link
Author

Thanks for the detailed explanation! Rerunning my last snippet, the predictions now match whereas the margins do not.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 19, 2017

@alex-r-xiao How much off are the margins? Since we're using 32-bit floating-points, margins can be off up to 1e-8, I believe.

@alex-r-xiao
Copy link
Author

I think they were off by something on the order of 1e-2.

hcho3 added a commit that referenced this issue Oct 20, 2017
Responding to issue #7.

Problem: margin predictions were off by 0.5

Diagnosis: XGBoost adds a global bias of 0.5 when training the model

Fix: Save the global bias parameter
@hcho3
Copy link
Collaborator

hcho3 commented Oct 20, 2017

Indeed, the margins were off by 0.5. This is because XGBoost adds a global bias of 0.5 by default. I've pushed a fix. Now the global bias will be added to every prediction, reproducing the behavior of XGBoost. The updated binaries (version 0.1a8) will be available in a few hours.

@hcho3
Copy link
Collaborator

hcho3 commented Oct 22, 2017

The updated binary (version 0.1a8) is now available from PyPI. Install it with

pip install treelite==0.1a8

@alex-r-xiao
Copy link
Author

That's a tough bug. Thanks!

@aazizisoufiane
Copy link

Dear,
I still find the same problem of matching predictions, on version 0.93, I even tried with and without margin. Can you help please?

@hcho3
Copy link
Collaborator

hcho3 commented Dec 10, 2020

Please post a new GitHub issue. Thanks!

@dmlc dmlc locked and limited conversation to collaborators Dec 10, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants