Permalink
Branch: master
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
115 lines (85 sloc) 3.1 KB
---
title: "Introduction to train/predict/score"
author: "Michel Lang"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Introduction to train/predict/score}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
```{r setup, include = FALSE}
library(mlr3)
knitr::opts_knit$set(
datatable.print.keys = FALSE,
datatable.print.class = TRUE
)
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
set.seed(123)
```
In this introduction, we fit a classification tree on the iris and determine the mean misclassification error.
## Task and learner objects
First, we need to generate the following `mlr3` objects from the task dictionary and the learner dictionary, respectively:
1. The classification task
```{r}
library(mlr3)
task = mlr_tasks$get("iris")
```
2. A learner for the classification tree
```{r}
learner = mlr_learners$get("classif.rpart")
```
## Index vector for train/test splits
We opt to learn on $\frac{4}{5}$ of all available observations and predict on the remaining $\frac{1}{5}$ observations.
For this purpose, we create two index vectors:
```{r}
train.set = sample(task$nrow, 4/5 * task$nrow)
test.set = setdiff(seq_len(task$nrow), train.set)
```
## Setting up an experiment
The process of fitting a machine learning model, predicting on test data and scoring the predictions by comparing predicted and true labels is called an experiment.
For this reason, we start by initializing an `Experiment` object:
```{r}
e = Experiment$new(task = task, learner = learner)
print(e)
```
The printer shows a summary of the state of the experiment, which is currently `[defined]` and includes the task and the learner.
## Training
To train the learner on the task, we need to call the train function of the experiment:
```{r}
e$train(row_ids = train.set)
print(e)
```
The printer indicates that the `Experiment` object was modified (its state is now `[trained]`) and was also extended, since the object now includes a `rpart` model:
```{r}
rpart.model = e$model
print(rpart.model)
```
## Predicting
After the training step, we can use the experiment to predict on observations of the task (note that you may alternatively also pass new data here as `data.frame`):
```{r}
e$predict(row_ids = test.set)
print(e)
```
The predictions can be retrieved as a simple `data.table`.
```{r}
library(data.table)
head(as.data.table(e$prediction))
```
## Performance assessment
The last step of the experiment is quantifying the performance of the model by comparing the predicted labels with the true labels using a performance measure.
The default measure for the iris classification task is the mean misclassification error, which is used here by default:
```{r}
task$measures[[1L]]$id
e$score()
print(e)
e$performance["classif.mmce"]
```
The experiment is now "complete" which means we can access all of its methods.
## Chaining methods
Instead of calling the methods `$train()`, `$predict()` and `$score()` one after each other, it is also possible to chain these commands:
```{r}
Experiment$new(task = task, learner = learner)$train(train.set)$predict(test.set)$score()
```