diff --git a/pkg/R/sparkR.R b/pkg/R/sparkR.R index 8a3fca68713e9..a623d84812d17 100644 --- a/pkg/R/sparkR.R +++ b/pkg/R/sparkR.R @@ -88,9 +88,7 @@ sparkR.init <- function( sparkEnvir = list(), sparkExecutorEnv = list(), sparkJars = "", - sparkRLibDir = "", - sparkRBackendPort = as.integer(Sys.getenv("SPARKR_BACKEND_PORT", "12345")), - sparkRRetryCount = 6) { + sparkRLibDir = "") { if (exists(".sparkRjsc", envir = .sparkREnv)) { cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n") @@ -121,45 +119,49 @@ sparkR.init <- function( if (sparkRExistingPort != "") { sparkRBackendPort <- sparkRExistingPort } else { + path <- tempfile(pattern = "backend_port") if (Sys.getenv("SPARKR_USE_SPARK_SUBMIT", "") == "") { launchBackend(classPath = cp, mainClass = "edu.berkeley.cs.amplab.sparkr.SparkRBackend", - args = as.character(sparkRBackendPort), + args = path, javaOpts = paste("-Xmx", sparkMem, sep = "")) } else { # TODO: We should deprecate sparkJars and ask users to add it to the # command line (using --jars) which is picked up by SparkSubmit launchBackendSparkSubmit( mainClass = "edu.berkeley.cs.amplab.sparkr.SparkRBackend", - args = as.character(sparkRBackendPort), + args = path, appJar = .sparkREnv$assemblyJarPath, sparkHome = sparkHome, sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "")) } + # wait atmost 100 seconds for JVM to launch + wait <- 0.1 + for (i in 1:25) { + Sys.sleep(wait) + if (file.exists(path)) { + break + } + wait <- wait * 1.25 + } + if (!file.exists(path)) { + stop("JVM is not ready after 10 seconds") + } + f <- file(path, open='rb') + sparkRBackendPort <- readInt(f) + close(f) + file.remove(path) + if (length(sparkRBackendPort) == 0) { + stop("JVM failed to launch") + } } .sparkREnv$sparkRBackendPort <- sparkRBackendPort - cat("Waiting for JVM to come up...\n") - tries <- 0 - while (tries < sparkRRetryCount) { - if (!connExists(.sparkREnv)) { - Sys.sleep(2 ^ tries) - tryCatch({ - connectBackend("localhost", .sparkREnv$sparkRBackendPort) - }, error = function(err) { - cat("Error in Connection, retrying...\n") - }, warning = function(war) { - cat("No Connection Found, retrying...\n") - }) - tries <- tries + 1 - } else { - cat("Connection ok.\n") - break - } - } - if (tries == sparkRRetryCount) { - stop(sprintf("Failed to connect JVM after %d tries.\n", sparkRRetryCount)) - } + tryCatch({ + connectBackend("localhost", sparkRBackendPort) + }, error = function(err) { + stop("Failed to connect JVM\n") + }) if (nchar(sparkHome) != 0) { sparkHome <- normalizePath(sparkHome) diff --git a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackend.scala b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackend.scala index e6c606a519c3f..89b2e7ada6b4c 100644 --- a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackend.scala +++ b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRBackend.scala @@ -1,7 +1,7 @@ package edu.berkeley.cs.amplab.sparkr -import java.io.IOException -import java.net.InetSocketAddress +import java.io.{File, FileOutputStream, DataOutputStream, IOException} +import java.net.{InetSocketAddress, Socket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -24,9 +24,7 @@ class SparkRBackend { var bootstrap: ServerBootstrap = null var bossGroup: EventLoopGroup = null - def init(port: Int) { - val socketAddr = new InetSocketAddress(port) - + def init(): Int = { bossGroup = new NioEventLoopGroup(SparkRConf.numServerThreads) val workerGroup = bossGroup val handler = new SparkRBackendHandler(this) @@ -51,9 +49,9 @@ class SparkRBackend { } }) - channelFuture = bootstrap.bind(socketAddr) + channelFuture = bootstrap.bind(new InetSocketAddress(0)) channelFuture.syncUninterruptibly() - println("SparkR Backend server started on port :" + port) + channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() } def run() = { @@ -80,18 +78,26 @@ class SparkRBackend { object SparkRBackend { def main(args: Array[String]) { if (args.length < 1) { - System.err.println("Usage: SparkRBackend ") + System.err.println("Usage: SparkRBackend ") System.exit(-1) } val sparkRBackend = new SparkRBackend() try { - sparkRBackend.init(args(0).toInt) + // bind to random port + val boundPort = sparkRBackend.init() + // tell the R process via temporary file + val path = args(0) + val f = new File(path + ".tmp") + val dos = new DataOutputStream(new FileOutputStream(f)) + dos.writeInt(boundPort) + dos.close() + f.renameTo(new File(path)) sparkRBackend.run() } catch { case e: IOException => System.err.println("Server shutting down: failed with exception ", e) sparkRBackend.close() - System.exit(0) + System.exit(1) } System.exit(0) } diff --git a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRRunner.scala b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRRunner.scala index fb356f89b2d1d..1b27f78844551 100644 --- a/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRRunner.scala +++ b/pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/SparkRRunner.scala @@ -22,8 +22,6 @@ object SparkRRunner { // Time to wait for SparkR backend to initialize in seconds val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt - // TODO: Can we get this from SparkConf ? - val sparkRBackendPort = sys.env.getOrElse("SPARKR_BACKEND_PORT", "12345").toInt val rCommand = "Rscript" // Check if the file path exists. @@ -39,23 +37,19 @@ object SparkRRunner { // Launch a SparkR backend server for the R process to connect to; this will let it see our // Java system properties etc. val sparkRBackend = new SparkRBackend() - val sparkRBackendThread = new Thread() { - val finishedInit = new Semaphore(0) - + @volatile var sparkRBackendPort = 0 + val initialized = new Semaphore(0) + val sparkRBackendThread = new Thread("SparkR backend") { override def run() { - sparkRBackend.init(sparkRBackendPort) - finishedInit.release() + sparkRBackendPort = sparkRBackend.init() + initialized.release() sparkRBackend.run() } - - def stopBackend() { - sparkRBackend.close() - } } sparkRBackendThread.start() // Wait for SparkRBackend initialization to finish - if (sparkRBackendThread.finishedInit.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { + if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) { // Launch R val returnCode = try { val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs) @@ -68,7 +62,7 @@ object SparkRRunner { process.waitFor() } finally { - sparkRBackendThread.stopBackend() + sparkRBackend.close() } System.exit(returnCode) } else {