Skip to content

Commit

Permalink
cross-validation (cv) now correctly uses mc.cores argument for parall…
Browse files Browse the repository at this point in the history
…el computing on suppported platforms.
  • Loading branch information
kkholst committed Sep 25, 2017
1 parent 0969868 commit 517b5dc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
13 changes: 7 additions & 6 deletions R/cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ rmse1 <- function(fit,data,response=NULL,...) {
##' f0 <- function(data,...) lm(...,data)
##' f1 <- function(data,...) lm(Sepal.Length~Species,data)
##' f2 <- function(data,...) lm(Sepal.Length~Species+Petal.Length,data)
##' x <- cv(list(model0=f0,model1=f1,model2=f2),rep=10, data=iris, formula=Sepal.Length~.)
##' x <- cv(list(m0=f0,m1=f1,m2=f2),rep=10, data=iris, formula=Sepal.Length~.)
##' x2 <- cv(list(f0(iris),f1(iris),f2(iris)),rep=10, data=iris)
##' @export
cv <- function(modelList, data, K=5, rep=1, perf, seed=NULL, mc.cores=1, ...) {
if (missing(perf)) perf <- rmse1
if (!is.list(modelList)) modelList <- list(modelList)
nam <- names(modelList)
if (is.null(nam)) nam <- rep("",length(modelList))
if (is.null(nam)) nam <- paste0("model",seq_along(modelList))
args <- list(...)
## Models run on full data:
if (is.function(modelList[[1]])) {
Expand All @@ -43,18 +43,19 @@ cv <- function(modelList, data, K=5, rep=1, perf, seed=NULL, mc.cores=1, ...) {
names(fit0) <- names(perf0) <- nam
n <- nrow(data)
M <- length(perf0) # Number of models
P <- length(perf0[[1]]) # Number of performance measures
P <- length(perf0[[1]]) # Number of performance measures
if (!is.null(seed)) {
if (!exists(".Random.seed", envir = .GlobalEnv, inherits = FALSE))
runif(1)
R.seed <- get(".Random.seed", envir = .GlobalEnv)
set.seed(seed)
RNGstate <- structure(seed, kind = as.list(RNGkind()))
on.exit(assign(".Random.seed", R.seed, envir = .GlobalEnv))
}

nam <- list(NULL,NULL,nam,namPerf)
dim <- c(rep,K,M,P)
PerfArr <- array(0,dim)
dimnames(PerfArr) <- nam
dimnames(PerfArr) <- list(NULL,NULL,nam,namPerf)
folds <- foldr(n,K,rep)
arg <- expand.grid(R=seq(rep),K=seq(K)) #,M=seq_along(modelList))

Expand All @@ -73,7 +74,7 @@ cv <- function(modelList, data, K=5, rep=1, perf, seed=NULL, mc.cores=1, ...) {
do.call(rbind,perfs)
}
if (mc.cores>1) {
val <- parallel::mcmapply(ff,seq(nrow(arg)),SIMPLIFY=FALSE)
val <- parallel::mcmapply(ff,seq(nrow(arg)),SIMPLIFY=FALSE,mc.cores=mc.cores)
} else {
val <- mapply(ff,seq(nrow(arg)),SIMPLIFY=FALSE)
}
Expand Down
2 changes: 1 addition & 1 deletion man/cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 517b5dc

Please sign in to comment.