/
getting-started-with-cheem.Rmd
328 lines (248 loc) · 15 KB
/
getting-started-with-cheem.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
---
title: "Getting started with cheem"
author: "Nicholas Spyrison"
date: '`r format(Sys.Date())`'
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Getting started with cheem}
%\VignetteEncoding{UTF-8}
%\VignetteEngine{knitr::rmarkdown}
editor_options:
chunk_output_type: console
---
<!-- #Example vignette:
https://github.com/njtierney/naniar/blob/master/vignettes/getting-started-w-naniar.Rmd -->
```{r setup, include = FALSE}
knitr::opts_chunk$set(
echo = TRUE, # code
include = TRUE, # plots
results = "show", # text: "hide", "show"
## stop("REPLACE ME"):
eval = FALSE, # chunk code
message = FALSE,
warning = FALSE,
error = FALSE,
collapse = TRUE,
comment = "#>",
fig.height = 4,
fig.width = 6,
fig.align = "center",
cache = FALSE
)
```
__TL;DR__, you can jump straight into the visuals and application with `cheem::run_app()`, but we suggest you read the introduction to get situated with the context first.
# Introduction
Non-linear models regularly result in more accurate prediction than their linear counterparts. However, the number and complexity of their terms make them more opaque to the interpretability. The our ability to understand how features (variables or predictors) influence predictions is important to a wide range of audiences. Attempts to bring interpretability to such complex models is an important aspect of eXplainable Artificial Intelligence (XAI).
_Local explanations_ are one such tool used in XAI. They attempt to approximate the feature importance in the vicinity of one instance (observation). That is to say that they give an approximation of linear terms at the position of one in-sample or out-of-sample observation.
```{r limenonlinear, echo=FALSE, fig.cap="Illustration of non-linear classification boundary. The use of local explanations approximates the feature importance in the vicinity of one instance. This allow us to understand a change in which features would result in a red plus being classified as a blue circle. From _Ribiro, M. et. all. (2017). Why should I trust you?_"}
knitr::include_graphics("../inst/shiny_apps/cheem/www/lime_nonlinear.png")
```
If the analyst can explore how models lead to bad predictions it can suggest insight into issues of the data or suggest models that may be more robust to misclassified or extreme residuals. An analyst may want to explore the support feature contributions where the explanations makes sense or may be completely unreliable. We purpose this sort of analysis as conducted with interactive graphics in the analysis and R package titled __cheem__.
# Preprocessing
This framework is broadly applicable for any model and compatible local explanation. We will illustrate with __xgboost__::xgboost() model (xgb) and the tree SHAP local explanation with __shapviz__::shapviz(). The model attempts to predict housing sales price from 11 predictors for 338 sale events from one neighborhood in the 2018 Ames data.
The first things we need are the prediction and a local explanation (or other embedded space). Here we create a xgb model, create predictions, and find the SHAP values of each observation.
```{r}
## Download if not installed
if(!require(cheem)) install.packages("cheem", dependencies = TRUE)
if(!require(treeshap)) install.packages("treeshap", dependencies = TRUE)
if(!require(shapviz)) install.packages("shapviz", dependencies = TRUE)
## Load onto session
library(cheem)
library(xgboost)
library(shapviz)
## Setup
X <- amesHousing2018_NorthAmes[, 1:9]
Y <- amesHousing2018_NorthAmes$SalePrice
clas <- amesHousing2018_NorthAmes$SubclassMS
## Model and predict
ames_train <- data.matrix(X) %>% xgb.DMatrix(label = Y)
ames_xgb_fit <- xgboost(data = ames_train, max.depth = 3, nrounds = 25)
ames_xgb_pred <- predict(ames_xgb_fit, newdata = ames_train)
ames_xgb_pred %>% head()
## SHAP values
shp <- shapviz(ames_xgb_fit, X_pred = ames_train, X = X)
## Keep just the [n, p] local explanations
ames_xgb_shap <- shp$S
ames_xgb_shap %>% head()
```
Note that the choice of the model, prediction, and local explanation (or other embedding) is choice of the analyst and not facilitated by __cheem__. Now let's prepare for the visualization of these spaces with a `cheem::cheem_ls()` call before we start our analysis.
```{r}
## Preprocessing for cheem analysis
ames_chm <- cheem_ls(X, Y,
class = clas,
attr_df = ames_xgb_shap,
pred = ames_xgb_pred,
label = "Ames, xgb, shap")
names(ames_chm)
```
# Cheem viewer
We have extracted tree SHAP, an feature importance measure in the vicinity of each observation. We need to identify an instance of interest to explore; we do so with the linked brushing available in the _global view_. Then we will vary contributions from different features to test the support an explanation in a _radial tour_
## Global view
To get more complete view lets look at approximations of the data space, attribution space, and model fits side-by-side with linked brushing with the help of __plotly__ and __crosstalk__. We have identified an observation with a large Mahalanobis distance (in data space) and the closest neighbor in attribution space.
<!-- The model information is essentially the confusion matrix for the classification case. -->
```{r, out.width="100%"}
prim <- 1
comp <- 17
global_view(ames_chm, primary_obs = prim, comparison_obs = comp,
height_px = 240, width_px = 720,
as_ggplot = TRUE, color = "log_maha.data")
```
From this _global view_ we want to identify a primary instance (PI) and optionally a comparison instance (CI) to explore. Misclassified or observations with high residuals are good targets for further exploration. One point sticks out in this case. Instance 243 (shown as *) is a Gentoo (purple) penguin, while the model predict it to be a Chinstrap penguin. Penguin 169 (shown as x) is reasonably close by and correctly predicted as Gentoo. In practice we used linked brushing and misclassification information to guide our search.
## Radial tour
There is a lot to unpack here. The normalized distribution of all feature attribution from all instances are shown as parallel coordinates lines. The above selected PI and CI are shown here as a dashed and dotted line respectively. The first thing we notice is that the attribution of the PI is close to it's (incorrect) prediction of Chinstrap (orange) in terms of bill length (`bl`) and flipper length (`fl`). In terms of bill depth and body mass (`bd` and `bm`) it is more like its observed species Gentoo (purple). We select flipper length as the feature to manipulate.
```{r, out.width="100%", eval = FALSE}
## Normalized attribution basis of the PI
bas <- sug_basis(ames_xgb_shap, rownum = prim)
## Default feature to manipulate:
#### the feature with largest separation between PI and CI attribution
mv <- sug_manip_var(
ames_xgb_shap, primary_obs = prim, comparison_obs = comp)
## Make the radial tour
ggt <- radial_cheem_tour(
ames_chm, basis = bas, manip_var = mv,
primary_obs = prim, comparison_obs = comp, angle = .15)
## Animate it
animate_gganimate(ggt, fps = 6)
#height = 2, width = 4.5, units = "in", res = 150
## Or as a plotly html widget
#animate_plotly(ggt, fps = 6)
```
```{r, echo=FALSE, out.width="100%"}
## To mitigate file size (CRAN note) and run time create a gif and include that instead of executing code to make inline.
if(FALSE){
prim <- 1
comp <- 17
bas <- sug_basis(shap_df, rownum = prim)
mv <- sug_manip_var(
shap_df, primary_obs = prim, comparison_obs = comp)
ggt <- radial_cheem_tour(
this_ls, basis = bas, manip_var = mv,
primary_obs = prim, comparison_obs = comp, angle = .15)
#### .gif is about .2 Mb saved, while HTML widget was about 7 Mb.
anim <- animate_gganimate(
ggt, fps = 6, height = 2, width = 4.5, units = "in", res = 150)
gganimate::anim_save("tour_penguins.gif", animation = anim)#, path = "./vignettes")
beepr::beep()
}
#knitr::include_graphics("tour_penguins.gif")
```
![](https://github.com/nspyrison/cheem/blob/main/vignettes/tour_penguins.gif?raw=true)
Starting from the attribution projection, this instance already looks more like its observed Gentoo than predicted Chinstrap. However, by frame 8, the basis has a full contribution of flipper length and does look more like the predicted Chinstrap. Looking at the parallel coordinate lines on the basis visual we can see that flipper length has a large gap between PI and CI, lets check the original variables to digest.
```{r}
library(ggplot2)
prim <- 1
ggplot(penguins_na.rm, aes(x = bill_length_mm,
y = flipper_length_mm,
colour = species,
shape = species)) +
geom_point() +
## Highlight PI, *
geom_point(data = penguins_na.rm[prim, ],
shape = 8, size = 5, alpha = 0.8) +
## Theme, scaling, color, and labels
theme_bw() +
theme(aspect.ratio = 1) +
scale_color_brewer(palette = "Dark2") +
labs(y = "Flipper length [mm]", x = "Bill length [mm]",
color = "Observed species", shape = "Observed species")
```
This profile, with two features that are most distinguished between the PI and CI. This instance is nested in the in between the Chinstrap penguins. That makes this instance particularly hard for a random forest model to classify as decision tree can only make partition on one value (horizontal and vertical lines here).
## Shiny application
We provide an interactive __shiny__ application. Interactive features are made possible with __plotly__, __crosstalk__, and __DT__. We have preprocessed simulated and modern datasets for you to explore this analysis with. Alternatively, bring your own data by saving the return of `cheem_ls()` as an rds file. Follow along with the example in `?cheem_ls`.
# Conclusion
Interpretability of black-box models is important to maintain. Local explanation extend this interpretability by approximating the feature importance in the vicinity of one instance. We purpose post-hoc analysis of these local explanations. First we explore them in a global, full instance context. Then we explore the support of the local explanation to see where it seems plausible or unreliable.
# Other local explanations (& models)
__cheem__ is agnostic to model or local explanation, but requires a model and local explanation. Above we illustrated using a random forest to predict penguin species. Below demonstrates using other attribution spaces from different models.
## shapviz (& xgb classification)
__shapviz__ is being actively maintained and is hosted on CRAN. It is compatible with H2O, lgb, and xgb models.
https://github.com/ModelOriented/shapviz
```{r, eval=FALSE, echo=TRUE}
if(!require(shapviz)) install.packages("shapviz")
if(!require(xgboost)) install.packages("xgboost")
library(shapviz)
library(xgboost)
set.seed(3653)
## Setup
X <- spinifex::penguins_na.rm[, 1:4]
Y <- spinifex::penguins_na.rm$species
clas <- spinifex::penguins_na.rm$species
## Model and predict
peng_train <- data.matrix(X) %>%
xgb.DMatrix(label = Y)
peng_xgb_fit <- xgboost(data = peng_train, max.depth = 3, nrounds = 25)
peng_xgb_pred <- predict(peng_xgb_fit, newdata = peng_train)
## SHAP
peng_xgb_shap <- shapviz(peng_xgb_fit, X_pred = peng_train, X = X)
## Keep just the [n, p] local explanations
peng_xgb_shap <- peng_xgb_shap$S
```
## treeshap (& randomForest regression)
__treeshap__ is only available on CRAN. It is compatible with many tree-based models including gbm, lbm, rf, ranger, and xgb models.
https://github.com/ModelOriented/treeshap
```{r, eval=FALSE, echo=TRUE}
if(!require(treeshap)) install.packages("treeshap")
if(!require(randomForest)) install.packages("randomForest")
library(treeshap)
library(randomForest)
## Setup
X <- spinifex::wine[, -1:2]
Y <- spinifex::wine$Alcohol
clas <- spinifex::wine$Type
## Fit randomForest::randomForest
wine_rf_fit <- randomForest::randomForest(
X, Y, ntree = 125,
mtry = ifelse(is_discrete(Y), sqrt(ncol(X)), ncol(X) / 3),
nodesize = max(ifelse(is_discrete(Y), 1, 5), nrow(X) / 500))
wine_rf_pred <- predict(wine_rf_fit)
## treeshap::treeshap()
wine_rf_tshap <- wine_rf_fit %>%
treeshap::randomForest.unify(X) %>%
treeshap::treeshap(X, interactions = FALSE, verbose = FALSE)
## Keep just the [n, p] local explanations
wine_rf_tshap <- wine_rf_tshap$shaps
```
## DALEX (& LM regression)
__DALEX__ is a popular and versatile XAI package available on CRAN. It is compatible with many models, but it uses the original, slower variant of SHAP local explanation. Expect long run times for sizable data or complex models.
https://ema.drwhy.ai/shapley.html#SHAPRcode
```{r, eval=FALSE, echo=TRUE}
if(!require(DALEX)) install.packages("DALEX")
library(DALEX)
## Setup
X <- dragons[, c(1:4, 6)]
Y <- dragons$life_length
clas <- dragons$colour
## Model and predict
drag_lm_fit <- lm(data = data.frame(Y, X), Y ~ .)
drag_lm_pred <- predict(drag_lm_fit)
## SHAP via DALEX, versatile but slow
drag_lm_exp <- explain(drag_lm_fit, data = X, y = Y,
label = "Dragons, LM, SHAP")
## DALEX::predict_parts_shap is flexible, but slow and one row at a time
drag_lm_shap <- matrix(NA, nrow(X), ncol(X))
sapply(1:nrow(X), function(i){
pps <- predict_parts_shap(drag_lm_exp, new_observation = X[i, ])
## Keep just the [n, p] local explanations
drag_lm_shap[i, ] <<- tapply(
pps$contribution, pps$variable, mean, na.rm = TRUE) %>% as.vector()
})
drag_lm_shap <- as.data.frame(drag_lm_shap)
```
<!-- # EXPERIMENTAL: Use with non-linear embedding spaces -->
<!-- Instead of creating a non-linear model with predictions The global view could compare data space and non-linear embedded space, though the radial tour isn't well defined as there isn't necessarily a y variable to be the y-axis for the tour. -->
<!-- ```{r, eval=FALSE, echo=TRUE} -->
<!-- ## NOTE: This is not supported. embedding are [n, d<p], and y axis of the tour is gone; closer to classification tour -->
<!-- if(!require(umap)) install.packages("umap") -->
<!-- library(umap) -->
<!-- library(cheem) -->
<!-- X <- spinifex::penguins_na.rm[, 1:4] -->
<!-- ## 2d umap embedding of 4d data space. No non-linear model or predictions. -->
<!-- peng_umap2 <- umap::umap(X, d = 4) -->
<!-- peng_umap2 <- peng_umap2$layout -->
<!-- peng_umap_chm <- cheem_ls(X, y = NULL, class = NULL, -->
<!-- attr_df = peng_umap2, -->
<!-- pred = NULL, -->
<!-- label = "Penguin, umap2") -->
<!-- global_view(peng_umap_chm, primary_obs = 115, comparison_obs = 296, -->
<!-- height_px = 240, width_px = 720, as_ggplot = FALSE) -->
<!-- ## Note that the radial tour isn't really applicable, as there is no y(/y axis). -->
<!-- ``` -->