Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom objective and evaluation functions #1230

Closed
nimamox opened this issue Feb 5, 2018 · 15 comments

Comments

@nimamox
Copy link

@nimamox nimamox commented Feb 5, 2018

Hi folks,
The problem is that when I set fobj to my customized objective function, the prediction error I receive after training differs from what is reported by feval at the last iteration. The errors are identical, as expected, if I remove fobj.

def my_logistic_obj(y_hat, dtrain):
    y = dtrain.get_label()
    p = y_hat 
    grad = 4 * p * y + p - 5 * y
    hess = (4 * y + 1) * (p * (1.0 - p))
    return grad, hess

def my_err_rate(y_hat, dtrain):
    y = dtrain.get_label()
    y_hat = np.clip(y_hat, 10e-7, 1-10e-7)
    loss_fn = y*np.log(y_hat)
    loss_fp = (1.0 - y)*np.log(1.0 - y_hat)
    return 'myloss', np.sum(-(5*loss_fn+loss_fp))/len(y), False

def calc_loss(y, yp): #same as my_err_rate
    yp = np.clip(yp, 10e-7, 1.0-10e-7)
    loss_fn = y*np.log(yp)
    loss_fp = (1.0-y)*np.log(1.0-yp)
    return np.sum(-(5*loss_fn+loss_fp))/y.shape[0]

params = {
    'task': 'train',
    'objective': 'regression',
    'boosting': 'gbdt',
    'metric': 'auc',
    'train_metric': '+',
    'num_leaves': 260,
    'learning_rate': 0.0245,
    'feature_fraction': 0.9,
    'bagging_fraction': 0.8,
    'bagging_freq': 5,
    'max_depth': 15,
    'max_bin': 512
}

categoricals = None
lgb_train = lgb.Dataset(X_trn, y_trn, feature_name=features, categorical_feature=categoricals, free_raw_data=True) 
lgb_eval = lgb.Dataset(X_val, y_val, feature_name=features, categorical_feature=categoricals, reference=lgb_train, free_raw_data=True)

evals_result = {}
gbm = lgb.train(params,
                lgb_train,
                num_boost_round = 5,
                valid_sets=[lgb_eval, lgb_train],
#                 early_stopping_rounds = 2,
#                 fobj = my_logistic_obj, # <~~ First without fobj
                feval = my_err_rate,
                evals_result=evals_result,
                verbose_eval=1
               )

This gives me,

...
[4]	training's auc: 0.725577	training's myloss: 0.921026	valid_0's auc: 0.664285	valid_0's myloss: 0.826622
[5]	training's auc: 0.726053	training's myloss: 0.916696	valid_0's auc: 0.665518	valid_0's myloss: 0.824463

And if I evaluate both metrics on predictions, it correctly gives the same results:

y_pred = gbm.predict(X_val)
calc_loss(y_val, y_pred) # ~> 0.824463   --  this is to show calc_loss() is identical to my_err_rate()
fpr, tpr, thresholds = metrics.roc_curve(y_val, y_pred)
metrics.auc(fpr, tpr) # ~> 0.6655177490096028

But the problem is that if I enable my customized objective function, the AUC will be the same by my own loss is different!
Enabling fobj I'd have,

...
[4]	training's auc: 0.724176	training's myloss: 0.638512	valid_0's auc: 0.663375	valid_0's myloss: 0.620981
[5]	training's auc: 0.727059	training's myloss: 0.635095	valid_0's auc: 0.666306	valid_0's myloss: 0.61966

And I'll have

y_pred = gbm.predict(X_val)
calc_loss(y_val, y_pred) # ~> 0.644286 
fpr, tpr, thresholds = metrics.roc_curve(y_val, y_pred)
metrics.auc(fpr, tpr) # ~> 0.6663059118977099

Now while the AUC is identical to what verbose-mode training indicates for last iteration, as you see, my loss shows a significantly worse error to what is reported.

Environment info

Operating System: Linux (Ubuntu Server 16.04)
CPU: x86_64 E5-2697 v2 (48 cores)
Python version: 2.7.12

@nimamox nimamox changed the title Custom objective and evaluation function Custom objective and evaluation functions Feb 5, 2018
@guolinke

This comment has been minimized.

Copy link
Member

@guolinke guolinke commented Feb 5, 2018

@nimamox can you reproduce this with random generated data?

@nimamox

This comment has been minimized.

Copy link
Author

@nimamox nimamox commented Feb 5, 2018

@guolinke Sure! Create the dataset as below:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
X, y = make_classification(300000, 30, 4, 5, n_clusters_per_class=7, 
                           weights=[.9], flip_y=.02, class_sep=.4, random_state=11)
X_trn, X_val, y_trn, y_val = train_test_split(X, y, test_size=0.3, random_state=111)
lgb_train = lgb.Dataset(X_trn, y_trn)
lgb_eval = lgb.Dataset(X_val, y_val)
[1]	training's auc: 0.5	training's myloss: 1.30535	valid_0's auc: 0.5	valid_0's myloss: 1.30191
[2]	training's auc: 0.76412	training's myloss: 0.809083	valid_0's auc: 0.747303	valid_0's myloss: 0.84469
[3]	training's auc: 0.796385	training's myloss: 0.774707	valid_0's auc: 0.772937	valid_0's myloss: 0.820682
[4]	training's auc: 0.812655	training's myloss: 0.748513	valid_0's auc: 0.784416	valid_0's myloss: 0.804421
[5]	training's auc: 0.82571	training's myloss: 0.730843	valid_0's auc: 0.791629	valid_0's myloss: 0.793847
y_pred = gbm.predict(X_val)
print calc_loss(y_val, y_pred)  # ~> 0.825364 (instead of 0.793847)
fpr, tpr, thresholds = metrics.roc_curve(y_val, y_pred)
print metrics.auc(fpr, tpr)     # ~> 0.7916288878057938 (correct!)

Oh, and I don't have this problem with XGBoost.

@guolinke

This comment has been minimized.

Copy link
Member

@guolinke guolinke commented Feb 6, 2018

@nimamox you should set to objective=None when using fobj.
I will create a PR to auto set it.

@nimamox

This comment has been minimized.

Copy link
Author

@nimamox nimamox commented Feb 7, 2018

@guolinke Do you mean setting objective: None inside the params dictionary? Cause that doesn't seem to work. And I still get different result calling predict() :/ Should I set it to objective: 'none', as your commit suggests? Cause that doesn't work either, actually the training would be faulty:

with objective: 'none'

[1]	training's auc: 0.5	training's myloss: 4.41656	valid_0's auc: 0.5	valid_0's myloss: 3.88887
Training until validation scores don't improve for 2 rounds.
[2]	training's auc: 0.5	training's myloss: 4.41656	valid_0's auc: 0.5	valid_0's myloss: 3.88887
[3]	training's auc: 0.5	training's myloss: 4.41656	valid_0's auc: 0.5	valid_0's myloss: 3.88887
Early stopping, best iteration is:
[1]	training's auc: 0.5	training's myloss: 4.41656	valid_0's auc: 0.5	valid_0's myloss: 3.88887

@guolinke

This comment has been minimized.

Copy link
Member

@guolinke guolinke commented Feb 7, 2018

yeah, it is because your fobj is wrong, due to many zero hessians.

The split gain and leaf output is calculated by sum_grad / sum_hess. So it is better to avoid zero hessians in your fobj.

you can use sigmoid on your y_hat, so the y_hat will start from 0.5, and be in range (0, 1).

@nimamox

This comment has been minimized.

Copy link
Author

@nimamox nimamox commented Feb 7, 2018

@guolinke Thank you. My derivatives are correct, I guess (+). So I should add this line to my fobj and feval?
y_hat = 1.0 / (1.0 + np.exp(-y_hat))
But then how come my loss becomes considerably much worse than before:

With objective: 'regression' and without applying sigmoid:

Training until validation scores don't improve for 2 rounds.
[2]	training's auc: 0.669729	training's myloss: 0.671651	valid_0's auc: 0.634059	valid_0's myloss: 0.639334
[3]	training's auc: 0.673228	training's myloss: 0.668724	valid_0's auc: 0.638485	valid_0's myloss: 0.637083
[4]	training's auc: 0.675611	training's myloss: 0.666609	valid_0's auc: 0.641027	valid_0's myloss: 0.63558
[5]	training's auc: 0.677461	training's myloss: 0.664942	valid_0's auc: 0.643508	valid_0's myloss: 0.634298
Did not meet early stopping. Best iteration is:
[5]	training's auc: 0.677461	training's myloss: 0.664942	valid_0's auc: 0.643508	valid_0's myloss: 0.634298

With objective: 'none' and applying sigmoid within fobj and feval

Training until validation scores don't improve for 2 rounds.
[2]	training's auc: 0.673328	training's myloss: 0.852722	valid_0's auc: 0.639955	valid_0's myloss: 0.830982
[3]	training's auc: 0.674358	training's myloss: 0.844482	valid_0's auc: 0.640572	valid_0's myloss: 0.82241
[4]	training's auc: 0.674572	training's myloss: 0.836628	valid_0's auc: 0.640497	valid_0's myloss: 0.81424
[5]	training's auc: 0.675242	training's myloss: 0.829155	valid_0's auc: 0.641332	valid_0's myloss: 0.806509
Did not meet early stopping. Best iteration is:
[5]	training's auc: 0.675242	training's myloss: 0.829155	valid_0's auc: 0.641332	valid_0's myloss: 0.806509

Anyways, I'm actually not relying on the actual prediction, but rather the leaf predictions. Does the code I started originally with give wrong leaf predictions as well (wrong as in not optimizing my custom objective function)?

@guolinke

This comment has been minimized.

Copy link
Member

@guolinke guolinke commented Feb 7, 2018

you can run more iterations. regression will boost from average, which will be better in the beginning.

@nimamox

This comment has been minimized.

Copy link
Author

@nimamox nimamox commented Feb 7, 2018

Sorry to bother you again. Sure, more iterations help, but it still doesn't make up the ~0.2 difference in loss with the original "wrong" code. LGBM gave me comparable results to XGBoost with identical objective and loss, but it doesn't now. :(

On second note, the advanced_example.py code in the repository (+) defines a custom objective but doesn't set objective: 'none'. Why is that? My loss is simply logloss, only penalizing false negatives by a factor of 5.

@guolinke

This comment has been minimized.

Copy link
Member

@guolinke guolinke commented Feb 7, 2018

you can use scale_pos_weight to achieve the penalizing false negatives.

when using the regression, the training will start from the average value of label. That is the only different when you set objective=none.
BTW, you should add the sigmoid both in fobj and feval(before clip).

@guolinke

This comment has been minimized.

Copy link
Member

@guolinke guolinke commented Feb 7, 2018

BTW, from you result, I found the fobj convergences much faster. I think it will be better in the end.

@guolinke

This comment has been minimized.

Copy link
Member

@guolinke guolinke commented Feb 7, 2018

if it still cannot work, you can provide a similar example with random generated data, I can help you to figure it out.

@nimamox

This comment has been minimized.

Copy link
Author

@nimamox nimamox commented Feb 8, 2018

@guolinke As I mentioned, I'm using leaf predictions as categorical features to my other model, so I can not afford high number of trees, as it makes my other model heavy, and well, you know curse of dimensionality... I don't know what best practices are regarding using GBDT leaf predictions. Are later trees produce more informative leaf predictions? In that case I could "warm-up" first so-to-speak and then use predictions of later trees.

I had looked up scale_pos_weight but it doesn't only penalize FNs more, but weighing all positive samples more and I need to stick with my loss function cause it is eventually what is used to evaluate my work. I'll check with synthetic data and report back to you. Thanks :)

@nimamox

This comment has been minimized.

Copy link
Author

@nimamox nimamox commented Feb 13, 2018

@guolinke I guess you were right! The problem was that when I changed to objective: 'none', my learning rate would be very small, and I needed to increase it by an order of magnitude! I wondered before why XGBoost's appropriate learning rate is about .25 but anything more than .023 makes my LGBM to overfit.

@guolinke

This comment has been minimized.

Copy link
Member

@guolinke guolinke commented Feb 13, 2018

happy to know your problem had been solved.

@guolinke guolinke closed this Feb 13, 2018
@friedsela

This comment has been minimized.

Copy link

@friedsela friedsela commented Apr 24, 2019

Hi, I want to write my a custom loss function (of RMSPE if it matters) and I understand that it needs to return the gradient and the hessian. The gradient is a vector of the size of the input and the hessian is a symmetric matrix of the size of the input. Are those the shapes of grad and hess that should be returned by my function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.