You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Я думаю, в коде catboost вычисляющем precision где-то перепутаны предсказания и истинные значения, поэтому ранняя остановка по точности срабатывает сразу же. Прошу расследовать )
Также хорошо бы добавить юзер параметр pos_label, как в sklearn, с дефолтом 1, потому что для бизнеса обычно важнее точность минорного класса 1, и это по сути стандарт.
Воспроизведение ошибки:
importnumpyasnpfromcollectionsimportCounterfromcatboostimportCatBoostClassifierfromsklearn.model_selectionimporttrain_test_splitfromcatboostimportversionprint("catboost",version.VERSION)
# create binary target with majority label 0, split to train/testy=np.concatenate([np.zeros(1000),np.ones(100)])
X=np.random.random(len(y))
print("Y distribution:",Counter(y))
X_train,X_test,y_train,y_test=train_test_split(X.reshape(-1,1),y,test_size=0.2,shuffle=True)
print("Y_test distribution:",Counter(y_test))
# create & train a catboost instance with ES on Precision:loss_function="Logloss"eval_metric="Precision"custom_metric="AUC"task_type="CPU"iterations=100use_best_model=Truelearning_rate=0.2verbose=Falserandom_seed=Noneearly_stopping_rounds=iterations//5gbm=CatBoostClassifier(
loss_function=loss_function,
eval_metric=eval_metric,
custom_metric=custom_metric,
iterations=iterations,
task_type=task_type,
early_stopping_rounds=early_stopping_rounds,use_best_model=use_best_model,
verbose=verbose
)
gbm.fit(X_train, y_train, eval_set=(X_test, y_test), use_best_model=use_best_model, plot=True)
preds=gbm.predict(X_test)
print(f"Predictions of a model that has best {eval_metric} as per Catboost ES:",Counter(preds))
print("Metrics of the best model:", gbm.get_best_score())
Фактический результат:
catboost 1.2
Y distribution: Counter({0.0: 1000, 1.0: 100})
Y_test distribution: Counter({0.0: 200, 1.0: 20})
Predictions of a model that has best Precision as per Catboost ES: Counter({0.0: 220})
Metrics of the best model: {'learn': {'Logloss': 0.32802699483212905, 'Precision': 1.0}, 'validation': {'Logloss': 0.33153463822915086, 'AUC': 0.5509999999999999, 'Precision': 1.0}}
То есть модель, которая предсказывает все нули (мажорный класс), катбуст считает имеющей точность 1, а это нехорошо и нелогично со всех сторон (кроме явного указания юзером pos_label=0, но это пока не поддерживается). Где-то внутри должна быть ошибка.
Ожидаемый результат: что-то более близкое к
Predictions of a model that has best Precision as per Catboost ES: Counter({0.0: 200, 1.0:20})
The text was updated successfully, but these errors were encountered:
Также хорошо бы добавить юзер параметр pos_label, как в sklearn, с дефолтом 1, потому что для бизнеса обычно важнее точность минорного класса 1, и это по сути стандарт.
А, так это не только во время early stopping так считается. Поясните, плиз, почему у катбуста получается 1.
from catboost.utils import eval_metric
from sklearn.metrics import precision_score
print(eval_metric(label=[0,0,0,1,1,1], approx=[0,0,0,0,0,0],metric='Precision'))
print(precision_score(y_true=[0,0,0,1,1,1], y_pred=[0,0,0,0,0,0]))
[1.0]
0.0
CatBoost действительно ставил по умолчанию если нет ни одного предсказанного объекта позитивного класса. Это скорее плохое поведение по вышеозвученным причинам, так что в 9a54fe3 исправлено на 0 + warning по аналогии с scikit-learn.
Problem:
catboost version: 1.2
Operating System: Win
CPU: +
GPU: +
Я думаю, в коде catboost вычисляющем precision где-то перепутаны предсказания и истинные значения, поэтому ранняя остановка по точности срабатывает сразу же. Прошу расследовать )
Также хорошо бы добавить юзер параметр pos_label, как в sklearn, с дефолтом 1, потому что для бизнеса обычно важнее точность минорного класса 1, и это по сути стандарт.
Воспроизведение ошибки:
Фактический результат:
То есть модель, которая предсказывает все нули (мажорный класс), катбуст считает имеющей точность 1, а это нехорошо и нелогично со всех сторон (кроме явного указания юзером pos_label=0, но это пока не поддерживается). Где-то внутри должна быть ошибка.
Ожидаемый результат: что-то более близкое к
The text was updated successfully, but these errors were encountered: