Copyright 2022 Vasile Rus, Andrew M. Olney and made available under [CC BY-SA](https://creativecommons.org/licenses/by-sa/4.0) for text and [Apache-2.0](http://www.apache.org/licenses/LICENSE-2.0) for code.


# Naïve Bayes

Naïve Bayes is a supervised data science method typically used for **classification/categorization** tasks as exemplified before in, for instance, the logistic regression notebook.
For that reason, it can be viewed as estimating the probabilities of a number of outcome variable values, e.g., the probabilities of categories in classification.
To classify a particular object or instance $X$, the class with the highest probability among all possible classes $1$ to $C$ is taken as shown below:

$$class (X) = argmax_{c \in (1..C)} P(c_i|X)$$ 
        
While quite successful in classification tasks, the actual estimated probabilities for each class are not very reliable.

In this notebook, we focus on multinomial, hard classification tasks.

### What you will learn

In this notebook, you will learn about naïve Bayes, an original data science paradigm to approach primarily classification tasks, and how it can be used to infer from labeled/annotated naïve Bayes based classifiers.  
We will study the following:

- The basics of naïve Bayes
- The meaning of “naïve” in naïve Bayes
- Details about how naïve Bayes models are trained
- Evaluation of performance for naïve Bayes classifiers


### When to use naïve Bayes

Naïve Bayes classifiers are useful when you have a categorical response/outcome variable and there are multiple features/predictors that can be used to predict the correct value of the outcome variable. 
The ultimate goal is to build automatically a probabilistic model to predict the correct value of the outcome variable for a new instance described by the set of predictors/features. 
Naïve Bayes outputs a probability distribution over the values of the outcome variable and therefore for each class a probability value is being generated. 
The category with the highest probability is typically chosen as the correct/most-likely category for the corresponding instance. 
Naïve Bayes has the advantage of being simple and highly accurate for classification when features can be treated independently, comparable to logistic regression but more easily extended to many categories.

## Mathematical Foundations of Naïve Bayes for Binary, Hard Classification

We briefly review in this section the mathematical formulation of the naïve Bayes method for multinomial, hard classification problems. 
That is, we assume the outcome for one instance or object can be one and only one category out $C$ possible categories.

The naïve Bayes method relies on Bayes' Theorem shown below:

$$P (Y|X) = \frac{P(Y)P(X|Y)}{P(X)}$$

The term $P (Y|X) $ is called the posterior, the term $P(Y)$ is called the prior, and the term $P(X|Y)$ is called the likelihood.

In a classification case, Y can take as value any of the classes $c \in (1..C)$ and X is described as a set of features/predictors $X=(x_1,..,x_P)$. 
Then Bayes' Theorem becomes:

$$P (Y=c_i| (x_1,..,x_P)) = \frac{P(Y=c_i)P(x_1,..,x_P|Y=c_i)}{P(x_1,..,x_P)}$$

The naïve Bayes method takes this theorem and based on the naive assumption of the predictors $x_i$ being independent, i.e., meaning $P(x_1,..,x_P|Y=c_i)$ is approximated by $\prod \limits _{j=1} ^P P(x_j|c_i)$, it re-writes the theorem in the following form:

$$P (Y=c_i| (x_1,..,x_P)) = \frac{P(Y=c_i) \prod \limits _{j=1} ^P P(x_j|c_i)}{P(x_1,..,x_P)}$$

This naive formulation of the theorem is more manageable in terms of estimating the parameters of the distributions involved and in particular of the likelihood probability.

## Training a Naïve Bayes Classifier

Training a naïve Bayes classifier implies deriving the prior and likelihood distributions from training data based on the naive formulation of Bayes' Theorem.

The prior $P(c_i)$ is derived using the following expression:

$$ P(c_i)= \frac{{\#} c_i}{N}$$

where ${\#} c_i$ is the number of training instances labeled with class $i$ and $N$ is the total number of training instances.

The likelihood $P (X | Y) = \prod _{j=1} ^P P(x_j|c_i) = \prod P(x_1|c_i)P(x_2|c_i)... P(x_P|c_i)$ is derived by multiplying individual conditional distributions for each predictor $x_i$ as shown below:

$$ P(x_i|c_i) = \frac{{\#} x_{ci}}{{\#} c_i}$$

Once the prior and likelihood distributions derived, to predict the most likely class for a new instance $X=(x_i, ..., x_P)$ we apply the naïve Bayes formula:

$$class (X) = argmax_{c \in (1..C)} {P(c_i|X)} = argmax_{c \in (1..C)} P (Y=c_i| (x_1,..,x_P)) = argmax_{c \in (1..C)} \frac{P(Y=c_i) P(x_1|c_i)P(x_2|c_i)... P(x_P|c_i)}{P(x_1,..,x_P)} $$

Since the denominator does not depend on $c_i$, the argument of argmax, we can ignore the denominator.
Then the most likely class can be simply obtained using this formula:

$$class (X) = argmax_{c \in (1..C)} P(c_i|X) = argmax_{c \in (1..C)} P(Y=c_i) P(x_1|c_i) P(x_2|c_i) ... P(x_P|c_i)$$ 

That is, the most likely class is the class correspond to the posterior probability estimated based on the above naive formulation of the Bayes Theorem.

<!-- NOTE: this has already been covered -->
<!-- ## Peformance Evaluation for Classification Methods including Naïve Bayes

The typical performance metrics for classifiers are accuracy, precision, and recall. These are typical derived by compared the predicted output to the golden or actual output/categories in the expert labelled dataset.

For a binary classification case, we denote the category 1 as the positive category and category 0 as the negative category. Using this new terminology, When comparing the predicted categories to the actual categories we may end up with the following cases:
* True Positives (TP): instances predicted as belonging to the positive category and which in fact do belong to the positive category
* True Negatives (TN): instances predicted as belonging to the negative category and which in fact do belong to the negative category
* False Positives (FP): instances predicted as belonging to the positive category and which in fact do belong to the negative category
* False Negatives (FN): instances predicted as belonging to the negative category and which in fact do belong to the positive category

From these categories, we define the following metrics:

$Accuracy = \frac{TP + TN}{TP + TN + FP + FN}$

$Precision = \frac{TP}{TP + FP}$

$Recall = \frac{TP}{TP + FN}$

Classfication methods that have a high accuracy are preferred in general although in some case maximizing precision or recall may be preferred. For instance, a high recall is highly recommended when making medical diagnosis since it is preferrable to err on mis-diagnosing someone as having cancer as opposed to missing someone who indeed has cancer, i.e., the method should try not to miss anyone who may indeed have cancer. 

In general, there is a trade-off between precision and recall. If precision is high then recall is low and viceversa. Total recall (100% recall) is achievable by always predicting the positive class, i.e., label all instances as positive, in which case precision will be very low. -->

## Example: Naïve Bayes

The data we will use is the `nursery` dataset, which ranks applications for nursery schools in Slovenia during the 1980s.
Because the original dataset is a fair bit larger, we've randomly sampled 2000 rows.

The goal is to predict `rank`.

| Variable | Type    | Description                                        |
|:----------|:---------|:----------------------------------------------------|
| parents  | Nominal | usual, pretentious, great_pret                     |
| has_nurs | Nominal | proper, less_proper, improper, critical, very_crit |
| form     | Nominal | complete, completed, incomplete, foster            |
| children | Nominal | 1, 2, 3, more                                      |
| housing  | Nominal | convenient, less_conv, critical                    |
| finance  | Nominal | convenient, inconv                                 |
| social   | Nominal | non-prob, slightly_prob, problematic               |
| health   | Nominal | recommended, priority, not_recom                   |
| rank    | Nominal | not_recom, recommend, very_recom, priority, spec_prior   |

<div style="text-align:center;font-size: smaller">
 <b>Source:</b> This dataset was taken from the <a href="https://archive.ics.uci.edu/ml/datasets/Nursery">UCI Machine Learning Repository library
    </a></div>
<br>


### Load data

Start by loading `readr`, `dplyr`, `base` and `tidyr` so we can read data into a dataframe and manipulate it:

- `library readr`
- `library dplyr`
- `library tidyr`
- `library base`

In [1]:
library(readr)
library(dplyr)
library(tidyr)
library(base)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="_;PP-/]_2fNUR.dyhw(8">readr</variable><variable id="`IEAx*Bh}E,Y}mK;jr;{">dplyr</variable><variable id="%ovw[;dNkqx,|V0M/RZA">tidyr</variable><variable id="6C!l+@S8,SvtzakmsxQD">base</variable></variables><block type="import_R" id="q]np1Ju|B`4k*R-zylwU" x="44" y="66"><field name="libraryName" id="_;PP-/]_2fNUR.dyhw(8">readr</field><next><block type="import_R" id="(%@0XC,((6M%4]kj+iDm"><field name="libraryName" id="`IEAx*Bh}E,Y}mK;jr;{">dplyr</field><next><block type="import_R" id="l.:Sc#tFH?#!tX=Zl.4b"><field name="libraryName" id="%ovw[;dNkqx,|V0M/RZA">tidyr</field><next><block type="import_R" id="AmB3B[T@UHq1l2P*NjUF"><field name="libraryName" id="6C!l+@S8,SvtzakmsxQD">base</field></block></next></block></next></block></next></block></xml>


Attaching package: ‘dplyr’


The following objects are masked from ‘package:stats’:

    filter, lag


The following objects are masked from ‘package:base’:

    intersect, setdiff, setequal, union




Load the dataframe and convert all character columns to factor at the same time.
This avoids having to specify `col_types`, which is tedious when many columns are nominal:

- Set `dataframe` to 
    - `pipe`
        - `with readr do read_csv`
            - using `"datasets/nursery.csv"`
        - to with `dplyr` do `mutate`
            - using `across(everything(), factor)
         
- `dataframe` (to display)

In [3]:
dataframe = readr::read_csv("datasets/nursery.csv") %>%
    dplyr::mutate(across(everything(),factor))

dataframe

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="(*,(U(rm+VG0+vg;w$65">dataframe</variable><variable id="_;PP-/]_2fNUR.dyhw(8">readr</variable><variable id="`IEAx*Bh}E,Y}mK;jr;{">dplyr</variable></variables><block type="variables_set" id="gM*jw`FfIR3)8=g0iEB7" x="-114" y="-24"><field name="VAR" id="(*,(U(rm+VG0+vg;w$65">dataframe</field><value name="VALUE"><block type="pipe_R" id="!gwXC?8/[mR7wV[4{{f+"><mutation items="1"></mutation><value name="INPUT"><block type="varDoMethod_R" id=",vaW{t?FHN1~E?+,h!w-"><mutation items="1"></mutation><field name="VAR" id="_;PP-/]_2fNUR.dyhw(8">readr</field><field name="MEMBER">read_csv</field><data>readr:read_csv</data><value name="ADD0"><block type="text" id="dfrpI5b@DHr+DQ:|@vpv"><field name="TEXT">datasets/nursery.csv</field></block></value></block></value><value name="ADD0"><block type="varDoMethod_R" id="n023QH9wtPknPa-0N48U"><mutation items="1"></mutation><field name="VAR" id="`IEAx*Bh}E,Y}mK;jr;{">dplyr</field><field name="MEMBER">mutate</field><data>dplyr:mutate</data><value name="ADD0"><block type="dummyOutputCodeBlock_R" id="]T7_jV(D%wn(Y2f8-zu("><field name="CODE">across(everything(),factor)</field></block></value></block></value></block></value></block><block type="variables_get" id="dn{+Q#DO%lN;G_tFGJ#B" x="-114" y="118"><field name="VAR" id="(*,(U(rm+VG0+vg;w$65">dataframe</field></block></xml>

[1mRows: [22m[34m2000[39m [1mColumns: [22m[34m9[39m
[36m──[39m [1mColumn specification[22m [36m────────────────────────────────────────────────────────[39m
[1mDelimiter:[22m ","
[31mchr[39m (9): parents, has_nurs, form, children, housing, finance, social, heath,...

[36mℹ[39m Use `spec()` to retrieve the full column specification for this data.
[36mℹ[39m Specify the column types or set `show_col_types = FALSE` to quiet this message.


parents,has_nurs,form,children,housing,finance,social,heath,rank
<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>
pretentious,very_crit,complete,more,less_conv,convenient,problematic,not_recom,not_recom
pretentious,proper,foster,more,critical,convenient,slightly_prob,recommended,priority
pretentious,critical,foster,3,less_conv,convenient,slightly_prob,recommended,spec_prior
great_pret,critical,completed,3,less_conv,convenient,nonprob,not_recom,not_recom
usual,less_proper,incomplete,3,convenient,inconv,slightly_prob,priority,priority
pretentious,improper,foster,more,critical,convenient,slightly_prob,recommended,priority
usual,very_crit,completed,more,convenient,convenient,nonprob,recommended,priority
⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮
usual,proper,foster,1,less_conv,inconv,problematic,not_recom,not_recom
usual,improper,foster,2,convenient,inconv,slightly_prob,priority,priority


Let's check for NA:

- `with base do summary`
    - using `dataframe`

In [4]:
base::summary(dataframe)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="6C!l+@S8,SvtzakmsxQD">base</variable><variable id="(*,(U(rm+VG0+vg;w$65">dataframe</variable></variables><block type="varDoMethod_R" id="!Q.e}4-{mlohpY4,bq7j" x="-185" y="73"><mutation items="1"></mutation><field name="VAR" id="6C!l+@S8,SvtzakmsxQD">base</field><field name="MEMBER">summary</field><data>base:summary</data><value name="ADD0"><block type="variables_get" id="WC|)7oRDJVadXh[gQFXH"><field name="VAR" id="(*,(U(rm+VG0+vg;w$65">dataframe</field></block></value></block></xml>

        parents           has_nurs           form     children  
 great_pret :718   critical   :379   complete  :478   1   :490  
 pretentious:618   improper   :400   completed :523   2   :499  
 usual      :664   less_proper:416   foster    :500   3   :536  
                   proper     :409   incomplete:499   more:475  
                   very_crit  :396                              
       housing          finance               social            heath    
 convenient:684   convenient: 976   nonprob      :661   not_recom  :682  
 critical  :670   inconv    :1024   problematic  :689   priority   :662  
 less_conv :646                     slightly_prob:650   recommended:656  
                                                                         
                                                                         
         rank    
 not_recom :682  
 priority  :637  
 spec_prior:631  
 very_recom: 50  
                 

There's no NA, and we see that all levels of all variables are pretty balanced, except for the highest recommendation level of `rank`, which is rare.

## Explore data

Start by loading `ggplot2`, `psych`, and `janitor` for exploring the data:

- `library ggplot2`
- `library psych`
<!-- - `library corrr` -->
- `library janitor`
<!-- - `library GGally` -->

In [5]:
library(janitor)
library(ggplot2)
library(psych)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="jSJYE[.$3mc!]5z4c{`@">janitor</variable><variable id="K2WB)HAgulDnN4.esuaK">ggplot2</variable><variable id="-%rFagMkPz%VG@$q~ANP">psych</variable></variables><block type="import_R" id="k.,xSim.rEx|fK*WwVUR" x="-79" y="-48"><field name="libraryName" id="jSJYE[.$3mc!]5z4c{`@">janitor</field><next><block type="import_R" id="n?whQI8t]Qj6)k-ny|GN"><field name="libraryName" id="K2WB)HAgulDnN4.esuaK">ggplot2</field><next><block type="import_R" id="P-VMU%%i=n!?0`h3v|jE"><field name="libraryName" id="-%rFagMkPz%VG@$q~ANP">psych</field></block></next></block></next></block></xml>


Attaching package: ‘janitor’


The following objects are masked from ‘package:stats’:

    chisq.test, fisher.test


“package ‘ggplot2’ was built under R version 4.2.2”

Attaching package: ‘psych’


The following objects are masked from ‘package:ggplot2’:

    %+%, alpha


Registered S3 method overwritten by 'GGally':
  method from   
  +.gg   ggplot2



Let's take a closer look with descriptive statistics:

- `with psych do describe`
    - using `dataframe`

In [6]:
psych::describe(dataframe)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="-%rFagMkPz%VG@$q~ANP">psych</variable><variable id="(*,(U(rm+VG0+vg;w$65">dataframe</variable></variables><block type="varDoMethod_R" id="!Q.e}4-{mlohpY4,bq7j" x="-260" y="1"><mutation items="1"></mutation><field name="VAR" id="-%rFagMkPz%VG@$q~ANP">psych</field><field name="MEMBER">describe</field><data>psych:describe</data><value name="ADD0"><block type="variables_get" id="WC|)7oRDJVadXh[gQFXH"><field name="VAR" id="(*,(U(rm+VG0+vg;w$65">dataframe</field></block></value></block></xml>

Unnamed: 0_level_0,vars,n,mean,sd,median,trimmed,mad,min,max,range,skew,kurtosis,se
Unnamed: 0_level_1,<int>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>,<dbl>
parents*,1,2000,1.973,0.8310333,2,1.96625,1.4826,1,3,2,0.050410118,-1.550994,0.01858247
has_nurs*,2,2000,3.0215,1.3982186,3,3.026875,1.4826,1,5,4,-0.019588299,-1.272229,0.03126512
form*,3,2000,2.51,1.1079322,2,2.5125,1.4826,1,4,3,-0.002064703,-1.33794,0.02477412
children*,4,2000,2.498,1.1025442,3,2.4975,1.4826,1,4,3,-0.011721645,-1.325196,0.02465364
housing*,5,2000,1.981,0.8154578,2,1.97625,1.4826,1,3,2,0.034838295,-1.496121,0.01823419
finance*,6,2000,1.512,0.499981,2,1.515,0.0,1,2,1,-0.047977824,-1.998697,0.01117991
social*,7,2000,1.9945,0.8098133,2,1.993125,1.4826,1,3,2,0.010008813,-1.475836,0.01810798
heath*,8,2000,1.987,0.8180254,2,1.98375,1.4826,1,3,2,0.023907085,-1.505966,0.0182916
rank*,9,2000,2.0245,0.8696424,2,1.999375,1.4826,1,4,3,0.180824439,-1.177968,0.0194458


We need to be careful with interpreting the means, etc of these nominal variables (**note the star**).
In such cases, tables showing the relative frequency of levels (which we've already seen) are more useful.

Because the variables are nominal, many of our standard tools won't work.
For example, we can't use a correlation matrix/heatmap, because correlation isn't defined for nominal variables.
There is something called [Cramer's V](https://towardsdatascience.com/the-search-for-categorical-correlation-a1cf7f1888c9) that is close, but it requires some custom coding that's a bit out of scope for us.
Similarly a scatterplot matrix is not going to be very useful.

One thing we can do is make a contingency table for each variable against `rank` using `tabyl`.
That will show us what levels correspond to rank decisions.
We're introducing a new concept here, which is **mapping** a function over a list.
To do this, we need to load the `purrr` package:

- `library purrr`

In [61]:
library(purrr)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="S%m`x6oGZN(@Lgmaxl$R">purrr</variable></variables><block type="import_R" id="Qm.t_s=MJ!FQy}}G}!;X" x="-244" y="-234"><field name="libraryName" id="S%m`x6oGZN(@Lgmaxl$R">purrr</field></block></xml>

Next we use a pipe to drop the `rank` column, get the column names, then map these names to `tabyl`.
The `x` stands for any column name, and `!!sym(x)` is special syntax since `tabyl` doesn't take variable names as character strings, which is what we're giving it:

- `pipe`
    - `dataframe`
    - to with `dplyr` do `select`
        - using `-rank`
    - then to with `base` do `colnames`
    - then to with `purrr` do `map`
        - using `\(x) tabyl(dataframe,!!sym(x),rank)`

In [62]:
dataframe %>%
    dplyr::select(-rank) %>%
    base::colnames() %>%
    purrr::map(\(x) tabyl(dataframe,!!sym(x),rank))

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="(*,(U(rm+VG0+vg;w$65">dataframe</variable><variable id="`IEAx*Bh}E,Y}mK;jr;{">dplyr</variable><variable id="6C!l+@S8,SvtzakmsxQD">base</variable><variable id="S%m`x6oGZN(@Lgmaxl$R">purrr</variable></variables><block type="pipe_R" id="g{V=}vxE}M~nawjFJdT0" x="31" y="29"><mutation items="3"></mutation><value name="INPUT"><block type="variables_get" id="S)DTV#oOCCNfM-e6r,37"><field name="VAR" id="(*,(U(rm+VG0+vg;w$65">dataframe</field></block></value><value name="ADD0"><block type="varDoMethod_R" id=")e@;p};W$JEm?mq]^d/2"><mutation items="1"></mutation><field name="VAR" id="`IEAx*Bh}E,Y}mK;jr;{">dplyr</field><field name="MEMBER">select</field><data>dplyr:select</data><value name="ADD0"><block type="dummyOutputCodeBlock_R" id="[LgX4G5+,jsa{3N7YTI/"><field name="CODE">-rank</field></block></value></block></value><value name="ADD1"><block type="varDoMethod_R" id="P6`Obx=%vL(5HL2Ba+_k"><mutation items="1"></mutation><field name="VAR" id="6C!l+@S8,SvtzakmsxQD">base</field><field name="MEMBER">colnames</field><data>base:colnames</data></block></value><value name="ADD2"><block type="varDoMethod_R" id="%GoHgr{l}K1GjX]Lwud2"><mutation items="1"></mutation><field name="VAR" id="S%m`x6oGZN(@Lgmaxl$R">purrr</field><field name="MEMBER">map</field><data>purrr:map</data><value name="ADD0"><block type="dummyOutputCodeBlock_R" id="Dn3I$3DhWNsVWd~EI0`L"><field name="CODE">\(x) tabyl(dataframe,!!sym(x),rank)</field></block></value></block></value></block></xml>

parents,not_recom,priority,spec_prior,very_recom
<fct>,<dbl>,<dbl>,<dbl>,<dbl>
great_pret,245,139,334,0
pretentious,213,203,182,20
usual,224,295,115,30

has_nurs,not_recom,priority,spec_prior,very_recom
<fct>,<dbl>,<dbl>,<dbl>,<dbl>
critical,124,64,191,0
improper,145,129,110,16
less_proper,151,193,53,19
proper,135,220,39,15
very_crit,127,31,238,0

form,not_recom,priority,spec_prior,very_recom
<fct>,<dbl>,<dbl>,<dbl>,<dbl>
complete,158,165,137,18
completed,179,160,169,15
foster,160,163,170,7
incomplete,185,149,155,10

children,not_recom,priority,spec_prior,very_recom
<fct>,<dbl>,<dbl>,<dbl>,<dbl>
1,169,189,111,21
2,163,164,157,15
3,181,148,199,8
more,169,136,164,6

housing,not_recom,priority,spec_prior,very_recom
<fct>,<dbl>,<dbl>,<dbl>,<dbl>
convenient,229,249,169,37
critical,221,199,246,4
less_conv,232,189,216,9

finance,not_recom,priority,spec_prior,very_recom
<fct>,<dbl>,<dbl>,<dbl>,<dbl>
convenient,345,319,280,32
inconv,337,318,351,18

social,not_recom,priority,spec_prior,very_recom
<fct>,<dbl>,<dbl>,<dbl>,<dbl>
nonprob,215,230,187,29
problematic,235,187,267,0
slightly_prob,232,220,177,21

heath,not_recom,priority,spec_prior,very_recom
<fct>,<dbl>,<dbl>,<dbl>,<dbl>
not_recom,682,0,0,0
priority,0,277,385,0
recommended,0,360,246,50


There are a number of zeros for `very_recom` across different variables, so it certainly seems like some levels of those variables are influential for this recommendation.

It's also clear that `heath` is a close proxy for `rank` and should be dropped before building the model.

### Prepare train/test sets

We need to split the dataframe into training data and testing data.

First, load the package for splitting:

- `library rsample`

In [19]:
library(rsample)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="~~-I(f=60)#JfKGvV_AP">rsample</variable></variables><block type="import_R" id="aFBwKpYz$V@3rI9Nk71a" x="-280" y="10"><field name="libraryName" id="~~-I(f=60)#JfKGvV_AP">rsample</field></block></xml>

“package ‘rsample’ was built under R version 4.2.2”


Now split the data, but first specify a random seed so your results match mine and specify strata because some of our target class levels are very rare:
           
- with `base` do `set.seed` using `2`

- Set `data_split` to `with rsample do initial split`
    - using `dataframe`
    - and `prop=.50`
    - and `strata = "rank"`
- Set `data_train` to `with rsample do training`
    - using `data_split`
- Set `data_test` to `with rsample do testing`
    - using `data_split`
- `data_train`


In [166]:
base::set.seed(2)

data_split = rsample::initial_split(dataframe,prop=.50,strata = "rank")
data_train = rsample::training(data_split)
data_test = rsample::testing(data_split)

data_train

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="6C!l+@S8,SvtzakmsxQD">base</variable><variable id="bQ!4E:J!~]0(]7KV]m@=">data_split</variable><variable id=":iMr},W7(N7vSLAUw!ao">data_train</variable><variable id="~~-I(f=60)#JfKGvV_AP">rsample</variable><variable id="(*,(U(rm+VG0+vg;w$65">dataframe</variable><variable id="|q$XCeTWL%AdgT|]tbnU">data_test</variable></variables><block type="varDoMethod_R" id="Bs.?L_yk:!JU:!D9$e:2" x="-110" y="151"><mutation items="1"></mutation><field name="VAR" id="6C!l+@S8,SvtzakmsxQD">base</field><field name="MEMBER">set.seed</field><data>base:set.seed</data><value name="ADD0"><block type="math_number" id="y3|#H:ii-(np[L}nH,7)"><field name="NUM">2</field></block></value></block><block type="variables_set" id="s!g),aa^(]dox/f`@P!y" x="-116" y="302"><field name="VAR" id="bQ!4E:J!~]0(]7KV]m@=">data_split</field><value name="VALUE"><block type="varDoMethod_R" id="hPsr6}9C/VNgaLsKuR,o"><mutation items="3"></mutation><field name="VAR" id="~~-I(f=60)#JfKGvV_AP">rsample</field><field name="MEMBER">initial_split</field><data>rsample:initial_split</data><value name="ADD0"><block type="variables_get" id="]~#@ltf];dTom_%pzV4n"><field name="VAR" id="(*,(U(rm+VG0+vg;w$65">dataframe</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id="R`?vH79hsA6Duxa9)AFX"><field name="CODE">prop=.50</field></block></value><value name="ADD2"><block type="dummyOutputCodeBlock_R" id=",6(U`btemu5|+O;_245#"><field name="CODE">strata = "rank"</field></block></value></block></value><next><block type="variables_set" id="3J6#JDFV0wE?V;NuM=?L"><field name="VAR" id=":iMr},W7(N7vSLAUw!ao">data_train</field><value name="VALUE"><block type="varDoMethod_R" id="X|Q7lj,pD_9W{%^.xd7h"><mutation items="1"></mutation><field name="VAR" id="~~-I(f=60)#JfKGvV_AP">rsample</field><field name="MEMBER">training</field><data>rsample:training</data><value name="ADD0"><block type="variables_get" id="JFCmHyJPiN`qwnlE~:iT"><field name="VAR" id="bQ!4E:J!~]0(]7KV]m@=">data_split</field></block></value></block></value><next><block type="variables_set" id="Y]ag(g~}tkN6:_X*]6P{"><field name="VAR" id="|q$XCeTWL%AdgT|]tbnU">data_test</field><value name="VALUE"><block type="varDoMethod_R" id="WBYo8G|ZcojJAqETRnv`"><mutation items="1"></mutation><field name="VAR" id="~~-I(f=60)#JfKGvV_AP">rsample</field><field name="MEMBER">testing</field><data>rsample:testing</data><value name="ADD0"><block type="variables_get" id="p^~x9|Zj((6qaUVvj#.E"><field name="VAR" id="bQ!4E:J!~]0(]7KV]m@=">data_split</field></block></value></block></value></block></next></block></next></block><block type="variables_get" id="9j){6[r67+7OFx`a~K[Y" x="-115" y="515"><field name="VAR" id=":iMr},W7(N7vSLAUw!ao">data_train</field></block></xml>

parents,has_nurs,form,children,housing,finance,social,heath,rank
<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>
pretentious,very_crit,complete,more,less_conv,convenient,problematic,not_recom,not_recom
pretentious,proper,foster,1,convenient,convenient,slightly_prob,not_recom,not_recom
great_pret,less_proper,completed,1,less_conv,inconv,nonprob,not_recom,not_recom
great_pret,critical,incomplete,3,less_conv,convenient,nonprob,not_recom,not_recom
great_pret,improper,completed,3,convenient,convenient,slightly_prob,not_recom,not_recom
pretentious,proper,completed,1,convenient,inconv,nonprob,not_recom,not_recom
great_pret,less_proper,complete,1,less_conv,convenient,problematic,not_recom,not_recom
⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮
pretentious,improper,completed,3,critical,convenient,nonprob,priority,spec_prior
great_pret,improper,complete,more,less_conv,inconv,problematic,priority,spec_prior


### Fit model

Load the `recipes` and `workflows` packages, along with `parsnip`, `generics`, `hardhat`, `discrim`, and `broom`:

- `library recipes`
- `library workflows`
- `library parsnip`
- `library generics`
- `library hardhat`
- `library broom`
- `library discrim`

In [167]:
library(recipes)
library(workflows)
library(generics)
library(parsnip)
library(hardhat)
library(broom)
library(discrim)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="v7B~eW}{zV,n~iX:qTJ-">recipes</variable><variable id="haMDnc8Uj|EjY^608!;]">workflows</variable><variable id="w(9-o9gLSDEJ,]Qt}e!^">generics</variable><variable id="3q]Js%*Alzd]|p|FOe}-">parsnip</variable><variable id="{5PU6CE0j[[^yJ79kBMR">hardhat</variable><variable id="Ubl?FFCN5uCFfCTQ?7:z">broom</variable><variable id="w!D@.AsJTNGfyfT07ttG">discrim</variable></variables><block type="import_R" id="EGE;dT.cnN,o}Y[99#Qb" x="-12" y="8"><field name="libraryName" id="v7B~eW}{zV,n~iX:qTJ-">recipes</field><next><block type="import_R" id="6DrE-26ByLD$vC}QeW:}"><field name="libraryName" id="haMDnc8Uj|EjY^608!;]">workflows</field><next><block type="import_R" id="hQ4(=]9%QB3(VWqs9OK:"><field name="libraryName" id="w(9-o9gLSDEJ,]Qt}e!^">generics</field><next><block type="import_R" id="d1vIIBH:{ygkVLV/Ts5l"><field name="libraryName" id="3q]Js%*Alzd]|p|FOe}-">parsnip</field><next><block type="import_R" id="vFfUgs[O$$~/VO1yVa~v"><field name="libraryName" id="{5PU6CE0j[[^yJ79kBMR">hardhat</field><next><block type="import_R" id="cegz.xTQ~Bq[q~wqZQHD"><field name="libraryName" id="Ubl?FFCN5uCFfCTQ?7:z">broom</field><next><block type="import_R" id="woT+;7kaSfjsFo7!kbL@"><field name="libraryName" id="w!D@.AsJTNGfyfT07ttG">discrim</field></block></next></block></next></block></next></block></next></block></next></block></next></block></xml>

Let's make a workflow to predict `rank` with naive Bayes  and a step that assigns `heath` to a role.


- Set `recipe` to 
    - `pipe` with `recipes` do `recipe`
        - using `rank ~ .`
        - and `data = data_train`
    - to with `recipes` do `update_role`
        - using `heath`
        - and `new_role="ignore"`
- Set `model` to 
    - `pipe` with `parsnip` do `naive_Bayes` 
        - using `smoothness = 1`
        - and `Laplace = .5`
    - to with `parsnip` do `set_mode` 
        - using `"classification"`
    - then to with `parsnip` do `set_engine`
        - using `"naivebayes"`
- Set `workflow` to 
    - `pipe` with `workflows` do `workflow`
    - to with `workflows` do `add_model`
        - using `model`
    - then to with `workflows` do `add_recipe` 
        - using `recipe`
        

In [168]:
recipe = recipes::recipe(rank ~ .,data = data_train) %>%
    recipes::update_role(heath,new_role="ignore")
model = parsnip::naive_Bayes(Laplace = .5,smoothness = 1) %>%
    parsnip::set_mode("classification") %>%
    parsnip::set_engine("naivebayes")
workflow = workflows::workflow() %>%
    workflows::add_model(model) %>%
    workflows::add_recipe(recipe)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="Oz8-33unXZJ?3~@*32r-">recipe</variable><variable id="mgo;O)iX^5)A5.@gqIkA">model</variable><variable id="v7B~eW}{zV,n~iX:qTJ-">recipes</variable><variable id="BLc@7E0B7Y3=fus{uzCr">workflow</variable><variable id="3q]Js%*Alzd]|p|FOe}-">parsnip</variable><variable id="haMDnc8Uj|EjY^608!;]">workflows</variable></variables><block type="variables_set" id="w{}j6,P|-6Qepjtz/mj{" x="77" y="213"><field name="VAR" id="Oz8-33unXZJ?3~@*32r-">recipe</field><value name="VALUE"><block type="pipe_R" id="uzp8+LCnH{r`1Xn%3mdK"><mutation items="1"></mutation><value name="INPUT"><block type="varDoMethod_R" id="U^,K1TWD+}odaY!kExkh"><mutation items="2"></mutation><field name="VAR" id="v7B~eW}{zV,n~iX:qTJ-">recipes</field><field name="MEMBER">recipe</field><data>recipes:recipe</data><value name="ADD0"><block type="dummyOutputCodeBlock_R" id="n(du-Q]_KTwoCFj6n|V5"><field name="CODE">rank ~ .</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id="lIU{_ql0epb7NH-C6Kvw"><field name="CODE">data = data_train</field></block></value></block></value><value name="ADD0"><block type="varDoMethod_R" id="5=08G^e872=?+){KI?5}"><mutation items="2"></mutation><field name="VAR" id="v7B~eW}{zV,n~iX:qTJ-">recipes</field><field name="MEMBER">update_role</field><data>recipes:update_role</data><value name="ADD0"><block type="dummyOutputCodeBlock_R" id="iP`lGU_Nf!$/q0V2={b8"><field name="CODE">heath</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id="iTUY$Gd4/i+,0Z}aCX,a"><field name="CODE">new_role="ignore"</field></block></value></block></value></block></value><next><block type="variables_set" id="*(u89h_-M@(PB-!qP/1l"><field name="VAR" id="mgo;O)iX^5)A5.@gqIkA">model</field><value name="VALUE"><block type="pipe_R" id="Oj1:/l+xksT^dSM;w{,g"><mutation items="2"></mutation><value name="INPUT"><block type="varDoMethod_R" id="m!?PwvZD%AJHhH1paJ[8"><mutation items="2"></mutation><field name="VAR" id="3q]Js%*Alzd]|p|FOe}-">parsnip</field><field name="MEMBER">naive_Bayes</field><data>parsnip:naive_Bayes</data><value name="ADD0"><block type="dummyOutputCodeBlock_R" id="Ndy.lU+xz(Q4#U5TfO2:"><field name="CODE">Laplace = .5</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id="E*CWuF:WWkz;2T]+xznC"><field name="CODE">smoothness = 1</field></block></value></block></value><value name="ADD0"><block type="varDoMethod_R" id="ady`8N}J*2BV_*AmtF`n"><mutation items="1"></mutation><field name="VAR" id="3q]Js%*Alzd]|p|FOe}-">parsnip</field><field name="MEMBER">set_mode</field><data>parsnip:set_mode</data><value name="ADD0"><block type="text" id="#4#x{=R9!%aA:_,cHnF("><field name="TEXT">classification</field></block></value></block></value><value name="ADD1"><block type="varDoMethod_R" id="bybPF(gahfhB3cmyy;/n"><mutation items="1"></mutation><field name="VAR" id="3q]Js%*Alzd]|p|FOe}-">parsnip</field><field name="MEMBER">set_engine</field><data>parsnip:set_engine</data><value name="ADD0"><block type="text" id="v9=Ry*6UUQjzFr5Sy@,f"><field name="TEXT">naivebayes</field></block></value></block></value></block></value><next><block type="variables_set" id="f%c#[D)Zv:uHH6WY_TQI"><field name="VAR" id="BLc@7E0B7Y3=fus{uzCr">workflow</field><value name="VALUE"><block type="pipe_R" id="Rk~*~r(!3iF+cvcz7IiD"><mutation items="2"></mutation><value name="INPUT"><block type="varDoMethod_R" id="H!7?PcsriTx,;-jJmb`7"><mutation items="1"></mutation><field name="VAR" id="haMDnc8Uj|EjY^608!;]">workflows</field><field name="MEMBER">workflow</field><data>workflows:workflow</data></block></value><value name="ADD0"><block type="varDoMethod_R" id="CrWn^(|[{],FV=/-Jt!b"><mutation items="1"></mutation><field name="VAR" id="haMDnc8Uj|EjY^608!;]">workflows</field><field name="MEMBER">add_model</field><data>workflows:add_model</data><value name="ADD0"><block type="variables_get" id="7MGIpZZh#XkK`d+ZY4h:"><field name="VAR" id="mgo;O)iX^5)A5.@gqIkA">model</field></block></value></block></value><value name="ADD1"><block type="varDoMethod_R" id="L1.O)ZR~Ab`VSER(L}FZ"><mutation items="1"></mutation><field name="VAR" id="haMDnc8Uj|EjY^608!;]">workflows</field><field name="MEMBER">add_recipe</field><data>workflows:add_recipe</data><value name="ADD0"><block type="variables_get" id="U.@@S3Rdx{:xKxE[Fsx]"><field name="VAR" id="Oz8-33unXZJ?3~@*32r-">recipe</field></block></value></block></value></block></value></block></next></block></next></block></xml>

Fit the workflow:

-  Set `trained_model` to with `generics` do `fit`
    - using `workflow`
    - and `data = data_train`
`trained_model`

In [169]:
trained_model = generics::fit(workflow,data = data_train)

trained_model

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="l#fwZ4WDUe#-#XAb#q^;">trained_model</variable><variable id="w(9-o9gLSDEJ,]Qt}e!^">generics</variable><variable id="BLc@7E0B7Y3=fus{uzCr">workflow</variable></variables><block type="variables_set" id="6F5Ls%m;fDN%1P.HL;~b" x="71" y="186"><field name="VAR" id="l#fwZ4WDUe#-#XAb#q^;">trained_model</field><value name="VALUE"><block type="varDoMethod_R" id="jAjgrG@)Vv]=~@dnuh7H"><mutation items="2"></mutation><field name="VAR" id="w(9-o9gLSDEJ,]Qt}e!^">generics</field><field name="MEMBER">fit</field><data>generics:fit</data><value name="ADD0"><block type="variables_get" id="KB;+aD=xFzgUg0(`:xP:"><field name="VAR" id="BLc@7E0B7Y3=fus{uzCr">workflow</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id="siA`Xs!U2c1Ww}6ggh90"><field name="CODE">data = data_train</field></block></value></block></value></block><block type="variables_get" id="qJ0nFc9Y6f9.~_a8X@r6" x="76" y="309"><field name="VAR" id="l#fwZ4WDUe#-#XAb#q^;">trained_model</field></block></xml>

══ Workflow [trained] ══════════════════════════════════════════════════════════
[3mPreprocessor:[23m Recipe
[3mModel:[23m naive_Bayes()

── Preprocessor ────────────────────────────────────────────────────────────────
0 Recipe Steps

── Model ───────────────────────────────────────────────────────────────────────

 
 Call: 
naive_bayes.default(x = maybe_data_frame(x), y = y, laplace = ~0.5, 
    usekernel = TRUE, adjust = ~1)

--------------------------------------------------------------------------------- 
 
Laplace smoothing: 0.5

--------------------------------------------------------------------------------- 
 
 A priori probabilities: 

 not_recom   priority spec_prior very_recom 
0.34434434 0.31831832 0.31731732 0.02002002 

--------------------------------------------------------------------------------- 
 
 Tables: 

--------------------------------------------------------------------------------- 
 ::: parents (Categorical) 
---------------------------------------------

Now use `augment` to get predictions.

- Set `data_evaluation` to with `generics` do `augment`
    - `trained_model`
    - `data_test`
- `data_evaluation`

In [170]:
data_evaluation = generics::augment(trained_model,data_test)

data_evaluation

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="F$%83*WA}Squkx=^;86c">data_evaluation</variable><variable id="w(9-o9gLSDEJ,]Qt}e!^">generics</variable><variable id="l#fwZ4WDUe#-#XAb#q^;">trained_model</variable><variable id="|q$XCeTWL%AdgT|]tbnU">data_test</variable></variables><block type="variables_set" id="c9)bbL(P.=Zt*ANq:wTF" x="-188" y="165"><field name="VAR" id="F$%83*WA}Squkx=^;86c">data_evaluation</field><value name="VALUE"><block type="varDoMethod_R" id="@!!94T,fy7p./bN/HI:m"><mutation items="2"></mutation><field name="VAR" id="w(9-o9gLSDEJ,]Qt}e!^">generics</field><field name="MEMBER">augment</field><data>generics:augment</data><value name="ADD0"><block type="variables_get" id="b-|JBRwH@EMvsFWEJ!;A"><field name="VAR" id="l#fwZ4WDUe#-#XAb#q^;">trained_model</field></block></value><value name="ADD1"><block type="variables_get" id="DT;7fC1oy;O,X]8)3FM1"><field name="VAR" id="|q$XCeTWL%AdgT|]tbnU">data_test</field></block></value></block></value></block><block type="variables_get" id="OI3EYMFqtCofk(0Ma~i+" x="-177" y="283"><field name="VAR" id="F$%83*WA}Squkx=^;86c">data_evaluation</field></block></xml>

parents,has_nurs,form,children,housing,finance,social,heath,rank,.pred_class,.pred_not_recom,.pred_priority,.pred_spec_prior,.pred_very_recom
<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<fct>,<dbl>,<dbl>,<dbl>,<dbl>
pretentious,proper,foster,more,critical,convenient,slightly_prob,recommended,priority,priority,0.3613081,0.50710876,0.11744426,1.413889e-02
great_pret,critical,completed,3,less_conv,convenient,nonprob,not_recom,not_recom,spec_prior,0.2478534,0.08888825,0.66316849,8.980999e-05
usual,less_proper,incomplete,3,convenient,inconv,slightly_prob,priority,priority,priority,0.4037144,0.51987855,0.05842622,1.798081e-02
usual,very_crit,completed,more,convenient,convenient,nonprob,recommended,priority,not_recom,0.4579088,0.16697993,0.35394566,2.116559e-02
usual,very_crit,incomplete,2,convenient,convenient,slightly_prob,not_recom,not_recom,not_recom,0.5269457,0.17419273,0.28670740,1.215420e-02
great_pret,less_proper,completed,1,critical,convenient,nonprob,priority,spec_prior,priority,0.3896007,0.46126119,0.14723828,1.899859e-03
pretentious,improper,completed,more,critical,inconv,problematic,recommended,spec_prior,not_recom,0.4124650,0.20726540,0.37972184,5.477855e-04
⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮,⋮
usual,critical,foster,1,critical,convenient,slightly_prob,not_recom,not_recom,not_recom,0.4368618,0.3038254,0.2560408,3.272053e-03
great_pret,critical,foster,3,critical,convenient,slightly_prob,not_recom,not_recom,spec_prior,0.2545135,0.0617370,0.6837208,2.873283e-05


### Evaluate the model

First load `yardstick` to get performance metrics:

- `library yardstick`

In [171]:
library(yardstick)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="M8O}^6C_fm;DGZt9!{=e">yardstick</variable></variables><block type="import_R" id="AZq!7y:AKY3O6PsjAWMu" x="-161" y="10"><field name="libraryName" id="M8O}^6C_fm;DGZt9!{=e">yardstick</field></block></xml>

To get the accuracy, recall, precision, and F1 do the following:

- `pipe`
    - `create list with`
        - with `yardstick` do `accuracy`
            - using `data_evaluation`
            - and `truth=rank`
            - and `estimate=.pred_class`
        - with `yardstick` do `precision`
            - using `data_evaluation`
            - and `truth=rank`
            - and `estimate=.pred_class`
        - with `yardstick` do `recall`
            - using `data_evaluation`
            - and `truth=rank`
            - and `estimate=.pred_class`
        - with `yardstick` do `f_meas`
            - using `data_evaluation`
            - and `truth=rank`
            - and `estimate=.pred_class`

In [172]:
list(yardstick::accuracy(data_evaluation,truth=rank,estimate=.pred_class), yardstick::precision(data_evaluation,truth=rank,estimate=.pred_class), yardstick::recall(data_evaluation,truth=rank,estimate=.pred_class), yardstick::f_meas(data_evaluation,truth=rank,estimate=.pred_class)) %>%
    dplyr::bind_rows()

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="`IEAx*Bh}E,Y}mK;jr;{">dplyr</variable><variable id="M8O}^6C_fm;DGZt9!{=e">yardstick</variable><variable id="F$%83*WA}Squkx=^;86c">data_evaluation</variable></variables><block type="pipe_R" id="4`_*5SvnS~Sa.p=3S#OI" x="-179" y="-223"><mutation items="1"></mutation><value name="INPUT"><block type="lists_create_with" id="|A+]|U,HA2TI/$inZo$N"><mutation items="4"></mutation><value name="ADD0"><block type="varDoMethod_R" id="ovJDL$T;GrTBZ,)jMz;a"><mutation items="3"></mutation><field name="VAR" id="M8O}^6C_fm;DGZt9!{=e">yardstick</field><field name="MEMBER">accuracy</field><data>yardstick:accuracy</data><value name="ADD0"><block type="variables_get" id="$#GYCvI1LKXt%Rsb09a}"><field name="VAR" id="F$%83*WA}Squkx=^;86c">data_evaluation</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id="i#@XFD{vr]B47yD52|(B"><field name="CODE">truth=rank</field></block></value><value name="ADD2"><block type="dummyOutputCodeBlock_R" id="(x50a~#{oijRotT|Z?8G"><field name="CODE">estimate=.pred_class</field></block></value></block></value><value name="ADD1"><block type="varDoMethod_R" id="j8eqada{+]K+(kP]fjd["><mutation items="3"></mutation><field name="VAR" id="M8O}^6C_fm;DGZt9!{=e">yardstick</field><field name="MEMBER">precision</field><data>yardstick:precision</data><value name="ADD0"><block type="variables_get" id="UEG_cZtZh0$Qd^M!JdR3"><field name="VAR" id="F$%83*WA}Squkx=^;86c">data_evaluation</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id="|Z)xC?L]O4:;pY-Yx}[`"><field name="CODE">truth=rank</field></block></value><value name="ADD2"><block type="dummyOutputCodeBlock_R" id="/-*t{(I-ds+mDl%{qy}v"><field name="CODE">estimate=.pred_class</field></block></value></block></value><value name="ADD2"><block type="varDoMethod_R" id="9kYm6fYJ_jy~n1FTvM}*"><mutation items="3"></mutation><field name="VAR" id="M8O}^6C_fm;DGZt9!{=e">yardstick</field><field name="MEMBER">recall</field><data>yardstick:recall</data><value name="ADD0"><block type="variables_get" id="1^]4G]IyxmvdiCk.z~.8"><field name="VAR" id="F$%83*WA}Squkx=^;86c">data_evaluation</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id="lpEx2Z_OPFriJq*Isu@T"><field name="CODE">truth=rank</field></block></value><value name="ADD2"><block type="dummyOutputCodeBlock_R" id="`x3kIKM$)S]L}UQK]4_="><field name="CODE">estimate=.pred_class</field></block></value></block></value><value name="ADD3"><block type="varDoMethod_R" id="n:|F(r?P|^4inAx3PqFZ"><mutation items="3"></mutation><field name="VAR" id="M8O}^6C_fm;DGZt9!{=e">yardstick</field><field name="MEMBER">f_meas</field><data>yardstick:f_meas</data><value name="ADD0"><block type="variables_get" id=":qt8OD9;Au+.@WF}CY$@"><field name="VAR" id="F$%83*WA}Squkx=^;86c">data_evaluation</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id="{~pcwmOK5=}7-KN18T%e"><field name="CODE">truth=rank</field></block></value><value name="ADD2"><block type="dummyOutputCodeBlock_R" id="Dp]`|Qqrb4?=dU-1o?_o"><field name="CODE">estimate=.pred_class</field></block></value></block></value></block></value><value name="ADD0"><block type="varDoMethod_R" id="Ivg[L*te{4kkI16K6!!~"><mutation items="1"></mutation><field name="VAR" id="`IEAx*Bh}E,Y}mK;jr;{">dplyr</field><field name="MEMBER">bind_rows</field><data>dplyr:bind_rows</data></block></value></block></xml>

.metric,.estimator,.estimate
<chr>,<chr>,<dbl>
accuracy,multiclass,0.5104895
precision,macro,0.6265424
recall,macro,0.4061859
f_meas,macro,0.4009604


Performance is pretty bad.
Recall our most frequent level is `not_recom`, which in the original data occurs $682/2000=34%$ of the time.

Let's look at the confusion matrix to better understand what's happening:

- with `yardstick` do `conf_mat`
    - using `data_evaluation`
    - and `truth=rank`
    - and `estimate=.pred_class`

In [174]:
yardstick::conf_mat(data_evaluation,truth=rank,estimate=.pred_class)

#<xml xmlns="https://developers.google.com/blockly/xml"><variables><variable id="M8O}^6C_fm;DGZt9!{=e">yardstick</variable><variable id="F$%83*WA}Squkx=^;86c">data_evaluation</variable></variables><block type="varDoMethod_R" id="gizLwBCab6Cr8^9Z`X%," x="84" y="89"><mutation items="3"></mutation><field name="VAR" id="M8O}^6C_fm;DGZt9!{=e">yardstick</field><field name="MEMBER">conf_mat</field><data>yardstick:conf_mat</data><value name="ADD0"><block type="variables_get" id=";h7|~eHy$XNka*39|jZN"><field name="VAR" id="F$%83*WA}Squkx=^;86c">data_evaluation</field></block></value><value name="ADD1"><block type="dummyOutputCodeBlock_R" id=".B{ZDc95:0}s8~W4yt4@"><field name="CODE">truth=rank</field></block></value><value name="ADD2"><block type="dummyOutputCodeBlock_R" id="+]t(q4WTgs~9oa^I%,CK"><field name="CODE">estimate=.pred_class</field></block></value></block></xml>

            Truth
Prediction   not_recom priority spec_prior very_recom
  not_recom         98       88         99          0
  priority         139      213         16         29
  spec_prior       101       18        199          0
  very_recom         0        0          0          1

Look at the `not_recom` column.
The model seems particularly bad at predicting here - it can tell that something is not `very_recom` but other than that it can't really tell.
Things are a bit better for the middle levels, but terrible again for `very_recom`.

We could consider trying a model that used the order of these levels instead of treating `rank` as nominal, or we could also consider collapsing the ranks into categories based on the classifier's confusion (if that is allowed).

<!-- ### Visualizing

#This is an example of how we might get feature importances out, per class
#but it seems way to complicated for our target audience

library(vip)

pred_fun <- function(object, newdata) {
  predict(object, new_data = newdata, type = "prob")$.pred_priority
}


trained_model %>%
  vi(method = "permute", target = "rank", metric = "auc", nsim = 10,
     pred_wrapper = pred_fun, train = data_train, reference_class = "priority") -->