In [1]:
!pip install optuna

Collecting optuna
  Downloading optuna-3.4.0-py3-none-any.whl (409 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m409.6/409.6 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.12.1-py3-none-any.whl (226 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.8/226.8 kB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting colorlog (from optuna)
  Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)
Collecting Mako (from alembic>=1.5.0->optuna)
  Downloading Mako-1.3.0-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: Mako, colorlog, alembic, optuna
Successfully installed Mako-1.3.0 alembic-1.12.1 colorlog-6.7.0 optuna-3.4.0


In [2]:
!pip install catboost

Collecting catboost
  Downloading catboost-1.2.2-cp310-cp310-manylinux2014_x86_64.whl (98.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.7/98.7 MB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: catboost
Successfully installed catboost-1.2.2


In [5]:
import numpy as np
import optuna
from optuna.integration import CatBoostPruningCallback

import catboost
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

In [14]:
X, y = make_classification(n_samples=1000, n_features=10)
train_x, valid_x, train_y, valid_y = train_test_split(X, y, test_size=0.25)

In [None]:
def objective(trial: optuna.Trial) -> float:
    param = {
        "objective": trial.suggest_categorical("objective", ["Logloss", "CrossEntropy"]),
        "colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.01, 0.1, log=True),
        "depth": trial.suggest_int("depth", 1, 12),
        "boosting_type": trial.suggest_categorical("boosting_type", ["Ordered", "Plain"]),
        "bootstrap_type": trial.suggest_categorical(
            "bootstrap_type", ["Bayesian", "Bernoulli", "MVS"]
        ),
        "used_ram_limit": "3gb",
        "eval_metric": "Accuracy",
    }

    if param["bootstrap_type"] == "Bayesian":
        param["bagging_temperature"] = trial.suggest_float("bagging_temperature", 0, 10)
    elif param["bootstrap_type"] == "Bernoulli":
        param["subsample"] = trial.suggest_float("subsample", 0.1, 1, log=True)

    gbm = catboost.CatBoostClassifier(**param)

    pruning_callback = CatBoostPruningCallback(trial, "Accuracy")
    gbm.fit(
        train_x,
        train_y,
        eval_set=[(valid_x, valid_y)],
        verbose=0,
        early_stopping_rounds=100,
        callbacks=[pruning_callback],
    )

    pruning_callback.check_pruned()

    preds = gbm.predict(valid_x)
    pred_labels = np.rint(preds)
    accuracy = accuracy_score(valid_y, pred_labels)

    return accuracy


if __name__ == "__main__":
    study = optuna.create_study(
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5), direction="maximize"
    )
    study.optimize(objective, n_trials=100, timeout=600)

    print("Number of finished trials: {}".format(len(study.trials)))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: {}".format(trial.value))

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

In [20]:
model = catboost.CatBoostClassifier(**trial.params).fit(X, y)

0:	learn: 0.6931453	total: 11.1ms	remaining: 11.1s
1:	learn: 0.6931435	total: 18.5ms	remaining: 9.22s
2:	learn: 0.6931419	total: 27.8ms	remaining: 9.25s
3:	learn: 0.6931403	total: 35.4ms	remaining: 8.82s
4:	learn: 0.6931388	total: 42.8ms	remaining: 8.51s
5:	learn: 0.6928812	total: 51.8ms	remaining: 8.57s
6:	learn: 0.6928798	total: 59.6ms	remaining: 8.46s
7:	learn: 0.6928786	total: 68.7ms	remaining: 8.52s
8:	learn: 0.6928748	total: 77.4ms	remaining: 8.53s
9:	learn: 0.6928738	total: 84.7ms	remaining: 8.38s
10:	learn: 0.6928727	total: 96.4ms	remaining: 8.67s
11:	learn: 0.6928717	total: 103ms	remaining: 8.52s
12:	learn: 0.6926908	total: 112ms	remaining: 8.54s
13:	learn: 0.6926899	total: 120ms	remaining: 8.48s
14:	learn: 0.6926891	total: 128ms	remaining: 8.39s
15:	learn: 0.6926883	total: 136ms	remaining: 8.38s
16:	learn: 0.6926876	total: 143ms	remaining: 8.27s
17:	learn: 0.6926870	total: 151ms	remaining: 8.23s
18:	learn: 0.6655613	total: 159ms	remaining: 8.19s
19:	learn: 0.6655612	total: 16