Skip to content

Commit

Permalink
plotPartialDependency: fix for bug 1174 (#1180)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmjones authored and berndbischl committed Aug 23, 2016
1 parent 63f8bdd commit 37a18ac
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 80 deletions.
31 changes: 20 additions & 11 deletions R/generatePartialDependence.R
Expand Up @@ -540,7 +540,7 @@ doAggregatePartialDependence = function(out, td, target, features, rng) {
stop("function argument must return a sorted numeric vector ordered lowest to highest.")

if (all(target %in% td$class.levels)) {
out = melt(out, id.vars = features, variable = "Class", value.name = "Probability")
out = melt(out, id.vars = features, variable = "Class", value.name = "Probability", variable.factor = TRUE)
out$Class = stri_replace_all(out$Class, "", regex = "^prob\\.")
}
out
Expand All @@ -556,7 +556,7 @@ doIndividualPartialDependence = function(out, td, n, rng, target, features, cent
rng = rng[rep(seq_len(nrow(rng)), each = n), , drop = FALSE]
out = cbind(out, rng, idx, row.names = NULL)
out = melt(out, id.vars = c(features, "idx"),
variable.name = "Class", value.name = "Probability")
variable.name = "Class", value.name = "Probability", variable.factor = TRUE)
out$idx = interaction(out$idx, out$Class)
} else {
out = as.data.frame(do.call("rbind", out))
Expand All @@ -566,7 +566,8 @@ doIndividualPartialDependence = function(out, td, n, rng, target, features, cent
out = cbind(out, rng)
out = melt(out, id.vars = features, variable.name = "idx", value.name = target)
if (td$type == "classif")
out = melt(out, id.vars = c(features, "idx"), value.name = "Probability", variable.name = "Class")
out = melt(out, id.vars = c(features, "idx"), value.name = "Probability", variable.name = "Class",
variable.factor = TRUE)
}
out
}
Expand Down Expand Up @@ -693,23 +694,26 @@ plotPartialDependence = function(obj, geom = "line", facet = NULL, facet.wrap.nr
length(features) < 3L & geom == "line")

if (geom == "line") {
idx = which(sapply(obj$data, class) == "factor" & colnames(obj$data) %in% features)
# explicit casting previously done implicitly by reshape2::melt.data.frame
for (id in idx) obj$data[, id] = as.numeric(obj$data[, id])
obj$data = setDF(melt(data.table(obj$data),
id.vars = colnames(obj$data)[!colnames(obj$data) %in% features],
variable = "Feature", value.name = "Value", na.rm = TRUE))
id.vars = colnames(obj$data)[!colnames(obj$data) %in% features],
variable = "Feature", value.name = "Value", na.rm = TRUE, variable.factor = TRUE))
if (!obj$individual) {
if (obj$task.desc$type %in% c("regr", "surv"))
plt = ggplot(obj$data, aes_string("Value", target)) +
geom_line(color = ifelse(is.null(data), "black", "red"))
geom_line(color = ifelse(is.null(data), "black", "red")) + geom_point()
else
plt = ggplot(obj$data, aes_string("Value", "Probability", group = "Class", color = "Class")) +
geom_line()
geom_line() + geom_point()
} else {
if (obj$task.desc$type %in% c("regr", "surv")) {
plt = ggplot(obj$data, aes_string("Value", target, group = "idx")) +
geom_line(alpha = .25, color = ifelse(is.null(data), "black", "red"))
geom_line(alpha = .25, color = ifelse(is.null(data), "black", "red")) + geom_point()
} else {
plt = ggplot(obj$data, aes_string("Value", "Probability", group = "idx", color = "Class")) +
geom_line(alpha = .25)
geom_line(alpha = .25) + geom_point()
}
}

Expand All @@ -720,9 +724,11 @@ plotPartialDependence = function(obj, geom = "line", facet = NULL, facet.wrap.nr
plt = plt + labs(x = features)
}

# bounds from fun or se estimation
if (bounds)
plt = plt + geom_ribbon(aes_string(ymin = "lower", ymax = "upper"), alpha = .5)

# labels for ice plots
if (obj$center)
plt = plt + ylab(stri_paste(target, "(centered)", sep = " "))

Expand All @@ -738,24 +744,27 @@ plotPartialDependence = function(obj, geom = "line", facet = NULL, facet.wrap.nr
plt = ggplot(obj$data, aes_string(x = features[1], y = features[2], fill = target))
plt = plt + geom_raster(aes_string(fill = target))

# labels for ICE plots
if (obj$center)
plt = plt + scale_fill_continuous(guide = guide_colorbar(title = stri_paste(target, "(centered)", sep = " ")))

if (obj$derivative)
plt = plt + scale_fill_continuous(guide = guide_colorbar(title = stri_paste(target, "(derivative)", sep = " ")))
}

# facetting
if (!is.null(facet)) {
plt = plt + facet_wrap(as.formula(stri_paste("~", facet)), scales = scales,
nrow = facet.wrap.nrow, ncol = facet.wrap.ncol)
}

# data overplotting
if (!is.null(data)) {
data = data[, colnames(data) %in% c(obj$features, obj$task.desc$target)]
if (!is.null(facet)) {
if (!facet %in% obj$features)
data = melt(data, id.vars = c(obj$task.desc$target, obj$features[obj$features == facet]),
variable = "Feature", value.name = "Value", na.rm = TRUE)
variable = "Feature", value.name = "Value", na.rm = TRUE, variable.factor = TRUE)
if (facet %in% obj$features) {
if (!is.factor(data[[facet]]))
data[[facet]] = stri_paste(facet, "=", as.factor(signif(data[[facet]], 2)), sep = " ")
Expand Down Expand Up @@ -844,7 +853,7 @@ plotPartialDependenceGGVIS = function(obj, interact = NULL, p = 1) {
} else if (!obj$interaction & length(obj$features) > 1L) {
id = colnames(obj$data)[!colnames(obj$data) %in% obj$features]
obj$data = melt(obj$data, id.vars = id, variable.name = "Feature",
value.name = "Value", na.rm = TRUE)
value.name = "Value", na.rm = TRUE, variable.factor = TRUE)
interact = "Feature"
choices = obj$features
} else
Expand Down
143 changes: 74 additions & 69 deletions tests/testthat/test_base_generatePartialDependence.R
Expand Up @@ -24,7 +24,7 @@ test_that("generateFunctionalANOVAData", {
expect_error(generateFunctionalANOVAData(fr, regr.task, c("lstat", "age"), 3L, mean, gridsize = gridsize))

dr1b = generateFunctionalANOVAData(fr, regr.task, c("lstat", "age"), 1L,
function(x) quantile(x, c(.025, .5, .975)), gridsize = gridsize)
function(x) quantile(x, c(.025, .5, .975)), gridsize = gridsize)
expect_that(dim(dr1b$data), equals(c(gridsize * length(dr1b$features), 6L)))
plotPartialDependence(dr1b)
dir = tempdir()
Expand All @@ -33,7 +33,7 @@ test_that("generateFunctionalANOVAData", {
doc = XML::xmlParse(path)

dr2b = generateFunctionalANOVAData(fr, regr.task, c("lstat", "age"), 2L,
function(x) quantile(x, c(.025, .5, .975)), gridsize = gridsize)
function(x) quantile(x, c(.025, .5, .975)), gridsize = gridsize)
expect_that(dim(dr2b$data), equals(c(gridsize^length(dr2b$features), 6L)))
expect_that(dr2b$interaction, is_true())
plotPartialDependence(dr2b, "tile")
Expand All @@ -49,11 +49,11 @@ test_that("generateFunctionalANOVAData", {
test_that("generatePartialDependenceData", {
gridsize = 3L

## test regression with interactions, centering, and mixed factor features
# test regression with interactions, centering, and mixed factor features
fr = train("regr.rpart", regr.task)
dr = generatePartialDependenceData(fr, input = regr.task, features = c("lstat", "chas"),
interaction = TRUE, fmin = list("lstat" = 1, "chas" = NA),
fmax = list("lstat" = 40, "chas" = NA), gridsize = gridsize)
interaction = TRUE, fmin = list("lstat" = 1, "chas" = NA),
fmax = list("lstat" = 40, "chas" = NA), gridsize = gridsize)
nfeat = length(dr$features)
nfacet = length(unique(regr.df[["chas"]]))
n = getTaskSize(regr.task)
Expand All @@ -65,146 +65,151 @@ test_that("generatePartialDependenceData", {
path = paste0(dir, "/test.svg")
ggsave(path)
doc = XML::xmlParse(path)
#expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfacet))
#expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(nfacet * gridsize))
## plotPartialDependenceGGVIS(dr, interact = "chas")
# expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfacet))
# expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(nfacet * gridsize))
# plotPartialDependenceGGVIS(dr, interact = "chas")

## check that if the input is a data.frame things work
# check that if the input is a data.frame things work
dr.df = generatePartialDependenceData(fr, input = regr.df, features = "lstat")

## check that the interactions and centering work with ICE
# check that the interactions and centering work with ICE
dr = generatePartialDependenceData(fr, input = regr.task, features = c("lstat", "chas"),
interaction = TRUE, individual = TRUE,
fmin = list("lstat" = 1, "chas" = NA),
fmax = list("lstat" = 40, "chas" = NA), gridsize = gridsize)
interaction = TRUE, individual = TRUE,
fmin = list("lstat" = 1, "chas" = NA),
fmax = list("lstat" = 40, "chas" = NA), gridsize = gridsize)
expect_that(max(dr$data$lstat), equals(40.))
expect_that(min(dr$data$lstat), equals(1.))
expect_that(nrow(dr$data), equals(gridsize * nfeat * n))

plotPartialDependence(dr, facet = "chas", data = regr.df, p = .25)
ggsave(path)
doc = XML::xmlParse(path)
#expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfacet))
## black.xpath counts points which are omitted when individual = TRUE
## expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(nfacet * gridsize * n))
## plotPartialDependenceGGVIS(dr, interact = "chas")
# expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfacet))
# black.xpath counts points which are omitted when individual = TRUE
# expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(nfacet * gridsize * n))
# plotPartialDependenceGGVIS(dr, interact = "chas")

## check that multiple features w/o interaction work with a label outputting classifier with
## an appropriate aggregation function
# check that multiple features w/o interaction work with a label outputting classifier with
# an appropriate aggregation function
fc = train("classif.rpart", multiclass.task)
dc = generatePartialDependenceData(fc, input = multiclass.task, features = c("Petal.Width", "Petal.Length"),
fun = function(x) table(x) / length(x), gridsize = gridsize)
fun = function(x) table(x) / length(x), gridsize = gridsize)
nfeat = length(dc$features)
n = getTaskSize(multiclass.task)
plotPartialDependence(dc, data = multiclass.df)
ggsave(path)
doc = XML::xmlParse(path)
## minus one because the of the legend
#expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfeat))
#expect_that(length(XML::getNodeSet(doc, red.xpath, ns.svg)) - 1, equals(nfeat * gridsize))
#expect_that(length(XML::getNodeSet(doc, blue.xpath, ns.svg)) - 1, equals(nfeat * gridsize))
#expect_that(length(XML::getNodeSet(doc, green.xpath, ns.svg)) - 1, equals(nfeat * gridsize))
## plotPartialDependenceGGVIS(dc)

## test that an inappropriate function for a classification task throws an error
## bounds cannot be used on classifiers
# minus one because the of the legend
# expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfeat))
# expect_that(length(XML::getNodeSet(doc, red.xpath, ns.svg)) - 1, equals(nfeat * gridsize))
# expect_that(length(XML::getNodeSet(doc, blue.xpath, ns.svg)) - 1, equals(nfeat * gridsize))
# expect_that(length(XML::getNodeSet(doc, green.xpath, ns.svg)) - 1, equals(nfeat * gridsize))
# plotPartialDependenceGGVIS(dc)

# test that an inappropriate function for a classification task throws an error
# bounds cannot be used on classifiers
fcp = train(makeLearner("classif.rpart", predict.type = "prob"), multiclass.task)
expect_error(generatePartialDependence(fcp, input = multiclass.task, features = "Petal.Width",
fun = function(x) quantile(x, c(.025, .5, .975))), gridsize = gridsize)
fun = function(x) quantile(x, c(.025, .5, .975))), gridsize = gridsize)

## check that probability outputting classifiers work w/ interactions
# check that probability outputting classifiers work w/ interactions
dcp = generatePartialDependenceData(fcp, input = multiclass.task, features = c("Petal.Width", "Petal.Length"),
interaction = TRUE, gridsize = gridsize)
interaction = TRUE, gridsize = gridsize)
nfacet = length(unique(dcp$data$Petal.Length))
ntarget = length(dcp$target)
plotPartialDependence(dcp, "tile")

## check that probability outputting classifiers work with ICE
# check that probability outputting classifiers work with ICE
dcp = generatePartialDependenceData(fcp, input = multiclass.task, features = c("Petal.Width", "Petal.Length"),
interaction = TRUE, individual = TRUE, gridsize = gridsize)
interaction = TRUE, individual = TRUE, gridsize = gridsize)

## check that survival tasks work with multiple features
# check that survival tasks work with multiple features
fs = train("surv.rpart", surv.task)
ds = generatePartialDependenceData(fs, input = surv.task, features = c("x1", "x2"), gridsize = gridsize)
nfeat = length(ds$features)
n = getTaskSize(surv.task)
plotPartialDependence(ds, data = surv.df)
ggsave(path)
doc = XML::xmlParse(path)
#expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfeat))
#expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(gridsize * nfeat))
## plotPartialDependenceGGVIS(ds)
# expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfeat))
# expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(gridsize * nfeat))
# plotPartialDependenceGGVIS(ds)

## check that bounds work for regression
# issue 1180 test
pd = generatePartialDependenceData(fr, input = regr.task,
features = c("lstat", "chas"), gridsize = gridsize)
plotPartialDependence(pd)

# check that bounds work for regression
db = generatePartialDependenceData(fr, input = regr.task, features = c("lstat", "chas"),
interaction = TRUE,
fun = function(x) quantile(x, c(.25, .5, .75)), gridsize = gridsize)
interaction = TRUE,
fun = function(x) quantile(x, c(.25, .5, .75)), gridsize = gridsize)
nfacet = length(unique(regr.df[["chas"]]))
n = getTaskSize(regr.task)
expect_that(colnames(db$data), equals(c("medv", "lstat", "chas", "lower", "upper")))
plotPartialDependence(db, facet = "chas", data = regr.df)
ggsave(path)
doc = XML::xmlParse(path)
#expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfacet))
#expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(nfacet * gridsize))
## plotPartialDependenceGGVIS(db, interact = "chas")
# expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfacet))
# expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(nfacet * gridsize))
# plotPartialDependenceGGVIS(db, interact = "chas")

## check derivative and factor feature failure
# check derivative and factor feature failure
expect_error(generatePartialDependenceData(fr, input = regr.task, features = c("lstat", "chas"),
derivative = TRUE))
derivative = TRUE))

## check interaction + derivative failure
# check interaction + derivative failure
expect_error(generatePartialDependenceData(fr, input = regr.task, features = c("lstat", "chas"),
interaction = TRUE, derivative = TRUE))
interaction = TRUE, derivative = TRUE))

## check that bounds work w/o interaction
# check that bounds work w/o interaction
db2 = generatePartialDependenceData(fr, input = regr.task, features = c("lstat", "crim"),
fun = function(x) quantile(x, c(.25, .5, .75)), gridsize = gridsize)
fun = function(x) quantile(x, c(.25, .5, .75)), gridsize = gridsize)
nfeat = length(db2$features)
n = getTaskSize(regr.task)
plotPartialDependence(db2, data = regr.df)
ggsave(path)
doc = XML::xmlParse(path)
#expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfeat))
#expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(nfeat * gridsize))
## plotPartialDependenceGGVIS(db2)
# expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfeat))
# expect_that(length(XML::getNodeSet(doc, black.xpath, ns.svg)), equals(nfeat * gridsize))
# plotPartialDependenceGGVIS(db2)

fcpb = train(makeLearner("classif.rpart", predict.type = "prob"), binaryclass.task)
bc = generatePartialDependenceData(fcpb, input = binaryclass.task, features = c("V11", "V12"),
individual = TRUE, gridsize = gridsize)
individual = TRUE, gridsize = gridsize)
nfeat = length(bc$features)
n = getTaskSize(binaryclass.task)
plotPartialDependence(bc, data = binaryclass.df)
ggsave(path)
doc = XML::xmlParse(path)
#expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfeat))
## again, omission of points for individual = TRUE
#expect_that(length(XML::getNodeSet(doc, red.line.xpath, ns.svg)) - 1, equals(nfeat * n))
## plotPartialDependenceGGVIS(bc)
# expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfeat))
# again, omission of points for individual = TRUE
# expect_that(length(XML::getNodeSet(doc, red.line.xpath, ns.svg)) - 1, equals(nfeat * n))
# plotPartialDependenceGGVIS(bc)

## check that derivative estimation works for ICE and pd for classification and regression
# check that derivative estimation works for ICE and pd for classification and regression
subset = 1:5
fr = train(makeLearner("regr.ksvm"), regr.task)
pfr = generatePartialDependenceData(fr, input = regr.df[subset, ], features = c("lstat", "crim"),
derivative = TRUE, individual = TRUE, gridsize = gridsize)
derivative = TRUE, individual = TRUE, gridsize = gridsize)
fc = train(makeLearner("classif.ksvm", predict.type = "prob"), multiclass.task)
pfc = generatePartialDependenceData(fc, input = multiclass.df[subset, ],
features = c("Petal.Width", "Petal.Length"),
derivative = TRUE, gridsize = gridsize)
features = c("Petal.Width", "Petal.Length"),
derivative = TRUE, gridsize = gridsize)
fs = train(makeLearner("surv.coxph"), surv.task)
pfs = generatePartialDependenceData(fs, input = surv.df[subset, ],
features = c("x1", "x2"),
derivative = TRUE, gridsize = gridsize)
features = c("x1", "x2"),
derivative = TRUE, gridsize = gridsize)

## check that se estimation works
# check that se estimation works
fse = train(makeLearner("regr.lm", predict.type = "se"), regr.task)
pfse = generatePartialDependenceData(fse, input = regr.task, features = c("lstat", "crim"),
bounds = c(-2, 2), gridsize = gridsize)
bounds = c(-2, 2), gridsize = gridsize)

## check that tile + contour plots work for two and three features with regression and survival
expect_error(plotPartialDependence(ds, geom = "tile")) ## interaction == FALSE
# check that tile + contour plots work for two and three features with regression and survival
expect_error(plotPartialDependence(ds, geom = "tile")) # interaction == FALSE
tfr = generatePartialDependenceData(fr, regr.df, features = c("lstat", "crim", "chas"),
interaction = TRUE, gridsize = gridsize)
interaction = TRUE, gridsize = gridsize)
plotPartialDependence(tfr, geom = "tile", facet = "chas", data = regr.df)
tfs = generatePartialDependenceData(fs, surv.df, c("x1", "x2"), interaction = TRUE)
plotPartialDependence(tfs, geom = "tile", data = surv.df)
Expand Down

0 comments on commit 37a18ac

Please sign in to comment.