-
Notifications
You must be signed in to change notification settings - Fork 1
/
tidymodel.R
99 lines (70 loc) · 2.13 KB
/
tidymodel.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
library(tidymodels)
## Data Sampling
iris_split <- initial_split(iris, prop = 0.6)
iris_split
iris_split %>%
training() %>%
glimpse()
## Pre-process interface
iris_recipe <- training(iris_split) %>%
recipe(Species ~.) %>%
step_corr(all_predictors()) %>%
step_center(all_predictors(), -all_outcomes()) %>%
step_scale(all_predictors(), -all_outcomes()) %>%
prep()
iris_recipe
## Execute the pre-processing
iris_testing <- iris_recipe %>%
bake(testing(iris_split))
glimpse(iris_testing)
iris_training <- juice(iris_recipe)
glimpse(iris_training)
## Model Training
library(randomForest)
library(ranger)
iris_ranger <- rand_forest(trees = 100, mode = "classification") %>%
set_engine("ranger") %>%
fit(Species ~ ., data = iris_training)
iris_rf <- rand_forest(trees = 100, mode = "classification") %>%
set_engine("randomForest") %>%
fit(Species ~ ., data = iris_training)
## Predictions
predict(iris_ranger, iris_testing)
iris_ranger %>%
predict(iris_testing) %>%
bind_cols(iris_testing) %>%
glimpse()
## Model Validation
iris_ranger %>%
predict(iris_testing) %>%
bind_cols(iris_testing) %>%
metrics(truth = Species, estimate = .pred_class)
iris_rf %>%
predict(iris_testing) %>%
bind_cols(iris_testing) %>%
metrics(truth = Species, estimate = .pred_class)
## Per classifier metrics
iris_ranger %>%
predict(iris_testing, type = "prob") %>%
glimpse()
iris_probs <- iris_ranger %>%
predict(iris_testing, type = "prob") %>%
bind_cols(iris_testing)
glimpse(iris_probs)
iris_probs%>%
gain_curve(Species, .pred_setosa:.pred_virginica) %>%
glimpse()
iris_probs%>%
gain_curve(Species, .pred_setosa:.pred_virginica) %>%
autoplot()
iris_probs%>%
roc_curve(Species, .pred_setosa:.pred_virginica) %>%
autoplot()
predict(iris_ranger, iris_testing, type = "prob") %>%
bind_cols(predict(iris_ranger, iris_testing)) %>%
bind_cols(select(iris_testing, Species)) %>%
glimpse()
predict(iris_ranger, iris_testing, type = "prob") %>%
bind_cols(predict(iris_ranger, iris_testing)) %>%
bind_cols(select(iris_testing, Species)) %>%
metrics(Species, .pred_setosa:.pred_virginica, estimate = .pred_class)