-
Notifications
You must be signed in to change notification settings - Fork 13
/
working-with-lightgbm-catboost.Rmd
86 lines (70 loc) · 3.03 KB
/
working-with-lightgbm-catboost.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
---
title: "Working with lightgbm and catboost"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Working with lightgbm and catboost}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
eval = FALSE,
warning = FALSE,
message = FALSE,
eval = FALSE
)
```
```{r setup}
library(tidymodels)
library(treesnip)
data("diamonds", package = "ggplot2")
diamonds <- diamonds %>% sample_n(1000)
# vfold resamples
diamonds_splits <- vfold_cv(diamonds, v = 5)
```
## Model spec
Both lightgbm and catboost are engines for `parsnip::boost_tree()`. All main hyperparameters are available, except `loss_reduction` for catboost.
```{r}
model_spec <- boost_tree(mtry = 5, trees = 500) %>% set_mode("regression")
# model specs
lightgbm_model <- model_spec %>% set_engine("lightgbm", nthread = 6)
catboost_model <- model_spec %>% set_engine("catboost", nthread = 6)
#workflows
lightgbm_wf <- workflow() %>% add_model(lightgbm_model)
catboost_wf <- workflow() %>% add_model(catboost_model)
```
## Recipes and categorical features
Differently from xgboost, lightgbm and catboost deals with nominal columns natively. No `step_dummy()` or any other encoding required.
### Encodings benchmark
In fact, avoiding `step_dummy()` seems to be a good idea when using lightgbm or catboost.
```{r fit_resample, cache = TRUE}
rec_ordered <- recipe(price ~ ., data = diamonds)
rec_dummy <- rec_ordered %>% step_dummy(all_nominal())
rec_factor <- rec_ordered %>% step_unorder(all_nominal())
# glimpse(juice(prep(rec_factor)))
lightgbm_fit_ordered <- fit_resamples(add_recipe(lightgbm_wf, rec_ordered), resamples = diamonds_splits)
lightgbm_fit_dummy <- fit_resamples(add_recipe(lightgbm_wf, rec_dummy), resamples = diamonds_splits)
lightgbm_fit_factor <- fit_resamples(add_recipe(lightgbm_wf, rec_factor), resamples = diamonds_splits)
catboost_fit_ordered <- fit_resamples(add_recipe(catboost_wf, rec_ordered), resamples = diamonds_splits)
catboost_fit_dummy <- fit_resamples(add_recipe(catboost_wf, rec_dummy), resamples = diamonds_splits)
catboost_fit_factor <- fit_resamples(add_recipe(catboost_wf, rec_factor), resamples = diamonds_splits)
```
```{r}
bind_rows(
collect_metrics(lightgbm_fit_ordered) %>% mutate(model = "lightgbm", encode = "ordered"),
collect_metrics(lightgbm_fit_dummy) %>% mutate(model = "lightgbm", encode = "dummy"),
collect_metrics(lightgbm_fit_factor) %>% mutate(model = "lightgbm", encode = "factor"),
collect_metrics(catboost_fit_ordered) %>% mutate(model = "catboost", encode = "ordered"),
collect_metrics(catboost_fit_dummy) %>% mutate(model = "catboost", encode = "dummy"),
collect_metrics(catboost_fit_factor) %>% mutate(model = "catboost", encode = "factor")
) %>%
filter(.metric == "rmse") %>%
ggplot(aes(x = encode, y = mean, group = model, colour = model)) +
geom_point() +
geom_errorbar(aes(ymin = mean - std_err, ymax = mean + std_err), width = 0.2) +
geom_line() +
labs(y = "RMSE (the lower the better)") +
theme_minimal(20)
```