Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to get Shap interactions for LightGBM? #104

Closed
pecto2020 opened this issue Aug 23, 2023 · 6 comments
Closed

how to get Shap interactions for LightGBM? #104

pecto2020 opened this issue Aug 23, 2023 · 6 comments

Comments

@pecto2020
Copy link

Your package is great, and very easy to use within tidymodels framework. I was wondering if it is possible to calculate interactions for LightGBM. I would like to use that instead of the heuristic (which is an amazing solution tho) in sv_dependence. I've seen that for Xgboost is possible and there is a param Interaction = T to set in shapviz.base. Any solution workaround for LightGBM?

@mayer79
Copy link
Collaborator

mayer79 commented Aug 23, 2023

Unfortunately not via TreeSHAP in LightGBM. But you could crunch interactions via the {treeshap} package.

@mayer79
Copy link
Collaborator

mayer79 commented Aug 23, 2023

@hbaniecki
Copy link
Member

I assume it involves hacking C++ code, which I can't help with :/

@mayer79
Copy link
Collaborator

mayer79 commented Aug 23, 2023

Oh, hmm...

@pecto2020
Copy link
Author

Tried to use treeshap but got an error #> Error in S_inter[, v, color_var]: subscript out of bounds

Here's the code

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.3
#> Warning: package 'broom' was built under R version 4.2.3
#> Warning: package 'dials' was built under R version 4.2.3
#> Warning: package 'dplyr' was built under R version 4.2.3
#> Warning: package 'ggplot2' was built under R version 4.2.3
#> Warning: package 'parsnip' was built under R version 4.2.3
#> Warning: package 'recipes' was built under R version 4.2.3
#> Warning: package 'tibble' was built under R version 4.2.3
#> Warning: package 'tidyr' was built under R version 4.2.3
#> Warning: package 'tune' was built under R version 4.2.3
#> Warning: package 'workflowsets' was built under R version 4.2.3
#> Warning: package 'yardstick' was built under R version 4.2.3
library(shapviz)
#> Warning: package 'shapviz' was built under R version 4.2.3
library(treeshap)
library(lightgbm)
#> Warning: package 'lightgbm' was built under R version 4.2.3
#> Loading required package: R6
#> 
#> Attaching package: 'lightgbm'
#> The following object is masked from 'package:dplyr':
#> 
#>     slice
library(datasets)
library(bonsai)
#> Warning: package 'bonsai' was built under R version 4.2.3

# Use the fifa20 dataset 
fifa20 <- fifa20$data %>% 
            select(-work_rate) %>%
            bind_cols(data.frame(target = fifa20$target)) 
# Split the data
set.seed(123)
split <- initial_split(fifa20)
train <- training(split)
test <- testing(split)

# Recipe 
rec <- recipe(target ~ ., data = train)

# Model specification
boost_spec <- boost_tree(
  mode = "regression", 
  trees = 200, 
  tree_depth = 6
) %>%
  set_engine("lightgbm") %>%
  set_mode("regression")

# Workflow
workflow <- workflow() %>% 
  add_recipe(rec) %>%
  add_model(boost_spec)

# Fit the model
boost_model <- workflow %>% fit(data = train)

# Create shap object with shapviz
shap_lgbm <- shapviz(extract_fit_engine(boost_model), 
                    as.matrix(test %>% select(-target)),
                    test %>% select(-target))

# Create unified model representation
unified_lgbm <- treeshap::lightgbm.unify(extract_fit_engine(boost_model), train %>% select(-target))

# Derive interactions
interactions_lgbm <- treeshap::treeshap(unified_lgbm, test %>% select(-target), interactions = T, verbose = 0)

# Plot dependences
shap_lgbm$S_inter <- interactions_lgbm$interactions

sv_dependence(shap_lgbm, v = "overall", interactions = T,  color_var = "height_cm")
#> Error in S_inter[, v, color_var]: subscript out of bounds

dim(shap_lgbm$S_inter)
#> [1]   54   54 4570

@mayer79
Copy link
Collaborator

mayer79 commented Aug 25, 2023

An interaction cannot be assigned to a shapviz object, so this code here is wrong:

shap_lgbm$S_inter <- interactions_lgbm$interactions

This works, but I would decompose less rows and divide the response by 1e6 (or so):

shap_lgbm <- shapviz(interactions_lgbm)
top4 <- names(head(sv_importance(shap_lgbm, kind = "no"), 4))
sv_interaction(shap_lgbm[1:1000, top4])
sv_dependence(shap_lgbm, v = "overall", color_var = top4, interactions = TRUE)

image
image

@mayer79 mayer79 closed this as completed Sep 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants