Skip to content

Commit

Permalink
Fix case when we don't have outcome for test data in survival forests
Browse files Browse the repository at this point in the history
  • Loading branch information
ehrlinger committed Aug 10, 2016
1 parent ef5ed93 commit e13a47c
Show file tree
Hide file tree
Showing 34 changed files with 48 additions and 39 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: ggRandomForests
Type: Package
Title: Visually Exploring Random Forests
Version: 2.0.0
Version: 2.0.1
Date: 2016-05-23
Author: John Ehrlinger <john.ehrlinger@gmail.com>
Maintainer: John Ehrlinger <john.ehrlinger@gmail.com>
Expand Down
8 changes: 7 additions & 1 deletion NEWS
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
Package: ggRandomForests
Version: 2.0.0
Version: 2.0.1

ggRandomForests v2.0.1
===================
* Correct a bug in survival plots when predicting on future data without a known outcome.
* Additional Vignettes are now at https://github.com/ehrlinger/ggRFVignette
* Minor bug and documentation fixes.

ggRandomForests v2.0.0
===================
Expand Down
22 changes: 13 additions & 9 deletions R/gg_rfsrc.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,21 @@ gg_rfsrc.rfsrc <- function(object,
# Do we want all lines, or bootstrap confidence bands.
colnames(rng) <- object$time.interest

rng$ptid <- 1:nrow(rng)
rng$cens <- as.logical(object$yvar[,2])
rng$obs_id <- 1:nrow(rng)
if(is.null(object$yvar)){
rng$event = FALSE
}else{
rng$event <- as.logical(object$yvar[,2])
}
gg_dta <- rng

# If we don't specify either a conf band or group by variable...
# Then we want to plot a curve for each observation.
if(is.null(arg_list$conf.int) & missing(by)){
gathercols <- colnames(gg_dta)[-which(colnames(gg_dta) %in% c("ptid", "cens"))]
gathercols <- colnames(gg_dta)[-which(colnames(gg_dta) %in% c("obs_id", "event"))]
gg_dta.mlt <- tidyr::gather_(gg_dta, "variable", "value", gathercols)
gg_dta.mlt$variable <- as.numeric(as.character(gg_dta.mlt$variable))
gg_dta.mlt$ptid <- factor(gg_dta.mlt$ptid)
gg_dta.mlt$obs_id <- factor(gg_dta.mlt$obs_id)

gg_dta <- gg_dta.mlt

Expand Down Expand Up @@ -303,7 +307,7 @@ gg_rfsrc.rfsrc <- function(object,

bootstrap_survival <- function(gg_dta, bs.samples, level.set){
## Calculate the leave one out estimate of the mean survival
gg.t <- gg_dta[, -which(colnames(gg_dta) %in% c("ptid","cens", "group"))]
gg.t <- gg_dta[, -which(colnames(gg_dta) %in% c("obs_id","event", "group"))]
mn.bs <- t(sapply(1:bs.samples,
function(pat){
st <- sample(1:nrow(gg.t), size = nrow(gg.t), replace=T)
Expand All @@ -322,13 +326,13 @@ bootstrap_survival <- function(gg_dta, bs.samples, level.set){
time.interest <- as.numeric(colnames(gg.t))

dta <- data.frame(cbind(time.interest,
t(rng)[-which(colnames(gg_dta) %in% c("ptid", "cens")),],
mn[-which(colnames(gg_dta) %in% c("ptid", "cens"))]))
t(rng)[-which(colnames(gg_dta) %in% c("obs_id", "event")),],
mn[-which(colnames(gg_dta) %in% c("obs_id", "event"))]))

if(ncol(dta) == 5){
colnames(dta)<- c("time", "lower", "upper", "median", "mean")
colnames(dta)<- c("value", "lower", "upper", "median", "mean")
}else{
colnames(dta)<- c("time", level.set, "mean")
colnames(dta)<- c("value", level.set, "mean")
}
dta
}
Expand Down
14 changes: 7 additions & 7 deletions R/gg_variable.R
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ gg_variable.rfsrc <- function(object,
gg_dta$yvar <- object$yvar

}else if(object$family == "surv"){
gg_dta$cens <- as.logical(object$yvar[,2])
colnames(gg_dta) <- c(object$xvar.names, "cens")
gg_dta$event <- as.logical(object$yvar[,2])
colnames(gg_dta) <- c(object$xvar.names, "event")

if(is.null(time)) time <- median(object$time.interest)
lng <- length(time)
Expand Down Expand Up @@ -251,15 +251,15 @@ gg_variable.randomForest <- function(object,

if(object$type == "regression"){
gg_dta$yhat <- object$predicted

}else{ # if(object$family == "class"){
colnames(object$predicted) <- paste("yhat.", colnames(object$predicted),
sep="")
gg_dta <- object$predicted
colnames(object$predicted) <- paste("yhat.", colnames(object$predicted),
sep="")
gg_dta <- object$predicted

gg_dta$yvar <- object$yvar

}
class(gg_dta) <- c("gg_variable", object$type, class(gg_dta))
invisible(gg_dta)
}
}
12 changes: 6 additions & 6 deletions R/plot.gg_rfsrc.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,20 @@ plot.gg_rfsrc <- function(x,

if("group" %in% colnames(gg_dta)){
gg_plt <- ggplot(gg_dta) +
geom_ribbon(aes_string(x="time", ymin="lower", ymax="upper", fill="group"),
geom_ribbon(aes_string(x="value", ymin="lower", ymax="upper", fill="group"),
alpha=alph,...) +
geom_step(aes_string(x="time", y="median", color="group"), ...)
geom_step(aes_string(x="value", y="median", color="group"), ...)
}else{
gg_plt <- ggplot(gg_dta) +
geom_ribbon(aes_string(x="time", ymin="lower", ymax="upper"),alpha=alph) +
geom_step(aes_string(x="time", y="median"), ...)
geom_ribbon(aes_string(x="value", ymin="lower", ymax="upper"),alpha=alph) +
geom_step(aes_string(x="value", y="median"), ...)
}
}else{

# Lines by observation
gg_plt <- ggplot(gg_dta,
aes_string(x="variable", y="value", col="cens",
by="ptid")) +
aes_string(x="variable", y="value", col="event",
by="obs_id")) +
geom_step(...)
}

Expand Down
22 changes: 11 additions & 11 deletions R/plot.gg_variable.R
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ plot.gg_variable <- function(x, xvar,
}

# These may be dangerous because they work on column names only.
if(sum(colnames(gg_dta) == "cens") != 0) family <- "surv"
if(sum(colnames(gg_dta) == "event") != 0) family <- "surv"

# Same here, but it's how I know there are multiple classes right now.
if(length(grep("yhat.", colnames(gg_dta))) > 0){
Expand Down Expand Up @@ -193,7 +193,7 @@ plot.gg_variable <- function(x, xvar,
if(missing(xvar)){
# We need to remove response variables here
cls <- c(grep("yhat", colnames(gg_dta)),
grep("cens", colnames(gg_dta)),
grep("event", colnames(gg_dta)),
grep("time", colnames(gg_dta))
)
xvar <- colnames(gg_dta)[-cls]
Expand Down Expand Up @@ -239,18 +239,18 @@ plot.gg_variable <- function(x, xvar,
## Survival plots
if(family == "surv"){
## response variables.
wch_y_var <- which(colnames(gg_dta) %in% c("cens", "yhat", "time"))
wch_y_var <- which(colnames(gg_dta) %in% c("event", "yhat", "time"))


# Handle categorical and continuous differently...
tmp_dta <- gg_dta[,c(wch_y_var, wch_x_var)]
gathercols <- colnames(tmp_dta)[-which(colnames(tmp_dta) %in% c("time", "cens", "yhat"))]
gathercols <- colnames(tmp_dta)[-which(colnames(tmp_dta) %in% c("time", "event", "yhat"))]
gg_dta.mlt <- tidyr::gather_(tmp_dta, "variable", "value", gathercols)

gg_dta.mlt$variable <- factor(gg_dta.mlt$variable, levels=xvar)
if(points){
gg_plt <- ggplot(gg_dta.mlt,
aes_string(x="value", y="yhat", color="cens", shape="cens"))
aes_string(x="value", y="yhat", color="event", shape="event"))
}else{
gg_plt <- ggplot(gg_dta.mlt,
aes_string(x="value", y="yhat"))
Expand Down Expand Up @@ -278,7 +278,7 @@ plot.gg_variable <- function(x, xvar,
gg_plt<- gg_plt +
geom_boxplot(aes_string(x="value", y="yhat"), color="grey",
..., outlier.shape = NA) +
geom_jitter(aes_string(x="value", y="yhat", color="cens", shape="cens"),
geom_jitter(aes_string(x="value", y="yhat", color="event", shape="event"),
...)

}
Expand Down Expand Up @@ -382,23 +382,23 @@ plot.gg_variable <- function(x, xvar,
# cat("2")
if(points){
gg_plt[[ind]] <- gg_plt[[ind]] +
geom_point(aes_string(x="var", y="yhat", color="cens", shape="cens"),
geom_point(aes_string(x="var", y="yhat", color="event", shape="event"),
...)
}else{
gg_plt[[ind]] <- gg_plt[[ind]] +
geom_smooth(aes_string(x="var", y="yhat"), ...)
}
if(smooth){
gg_plt[[ind]] <- gg_plt[[ind]] +
geom_smooth(...)
geom_smooth(aes_string(x="var", y="yhat"), ...)
}
}else{

# cat("3")
gg_plt[[ind]] <- gg_plt[[ind]] +
geom_boxplot(aes_string(x="var", y="yhat"), color="black",
..., outlier.shape = NA)+
geom_jitter(aes_string(x="var", y="yhat", color="cens", shape="cens"),
geom_jitter(aes_string(x="var", y="yhat", color="event", shape="event"),
...)

}
Expand Down Expand Up @@ -427,7 +427,7 @@ plot.gg_variable <- function(x, xvar,
if(smooth){
gg_plt[[ind]] <- gg_plt[[ind]] +
geom_smooth(...)
}
}
}else{

# cat("6")
Expand Down Expand Up @@ -488,4 +488,4 @@ plot.gg_variable <- function(x, xvar,
if(lng == 1) gg_plt <- gg_plt[[1]]
}
return(gg_plt)
}
}
7 changes: 3 additions & 4 deletions cran-comments.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
This is ggRandomForests package submission v2.0.0
This is ggRandomForests package submission v2.0.1
--------------------------------------------------------------------------------
* Added initial support for the randomForest package
* Updated cache files for randomForestSRC 2.2.0 release.
* Remove regression vignettes to meet CRAN size limts. These remain available at the package source https://github.com/ehrlinger/ggRandomForests
* Correct a bug in survival plots when predicting on future data without a known outcome.
* Additional Vignettes are now at https://github.com/ehrlinger/ggRFVignette
* Minor bug and documentation fixes.
Binary file added vignettes/fig-rfs/rfs-albumin-bili-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-albumin-coplot-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-bili-albumin-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-bili-coplot-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-categoricalEDA-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-continuousEDA-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-coplot_bilirubin-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-depthVimp-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-errorPlot-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-gg_survival-bili-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-mindepth-plot-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-pbc-partial-edema-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-pbc-partial-panel-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-plot_gg_cum_hazard-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-plot_gg_survival-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-predictPlot-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-rf-vimp-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-rfsrc-mean2-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-rfsrc-plot-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-rfsrc-plot3Mnth-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-surface3d-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-timeSurface3d-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-var_dep-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-variable-plot-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-variable-plotCat-1.pdf
Binary file not shown.
Binary file added vignettes/fig-rfs/rfs-variable-plotbili-1.pdf
Binary file not shown.
Binary file added vignettes/randomForestSRC-Survival.pdf
Binary file not shown.

0 comments on commit e13a47c

Please sign in to comment.