Skip to content

Commit

Permalink
adds tile plots for classification tasks (#923)
Browse files Browse the repository at this point in the history
  • Loading branch information
zmjones authored and larskotthoff committed Jun 3, 2016
1 parent 3f6666e commit e71a126
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 29 deletions.
1 change: 1 addition & 0 deletions NEWS
Expand Up @@ -30,6 +30,7 @@ mlr_2.9:
- rename generatePartialPrediction, plotPartialPrediction, and plotPartialPredictionGGVIS to
generatePartialDependence, plotPartialDependence, and plotPartialDependenceGGVIS
- add stability.nugget parameter for "regr.km"
- plotPartialDependence now can create plots classification tasks with more than one interacted features

- new learners
-- surv.cv.CoxBoost
Expand Down
36 changes: 12 additions & 24 deletions R/generatePartialDependence.R
Expand Up @@ -460,26 +460,8 @@ plotPartialDependence = function(obj, geom = "line", facet = NULL, p = 1) {
if (length(obj$features) > 2L & geom != "tile" & obj$interaction)
stop("To plot more than 2 features geom must be 'tile'!")
assertChoice(geom, c("tile", "line"))
if (geom == "tile") {
if (!(obj$task.desc$type %in% c("regr", "surv"))) {
if (length(obj$task.desc$class.levels) > 2L)
stop("Only visualization of binary classification works with tiling!")
}

feat_classes = sapply(obj$data, class)
if (any(feat_classes == "factor")) {
fact_feats = names(feat_classes[feat_classes == "factor"])
if (!is.null(facet))
fact_feats = fact_feats[which(fact_feats != facet)]
do_not_contour = length(fact_feats) > 0L
} else
do_not_contour = FALSE

if (do_not_contour)
warning("Factor features cannot be used to create contour plots! only tiles will be displayed.")
if (!obj$interaction)
if (geom == "tile" & !obj$interaction)
stop("generatePartialDependenceData was called with interaction = FALSE!")
}

if (!is.null(facet)) {
assertChoice(facet, obj$features)
Expand All @@ -499,6 +481,8 @@ plotPartialDependence = function(obj, geom = "line", facet = NULL, p = 1) {
facet = "Feature"
scales = "free_x"
}
if (obj$task.desc$type == "classif" & geom == "tile" & length(features) == 2L)
scales = "free_x"
}

if (p != 1) {
Expand Down Expand Up @@ -554,14 +538,18 @@ plotPartialDependence = function(obj, geom = "line", facet = NULL, p = 1) {
if (obj$derivative)
plt = plt + ylab(stri_paste(target, "(derivative)", sep = " "))
} else { ## tiling
plt = ggplot(obj$data, aes_string(x = features[1], y = features[2], z = target))
plt = plt + geom_tile(aes_string(fill = target))
if (!do_not_contour)
plt = plt + stat_contour()
if (obj$task.desc$type == "classif") {
plt = ggplot(obj$data, aes_string(x = features[1], y = features[2], fill = "Probability"))
plt = plt + geom_raster()
facet = "Class"
} else {
plt = ggplot(obj$data, aes_string(x = features[1], y = features[2], z = target))
plt = plt + geom_raster(aes_string(fill = target))
}
}

if (!is.null(facet))
plt = plt + facet_wrap(as.formula(stri_paste("~ ", facet)), scales = scales)
plt = plt + facet_wrap(as.formula(stri_paste("~", facet)), scales = scales)

plt
}
Expand Down
8 changes: 3 additions & 5 deletions tests/testthat/test_base_generatePartialDependence.R
Expand Up @@ -72,7 +72,7 @@ test_that("generatePartialDependenceData", {
interaction = TRUE, gridsize = gridsize)
nfacet = length(unique(dcp$data$Petal.Length))
ntarget = length(dcp$target)
plotPartialDependence(dcp, facet = "Petal.Length")
plotPartialDependence(dcp, geom = "tile")
ggsave(path)
doc = XML::xmlParse(path)
#expect_that(length(XML::getNodeSet(doc, grey.xpath, ns.svg)), equals(nfacet))
Expand All @@ -84,11 +84,11 @@ test_that("generatePartialDependenceData", {
## check that probability outputting classifiers work with ICE
dcp = generatePartialDependenceData(fcp, input = multiclass.task, features = c("Petal.Width", "Petal.Length"),
interaction = TRUE, individual = TRUE, gridsize = gridsize)
plotPartialDependence(dcp, facet = "Petal.Length")
plotPartialDependence(dcp, geom = "tile")
## plotPartialDependenceGGVIS(dcp, interact = "Petal.Length")

## check that survival tasks work with multiple features
fs = train("surv.rpart", surv.task)
fs = train("surv.coxph", surv.task)
ds = generatePartialDependenceData(fs, input = surv.task, features = c("x1", "x2"),
gridsize = gridsize)
nfeat = length(ds$features)
Expand Down Expand Up @@ -167,8 +167,6 @@ test_that("generatePartialDependenceData", {
bounds = c(-2, 2), gridsize = gridsize)

## check that tile + contour plots work for two and three features with regression and survival
expect_warning(plotPartialDependence(db, "tile")) ## factor feature
expect_error(plotPartialDependence(dcp, geom = "tile")) ## no multiclass support
expect_error(plotPartialDependence(ds, geom = "tile")) ## interaction == FALSE
tfr = generatePartialDependenceData(fr, regr.df, features = c("lstat", "crim", "chas"),
interaction = TRUE, gridsize = gridsize)
Expand Down

0 comments on commit e71a126

Please sign in to comment.