generated from just-the-docs/just-the-docs-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo12-2.Rmd
88 lines (72 loc) · 2.19 KB
/
demo12-2.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
---
title: "Untitled"
output:
pagedown::book_crc:
highlight: "kate"
date: "`r Sys.Date()`"
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE, warning = FALSE, message = FALSE)
```
```{r}
library(DALEXtra)
library(SummarizedExperiment)
library(glue)
library(parsnip)
library(tidymodels)
library(tidyverse)
library(vip)
theme_set(theme_bw())
```
We'll study the Type I Diabetes data. The two objects below consider the studies
combined/separately.
```{r}
load("T1D.rda")
se <- se[, colData(se)$disease %in% c("healthy", "T1D")]
x <- t(assay(se)) |>
as_tibble() %>%
set_names(glue("ASV{seq_along(.)}"))
combined_data <- bind_cols(
x,
y = factor(colData(se)$disease),
study_name = colData(se)$study_name
)
split_data <- combined_data %>%
split(.$study_name)
```
We'll fit models in the two extremes: completely separate fits, and completely
combined.
```{r}
gbm <- boost_tree(mode = "classification", trees = 50)
combined_fit <- fit(gbm, y ~ ., data = select(combined_data, -study_name))
separate_fits <- map(split_data, ~ fit(gbm, y ~ ., data = select(., -study_name)))
```
Let's compare the features that are considered important across models.
```{r}
vip(combined_fit, num_features = 20)
map(separate_fits, ~ vip(., num_features = 20))
```
We can interpret the results as well.
```{r}
focus_taxa <- c("ASV39", "ASV166")
explainer <- explain_tidymodels(combined_fit, data = select(combined_data, -study_name:-y), y = combined_data$y)
profiles <- model_profile(explainer, variables = focus_taxa)
plot(profiles, geom = "profiles", variables = focus_taxa)
```
```{r}
explainers <- separate_fits |>
map2(split_data, ~ explain_tidymodels(.x, data = select(.y, -study_name:-y), y = .x$y))
focus_taxa <- c("ASV79", "ASV40", "ASV39")
explainers |>
map(~ model_profile(., variables = focus_taxa)) |>
map(~ plot(., geom = "profiles", variables = focus_taxa))
```
```{r}
focus_taxa <- c("ASV39", "ASV166", "ASV40", "ASV79", "ASV108", "ASV74")
combined_long <- combined_data |>
select(y, study_name, focus_taxa) |>
pivot_longer(starts_with("ASV"), names_to = "ASV")
ggplot(combined_long) +
geom_boxplot(aes(log(1 + value), ASV, fill = y)) +
facet_grid(. ~ study_name, scales = "free")
```