Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

local_attributions fails for classification: incorrect number of subscripts on matrix #69

Closed
agilebean opened this issue Feb 3, 2020 · 10 comments
Labels
bug 💣 Bug to fix

Comments

@agilebean
Copy link

agilebean commented Feb 3, 2020

For classification, local_attributions() returns the error:

Error in contribution[nrow(contribution), ] <- cummulative[nrow(contribution),  : 
  incorrect number of subscripts on matrix

One hint for the root cause might be the warning message thrown by the explainer - it tries to calculate numeric residuals which of course it cannot do:

      DALEX.explainer <- DALEX::explain(
        model = model_object,
        data = features,
        y = training.set$.outcome == TARGET.VALUE,
        label = paste(model_object$method, " model"),
        colorize = TRUE
      )
  
  A new explainer has been created!  
Warning message:
In mean.default(residuals) :
  argument is not numeric or logical: returning NA

Reproducible example:

random.case <- structure(list(anger = 0.166666666666667, anticipation = 0, disgust = 0.166666666666667, 
    fear = 0.166666666666667, joy = 0, negative = 0.25, positive = 0.0833333333333333, 
    sadness = 0.0833333333333333, surprise = 0.0833333333333333, 
    trust = 0), class = "data.frame", row.names = c(NA, -1L))

training.set <- structure(list(.outcome = structure(c(3L, 4L, 5L, 4L, 4L, 5L, 
5L, 4L, 3L, 3L, 3L, 5L, 4L, 3L, 3L, 1L, 4L, 3L, 4L, 5L, 3L, 2L, 
5L, 5L, 5L), .Label = c("1", "2", "3", "4", "5"), class = "factor"), 
    anger = c(0, 0.0434782608695652, 0, 0, 0, 0.1, 0, 0.037037037037037, 
    0.0192307692307692, 0, 0, 0, 0, 0.0673076923076923, 0.181818181818182, 
    0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
    0.0952380952380952, 0, 0.0441176470588235, 0), anticipation = c(0.333333333333333, 
    0.217391304347826, 0.125, 0.15, 0.2, 0.2, 0.217391304347826, 
    0.111111111111111, 0.173076923076923, 0.166666666666667, 
    0.111111111111111, 0.157894736842105, 0.214285714285714, 
    0.115384615384615, 0.0909090909090909, 0.0408163265306122, 
    0, 0.166666666666667, 0, 0.114285714285714, 0.184210526315789, 
    0.0476190476190476, 0.133333333333333, 0.102941176470588, 
    0.176470588235294), disgust = c(0, 0, 0, 0, 0, 0, 0, 0.0185185185185185, 
    0.0192307692307692, 0.0833333333333333, 0.0740740740740741, 
    0, 0, 0.0288461538461538, 0, 0.0204081632653061, 0, 0, 0.111111111111111, 
    0, 0, 0.0952380952380952, 0, 0.0294117647058824, 0), fear = c(0, 
    0.0434782608695652, 0, 0.05, 0, 0, 0, 0.0185185185185185, 
    0, 0, 0, 0, 0, 0.0673076923076923, 0, 0.0408163265306122, 
    0, 0.0833333333333333, 0.111111111111111, 0, 0.0263157894736842, 
    0.0952380952380952, 0, 0.0294117647058824, 0), joy = c(0, 
    0.130434782608696, 0.166666666666667, 0.15, 0.233333333333333, 
    0.2, 0.173913043478261, 0.166666666666667, 0.0961538461538462, 
    0.166666666666667, 0.037037037037037, 0.210526315789474, 
    0.214285714285714, 0.0961538461538462, 0.181818181818182, 
    0.0204081632653061, 0.333333333333333, 0.0833333333333333, 
    0.222222222222222, 0.2, 0.105263157894737, 0.0952380952380952, 
    0.2, 0.147058823529412, 0.176470588235294), negative = c(0, 
    0.0869565217391304, 0.0833333333333333, 0.1, 0, 0, 0, 0.0555555555555556, 
    0.0769230769230769, 0.166666666666667, 0.0740740740740741, 
    0.0526315789473684, 0.0714285714285714, 0.105769230769231, 
    0.181818181818182, 0.204081632653061, 0, 0.166666666666667, 
    0.222222222222222, 0.0285714285714286, 0.105263157894737, 
    0.19047619047619, 0, 0.102941176470588, 0.0294117647058824
    ), positive = c(0.333333333333333, 0.217391304347826, 0.291666666666667, 
    0.4, 0.3, 0.3, 0.347826086956522, 0.333333333333333, 0.326923076923077, 
    0.25, 0.259259259259259, 0.315789473684211, 0.285714285714286, 
    0.240384615384615, 0.181818181818182, 0.244897959183673, 
    0.333333333333333, 0.25, 0.222222222222222, 0.4, 0.342105263157895, 
    0.238095238095238, 0.4, 0.235294117647059, 0.352941176470588
    ), sadness = c(0.333333333333333, 0.0434782608695652, 0.0416666666666667, 
    0, 0, 0, 0, 0.0185185185185185, 0.0576923076923077, 0, 0.0740740740740741, 
    0, 0, 0.0480769230769231, 0.0909090909090909, 0.142857142857143, 
    0, 0, 0.111111111111111, 0, 0.0526315789473684, 0.0952380952380952, 
    0, 0.0441176470588235, 0.0294117647058824), surprise = c(0, 
    0.0434782608695652, 0.0833333333333333, 0.05, 0.0666666666666667, 
    0, 0.0434782608695652, 0.037037037037037, 0.0192307692307692, 
    0, 0.111111111111111, 0.0526315789473684, 0, 0.0865384615384615, 
    0, 0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
    0, 0.0666666666666667, 0.0735294117647059, 0.0294117647058824
    ), trust = c(0, 0.173913043478261, 0.208333333333333, 0.1, 
    0.2, 0.2, 0.217391304347826, 0.203703703703704, 0.211538461538462, 
    0.166666666666667, 0.259259259259259, 0.210526315789474, 
    0.214285714285714, 0.144230769230769, 0.0909090909090909, 
    0.204081632653061, 0.333333333333333, 0.25, 0, 0.2, 0.0789473684210526, 
    0.0476190476190476, 0.2, 0.191176470588235, 0.205882352941176
    )), row.names = c(NA, 25L), class = "data.frame")

model.rf <- caret::train(
  form = .outcome ~ .,
  data = training.set,
  method = "rf", 
  trControl = trainControl(
    method = "repeatedcv", number = 5, repeats = 5)
)

target <- training.set$.outcome
features <- training.set %>% select(-.outcome)

TARGET.VALUE <- "1"

DALEX.explainer <- DALEX::explain(
        model = model.rf,
        data = features,
        y = target == TARGET.VALUE,
        label = paste(model_object$method, " model"),
        colorize = TRUE
  )

DALEX.attribution <- DALEX.explainer %>%
        iBreakDown::local_attributions(random.case) 

@agilebean agilebean changed the title local_attributions fails for classification local_attributions fails for classification: incorrect number of subscripts on matrix Feb 3, 2020
@hbaniecki
Copy link
Member

Hi @agilebean! Can you provide the model please? It might be the problem with predict_function.
Your code works for me DALEX v0.9.4 && iBreakDown v0.9.9 :

library(dplyr)
random.case <- structure(list(anger = 0.166666666666667, anticipation = 0, disgust = 0.166666666666667, 
                              fear = 0.166666666666667, joy = 0, negative = 0.25, positive = 0.0833333333333333, 
                              sadness = 0.0833333333333333, surprise = 0.0833333333333333, 
                              trust = 0), class = "data.frame", row.names = c(NA, -1L))

training.set <- structure(list(.outcome = structure(c(3L, 4L, 5L, 4L, 4L, 5L, 
                                                      5L, 4L, 3L, 3L, 3L, 5L, 4L, 3L, 3L, 1L, 4L, 3L, 4L, 5L, 3L, 2L, 
                                                      5L, 5L, 5L), .Label = c("1", "2", "3", "4", "5"), class = "factor"), 
                               anger = c(0, 0.0434782608695652, 0, 0, 0, 0.1, 0, 0.037037037037037, 
                                         0.0192307692307692, 0, 0, 0, 0, 0.0673076923076923, 0.181818181818182, 
                                         0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
                                         0.0952380952380952, 0, 0.0441176470588235, 0), anticipation = c(0.333333333333333, 
                                                                                                         0.217391304347826, 0.125, 0.15, 0.2, 0.2, 0.217391304347826, 
                                                                                                         0.111111111111111, 0.173076923076923, 0.166666666666667, 
                                                                                                         0.111111111111111, 0.157894736842105, 0.214285714285714, 
                                                                                                         0.115384615384615, 0.0909090909090909, 0.0408163265306122, 
                                                                                                         0, 0.166666666666667, 0, 0.114285714285714, 0.184210526315789, 
                                                                                                         0.0476190476190476, 0.133333333333333, 0.102941176470588, 
                                                                                                         0.176470588235294), disgust = c(0, 0, 0, 0, 0, 0, 0, 0.0185185185185185, 
                                                                                                                                         0.0192307692307692, 0.0833333333333333, 0.0740740740740741, 
                                                                                                                                         0, 0, 0.0288461538461538, 0, 0.0204081632653061, 0, 0, 0.111111111111111, 
                                                                                                                                         0, 0, 0.0952380952380952, 0, 0.0294117647058824, 0), fear = c(0, 
                                                                                                                                                                                                       0.0434782608695652, 0, 0.05, 0, 0, 0, 0.0185185185185185, 
                                                                                                                                                                                                       0, 0, 0, 0, 0, 0.0673076923076923, 0, 0.0408163265306122, 
                                                                                                                                                                                                       0, 0.0833333333333333, 0.111111111111111, 0, 0.0263157894736842, 
                                                                                                                                                                                                       0.0952380952380952, 0, 0.0294117647058824, 0), joy = c(0, 
                                                                                                                                                                                                                                                              0.130434782608696, 0.166666666666667, 0.15, 0.233333333333333, 
                                                                                                                                                                                                                                                              0.2, 0.173913043478261, 0.166666666666667, 0.0961538461538462, 
                                                                                                                                                                                                                                                              0.166666666666667, 0.037037037037037, 0.210526315789474, 
                                                                                                                                                                                                                                                              0.214285714285714, 0.0961538461538462, 0.181818181818182, 
                                                                                                                                                                                                                                                              0.0204081632653061, 0.333333333333333, 0.0833333333333333, 
                                                                                                                                                                                                                                                       0.222222222222222, 0.2, 0.105263157894737, 0.0952380952380952, 
                                                                                                                                                                                                                                                              0.2, 0.147058823529412, 0.176470588235294), negative = c(0, 
                                                                                                                                                                                                                                                                                                                       0.0869565217391304, 0.0833333333333333, 0.1, 0, 0, 0, 0.0555555555555556, 
                                                                                                                                                                                                                                                                                                                       0.0769230769230769, 0.166666666666667, 0.0740740740740741, 
                                                                                                                                                                                                                                                                                                                       0.0526315789473684, 0.0714285714285714, 0.105769230769231, 
                                                                                                                                                                                                                                                                                                                       0.181818181818182, 0.204081632653061, 0, 0.166666666666667, 
                                                                                                                                                                                                                                                                                                                       0.222222222222222, 0.0285714285714286, 0.105263157894737, 
                                                                                                                                                                                                                                                                                                                       0.19047619047619, 0, 0.102941176470588, 0.0294117647058824
                                                                                                                                                                                                                                                              ), positive = c(0.333333333333333, 0.217391304347826, 0.291666666666667, 
                                                                                                                                                                                                                                                                              0.4, 0.3, 0.3, 0.347826086956522, 0.333333333333333, 0.326923076923077, 
                                                                                                                                                                                                                                                                              0.25, 0.259259259259259, 0.315789473684211, 0.285714285714286, 
                                                                                                                                                                                                                                                                              0.240384615384615, 0.181818181818182, 0.244897959183673, 
                                                                                                                                                                                                                                                                              0.333333333333333, 0.25, 0.222222222222222, 0.4, 0.342105263157895, 
                                                                                                                                                                                                                                                                              0.238095238095238, 0.4, 0.235294117647059, 0.352941176470588
                                                                                                                                                                                                                                                              ), sadness = c(0.333333333333333, 0.0434782608695652, 0.0416666666666667, 
                                                                                                                                                                                                                                                                             0, 0, 0, 0, 0.0185185185185185, 0.0576923076923077, 0, 0.0740740740740741, 
                                                                                                                                                                                                                                                                             0, 0, 0.0480769230769231, 0.0909090909090909, 0.142857142857143, 
                                                                                                                                                                                                                                                                             0, 0, 0.111111111111111, 0, 0.0526315789473684, 0.0952380952380952, 
                                                                                                                                                                                                                                                                             0, 0.0441176470588235, 0.0294117647058824), surprise = c(0, 
                                                                                                                                                                                                                                                                                                                                      0.0434782608695652, 0.0833333333333333, 0.05, 0.0666666666666667, 
                                                                                                                                                                                                                                                                                                                                      0, 0.0434782608695652, 0.037037037037037, 0.0192307692307692, 
                                                                                                                                                                                                                                                                                                                                      0, 0.111111111111111, 0.0526315789473684, 0, 0.0865384615384615, 
                                                                                                                                                                                                                                                                                                                                      0, 0.0408163265306122, 0, 0, 0, 0.0285714285714286, 0.0526315789473684, 
                                                                                                                                                                                                                                                                                                                                      0, 0.0666666666666667, 0.0735294117647059, 0.0294117647058824
                                                                                                                                                                                                                                                                             ), trust = c(0, 0.173913043478261, 0.208333333333333, 0.1, 
                                                                                                                                                                                                                                                                                          0.2, 0.2, 0.217391304347826, 0.203703703703704, 0.211538461538462, 
                                                                                                                                                                                                                                                                                          0.166666666666667, 0.259259259259259, 0.210526315789474, 
                                                                                                                                                                                                                                                                                          0.214285714285714, 0.144230769230769, 0.0909090909090909, 
                                                                                                                                                                                                                                                                                          0.204081632653061, 0.333333333333333, 0.25, 0, 0.2, 0.0789473684210526, 
                                                                                                                                                                                                                                                                                          0.0476190476190476, 0.2, 0.191176470588235, 0.205882352941176
                                                                                                                                                                                                                                                                             )), row.names = c(NA, 25L), class = "data.frame")
target <- training.set$.outcome
features <- training.set %>% select(-.outcome)

TARGET.VALUE <- "1"

colnames(training.set)[1] <- "outcome"
model_object <- lm(outcome==TARGET.VALUE~., data = training.set)

DALEX.explainer <- DALEX::explain(
  model = model_object,
  data = features,
  y = training.set$outcome == TARGET.VALUE,
  label = paste(model_object$method, " model"),
  colorize = TRUE
)

DALEX.attribution <- DALEX.explainer %>% iBreakDown::local_attributions(random.case) 
DALEX.attribution

@agilebean
Copy link
Author

agilebean commented Feb 3, 2020

Ha, crossing thoughts - I just included the model in the description!

@agilebean
Copy link
Author

I just verified I had iBreakDown_0.9.9 and only 1 subrelease number below for DALEX, i.e. DALEX_0.9.3

@agilebean
Copy link
Author

I just ran it again with the same error. However, when I run the same analysis - but with the model trained as regression instead of classification, it WORKS! Double checked just now.

@agilebean
Copy link
Author

@hbaniecki Did you try it with the model I specified in the description above?

@hbaniecki
Copy link
Member

Yes, it is a weird problem with data.frame/matrix behavior. I believe that this part

cummulative <- do.call(rbind, c(list(baseline_yhat), yhats_mean, list(target_yhat)))
contribution <- rbind(0,apply(cummulative, 2, diff))
contribution[1,] <- cummulative[1,]
contribution[nrow(contribution),] <- cummulative[nrow(contribution),]
can be handled better (to fix).

While running your example, there is a red warning (in the explainer output) saying that predict_function returns probabilities for multiple classes. For now, if you want to use local_attributions for one class (e.g. target = "1"), you can use a custom predict_function and pass it to the explainer.

custom_predict_caret_oneclass <- function(model, data, target = "1") {
  return(predict(model, data, type = "prob")[, target])
}

DALEX.explainer <- DALEX::explain(
  model = model.rf,
  data = features,
  y = target == TARGET.VALUE,
  predict_function = custom_predict_caret_oneclass,
  label = paste(model.rf$method, " model"),
  colorize = TRUE
)

DALEX.attribution <- DALEX.explainer %>%
  iBreakDown::local_attributions(random.case)

DALEX.attribution

@hbaniecki hbaniecki added the bug 💣 Bug to fix label Feb 3, 2020
@agilebean
Copy link
Author

Great analysis.
Thanks for the oneclass predict_function code!
But that's a bummer, I need it for a publication.
Speaking of which, this issue on numbers on plots is extremely important for publications.

@pbiecek
Copy link
Member

pbiecek commented Feb 15, 2020

Thanks, there was a problem in the predict returns data.frame instead of matrix.
It is solved in the latest DALEX in the ema branch (will be on master on the beginning of the week and on CRAN in a week).

In the meantime you can use user defined predict_function

DALEX.explainer <- DALEX::explain(
  model = model.rf,
  data = features,
  y = target == TARGET.VALUE,
  predict_function = function(m,x) as.matrix(predict(m, newdata = x, type = "prob")),
  label = paste(model_object$method, " model"),
  colorize = TRUE
)

@pbiecek
Copy link
Member

pbiecek commented Feb 17, 2020

this is now fixed with the latest DALEX starting with 0.9.8 as in
https://github.com/ModelOriented/DALEX/tree/DALEX_1.0_ema_version

@pbiecek pbiecek closed this as completed Feb 17, 2020
@agilebean
Copy link
Author

I can confirm it works now - just ran local_attributions() on a classification dataset.
Wonderful.
Returns this plot:
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 💣 Bug to fix
Projects
None yet
Development

No branches or pull requests

3 participants