In [34]:
from catboost import CatBoostClassifier
from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np

In [35]:
X = datasets.fetch_covtype().data[:3000]
y = datasets.fetch_covtype().target[:3000]
X_1, X_test, y_1, y_test = train_test_split(X, y, test_size=0.1)
X_train, X_va, y_train, y_va = train_test_split(X_1, y_1)

print(X_train.shape, y_train.shape)
print(np.unique(y_train))  # 7分类任务

(2025, 54) (2025,)
[1 2 3 4 5 6 7]


In [36]:
model = CatBoostClassifier(
    objective='MultiClass',
    thread_count=-1,
    allow_writing_files=False,
    task_type='CPU',
    leaf_estimation_method='Gradient',
    eval_metric='MultiClass',
    n_estimators=1000,
    max_depth=6,
    learning_rate=0.01,
    colsample_bylevel=1,
    reg_lambda=3.0
)
model

<catboost.core.CatBoostClassifier at 0x1874ef259d0>

In [37]:
model.fit(X_train, y_train,
          eval_set=[(X_va, y_va)],  # 自动将训练数据集加入评估
          sample_weight=1 / (y_train + 1),  # 对应Pool(weight)
          use_best_model=True,
          verbose=100,  # 对应verbose_eval
          early_stopping_rounds=100,
          )

0:	learn: 1.9290727	test: 1.9308676	best: 1.9308676 (0)	total: 18.3ms	remaining: 18.3s
100:	learn: 1.0211085	test: 1.1156802	best: 1.1156802 (100)	total: 1.73s	remaining: 15.4s
200:	learn: 0.7620121	test: 0.8810746	best: 0.8810746 (200)	total: 3.42s	remaining: 13.6s
300:	learn: 0.6463485	test: 0.7754755	best: 0.7754755 (300)	total: 5.1s	remaining: 11.8s
400:	learn: 0.5752489	test: 0.7111412	best: 0.7111412 (400)	total: 6.65s	remaining: 9.94s
500:	learn: 0.5235204	test: 0.6656630	best: 0.6656630 (500)	total: 7.92s	remaining: 7.89s
600:	learn: 0.4811194	test: 0.6293004	best: 0.6293004 (600)	total: 9.12s	remaining: 6.06s
700:	learn: 0.4490036	test: 0.6033213	best: 0.6033213 (700)	total: 10.4s	remaining: 4.42s
800:	learn: 0.4196243	test: 0.5805033	best: 0.5805033 (800)	total: 11.6s	remaining: 2.89s
900:	learn: 0.3942867	test: 0.5624550	best: 0.5624550 (900)	total: 12.9s	remaining: 1.42s
999:	learn: 0.3708435	test: 0.5456734	best: 0.5456734 (999)	total: 14.2s	remaining: 0us

bestTest = 0.54

<catboost.core.CatBoostClassifier at 0x1874ef259d0>

In [38]:
# 预测结果(num_sample, 1)
model.predict(X_test)

array([[6],
       [5],
       [4],
       [2],
       [2],
       [2],
       [2],
       [6],
       [1],
       [2],
       [4],
       [6],
       [1],
       [1],
       [5],
       [2],
       [2],
       [2],
       [2],
       [2],
       [3],
       [6],
       [2],
       [1],
       [2],
       [2],
       [5],
       [2],
       [2],
       [1],
       [3],
       [2],
       [1],
       [3],
       [3],
       [5],
       [2],
       [5],
       [5],
       [2],
       [1],
       [6],
       [1],
       [5],
       [2],
       [2],
       [5],
       [5],
       [2],
       [4],
       [5],
       [5],
       [6],
       [5],
       [5],
       [1],
       [7],
       [2],
       [6],
       [4],
       [2],
       [2],
       [1],
       [1],
       [2],
       [5],
       [2],
       [5],
       [2],
       [4],
       [1],
       [5],
       [1],
       [2],
       [1],
       [5],
       [2],
       [5],
       [3],
       [2],
       [3],
       [2],
       [2],
    

In [39]:
# 预测结果的概率矩阵
model.predict_proba(X_test)

array([[0.00614346, 0.01481774, 0.23432367, ..., 0.01228857, 0.70353162,
        0.00536308],
       [0.00963242, 0.02798403, 0.00674305, ..., 0.93943933, 0.00573211,
        0.00355752],
       [0.01520494, 0.02388158, 0.30910162, ..., 0.01482156, 0.25633912,
        0.00932615],
       ...,
       [0.00953938, 0.02027623, 0.52012657, ..., 0.02328807, 0.32789957,
        0.00836302],
       [0.01668658, 0.13603143, 0.00801589, ..., 0.81731689, 0.00797729,
        0.0053387 ],
       [0.63472793, 0.31253841, 0.00720136, ..., 0.00947448, 0.00761593,
        0.02185876]])

In [40]:
# 准确率
model.score(X_test, y_test)

0.8166666666666667

In [41]:
# 特征相对重要性
model.feature_importances_

array([2.63547480e+01, 3.35582233e+00, 2.96319671e+00, 6.70575114e+00,
       4.87746347e+00, 1.54087640e+01, 4.20576353e+00, 3.43741707e+00,
       3.64888737e+00, 9.73667891e+00, 5.62791247e+00, 0.00000000e+00,
       1.28770361e+00, 7.73096191e+00, 7.42420361e-02, 2.61341613e-02,
       2.97505649e-02, 1.09479932e-02, 9.91420384e-03, 1.58909291e-02,
       0.00000000e+00, 0.00000000e+00, 3.24536145e-03, 8.70531436e-02,
       1.19629277e-02, 8.52675097e-01, 1.56265479e-01, 2.52564093e-04,
       0.00000000e+00, 1.04287770e-01, 3.30942417e-02, 5.31907353e-02,
       1.42738120e-03, 2.81426277e-01, 0.00000000e+00, 2.63970320e-02,
       5.53116209e-01, 4.24547511e-02, 0.00000000e+00, 0.00000000e+00,
       0.00000000e+00, 6.77701672e-05, 8.97523745e-01, 1.21573477e+00,
       1.96779539e-02, 5.55578530e-03, 1.65915648e-02, 9.43996217e-04,
       0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 8.67032266e-02,
       4.05076840e-02, 1.89410238e-03])

In [42]:
model.evals_result_

{'learn': {'MultiClass': [1.9290727401930146,
   1.9127226901790793,
   1.8962962418551366,
   1.8799582498341538,
   1.8636519701464072,
   1.8475571255213616,
   1.8316538868742596,
   1.8155899407255207,
   1.80015905617703,
   1.784211173864234,
   1.7698777064486237,
   1.7553469696433122,
   1.7409809632501985,
   1.7265168674493736,
   1.7129668664328608,
   1.6994855295277889,
   1.6856433830325475,
   1.6720829807829192,
   1.6589058044857479,
   1.6459197402323071,
   1.6327014175012464,
   1.6200223866377377,
   1.6074802014921203,
   1.5953214896899663,
   1.5826914203531859,
   1.571163461468948,
   1.5594644683520074,
   1.54769262142582,
   1.5361504495909883,
   1.5242570993763285,
   1.5130322072294944,
   1.5018491565933454,
   1.4906140863808788,
   1.4793741921000307,
   1.4683648144076369,
   1.4577552562763583,
   1.4471755089400253,
   1.436867369336384,
   1.4269324242597208,
   1.416554741188675,
   1.4072461781597003,
   1.397762712900427,
   1.388192442979126

In [43]:
model.get_best_score()

{'learn': {'MultiClass': 0.3708434500514064},
 'validation': {'MultiClass': 0.5456734484112032}}

In [44]:
model.save_model('cat1.model')

In [45]:
bst = CatBoostClassifier()
bst.load_model('cat1.model')
bst.score(X_test, y_test)

0.8166666666666667