-
Notifications
You must be signed in to change notification settings - Fork 40
/
35-caret.qmd
161 lines (105 loc) · 5.65 KB
/
35-caret.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# The caret package {#sec-caret}
* We have already learned about several machine learning algorithms.
* Many of these algorithms are implemented in R.
* However, they are distributed via different packages, developed by different authors, and often use different syntax.
* The __caret__ package tries to consolidate these differences and provide consistency.
* It currently includes over 200 different methods which are summarized in the __caret__ package manual^[https://topepo.github.io/caret/available-models.html]. Keep in mind that __caret__ does not include the needed packages and, to implement a package through __caret__, you still need to install the library.
* The required packages for each method are described in the package manual.
* The __caret__ package also provides a function that performs cross validation for us.
* Here we provide some examples showing how we use this incredibly helpful package.
* We will use the 2 or 7 example to illustrate and in later sections we use use the package to run algorithms on the larger MNIST datset.
## The `train` functon
The __caret__ `train` function lets us train different algorithms using similar syntax.
* So, for example, we can type:
```{r}
#| message: false
#| warning: false
#| cache: false
library(tidyverse)
library(dslabs)
library(caret)
train_glm <- train(y ~ ., method = "glm", data = mnist_27$train)
train_qda <- train(y ~ ., method = "qda", data = mnist_27$train)
train_knn <- train(y ~ ., method = "knn", data = mnist_27$train)
```
* To make predictions, we can use the output of this function directly without needing to look at the specifics of `predict.glm` and `predict.knn`.
* Instead, we can learn how to obtain predictions from `predict.train`.
The code looks the same for both methods:
```{r}
y_hat_glm <- predict(train_glm, mnist_27$test, type = "raw")
y_hat_qda <- predict(train_qda, mnist_27$test, type = "raw")
y_hat_knn <- predict(train_knn, mnist_27$test, type = "raw")
```
* This permits us to quickly compare the algorithms.
* For example, we can compare the accuracy like this:
```{r}
fits <- list(glm = y_hat_glm, qda = y_hat_qda, knn = y_hat_knn)
sapply(fits, function(fit) confusionMatrix(fit, mnist_27$test$y)$overall[["Accuracy"]])
```
## Cross validation {#sec-caret-cv}
* When an algorithm includes a tuning parameter, `train` automatically uses cross validation to decide among a few default values.
* To find out what parameter or parameters are optimized, you can read the manual ^[http://topepo.github.io/caret/available-models.html] or study the output of:
```{r, eval=FALSE}
getModelInfo("knn")
```
* We can also use a quick lookup like this:
```{r, eval=FALSE}
modelLookup("knn")
```
* If we run it with default values:
```{r}
train_knn <- train(y ~ ., method = "knn", data = mnist_27$train)
```
you can quickly see the results of the cross validation using the `ggplot` function.
* The argument `highlight` highlights the max:
```{r caret-highlight}
ggplot(train_knn, highlight = TRUE)
```
* By default, the cross validation is performed by taking 25 bootstrap samples comprised of 25% of the observations.
* For the `kNN` method, the default is to try $k=5,7,9$. We change this using the `tuneGrid` parameter.
* The grid of values must be supplied by a data frame with the parameter names as specified in the `modelLookup` output.
* Here, we present an example where we try out 30 values between 9 and 67.
* To do this with __caret__, we need to define a column named `k`, so we use this: `data.frame(k = seq(9, 67, 2))`.
* Note that when running this code, we are fitting 30 versions of kNN to 25 bootstrapped samples.
* Since we are fitting $30 \times 25 = 750$ kNN models, running this code will take several seconds.
* We set the seed because cross validation is a random procedure and we want to make sure the result here is reproducible.
```{r train-knn-plot}
set.seed(2008)
train_knn <- train(y ~ ., method = "knn",
data = mnist_27$train,
tuneGrid = data.frame(k = seq(9, 71, 2)))
ggplot(train_knn, highlight = TRUE)
```
* To access the parameter that maximized the accuracy, you can use this:
```{r}
train_knn$bestTune
```
and the best performing model like this:
```{r}
train_knn$finalModel
```
* The function `predict` will use this best performing model.
* Here is the accuracy of the best model when applied to the test set, which we have not used at all yet because the cross validation was done on the training set:
```{r}
confusionMatrix(predict(train_knn, mnist_27$test, type = "raw"),
mnist_27$test$y)$overall["Accuracy"]
```
* If we want to change how we perform cross validation, we can use the `trainControl` function.
* We can make the code above go a bit faster by using, for example, 10-fold cross validation.
* This means we have 10 samples using 10% of the observations each.
* We accomplish this using the following code:
```{r cv-10-fold-accuracy-estimate}
control <- trainControl(method = "cv", number = 10, p = .9)
train_knn_cv <- train(y ~ ., method = "knn",
data = mnist_27$train,
tuneGrid = data.frame(k = seq(9, 71, 2)),
trControl = control)
ggplot(train_knn_cv, highlight = TRUE)
```
* We notice that the accuracy estimates are more variable, which is expected since we changed the number of samples used to estimate accuracy.
* Note that `results` component of the `train` output includes several summary statistics related to the variability of the cross validation estimates:
```{r}
names(train_knn$results)
```
* We have only covered the basics.
* The **caret** package manual ^[https://topepo.github.io/caret/available-models.html] includes many more details.