## 欢迎进入：R语言mlr3verse机器学习-案例实战  
# R机器学习入门与实践01：Logistic回归预测鸢尾花分类  

可对比 Python 项目：[A.机器学习系列入门系列[一]:基于鸢尾花的逻辑回归分类预测](https://www.heywhale.com/mw/project/6414063236218140149eb4fd)  

算法原理略，先只上案例。

In [5]:
library(tidyverse)
library(mlr3verse)

## 1 Logistic回归：二分类  

### 1.1 准备数据  

鸢尾花数据是三分类，选择其中的两类，注意需要剔除多余的因子水平：

In [3]:
df = iris %>% 
  slice(1:100) %>% 
  mutate(Species = fct_drop(Species))
head(df)

Unnamed: 0_level_0,Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species
Unnamed: 0_level_1,<dbl>,<dbl>,<dbl>,<dbl>,<fct>
1,5.1,3.5,1.4,0.2,setosa
2,4.9,3.0,1.4,0.2,setosa
3,4.7,3.2,1.3,0.2,setosa
4,4.6,3.1,1.5,0.2,setosa
5,5.0,3.6,1.4,0.2,setosa
6,5.4,3.9,1.7,0.4,setosa


### 1.2 创建任务  

**任务**，是数据的封装。

In [4]:
task = as_task_classif(df, target = "Species")
task

<TaskClassif:df> (100 x 5)
* Target: Species
* Properties: twoclass
* Features (4):
  - dbl (4): Petal.Length, Petal.Width, Sepal.Length, Sepal.Width

可视化任务（探索数据）：

In [6]:
autoplot(task, type = "target")

In [7]:
autoplot(task, type = "duo")

Registered S3 method overwritten by 'GGally':
  method from   
  +.gg   ggplot2



In [8]:
autoplot(task, type = "pairs")

[1m[22m`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
[1m[22m`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
[1m[22m`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
[1m[22m`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.


### 1.3 划分训练集, 测试集

In [6]:
set.seed(123)
split = partition(task, ratio = 0.8)     # 默认根据目标变量分层

### 1.4 选择学习器  

**学习器**，是机器学习算法的封装。

In [7]:
log_reg = lrn("classif.log_reg", predict_type = "prob")
log_reg

<LearnerClassifLogReg:classif.log_reg>
* Model: -
* Parameters: list()
* Packages: mlr3, mlr3learners, stats
* Predict Types:  response, [prob]
* Feature Types: logical, integer, numeric, character, factor, ordered
* Properties: loglik, twoclass

### 1.5 训练模型  

在训练集上训练模型，提取并整洁输出模型结果：

In [9]:
log_reg$train(task, row_ids = split$train)
log_reg$model %>% broom::tidy()

“glm.fit: algorithm did not converge”
“glm.fit: fitted probabilities numerically 0 or 1 occurred”


term,estimate,std.error,statistic,p.value
<chr>,<dbl>,<dbl>,<dbl>,<dbl>
(Intercept),-5.120078,647515.6,-7.907265e-06,0.9999937
Petal.Length,-18.559977,158016.9,-0.0001174557,0.9999063
Petal.Width,-26.814618,255176.1,-0.0001050828,0.9999162
Sepal.Length,9.891978,203881.8,4.85182e-05,0.9999613
Sepal.Width,6.841599,100840.4,6.784582e-05,0.9999459


### 1.6 模型预测  

在测试集上做预测：

In [10]:
pred = log_reg$predict(task, row_ids = split$test)
pred

<PredictionClassif> for 20 observations:
    row_ids      truth   response  prob.setosa prob.versicolor
          1     setosa     setosa 1.000000e+00    2.220446e-16
          2     setosa     setosa 1.000000e+00    2.220446e-16
          6     setosa     setosa 1.000000e+00    2.220446e-16
---                                                           
         94 versicolor versicolor 7.004507e-13    1.000000e+00
         95 versicolor versicolor 2.220446e-16    1.000000e+00
         97 versicolor versicolor 2.220446e-16    1.000000e+00

### 1.7 性能评估

In [11]:
pred$confusion                     # 混淆矩阵
pred$score(msr("classif.acc"))     # 准确率
pred$score(msr("classif.auc"))     # AUC

            truth
response     setosa versicolor
  setosa         10          0
  versicolor      0         10

二分类可以绘制ROC曲线、PR曲线：（需要安装 `precrec` 包）

In [15]:
autoplot(pred, type = "roc")    # ROC曲线

ERROR: Error: The following packages could not be loaded: precrec


In [16]:
autoplot(pred, type = "prc")

ERROR: Error: The following packages could not be loaded: precrec


## 2 Logistic回归：多分类  

### 2.1 创建任务

In [12]:
task = as_task_classif(iris, target = "Species")
task

<TaskClassif:iris> (150 x 5)
* Target: Species
* Properties: multiclass
* Features (4):
  - dbl (4): Petal.Length, Petal.Width, Sepal.Length, Sepal.Width

### 2.2 划分训练集, 测试集

In [15]:
set.seed(1)
split = partition(task, ratio = 0.8)    # 默认根据目标变量分层

### 2.3 选择学习器

In [16]:
nnet = lrn("classif.multinom", predict_type = "prob")
nnet

<LearnerClassifMultinom:classif.multinom>
* Model: -
* Parameters: list()
* Packages: mlr3, mlr3learners, nnet
* Predict Types:  response, [prob]
* Feature Types: logical, integer, numeric, factor
* Properties: loglik, multiclass, twoclass, weights

### 2.4 训练模型

In [17]:
nnet$train(task, row_ids = split$train)

# weights:  18 (10 variable)
initial  value 131.833475 
iter  10 value 11.599606
iter  20 value 5.459655
iter  30 value 5.184589
iter  40 value 5.077373
iter  50 value 5.070058
iter  60 value 5.066652
iter  70 value 5.065916
iter  80 value 5.065263
final  value 5.065204 
converged


### 2.5 模型预测

In [18]:
pred = nnet$predict(task, row_ids = split$test)
pred

<PredictionClassif> for 30 observations:
    row_ids     truth  response  prob.setosa prob.versicolor prob.virginica
          2    setosa    setosa 9.999983e-01    1.677754e-06   2.303844e-28
          5    setosa    setosa 1.000000e+00    1.202973e-08   1.296068e-31
         11    setosa    setosa 1.000000e+00    3.012739e-09   1.348837e-32
---                                                                        
        144 virginica virginica 6.595540e-20    1.188276e-07   9.999999e-01
        148 virginica virginica 1.299052e-13    1.226760e-03   9.987732e-01
        150 virginica virginica 1.196920e-13    9.403741e-03   9.905963e-01

### 2.5 性能评估

In [19]:
pred$confusion                      # 混淆矩阵
pred$score(msr("classif.acc"))      # 准确率

            truth
response     setosa versicolor virginica
  setosa         10          0         0
  versicolor      0          9         0
  virginica       0          1        10

自定义绘制混淆矩阵函数：

In [20]:
plot_confusion = function(cm) {
  as_tibble(cm) %>% 
    mutate(response = factor(response),
           truth = factor(truth, rev(levels(response)))) %>% 
    ggplot(aes(response, truth, fill = n)) +
    geom_tile() +
    geom_text(aes(label = n)) +
    scale_fill_gradientn(colors = rev(hcl.colors(10, "Blues")),
                         breaks = seq(0,10,2)) +
    coord_fixed() +
    theme_minimal()
}


In [21]:
plot_confusion(pred$confusion)