From f7a25644ed5b3b49fe7f33743bec3d95cdf7913e Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 17 Apr 2015 11:02:31 +0100 Subject: [PATCH 01/33] SPARK-6846 [WEBUI] Stage kill URL easy to accidentally trigger and possibility for security issue kill endpoints now only accept a POST (kill stage, master kill app, master kill driver); kill link now POSTs Author: Sean Owen Closes #5528 from srowen/SPARK-6846 and squashes the following commits: 137ac9f [Sean Owen] Oops, fix scalastyle line length probelm 7c5f961 [Sean Owen] Add Imran's test of kill link 59f447d [Sean Owen] kill endpoints now only accept a POST (kill stage, master kill app, master kill driver); kill link now POSTs --- .../org/apache/spark/ui/static/webui.css | 6 +-- .../spark/deploy/master/ui/MasterPage.scala | 28 +++++++------ .../spark/deploy/master/ui/MasterWebUI.scala | 8 ++-- .../org/apache/spark/ui/JettyUtils.scala | 17 +++++++- .../scala/org/apache/spark/ui/SparkUI.scala | 4 +- .../org/apache/spark/ui/jobs/StageTable.scala | 27 ++++++------- .../org/apache/spark/ui/UISeleniumSuite.scala | 40 +++++++++++++------ 7 files changed, 78 insertions(+), 52 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 6c37cc8b98236..4910744d1d790 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -85,17 +85,13 @@ table.sortable td { filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FFA4EDFF', endColorstr='#FF94DDFF', GradientType=0); } -span.kill-link { +a.kill-link { margin-right: 2px; margin-left: 20px; color: gray; float: right; } -span.kill-link a { - color: gray; -} - span.expand-details { font-size: 10pt; cursor: pointer; diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 399f07399a0aa..1f2c3fdbfb2bc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -190,12 +190,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def appRow(app: ApplicationInfo): Seq[Node] = { val killLink = if (parent.killEnabled && (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) { - val killLinkUri = s"app/kill?id=${app.id}&terminate=true" - val confirm = "return window.confirm(" + - s"'Are you sure you want to kill application ${app.id} ?');" - - (kill) - + val confirm = + s"if (window.confirm('Are you sure you want to kill application ${app.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
+ + + (kill) +
} @@ -223,12 +225,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { (driver.state == DriverState.RUNNING || driver.state == DriverState.SUBMITTED || driver.state == DriverState.RELAUNCHING)) { - val killLinkUri = s"driver/kill?id=${driver.id}&terminate=true" - val confirm = "return window.confirm(" + - s"'Are you sure you want to kill driver ${driver.id} ?');" - - (kill) - + val confirm = + s"if (window.confirm('Are you sure you want to kill driver ${driver.id} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
+ + + (kill) +
} {driver.id} {killLink} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 1b670418ab1ff..bb11e0642ddc6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -43,10 +43,10 @@ class MasterWebUI(val master: Master, requestedPort: Int) attachPage(new HistoryNotFoundPage(this)) attachPage(masterPage) attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler( - createRedirectHandler("/app/kill", "/", masterPage.handleAppKillRequest)) - attachHandler( - createRedirectHandler("/driver/kill", "/", masterPage.handleDriverKillRequest)) + attachHandler(createRedirectHandler( + "/app/kill", "/", masterPage.handleAppKillRequest, httpMethod = "POST")) + attachHandler(createRedirectHandler( + "/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethod = "POST")) } /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */ diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 95f254a9ef22a..a091ca650c60c 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -114,10 +114,23 @@ private[spark] object JettyUtils extends Logging { srcPath: String, destPath: String, beforeRedirect: HttpServletRequest => Unit = x => (), - basePath: String = ""): ServletContextHandler = { + basePath: String = "", + httpMethod: String = "GET"): ServletContextHandler = { val prefixedDestPath = attachPrefix(basePath, destPath) val servlet = new HttpServlet { - override def doGet(request: HttpServletRequest, response: HttpServletResponse) { + override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { + httpMethod match { + case "GET" => doRequest(request, response) + case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + override def doPost(request: HttpServletRequest, response: HttpServletResponse): Unit = { + httpMethod match { + case "POST" => doRequest(request, response) + case _ => response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + } + } + private def doRequest(request: HttpServletRequest, response: HttpServletResponse): Unit = { beforeRedirect(request) // Make sure we don't end up with "//" in the middle val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index adfa6bbada256..580ab8b1325f8 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -55,8 +55,8 @@ private[spark] class SparkUI private ( attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) - attachHandler( - createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest)) + attachHandler(createRedirectHandler( + "/stages/stage/kill", "/stages", stagesTab.handleKillRequest, httpMethod = "POST")) } initialize() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 5865850fa09b5..cb72890a0fd20 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -73,20 +73,21 @@ private[ui] class StageTableBase( } private def makeDescription(s: StageInfo): Seq[Node] = { - // scalastyle:off + val basePathUri = UIUtils.prependBaseUri(basePath) + val killLink = if (killEnabled) { - val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" - .format(UIUtils.prependBaseUri(basePath), s.stageId) - val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" - .format(s.stageId) - - (kill) - + val killLinkUri = s"$basePathUri/stages/stage/kill/" + val confirm = + s"if (window.confirm('Are you sure you want to kill stage ${s.stageId} ?')) " + + "{ this.parentNode.submit(); return true; } else { return false; }" +
+ + + (kill) +
} - // scalastyle:on - val nameLinkUri ="%s/stages/stage?id=%s&attempt=%s" - .format(UIUtils.prependBaseUri(basePath), s.stageId, s.attemptId) + val nameLinkUri = s"$basePathUri/stages/stage?id=${s.stageId}&attempt=${s.attemptId}" val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) @@ -98,11 +99,9 @@ private[ui] class StageTableBase( diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 1cb594633f331..eb9db550fd74c 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.net.{HttpURLConnection, URL} import javax.servlet.http.HttpServletRequest import scala.collection.JavaConversions._ @@ -56,12 +57,13 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before * Create a test SparkContext with the SparkUI enabled. * It is safe to `get` the SparkUI directly from the SparkContext returned here. */ - private def newSparkContext(): SparkContext = { + private def newSparkContext(killEnabled: Boolean = true): SparkContext = { val conf = new SparkConf() .setMaster("local") .setAppName("test") .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") + .set("spark.ui.killEnabled", killEnabled.toString) val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -128,21 +130,12 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } test("spark.ui.killEnabled should properly control kill button display") { - def getSparkContext(killEnabled: Boolean): SparkContext = { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.ui.enabled", "true") - .set("spark.ui.killEnabled", killEnabled.toString) - new SparkContext(conf) - } - def hasKillLink: Boolean = find(className("kill-link")).isDefined def runSlowJob(sc: SparkContext) { sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() } - withSpark(getSparkContext(killEnabled = true)) { sc => + withSpark(newSparkContext(killEnabled = true)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") @@ -150,7 +143,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } } - withSpark(getSparkContext(killEnabled = false)) { sc => + withSpark(newSparkContext(killEnabled = false)) { sc => runSlowJob(sc) eventually(timeout(5 seconds), interval(50 milliseconds)) { go to (sc.ui.get.appUIAddress.stripSuffix("/") + "/stages") @@ -233,7 +226,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before // because someone could change the error message and cause this test to pass by accident. // Instead, it's safer to check that each row contains a link to a stage details page. findAll(cssSelector("tbody tr")).foreach { row => - val link = row.underlying.findElement(By.xpath(".//a")) + val link = row.underlying.findElement(By.xpath("./td/div/a")) link.getAttribute("href") should include ("stage") } } @@ -356,4 +349,25 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers with Before } } } + + test("kill stage is POST only") { + def getResponseCode(url: URL, method: String): Int = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + connection.connect() + val code = connection.getResponseCode() + connection.disconnect() + code + } + + withSpark(newSparkContext(killEnabled = true)) { sc => + sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() + eventually(timeout(5 seconds), interval(50 milliseconds)) { + val url = new URL( + sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0&terminate=true") + getResponseCode(url, "GET") should be (405) + getResponseCode(url, "POST") should be (200) + } + } + } } From 4527761bcd6501c362baf2780905a0018b9a74ba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 17 Apr 2015 11:06:01 +0100 Subject: [PATCH 02/33] [SPARK-6046] [core] Reorganize deprecated config support in SparkConf. This change tries to follow the chosen way for handling deprecated configs in SparkConf: all values (old and new) are kept in the conf object, and newer names take precedence over older ones when retrieving the value. Warnings are logged when config options are set, which generally happens on the driver node (where the logs are most visible). Author: Marcelo Vanzin Closes #5514 from vanzin/SPARK-6046 and squashes the following commits: 9371529 [Marcelo Vanzin] Avoid math. 6cf3f11 [Marcelo Vanzin] Review feedback. 2445d48 [Marcelo Vanzin] Fix (and cleanup) update interval initialization. b6824be [Marcelo Vanzin] Clean up the other deprecated config use also. ab20351 [Marcelo Vanzin] Update FsHistoryProvider to only retrieve new config key. 2c93209 [Marcelo Vanzin] [SPARK-6046] [core] Reorganize deprecated config support in SparkConf. --- .../scala/org/apache/spark/SparkConf.scala | 174 ++++++++++-------- .../deploy/history/FsHistoryProvider.scala | 9 +- .../org/apache/spark/executor/Executor.scala | 5 +- .../org/apache/spark/SparkConfSuite.scala | 22 +++ docs/monitoring.md | 6 +- .../org/apache/spark/deploy/yarn/Client.scala | 3 +- 6 files changed, 124 insertions(+), 95 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 390e631647bd6..b0186e9a007b8 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -68,6 +68,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { if (value == null) { throw new NullPointerException("null value for " + key) } + logDeprecationWarning(key) settings.put(key, value) this } @@ -134,13 +135,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Set multiple parameters together */ def setAll(settings: Traversable[(String, String)]): SparkConf = { - this.settings.putAll(settings.toMap.asJava) + settings.foreach { case (k, v) => set(k, v) } this } /** Set a parameter if it isn't already configured */ def setIfMissing(key: String, value: String): SparkConf = { - settings.putIfAbsent(key, value) + if (settings.putIfAbsent(key, value) == null) { + logDeprecationWarning(key) + } this } @@ -174,8 +177,8 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { getOption(key).getOrElse(defaultValue) } - /** - * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no + /** + * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. * @throws NoSuchElementException */ @@ -183,36 +186,35 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.timeStringAsSeconds(get(key)) } - /** - * Get a time parameter as seconds, falling back to a default if not set. If no + /** + * Get a time parameter as seconds, falling back to a default if not set. If no * suffix is provided then seconds are assumed. - * */ def getTimeAsSeconds(key: String, defaultValue: String): Long = { Utils.timeStringAsSeconds(get(key, defaultValue)) } - /** - * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no - * suffix is provided then milliseconds are assumed. + /** + * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no + * suffix is provided then milliseconds are assumed. * @throws NoSuchElementException */ def getTimeAsMs(key: String): Long = { Utils.timeStringAsMs(get(key)) } - /** - * Get a time parameter as milliseconds, falling back to a default if not set. If no - * suffix is provided then milliseconds are assumed. + /** + * Get a time parameter as milliseconds, falling back to a default if not set. If no + * suffix is provided then milliseconds are assumed. */ def getTimeAsMs(key: String, defaultValue: String): Long = { Utils.timeStringAsMs(get(key, defaultValue)) } - + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { - Option(settings.get(key)) + Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) } /** Get all parameters as a list of pairs */ @@ -379,13 +381,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { } } } - - // Warn against the use of deprecated configs - deprecatedConfigs.values.foreach { dc => - if (contains(dc.oldName)) { - dc.warn() - } - } } /** @@ -400,19 +395,44 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { private[spark] object SparkConf extends Logging { + /** + * Maps deprecated config keys to information about the deprecation. + * + * The extra information is logged as a warning when the config is present in the user's + * configuration. + */ private val deprecatedConfigs: Map[String, DeprecatedConfig] = { val configs = Seq( - DeprecatedConfig("spark.files.userClassPathFirst", "spark.executor.userClassPathFirst", - "1.3"), - DeprecatedConfig("spark.yarn.user.classpath.first", null, "1.3", - "Use spark.{driver,executor}.userClassPathFirst instead."), - DeprecatedConfig("spark.history.fs.updateInterval", - "spark.history.fs.update.interval.seconds", - "1.3", "Use spark.history.fs.update.interval.seconds instead"), - DeprecatedConfig("spark.history.updateInterval", - "spark.history.fs.update.interval.seconds", - "1.3", "Use spark.history.fs.update.interval.seconds instead")) - configs.map { x => (x.oldName, x) }.toMap + DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", + "Please use spark.{driver,executor}.userClassPathFirst instead.")) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) + } + + /** + * Maps a current config key to alternate keys that were used in previous version of Spark. + * + * The alternates are used in the order defined in this map. If deprecated configs are + * present in the user's configuration, a warning is logged. + */ + private val configsWithAlternatives = Map[String, Seq[AlternateConfig]]( + "spark.executor.userClassPathFirst" -> Seq( + AlternateConfig("spark.files.userClassPathFirst", "1.3")), + "spark.history.fs.update.interval" -> Seq( + AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"), + AlternateConfig("spark.history.fs.updateInterval", "1.3"), + AlternateConfig("spark.history.updateInterval", "1.3")) + ) + + /** + * A view of `configsWithAlternatives` that makes it more efficient to look up deprecated + * config keys. + * + * Maps the deprecated config name to a 2-tuple (new config name, alternate config info). + */ + private val allAlternatives: Map[String, (String, AlternateConfig)] = { + configsWithAlternatives.keys.flatMap { key => + configsWithAlternatives(key).map { cfg => (cfg.key -> (key -> cfg)) } + }.toMap } /** @@ -443,61 +463,57 @@ private[spark] object SparkConf extends Logging { } /** - * Translate the configuration key if it is deprecated and has a replacement, otherwise just - * returns the provided key. - * - * @param userKey Configuration key from the user / caller. - * @param warn Whether to print a warning if the key is deprecated. Warnings will be printed - * only once for each key. + * Looks for available deprecated keys for the given config option, and return the first + * value available. */ - private def translateConfKey(userKey: String, warn: Boolean = false): String = { - deprecatedConfigs.get(userKey) - .map { deprecatedKey => - if (warn) { - deprecatedKey.warn() - } - deprecatedKey.newName.getOrElse(userKey) - }.getOrElse(userKey) + def getDeprecatedConfig(key: String, conf: SparkConf): Option[String] = { + configsWithAlternatives.get(key).flatMap { alts => + alts.collectFirst { case alt if conf.contains(alt.key) => + val value = conf.get(alt.key) + alt.translation.map(_(value)).getOrElse(value) + } + } } /** - * Holds information about keys that have been deprecated or renamed. + * Logs a warning message if the given config key is deprecated. + */ + def logDeprecationWarning(key: String): Unit = { + deprecatedConfigs.get(key).foreach { cfg => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"may be removed in the future. ${cfg.deprecationMessage}") + } + + allAlternatives.get(key).foreach { case (newKey, cfg) => + logWarning( + s"The configuration key '$key' has been deprecated as of Spark ${cfg.version} and " + + s"and may be removed in the future. Please use the new key '$newKey' instead.") + } + } + + /** + * Holds information about keys that have been deprecated and do not have a replacement. * - * @param oldName Old configuration key. - * @param newName New configuration key, or `null` if key has no replacement, in which case the - * deprecated key will be used (but the warning message will still be printed). + * @param key The deprecated key. * @param version Version of Spark where key was deprecated. - * @param deprecationMessage Message to include in the deprecation warning; mandatory when - * `newName` is not provided. + * @param deprecationMessage Message to include in the deprecation warning. */ private case class DeprecatedConfig( - oldName: String, - _newName: String, + key: String, version: String, - deprecationMessage: String = null) { - - private val warned = new AtomicBoolean(false) - val newName = Option(_newName) + deprecationMessage: String) - if (newName == null && (deprecationMessage == null || deprecationMessage.isEmpty())) { - throw new IllegalArgumentException("Need new config name or deprecation message.") - } - - def warn(): Unit = { - if (warned.compareAndSet(false, true)) { - if (newName != null) { - val message = Option(deprecationMessage).getOrElse( - s"Please use the alternative '$newName' instead.") - logWarning( - s"The configuration option '$oldName' has been replaced as of Spark $version and " + - s"may be removed in the future. $message") - } else { - logWarning( - s"The configuration option '$oldName' has been deprecated as of Spark $version and " + - s"may be removed in the future. $deprecationMessage") - } - } - } + /** + * Information about an alternate configuration key that has been deprecated. + * + * @param key The deprecated config key. + * @param version The Spark version in which the key was deprecated. + * @param translation A translation function for converting old config values into new ones. + */ + private case class AlternateConfig( + key: String, + version: String, + translation: Option[String => String] = None) - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9d40d8c8fd7a8..985545742df67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -49,11 +49,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private val NOT_STARTED = "" // Interval between each check for event log updates - private val UPDATE_INTERVAL_MS = conf.getOption("spark.history.fs.update.interval.seconds") - .orElse(conf.getOption("spark.history.fs.updateInterval")) - .orElse(conf.getOption("spark.history.updateInterval")) - .map(_.toInt) - .getOrElse(10) * 1000 + private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") // Interval between each cleaner checks for event logs to delete private val CLEAN_INTERVAL_MS = conf.getLong("spark.history.fs.cleaner.interval.seconds", @@ -130,8 +126,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis // Disable the background thread during tests. if (!conf.contains("spark.testing")) { // A task that periodically checks for event log updates on disk. - pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_MS, - TimeUnit.MILLISECONDS) + pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_S, TimeUnit.SECONDS) if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 1b5fdeba28ee2..327d155b38c22 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -89,10 +89,7 @@ private[spark] class Executor( ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME, new ExecutorEndpoint(env.rpcEnv, executorId)) // Whether to load classes in user jars before those in Spark jars - private val userClassPathFirst: Boolean = { - conf.getBoolean("spark.executor.userClassPathFirst", - conf.getBoolean("spark.files.userClassPathFirst", false)) - } + private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index e08210ae60d17..7d87ba5fd2610 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -197,6 +197,28 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro serializer.newInstance().serialize(new StringBuffer()) } + test("deprecated configs") { + val conf = new SparkConf() + val newName = "spark.history.fs.update.interval" + + assert(!conf.contains(newName)) + + conf.set("spark.history.updateInterval", "1") + assert(conf.get(newName) === "1") + + conf.set("spark.history.fs.updateInterval", "2") + assert(conf.get(newName) === "2") + + conf.set("spark.history.fs.update.interval.seconds", "3") + assert(conf.get(newName) === "3") + + conf.set(newName, "4") + assert(conf.get(newName) === "4") + + val count = conf.getAll.filter { case (k, v) => k.startsWith("spark.history.") }.size + assert(count === 4) + } + } class Class1 {} diff --git a/docs/monitoring.md b/docs/monitoring.md index 6816671ffbf46..2a130224591ca 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -86,10 +86,10 @@ follows: - spark.history.fs.update.interval.seconds - 10 + spark.history.fs.update.interval + 10s - The period, in seconds, at which information displayed by this history server is updated. + The period at which information displayed by this history server is updated. Each update checks for any changes made to the event logs in persisted storage. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1091ff54b0463..52e4dee46c535 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1052,8 +1052,7 @@ object Client extends Logging { if (isDriver) { conf.getBoolean("spark.driver.userClassPathFirst", false) } else { - conf.getBoolean("spark.executor.userClassPathFirst", - conf.getBoolean("spark.files.userClassPathFirst", false)) + conf.getBoolean("spark.executor.userClassPathFirst", false) } } From f6a9a57a72767f48fcc02e5fda4d6eafa67aebde Mon Sep 17 00:00:00 2001 From: Punya Biswal Date: Fri, 17 Apr 2015 11:08:37 +0100 Subject: [PATCH 03/33] [SPARK-6952] Handle long args when detecting PID reuse sbin/spark-daemon.sh used ps -p "$TARGET_PID" -o args= to figure out whether the process running with the expected PID is actually a Spark daemon. When running with a large classpath, the output of ps gets truncated and the check fails spuriously. This weakens the check to see if it's a java command (which is something we do in other parts of the script) rather than looking for the specific main class name. This means that SPARK-4832 might happen under a slightly broader range of circumstances (a java program happened to reuse the same PID), but it seems worthwhile compared to failing consistently with a large classpath. Author: Punya Biswal Closes #5535 from punya/feature/SPARK-6952 and squashes the following commits: 7ea12d1 [Punya Biswal] Handle long args when detecting PID reuse --- sbin/spark-daemon.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index d8e0facb81169..de762acc8fa0e 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -129,7 +129,7 @@ run_command() { if [ -f "$pid" ]; then TARGET_ID="$(cat "$pid")" - if [[ $(ps -p "$TARGET_ID" -o args=) =~ $command ]]; then + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then echo "$command running as process $TARGET_ID. Stop it first." exit 1 fi @@ -163,7 +163,7 @@ run_command() { echo "$newpid" > "$pid" sleep 2 # Check if the process has died; in that case we'll tail the log so the user can see - if [[ ! $(ps -p "$newpid" -o args=) =~ $command ]]; then + if [[ ! $(ps -p "$newpid" -o comm=) =~ "java" ]]; then echo "failed to launch $command:" tail -2 "$log" | sed 's/^/ /' echo "full log in $log" From dc48ba9f9f7449dd2f12cbad288b65c8119d9284 Mon Sep 17 00:00:00 2001 From: linweizhong Date: Fri, 17 Apr 2015 12:04:02 +0100 Subject: [PATCH 04/33] [SPARK-6604][PySpark]Specify ip of python server scoket In driver now will start a server socket and use a wildcard ip, use 127.0.0.0 is more reasonable, as we only use it by local Python process. /cc davies Author: linweizhong Closes #5256 from Sephiroth-Lin/SPARK-6604 and squashes the following commits: 7b3c633 [linweizhong] rephrase --- core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b1ffba4c546bf..7409dc2d866f6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -604,7 +604,7 @@ private[spark] object PythonRDD extends Logging { * The thread will terminate after all the data are sent or any exceptions happen. */ private def serveIterator[T](items: Iterator[T], threadName: String): Int = { - val serverSocket = new ServerSocket(0, 1) + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) From c84d91692aa25c01882bcc3f9fd5de3cfa786195 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 17 Apr 2015 11:29:27 -0500 Subject: [PATCH 05/33] [SPARK-6957] [SPARK-6958] [SQL] improve API compatibility to pandas ``` select(['cola', 'colb']) groupby(['colA', 'colB']) groupby([df.colA, df.colB]) df.sort('A', ascending=True) df.sort(['A', 'B'], ascending=True) df.sort(['A', 'B'], ascending=[1, 0]) ``` cc rxin Author: Davies Liu Closes #5544 from davies/compatibility and squashes the following commits: 4944058 [Davies Liu] add docstrings adb2816 [Davies Liu] Merge branch 'master' of github.com:apache/spark into compatibility bcbbcab [Davies Liu] support ascending as list 8dabdf0 [Davies Liu] improve API compatibility to pandas --- python/pyspark/sql/dataframe.py | 96 ++++++++++++++++++++++----------- python/pyspark/sql/functions.py | 11 ++-- python/pyspark/sql/tests.py | 2 +- 3 files changed, 70 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b9a3e6cfe7f49..326d22e72f104 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -485,13 +485,17 @@ def join(self, other, joinExprs=None, joinType=None): return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix - def sort(self, *cols): + def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). - :param cols: list of :class:`Column` to sort by. + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: sort by ascending order or not, could be bool, int + or list of bool, int (default: True). >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.sort("age", ascending=False).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> df.orderBy(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] >>> from pyspark.sql.functions import * @@ -499,16 +503,42 @@ def sort(self, *cols): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.orderBy(desc("age"), "name").collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ if not cols: raise ValueError("should sort by at least one column") - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be bool or list, but got %s" % type(ascending)) + + jdf = self._jdf.sort(self._jseq(jcols)) return DataFrame(jdf, self.sql_ctx) orderBy = sort + def _jseq(self, cols, converter=None): + """Return a JVM Seq of Columns from a list of Column or names""" + return _to_seq(self.sql_ctx._sc, cols, converter) + + def _jcols(self, *cols): + """Return a JVM Seq of Columns from a list of Column or column names + + If `cols` has only one list in it, cols[0] will be used as the list. + """ + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + return self._jseq(cols, _to_java_column) + def describe(self, *cols): """Computes statistics for numeric columns. @@ -523,9 +553,7 @@ def describe(self, *cols): min 2 max 5 """ - cols = ListConverter().convert(cols, - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)) + jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix @@ -607,9 +635,7 @@ def select(self, *cols): >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sql_ctx) def selectExpr(self, *expr): @@ -620,8 +646,9 @@ def selectExpr(self, *expr): >>> df.selectExpr("age * 2", "abs(age)").collect() [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)] """ - jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) - jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) + if len(expr) == 1 and isinstance(expr[0], list): + expr = expr[0] + jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sql_ctx) @ignore_unicode_prefix @@ -659,6 +686,8 @@ def groupBy(self, *cols): so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. + :func:`groupby` is an alias for :func:`groupBy`. + :param cols: list of columns to group by. Each element should be a column name (string) or an expression (:class:`Column`). @@ -668,12 +697,14 @@ def groupBy(self, *cols): [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] >>> df.groupBy(df.name).avg().collect() [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + >>> df.groupBy(['name', df.age]).count().collect() + [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + jdf = self._jdf.groupBy(self._jcols(*cols)) return GroupedData(jdf, self.sql_ctx) + groupby = groupBy + def agg(self, *exprs): """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy.agg()``). @@ -744,9 +775,7 @@ def dropna(self, how='any', thresh=None, subset=None): if thresh is None: thresh = len(subset) if how == 'any' else 1 - cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) - cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) - return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx) + return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) def fillna(self, value, subset=None): """Replace null values, alias for ``na.fill()``. @@ -799,9 +828,7 @@ def fillna(self, value, subset=None): elif not isinstance(subset, (list, tuple)): raise ValueError("subset should be a list or tuple of column names") - cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client) - cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols) - return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx) + return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) @ignore_unicode_prefix def withColumn(self, colName, col): @@ -862,10 +889,8 @@ def _api(self): def df_varargs_api(f): def _api(self, *args): - jargs = ListConverter().convert(args, - self.sql_ctx._sc._gateway._gateway_client) name = f.__name__ - jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs)) + jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) return DataFrame(jdf, self.sql_ctx) _api.__name__ = f.__name__ _api.__doc__ = f.__doc__ @@ -912,9 +937,8 @@ def agg(self, *exprs): else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jcols = ListConverter().convert([c._jc for c in exprs[1:]], - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + jdf = self._jdf.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) return DataFrame(jdf, self.sql_ctx) @dfapi @@ -1006,6 +1030,19 @@ def _to_java_column(col): return jcol +def _to_seq(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM Seq of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + jcols = ListConverter().convert(cols, sc._gateway._gateway_client) + return sc._jvm.PythonUtils.toSeq(jcols) + + def _unary_op(name, doc="unary operator"): """ Create a method for given unary operator """ def _(self): @@ -1177,8 +1214,7 @@ def inSet(self, *cols): cols = cols[0] cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] sc = SparkContext._active_spark_context - jcols = ListConverter().convert(cols, sc._gateway._gateway_client) - jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols)) + jc = getattr(self._jc, "in")(_to_seq(sc, cols)) return Column(jc) # order diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1d6536952810f..bb47923f24b82 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -23,13 +23,11 @@ if sys.version < "3": from itertools import imap as map -from py4j.java_collections import ListConverter - from pyspark import SparkContext from pyspark.rdd import _prepare_for_python_RDD from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column +from pyspark.sql.dataframe import Column, _to_java_column, _to_seq __all__ = ['countDistinct', 'approxCountDistinct', 'udf'] @@ -87,8 +85,7 @@ def countDistinct(col, *cols): [Row(c=2)] """ sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) - jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols)) + jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column)) return Column(jc) @@ -138,9 +135,7 @@ def __del__(self): def __call__(self, *cols): sc = SparkContext._active_spark_context - jcols = ListConverter().convert([_to_java_column(c) for c in cols], - sc._gateway._gateway_client) - jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols)) + jc = self._judf.apply(_to_seq(sc, cols, _to_java_column)) return Column(jc) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6691e8c8dc44b..aa3aa1d164d9f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -282,7 +282,7 @@ def test_apply_schema(self): StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), StructField("list1", ArrayType(ByteType(), False), False), StructField("null1", DoubleType(), True)]) - df = self.sqlCtx.applySchema(rdd, schema) + df = self.sqlCtx.createDataFrame(rdd, schema) results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), From 50ab8a6543ad5c31e89c16df374d0cb13222fd1e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 17 Apr 2015 14:21:51 -0500 Subject: [PATCH 06/33] [SPARK-2669] [yarn] Distribute client configuration to AM. Currently, when Spark launches the Yarn AM, the process will use the local Hadoop configuration on the node where the AM launches, if one is present. A more correct approach is to use the same configuration used to launch the Spark job, since the user may have made modifications (such as adding app-specific configs). The approach taken here is to use the distributed cache to make all files in the Hadoop configuration directory available to the AM. This is a little overkill since only the AM needs them (the executors use the broadcast Hadoop configuration from the driver), but is the easier approach. Even though only a few files in that directory may end up being used, all of them are uploaded. This allows supporting use cases such as when auxiliary configuration files are used for SSL configuration, or when uploading a Hive configuration directory. Not all of these may be reflected in a o.a.h.conf.Configuration object, but may be needed when a driver in cluster mode instantiates, for example, a HiveConf object instead. Author: Marcelo Vanzin Closes #4142 from vanzin/SPARK-2669 and squashes the following commits: f5434b9 [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 013f0fb [Marcelo Vanzin] Review feedback. f693152 [Marcelo Vanzin] Le sigh. ed45b7d [Marcelo Vanzin] Zip all config files and upload them as an archive. 5927b6b [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 cbb9fb3 [Marcelo Vanzin] Remove stale test. e3e58d0 [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 e3d0613 [Marcelo Vanzin] Review feedback. 34bdbd8 [Marcelo Vanzin] Fix test. 022a688 [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 a77ddd5 [Marcelo Vanzin] Merge branch 'master' into SPARK-2669 79221c7 [Marcelo Vanzin] [SPARK-2669] [yarn] Distribute client configuration to AM. --- docs/running-on-yarn.md | 6 +- .../org/apache/spark/deploy/yarn/Client.scala | 125 +++++++++++++++--- .../spark/deploy/yarn/ExecutorRunnable.scala | 2 +- .../spark/deploy/yarn/ClientSuite.scala | 29 ++-- .../spark/deploy/yarn/YarnClusterSuite.scala | 6 +- 5 files changed, 132 insertions(+), 36 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 853c9f26b0ec9..0968fc5ad632b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,7 +211,11 @@ Most of the configs are the same for Spark on YARN as for other deployment modes # Launching Spark on YARN Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. +These configs are used to write to the dfs and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 52e4dee46c535..019afbd1a1743 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,15 +17,18 @@ package org.apache.spark.deploy.yarn +import java.io.{File, FileOutputStream} import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException} import java.nio.ByteBuffer +import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe import scala.util.{Try, Success, Failure} import com.google.common.base.Objects +import com.google.common.io.Files import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration @@ -77,12 +80,6 @@ private[spark] class Client( def stop(): Unit = yarnClient.stop() - /* ------------------------------------------------------------------------------------- * - | The following methods have much in common in the stable and alpha versions of Client, | - | but cannot be implemented in the parent trait due to subtle API differences across | - | hadoop versions. | - * ------------------------------------------------------------------------------------- */ - /** * Submit an application running our ApplicationMaster to the ResourceManager. * @@ -223,6 +220,10 @@ private[spark] class Client( val fs = FileSystem.get(hadoopConf) val dst = new Path(fs.getHomeDirectory(), appStagingDir) val nns = getNameNodesToAccess(sparkConf) + dst + // Used to keep track of URIs added to the distributed cache. If the same URI is added + // multiple times, YARN will fail to launch containers for the app with an internal + // error. + val distributedUris = new HashSet[String] obtainTokensForNamenodes(nns, hadoopConf, credentials) obtainTokenForHiveMetastore(hadoopConf, credentials) @@ -241,6 +242,17 @@ private[spark] class Client( "for alternatives.") } + def addDistributedUri(uri: URI): Boolean = { + val uriStr = uri.toString() + if (distributedUris.contains(uriStr)) { + logWarning(s"Resource $uri added multiple times to distributed cache.") + false + } else { + distributedUris += uriStr + true + } + } + /** * Copy the given main resource to the distributed cache if the scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. @@ -258,11 +270,13 @@ private[spark] class Client( if (!localPath.isEmpty()) { val localURI = new URI(localPath) if (localURI.getScheme != LOCAL_SCHEME) { - val src = getQualifiedLocalPath(localURI, hadoopConf) - val destPath = copyFileToRemote(dst, src, replication) - val destFs = FileSystem.get(destPath.toUri(), hadoopConf) - distCacheMgr.addResource(destFs, hadoopConf, destPath, - localResources, LocalResourceType.FILE, destName, statCache) + if (addDistributedUri(localURI)) { + val src = getQualifiedLocalPath(localURI, hadoopConf) + val destPath = copyFileToRemote(dst, src, replication) + val destFs = FileSystem.get(destPath.toUri(), hadoopConf) + distCacheMgr.addResource(destFs, hadoopConf, destPath, + localResources, LocalResourceType.FILE, destName, statCache) + } } else if (confKey != null) { // If the resource is intended for local use only, handle this downstream // by setting the appropriate property @@ -271,6 +285,13 @@ private[spark] class Client( } } + createConfArchive().foreach { file => + require(addDistributedUri(file.toURI())) + val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication) + distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE, + LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true) + } + /** * Do the same for any additional resources passed in through ClientArguments. * Each resource category is represented by a 3-tuple of: @@ -288,13 +309,15 @@ private[spark] class Client( flist.split(',').foreach { file => val localURI = new URI(file.trim()) if (localURI.getScheme != LOCAL_SCHEME) { - val localPath = new Path(localURI) - val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) - val destPath = copyFileToRemote(dst, localPath, replication) - distCacheMgr.addResource( - fs, hadoopConf, destPath, localResources, resType, linkname, statCache) - if (addToClasspath) { - cachedSecondaryJarLinks += linkname + if (addDistributedUri(localURI)) { + val localPath = new Path(localURI) + val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) + val destPath = copyFileToRemote(dst, localPath, replication) + distCacheMgr.addResource( + fs, hadoopConf, destPath, localResources, resType, linkname, statCache) + if (addToClasspath) { + cachedSecondaryJarLinks += linkname + } } } else if (addToClasspath) { // Resource is intended for local use only and should be added to the class path @@ -310,6 +333,57 @@ private[spark] class Client( localResources } + /** + * Create an archive with the Hadoop config files for distribution. + * + * These are only used by the AM, since executors will use the configuration object broadcast by + * the driver. The files are zipped and added to the job as an archive, so that YARN will explode + * it when distributing to the AM. This directory is then added to the classpath of the AM + * process, just to make sure that everybody is using the same default config. + * + * This follows the order of precedence set by the startup scripts, in which HADOOP_CONF_DIR + * shows up in the classpath before YARN_CONF_DIR. + * + * Currently this makes a shallow copy of the conf directory. If there are cases where a + * Hadoop config directory contains subdirectories, this code will have to be fixed. + */ + private def createConfArchive(): Option[File] = { + val hadoopConfFiles = new HashMap[String, File]() + Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => + sys.env.get(envKey).foreach { path => + val dir = new File(path) + if (dir.isDirectory()) { + dir.listFiles().foreach { file => + if (!hadoopConfFiles.contains(file.getName())) { + hadoopConfFiles(file.getName()) = file + } + } + } + } + } + + if (!hadoopConfFiles.isEmpty) { + val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip", + new File(Utils.getLocalDir(sparkConf))) + + val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive)) + try { + hadoopConfStream.setLevel(0) + hadoopConfFiles.foreach { case (name, file) => + hadoopConfStream.putNextEntry(new ZipEntry(name)) + Files.copy(file, hadoopConfStream) + hadoopConfStream.closeEntry() + } + } finally { + hadoopConfStream.close() + } + + Some(hadoopConfArchive) + } else { + None + } + } + /** * Set up the environment for launching our ApplicationMaster container. */ @@ -317,7 +391,7 @@ private[spark] class Client( logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.driver.extraClassPath") - populateClasspath(args, yarnConf, sparkConf, env, extraCp) + populateClasspath(args, yarnConf, sparkConf, env, true, extraCp) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() @@ -718,6 +792,9 @@ object Client extends Logging { // Distribution-defined classpath to add to processes val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH" + // Subdirectory where the user's hadoop config files will be placed. + val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__" + /** * Find the user-defined Spark jar if configured, or return the jar containing this * class if not. @@ -831,11 +908,19 @@ object Client extends Logging { conf: Configuration, sparkConf: SparkConf, env: HashMap[String, String], + isAM: Boolean, extraClassPath: Option[String] = None): Unit = { extraClassPath.foreach(addClasspathEntry(_, env)) addClasspathEntry( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env ) + + if (isAM) { + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + + LOCALIZED_HADOOP_CONF_DIR, env) + } + if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { val userClassPath = if (args != null) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index b06069c07f451..9d04d241dae9e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -277,7 +277,7 @@ class ExecutorRunnable( private def prepareEnvironment(container: Container): HashMap[String, String] = { val env = new HashMap[String, String]() val extraCp = sparkConf.getOption("spark.executor.extraClassPath") - Client.populateClasspath(null, yarnConf, sparkConf, env, extraCp) + Client.populateClasspath(null, yarnConf, sparkConf, env, false, extraCp) sparkConf.getExecutorEnv.foreach { case (key, value) => // This assumes each executor environment variable set here is a path diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index c1b94ac9c5bdd..a51c2005cb472 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -20,6 +20,11 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.reflect.ClassTag +import scala.util.Try + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig @@ -30,11 +35,6 @@ import org.mockito.Matchers._ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ HashMap => MutableHashMap } -import scala.reflect.ClassTag -import scala.util.Try - import org.apache.spark.{SparkException, SparkConf} import org.apache.spark.util.Utils @@ -93,7 +93,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { val env = new MutableHashMap[String, String]() val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) - Client.populateClasspath(args, conf, sparkConf, env) + Client.populateClasspath(args, conf, sparkConf, env, true) val cp = env("CLASSPATH").split(":|;|") s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => @@ -104,13 +104,16 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll { cp should not contain (uri.getPath()) } }) - if (classOf[Environment].getMethods().exists(_.getName == "$$")) { - cp should contain("{{PWD}}") - } else if (Utils.isWindows) { - cp should contain("%PWD%") - } else { - cp should contain(Environment.PWD.$()) - } + val pwdVar = + if (classOf[Environment].getMethods().exists(_.getName == "$$")) { + "{{PWD}}" + } else if (Utils.isWindows) { + "%PWD%" + } else { + Environment.PWD.$() + } + cp should contain(pwdVar) + cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}") cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index a18c94d4ab4a8..3877da4120e7c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -77,6 +77,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit private var yarnCluster: MiniYARNCluster = _ private var tempDir: File = _ private var fakeSparkJar: File = _ + private var hadoopConfDir: File = _ private var logConfDir: File = _ override def beforeAll() { @@ -120,6 +121,9 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}") fakeSparkJar = File.createTempFile("sparkJar", null, tempDir) + hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR) + assert(hadoopConfDir.mkdir()) + File.createTempFile("token", ".txt", hadoopConfDir) } override def afterAll() { @@ -258,7 +262,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit appArgs Utils.executeAndGetOutput(argv, - extraEnvironment = Map("YARN_CONF_DIR" -> tempDir.getAbsolutePath())) + extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath())) } /** From a83571acc938582865efb41645aa1e414f339e46 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 17 Apr 2015 13:15:36 -0700 Subject: [PATCH 07/33] [SPARK-6113] [ml] Stabilize DecisionTree API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is a PR for cleaning up and finalizing the DecisionTree API. PRs for ensembles will follow once this is merged. ### Goal Here is the description copied from the JIRA (for both trees and ensembles): > **Issue**: The APIs for DecisionTree and ensembles (RandomForests and GradientBoostedTrees) have been experimental for a long time. The API has become very convoluted because trees and ensembles have many, many variants, some of which we have added incrementally without a long-term design. > **Proposal**: This JIRA is for discussing changes required to finalize the APIs. After we discuss, I will make a PR to update the APIs and make them non-Experimental. This will require making many breaking changes; see the design doc for details. > **[Design doc](https://docs.google.com/document/d/1rJ_DZinyDG3PkYkAKSsQlY0QgCeefn4hUv7GsPkzBP4)** : This outlines current issues and the proposed API. Overall code layout: * The old API in mllib.tree.* will remain the same. * The new API will reside in ml.classification.* and ml.regression.* ### Summary of changes Old API * Exactly the same, except I made 1 method in Loss private (but that is not a breaking change since that method was introduced after the Spark 1.3 release). New APIs * Under Pipeline API * The new API preserves functionality, except: * New API does NOT store prob (probability of label in classification). I want to have it store the full vector of probabilities but feel that should be in a later PR. * Use abstractions for parameters, estimators, and models to avoid code duplication * Limit parameters to relevant algorithms * For enum-like types, only expose Strings * We can make these pluggable later on by adding new parameters. That is a far-future item. Test suites * I organized DecisionTreeSuite, but I made absolutely no changes to the tests themselves. * The test suites for the new API only test (a) similarity with the results of the old API and (b) elements of the new API. * After code is moved to this new API, we should move the tests from the old suites which test the internals. ### Details #### Changed names Parameters * useNodeIdCache -> cacheNodeIds #### Other changes * Split: Changed categories to set instead of list #### Non-decision tree changes * AttributeGroup * Added parentheses to toMetadata, toStructField methods (These were removed in a previous PR, but I ran into 1 issue with the Scala compiler not being able to disambiguate between a toMetadata method with no parentheses and a toMetadata method which takes 1 argument.) * Attributes * Renamed: toMetadata -> toMetadataImpl * Added toMetadata methods which return ML metadata (keyed with “ML_ATTR”) * NominalAttribute: Added getNumValues method which examines both numValues and values. * Params.inheritValues: Checks whether the parent param really belongs to the child (to allow Estimator-Model pairs with different sets of parameters) ### Questions for reviewers * Is "DecisionTreeClassificationModel" too long a name? * Is this OK in the docs? ``` class DecisionTreeRegressor extends TreeRegressor[DecisionTreeRegressionModel] with DecisionTreeParams[DecisionTreeRegressor] with TreeRegressorParams[DecisionTreeRegressor] ``` ### Future We should open up the abstractions at some point. E.g., it would be useful to be able to set tree-related parameters in 1 place and then pass those to multiple tree-based algorithms. Follow-up JIRAs will be (in this order): * Tree ensembles * Deprecate old tree code * Move DecisionTree implementation code to new API. * Move tests from the old suites which test the internals. * Update programming guide * Python API * Change RandomForest* to always use bootstrapping, even when numTrees = 1 * Provide the probability of the predicted label for classification. After we move code to the new API and update it to maintain probabilities for all labels, then we can add the probabilities to the new API. CC: mengxr manishamde codedeft chouqin MechCoder Author: Joseph K. Bradley Closes #5530 from jkbradley/dt-api-dt and squashes the following commits: 6aae255 [Joseph K. Bradley] Changed tree abstractions not to take type parameters, and for setters to return this.type instead ec17947 [Joseph K. Bradley] Updates based on code review. Main changes were: moving public types from ml.impl.tree to ml.tree, modifying CategoricalSplit to take an Array of categories but store a Set internally, making more types sealed or final 5626c81 [Joseph K. Bradley] style fixes f8fbd24 [Joseph K. Bradley] imported reorg of DecisionTreeSuite from old PR. small cleanups 7ef63ed [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example (for real this time) e11673f [Joseph K. Bradley] Added DecisionTreeRegressor, test suites, and example 119f407 [Joseph K. Bradley] added DecisionTreeClassifier example 0bdc486 [Joseph K. Bradley] fixed issues after param PR was merged f9fbb60 [Joseph K. Bradley] Done with DecisionTreeClassifier, but no save/load yet. Need to add example as well 2532c9a [Joseph K. Bradley] partial move to spark.ml API, not done yet c72c1a0 [Joseph K. Bradley] Copied changes for common items, plus DecisionTreeClassifier from original PR --- .../examples/ml/DecisionTreeExample.scala | 322 +++++++++++++++ .../spark/ml/attribute/AttributeGroup.scala | 10 +- .../spark/ml/attribute/attributes.scala | 43 +- .../DecisionTreeClassifier.scala | 155 ++++++++ .../spark/ml/feature/StringIndexer.scala | 2 +- .../spark/ml/impl/tree/treeParams.scala | 300 ++++++++++++++ .../scala/org/apache/spark/ml/package.scala | 12 + .../org/apache/spark/ml/param/params.scala | 3 +- .../ml/regression/DecisionTreeRegressor.scala | 145 +++++++ .../scala/org/apache/spark/ml/tree/Node.scala | 205 ++++++++++ .../org/apache/spark/ml/tree/Split.scala | 151 +++++++ .../org/apache/spark/ml/tree/treeModels.scala | 60 +++ .../apache/spark/ml/util/MetadataUtils.scala | 82 ++++ .../spark/mllib/tree/DecisionTree.scala | 5 +- .../mllib/tree/GradientBoostedTrees.scala | 12 +- .../spark/mllib/tree/RandomForest.scala | 2 +- .../tree/configuration/BoostingStrategy.scala | 10 +- .../spark/mllib/tree/loss/AbsoluteError.scala | 5 +- .../spark/mllib/tree/loss/LogLoss.scala | 5 +- .../apache/spark/mllib/tree/loss/Loss.scala | 4 +- .../spark/mllib/tree/loss/SquaredError.scala | 5 +- .../mllib/tree/model/DecisionTreeModel.scala | 4 +- .../apache/spark/mllib/tree/model/Node.scala | 2 +- .../mllib/tree/model/treeEnsembleModels.scala | 32 +- .../JavaDecisionTreeClassifierSuite.java | 98 +++++ .../JavaDecisionTreeRegressorSuite.java | 97 +++++ .../ml/attribute/AttributeGroupSuite.scala | 4 +- .../spark/ml/attribute/AttributeSuite.scala | 42 +- .../DecisionTreeClassifierSuite.scala | 274 +++++++++++++ .../spark/ml/feature/VectorIndexerSuite.scala | 2 +- .../org/apache/spark/ml/impl/TreeTests.scala | 132 +++++++ .../DecisionTreeRegressorSuite.scala | 91 +++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 373 +++++++++--------- 33 files changed, 2426 insertions(+), 263 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java create mode 100644 mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala new file mode 100644 index 0000000000000..d4cc8dede07ef --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.collection.mutable +import scala.language.reflectiveCalls + +import scopt.OptionParser + +import org.apache.spark.ml.tree.DecisionTreeModel +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.examples.mllib.AbstractParams +import org.apache.spark.ml.{Pipeline, PipelineStage} +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} +import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{SQLContext, DataFrame} + + +/** + * An example runner for decision trees. Run with + * {{{ + * ./bin/run-example ml.DecisionTreeExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DecisionTreeExample { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + maxBins: Int = 32, + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + numTrees: Int = 1, + featureSubsetStrategy: String = "auto", + fracTest: Double = 0.2, + cacheNodeIds: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DecisionTreeExample") { + head("DecisionTreeExample: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("maxBins") + .text(s"max number of bins, default: ${defaultParams.maxBins}") + .action((x, c) => c.copy(maxBins = x)) + opt[Int]("minInstancesPerNode") + .text(s"min number of instances required at child nodes to create the parent split," + + s" default: ${defaultParams.minInstancesPerNode}") + .action((x, c) => c.copy(minInstancesPerNode = x)) + opt[Double]("minInfoGain") + .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") + .action((x, c) => c.copy(minInfoGain = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("cacheNodeIds") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.cacheNodeIds}") + .action((x, c) => c.copy(cacheNodeIds = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + }}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + /** Load a dataset from the given path, using the given format */ + private[ml] def loadData( + sc: SparkContext, + path: String, + format: String, + expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = { + format match { + case "dense" => MLUtils.loadLabeledPoints(sc, path) + case "libsvm" => expectedNumFeatures match { + case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures) + case None => MLUtils.loadLibSVMFile(sc, path) + } + case _ => throw new IllegalArgumentException(s"Bad data format: $format") + } + } + + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset) + */ + private[ml] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: String, + fracTest: Double): (DataFrame, DataFrame) = { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Load training data + val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat) + + // Load or create test set + val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { + // Load testInput. + val numFeatures = origExamples.take(1)(0).features.size + val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures)) + Array(origExamples, origTestExamples) + } else { + // Split input into training, test. + origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) + } + + // For classification, convert labels to Strings since we will index them later with + // StringIndexer. + def labelsToStrings(data: DataFrame): DataFrame = { + algo.toLowerCase match { + case "classification" => + data.withColumn("labelString", data("label").cast(StringType)) + case "regression" => + data + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + } + val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache()) + + (dataframes(0), dataframes(1)) + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params") + val sc = new SparkContext(conf) + params.checkpointDir.foreach(sc.setCheckpointDir) + val algo = params.algo.toLowerCase + + println(s"DecisionTreeExample with parameters:\n$params") + + // Load training and test data and cache it. + val (training: DataFrame, test: DataFrame) = + loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest) + + val numTraining = training.count() + val numTest = test.count() + val numFeatures = training.select("features").first().getAs[Vector](0).size + println("Loaded data:") + println(s" numTraining = $numTraining, numTest = $numTest") + println(s" numFeatures = $numFeatures") + + // Set up Pipeline + val stages = new mutable.ArrayBuffer[PipelineStage]() + // (1) For classification, re-index classes. + val labelColName = if (algo == "classification") "indexedLabel" else "label" + if (algo == "classification") { + val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName) + stages += labelIndexer + } + // (2) Identify categorical features using VectorIndexer. + // Features with more than maxCategories values will be treated as continuous. + val featuresIndexer = new VectorIndexer().setInputCol("features") + .setOutputCol("indexedFeatures").setMaxCategories(10) + stages += featuresIndexer + // (3) Learn DecisionTree + val dt = algo match { + case "classification" => + new DecisionTreeClassifier().setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case "regression" => + new DecisionTreeRegressor().setFeaturesCol("indexedFeatures") + .setLabelCol(labelColName) + .setMaxDepth(params.maxDepth) + .setMaxBins(params.maxBins) + .setMinInstancesPerNode(params.minInstancesPerNode) + .setMinInfoGain(params.minInfoGain) + .setCacheNodeIds(params.cacheNodeIds) + .setCheckpointInterval(params.checkpointInterval) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + stages += dt + val pipeline = new Pipeline().setStages(stages.toArray) + + // Fit the Pipeline + val startTime = System.nanoTime() + val pipelineModel = pipeline.fit(training) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + + // Get the trained Decision Tree from the fitted PipelineModel + val treeModel: DecisionTreeModel = algo match { + case "classification" => + pipelineModel.getModel[DecisionTreeClassificationModel]( + dt.asInstanceOf[DecisionTreeClassifier]) + case "regression" => + pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor]) + case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + if (treeModel.numNodes < 20) { + println(treeModel.toDebugString) // Print full model. + } else { + println(treeModel) // Print model summary. + } + + // Predict on training + val trainingFullPredictions = pipelineModel.transform(training).cache() + val trainingPredictions = trainingFullPredictions.select("prediction") + .map(_.getDouble(0)) + val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0)) + // Predict on test data + val testFullPredictions = pipelineModel.transform(test).cache() + val testPredictions = testFullPredictions.select("prediction") + .map(_.getDouble(0)) + val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0)) + + // For classification, print number of classes for reference. + if (algo == "classification") { + val numClasses = + MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match { + case Some(n) => n + case None => throw new RuntimeException( + "DecisionTreeExample had unknown failure when indexing labels for classification.") + } + println(s"numClasses = $numClasses.") + } + + // Evaluate model on training, test data + algo match { + case "classification" => + val trainingAccuracy = + new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision + println(s"Train accuracy = $trainingAccuracy") + val testAccuracy = + new MulticlassMetrics(testPredictions.zip(testLabels)).precision + println(s"Test accuracy = $testAccuracy") + case "regression" => + val trainingRMSE = + new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError + println(s"Training root mean squared error (RMSE) = $trainingRMSE") + val testRMSE = + new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError + println(s"Test root mean squared error (RMSE) = $testRMSE") + case _ => + throw new IllegalArgumentException("Algo ${params.algo} not supported.") + } + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index aa27a668f1695..d7dee8fed2a55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -117,12 +117,12 @@ class AttributeGroup private ( case numeric: NumericAttribute => // Skip default numeric attributes. if (numeric.withoutIndex != NumericAttribute.defaultAttr) { - numericMetadata += numeric.toMetadata(withType = false) + numericMetadata += numeric.toMetadataImpl(withType = false) } case nominal: NominalAttribute => - nominalMetadata += nominal.toMetadata(withType = false) + nominalMetadata += nominal.toMetadataImpl(withType = false) case binary: BinaryAttribute => - binaryMetadata += binary.toMetadata(withType = false) + binaryMetadata += binary.toMetadataImpl(withType = false) } val attrBldr = new MetadataBuilder if (numericMetadata.nonEmpty) { @@ -151,7 +151,7 @@ class AttributeGroup private ( } /** Converts to ML metadata */ - def toMetadata: Metadata = toMetadata(Metadata.empty) + def toMetadata(): Metadata = toMetadata(Metadata.empty) /** Converts to a StructField with some existing metadata. */ def toStructField(existingMetadata: Metadata): StructField = { @@ -159,7 +159,7 @@ class AttributeGroup private ( } /** Converts to a StructField. */ - def toStructField: StructField = toStructField(Metadata.empty) + def toStructField(): StructField = toStructField(Metadata.empty) override def equals(other: Any): Boolean = { other match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index 00b7566aab434..5717d6ec2eaec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -68,21 +68,32 @@ sealed abstract class Attribute extends Serializable { * Converts this attribute to [[Metadata]]. * @param withType whether to include the type info */ - private[attribute] def toMetadata(withType: Boolean): Metadata + private[attribute] def toMetadataImpl(withType: Boolean): Metadata /** * Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to * save space, because numeric type is the default attribute type. For nominal and binary * attributes, the type info is included. */ - private[attribute] def toMetadata(): Metadata = { + private[attribute] def toMetadataImpl(): Metadata = { if (attrType == AttributeType.Numeric) { - toMetadata(withType = false) + toMetadataImpl(withType = false) } else { - toMetadata(withType = true) + toMetadataImpl(withType = true) } } + /** Converts to ML metadata with some existing metadata. */ + def toMetadata(existingMetadata: Metadata): Metadata = { + new MetadataBuilder() + .withMetadata(existingMetadata) + .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl()) + .build() + } + + /** Converts to ML metadata */ + def toMetadata(): Metadata = toMetadata(Metadata.empty) + /** * Converts to a [[StructField]] with some existing metadata. * @param existingMetadata existing metadata to carry over @@ -90,7 +101,7 @@ sealed abstract class Attribute extends Serializable { def toStructField(existingMetadata: Metadata): StructField = { val newMetadata = new MetadataBuilder() .withMetadata(existingMetadata) - .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata()) + .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl()) .build() StructField(name.get, DoubleType, nullable = false, newMetadata) } @@ -98,7 +109,7 @@ sealed abstract class Attribute extends Serializable { /** Converts to a [[StructField]]. */ def toStructField(): StructField = toStructField(Metadata.empty) - override def toString: String = toMetadata(withType = true).toString + override def toString: String = toMetadataImpl(withType = true).toString } /** Trait for ML attribute factories. */ @@ -210,7 +221,7 @@ class NumericAttribute private[ml] ( override def isNominal: Boolean = false /** Convert this attribute to metadata. */ - private[attribute] override def toMetadata(withType: Boolean): Metadata = { + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { import org.apache.spark.ml.attribute.AttributeKeys._ val bldr = new MetadataBuilder() if (withType) bldr.putString(TYPE, attrType.name) @@ -353,6 +364,20 @@ class NominalAttribute private[ml] ( /** Copy without the `numValues`. */ def withoutNumValues: NominalAttribute = copy(numValues = None) + /** + * Get the number of values, either from `numValues` or from `values`. + * Return None if unknown. + */ + def getNumValues: Option[Int] = { + if (numValues.nonEmpty) { + numValues + } else if (values.nonEmpty) { + Some(values.get.length) + } else { + None + } + } + /** Creates a copy of this attribute with optional changes. */ private def copy( name: Option[String] = name, @@ -363,7 +388,7 @@ class NominalAttribute private[ml] ( new NominalAttribute(name, index, isOrdinal, numValues, values) } - private[attribute] override def toMetadata(withType: Boolean): Metadata = { + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { import org.apache.spark.ml.attribute.AttributeKeys._ val bldr = new MetadataBuilder() if (withType) bldr.putString(TYPE, attrType.name) @@ -465,7 +490,7 @@ class BinaryAttribute private[ml] ( new BinaryAttribute(name, index, values) } - private[attribute] override def toMetadata(withType: Boolean): Metadata = { + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { import org.apache.spark.ml.attribute.AttributeKeys._ val bldr = new MetadataBuilder if (withType) bldr.putString(TYPE, attrType.name) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala new file mode 100644 index 0000000000000..3855e396b5534 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class DecisionTreeClassifier + extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + with DecisionTreeParams + with TreeClassifierParams { + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = + super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = + super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): DecisionTreeClassificationModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match { + case Some(n: Int) => n + case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + + s" with invalid label column, without the number of classes specified.") + // TODO: Automatically index labels. + } + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = getOldStrategy(categoricalFeatures, numClasses) + val oldModel = OldDecisionTree.train(oldDataset, strategy) + DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + override private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses) + strategy.algo = OldAlgo.Classification + strategy.setImpurity(getOldImpurity) + strategy + } +} + +object DecisionTreeClassifier { + /** Accessor for supported impurities */ + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@AlphaComponent +final class DecisionTreeClassificationModel private[ml] ( + override val parent: DecisionTreeClassifier, + override val fittingParamMap: ParamMap, + override val rootNode: Node) + extends PredictionModel[Vector, DecisionTreeClassificationModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + override protected def predict(features: Vector): Double = { + rootNode.predict(features) + } + + override protected def copy(): DecisionTreeClassificationModel = { + val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) + } +} + +private[ml] object DecisionTreeClassificationModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeClassifier, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, + s"Cannot convert non-classification DecisionTreeModel (old API) to" + + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 4d960df357fe9..23956c512c8a6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -118,7 +118,7 @@ class StringIndexerModel private[ml] ( } val outputColName = map(outputCol) val metadata = NominalAttribute.defaultAttr - .withName(outputColName).withValues(labels).toStructField().metadata + .withName(outputColName).withValues(labels).toMetadata() dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala new file mode 100644 index 0000000000000..6f4509f03d033 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl.tree + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.impl.estimator.PredictorParams +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy, + Impurity => OldImpurity, Variance => OldVariance} + + +/** + * :: DeveloperApi :: + * Parameters for Decision Tree-based algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait DecisionTreeParams extends PredictorParams { + + /** + * Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default = 5) + * @group param + */ + final val maxDepth: IntParam = + new IntParam(this, "maxDepth", "Maximum depth of the tree." + + " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.") + + /** + * Maximum number of bins used for discretizing continuous features and for choosing how to split + * on features at each node. More bins give higher granularity. + * Must be >= 2 and >= number of categories in any categorical feature. + * (default = 32) + * @group param + */ + final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + + " discretizing continuous features. Must be >=2 and >= number of categories for any" + + " categorical feature.") + + /** + * Minimum number of instances each child must have after split. + * If a split causes the left or right child to have fewer than minInstancesPerNode, + * the split will be discarded as invalid. + * Should be >= 1. + * (default = 1) + * @group param + */ + final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + + " number of instances each child must have after split. If a split causes the left or right" + + " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + + " Should be >= 1.") + + /** + * Minimum information gain for a split to be considered at a tree node. + * (default = 0.0) + * @group param + */ + final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", + "Minimum information gain for a split to be considered at a tree node.") + + /** + * Maximum memory in MB allocated to histogram aggregation. + * (default = 256 MB) + * @group expertParam + */ + final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", + "Maximum memory in MB allocated to histogram aggregation.") + + /** + * If false, the algorithm will pass trees to executors to match instances with nodes. + * If true, the algorithm will cache node IDs for each instance. + * Caching can speed up training of deeper trees. + * (default = false) + * @group expertParam + */ + final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" + + " algorithm will pass trees to executors to match instances with nodes. If true, the" + + " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + + " trees.") + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertParam + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + + " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + + " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + + " checkpoint directory is set in the SparkContext. Must be >= 1.") + + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + + /** @group setParam */ + def setMaxDepth(value: Int): this.type = { + require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") + set(maxDepth, value) + this.asInstanceOf[this.type] + } + + /** @group getParam */ + def getMaxDepth: Int = getOrDefault(maxDepth) + + /** @group setParam */ + def setMaxBins(value: Int): this.type = { + require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value") + set(maxBins, value) + this + } + + /** @group getParam */ + def getMaxBins: Int = getOrDefault(maxBins) + + /** @group setParam */ + def setMinInstancesPerNode(value: Int): this.type = { + require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value") + set(minInstancesPerNode, value) + this + } + + /** @group getParam */ + def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode) + + /** @group setParam */ + def setMinInfoGain(value: Double): this.type = { + set(minInfoGain, value) + this + } + + /** @group getParam */ + def getMinInfoGain: Double = getOrDefault(minInfoGain) + + /** @group expertSetParam */ + def setMaxMemoryInMB(value: Int): this.type = { + require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value") + set(maxMemoryInMB, value) + this + } + + /** @group expertGetParam */ + def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB) + + /** @group expertSetParam */ + def setCacheNodeIds(value: Boolean): this.type = { + set(cacheNodeIds, value) + this + } + + /** @group expertGetParam */ + def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds) + + /** @group expertSetParam */ + def setCheckpointInterval(value: Int): this.type = { + require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value") + set(checkpointInterval, value) + this + } + + /** @group expertGetParam */ + def getCheckpointInterval: Int = getOrDefault(checkpointInterval) + + /** + * Create a Strategy instance to use with the old API. + * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0, + * the default for single trees). + */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + val strategy = OldStrategy.defaultStategy(OldAlgo.Classification) + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = 1.0 // default for individual trees + strategy + } +} + +/** + * (private trait) Parameters for Decision Tree-based classification algorithms. + */ +private[ml] trait TreeClassifierParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "entropy" and "gini". + * (default = gini) + * @group param + */ + val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + + setDefault(impurity -> "gini") + + /** @group setParam */ + def setImpurity(value: String): this.type = { + val impurityStr = value.toLowerCase + require(TreeClassifierParams.supportedImpurities.contains(impurityStr), + s"Tree-based classifier was given unrecognized impurity: $value." + + s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}") + set(impurity, impurityStr) + this + } + + /** @group getParam */ + def getImpurity: String = getOrDefault(impurity) + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "entropy" => OldEntropy + case "gini" => OldGini + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeClassifierParams was given unrecognized impurity: $impurity.") + } + } +} + +private[ml] object TreeClassifierParams { + // These options should be lowercase. + val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) +} + +/** + * (private trait) Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "variance". + * (default = variance) + * @group param + */ + val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + + setDefault(impurity -> "variance") + + /** @group setParam */ + def setImpurity(value: String): this.type = { + val impurityStr = value.toLowerCase + require(TreeRegressorParams.supportedImpurities.contains(impurityStr), + s"Tree-based regressor was given unrecognized impurity: $value." + + s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}") + set(impurity, impurityStr) + this + } + + /** @group getParam */ + def getImpurity: String = getOrDefault(impurity) + + /** Convert new impurity to old impurity. */ + protected def getOldImpurity: OldImpurity = { + getImpurity match { + case "variance" => OldVariance + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeRegressorParams was given unrecognized impurity: $impurity") + } + } +} + +private[ml] object TreeRegressorParams { + // These options should be lowercase. + val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala index b45bd1499b72e..ac75e9de1a8f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -32,6 +32,18 @@ package org.apache.spark * @groupname getParam Parameter getters * @groupprio getParam 6 * + * @groupname expertParam (expert-only) Parameters + * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can + * take. Users can set and get the parameter values through setters and getters, + * respectively. + * @groupprio expertParam 7 + * + * @groupname expertSetParam (expert-only) Parameter setters + * @groupprio expertSetParam 8 + * + * @groupname expertGetParam (expert-only) Parameter getters + * @groupprio expertGetParam 9 + * * @groupname Ungrouped Members * @groupprio Ungrouped 0 */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 849c60433c777..ddc5907e7facd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -296,8 +296,9 @@ private[spark] object Params { paramMap: ParamMap, parent: E, child: M): Unit = { + val childParams = child.params.map(_.name).toSet parent.params.foreach { param => - if (paramMap.contains(param)) { + if (paramMap.contains(param) && childParams.contains(param.name)) { child.set(child.getParam(param.name), paramMap(param)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala new file mode 100644 index 0000000000000..49a8b77acf960 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm + * for regression. + * It supports both continuous and categorical features. + */ +@AlphaComponent +final class DecisionTreeRegressor + extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] + with DecisionTreeParams + with TreeRegressorParams { + + // Override parameter setters from parent trait for Java API compatibility. + + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) + + override def setCheckpointInterval(value: Int): this.type = + super.setCheckpointInterval(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): DecisionTreeRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap) + val strategy = getOldStrategy(categoricalFeatures) + val oldModel = OldDecisionTree.train(oldDataset, strategy) + DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures) + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0) + strategy.algo = OldAlgo.Regression + strategy.setImpurity(getOldImpurity) + strategy + } +} + +object DecisionTreeRegressor { + /** Accessor for supported impurities */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities +} + +/** + * :: AlphaComponent :: + * + * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression. + * It supports both continuous and categorical features. + * @param rootNode Root of the decision tree + */ +@AlphaComponent +final class DecisionTreeRegressionModel private[ml] ( + override val parent: DecisionTreeRegressor, + override val fittingParamMap: ParamMap, + override val rootNode: Node) + extends PredictionModel[Vector, DecisionTreeRegressionModel] + with DecisionTreeModel with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + override protected def predict(features: Vector): Double = { + rootNode.predict(features) + } + + override protected def copy(): DecisionTreeRegressionModel = { + val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) + Params.inheritValues(this.extractParamMap(), this, m) + m + } + + override def toString: String = { + s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes" + } + + /** Convert to a model in the old API */ + private[ml] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) + } +} + +private[ml] object DecisionTreeRegressionModel { + + /** (private[ml]) Convert a model from the old API */ + def fromOld( + oldModel: OldDecisionTreeModel, + parent: DecisionTreeRegressor, + fittingParamMap: ParamMap, + categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, + s"Cannot convert non-regression DecisionTreeModel (old API) to" + + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") + val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) + new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala new file mode 100644 index 0000000000000..d6e2203d9f937 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, + Node => OldNode, Predict => OldPredict} + + +/** + * Decision tree node interface. + */ +sealed abstract class Node extends Serializable { + + // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree + // code into the new API and deprecate the old API. + + /** Prediction this node makes (or would make, if it is an internal node) */ + def prediction: Double + + /** Impurity measure at this node (for training data) */ + def impurity: Double + + /** Recursive prediction helper method */ + private[ml] def predict(features: Vector): Double = prediction + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes. + */ + private[tree] def subtreeDepth: Int + + /** + * Create a copy of this node in the old Node format, recursively creating child nodes as needed. + * @param id Node ID using old format IDs + */ + private[ml] def toOld(id: Int): OldNode +} + +private[ml] object Node { + + /** + * Create a new Node from the old Node format, recursively creating child nodes as needed. + */ + def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = { + if (oldNode.isLeaf) { + // TODO: Once the implementation has been moved to this API, then include sufficient + // statistics here. + new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) + } else { + val gain = if (oldNode.stats.nonEmpty) { + oldNode.stats.get.gain + } else { + 0.0 + } + new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, + gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), + split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + } + } +} + +/** + * Decision tree leaf node. + * @param prediction Prediction this node makes + * @param impurity Impurity measure at this node (for training data) + */ +final class LeafNode private[ml] ( + override val prediction: Double, + override val impurity: Double) extends Node { + + override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" + + override private[ml] def predict(features: Vector): Double = prediction + + override private[tree] def numDescendants: Int = 0 + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"Predict: $prediction\n" + } + + override private[tree] def subtreeDepth: Int = 0 + + override private[ml] def toOld(id: Int): OldNode = { + // NOTE: We do NOT store 'prob' in the new API currently. + new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, + None, None, None, None) + } +} + +/** + * Internal Decision Tree node. + * @param prediction Prediction this node would make if it were a leaf node + * @param impurity Impurity measure at this node (for training data) + * @param gain Information gain value. + * Values < 0 indicate missing values; this quirk will be removed with future updates. + * @param leftChild Left-hand child node + * @param rightChild Right-hand child node + * @param split Information about the test used to split to the left or right child. + */ +final class InternalNode private[ml] ( + override val prediction: Double, + override val impurity: Double, + val gain: Double, + val leftChild: Node, + val rightChild: Node, + val split: Split) extends Node { + + override def toString: String = { + s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" + } + + override private[ml] def predict(features: Vector): Double = { + if (split.shouldGoLeft(features)) { + leftChild.predict(features) + } else { + rightChild.predict(features) + } + } + + override private[tree] def numDescendants: Int = { + 2 + leftChild.numDescendants + rightChild.numDescendants + } + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" + + leftChild.subtreeToString(indentFactor + 1) + + prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" + + rightChild.subtreeToString(indentFactor + 1) + } + + override private[tree] def subtreeDepth: Int = { + 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth) + } + + override private[ml] def toOld(id: Int): OldNode = { + assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + + " since the old API does not support deep trees.") + // NOTE: We do NOT store 'prob' in the new API currently. + new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, + Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + Some(rightChild.toOld(OldNode.rightChildIndex(id))), + Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, + new OldPredict(leftChild.prediction, prob = 0.0), + new OldPredict(rightChild.prediction, prob = 0.0)))) + } +} + +private object InternalNode { + + /** + * Helper method for [[Node.subtreeToString()]]. + * @param split Split to print + * @param left Indicates whether this is the part of the split going to the left, + * or that going to the right. + */ + private def splitToString(split: Split, left: Boolean): String = { + val featureStr = s"feature ${split.featureIndex}" + split match { + case contSplit: ContinuousSplit => + if (left) { + s"$featureStr <= ${contSplit.threshold}" + } else { + s"$featureStr > ${contSplit.threshold}" + } + case catSplit: CategoricalSplit => + val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}") + if (left) { + s"$featureStr in $categoriesStr" + } else { + s"$featureStr not in $categoriesStr" + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala new file mode 100644 index 0000000000000..cb940f62990ed --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} +import org.apache.spark.mllib.tree.model.{Split => OldSplit} + + +/** + * Interface for a "Split," which specifies a test made at a decision tree node + * to choose the left or right path. + */ +sealed trait Split extends Serializable { + + /** Index of feature which this split tests */ + def featureIndex: Int + + /** Return true (split to left) or false (split to right) */ + private[ml] def shouldGoLeft(features: Vector): Boolean + + /** Convert to old Split format */ + private[tree] def toOld: OldSplit +} + +private[ml] object Split { + + def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { + oldSplit.featureType match { + case OldFeatureType.Categorical => + new CategoricalSplit(featureIndex = oldSplit.feature, + leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) + case OldFeatureType.Continuous => + new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold) + } + } +} + +/** + * Split which tests a categorical feature. + * @param featureIndex Index of the feature to test + * @param leftCategories If the feature value is in this set of categories, then the split goes + * left. Otherwise, it goes right. + * @param numCategories Number of categories for this feature. + */ +final class CategoricalSplit( + override val featureIndex: Int, + leftCategories: Array[Double], + private val numCategories: Int) + extends Split { + + require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + + s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}") + + /** + * If true, then "categories" is the set of categories for splitting to the left, and vice versa. + */ + private val isLeft: Boolean = leftCategories.length <= numCategories / 2 + + /** Set of categories determining the splitting rule, along with [[isLeft]]. */ + private val categories: Set[Double] = { + if (isLeft) { + leftCategories.toSet + } else { + setComplement(leftCategories.toSet) + } + } + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + if (isLeft) { + categories.contains(features(featureIndex)) + } else { + !categories.contains(features(featureIndex)) + } + } + + override def equals(o: Any): Boolean = { + o match { + case other: CategoricalSplit => featureIndex == other.featureIndex && + isLeft == other.isLeft && categories == other.categories + case _ => false + } + } + + override private[tree] def toOld: OldSplit = { + val oldCats = if (isLeft) { + categories + } else { + setComplement(categories) + } + OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList) + } + + /** Get sorted categories which split to the left */ + def getLeftCategories: Array[Double] = { + val cats = if (isLeft) categories else setComplement(categories) + cats.toArray.sorted + } + + /** Get sorted categories which split to the right */ + def getRightCategories: Array[Double] = { + val cats = if (isLeft) setComplement(categories) else categories + cats.toArray.sorted + } + + /** [0, numCategories) \ cats */ + private def setComplement(cats: Set[Double]): Set[Double] = { + Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet + } +} + +/** + * Split which tests a continuous feature. + * @param featureIndex Index of the feature to test + * @param threshold If the feature value is <= this threshold, then the split goes left. + * Otherwise, it goes right. + */ +final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split { + + override private[ml] def shouldGoLeft(features: Vector): Boolean = { + features(featureIndex) <= threshold + } + + override def equals(o: Any): Boolean = { + o match { + case other: ContinuousSplit => + featureIndex == other.featureIndex && threshold == other.threshold + case _ => + false + } + } + + override private[tree] def toOld: OldSplit = { + OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala new file mode 100644 index 0000000000000..8e3bc3849dcf0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.annotation.AlphaComponent + + +/** + * :: AlphaComponent :: + * + * Abstraction for Decision Tree models. + * + * TODO: Add support for predicting probabilities and raw predictions + */ +@AlphaComponent +trait DecisionTreeModel { + + /** Root of the decision tree */ + def rootNode: Node + + /** Number of nodes in tree, including leaf nodes. */ + def numNodes: Int = { + 1 + rootNode.numDescendants + } + + /** + * Depth of the tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + lazy val depth: Int = { + rootNode.subtreeDepth + } + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"DecisionTreeModel of depth $depth with $numNodes nodes" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + rootNode.subtreeToString(2) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala new file mode 100644 index 0000000000000..c84c8b4eb744f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.immutable.HashMap + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, + NumericAttribute} +import org.apache.spark.sql.types.StructField + + +/** + * :: Experimental :: + * + * Helper utilities for tree-based algorithms + */ +@Experimental +object MetadataUtils { + + /** + * Examine a schema to identify the number of classes in a label column. + * Returns None if the number of labels is not specified, or if the label column is continuous. + */ + def getNumClasses(labelSchema: StructField): Option[Int] = { + Attribute.fromStructField(labelSchema) match { + case numAttr: NumericAttribute => None + case binAttr: BinaryAttribute => Some(2) + case nomAttr: NominalAttribute => nomAttr.getNumValues + } + } + + /** + * Examine a schema to identify categorical (Binary and Nominal) features. + * + * @param featuresSchema Schema of the features column. + * If a feature does not have metadata, it is assumed to be continuous. + * If a feature is Nominal, then it must have the number of values + * specified. + * @return Map: feature index --> number of categories. + * The map's set of keys will be the set of categorical feature indices. + */ + def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = { + val metadata = AttributeGroup.fromStructField(featuresSchema) + if (metadata.attributes.isEmpty) { + HashMap.empty[Int, Int] + } else { + metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) => + if (attr == null) { + Iterator() + } else { + attr match { + case numAttr: NumericAttribute => Iterator() + case binAttr: BinaryAttribute => Iterator(idx -> 2) + case nomAttr: NominalAttribute => + nomAttr.getNumValues match { + case Some(numValues: Int) => Iterator(idx -> numValues) + case None => throw new IllegalArgumentException(s"Feature $idx is marked as" + + " Nominal (categorical), but it does not have the number of values specified.") + } + } + } + }.toMap + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b9d0c56dd1ea3..dfe3a0b6913ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1147,7 +1147,10 @@ object DecisionTree extends Serializable with Logging { } } - assert(splits.length > 0) + // TODO: Do not fail; just ignore the useless feature. + assert(splits.length > 0, + s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + + " Please remove this feature and then try again.") // set number of splits accordingly metadata.setNumSplits(featureIndex, splits.length) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index c02c79f094b66..0e31c7ed58df8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -81,11 +81,11 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) /** * Method to validate a gradient boosting model * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @param validationInput Validation dataset: - RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - Should be different from and follow the same distribution as input. - e.g., these two datasets could be created from an original dataset - by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @param validationInput Validation dataset. + * This dataset should be different from the training dataset, + * but it should follow the same distribution. + * E.g., these two datasets could be created from an original dataset + * by using [[org.apache.spark.rdd.RDD.randomSplit()]] * @return a gradient boosted trees model that can be used for prediction */ def runWithValidation( @@ -194,8 +194,6 @@ object GradientBoostedTrees extends Logging { val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight - val startingModel = new GradientBoostedTreesModel( - Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1)) var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index db01f2e229e5a..055e60c7d9c95 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -249,7 +249,7 @@ private class RandomForest ( nodeIdCache.get.deleteAllCheckpoints() } catch { case e:IOException => - logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}") + logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 664c8df019233..2d6b01524ff3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -89,14 +89,14 @@ object BoostingStrategy { * @return Configuration for boosting algorithm */ def defaultParams(algo: Algo): BoostingStrategy = { - val treeStragtegy = Strategy.defaultStategy(algo) - treeStragtegy.maxDepth = 3 + val treeStrategy = Strategy.defaultStategy(algo) + treeStrategy.maxDepth = 3 algo match { case Algo.Classification => - treeStragtegy.numClasses = 2 - new BoostingStrategy(treeStragtegy, LogLoss) + treeStrategy.numClasses = 2 + new BoostingStrategy(treeStrategy, LogLoss) case Algo.Regression => - new BoostingStrategy(treeStragtegy, SquaredError) + new BoostingStrategy(treeStrategy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by boosting.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index 6f570b4e09c79..2bdef73c4a8f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -45,9 +45,8 @@ object AbsoluteError extends Loss { if (label - prediction < 0) 1.0 else -1.0 } - override def computeError(prediction: Double, label: Double): Double = { + override private[mllib] def computeError(prediction: Double, label: Double): Double = { val err = label - prediction math.abs(err) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 24ee9f3d51293..778c24526de70 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -47,10 +47,9 @@ object LogLoss extends Loss { - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction)) } - override def computeError(prediction: Double, label: Double): Double = { + override private[mllib] def computeError(prediction: Double, label: Double): Double = { val margin = 2.0 * label * prediction // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. 2.0 * MLUtils.log1pExp(-margin) } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index d3b82b752fa0d..64ffccbce073f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. @@ -57,6 +58,5 @@ trait Loss extends Serializable { * @param label True label. * @return Measure of model error on datapoint. */ - def computeError(prediction: Double, label: Double): Double - + private[mllib] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 58857ae15e93e..a5582d3ef3324 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel -import org.apache.spark.rdd.RDD + /** * :: DeveloperApi :: @@ -45,9 +45,8 @@ object SquaredError extends Loss { 2.0 * (prediction - label) } - override def computeError(prediction: Double, label: Double): Double = { + override private[mllib] def computeError(prediction: Double, label: Double): Double = { val err = prediction - label err * err } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index c9bafd60fba4d..331af428533de 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -113,11 +113,13 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable DecisionTreeModel.SaveLoadV1_0.save(sc, path, this) } - override protected def formatVersion: String = "1.0" + override protected def formatVersion: String = DecisionTreeModel.formatVersion } object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { + private[spark] def formatVersion: String = "1.0" + private[tree] object SaveLoadV1_0 { def thisFormatVersion: String = "1.0" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 4f72bb8014cc0..708ba04b567d3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -175,7 +175,7 @@ class Node ( } } -private[tree] object Node { +private[spark] object Node { /** * Return a node with the given node id (but nothing else set). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index fef3d2acb202a..8341219bfa71c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils + /** * :: Experimental :: * Represents a random forest model. @@ -47,7 +48,7 @@ import org.apache.spark.util.Utils */ @Experimental class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) - extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), + extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), combiningStrategy = if (algo == Classification) Vote else Average) with Saveable { @@ -58,11 +59,13 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis RandomForestModel.SaveLoadV1_0.thisClassName) } - override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + override protected def formatVersion: String = RandomForestModel.formatVersion } object RandomForestModel extends Loader[RandomForestModel] { + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + override def load(sc: SparkContext, path: String): RandomForestModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -102,15 +105,13 @@ class GradientBoostedTreesModel( extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) with Saveable { - require(trees.size == treeWeights.size) + require(trees.length == treeWeights.length) override def save(sc: SparkContext, path: String): Unit = { TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, GradientBoostedTreesModel.SaveLoadV1_0.thisClassName) } - override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion - /** * Method to compute error or loss for every iteration of gradient boosting. * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] @@ -138,7 +139,7 @@ class GradientBoostedTreesModel( evaluationArray(0) = predictionAndError.values.mean() val broadcastTrees = sc.broadcast(trees) - (1 until numIterations).map { nTree => + (1 until numIterations).foreach { nTree => predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => val currentTree = broadcastTrees.value(nTree) val currentTreeWeight = localTreeWeights(nTree) @@ -155,6 +156,7 @@ class GradientBoostedTreesModel( evaluationArray } + override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion } object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { @@ -200,17 +202,17 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { loss: Loss): RDD[(Double, Double)] = { val newPredError = data.zip(predictionAndError).mapPartitions { iter => - iter.map { - case (lp, (pred, error)) => { - val newPred = pred + tree.predict(lp.features) * treeWeight - val newError = loss.computeError(newPred, lp.label) - (newPred, newError) - } + iter.map { case (lp, (pred, error)) => + val newPred = pred + tree.predict(lp.features) * treeWeight + val newError = loss.computeError(newPred, lp.label) + (newPred, newError) } } newPredError } + private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion + override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = { val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName @@ -340,12 +342,12 @@ private[tree] sealed class TreeEnsembleModel( } /** - * Get number of trees in forest. + * Get number of trees in ensemble. */ - def numTrees: Int = trees.size + def numTrees: Int = trees.length /** - * Get total number of nodes, summed over all trees in the forest. + * Get total number of nodes, summed over all trees in the ensemble. */ def totalNumNodes: Int = trees.map(_.numNodes).sum } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java new file mode 100644 index 0000000000000..43b8787f9dd7e --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.File; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.util.Utils; + + +public class JavaDecisionTreeClassifierSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) { + dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]); + } + DecisionTreeClassificationModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model3.save(sc.sc(), path); + DecisionTreeClassificationModel sameModel = + DecisionTreeClassificationModel.load(sc.sc(), path); + TreeTests.checkEqual(model3, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java new file mode 100644 index 0000000000000..a3a339004f31c --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression; + +import java.io.File; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.impl.TreeTests; +import org.apache.spark.mllib.classification.LogisticRegressionSuite; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.util.Utils; + + +public class JavaDecisionTreeRegressorSuite implements Serializable { + + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runDT() { + int nPoints = 20; + double A = 2.0; + double B = -1.5; + + JavaRDD data = sc.parallelize( + LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); + Map categoricalFeatures = new HashMap(); + DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + + // This tests setters. Training with various options is tested in Scala. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setMaxDepth(2) + .setMaxBins(10) + .setMinInstancesPerNode(5) + .setMinInfoGain(0.0) + .setMaxMemoryInMB(256) + .setCacheNodeIds(false) + .setCheckpointInterval(10) + .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern + for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) { + dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]); + } + DecisionTreeRegressionModel model = dt.fit(dataFrame); + + model.transform(dataFrame); + model.numNodes(); + model.depth(); + model.toDebugString(); + + /* + // TODO: Add test once save/load are implemented. + File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark"); + String path = tempDir.toURI().toString(); + try { + model2.save(sc.sc(), path); + DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path); + TreeTests.checkEqual(model2, sameModel); + } finally { + Utils.deleteRecursively(tempDir); + } + */ + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala index 0dcfe5a2002dc..17ddd335deb6d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala @@ -44,7 +44,7 @@ class AttributeGroupSuite extends FunSuite { group("abc") } assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name)) - assert(group === AttributeGroup.fromStructField(group.toStructField)) + assert(group === AttributeGroup.fromStructField(group.toStructField())) } test("attribute group without attributes") { @@ -54,7 +54,7 @@ class AttributeGroupSuite extends FunSuite { assert(group0.size === 10) assert(group0.attributes.isEmpty) assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name)) - assert(group0 === AttributeGroup.fromStructField(group0.toStructField)) + assert(group0 === AttributeGroup.fromStructField(group0.toStructField())) val group1 = new AttributeGroup("item") assert(group1.name === "item") diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index 6ec35b03656f9..3e1a7196e37cb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -36,9 +36,9 @@ class AttributeSuite extends FunSuite { assert(attr.max.isEmpty) assert(attr.std.isEmpty) assert(attr.sparsity.isEmpty) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = false) === metadata) - assert(attr.toMetadata(withType = true) === metadataWithType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === Attribute.fromMetadata(metadataWithType)) intercept[NoSuchElementException] { @@ -59,9 +59,9 @@ class AttributeSuite extends FunSuite { assert(!attr.isNominal) assert(attr.name === Some(name)) assert(attr.index === Some(index)) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = false) === metadata) - assert(attr.toMetadata(withType = true) === metadataWithType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = false) === metadata) + assert(attr.toMetadataImpl(withType = true) === metadataWithType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === Attribute.fromMetadata(metadataWithType)) val field = attr.toStructField() @@ -81,7 +81,7 @@ class AttributeSuite extends FunSuite { assert(attr2.max === Some(1.0)) assert(attr2.std === Some(0.5)) assert(attr2.sparsity === Some(0.3)) - assert(attr2 === Attribute.fromMetadata(attr2.toMetadata())) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) } test("bad numeric attributes") { @@ -105,9 +105,9 @@ class AttributeSuite extends FunSuite { assert(attr.values.isEmpty) assert(attr.numValues.isEmpty) assert(attr.isOrdinal.isEmpty) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) intercept[NoSuchElementException] { @@ -135,9 +135,9 @@ class AttributeSuite extends FunSuite { assert(attr.values === Some(values)) assert(attr.indexOf("medium") === 1) assert(attr.getValue(1) === "medium") - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === NominalAttribute.fromMetadata(metadataWithoutType)) assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) @@ -147,8 +147,8 @@ class AttributeSuite extends FunSuite { assert(attr2.index.isEmpty) assert(attr2.values.get === Array("small", "medium", "large", "x-large")) assert(attr2.indexOf("x-large") === 3) - assert(attr2 === Attribute.fromMetadata(attr2.toMetadata())) - assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false))) + assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl())) + assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadataImpl(withType = false))) } test("bad nominal attributes") { @@ -168,9 +168,9 @@ class AttributeSuite extends FunSuite { assert(attr.name.isEmpty) assert(attr.index.isEmpty) assert(attr.values.isEmpty) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) intercept[NoSuchElementException] { @@ -196,9 +196,9 @@ class AttributeSuite extends FunSuite { assert(attr.name === Some(name)) assert(attr.index === Some(index)) assert(attr.values.get === values) - assert(attr.toMetadata() === metadata) - assert(attr.toMetadata(withType = true) === metadata) - assert(attr.toMetadata(withType = false) === metadataWithoutType) + assert(attr.toMetadataImpl() === metadata) + assert(attr.toMetadataImpl(withType = true) === metadata) + assert(attr.toMetadataImpl(withType = false) === metadataWithoutType) assert(attr === Attribute.fromMetadata(metadata)) assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType)) assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField())) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala new file mode 100644 index 0000000000000..af88595df5245 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext { + + import DecisionTreeClassifierSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + orderedLabeledPointsWithLabel0RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()) + orderedLabeledPointsWithLabel1RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()) + categoricalDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass()) + continuousDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass()) + categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize( + OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Binary classification stump with ordered categorical features") { + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + val numClasses = 2 + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") { + val dt = new DecisionTreeClassifier() + .setMaxDepth(3) + .setMaxBins(100) + val numClasses = 2 + Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd => + DecisionTreeClassifier.supportedImpurities.foreach { impurity => + dt.setImpurity(impurity) + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + } + } + + test("Multiclass classification stump with 3-ary (unordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 3 + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Binary classification stump with 2 continuous features") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with unordered categorical features," + + " with just enough bins") { + val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(maxBins) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with continuous features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with continuous + unordered categorical features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with 10-ary (ordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("Multiclass classification tree with 10-ary (ordered) categorical features," + + " with just enough bins") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(10) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min instances per node requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("do not choose split that does not satisfy min instance per node requirements") { + // if a split does not satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxBins(2) + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures, numClasses) + } + + test("split must satisfy min info gain requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInfoGain(1.0) + val numClasses = 2 + compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: Reinstate test once save/load are implemented + /* + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification) + val newModel = DecisionTreeClassificationModel.fromOld(oldModel) + + // Save model, load it back, and compare. + try { + newModel.save(sc, path) + val sameNewModel = DecisionTreeClassificationModel.load(sc, path) + TreeTests.checkEqual(newModel, sameNewModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + */ +} + +private[ml] object DecisionTreeClassifierSuite extends FunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent, + newTree.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 81ef831c42e55..1b261b2643854 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -228,7 +228,7 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext { } val attrGroup = new AttributeGroup("features", featureAttributes) val densePoints1WithMeta = - densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata)) + densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata())) val vectorIndexer = getIndexer.setMaxCategories(2) val model = vectorIndexer.fit(densePoints1WithMeta) // Check that ML metadata are preserved. diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala new file mode 100644 index 0000000000000..2e57d4ce37f1d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl + +import scala.collection.JavaConverters._ + +import org.scalatest.FunSuite + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame} + + +private[ml] object TreeTests extends FunSuite { + + /** + * Convert the given data to a DataFrame, and set the features and label metadata. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param categoricalFeatures Map: categorical feature index -> number of distinct values + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @return DataFrame with metadata + */ + def setMetadata( + data: RDD[LabeledPoint], + categoricalFeatures: Map[Int, Int], + numClasses: Int): DataFrame = { + val sqlContext = new SQLContext(data.sparkContext) + import sqlContext.implicits._ + val df = data.toDF() + val numFeatures = data.first().features.size + val featuresAttributes = Range(0, numFeatures).map { feature => + if (categoricalFeatures.contains(feature)) { + NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) + } else { + NumericAttribute.defaultAttr.withIndex(feature) + } + }.toArray + val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName("label") + } else { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + df.select(df("features").as("features", featuresMetadata), + df("label").as("label", labelMetadata)) + } + + /** Java-friendly version of [[setMetadata()]] */ + def setMetadata( + data: JavaRDD[LabeledPoint], + categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], + numClasses: Int): DataFrame = { + setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numClasses) + } + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + try { + checkEqual(a.rootNode, b.rootNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendants are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: Node): Unit = { + assert(a.prediction === b.prediction) + assert(a.impurity === b.impurity) + (a, b) match { + case (aye: InternalNode, bee: InternalNode) => + assert(aye.split === bee.split) + checkEqual(aye.leftChild, bee.leftChild) + checkEqual(aye.rightChild, bee.rightChild) + case (aye: LeafNode, bee: LeafNode) => // do nothing + case _ => + throw new AssertionError("Found mismatched nodes") + } + } + + // TODO: Reinstate after adding ensembles + /** + * Check if the two models are exactly the same. + * If the models are not equal, this throws an exception. + */ + /* + def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { + try { + a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) => + TreeTests.checkEqual(treeA, treeB) + } + assert(a.getTreeWeights === b.getTreeWeights) + } catch { + case ex: Exception => throw new AssertionError( + "checkEqual failed since the two tree ensembles were not identical") + } + } + */ +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala new file mode 100644 index 0000000000000..0b40fe33fae9d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.scalatest.FunSuite + +import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, + DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + + +class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext { + + import DecisionTreeRegressorSuite.compareAPIs + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Regression stump with 3-ary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3, 1-> 3) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + test("Regression stump with binary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 2, 1-> 2) + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + // TODO: test("model save/load") +} + +private[ml] object DecisionTreeRegressorSuite extends FunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val oldStrategy = dt.getOldStrategy(categoricalFeatures) + val oldTree = OldDecisionTree.train(data, oldStrategy) + val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newTree = dt.fit(newData) + // Use parent, fittingParamMap from newTree since these are not checked anyways. + val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent, + newTree.fittingParamMap, categoricalFeatures) + TreeTests.checkEqual(oldTreeAsNew, newTree) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 4c162df810bb2..249b8eae19b17 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -36,6 +36,10 @@ import org.apache.spark.util.Utils class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { + ///////////////////////////////////////////////////////////////////////////// + // Tests examining individual elements of training + ///////////////////////////////////////////////////////////////////////////// + test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(bins(0).length === 0) } + test("Avoid aggregation on the last level") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Second level node building with vs. without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, + numClasses = 2, maxBins = 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val rootNode1 = modelOneNode.topNode.deepCopy() + val rootNode2 = modelOneNode.topNode.deepCopy() + assert(rootNode1.leftNode.nonEmpty) + assert(rootNode1.rightNode.nonEmpty) + + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + // Single group second level tree construction. + val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) + val treeToNodeToIndexInfo = Map((0, Map( + (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), + (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + val children1 = new Array[Node](2) + children1(0) = rootNode1.leftNode.get + children1(1) = rootNode1.rightNode.get + + // Train one second-level node at a time. + val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) + val treeToNodeToIndexInfoA = Map((0, Map( + (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) + val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) + val treeToNodeToIndexInfoB = Map((0, Map( + (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) + val children2 = new Array[Node](2) + children2(0) = rootNode2.leftNode.get + children2(1) = rootNode2.rightNode.get + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until 2) { + assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) + assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) + val stats1 = children1(i).stats.get + val stats2 = children2(i).stats.get + assert(stats1.gain === stats2.gain) + assert(stats1.impurity === stats2.impurity) + assert(stats1.leftImpurity === stats2.leftImpurity) + assert(stats1.rightImpurity === stats2.rightImpurity) + assert(children1(i).predict.predict === children2(i).predict.predict) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// test("Binary classification stump with ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() @@ -438,76 +601,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(rootNode.predict.predict === 1) } - test("Second level node building with vs. without groups") { - val arr = DecisionTreeSuite.generateOrderedLabeledPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - assert(splits.length === 2) - assert(splits(0).length === 99) - assert(bins.length === 2) - assert(bins(0).length === 100) - - // Train a 1-node model - val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, - numClasses = 2, maxBins = 100) - val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNode1 = modelOneNode.topNode.deepCopy() - val rootNode2 = modelOneNode.topNode.deepCopy() - assert(rootNode1.leftNode.nonEmpty) - assert(rootNode1.rightNode.nonEmpty) - - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - // Single group second level tree construction. - val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) - val treeToNodeToIndexInfo = Map((0, Map( - (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), - (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - val children1 = new Array[Node](2) - children1(0) = rootNode1.leftNode.get - children1(1) = rootNode1.rightNode.get - - // Train one second-level node at a time. - val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) - val treeToNodeToIndexInfoA = Map((0, Map( - (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) - val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) - val treeToNodeToIndexInfoB = Map((0, Map( - (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) - nodeQueue.clear() - DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), - nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) - val children2 = new Array[Node](2) - children2(0) = rootNode2.leftNode.get - children2(1) = rootNode2.rightNode.get - - // Verify whether the splits obtained using single group and multiple group level - // construction strategies are the same. - for (i <- 0 until 2) { - assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0) - assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0) - assert(children1(i).split === children2(i).split) - assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty) - val stats1 = children1(i).stats.get - val stats2 = children2(i).stats.get - assert(stats1.gain === stats2.gain) - assert(stats1.impurity === stats2.impurity) - assert(stats1.leftImpurity === stats2.leftImpurity) - assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict.predict === children2(i).predict.predict) - } - } - test("Multiclass classification stump with 3-ary (unordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) @@ -528,11 +621,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) - arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 2) @@ -544,11 +637,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("Binary classification stump with 2 continuous features") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -668,11 +761,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min instances per node requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) - + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClasses = 2, minInstancesPerNode = 2) @@ -695,11 +787,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { test("do not choose split that does not satisfy min instance per node requirements") { // if a split does not satisfy min instances per node requirements, // this split is invalid, even though the information gain of split is large. - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) - arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, @@ -715,10 +807,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { } test("split must satisfy min info gain requirements") { - val arr = new Array[LabeledPoint](3) - arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) - arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) - arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) val input = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, @@ -739,91 +831,9 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { assert(gain == InformationGainStats.invalidInformationGainStats) } - test("Avoid aggregation on the last level") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue leaf nodes into node queue - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } - - test("Avoid aggregation if impurity is 0.0") { - val arr = new Array[LabeledPoint](4) - arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) - arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) - arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) - arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) - val input = sc.parallelize(arr) - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, - numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) - - val topNode = Node.emptyNode(nodeIndex = 1) - assert(topNode.predict.predict === Double.MinValue) - assert(topNode.impurity === -1.0) - assert(topNode.isLeaf === false) - - val nodesForGroup = Map((0, Array(topNode))) - val treeToNodeToIndexInfo = Map((0, Map( - (topNode.id, new RandomForest.NodeIndexInfo(0, None)) - ))) - val nodeQueue = new mutable.Queue[(Int, Node)]() - DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), - nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) - - // don't enqueue a node into node queue if its impurity is 0.0 - assert(nodeQueue.isEmpty) - - // set impurity and predict for topNode - assert(topNode.predict.predict !== Double.MinValue) - assert(topNode.impurity !== -1.0) - - // set impurity and predict for child nodes - assert(topNode.leftNode.get.predict.predict === 0.0) - assert(topNode.rightNode.get.predict.predict === 1.0) - assert(topNode.leftNode.get.impurity === 0.0) - assert(topNode.rightNode.get.impurity === 0.0) - } + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// test("Node.subtreeIterator") { val model = DecisionTreeSuite.createModel(Classification) @@ -996,8 +1006,9 @@ object DecisionTreeSuite extends FunSuite { /** * Create a tree model. This is deterministic and contains a variety of node and feature types. + * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.) */ - private[tree] def createModel(algo: Algo): DecisionTreeModel = { + private[mllib] def createModel(algo: Algo): DecisionTreeModel = { val topNode = createInternalNode(id = 1, Continuous) val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical)) val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7)) @@ -1017,7 +1028,7 @@ object DecisionTreeSuite extends FunSuite { * make mistakes such as creating loops of Nodes. * If the trees are not equal, this prints the two trees and throws an exception. */ - private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { + private[mllib] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = { try { assert(a.algo === b.algo) checkEqual(a.topNode, b.topNode) From 59e206deb7346148412bbf5ba4ab626718fadf18 Mon Sep 17 00:00:00 2001 From: cafreeman Date: Fri, 17 Apr 2015 13:42:19 -0700 Subject: [PATCH 08/33] [SPARK-6807] [SparkR] Merge recent SparkR-pkg changes This PR pulls in recent changes in SparkR-pkg, including cartesian, intersection, sampleByKey, subtract, subtractByKey, except, and some API for StructType and StructField. Author: cafreeman Author: Davies Liu Author: Zongheng Yang Author: Shivaram Venkataraman Author: Shivaram Venkataraman Author: Sun Rui Closes #5436 from davies/R3 and squashes the following commits: c2b09be [Davies Liu] SQLTypes -> schema a5a02f2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into R3 168b7fe [Davies Liu] sort generics b1fe460 [Davies Liu] fix conflict in README.md e74c04e [Davies Liu] fix schema.R 4f5ac09 [Davies Liu] Merge branch 'master' of github.com:apache/spark into R5 41f8184 [Davies Liu] rm man ae78312 [Davies Liu] Merge pull request #237 from sun-rui/SPARKR-154_3 1bdcb63 [Zongheng Yang] Updates to README.md. 5a553e7 [cafreeman] Use object attribute instead of argument 71372d9 [cafreeman] Update docs and examples 8526d2e71 [cafreeman] Remove `tojson` functions 6ef5f2d [cafreeman] Fix spacing 7741d66 [cafreeman] Rename the SQL DataType function 141efd8 [Shivaram Venkataraman] Merge pull request #245 from hqzizania/upstream 9387402 [Davies Liu] fix style 40199eb [Shivaram Venkataraman] Move except into sorted position 07d0dbc [Sun Rui] [SPARKR-244] Fix test failure after integration of subtract() and subtractByKey() for RDD. 7e8caa3 [Shivaram Venkataraman] Merge pull request #246 from hlin09/fixCombineByKey ed66c81 [cafreeman] Update `subtract` to work with `generics.R` f3ba785 [cafreeman] Fixed duplicate export 275deb4 [cafreeman] Update `NAMESPACE` and tests 1a3b63d [cafreeman] new version of `CreateDF` 836c4bf [cafreeman] Update `createDataFrame` and `toDF` be5d5c1 [cafreeman] refactor schema functions 40338a4 [Zongheng Yang] Merge pull request #244 from sun-rui/SPARKR-154_5 20b97a6 [Zongheng Yang] Merge pull request #234 from hqzizania/assist ba54e34 [Shivaram Venkataraman] Merge pull request #238 from sun-rui/SPARKR-154_4 c9497a3 [Shivaram Venkataraman] Merge pull request #208 from lythesia/master b317aa7 [Zongheng Yang] Merge pull request #243 from hqzizania/master 136a07e [Zongheng Yang] Merge pull request #242 from hqzizania/stats cd66603 [cafreeman] new line at EOF 8b76e81 [Shivaram Venkataraman] Merge pull request #233 from redbaron/fail-early-on-missing-dep 7dd81b7 [cafreeman] Documentation 0e2a94f [cafreeman] Define functions for schema and fields --- R/pkg/DESCRIPTION | 2 +- R/pkg/NAMESPACE | 20 +- R/pkg/R/DataFrame.R | 18 +- R/pkg/R/RDD.R | 205 ++++++++++++------ R/pkg/R/SQLContext.R | 44 +--- R/pkg/R/SQLTypes.R | 64 ------ R/pkg/R/column.R | 2 +- R/pkg/R/generics.R | 46 +++- R/pkg/R/group.R | 2 +- R/pkg/R/pairRDD.R | 192 +++++++++++++--- R/pkg/R/schema.R | 162 ++++++++++++++ R/pkg/R/serialize.R | 9 +- R/pkg/R/utils.R | 80 +++++++ R/pkg/inst/tests/test_rdd.R | 193 ++++++++++++++--- R/pkg/inst/tests/test_shuffle.R | 12 + R/pkg/inst/tests/test_sparkSQL.R | 35 +-- R/pkg/inst/worker/worker.R | 59 ++++- .../scala/org/apache/spark/api/r/RRDD.scala | 131 +++++------ .../scala/org/apache/spark/api/r/SerDe.scala | 14 +- .../org/apache/spark/sql/api/r/SQLUtils.scala | 32 ++- 20 files changed, 971 insertions(+), 351 deletions(-) delete mode 100644 R/pkg/R/SQLTypes.R create mode 100644 R/pkg/R/schema.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 052f68c6c24e2..1c1779a763c7e 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -19,7 +19,7 @@ Collate: 'jobj.R' 'RDD.R' 'pairRDD.R' - 'SQLTypes.R' + 'schema.R' 'column.R' 'group.R' 'DataFrame.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index a354cdce74afa..80283643861ac 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -5,6 +5,7 @@ exportMethods( "aggregateByKey", "aggregateRDD", "cache", + "cartesian", "checkpoint", "coalesce", "cogroup", @@ -28,6 +29,7 @@ exportMethods( "fullOuterJoin", "glom", "groupByKey", + "intersection", "join", "keyBy", "keys", @@ -52,11 +54,14 @@ exportMethods( "reduceByKeyLocally", "repartition", "rightOuterJoin", + "sampleByKey", "sampleRDD", "saveAsTextFile", "saveAsObjectFile", "sortBy", "sortByKey", + "subtract", + "subtractByKey", "sumRDD", "take", "takeOrdered", @@ -95,6 +100,7 @@ exportClasses("DataFrame") exportMethods("columns", "distinct", "dtypes", + "except", "explain", "filter", "groupBy", @@ -118,7 +124,6 @@ exportMethods("columns", "show", "showDF", "sortDF", - "subtract", "toJSON", "toRDD", "unionAll", @@ -178,5 +183,14 @@ export("cacheTable", "toDF", "uncacheTable") -export("print.structType", - "print.structField") +export("sparkRSQL.init", + "sparkRHive.init") + +export("structField", + "structField.jobj", + "structField.character", + "print.structField", + "structType", + "structType.jobj", + "structType.structField", + "print.structType") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 044fdb4d01223..861fe1c78b0db 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -17,7 +17,7 @@ # DataFrame.R - DataFrame class and methods implemented in S4 OO classes -#' @include generics.R jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R +#' @include generics.R jobj.R schema.R RDD.R pairRDD.R column.R group.R NULL setOldClass("jobj") @@ -1141,15 +1141,15 @@ setMethod("intersect", dataFrame(intersected) }) -#' Subtract +#' except #' #' Return a new DataFrame containing rows in this DataFrame #' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL. #' #' @param x A Spark DataFrame #' @param y A Spark DataFrame -#' @return A DataFrame containing the result of the subtract operation. -#' @rdname subtract +#' @return A DataFrame containing the result of the except operation. +#' @rdname except #' @export #' @examples #'\dontrun{ @@ -1157,13 +1157,15 @@ setMethod("intersect", #' sqlCtx <- sparkRSQL.init(sc) #' df1 <- jsonFile(sqlCtx, path) #' df2 <- jsonFile(sqlCtx, path2) -#' subtractDF <- subtract(df, df2) +#' exceptDF <- except(df, df2) #' } -setMethod("subtract", +#' @rdname except +#' @export +setMethod("except", signature(x = "DataFrame", y = "DataFrame"), function(x, y) { - subtracted <- callJMethod(x@sdf, "except", y@sdf) - dataFrame(subtracted) + excepted <- callJMethod(x@sdf, "except", y@sdf) + dataFrame(excepted) }) #' Save the contents of the DataFrame to a data source diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 820027ef67e3b..128431334ca52 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -730,6 +730,7 @@ setMethod("take", index <- -1 jrdd <- getJRDD(x) numPartitions <- numPartitions(x) + serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size # estimates similar to the scala version of `take`. @@ -748,13 +749,14 @@ setMethod("take", elems <- convertJListToRList(partition, flatten = TRUE, logicalUpperBound = size, - serializedMode = getSerializedMode(x)) - # TODO: Check if this append is O(n^2)? + serializedMode = serializedModeRDD) + resList <- append(resList, elems) } resList }) + #' First #' #' Return the first element of an RDD @@ -1092,21 +1094,42 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { if (num < length(part)) { # R limitation: order works only on primitive types! ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending) - list(part[ord[1:num]]) + part[ord[1:num]] } else { - list(part) + part } } - reduceFunc <- function(elems, part) { - newElems <- append(elems, part) - # R limitation: order works only on primitive types! - ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending) - newElems[ord[1:num]] - } - newRdd <- mapPartitions(x, partitionFunc) - reduce(newRdd, reduceFunc) + + resList <- list() + index <- -1 + jrdd <- getJRDD(newRdd) + numPartitions <- numPartitions(newRdd) + serializedModeRDD <- getSerializedMode(newRdd) + + while (TRUE) { + index <- index + 1 + + if (index >= numPartitions) { + ord <- order(unlist(resList, recursive = FALSE), decreasing = !ascending) + resList <- resList[ord[1:num]] + break + } + + # a JList of byte arrays + partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index))) + partition <- partitionArr[[1]] + + # elems is capped to have at most `num` elements + elems <- convertJListToRList(partition, + flatten = TRUE, + logicalUpperBound = num, + serializedMode = serializedModeRDD) + + resList <- append(resList, elems) + } + resList } #' Returns the first N elements from an RDD in ascending order. @@ -1465,67 +1488,105 @@ setMethod("zipRDD", stop("Can only zip RDDs which have the same number of partitions.") } - if (getSerializedMode(x) != getSerializedMode(other) || - getSerializedMode(x) == "byte") { - # Append the number of elements in each partition to that partition so that we can later - # check if corresponding partitions of both RDDs have the same number of elements. - # - # Note that this appending also serves the purpose of reserialization, because even if - # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded - # as a single byte array. For example, partitions of an RDD generated from partitionBy() - # may be encoded as multiple byte arrays. - appendLength <- function(part) { - part[[length(part) + 1]] <- length(part) + 1 - part - } - x <- lapplyPartition(x, appendLength) - other <- lapplyPartition(other, appendLength) - } + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "zip", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other)) - # The zippedRDD's elements are of scala Tuple2 type. The serialized - # flag Here is used for the elements inside the tuples. - serializerMode <- getSerializedMode(x) - zippedRDD <- RDD(zippedJRDD, serializerMode) + mergePartitions(rdd, TRUE) + }) + +#' Cartesian product of this RDD and another one. +#' +#' Return the Cartesian product of this RDD and another one, +#' that is, the RDD of all pairs of elements (a, b) where a +#' is in this and b is in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @return A new RDD which is the Cartesian product of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:2) +#' sortByKey(cartesian(rdd, rdd)) +#' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) +#'} +#' @rdname cartesian +#' @aliases cartesian,RDD,RDD-method +setMethod("cartesian", + signature(x = "RDD", other = "RDD"), + function(x, other) { + rdds <- appendPartitionLengths(x, other) + jrdd <- callJMethod(getJRDD(rdds[[1]]), "cartesian", getJRDD(rdds[[2]])) + # The jrdd's elements are of scala Tuple2 type. The serialized + # flag here is used for the elements inside the tuples. + rdd <- RDD(jrdd, getSerializedMode(rdds[[1]])) - partitionFunc <- function(split, part) { - len <- length(part) - if (len > 0) { - if (serializerMode == "byte") { - lengthOfValues <- part[[len]] - lengthOfKeys <- part[[len - lengthOfValues]] - stopifnot(len == lengthOfKeys + lengthOfValues) - - # check if corresponding partitions of both RDDs have the same number of elements. - if (lengthOfKeys != lengthOfValues) { - stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") - } - - if (lengthOfKeys > 1) { - keys <- part[1 : (lengthOfKeys - 1)] - values <- part[(lengthOfKeys + 1) : (len - 1)] - } else { - keys <- list() - values <- list() - } - } else { - # Keys, values must have same length here, because this has - # been validated inside the JavaRDD.zip() function. - keys <- part[c(TRUE, FALSE)] - values <- part[c(FALSE, TRUE)] - } - mapply( - function(k, v) { - list(k, v) - }, - keys, - values, - SIMPLIFY = FALSE, - USE.NAMES = FALSE) - } else { - part - } + mergePartitions(rdd, FALSE) + }) + +#' Subtract an RDD with another RDD. +#' +#' Return an RDD with the elements from this that are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the elements from this that are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) +#' rdd2 <- parallelize(sc, list(2, 4)) +#' collect(subtract(rdd1, rdd2)) +#' # list(1, 1, 3) +#'} +#' @rdname subtract +#' @aliases subtract,RDD +setMethod("subtract", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + mapFunction <- function(e) { list(e, NA) } + rdd1 <- map(x, mapFunction) + rdd2 <- map(other, mapFunction) + keys(subtractByKey(rdd1, rdd2, numPartitions)) + }) + +#' Intersection of this RDD and another one. +#' +#' Return the intersection of this RDD and another one. +#' The output will not contain any duplicate elements, +#' even if the input RDDs did. Performs a hash partition +#' across the cluster. +#' Note that this method performs a shuffle internally. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions The number of partitions in the result RDD. +#' @return An RDD which is the intersection of these two RDDs. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) +#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) +#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) +#' # list(1, 2, 3) +#'} +#' @rdname intersection +#' @aliases intersection,RDD +setMethod("intersection", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + rdd1 <- map(x, function(v) { list(v, NA) }) + rdd2 <- map(other, function(v) { list(v, NA) }) + + filterFunction <- function(elem) { + iters <- elem[[2]] + all(as.vector( + lapply(iters, function(iter) { length(iter) > 0 }), mode = "logical")) } - - PipelinedRDD(zippedRDD, partitionFunc) + + keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 930ada22f4c38..4f05ba524a01a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -54,9 +54,9 @@ infer_type <- function(x) { # StructType types <- lapply(x, infer_type) fields <- lapply(1:length(x), function(i) { - list(name = names[[i]], type = types[[i]], nullable = TRUE) + structField(names[[i]], types[[i]], TRUE) }) - list(type = "struct", fields = fields) + do.call(structType, fields) } } else if (length(x) > 1) { list(type = "array", elementType = type, containsNull = TRUE) @@ -65,30 +65,6 @@ infer_type <- function(x) { } } -#' dump the schema into JSON string -tojson <- function(x) { - if (is.list(x)) { - names <- names(x) - if (!is.null(names)) { - items <- lapply(names, function(n) { - safe_n <- gsub('"', '\\"', n) - paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '') - }) - d <- paste(items, collapse = ', ') - paste('{', d, '}', sep = '') - } else { - l <- paste(lapply(x, tojson), collapse = ', ') - paste('[', l, ']', sep = '') - } - } else if (is.character(x)) { - paste('"', x, '"', sep = '') - } else if (is.logical(x)) { - if (x) "true" else "false" - } else { - stop(paste("unexpected type:", class(x))) - } -} - #' Create a DataFrame from an RDD #' #' Converts an RDD to a DataFrame by infer the types. @@ -134,7 +110,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { stop(paste("unexpected type:", class(data))) } - if (is.null(schema) || is.null(names(schema))) { + if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) { row <- first(rdd) names <- if (is.null(schema)) { names(row) @@ -143,7 +119,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { } if (is.null(names)) { names <- lapply(1:length(row), function(x) { - paste("_", as.character(x), sep = "") + paste("_", as.character(x), sep = "") }) } @@ -159,20 +135,18 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) { types <- lapply(row, infer_type) fields <- lapply(1:length(row), function(i) { - list(name = names[[i]], type = types[[i]], nullable = TRUE) + structField(names[[i]], types[[i]], TRUE) }) - schema <- list(type = "struct", fields = fields) + schema <- do.call(structType, fields) } - stopifnot(class(schema) == "list") - stopifnot(schema$type == "struct") - stopifnot(class(schema$fields) == "list") - schemaString <- tojson(schema) + stopifnot(class(schema) == "structType") + # schemaString <- tojson(schema) jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schemaString, sqlCtx) + srdd, schema$jobj, sqlCtx) dataFrame(sdf) } diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R deleted file mode 100644 index 962fba5b3cf03..0000000000000 --- a/R/pkg/R/SQLTypes.R +++ /dev/null @@ -1,64 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Utility functions for handling SparkSQL DataTypes. - -# Handler for StructType -structType <- function(st) { - obj <- structure(new.env(parent = emptyenv()), class = "structType") - obj$jobj <- st - obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) } - obj -} - -#' Print a Spark StructType. -#' -#' This function prints the contents of a StructType returned from the -#' SparkR JVM backend. -#' -#' @param x A StructType object -#' @param ... further arguments passed to or from other methods -print.structType <- function(x, ...) { - fieldsList <- lapply(x$fields(), function(i) { i$print() }) - print(fieldsList) -} - -# Handler for StructField -structField <- function(sf) { - obj <- structure(new.env(parent = emptyenv()), class = "structField") - obj$jobj <- sf - obj$name <- function() { callJMethod(sf, "name") } - obj$dataType <- function() { callJMethod(sf, "dataType") } - obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } - obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } - obj$nullable <- function() { callJMethod(sf, "nullable") } - obj$print <- function() { paste("StructField(", - paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "), - ")", sep = "") } - obj -} - -#' Print a Spark StructField. -#' -#' This function prints the contents of a StructField returned from the -#' SparkR JVM backend. -#' -#' @param x A StructField object -#' @param ... further arguments passed to or from other methods -print.structField <- function(x, ...) { - cat(x$print()) -} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index b282001d8b6b5..95fb9ff0887b6 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -17,7 +17,7 @@ # Column Class -#' @include generics.R jobj.R SQLTypes.R +#' @include generics.R jobj.R schema.R NULL setOldClass("jobj") diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5fb1ccaa84ee2..6c6233390134c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -230,6 +230,10 @@ setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") } ############ Binary Functions ############# +#' @rdname cartesian +#' @export +setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") }) + #' @rdname countByKey #' @export setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) @@ -238,6 +242,11 @@ setGeneric("countByKey", function(x) { standardGeneric("countByKey") }) #' @export setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") }) +#' @rdname intersection +#' @export +setGeneric("intersection", function(x, other, numPartitions = 1L) { + standardGeneric("intersection") }) + #' @rdname keys #' @export setGeneric("keys", function(x) { standardGeneric("keys") }) @@ -250,12 +259,18 @@ setGeneric("lookup", function(x, key) { standardGeneric("lookup") }) #' @export setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") }) +#' @rdname sampleByKey +#' @export +setGeneric("sampleByKey", + function(x, withReplacement, fractions, seed) { + standardGeneric("sampleByKey") + }) + #' @rdname values #' @export setGeneric("values", function(x) { standardGeneric("values") }) - ############ Shuffle Functions ############ #' @rdname aggregateByKey @@ -330,9 +345,24 @@ setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("ri #' @rdname sortByKey #' @export -setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) { - standardGeneric("sortByKey") -}) +setGeneric("sortByKey", + function(x, ascending = TRUE, numPartitions = 1L) { + standardGeneric("sortByKey") + }) + +#' @rdname subtract +#' @export +setGeneric("subtract", + function(x, other, numPartitions = 1L) { + standardGeneric("subtract") + }) + +#' @rdname subtractByKey +#' @export +setGeneric("subtractByKey", + function(x, other, numPartitions = 1L) { + standardGeneric("subtractByKey") + }) ################### Broadcast Variable Methods ################# @@ -357,6 +387,10 @@ setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @export setGeneric("explain", function(x, ...) { standardGeneric("explain") }) +#' @rdname except +#' @export +setGeneric("except", function(x, y) { standardGeneric("except") }) + #' @rdname filter #' @export setGeneric("filter", function(x, condition) { standardGeneric("filter") }) @@ -434,10 +468,6 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") }) #' @export setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") }) -#' @rdname subtract -#' @export -setGeneric("subtract", function(x, y) { standardGeneric("subtract") }) - #' @rdname tojson #' @export setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 855fbdfc7c4ca..02237b3672d6b 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -17,7 +17,7 @@ # group.R - GroupedData class and methods implemented in S4 OO classes -#' @include generics.R jobj.R SQLTypes.R column.R +#' @include generics.R jobj.R schema.R column.R NULL setOldClass("jobj") diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 5d64822859d1f..13efebc11c46e 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -430,7 +430,7 @@ setMethod("combineByKey", pred <- function(item) exists(item$hash, keys) lapply(part, function(item) { - item$hash <- as.character(item[[1]]) + item$hash <- as.character(hashCode(item[[1]])) updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner) }) convertEnvsToList(keys, combiners) @@ -443,7 +443,7 @@ setMethod("combineByKey", pred <- function(item) exists(item$hash, keys) lapply(part, function(item) { - item$hash <- as.character(item[[1]]) + item$hash <- as.character(hashCode(item[[1]])) updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity) }) convertEnvsToList(keys, combiners) @@ -452,19 +452,19 @@ setMethod("combineByKey", }) #' Aggregate a pair RDD by each key. -#' +#' #' Aggregate the values of each key in an RDD, using given combine functions #' and a neutral "zero value". This function can return a different result type, #' U, than the type of the values in this RDD, V. Thus, we need one operation -#' for merging a V into a U and one operation for merging two U's, The former -#' operation is used for merging values within a partition, and the latter is -#' used for merging values between partitions. To avoid memory allocation, both -#' of these functions are allowed to modify and return their first argument +#' for merging a V into a U and one operation for merging two U's, The former +#' operation is used for merging values within a partition, and the latter is +#' used for merging values between partitions. To avoid memory allocation, both +#' of these functions are allowed to modify and return their first argument #' instead of creating a new U. -#' +#' #' @param x An RDD. #' @param zeroValue A neutral "zero value". -#' @param seqOp A function to aggregate the values of each key. It may return +#' @param seqOp A function to aggregate the values of each key. It may return #' a different result type from the type of the values. #' @param combOp A function to aggregate results of seqOp. #' @return An RDD containing the aggregation result. @@ -476,7 +476,7 @@ setMethod("combineByKey", #' zeroValue <- list(0, 0) #' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } #' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } -#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) +#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) #' # list(list(1, list(3, 2)), list(2, list(7, 2))) #'} #' @rdname aggregateByKey @@ -493,12 +493,12 @@ setMethod("aggregateByKey", }) #' Fold a pair RDD by each key. -#' +#' #' Aggregate the values of each key in an RDD, using an associative function "func" -#' and a neutral "zero value" which may be added to the result an arbitrary -#' number of times, and must not change the result (e.g., 0 for addition, or +#' and a neutral "zero value" which may be added to the result an arbitrary +#' number of times, and must not change the result (e.g., 0 for addition, or #' 1 for multiplication.). -#' +#' #' @param x An RDD. #' @param zeroValue A neutral "zero value". #' @param func An associative function for folding values of each key. @@ -548,11 +548,11 @@ setMethod("join", function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) - + doJoin <- function(v) { joinTaggedList(v, list(FALSE, FALSE)) } - + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)), doJoin) }) @@ -568,8 +568,8 @@ setMethod("join", #' @param y An RDD to be joined. Should be an RDD where each element is #' list(K, V). #' @param numPartitions Number of partitions to create. -#' @return For each element (k, v) in x, the resulting RDD will either contain -#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) +#' @return For each element (k, v) in x, the resulting RDD will either contain +#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) #' if no elements in rdd2 have key k. #' @examples #'\dontrun{ @@ -586,11 +586,11 @@ setMethod("leftOuterJoin", function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) - + doJoin <- function(v) { joinTaggedList(v, list(FALSE, TRUE)) } - + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) @@ -623,18 +623,18 @@ setMethod("rightOuterJoin", function(x, y, numPartitions) { xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) }) yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) }) - + doJoin <- function(v) { joinTaggedList(v, list(TRUE, FALSE)) } - + joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin) }) #' Full outer join two RDDs #' #' @description -#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). +#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V). #' The key types of the two RDDs should be the same. #' #' @param x An RDD to be joined. Should be an RDD where each element is @@ -644,7 +644,7 @@ setMethod("rightOuterJoin", #' @param numPartitions Number of partitions to create. #' @return For each element (k, v) in x and (k, w) in y, the resulting RDD #' will contain all pairs (k, (v, w)) for both (k, v) in x and -#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements +#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements #' in x/y have key k. #' @examples #'\dontrun{ @@ -683,7 +683,7 @@ setMethod("fullOuterJoin", #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) #' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) -#' cogroup(rdd1, rdd2, numPartitions = 2L) +#' cogroup(rdd1, rdd2, numPartitions = 2L) #' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) #'} #' @rdname cogroup @@ -694,7 +694,7 @@ setMethod("cogroup", rdds <- list(...) rddsLen <- length(rdds) for (i in 1:rddsLen) { - rdds[[i]] <- lapply(rdds[[i]], + rdds[[i]] <- lapply(rdds[[i]], function(x) { list(x[[1]], list(i, x[[2]])) }) } union.rdd <- Reduce(unionRDD, rdds) @@ -719,7 +719,7 @@ setMethod("cogroup", } }) } - cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), + cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions), group.func) }) @@ -741,18 +741,18 @@ setMethod("sortByKey", signature(x = "RDD"), function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) { rangeBounds <- list() - + if (numPartitions > 1) { rddSize <- count(x) # constant from Spark's RangePartitioner maxSampleSize <- numPartitions * 20 fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) - + samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L))) - + # Note: the built-in R sort() function only works on atomic vectors samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) - + if (length(samples) > 0) { rangeBounds <- lapply(seq_len(numPartitions - 1), function(i) { @@ -764,24 +764,146 @@ setMethod("sortByKey", rangePartitionFunc <- function(key) { partition <- 0 - + # TODO: Use binary search instead of linear search, similar with Spark while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { partition <- partition + 1 } - + if (ascending) { partition } else { numPartitions - partition - 1 } } - + partitionFunc <- function(part) { sortKeyValueList(part, decreasing = !ascending) } - + newRDD <- partitionBy(x, numPartitions, rangePartitionFunc) lapplyPartition(newRDD, partitionFunc) }) +#' Subtract a pair RDD with another pair RDD. +#' +#' Return an RDD with the pairs from x whose keys are not in other. +#' +#' @param x An RDD. +#' @param other An RDD. +#' @param numPartitions Number of the partitions in the result RDD. +#' @return An RDD with the pairs from x whose keys are not in other. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), +#' list("b", 5), list("a", 2))) +#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) +#' collect(subtractByKey(rdd1, rdd2)) +#' # list(list("b", 4), list("b", 5)) +#'} +#' @rdname subtractByKey +#' @aliases subtractByKey,RDD +setMethod("subtractByKey", + signature(x = "RDD", other = "RDD"), + function(x, other, numPartitions = SparkR::numPartitions(x)) { + filterFunction <- function(elem) { + iters <- elem[[2]] + (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) + } + + flatMapValues(filterRDD(cogroup(x, + other, + numPartitions = numPartitions), + filterFunction), + function (v) { v[[1]] }) + }) + +#' Return a subset of this RDD sampled by key. +#' +#' @description +#' \code{sampleByKey} Create a sample of this RDD using variable sampling rates +#' for different keys as specified by fractions, a key to sampling rate map. +#' +#' @param x The RDD to sample elements by key, where each element is +#' list(K, V) or c(K, V). +#' @param withReplacement Sampling with replacement or not +#' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, 1:3000) +#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x) +#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }}) +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) +#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE +#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE +#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE +#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE +#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE +#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE +#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE +#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE +#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE +#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored +#' fractions <- list(a = 0.2, b = 0.1) +#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c" +#'} +#' @rdname sampleByKey +#' @aliases sampleByKey,RDD-method +setMethod("sampleByKey", + signature(x = "RDD", withReplacement = "logical", + fractions = "vector", seed = "integer"), + function(x, withReplacement, fractions, seed) { + + for (elem in fractions) { + if (elem < 0.0) { + stop(paste("Negative fraction value ", fractions[which(fractions == elem)])) + } + } + + # The sampler: takes a partition and returns its sampled version. + samplingFunc <- function(split, part) { + set.seed(bitwXor(seed, split)) + res <- vector("list", length(part)) + len <- 0 + + # mixing because the initial seeds are close to each other + runif(10) + + for (elem in part) { + if (elem[[1]] %in% names(fractions)) { + frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))]) + if (withReplacement) { + count <- rpois(1, frac) + if (count > 0) { + res[(len + 1):(len + count)] <- rep(list(elem), count) + len <- len + count + } + } else { + if (runif(1) < frac) { + len <- len + 1 + res[[len]] <- elem + } + } + } else { + stop("KeyError: \"", elem[[1]], "\"") + } + } + + # TODO(zongheng): look into the performance of the current + # implementation. Look into some iterator package? Note that + # Scala avoids many calls to creating an empty list and PySpark + # similarly achieves this using `yield'. (duplicated from sampleRDD) + if (len > 0) { + res[1:len] + } else { + list() + } + } + + lapplyPartitionsWithIndex(x, samplingFunc) + }) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R new file mode 100644 index 0000000000000..e442119086b17 --- /dev/null +++ b/R/pkg/R/schema.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A set of S3 classes and methods that support the SparkSQL `StructType` and `StructField +# datatypes. These are used to create and interact with DataFrame schemas. + +#' structType +#' +#' Create a structType object that contains the metadata for a DataFrame. Intended for +#' use with createDataFrame and toDF. +#' +#' @param x a structField object (created with the field() function) +#' @param ... additional structField objects +#' @return a structType object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' schema <- structType(structField("a", "integer"), structField("b", "string")) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } +structType <- function(x, ...) { + UseMethod("structType", x) +} + +structType.jobj <- function(x) { + obj <- structure(list(), class = "structType") + obj$jobj <- x + obj$fields <- function() { lapply(callJMethod(obj$jobj, "fields"), structField) } + obj +} + +structType.structField <- function(x, ...) { + fields <- list(x, ...) + if (!all(sapply(fields, inherits, "structField"))) { + stop("All arguments must be structField objects.") + } + sfObjList <- lapply(fields, function(field) { + field$jobj + }) + stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructType", + listToSeq(sfObjList)) + structType(stObj) +} + +#' Print a Spark StructType. +#' +#' This function prints the contents of a StructType returned from the +#' SparkR JVM backend. +#' +#' @param x A StructType object +#' @param ... further arguments passed to or from other methods +print.structType <- function(x, ...) { + cat("StructType\n", + sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(), + "\", type = \"", field$dataType.toString(), + "\", nullable = ", field$nullable(), "\n", + sep = "") }) + , sep = "") +} + +#' structField +#' +#' Create a structField object that contains the metadata for a single field in a schema. +#' +#' @param x The name of the field +#' @param type The data type of the field +#' @param nullable A logical vector indicating whether or not the field is nullable +#' @return a structField object +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlCtx <- sparkRSQL.init(sc) +#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) +#' field1 <- structField("a", "integer", TRUE) +#' field2 <- structField("b", "string", TRUE) +#' schema <- structType(field1, field2) +#' df <- createDataFrame(sqlCtx, rdd, schema) +#' } + +structField <- function(x, ...) { + UseMethod("structField", x) +} + +structField.jobj <- function(x) { + obj <- structure(list(), class = "structField") + obj$jobj <- x + obj$name <- function() { callJMethod(x, "name") } + obj$dataType <- function() { callJMethod(x, "dataType") } + obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") } + obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") } + obj$nullable <- function() { callJMethod(x, "nullable") } + obj +} + +structField.character <- function(x, type, nullable = TRUE) { + if (class(x) != "character") { + stop("Field name must be a string.") + } + if (class(type) != "character") { + stop("Field type must be a string.") + } + if (class(nullable) != "logical") { + stop("nullable must be either TRUE or FALSE") + } + options <- c("byte", + "integer", + "double", + "numeric", + "character", + "string", + "binary", + "raw", + "logical", + "boolean", + "timestamp", + "date") + dataType <- if (type %in% options) { + type + } else { + stop(paste("Unsupported type for Dataframe:", type)) + } + sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "createStructField", + x, + dataType, + nullable) + structField(sfObj) +} + +#' Print a Spark StructField. +#' +#' This function prints the contents of a StructField returned from the +#' SparkR JVM backend. +#' +#' @param x A StructField object +#' @param ... further arguments passed to or from other methods +print.structField <- function(x, ...) { + cat("StructField(name = \"", x$name(), + "\", type = \"", x$dataType.toString(), + "\", nullable = ", x$nullable(), + ")", + sep = "") +} diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 8a9c0c652ce24..c53d0a961016f 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -69,8 +69,9 @@ writeJobj <- function(con, value) { } writeString <- function(con, value) { - writeInt(con, as.integer(nchar(value) + 1)) - writeBin(value, con, endian = "big") + utfVal <- enc2utf8(value) + writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1)) + writeBin(utfVal, con, endian = "big") } writeInt <- function(con, value) { @@ -189,7 +190,3 @@ writeArgs <- function(con, args) { } } } - -writeStrings <- function(con, stringList) { - writeLines(unlist(stringList), con) -} diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index c337fb0751e72..23305d3c67074 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -465,3 +465,83 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { } func } + +# Append partition lengths to each partition in two input RDDs if needed. +# param +# x An RDD. +# Other An RDD. +# return value +# A list of two result RDDs. +appendPartitionLengths <- function(x, other) { + if (getSerializedMode(x) != getSerializedMode(other) || + getSerializedMode(x) == "byte") { + # Append the number of elements in each partition to that partition so that we can later + # know the boundary of elements from x and other. + # + # Note that this appending also serves the purpose of reserialization, because even if + # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded + # as a single byte array. For example, partitions of an RDD generated from partitionBy() + # may be encoded as multiple byte arrays. + appendLength <- function(part) { + len <- length(part) + part[[len + 1]] <- len + 1 + part + } + x <- lapplyPartition(x, appendLength) + other <- lapplyPartition(other, appendLength) + } + list (x, other) +} + +# Perform zip or cartesian between elements from two RDDs in each partition +# param +# rdd An RDD. +# zip A boolean flag indicating this call is for zip operation or not. +# return value +# A result RDD. +mergePartitions <- function(rdd, zip) { + serializerMode <- getSerializedMode(rdd) + partitionFunc <- function(split, part) { + len <- length(part) + if (len > 0) { + if (serializerMode == "byte") { + lengthOfValues <- part[[len]] + lengthOfKeys <- part[[len - lengthOfValues]] + stopifnot(len == lengthOfKeys + lengthOfValues) + + # For zip operation, check if corresponding partitions of both RDDs have the same number of elements. + if (zip && lengthOfKeys != lengthOfValues) { + stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.") + } + + if (lengthOfKeys > 1) { + keys <- part[1 : (lengthOfKeys - 1)] + } else { + keys <- list() + } + if (lengthOfValues > 1) { + values <- part[(lengthOfKeys + 1) : (len - 1)] + } else { + values <- list() + } + + if (!zip) { + return(mergeCompactLists(keys, values)) + } + } else { + keys <- part[c(TRUE, FALSE)] + values <- part[c(FALSE, TRUE)] + } + mapply( + function(k, v) { list(k, v) }, + keys, + values, + SIMPLIFY = FALSE, + USE.NAMES = FALSE) + } else { + part + } + } + + PipelinedRDD(rdd, partitionFunc) +} diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index b76e4db03e715..3ba7d1716302a 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -35,7 +35,7 @@ test_that("get number of partitions in RDD", { test_that("first on RDD", { expect_true(first(rdd) == 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_true(first(newrdd) == 2) }) test_that("count and length on RDD", { @@ -48,7 +48,7 @@ test_that("count by values and keys", { actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) - + actual <- countByKey(intRdd) expected <- list(list(2L, 2L), list(1L, 2L)) expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) @@ -82,11 +82,11 @@ test_that("filterRDD on RDD", { filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collect(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) - + filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd) actual <- collect(filtered.rdd) expect_equal(actual, list(list(1L, -1))) - + # Filter out all elements. filtered.rdd <- filterRDD(rdd, function(x) { x > 10 }) actual <- collect(filtered.rdd) @@ -96,7 +96,7 @@ test_that("filterRDD on RDD", { test_that("lookup on RDD", { vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) - + vals <- lookup(intRdd, 3L) expect_equal(vals, list()) }) @@ -110,7 +110,7 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) rdd2 <- lapply(rdd2, function(x) x + x) actual <- collect(rdd2) - expected <- list(24, 24, 24, 24, 24, + expected <- list(24, 24, 24, 24, 24, 168, 170, 172, 174, 176) expect_equal(actual, expected) }) @@ -248,10 +248,10 @@ test_that("flatMapValues() on pairwise RDDs", { l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4)))) actual <- collect(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) - + # Generate x to x+1 for every value actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) - expect_equal(actual, + expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) }) @@ -348,7 +348,7 @@ test_that("top() on RDDs", { rdd <- parallelize(sc, l) actual <- top(rdd, 6L) expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6]) - + l <- list("e", "d", "c", "d", "a") rdd <- parallelize(sc, l) actual <- top(rdd, 3L) @@ -358,7 +358,7 @@ test_that("top() on RDDs", { test_that("fold() on RDDs", { actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) - + rdd <- parallelize(sc, list()) actual <- fold(rdd, 0, "+") expect_equal(actual, 0) @@ -371,7 +371,7 @@ test_that("aggregateRDD() on RDDs", { combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) expect_equal(actual, list(10, 4)) - + rdd <- parallelize(sc, list()) actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp) expect_equal(actual, list(0, 0)) @@ -380,13 +380,13 @@ test_that("aggregateRDD() on RDDs", { test_that("zipWithUniqueId() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collect(zipWithUniqueId(rdd)) - expected <- list(list("a", 0), list("b", 3), list("c", 1), + expected <- list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) expect_equal(actual, expected) - + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) actual <- collect(zipWithUniqueId(rdd)) - expected <- list(list("a", 0), list("b", 1), list("c", 2), + expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) }) @@ -394,13 +394,13 @@ test_that("zipWithUniqueId() on RDDs", { test_that("zipWithIndex() on RDDs", { rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collect(zipWithIndex(rdd)) - expected <- list(list("a", 0), list("b", 1), list("c", 2), + expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) - + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) actual <- collect(zipWithIndex(rdd)) - expected <- list(list("a", 0), list("b", 1), list("c", 2), + expected <- list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) expect_equal(actual, expected) }) @@ -427,12 +427,12 @@ test_that("pipeRDD() on RDDs", { actual <- collect(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) - + trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n")) actual <- collect(pipeRDD(trailed.rdd, "sort")) expected <- list("", "1", "2", "3") expect_equal(actual, expected) - + rev.nums <- 9:0 rev.rdd <- parallelize(sc, rev.nums, 2L) actual <- collect(pipeRDD(rev.rdd, "sort")) @@ -446,11 +446,11 @@ test_that("zipRDD() on RDDs", { actual <- collect(zipRDD(rdd1, rdd2)) expect_equal(actual, list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))) - + mockFile = c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern="spark-test", fileext=".tmp") writeLines(mockFile, fileName) - + rdd <- textFile(sc, fileName, 1) actual <- collect(zipRDD(rdd, rdd)) expected <- lapply(mockFile, function(x) { list(x ,x) }) @@ -465,10 +465,125 @@ test_that("zipRDD() on RDDs", { actual <- collect(zipRDD(rdd, rdd1)) expected <- lapply(mockFile, function(x) { list(x, x) }) expect_equal(actual, expected) - + + unlink(fileName) +}) + +test_that("cartesian() on RDDs", { + rdd <- parallelize(sc, 1:3) + actual <- collect(cartesian(rdd, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(1, 1), list(1, 2), list(1, 3), + list(2, 1), list(2, 2), list(2, 3), + list(3, 1), list(3, 2), list(3, 3))) + + # test case where one RDD is empty + emptyRdd <- parallelize(sc, list()) + actual <- collect(cartesian(rdd, emptyRdd)) + expect_equal(actual, list()) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName) + actual <- collect(cartesian(rdd, rdd)) + expected <- list( + list("Spark is awesome.", "Spark is pretty."), + list("Spark is awesome.", "Spark is awesome."), + list("Spark is pretty.", "Spark is pretty."), + list("Spark is pretty.", "Spark is awesome.")) + expect_equal(sortKeyValueList(actual), expected) + + rdd1 <- parallelize(sc, 0:1) + actual <- collect(cartesian(rdd1, rdd)) + expect_equal(sortKeyValueList(actual), + list( + list(0, "Spark is pretty."), + list(0, "Spark is awesome."), + list(1, "Spark is pretty."), + list(1, "Spark is awesome."))) + + rdd1 <- map(rdd, function(x) { x }) + actual <- collect(cartesian(rdd, rdd1)) + expect_equal(sortKeyValueList(actual), expected) + unlink(fileName) }) +test_that("subtract() on RDDs", { + l <- list(1, 1, 2, 2, 3, 4) + rdd1 <- parallelize(sc, l) + + # subtract by itself + actual <- collect(subtract(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtract by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + l) + + rdd2 <- parallelize(sc, list(2, 4)) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="integer"))), + list(1, 1, 3)) + + l <- list("a", "a", "b", "b", "c", "d") + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list("b", "d")) + actual <- collect(subtract(rdd1, rdd2)) + expect_equal(as.list(sort(as.vector(actual, mode="character"))), + list("a", "a", "c")) +}) + +test_that("subtractByKey() on pairwise RDDs", { + l <- list(list("a", 1), list("b", 4), + list("b", 5), list("a", 2)) + rdd1 <- parallelize(sc, l) + + # subtractByKey by itself + actual <- collect(subtractByKey(rdd1, rdd1)) + expect_equal(actual, list()) + + # subtractByKey by an empty RDD + rdd2 <- parallelize(sc, list()) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(sortKeyValueList(actual), + sortKeyValueList(l)) + + rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list("b", 4), list("b", 5))) + + l <- list(list(1, 1), list(2, 4), + list(2, 5), list(1, 2)) + rdd1 <- parallelize(sc, l) + rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1))) + actual <- collect(subtractByKey(rdd1, rdd2)) + expect_equal(actual, + list(list(2, 4), list(2, 5))) +}) + +test_that("intersection() on RDDs", { + # intersection with self + actual <- collect(intersection(rdd, rdd)) + expect_equal(sort(as.integer(actual)), nums) + + # intersection with an empty RDD + emptyRdd <- parallelize(sc, list()) + actual <- collect(intersection(rdd, emptyRdd)) + expect_equal(actual, list()) + + rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) + rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8)) + actual <- collect(intersection(rdd1, rdd2)) + expect_equal(sort(as.integer(actual)), 1:3) +}) + test_that("join() on pairwise RDDs", { rdd1 <- parallelize(sc, list(list(1,1), list(2,4))) rdd2 <- parallelize(sc, list(list(1,2), list(1,3))) @@ -596,9 +711,9 @@ test_that("sortByKey() on pairwise RDDs", { sortedRdd3 <- sortByKey(rdd3) actual <- collect(sortedRdd3) expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4))) - + # test on the boundary cases - + # boundary case 1: the RDD to be sorted has only 1 partition rdd4 <- parallelize(sc, l, 1L) sortedRdd4 <- sortByKey(rdd4) @@ -623,7 +738,7 @@ test_that("sortByKey() on pairwise RDDs", { rdd7 <- parallelize(sc, l3, 2L) sortedRdd7 <- sortByKey(rdd7) actual <- collect(sortedRdd7) - expect_equal(actual, l3) + expect_equal(actual, l3) }) test_that("collectAsMap() on a pairwise RDD", { @@ -634,12 +749,36 @@ test_that("collectAsMap() on a pairwise RDD", { rdd <- parallelize(sc, list(list("a", 1), list("b", 2))) vals <- collectAsMap(rdd) expect_equal(vals, list(a = 1, b = 2)) - + rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4)) - + rdd <- parallelize(sc, list(list(1, "a"), list(2, "b"))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = "a", `2` = "b")) }) + +test_that("sampleByKey() on pairwise RDDs", { + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) + fractions <- list(a = 0.2, b = 0.1) + sample <- sampleByKey(pairsRDD, FALSE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")), TRUE) + expect_equal(50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")), TRUE) + expect_equal(lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0, TRUE) + expect_equal(lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000, TRUE) + expect_equal(lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0, TRUE) + expect_equal(lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000, TRUE) + + rdd <- parallelize(sc, 1:2000) + pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list(2, x) else list(3, x) }) + fractions <- list(`2` = 0.2, `3` = 0.1) + sample <- sampleByKey(pairsRDD, TRUE, fractions, 1618L) + expect_equal(100 < length(lookup(sample, 2)) && 300 > length(lookup(sample, 2)), TRUE) + expect_equal(50 < length(lookup(sample, 3)) && 150 > length(lookup(sample, 3)), TRUE) + expect_equal(lookup(sample, 2)[which.min(lookup(sample, 2))] >= 0, TRUE) + expect_equal(lookup(sample, 2)[which.max(lookup(sample, 2))] <= 2000, TRUE) + expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) + expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) +}) diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R index d1da8232aea81..d7dedda553c56 100644 --- a/R/pkg/inst/tests/test_shuffle.R +++ b/R/pkg/inst/tests/test_shuffle.R @@ -87,6 +87,18 @@ test_that("combineByKey for doubles", { expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) +test_that("combineByKey for characters", { + stringKeyRDD <- parallelize(sc, + list(list("max", 1L), list("min", 2L), + list("other", 3L), list("max", 4L)), 2L) + reduced <- combineByKey(stringKeyRDD, + function(x) { x }, "+", "+", 2L) + actual <- collect(reduced) + + expected <- list(list("max", 5L), list("min", 2L), list("other", 3L)) + expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) +}) + test_that("aggregateByKey", { # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index cf5cf6d1692af..25831ae2d9e18 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -44,9 +44,8 @@ test_that("infer types", { expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(a = 1L, b = "2")), - list(type = "struct", - fields = list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)))) + structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE))) e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), @@ -54,6 +53,18 @@ test_that("infer types", { valueContainsNull = TRUE)) }) +test_that("structType and structField", { + testField <- structField("a", "string") + expect_true(inherits(testField, "structField")) + expect_true(testField$name() == "a") + expect_true(testField$nullable()) + + testSchema <- structType(testField, structField("b", "integer")) + expect_true(inherits(testSchema, "structType")) + expect_true(inherits(testSchema$fields()[[2]], "structField")) + expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") +}) + test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlCtx, rdd, list("a", "b")) @@ -66,9 +77,8 @@ test_that("create DataFrame from RDD", { expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) - fields <- list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)) - schema <- list(type = "struct", fields = fields) + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlCtx, rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b")) @@ -94,9 +104,8 @@ test_that("toDF", { expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("_1", "_2")) - fields <- list(list(name = "a", type = "integer", nullable = TRUE), - list(name = "b", type = "string", nullable = TRUE)) - schema <- list(type = "struct", fields = fields) + schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), + structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) expect_true(inherits(df, "DataFrame")) expect_equal(columns(df), c("a", "b")) @@ -635,7 +644,7 @@ test_that("isLocal()", { expect_false(isLocal(df)) }) -test_that("unionAll(), subtract(), and intersect() on a DataFrame", { +test_that("unionAll(), except(), and intersect() on a DataFrame", { df <- jsonFile(sqlCtx, jsonPath) lines <- c("{\"name\":\"Bob\", \"age\":24}", @@ -650,10 +659,10 @@ test_that("unionAll(), subtract(), and intersect() on a DataFrame", { expect_true(count(unioned) == 6) expect_true(first(unioned)$name == "Michael") - subtracted <- sortDF(subtract(df, df2), desc(df$age)) + excepted <- sortDF(except(df, df2), desc(df$age)) expect_true(inherits(unioned, "DataFrame")) - expect_true(count(subtracted) == 2) - expect_true(first(subtracted)$name == "Justin") + expect_true(count(excepted) == 2) + expect_true(first(excepted)$name == "Justin") intersected <- sortDF(intersect(df, df2), df$age) expect_true(inherits(unioned, "DataFrame")) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index c6542928e8ddd..014bf7bd7b3fe 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -17,6 +17,23 @@ # Worker class +# Get current system time +currentTimeSecs <- function() { + as.numeric(Sys.time()) +} + +# Get elapsed time +elapsedSecs <- function() { + proc.time()[3] +} + +# Constants +specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L) + +# Timing R process boot +bootTime <- currentTimeSecs() +bootElap <- elapsedSecs() + rLibDir <- Sys.getenv("SPARKR_RLIBDIR") # Set libPaths to include SparkR package as loadNamespace needs this # TODO: Figure out if we can avoid this by not loading any objects that require @@ -37,7 +54,7 @@ serializer <- SparkR:::readString(inputCon) # Include packages as required packageNames <- unserialize(SparkR:::readRaw(inputCon)) for (pkg in packageNames) { - suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE)) + suppressPackageStartupMessages(library(as.character(pkg), character.only=TRUE)) } # read function dependencies @@ -46,6 +63,9 @@ computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen)) env <- environment(computeFunc) parent.env(env) <- .GlobalEnv # Attach under global environment. +# Timing init envs for computing +initElap <- elapsedSecs() + # Read and set broadcast variables numBroadcastVars <- SparkR:::readInt(inputCon) if (numBroadcastVars > 0) { @@ -56,6 +76,9 @@ if (numBroadcastVars > 0) { } } +# Timing broadcast +broadcastElap <- elapsedSecs() + # If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int # as number of partitions to create. numPartitions <- SparkR:::readInt(inputCon) @@ -73,14 +96,23 @@ if (isEmpty != 0) { } else if (deserializer == "row") { data <- SparkR:::readDeserializeRows(inputCon) } + # Timing reading input data for execution + inputElap <- elapsedSecs() + output <- computeFunc(partition, data) + # Timing computing + computeElap <- elapsedSecs() + if (serializer == "byte") { SparkR:::writeRawSerialize(outputCon, output) } else if (serializer == "row") { SparkR:::writeRowSerialize(outputCon, output) } else { - SparkR:::writeStrings(outputCon, output) + # write lines one-by-one with flag + lapply(output, function(line) SparkR:::writeString(outputCon, line)) } + # Timing output + outputElap <- elapsedSecs() } else { if (deserializer == "byte") { # Now read as many characters as described in funcLen @@ -90,6 +122,8 @@ if (isEmpty != 0) { } else if (deserializer == "row") { data <- SparkR:::readDeserializeRows(inputCon) } + # Timing reading input data for execution + inputElap <- elapsedSecs() res <- new.env() @@ -107,6 +141,8 @@ if (isEmpty != 0) { res[[bucket]] <- acc } invisible(lapply(data, hashTupleToEnvir)) + # Timing computing + computeElap <- elapsedSecs() # Step 2: write out all of the environment as key-value pairs. for (name in ls(res)) { @@ -116,13 +152,26 @@ if (isEmpty != 0) { length(res[[name]]$data) <- res[[name]]$counter SparkR:::writeRawSerialize(outputCon, res[[name]]$data) } + # Timing output + outputElap <- elapsedSecs() } +} else { + inputElap <- broadcastElap + computeElap <- broadcastElap + outputElap <- broadcastElap } +# Report timing +SparkR:::writeInt(outputCon, specialLengths$TIMING_DATA) +SparkR:::writeDouble(outputCon, bootTime) +SparkR:::writeDouble(outputCon, initElap - bootElap) # init +SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast +SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input +SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute +SparkR:::writeDouble(outputCon, outputElap - computeElap) # output + # End of output -if (serializer %in% c("byte", "row")) { - SparkR:::writeInt(outputCon, 0L) -} +SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM) close(outputCon) close(inputCon) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 5fa4d483b8342..6fea5e1144f2f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -42,10 +42,15 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( rLibDir: String, broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { + protected var dataStream: DataInputStream = _ + private var bootTime: Double = _ override def getPartitions: Array[Partition] = parent.partitions override def compute(partition: Partition, context: TaskContext): Iterator[U] = { + // Timing start + bootTime = System.currentTimeMillis / 1000.0 + // The parent may be also an RRDD, so we should launch it first. val parentIterator = firstParent[T].iterator(partition, context) @@ -69,7 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( // the socket used to receive the output of task val outSocket = serverSocket.accept() val inputStream = new BufferedInputStream(outSocket.getInputStream) - val dataStream = openDataStream(inputStream) + dataStream = new DataInputStream(inputStream) serverSocket.close() try { @@ -155,6 +160,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( } else if (deserializer == SerializationFormats.ROW) { dataOut.write(elem.asInstanceOf[Array[Byte]]) } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) printOut.println(elem) } } @@ -180,9 +186,41 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( }.start() } - protected def openDataStream(input: InputStream): Closeable + protected def readData(length: Int): U - protected def read(): U + protected def read(): U = { + try { + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length >= 0 => + readData(length) + } + } catch { + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) + } + } } /** @@ -202,31 +240,16 @@ private class PairwiseRRDD[T: ClassTag]( SerializationFormats.BYTE, packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - private var dataStream: DataInputStream = _ - - override protected def openDataStream(input: InputStream): Closeable = { - dataStream = new DataInputStream(input) - dataStream - } - - override protected def read(): (Int, Array[Byte]) = { - try { - val length = dataStream.readInt() - - length match { - case length if length == 2 => - val hashedKey = dataStream.readInt() - val contentPairsLength = dataStream.readInt() - val contentPairs = new Array[Byte](contentPairsLength) - dataStream.readFully(contentPairs) - (hashedKey, contentPairs) - case _ => null // End of input - } - } catch { - case eof: EOFException => { - throw new SparkException("R worker exited unexpectedly (crashed)", eof) - } - } + override protected def readData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null + } } lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) @@ -247,28 +270,13 @@ private class RRDD[T: ClassTag]( parent, -1, func, deserializer, serializer, packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - private var dataStream: DataInputStream = _ - - override protected def openDataStream(input: InputStream): Closeable = { - dataStream = new DataInputStream(input) - dataStream - } - - override protected def read(): Array[Byte] = { - try { - val length = dataStream.readInt() - - length match { - case length if length > 0 => - val obj = new Array[Byte](length) - dataStream.readFully(obj, 0, length) - obj - case _ => null - } - } catch { - case eof: EOFException => { - throw new SparkException("R worker exited unexpectedly (crashed)", eof) - } + override protected def readData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null } } @@ -289,26 +297,21 @@ private class StringRRDD[T: ClassTag]( parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - private var dataStream: BufferedReader = _ - - override protected def openDataStream(input: InputStream): Closeable = { - dataStream = new BufferedReader(new InputStreamReader(input)) - dataStream - } - - override protected def read(): String = { - try { - dataStream.readLine() - } catch { - case e: IOException => { - throw new SparkException("R worker exited unexpectedly (crashed)", e) - } + override protected def readData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null } } lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } +private object SpecialLengths { + val TIMING_DATA = -1 +} + private[r] class BufferedStreamThread( in: InputStream, name: String, diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index ccb2a371f4e48..371dfe454d1a2 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -85,13 +85,17 @@ private[spark] object SerDe { in.readDouble() } + def readStringBytes(in: DataInputStream, len: Int): String = { + val bytes = new Array[Byte](len) + in.readFully(bytes) + assert(bytes(len - 1) == 0) + val str = new String(bytes.dropRight(1), "UTF-8") + str + } + def readString(in: DataInputStream): String = { val len = in.readInt() - val asciiBytes = new Array[Byte](len) - in.readFully(asciiBytes) - assert(asciiBytes(len - 1) == 0) - val str = new String(asciiBytes.dropRight(1).map(_.toChar)) - str + readStringBytes(in, len) } def readBoolean(in: DataInputStream): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index d1ea7cc3e9162..ae77f72998a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} private[r] object SQLUtils { @@ -39,8 +39,34 @@ private[r] object SQLUtils { arr.toSeq } - def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = { - val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] + def createStructType(fields : Seq[StructField]): StructType = { + StructType(fields) + } + + def getSQLDataType(dataType: String): DataType = { + dataType match { + case "byte" => org.apache.spark.sql.types.ByteType + case "integer" => org.apache.spark.sql.types.IntegerType + case "double" => org.apache.spark.sql.types.DoubleType + case "numeric" => org.apache.spark.sql.types.DoubleType + case "character" => org.apache.spark.sql.types.StringType + case "string" => org.apache.spark.sql.types.StringType + case "binary" => org.apache.spark.sql.types.BinaryType + case "raw" => org.apache.spark.sql.types.BinaryType + case "logical" => org.apache.spark.sql.types.BooleanType + case "boolean" => org.apache.spark.sql.types.BooleanType + case "timestamp" => org.apache.spark.sql.types.TimestampType + case "date" => org.apache.spark.sql.types.DateType + case _ => throw new IllegalArgumentException(s"Invaid type $dataType") + } + } + + def createStructField(name: String, dataType: String, nullable: Boolean): StructField = { + val dtObj = getSQLDataType(dataType) + StructField(name, dtObj, nullable) + } + + def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = { val num = schema.fields.size val rowRDD = rdd.map(bytesToRow) sqlContext.createDataFrame(rowRDD, schema) From d305e686b3d73213784bd75cdad7d168b22a1dc4 Mon Sep 17 00:00:00 2001 From: Olivier Girardot Date: Fri, 17 Apr 2015 16:23:10 -0500 Subject: [PATCH 09/33] SPARK-6988 : Fix documentation regarding DataFrames using the Java API This patch includes : * adding how to use map after an sql query using javaRDD * fixing the first few java examples that were written in Scala Thank you for your time, Olivier. Author: Olivier Girardot Closes #5564 from ogirardot/branch-1.3 and squashes the following commits: 9f8d60e [Olivier Girardot] SPARK-6988 : Fix documentation regarding DataFrames using the Java API (cherry picked from commit 6b528dc139da594ef2e651d84bd91fe0f738a39d) Signed-off-by: Reynold Xin --- docs/sql-programming-guide.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 03500867df70f..d49233714a0bb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -193,8 +193,8 @@ df.groupBy("age").count().show()
{% highlight java %} -val sc: JavaSparkContext // An existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) +JavaSparkContext sc // An existing SparkContext. +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc) // Create the DataFrame DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json"); @@ -308,8 +308,8 @@ val df = sqlContext.sql("SELECT * FROM table")
{% highlight java %} -val sqlContext = ... // An existing SQLContext -val df = sqlContext.sql("SELECT * FROM table") +SQLContext sqlContext = ... // An existing SQLContext +DataFrame df = sqlContext.sql("SELECT * FROM table") {% endhighlight %}
@@ -435,7 +435,7 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. -List teenagerNames = teenagers.map(new Function() { +List teenagerNames = teenagers.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -599,7 +599,7 @@ DataFrame results = sqlContext.sql("SELECT name FROM people"); // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. -List names = results.map(new Function() { +List names = results.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } @@ -860,7 +860,7 @@ DataFrame parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); -List teenagerNames = teenagers.map(new Function() { +List teenagerNames = teenagers.javaRDD().map(new Function() { public String call(Row row) { return "Name: " + row.getString(0); } From a452c59210cf2c8ff8601cdb11401eea6dc14973 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 17 Apr 2015 16:30:13 -0500 Subject: [PATCH 10/33] Minor fix to SPARK-6958: Improve Python docstring for DataFrame.sort. As a follow up PR to #5544. cc davies Author: Reynold Xin Closes #5558 from rxin/sort-doc-improvement and squashes the following commits: f4c276f [Reynold Xin] Review feedback. d2dcf24 [Reynold Xin] Minor fix to SPARK-6958: Improve Python docstring for DataFrame.sort. --- python/pyspark/sql/dataframe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 326d22e72f104..d70c5b0a6930c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -489,8 +489,9 @@ def sort(self, *cols, **kwargs): """Returns a new :class:`DataFrame` sorted by the specified column(s). :param cols: list of :class:`Column` or column names to sort by. - :param ascending: sort by ascending order or not, could be bool, int - or list of bool, int (default: True). + :param ascending: boolean or list of boolean (default True). + Sort ascending vs. descending. Specify list for multiple sort orders. + If a list is specified, length of the list must equal length of the `cols`. >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] @@ -519,7 +520,7 @@ def sort(self, *cols, **kwargs): jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, jcols)] else: - raise TypeError("ascending can only be bool or list, but got %s" % type(ascending)) + raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) jdf = self._jdf.sort(self._jseq(jcols)) return DataFrame(jdf, self.sql_ctx) From c5ed510135aee3a1a0402057b3b5229892aa6f3a Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Fri, 17 Apr 2015 18:28:42 -0700 Subject: [PATCH 11/33] [SPARK-6703][Core] Provide a way to discover existing SparkContext's I've added a static getOrCreate method to the static SparkContext object that allows one to either retrieve a previously created SparkContext or to instantiate a new one with the provided config. The method accepts an optional SparkConf to make usage intuitive. Still working on a test for this, basically want to create a new context from scratch, then ensure that subsequent calls don't overwrite that. Author: Ilya Ganelin Closes #5501 from ilganeli/SPARK-6703 and squashes the following commits: db9a963 [Ilya Ganelin] Closing second spark context 1dc0444 [Ilya Ganelin] Added ref equality check 8c884fa [Ilya Ganelin] Made getOrCreate synchronized cb0c6b7 [Ilya Ganelin] Doc updates and code cleanup 270cfe3 [Ilya Ganelin] [SPARK-6703] Documentation fixes 15e8dea [Ilya Ganelin] Updated comments and added MiMa Exclude 0e1567c [Ilya Ganelin] Got rid of unecessary option for AtomicReference dfec4da [Ilya Ganelin] Changed activeContext to AtomicReference 733ec9f [Ilya Ganelin] Fixed some bugs in test code 8be2f83 [Ilya Ganelin] Replaced match with if e92caf7 [Ilya Ganelin] [SPARK-6703] Added test to ensure that getOrCreate both allows creation, retrieval, and a second context if desired a99032f [Ilya Ganelin] Spacing fix d7a06b8 [Ilya Ganelin] Updated SparkConf class to add getOrCreate method. Started test suite implementation --- .../scala/org/apache/spark/SparkContext.scala | 49 ++++++++++++++++--- .../org/apache/spark/SparkContextSuite.scala | 20 ++++++++ project/MimaExcludes.scala | 4 ++ 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e106c5c4bef60..86269eac52db0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -23,7 +23,7 @@ import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} import java.util.UUID.randomUUID import scala.collection.{Map, Set} @@ -1887,11 +1887,12 @@ object SparkContext extends Logging { private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() /** - * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`. + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `null`. * - * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. */ - private var activeContext: Option[SparkContext] = None + private val activeContext: AtomicReference[SparkContext] = + new AtomicReference[SparkContext](null) /** * Points to a partially-constructed SparkContext if some thread is in the SparkContext @@ -1926,7 +1927,8 @@ object SparkContext extends Logging { logWarning(warnMsg) } - activeContext.foreach { ctx => + if (activeContext.get() != null) { + val ctx = activeContext.get() val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + " To ignore this error, set spark.driver.allowMultipleContexts = true. " + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" @@ -1941,6 +1943,39 @@ object SparkContext extends Logging { } } + /** + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ + def getOrCreate(config: SparkConf): SparkContext = { + // Synchronize to ensure that multiple create requests don't trigger an exception + // from assertNoOtherContextIsRunning within setActiveContext + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + if (activeContext.get() == null) { + setActiveContext(new SparkContext(config), allowMultipleContexts = false) + } + activeContext.get() + } + } + + /** + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. + * + * This method allows not passing a SparkConf (useful if just retrieving). + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ + def getOrCreate(): SparkContext = { + getOrCreate(new SparkConf()) + } + /** * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is * running. Throws an exception if a running context is detected and logs a warning if another @@ -1967,7 +2002,7 @@ object SparkContext extends Logging { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { assertNoOtherContextIsRunning(sc, allowMultipleContexts) contextBeingConstructed = None - activeContext = Some(sc) + activeContext.set(sc) } } @@ -1978,7 +2013,7 @@ object SparkContext extends Logging { */ private[spark] def clearActiveContext(): Unit = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { - activeContext = None + activeContext.set(null) } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 94be1c6d6397c..728558a424780 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -67,6 +67,26 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { } } + test("Test getOrCreate") { + var sc2: SparkContext = null + SparkContext.clearActiveContext() + val conf = new SparkConf().setAppName("test").setMaster("local") + + sc = SparkContext.getOrCreate(conf) + + assert(sc.getConf.get("spark.app.name").equals("test")) + sc2 = SparkContext.getOrCreate(new SparkConf().setAppName("test2").setMaster("local")) + assert(sc2.getConf.get("spark.app.name").equals("test")) + assert(sc === sc2) + assert(sc eq sc2) + + // Try creating second context to confirm that it's still possible, if desired + sc2 = new SparkContext(new SparkConf().setAppName("test3").setMaster("local") + .set("spark.driver.allowMultipleContexts", "true")) + + sc2.stop() + } + test("BytesWritable implicit conversion is correct") { // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1564babefa62f..7ef363a2f07ad 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -68,6 +68,10 @@ object MimaExcludes { // SPARK-6693 add tostring with max lines and width for matrix ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrix.toString") + )++ Seq( + // SPARK-6703 Add getOrCreate method to SparkContext + ProblemFilters.exclude[IncompatibleResultTypeProblem] + ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") ) case v if v.startsWith("1.3") => From 6fbeb82e13db7117d8f216e6148632490a4bc5be Mon Sep 17 00:00:00 2001 From: Jongyoul Lee Date: Fri, 17 Apr 2015 18:30:55 -0700 Subject: [PATCH 12/33] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Defined executorCores from "spark.mesos.executor.cores" - Changed the amount of mesosExecutor's cores to executorCores. - Added new configuration option on running-on-mesos.md Author: Jongyoul Lee Closes #5063 from jongyoul/SPARK-6350 and squashes the following commits: 9238d6e [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Fixed docs - Changed configuration name - Made mesosExecutorCores private 2d41241 [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Fixed docs 89edb4f [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Fixed docs 8ba7694 [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Fixed docs 7549314 [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Fixed docs 4ae7b0c [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Removed TODO c27efce [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Fixed Mesos*Suite for supporting integer WorkerOffers - Fixed Documentation 1fe4c03 [Jongyoul Lee] [SPARK-6453][Mesos] Some Mesos*Suite have a different package with their classes - Change available resources of cpus to integer value beacuse WorkerOffer support the amount cpus as integer value 5f3767e [Jongyoul Lee] Revert "[SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode" 4b7c69e [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Changed configruation name and description from "spark.mesos.executor.cores" to "spark.executor.frameworkCores" 0556792 [Jongyoul Lee] [SPARK-6350][Mesos] Make mesosExecutorCores configurable in mesos "fine-grained" mode - Defined executorCores from "spark.mesos.executor.cores" - Changed the amount of mesosExecutor's cores to executorCores. - Added new configuration option on running-on-mesos.md --- .../cluster/mesos/MesosSchedulerBackend.scala | 14 +++++++------- .../cluster/mesos/MesosSchedulerBackendSuite.scala | 4 ++-- docs/running-on-mesos.md | 10 ++++++++++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index b381436839227..d9d62b0e287ed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -67,6 +67,8 @@ private[spark] class MesosSchedulerBackend( // The listener bus to publish executor added/removed events. val listenerBus = sc.listenerBus + + private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1) @volatile var appId: String = _ @@ -139,7 +141,7 @@ private[spark] class MesosSchedulerBackend( .setName("cpus") .setType(Value.Type.SCALAR) .setScalar(Value.Scalar.newBuilder() - .setValue(scheduler.CPUS_PER_TASK).build()) + .setValue(mesosExecutorCores).build()) .build() val memory = Resource.newBuilder() .setName("mem") @@ -220,10 +222,9 @@ private[spark] class MesosSchedulerBackend( val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue - // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK? (mem >= MemoryUtils.calculateTotalMemory(sc) && // need at least 1 for executor, 1 for task - cpus >= 2 * scheduler.CPUS_PER_TASK) || + cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) || (slaveIdsWithExecutors.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) } @@ -232,10 +233,9 @@ private[spark] class MesosSchedulerBackend( val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { getResource(o.getResourcesList, "cpus").toInt } else { - // If the executor doesn't exist yet, subtract CPU for executor - // TODO(pwendell): Should below just subtract "1"? - getResource(o.getResourcesList, "cpus").toInt - - scheduler.CPUS_PER_TASK + // If the Mesos executor has not been started on this slave yet, set aside a few + // cores for the Mesos executor by offering fewer cores to the Spark executor + (getResource(o.getResourcesList, "cpus") - mesosExecutorCores).toInt } new WorkerOffer( o.getSlaveId.getValue, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index a311512e82c5e..cdd7be0fbe5dd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -118,12 +118,12 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(0).getSlaveId.getValue, mesosOffers.get(0).getHostname, - 2 + (minCpu - backend.mesosExecutorCores).toInt )) expectedWorkerOffers.append(new WorkerOffer( mesosOffers.get(2).getSlaveId.getValue, mesosOffers.get(2).getHostname, - 2 + (minCpu - backend.mesosExecutorCores).toInt )) val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index c984639bd34cf..594bf78b67713 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -210,6 +210,16 @@ See the [configuration page](configuration.html) for information on Spark config Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting. + + spark.mesos.mesosExecutor.cores + 1.0 + + (Fine-grained mode only) Number of cores to give each Mesos executor. This does not + include the cores used to run the Spark tasks. In other words, even if no Spark task + is being run, each Mesos executor will occupy the number of cores configured here. + The value can be a floating point number. + + spark.mesos.executor.home driver side SPARK_HOME From 1991337336596f94698e79c2366f065c374128ab Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 17 Apr 2015 19:02:07 -0700 Subject: [PATCH 13/33] [SPARK-5933] [core] Move config deprecation warnings to SparkConf. I didn't find many deprecated configs after a grep-based search, but the ones I could find were moved to the centralized location in SparkConf. While there, I deprecated a couple more HS configs that mentioned time units. Author: Marcelo Vanzin Closes #5562 from vanzin/SPARK-5933 and squashes the following commits: dcb617e7 [Marcelo Vanzin] [SPARK-5933] [core] Move config deprecation warnings to SparkConf. --- .../main/scala/org/apache/spark/SparkConf.scala | 17 ++++++++++++++--- .../main/scala/org/apache/spark/SparkEnv.scala | 10 ++-------- .../deploy/history/FsHistoryProvider.scala | 15 +++------------ .../scala/org/apache/spark/SparkConfSuite.scala | 3 +++ docs/monitoring.md | 15 +++++++-------- .../spark/deploy/yarn/ApplicationMaster.scala | 9 +-------- 6 files changed, 30 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index b0186e9a007b8..e3a649d755450 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -403,6 +403,9 @@ private[spark] object SparkConf extends Logging { */ private val deprecatedConfigs: Map[String, DeprecatedConfig] = { val configs = Seq( + DeprecatedConfig("spark.cache.class", "0.8", + "The spark.cache.class property is no longer being used! Specify storage levels using " + + "the RDD.persist() method instead."), DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", "Please use spark.{driver,executor}.userClassPathFirst instead.")) Map(configs.map { cfg => (cfg.key -> cfg) }:_*) @@ -420,7 +423,15 @@ private[spark] object SparkConf extends Logging { "spark.history.fs.update.interval" -> Seq( AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"), AlternateConfig("spark.history.fs.updateInterval", "1.3"), - AlternateConfig("spark.history.updateInterval", "1.3")) + AlternateConfig("spark.history.updateInterval", "1.3")), + "spark.history.fs.cleaner.interval" -> Seq( + AlternateConfig("spark.history.fs.cleaner.interval.seconds", "1.4")), + "spark.history.fs.cleaner.maxAge" -> Seq( + AlternateConfig("spark.history.fs.cleaner.maxAge.seconds", "1.4")), + "spark.yarn.am.waitTime" -> Seq( + AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", + // Translate old value to a duration, with 10s wait time per try. + translation = s => s"${s.toLong * 10}s")) ) /** @@ -470,7 +481,7 @@ private[spark] object SparkConf extends Logging { configsWithAlternatives.get(key).flatMap { alts => alts.collectFirst { case alt if conf.contains(alt.key) => val value = conf.get(alt.key) - alt.translation.map(_(value)).getOrElse(value) + if (alt.translation != null) alt.translation(value) else value } } } @@ -514,6 +525,6 @@ private[spark] object SparkConf extends Logging { private case class AlternateConfig( key: String, version: String, - translation: Option[String => String] = None) + translation: String => String = null) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 0171488e09562..959aefabd8de4 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -103,7 +103,7 @@ class SparkEnv ( // actorSystem.awaitTermination() // Note that blockTransferService is stopped by BlockManager since it is started by it. - + // If we only stop sc, but the driver process still run as a services then we need to delete // the tmp dir, if not, it will create too many tmp dirs. // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the @@ -375,12 +375,6 @@ object SparkEnv extends Logging { "." } - // Warn about deprecated spark.cache.class property - if (conf.contains("spark.cache.class")) { - logWarning("The spark.cache.class property is no longer being used! Specify storage " + - "levels using the RDD.persist() method instead.") - } - val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf) } @@ -406,7 +400,7 @@ object SparkEnv extends Logging { shuffleMemoryManager, outputCommitCoordinator, conf) - + // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is // called, and we only need to do it for driver. Because driver may run as a service, and if we // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 985545742df67..47bdd7749ec3d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -52,8 +52,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s") // Interval between each cleaner checks for event logs to delete - private val CLEAN_INTERVAL_MS = conf.getLong("spark.history.fs.cleaner.interval.seconds", - DEFAULT_SPARK_HISTORY_FS_CLEANER_INTERVAL_S) * 1000 + private val CLEAN_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.cleaner.interval", "1d") private val logDir = conf.getOption("spark.history.fs.logDirectory") .map { d => Utils.resolveURI(d).toString } @@ -130,8 +129,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) { // A task that periodically cleans event logs on disk. - pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_MS, - TimeUnit.MILLISECONDS) + pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS) } } } @@ -270,8 +268,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis try { val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) .getOrElse(Seq[FileStatus]()) - val maxAge = conf.getLong("spark.history.fs.cleaner.maxAge.seconds", - DEFAULT_SPARK_HISTORY_FS_MAXAGE_S) * 1000 + val maxAge = conf.getTimeAsSeconds("spark.history.fs.cleaner.maxAge", "7d") * 1000 val now = System.currentTimeMillis() val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() @@ -417,12 +414,6 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" - - // One day - val DEFAULT_SPARK_HISTORY_FS_CLEANER_INTERVAL_S = Duration(1, TimeUnit.DAYS).toSeconds - - // One week - val DEFAULT_SPARK_HISTORY_FS_MAXAGE_S = Duration(7, TimeUnit.DAYS).toSeconds } private class FsApplicationHistoryInfo( diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 7d87ba5fd2610..8e6c200c4ba00 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -217,6 +217,9 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro val count = conf.getAll.filter { case (k, v) => k.startsWith("spark.history.") }.size assert(count === 4) + + conf.set("spark.yarn.applicationMaster.waitTries", "42") + assert(conf.getTimeAsSeconds("spark.yarn.am.waitTime") === 420) } } diff --git a/docs/monitoring.md b/docs/monitoring.md index 2a130224591ca..8a85928d6d44d 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -153,19 +153,18 @@ follows: - spark.history.fs.cleaner.interval.seconds - 86400 + spark.history.fs.cleaner.interval + 1d - How often the job history cleaner checks for files to delete, in seconds. Defaults to 86400 (one day). - Files are only deleted if they are older than spark.history.fs.cleaner.maxAge.seconds. + How often the job history cleaner checks for files to delete. + Files are only deleted if they are older than spark.history.fs.cleaner.maxAge. - spark.history.fs.cleaner.maxAge.seconds - 3600 * 24 * 7 + spark.history.fs.cleaner.maxAge + 7d - Job history files older than this many seconds will be deleted when the history cleaner runs. - Defaults to 3600 * 24 * 7 (1 week). + Job history files older than this will be deleted when the history cleaner runs. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c357b7ae9d4da..f7a84207e9da6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -373,14 +373,7 @@ private[spark] class ApplicationMaster( private def waitForSparkContextInitialized(): SparkContext = { logInfo("Waiting for spark context initialization") sparkContextRef.synchronized { - val waitTries = sparkConf.getOption("spark.yarn.applicationMaster.waitTries") - .map(_.toLong * 10000L) - if (waitTries.isDefined) { - logWarning( - "spark.yarn.applicationMaster.waitTries is deprecated, use spark.yarn.am.waitTime") - } - val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", - s"${waitTries.getOrElse(100000L)}ms") + val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s") val deadline = System.currentTimeMillis() + totalWaitTime while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) { From d850b4bd3a294dd245881e03f7f94bf970a7ee79 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 17 Apr 2015 19:17:06 -0700 Subject: [PATCH 14/33] [SPARK-6975][Yarn] Fix argument validation error `numExecutors` checking is failed when dynamic allocation is enabled with default configuration. Details can be seen is [SPARK-6975](https://issues.apache.org/jira/browse/SPARK-6975). sryza, please help me to review this, not sure is this the correct way, I think previous you change this part :) Author: jerryshao Closes #5551 from jerryshao/SPARK-6975 and squashes the following commits: 4335da1 [jerryshao] Change according to the comments 77bdcbd [jerryshao] Fix argument validation error --- .../org/apache/spark/deploy/yarn/ClientArguments.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index da6798cb1b279..1423533470fc0 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -103,9 +103,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) * This is intended to be called only after the provided arguments have been parsed. */ private def validateArgs(): Unit = { - if (numExecutors <= 0) { + if (numExecutors < 0 || (!isDynamicAllocationEnabled && numExecutors == 0)) { throw new IllegalArgumentException( - "You must specify at least 1 executor!\n" + getUsageMessage()) + s""" + |Number of executors was $numExecutors, but must be at least 1 + |(or 0 if dynamic executor allocation is enabled). + |${getUsageMessage()} + """.stripMargin) } if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) { throw new SparkException("Executor cores must not be less than " + From 5f095d56054d57c54d81db1d36cd46312810fb6a Mon Sep 17 00:00:00 2001 From: Olivier Girardot Date: Sat, 18 Apr 2015 00:31:01 -0700 Subject: [PATCH 15/33] SPARK-6992 : Fix documentation example for Spark SQL on StructType This patch is fixing the Java examples for Spark SQL when defining programmatically a Schema and mapping Rows. Author: Olivier Girardot Closes #5569 from ogirardot/branch-1.3 and squashes the following commits: c29e58d [Olivier Girardot] SPARK-6992 : Fix documentation example for Spark SQL on StructType (cherry picked from commit c9b1ba4b16a7afe93d45bf75b128cc0dd287ded0) Signed-off-by: Reynold Xin --- docs/sql-programming-guide.md | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d49233714a0bb..b2022546268a7 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -555,13 +555,16 @@ by `SQLContext`. For example: {% highlight java %} -// Import factory methods provided by DataType. -import org.apache.spark.sql.types.DataType; +import org.apache.spark.api.java.function.Function; +// Import factory methods provided by DataTypes. +import org.apache.spark.sql.types.DataTypes; // Import StructType and StructField import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructField; // Import Row. import org.apache.spark.sql.Row; +// Import RowFactory. +import org.apache.spark.sql.RowFactory; // sc is an existing JavaSparkContext. SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); @@ -575,16 +578,16 @@ String schemaString = "name age"; // Generate the schema based on the string of schema List fields = new ArrayList(); for (String fieldName: schemaString.split(" ")) { - fields.add(DataType.createStructField(fieldName, DataType.StringType, true)); + fields.add(DataTypes.createStructField(fieldName, DataTypes.StringType, true)); } -StructType schema = DataType.createStructType(fields); +StructType schema = DataTypes.createStructType(fields); // Convert records of the RDD (people) to Rows. JavaRDD rowRDD = people.map( new Function() { public Row call(String record) throws Exception { String[] fields = record.split(","); - return Row.create(fields[0], fields[1].trim()); + return RowFactory.create(fields[0], fields[1].trim()); } }); From 327ebf0cb5e236579bece057eda27b21aed0e2dc Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 18 Apr 2015 10:14:56 +0100 Subject: [PATCH 16/33] [core] [minor] Make sure ConnectionManager stops. My previous fix (force a selector wakeup) didn't seem to work since I ran into the hang again. So change the code a bit to be more explicit about the condition when the selector thread should exit. Author: Marcelo Vanzin Closes #5566 from vanzin/conn-mgr-hang and squashes the following commits: ddb2c03 [Marcelo Vanzin] [core] [minor] Make sure ConnectionManager stops. --- .../spark/network/nio/ConnectionManager.scala | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 5a74c13b38bf7..1a68e621eaee7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -188,6 +188,7 @@ private[nio] class ConnectionManager( private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + @volatile private var isActive = true private val selectorThread = new Thread("connection-manager-thread") { override def run(): Unit = ConnectionManager.this.run() } @@ -342,7 +343,7 @@ private[nio] class ConnectionManager( def run() { try { - while(!selectorThread.isInterrupted) { + while (isActive) { while (!registerRequests.isEmpty) { val conn: SendingConnection = registerRequests.dequeue() addListeners(conn) @@ -398,7 +399,7 @@ private[nio] class ConnectionManager( } catch { // Explicitly only dealing with CancelledKeyException here since other exceptions // should be dealt with differently. - case e: CancelledKeyException => { + case e: CancelledKeyException => // Some keys within the selectors list are invalid/closed. clear them. val allKeys = selector.keys().iterator() @@ -420,8 +421,11 @@ private[nio] class ConnectionManager( } } } - } - 0 + 0 + + case e: ClosedSelectorException => + logDebug("Failed select() as selector is closed.", e) + return } if (selectedKeysCount == 0) { @@ -988,11 +992,11 @@ private[nio] class ConnectionManager( } def stop() { + isActive = false ackTimeoutMonitor.stop() - selector.wakeup() + selector.close() selectorThread.interrupt() selectorThread.join() - selector.close() val connections = connectionsByKey.values connections.foreach(_.close()) if (connectionsByKey.size != 0) { From 28683b4df5de06373b867068b9b8adfbcaf93176 Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Sat, 18 Apr 2015 16:46:28 -0700 Subject: [PATCH 17/33] [SPARK-6219] Reuse pep8.py Per the discussion in the comments on [this commit](https://github.com/apache/spark/commit/f17d43b033d928dbc46aef8e367aa08902e698ad#commitcomment-10780649), this PR allows the Python lint script to reuse `pep8.py` when possible. Author: Nicholas Chammas Closes #5561 from nchammas/save-dem-pep8-bytes and squashes the following commits: b7c91e6 [Nicholas Chammas] reuse pep8.py --- dev/.gitignore | 1 + dev/lint-python | 21 +++++++++++---------- 2 files changed, 12 insertions(+), 10 deletions(-) create mode 100644 dev/.gitignore diff --git a/dev/.gitignore b/dev/.gitignore new file mode 100644 index 0000000000000..4a6027429e0d3 --- /dev/null +++ b/dev/.gitignore @@ -0,0 +1 @@ +pep8*.py diff --git a/dev/lint-python b/dev/lint-python index fded654893a7c..f50d149dc4d44 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -32,18 +32,19 @@ compile_status="${PIPESTATUS[0]}" #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 #+ TODOs: #+ - Download pep8 from PyPI. It's more "official". -PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py" -PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.6.2/pep8.py" +PEP8_VERSION="1.6.2" +PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8-$PEP8_VERSION.py" +PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/$PEP8_VERSION/pep8.py" -# if [ ! -e "$PEP8_SCRIPT_PATH" ]; then -curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" -curl_status="$?" +if [ ! -e "$PEP8_SCRIPT_PATH" ]; then + curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" + curl_status="$?" -if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." - exit "$curl_status" + if [ "$curl_status" -ne 0 ]; then + echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + exit "$curl_status" + fi fi -# fi # There is no need to write this output to a file #+ first, but we do so so that the check status can @@ -65,7 +66,7 @@ else echo "Python lint checks passed." fi -rm "$PEP8_SCRIPT_PATH" +# rm "$PEP8_SCRIPT_PATH" rm "$PYTHON_LINT_REPORT_PATH" exit "$lint_status" From 729885ec6b4be61144d04821f1a6e8d2134eea00 Mon Sep 17 00:00:00 2001 From: Gaurav Nanda Date: Sat, 18 Apr 2015 17:20:46 -0700 Subject: [PATCH 18/33] Fixed doc Just fixed a doc. Author: Gaurav Nanda Closes #5576 from gaurav324/master and squashes the following commits: 8a7323f [Gaurav Nanda] Fixed doc --- docs/mllib-linear-methods.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 9270741d439d9..2b2be4d9d0273 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -377,7 +377,7 @@ references. Here is an [detailed mathematical derivation](http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297). -For multiclass classification problems, the algorithm will outputs a multinomial logistic regression +For multiclass classification problems, the algorithm will output a multinomial logistic regression model, which contains $K - 1$ binary logistic regression models regressed against the first class. Given a new data points, $K - 1$ models will be run, and the class with largest probability will be chosen as the predicted class. From 8fbd45c74e762dd6b071ea58a60f5bb649f74042 Mon Sep 17 00:00:00 2001 From: Olivier Girardot Date: Sat, 18 Apr 2015 18:21:44 -0700 Subject: [PATCH 19/33] SPARK-6993 : Add default min, max methods for JavaDoubleRDD The default method will use Guava's Ordering instead of java.util.Comparator.naturalOrder() because it's not available in Java 7, only in Java 8. Author: Olivier Girardot Closes #5571 from ogirardot/master and squashes the following commits: 7fe2e9e [Olivier Girardot] SPARK-6993 : Add default min, max methods for JavaDoubleRDD --- .../org/apache/spark/api/java/JavaDoubleRDD.scala | 14 ++++++++++++++ .../test/java/org/apache/spark/JavaAPISuite.java | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 79e4ebf2db578..61af867b11b9c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -163,6 +163,20 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) /** Add up the elements in this RDD. */ def sum(): JDouble = srdd.sum() + /** + * Returns the minimum element from this RDD as defined by + * the default comparator natural order. + * @return the minimum of the RDD + */ + def min(): JDouble = min(com.google.common.collect.Ordering.natural()) + + /** + * Returns the maximum element from this RDD as defined by + * the default comparator natural order. + * @return the maximum of the RDD + */ + def max(): JDouble = max(com.google.common.collect.Ordering.natural()) + /** * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and * count of the RDD's elements in one operation. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index d4b5bb519157c..8a4f2a08fe701 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -761,6 +761,20 @@ public void min() { Assert.assertEquals(1.0, max, 0.001); } + @Test + public void naturalMax() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.max(); + Assert.assertTrue(4.0 == max); + } + + @Test + public void naturalMin() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + double max = rdd.min(); + Assert.assertTrue(1.0 == max); + } + @Test public void takeOrdered() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); From 0424da68d4c81dc3a9944d8485feb1233c6633c4 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sun, 19 Apr 2015 09:37:09 +0100 Subject: [PATCH 20/33] [SPARK-6963][CORE]Flaky test: o.a.s.ContextCleanerSuite automatically cleanup checkpoint cc andrewor14 Author: GuoQiang Li Closes #5548 from witgo/SPARK-6963 and squashes the following commits: 964aea7 [GuoQiang Li] review commits b08b3c9 [GuoQiang Li] Flaky test: o.a.s.ContextCleanerSuite automatically cleanup checkpoint --- .../org/apache/spark/ContextCleaner.scala | 2 ++ .../apache/spark/ContextCleanerSuite.scala | 21 +++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 715b259057569..37198d887b07b 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -236,6 +236,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { try { logDebug("Cleaning rdd checkpoint data " + rddId) RDDCheckpointData.clearRDDCheckpointData(sc, rddId) + listeners.foreach(_.checkpointCleaned(rddId)) logInfo("Cleaned rdd checkpoint data " + rddId) } catch { @@ -260,4 +261,5 @@ private[spark] trait CleanerListener { def shuffleCleaned(shuffleId: Int) def broadcastCleaned(broadcastId: Long) def accumCleaned(accId: Long) + def checkpointCleaned(rddId: Long) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 097e7076e5391..c7868ddcf770f 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -224,7 +224,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { assert(fs.exists(path)) // the checkpoint is not cleaned by default (without the configuration set) - var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil) + var postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Nil) rdd = null // Make RDD out of scope runGC() postGCTester.assertCleanup() @@ -245,7 +245,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { assert(fs.exists(RDDCheckpointData.rddCheckpointDataPath(sc, rddId).get)) // Test that GC causes checkpoint data cleanup after dereferencing the RDD - postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil) + postGCTester = new CleanerTester(sc, Seq(rddId), Nil, Nil, Seq(rddId)) rdd = null // Make RDD out of scope runGC() postGCTester.assertCleanup() @@ -406,12 +406,14 @@ class CleanerTester( sc: SparkContext, rddIds: Seq[Int] = Seq.empty, shuffleIds: Seq[Int] = Seq.empty, - broadcastIds: Seq[Long] = Seq.empty) + broadcastIds: Seq[Long] = Seq.empty, + checkpointIds: Seq[Long] = Seq.empty) extends Logging { val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds + val toBeCheckpointIds = new HashSet[Long] with SynchronizedSet[Long] ++= checkpointIds val isDistributed = !sc.isLocal val cleanerListener = new CleanerListener { @@ -427,12 +429,17 @@ class CleanerTester( def broadcastCleaned(broadcastId: Long): Unit = { toBeCleanedBroadcstIds -= broadcastId - logInfo("Broadcast" + broadcastId + " cleaned") + logInfo("Broadcast " + broadcastId + " cleaned") } def accumCleaned(accId: Long): Unit = { logInfo("Cleaned accId " + accId + " cleaned") } + + def checkpointCleaned(rddId: Long): Unit = { + toBeCheckpointIds -= rddId + logInfo("checkpoint " + rddId + " cleaned") + } } val MAX_VALIDATION_ATTEMPTS = 10 @@ -456,7 +463,8 @@ class CleanerTester( /** Verify that RDDs, shuffles, etc. occupy resources */ private def preCleanupValidate() { - assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup") + assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty || + checkpointIds.nonEmpty, "Nothing to cleanup") // Verify the RDDs have been persisted and blocks are present rddIds.foreach { rddId => @@ -547,7 +555,8 @@ class CleanerTester( private def isAllCleanedUp = toBeCleanedRDDIds.isEmpty && toBeCleanedShuffleIds.isEmpty && - toBeCleanedBroadcstIds.isEmpty + toBeCleanedBroadcstIds.isEmpty && + toBeCheckpointIds.isEmpty private def getRDDBlocks(rddId: Int): Seq[BlockId] = { blockManager.master.getMatchingBlockIds( _ match { From fa73da024000386eecef79573e8ac96d6f05b2c7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 19 Apr 2015 20:33:51 -0700 Subject: [PATCH 21/33] [SPARK-6998][MLlib] Make StreamingKMeans 'Serializable' If `StreamingKMeans` is not `Serializable`, we cannot do checkpoint for applications that using `StreamingKMeans`. So we should make it `Serializable`. Author: zsxwing Closes #5582 from zsxwing/SPARK-6998 and squashes the following commits: 67c2a14 [zsxwing] Make StreamingKMeans 'Serializable' --- .../org/apache/spark/mllib/clustering/StreamingKMeans.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index f483fd1c7d2cf..d4606fda37b0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -165,7 +165,7 @@ class StreamingKMeansModel( class StreamingKMeans( var k: Int, var decayFactor: Double, - var timeUnit: String) extends Logging { + var timeUnit: String) extends Logging with Serializable { def this() = this(2, 1.0, StreamingKMeans.BATCHES) From d8e1b7b06c499289ff3ce5ec91ff354493a17c48 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 19 Apr 2015 20:35:43 -0700 Subject: [PATCH 22/33] [SPARK-6983][Streaming] Update ReceiverTrackerActor to use the new Rpc interface A subtask of [SPARK-5293](https://issues.apache.org/jira/browse/SPARK-5293) Author: zsxwing Closes #5557 from zsxwing/SPARK-6983 and squashes the following commits: e777e9f [zsxwing] Update ReceiverTrackerActor to use the new Rpc interface --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 2 +- .../receiver/ReceiverSupervisorImpl.scala | 52 +++++---------- .../streaming/scheduler/ReceiverInfo.scala | 4 +- .../streaming/scheduler/ReceiverTracker.scala | 64 ++++++++++--------- 4 files changed, 52 insertions(+), 70 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index f2c1c86af767e..cba038ca355d7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -258,7 +258,7 @@ private[spark] trait RpcEndpoint { final def stop(): Unit = { val _self = self if (_self != null) { - rpcEnv.stop(self) + rpcEnv.stop(_self) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 8f2f1fef76874..89af40330b9d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -21,18 +21,16 @@ import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await -import akka.actor.{ActorRef, Actor, Props} -import akka.pattern.ask import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] @@ -63,37 +61,23 @@ private[streaming] class ReceiverSupervisorImpl( } - /** Remote Akka actor for the ReceiverTracker */ - private val trackerActor = { - val ip = env.conf.get("spark.driver.host", "localhost") - val port = env.conf.getInt("spark.driver.port", 7077) - val url = AkkaUtils.address( - AkkaUtils.protocol(env.actorSystem), - SparkEnv.driverActorSystemName, - ip, - port, - "ReceiverTracker") - env.actorSystem.actorSelection(url) - } - - /** Timeout for Akka actor messages */ - private val askTimeout = AkkaUtils.askTimeout(env.conf) + /** Remote RpcEndpointRef for the ReceiverTracker */ + private val trackerEndpoint = RpcUtils.makeDriverRef("ReceiverTracker", env.conf, env.rpcEnv) - /** Akka actor for receiving messages from the ReceiverTracker in the driver */ - private val actor = env.actorSystem.actorOf( - Props(new Actor { + /** RpcEndpointRef for receiving messages from the ReceiverTracker in the driver */ + private val endpoint = env.rpcEnv.setupEndpoint( + "Receiver-" + streamId + "-" + System.currentTimeMillis(), new ThreadSafeRpcEndpoint { + override val rpcEnv: RpcEnv = env.rpcEnv override def receive: PartialFunction[Any, Unit] = { case StopReceiver => logInfo("Received stop signal") - stop("Stopped by driver", None) + ReceiverSupervisorImpl.this.stop("Stopped by driver", None) case CleanupOldBlocks(threshTime) => logDebug("Received delete old batch signal") cleanupOldBlocks(threshTime) } - - def ref: ActorRef = self - }), "Receiver-" + streamId + "-" + System.currentTimeMillis()) + }) /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) @@ -162,15 +146,14 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val blockInfo = ReceivedBlockInfo(streamId, numRecords, blockStoreResult) - val future = trackerActor.ask(AddBlock(blockInfo))(askTimeout) - Await.result(future, askTimeout) + trackerEndpoint.askWithReply[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") } /** Report error to the receiver tracker */ def reportError(message: String, error: Throwable) { val errorString = Option(error).map(Throwables.getStackTraceAsString).getOrElse("") - trackerActor ! ReportError(streamId, message, errorString) + trackerEndpoint.send(ReportError(streamId, message, errorString)) logWarning("Reported error " + message + " - " + error) } @@ -180,22 +163,19 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onStop(message: String, error: Option[Throwable]) { blockGenerator.stop() - env.actorSystem.stop(actor) + env.rpcEnv.stop(endpoint) } override protected def onReceiverStart() { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, Utils.localHostName(), actor) - val future = trackerActor.ask(msg)(askTimeout) - Await.result(future, askTimeout) + streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) + trackerEndpoint.askWithReply[Boolean](msg) } override protected def onReceiverStop(message: String, error: Option[Throwable]) { logInfo("Deregistering receiver " + streamId) val errorString = error.map(Throwables.getStackTraceAsString).getOrElse("") - val future = trackerActor.ask( - DeregisterReceiver(streamId, message, errorString))(askTimeout) - Await.result(future, askTimeout) + trackerEndpoint.askWithReply[Boolean](DeregisterReceiver(streamId, message, errorString)) logInfo("Stopped receiver " + streamId) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index d7e39c528c519..52f08b9c9de68 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.scheduler -import akka.actor.ActorRef import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef /** * :: DeveloperApi :: @@ -28,7 +28,7 @@ import org.apache.spark.annotation.DeveloperApi case class ReceiverInfo( streamId: Int, name: String, - private[streaming] val actor: ActorRef, + private[streaming] val endpoint: RpcEndpointRef, active: Boolean, location: String, lastErrorMessage: String = "", diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 98900473138fe..c4ead6f30a63d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,13 +17,11 @@ package org.apache.spark.streaming.scheduler - import scala.collection.mutable.{HashMap, SynchronizedMap} import scala.language.existentials -import akka.actor._ - import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException} +import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver} @@ -36,7 +34,7 @@ private[streaming] case class RegisterReceiver( streamId: Int, typ: String, host: String, - receiverActor: ActorRef + receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) extends ReceiverTrackerMessage @@ -67,19 +65,19 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ) private val listenerBus = ssc.scheduler.listenerBus - // actor is created when generator starts. + // endpoint is created when generator starts. // This not being null means the tracker has been started and not stopped - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null - /** Start the actor and receiver execution thread. */ + /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { - if (actor != null) { + if (endpoint != null) { throw new SparkException("ReceiverTracker already started") } if (!receiverInputStreams.isEmpty) { - actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor), - "ReceiverTracker") + endpoint = ssc.env.rpcEnv.setupEndpoint( + "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") } @@ -87,13 +85,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (!receiverInputStreams.isEmpty && actor != null) { + if (!receiverInputStreams.isEmpty && endpoint != null) { // First, stop the receivers if (!skipReceiverLaunch) receiverExecutor.stop(graceful) - // Finally, stop the actor - ssc.env.actorSystem.stop(actor) - actor = null + // Finally, stop the endpoint + ssc.env.rpcEnv.stop(endpoint) + endpoint = null receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") } @@ -129,8 +127,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - receiverInfo.values.flatMap { info => Option(info.actor) } - .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) } + receiverInfo.values.flatMap { info => Option(info.endpoint) } + .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) } } } @@ -139,23 +137,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false streamId: Int, typ: String, host: String, - receiverActor: ActorRef, - sender: ActorRef + receiverEndpoint: RpcEndpointRef, + senderAddress: RpcAddress ) { if (!receiverInputStreamIds.contains(streamId)) { throw new SparkException("Register received for unexpected id " + streamId) } receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverActor, true, host) + streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) - logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address) + logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) } /** Deregister a receiver */ private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { case Some(oldInfo) => - oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error) + oldInfo.copy(endpoint = null, active = false, lastErrorMessage = message, lastError = error) case None => logWarning("No prior receiver info") ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error) @@ -199,19 +197,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false receivedBlockTracker.hasUnallocatedReceivedBlocks } - /** Actor to receive messages from the receivers. */ - private class ReceiverTrackerActor extends Actor { + /** RpcEndpoint to receive messages from the receivers. */ + private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + override def receive: PartialFunction[Any, Unit] = { - case RegisterReceiver(streamId, typ, host, receiverActor) => - registerReceiver(streamId, typ, host, receiverActor, sender) - sender ! true - case AddBlock(receivedBlockInfo) => - sender ! addBlock(receivedBlockInfo) case ReportError(streamId, message, error) => reportError(streamId, message, error) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RegisterReceiver(streamId, typ, host, receiverEndpoint) => + registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + context.reply(true) + case AddBlock(receivedBlockInfo) => + context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) - sender ! true + context.reply(true) } } @@ -314,8 +316,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Stops the receivers. */ private def stopReceivers() { // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.actor)} - .foreach { _ ! StopReceiver } + receiverInfo.values.flatMap { info => Option(info.endpoint)} + .foreach { _.send(StopReceiver) } logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") } } From c776ee8a6fdcdc463746a815b7686e4e33a874a9 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 19 Apr 2015 20:48:36 -0700 Subject: [PATCH 23/33] [SPARK-6979][Streaming] Replace JobScheduler.eventActor and JobGenerator.eventActor with EventLoop Title says it all. cc rxin tdas Author: zsxwing Closes #5554 from zsxwing/SPARK-6979 and squashes the following commits: 5304350 [zsxwing] Fix NotSerializableException e9d3479 [zsxwing] Add blank lines 633e279 [zsxwing] Fix NotSerializableException e496ace [zsxwing] Replace JobGenerator.eventActor with EventLoop ec6ec58 [zsxwing] Fix the import order ce0fa73 [zsxwing] Replace JobScheduler.eventActor with EventLoop --- .../mllib/clustering/StreamingKMeans.scala | 3 +- .../streaming/scheduler/JobGenerator.scala | 38 +++++++++--------- .../streaming/scheduler/JobScheduler.scala | 40 ++++++++++--------- 3 files changed, 42 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index d4606fda37b0d..812014a041719 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -20,8 +20,7 @@ package org.apache.spark.mllib.clustering import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.SparkContext._ -import org.apache.spark.annotation.{Experimental, DeveloperApi} +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.DStream diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 58e56638a2dca..2467d50839add 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -19,12 +19,10 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} -import akka.actor.{ActorRef, Props, Actor} - import org.apache.spark.{SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.util.{Clock, EventLoop, ManualClock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -58,7 +56,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator") + longTime => eventLoop.post(GenerateJobs(new Time(longTime))), "JobGenerator") // This is marked lazy so that this is initialized after checkpoint duration has been set // in the context and the generator has been started. @@ -70,22 +68,26 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { null } - // eventActor is created when generator starts. + // eventLoop is created when generator starts. // This not being null means the scheduler has been started and not stopped - private var eventActor: ActorRef = null + private var eventLoop: EventLoop[JobGeneratorEvent] = null // last batch whose completion,checkpointing and metadata cleanup has been completed private var lastProcessedBatch: Time = null /** Start generation of jobs */ def start(): Unit = synchronized { - if (eventActor != null) return // generator has already been started + if (eventLoop != null) return // generator has already been started + + eventLoop = new EventLoop[JobGeneratorEvent]("JobGenerator") { + override protected def onReceive(event: JobGeneratorEvent): Unit = processEvent(event) - eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - override def receive: PartialFunction[Any, Unit] = { - case event: JobGeneratorEvent => processEvent(event) + override protected def onError(e: Throwable): Unit = { + jobScheduler.reportError("Error in job generator", e) } - }), "JobGenerator") + } + eventLoop.start() + if (ssc.isCheckpointPresent) { restart() } else { @@ -99,7 +101,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * checkpoints written. */ def stop(processReceivedData: Boolean): Unit = synchronized { - if (eventActor == null) return // generator has already been stopped + if (eventLoop == null) return // generator has already been stopped if (processReceivedData) { logInfo("Stopping JobGenerator gracefully") @@ -146,9 +148,9 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { graph.stop() } - // Stop the actor and checkpoint writer + // Stop the event loop and checkpoint writer if (shouldCheckpoint) checkpointWriter.stop() - ssc.env.actorSystem.stop(eventActor) + eventLoop.stop() logInfo("Stopped JobGenerator") } @@ -156,7 +158,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * Callback called when a batch has been completely processed. */ def onBatchCompletion(time: Time) { - eventActor ! ClearMetadata(time) + eventLoop.post(ClearMetadata(time)) } /** @@ -164,7 +166,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { */ def onCheckpointCompletion(time: Time, clearCheckpointDataLater: Boolean) { if (clearCheckpointDataLater) { - eventActor ! ClearCheckpointData(time) + eventLoop.post(ClearCheckpointData(time)) } } @@ -247,7 +249,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) } - eventActor ! DoCheckpoint(time, clearCheckpointDataLater = false) + eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = false)) } /** Clear DStream metadata for the given `time`. */ @@ -257,7 +259,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // If checkpointing is enabled, then checkpoint, // else mark batch to be fully processed if (shouldCheckpoint) { - eventActor ! DoCheckpoint(time, clearCheckpointDataLater = true) + eventLoop.post(DoCheckpoint(time, clearCheckpointDataLater = true)) } else { // If checkpointing is not enabled, then delete metadata information about // received blocks (block data not saved in any case). Otherwise, wait for diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 95f1857b4c377..508b89278dcba 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,13 +17,15 @@ package org.apache.spark.streaming.scheduler -import scala.util.{Failure, Success, Try} -import scala.collection.JavaConversions._ import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} -import akka.actor.{ActorRef, Actor, Props} -import org.apache.spark.{SparkException, Logging, SparkEnv} + +import scala.collection.JavaConversions._ +import scala.util.{Failure, Success} + +import org.apache.spark.Logging import org.apache.spark.rdd.PairRDDFunctions import org.apache.spark.streaming._ +import org.apache.spark.util.EventLoop private[scheduler] sealed trait JobSchedulerEvent @@ -46,20 +48,20 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { val listenerBus = new StreamingListenerBus() // These two are created only when scheduler starts. - // eventActor not being null means the scheduler has been started and not stopped + // eventLoop not being null means the scheduler has been started and not stopped var receiverTracker: ReceiverTracker = null - private var eventActor: ActorRef = null - + private var eventLoop: EventLoop[JobSchedulerEvent] = null def start(): Unit = synchronized { - if (eventActor != null) return // scheduler has already been started + if (eventLoop != null) return // scheduler has already been started logDebug("Starting JobScheduler") - eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { - override def receive: PartialFunction[Any, Unit] = { - case event: JobSchedulerEvent => processEvent(event) - } - }), "JobScheduler") + eventLoop = new EventLoop[JobSchedulerEvent]("JobScheduler") { + override protected def onReceive(event: JobSchedulerEvent): Unit = processEvent(event) + + override protected def onError(e: Throwable): Unit = reportError("Error in job scheduler", e) + } + eventLoop.start() listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) @@ -69,7 +71,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def stop(processAllReceivedData: Boolean): Unit = synchronized { - if (eventActor == null) return // scheduler has already been stopped + if (eventLoop == null) return // scheduler has already been stopped logDebug("Stopping JobScheduler") // First, stop receiving @@ -96,8 +98,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // Stop everything else listenerBus.stop() - ssc.env.actorSystem.stop(eventActor) - eventActor = null + eventLoop.stop() + eventLoop = null logInfo("Stopped JobScheduler") } @@ -117,7 +119,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } def reportError(msg: String, e: Throwable) { - eventActor ! ErrorReported(msg, e) + eventLoop.post(ErrorReported(msg, e)) } private def processEvent(event: JobSchedulerEvent) { @@ -172,14 +174,14 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private class JobHandler(job: Job) extends Runnable { def run() { - eventActor ! JobStarted(job) + eventLoop.post(JobStarted(job)) // Disable checks for existing output directories in jobs launched by the streaming scheduler, // since we may need to write output to an existing directory during checkpoint recovery; // see SPARK-4835 for more details. PairRDDFunctions.disableOutputSpecValidation.withValue(true) { job.run() } - eventActor ! JobCompleted(job) + eventLoop.post(JobCompleted(job)) } } } From 6fe690d5a8216ba7efde4b52e7a19fb00814341c Mon Sep 17 00:00:00 2001 From: dobashim Date: Mon, 20 Apr 2015 00:03:23 -0400 Subject: [PATCH 24/33] [doc][mllib] Fix typo of the page title in Isotonic regression documents * Fix the page title in Isotonic regression documents (Naive Bayes -> Isotonic regression) * Add a newline character at the end of the file Author: dobashim Closes #5581 from dobashim/master and squashes the following commits: d54a041 [dobashim] Fix typo of the page title in Isotonic regression documents --- docs/mllib-isotonic-regression.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index 12fb29d426741..b521c2f27cd6e 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -1,6 +1,6 @@ --- layout: global -title: Naive Bayes - MLlib +title: Isotonic regression - MLlib displayTitle: MLlib - Regression --- @@ -152,4 +152,4 @@ Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map( System.out.println("Mean Squared Error = " + meanSquaredError); {% endhighlight %}
- \ No newline at end of file + From 1be207078cef48c5935595969bf9f6b1ec1334ca Mon Sep 17 00:00:00 2001 From: jrabary Date: Mon, 20 Apr 2015 09:47:56 -0700 Subject: [PATCH 25/33] [SPARK-5924] Add the ability to specify withMean or withStd parameters with StandarScaler The current implementation call the default constructor of mllib.feature.StandarScaler without the possibility to specify withMean or withStd options. Author: jrabary Closes #4704 from jrabary/master and squashes the following commits: fae8568 [jrabary] style fix 8896b0e [jrabary] Comments fix ef96d73 [jrabary] style fix 8e52607 [jrabary] style fix edd9d48 [jrabary] Fix default param initialization 17e1a76 [jrabary] Fix default param initialization 298f405 [jrabary] Typo fix 45ed914 [jrabary] Add withMean and withStd params to StandarScaler --- .../spark/ml/feature/StandardScaler.scala | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1b102619b3524..447851ec034d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -30,7 +30,22 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * Params for [[StandardScaler]] and [[StandardScalerModel]]. */ -private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol +private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * False by default. Centers the data with mean before scaling. + * It will build a dense output, so this does not work on sparse input + * and will raise an exception. + * @group param + */ + val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean") + + /** + * True by default. Scales the data to unit standard deviation. + * @group param + */ + val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation") +} /** * :: AlphaComponent :: @@ -40,18 +55,27 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with @AlphaComponent class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + setDefault(withMean -> false, withStd -> true) + /** @group setParam */ def setInputCol(value: String): this.type = set(inputCol, value) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - + + /** @group setParam */ + def setWithMean(value: Boolean): this.type = set(withMean, value) + + /** @group setParam */ + def setWithStd(value: Boolean): this.type = set(withStd, value) + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) val map = extractParamMap(paramMap) val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } - val scaler = new feature.StandardScaler().fit(input) - val model = new StandardScalerModel(this, map, scaler) + val scaler = new feature.StandardScaler(withMean = map(withMean), withStd = map(withStd)) + val scalerModel = scaler.fit(input) + val model = new StandardScalerModel(this, map, scalerModel) Params.inheritValues(map, this, model) model } From 968ad972175390bb0a96918fd3c7b318d70fa466 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Mon, 20 Apr 2015 09:54:21 -0700 Subject: [PATCH 26/33] [SPARK-7003] Improve reliability of connection failure detection between Netty block transfer service endpoints Currently we rely on the assumption that an exception will be raised and the channel closed if two endpoints cannot communicate over a Netty TCP channel. However, this guarantee does not hold in all network environments, and [SPARK-6962](https://issues.apache.org/jira/browse/SPARK-6962) seems to point to a case where only the server side of the connection detected a fault. This patch improves robustness of fetch/rpc requests by having an explicit timeout in the transport layer which closes the connection if there is a period of inactivity while there are outstanding requests. NB: This patch is actually only around 50 lines added if you exclude the testing-related code. Author: Aaron Davidson Closes #5584 from aarondav/timeout and squashes the following commits: 8699680 [Aaron Davidson] Address Reynold's comments 37ce656 [Aaron Davidson] [SPARK-7003] Improve reliability of connection failure detection between Netty block transfer service endpoints --- .../spark/network/TransportContext.java | 5 +- .../client/TransportResponseHandler.java | 14 +- .../server/TransportChannelHandler.java | 33 ++- .../spark/network/util/MapConfigProvider.java | 41 +++ .../apache/spark/network/util/NettyUtils.java | 2 +- .../RequestTimeoutIntegrationSuite.java | 277 ++++++++++++++++++ .../network/TransportClientFactorySuite.java | 21 +- 7 files changed, 375 insertions(+), 18 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java create mode 100644 network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index f0a89c9d9116c..3fe69b1bd8851 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -22,6 +22,7 @@ import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -106,6 +107,7 @@ public TransportChannelHandler initializePipeline(SocketChannel channel) { .addLast("encoder", encoder) .addLast("frameDecoder", NettyUtils.createFrameDecoder()) .addLast("decoder", decoder) + .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. .addLast("handler", channelHandler); @@ -126,7 +128,8 @@ private TransportChannelHandler createChannelHandler(Channel channel) { TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); - return new TransportChannelHandler(client, responseHandler, requestHandler); + return new TransportChannelHandler(client, responseHandler, requestHandler, + conf.connectionTimeoutMs()); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 2044afb0d85db..94fc21af5e606 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -20,8 +20,8 @@ import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; -import com.google.common.annotations.VisibleForTesting; import io.netty.channel.Channel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,13 +50,18 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingRpcs; + /** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */ + private final AtomicLong timeOfLastRequestNs; + public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap(); this.outstandingRpcs = new ConcurrentHashMap(); + this.timeOfLastRequestNs = new AtomicLong(0); } public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingFetches.put(streamChunkId, callback); } @@ -65,6 +70,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { } public void addRpcRequest(long requestId, RpcResponseCallback callback) { + timeOfLastRequestNs.set(System.nanoTime()); outstandingRpcs.put(requestId, callback); } @@ -161,8 +167,12 @@ public void handle(ResponseMessage message) { } /** Returns total number of outstanding requests (fetch requests + rpcs) */ - @VisibleForTesting public int numOutstandingRequests() { return outstandingFetches.size() + outstandingRpcs.size(); } + + /** Returns the time in nanoseconds of when the last request was sent out. */ + public long getTimeOfLastRequestNs() { + return timeOfLastRequestNs.get(); + } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index e491367fa4528..8e0ee709e38e3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -19,6 +19,8 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.timeout.IdleState; +import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,6 +42,11 @@ * Client. * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, * for the Client's responses to the Server's requests. + * + * This class also handles timeouts from a {@link io.netty.handler.timeout.IdleStateHandler}. + * We consider a connection timed out if there are outstanding fetch or RPC requests but no traffic + * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not + * timeout if the client is continuously sending but getting no responses, for simplicity. */ public class TransportChannelHandler extends SimpleChannelInboundHandler { private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); @@ -47,14 +54,17 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler 0; + boolean isActuallyOverdue = + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; + if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + ctx.close(); + } + } + } } diff --git a/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java new file mode 100644 index 0000000000000..668d2356b955d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import com.google.common.collect.Maps; + +import java.util.Map; +import java.util.NoSuchElementException; + +/** ConfigProvider based on a Map (copied in the constructor). */ +public class MapConfigProvider extends ConfigProvider { + private final Map config; + + public MapConfigProvider(Map config) { + this.config = Maps.newHashMap(config); + } + + @Override + public String get(String name) { + String value = config.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index dabd6261d2aa0..26c6399ce7dbc 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -98,7 +98,7 @@ public static ByteToMessageDecoder createFrameDecoder() { return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); } - /** Returns the remote address on the channel or "<remote address>" if none exists. */ + /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ public static String getRemoteAddress(Channel channel) { if (channel != null && channel.remoteAddress() != null) { return channel.remoteAddress().toString(); diff --git a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java new file mode 100644 index 0000000000000..84ebb337e6d54 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import com.google.common.collect.Maps; +import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.junit.*; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** + * Suite which ensures that requests that go without a response for the network timeout period are + * failed, and the connection closed. + * + * In this suite, we use 2 seconds as the connection timeout, with some slack given in the tests, + * to ensure stability in different test environments. + */ +public class RequestTimeoutIntegrationSuite { + + private TransportServer server; + private TransportClientFactory clientFactory; + + private StreamManager defaultManager; + private TransportConf conf; + + // A large timeout that "shouldn't happen", for the sake of faulty tests not hanging forever. + private final int FOREVER = 60 * 1000; + + @Before + public void setUp() throws Exception { + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.connectionTimeout", "2s"); + conf = new TransportConf(new MapConfigProvider(configMap)); + + defaultManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + throw new UnsupportedOperationException(); + } + }; + } + + @After + public void tearDown() { + if (server != null) { + server.close(); + } + if (clientFactory != null) { + clientFactory.close(); + } + } + + // Basic suite: First request completes quickly, and second waits for longer than network timeout. + @Test + public void timeoutInactiveRequests() throws Exception { + final Semaphore semaphore = new Semaphore(1); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // First completes quickly (semaphore starts at 1). + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.success.length == response.length); + } + + // Second times out after 2 seconds, with slack. Must be IOException. + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client.sendRpc(new byte[0], callback1); + callback1.wait(4 * 1000); + assert (callback1.failure != null); + assert (callback1.failure instanceof IOException); + } + semaphore.release(); + } + + // A timeout will cause the connection to be closed, invalidating the current TransportClient. + // It should be the case that requesting a client from the factory produces a new, valid one. + @Test + public void timeoutCleanlyClosesClient() throws Exception { + final Semaphore semaphore = new Semaphore(0); + final byte[] response = new byte[16]; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + try { + semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS); + callback.onSuccess(response); + } catch (InterruptedException e) { + // do nothing + } + } + + @Override + public StreamManager getStreamManager() { + return defaultManager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + + // First request should eventually fail. + TransportClient client0 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback0 = new TestCallback(); + synchronized (callback0) { + client0.sendRpc(new byte[0], callback0); + callback0.wait(FOREVER); + assert (callback0.failure instanceof IOException); + assert (!client0.isActive()); + } + + // Increment the semaphore and the second request should succeed quickly. + semaphore.release(2); + TransportClient client1 = + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + TestCallback callback1 = new TestCallback(); + synchronized (callback1) { + client1.sendRpc(new byte[0], callback1); + callback1.wait(FOREVER); + assert (callback1.success.length == response.length); + assert (callback1.failure == null); + } + } + + // The timeout is relative to the LAST request sent, which is kinda weird, but still. + // This test also makes sure the timeout works for Fetch requests as well as RPCs. + @Test + public void furtherRequestsDelay() throws Exception { + final byte[] response = new byte[16]; + final StreamManager manager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); + return new NioManagedBuffer(ByteBuffer.wrap(response)); + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return manager; + } + }; + + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + + // Send one request, which will eventually fail. + TestCallback callback0 = new TestCallback(); + client.fetchChunk(0, 0, callback0); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + // Send a second request before the first has failed. + TestCallback callback1 = new TestCallback(); + client.fetchChunk(0, 1, callback1); + Uninterruptibles.sleepUninterruptibly(1200, TimeUnit.MILLISECONDS); + + synchronized (callback0) { + // not complete yet, but should complete soon + assert (callback0.success == null && callback0.failure == null); + callback0.wait(2 * 1000); + assert (callback0.failure instanceof IOException); + } + + synchronized (callback1) { + // failed at same time as previous + assert (callback0.failure instanceof IOException); + } + } + + /** + * Callback which sets 'success' or 'failure' on completion. + * Additionally notifies all waiters on this callback when invoked. + */ + class TestCallback implements RpcResponseCallback, ChunkReceivedCallback { + + byte[] success; + Throwable failure; + + @Override + public void onSuccess(byte[] response) { + synchronized(this) { + success = response; + this.notifyAll(); + } + } + + @Override + public void onFailure(Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + synchronized(this) { + try { + success = buffer.nioByteBuffer().array(); + this.notifyAll(); + } catch (IOException e) { + // weird + } + } + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + synchronized(this) { + failure = e; + this.notifyAll(); + } + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 416dc1b969fa4..35de5e57ccb98 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -20,10 +20,11 @@ import java.io.IOException; import java.util.Collections; import java.util.HashSet; -import java.util.NoSuchElementException; +import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; +import com.google.common.collect.Maps; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -36,9 +37,9 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -70,16 +71,10 @@ public void tearDown() { */ private void testClientReuse(final int maxConnections, boolean concurrent) throws IOException, InterruptedException { - TransportConf conf = new TransportConf(new ConfigProvider() { - @Override - public String get(String name) { - if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) { - return Integer.toString(maxConnections); - } else { - throw new NoSuchElementException(); - } - } - }); + + Map configMap = Maps.newHashMap(); + configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); + TransportConf conf = new TransportConf(new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); From 77176619a97d07811ab20e1dde4677359d85eb33 Mon Sep 17 00:00:00 2001 From: Elisey Zanko Date: Mon, 20 Apr 2015 10:44:09 -0700 Subject: [PATCH 27/33] [SPARK-6661] Python type errors should print type, not object Author: Elisey Zanko Closes #5361 from 31z4/spark-6661 and squashes the following commits: 73c5d79 [Elisey Zanko] Python type errors should print type, not object --- python/pyspark/accumulators.py | 2 +- python/pyspark/context.py | 2 +- python/pyspark/ml/param/__init__.py | 2 +- python/pyspark/ml/pipeline.py | 4 ++-- python/pyspark/mllib/linalg.py | 4 ++-- python/pyspark/mllib/regression.py | 2 +- python/pyspark/mllib/tests.py | 6 ++++-- python/pyspark/sql/_types.py | 12 ++++++------ python/pyspark/sql/context.py | 8 ++++---- python/pyspark/sql/dataframe.py | 2 +- 10 files changed, 23 insertions(+), 21 deletions(-) diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 7271809e43880..0d21a132048a5 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -83,7 +83,7 @@ >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... -Exception:... +TypeError:... """ import sys diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 1dc2fec0ae5c8..6a743ac8bd600 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -671,7 +671,7 @@ def accumulator(self, value, accum_param=None): elif isinstance(value, complex): accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM else: - raise Exception("No default accumulator param for type %s" % type(value)) + raise TypeError("No default accumulator param for type %s" % type(value)) SparkContext._next_accum_id += 1 return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 9fccb65675185..49c20b4cf70cf 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -30,7 +30,7 @@ class Param(object): def __init__(self, parent, name, doc): if not isinstance(parent, Params): - raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__) + raise TypeError("Parent must be a Params but got type %s." % type(parent)) self.parent = parent self.name = str(name) self.doc = str(doc) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index d94ecfff09f66..7c1ec3026da6f 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -131,8 +131,8 @@ def fit(self, dataset, params={}): stages = paramMap[self.stages] for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): - raise ValueError( - "Cannot recognize a pipeline stage of type %s." % type(stage).__name__) + raise TypeError( + "Cannot recognize a pipeline stage of type %s." % type(stage)) indexOfLastEstimator = -1 for i, stage in enumerate(stages): if isinstance(stage, Estimator): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 38b3aa3ad460e..ec8c879ea9389 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -145,7 +145,7 @@ def serialize(self, obj): values = [float(v) for v in obj] return (1, None, None, values) else: - raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + raise TypeError("cannot serialize %r of type %r" % (obj, type(obj))) def deserialize(self, datum): assert len(datum) == 4, \ @@ -561,7 +561,7 @@ def __getitem__(self, index): inds = self.indices vals = self.values if not isinstance(index, int): - raise ValueError( + raise TypeError( "Indices must be of type integer, got type %s" % type(index)) if index < 0: index += self.size diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index cd7310a64f4ae..a0117c57133ae 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -170,7 +170,7 @@ def _regression_train_wrapper(train_func, modelClass, data, initial_weights): from pyspark.mllib.classification import LogisticRegressionModel first = data.first() if not isinstance(first, LabeledPoint): - raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) + raise TypeError("data should be an RDD of LabeledPoint, but got %s" % type(first)) if initial_weights is None: initial_weights = [0.0] * len(data.first().features) if (modelClass == LogisticRegressionModel): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index c6ed5acd1770e..849c88341a967 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -135,8 +135,10 @@ def test_sparse_vector_indexing(self): self.assertEquals(sv[-1], 2) self.assertEquals(sv[-2], 0) self.assertEquals(sv[-4], 0) - for ind in [4, -5, 7.8]: + for ind in [4, -5]: self.assertRaises(ValueError, sv.__getitem__, ind) + for ind in [7.8, '1']: + self.assertRaises(TypeError, sv.__getitem__, ind) def test_matrix_indexing(self): mat = DenseMatrix(3, 2, [0, 1, 4, 6, 8, 10]) @@ -450,7 +452,7 @@ def test_infer_schema(self): elif isinstance(v, DenseVector): self.assertEqual(v, self.dv1) else: - raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + raise TypeError("expecting a vector but got %r of type %r" % (v, type(v))) @unittest.skipIf(not _have_scipy, "SciPy not installed") diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py index 492c0cbdcf693..110d1152fbdb6 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/_types.py @@ -562,8 +562,8 @@ def _infer_type(obj): else: try: return _infer_schema(obj) - except ValueError: - raise ValueError("not supported type: %s" % type(obj)) + except TypeError: + raise TypeError("not supported type: %s" % type(obj)) def _infer_schema(row): @@ -584,7 +584,7 @@ def _infer_schema(row): items = sorted(row.__dict__.items()) else: - raise ValueError("Can not infer schema for type: %s" % type(row)) + raise TypeError("Can not infer schema for type: %s" % type(row)) fields = [StructField(k, _infer_type(v), True) for k, v in items] return StructType(fields) @@ -696,7 +696,7 @@ def _merge_type(a, b): return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) - raise TypeError("Can not merge type %s and %s" % (a, b)) + raise TypeError("Can not merge type %s and %s" % (type(a), type(b))) # same type if isinstance(a, StructType): @@ -773,7 +773,7 @@ def convert_struct(obj): elif hasattr(obj, "__dict__"): # object d = obj.__dict__ else: - raise ValueError("Unexpected obj: %s" % obj) + raise TypeError("Unexpected obj type: %s" % type(obj)) if convert_fields: return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) @@ -912,7 +912,7 @@ def _infer_schema_type(obj, dataType): return StructType(fields) else: - raise ValueError("Unexpected dataType: %s" % dataType) + raise TypeError("Unexpected dataType: %s" % type(dataType)) _acceptable_types = { diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index c90afc326ca0e..acf3c114548c0 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -208,7 +208,7 @@ def applySchema(self, rdd, schema): raise TypeError("Cannot apply schema to DataFrame") if not isinstance(schema, StructType): - raise TypeError("schema should be StructType, but got %s" % schema) + raise TypeError("schema should be StructType, but got %s" % type(schema)) return self.createDataFrame(rdd, schema) @@ -281,7 +281,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) except Exception: - raise ValueError("cannot create an RDD from type: %s" % type(data)) + raise TypeError("cannot create an RDD from type: %s" % type(data)) else: rdd = data @@ -293,8 +293,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if isinstance(schema, (list, tuple)): first = rdd.first() if not isinstance(first, (list, tuple)): - raise ValueError("each row in `rdd` should be list or tuple, " - "but got %r" % type(first)) + raise TypeError("each row in `rdd` should be list or tuple, " + "but got %r" % type(first)) row_cls = Row(*schema) schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d70c5b0a6930c..75c181c0c7f5e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -608,7 +608,7 @@ def __getitem__(self, item): jc = self._jdf.apply(self.columns[item]) return Column(jc) else: - raise TypeError("unexpected type: %s" % type(item)) + raise TypeError("unexpected item type: %s" % type(item)) def __getattr__(self, name): """Returns the :class:`Column` denoted by ``name``. From 1ebceaa55bec28850a48fb28b4cf7b44c8448a78 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Apr 2015 10:47:37 -0700 Subject: [PATCH 28/33] [Minor][MLlib] Incorrect path to test data is used in DecisionTreeExample It should load from `testInput` instead of `input` for test data. Author: Liang-Chi Hsieh Closes #5594 from viirya/use_testinput and squashes the following commits: 5e8b174 [Liang-Chi Hsieh] Fix style. b60b475 [Liang-Chi Hsieh] Use testInput. --- .../org/apache/spark/examples/ml/DecisionTreeExample.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index d4cc8dede07ef..921b396e799e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -173,7 +173,8 @@ object DecisionTreeExample { val splits: Array[RDD[LabeledPoint]] = if (testInput != "") { // Load testInput. val numFeatures = origExamples.take(1)(0).features.size - val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures)) + val origTestExamples: RDD[LabeledPoint] = + loadData(sc, testInput, dataFormat, Some(numFeatures)) Array(origExamples, origTestExamples) } else { // Split input into training, test. From 97fda73db4efda2ba5b12937954de428258a5b56 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Mon, 20 Apr 2015 13:11:21 -0700 Subject: [PATCH 29/33] fixed doc The contribution is my original work. I license the work to the project under the project's open source license. Small typo in the programming guide. Author: Eric Chiang Closes #5599 from ericchiang/docs-typo and squashes the following commits: 1177942 [Eric Chiang] fixed doc --- docs/programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index f4fabb0927b66..27816515c5de2 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1093,7 +1093,7 @@ for details. ### Shuffle operations Certain operations within Spark trigger an event known as the shuffle. The shuffle is Spark's -mechanism for re-distributing data so that is grouped differently across partitions. This typically +mechanism for re-distributing data so that it's grouped differently across partitions. This typically involves copying data across executors and machines, making the shuffle a complex and costly operation. From 517bdf36aecdc94ef569b68f0a96892e707b5c7b Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 20 Apr 2015 13:46:55 -0700 Subject: [PATCH 30/33] [doc][streaming] Fixed broken link in mllib section The commit message is pretty self-explanatory. Author: BenFradet Closes #5600 from BenFradet/master and squashes the following commits: 108492d [BenFradet] [doc][streaming] Fixed broken link in mllib section --- docs/streaming-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 262512a639046..2f2fea53168a3 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1588,7 +1588,7 @@ See the [DataFrames and SQL](sql-programming-guide.html) guide to learn more abo *** ## MLlib Operations -You can also easily use machine learning algorithms provided by [MLlib](mllib-guide.html). First of all, there are streaming machine learning algorithms (e.g. (Streaming Linear Regression](mllib-linear-methods.html#streaming-linear-regression), [Streaming KMeans](mllib-clustering.html#streaming-k-means), etc.) which can simultaneously learn from the streaming data as well as apply the model on the streaming data. Beyond these, for a much larger class of machine learning algorithms, you can learn a learning model offline (i.e. using historical data) and then apply the model online on streaming data. See the [MLlib](mllib-guide.html) guide for more details. +You can also easily use machine learning algorithms provided by [MLlib](mllib-guide.html). First of all, there are streaming machine learning algorithms (e.g. [Streaming Linear Regression](mllib-linear-methods.html#streaming-linear-regression), [Streaming KMeans](mllib-clustering.html#streaming-k-means), etc.) which can simultaneously learn from the streaming data as well as apply the model on the streaming data. Beyond these, for a much larger class of machine learning algorithms, you can learn a learning model offline (i.e. using historical data) and then apply the model online on streaming data. See the [MLlib](mllib-guide.html) guide for more details. *** From ce7ddabbcd330b19f6d0c17082304dfa6e1621b2 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 20 Apr 2015 18:42:50 -0700 Subject: [PATCH 31/33] [SPARK-6368][SQL] Build a specialized serializer for Exchange operator. JIRA: https://issues.apache.org/jira/browse/SPARK-6368 Author: Yin Huai Closes #5497 from yhuai/serializer2 and squashes the following commits: da562c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 50e0c3d [Yin Huai] When no filed is emitted to shuffle, use SparkSqlSerializer for now. 9f1ed92 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 6d07678 [Yin Huai] Address comments. 4273b8c [Yin Huai] Enabled SparkSqlSerializer2. 09e587a [Yin Huai] Remove TODO. 791b96a [Yin Huai] Use UTF8String. 60a1487 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 3e09655 [Yin Huai] Use getAs for Date column. 43b9fb4 [Yin Huai] Test. 8297732 [Yin Huai] Fix test. c9373c8 [Yin Huai] Support DecimalType. 2379eeb [Yin Huai] ASF header. 39704ab [Yin Huai] Specialized serializer for Exchange. --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 + .../apache/spark/sql/execution/Exchange.scala | 59 ++- .../sql/execution/SparkSqlSerializer2.scala | 421 ++++++++++++++++++ .../execution/SparkSqlSerializer2Suite.scala | 195 ++++++++ 4 files changed, 673 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5c65f04ee8497..4fc5de7e824fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -64,6 +64,8 @@ private[spark] object SQLConf { // Set to false when debugging requires the ability to look at invalid query plans. val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -147,6 +149,8 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 69a620e1ec929..5b2e46962cd3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} +import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner} import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.serializer.Serializer import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair object Exchange { @@ -77,9 +79,48 @@ case class Exchange( } } - override def execute(): RDD[Row] = attachTree(this , "execute") { - lazy val sparkConf = child.sqlContext.sparkContext.getConf + @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf + + def serializer( + keySchema: Array[DataType], + valueSchema: Array[DataType], + numPartitions: Int): Serializer = { + // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out + // through write(key) and then write(value) instead of write((key, value)). Because + // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use + // it when spillToMergeableFile in ExternalSorter will be used. + // So, we will not use SparkSqlSerializer2 when + // - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater + // then the bypassMergeThreshold; or + // - newOrdering is defined. + val cannotUseSqlSerializer2 = + (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty + + // It is true when there is no field that needs to be write out. + // For now, we will not use SparkSqlSerializer2 when noField is true. + val noField = + (keySchema == null || keySchema.length == 0) && + (valueSchema == null || valueSchema.length == 0) + + val useSqlSerializer2 = + child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. + !cannotUseSqlSerializer2 && // Safe to use Serializer2. + SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. + SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. + !noField + + val serializer = if (useSqlSerializer2) { + logInfo("Using SparkSqlSerializer2.") + new SparkSqlSerializer2(keySchema, valueSchema) + } else { + logInfo("Using SparkSqlSerializer.") + new SparkSqlSerializer(sparkConf) + } + + serializer + } + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -111,7 +152,10 @@ case class Exchange( } else { new ShuffledRDD[Row, Row, Row](rdd, part) } - shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) + val keySchema = expressions.map(_.dataType).toArray + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions)) + shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => @@ -134,7 +178,9 @@ case class Exchange( } else { new ShuffledRDD[Row, Null, Null](rdd, part) } - shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) + val keySchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, null, numPartitions)) + shuffled.map(_._1) case SinglePartition => @@ -152,7 +198,8 @@ case class Exchange( } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(null, valueSchema, 1)) shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala new file mode 100644 index 0000000000000..cec97de2cd8e4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io._ +import java.math.{BigDecimal, BigInteger} +import java.nio.ByteBuffer +import java.sql.Timestamp + +import scala.reflect.ClassTag + +import org.apache.spark.serializer._ +import org.apache.spark.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types._ + +/** + * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in + * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the + * [[Product2]] are constructed based on their schemata. + * The benefit of this serialization stream is that compared with general-purpose serializers like + * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower + * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are: + * 1. It does not support complex types, i.e. Map, Array, and Struct. + * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when + * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because + * the objects passed in the serializer are not in the type of [[Product2]]. Also also see + * the comment of the `serializer` method in [[Exchange]] for more information on it. + */ +private[sql] class Serializer2SerializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + out: OutputStream) + extends SerializationStream with Logging { + + val rowOut = new DataOutputStream(out) + val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) + val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + + def writeObject[T: ClassTag](t: T): SerializationStream = { + val kv = t.asInstanceOf[Product2[Row, Row]] + writeKey(kv._1) + writeValue(kv._2) + + this + } + + def flush(): Unit = { + rowOut.flush() + } + + def close(): Unit = { + rowOut.close() + } +} + +/** + * The corresponding deserialization stream for [[Serializer2SerializationStream]]. + */ +private[sql] class Serializer2DeserializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + in: InputStream) + extends DeserializationStream with Logging { + + val rowIn = new DataInputStream(new BufferedInputStream(in)) + + val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null + val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null + val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) + val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) + + def readObject[T: ClassTag](): T = { + readKey() + readValue() + + (key, value).asInstanceOf[T] + } + + def close(): Unit = { + rowIn.close() + } +} + +private[sql] class ShuffleSerializerInstance( + keySchema: Array[DataType], + valueSchema: Array[DataType]) + extends SerializerInstance { + + def serialize[T: ClassTag](t: T): ByteBuffer = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException("Not supported.") + + def serializeStream(s: OutputStream): SerializationStream = { + new Serializer2SerializationStream(keySchema, valueSchema, s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new Serializer2DeserializationStream(keySchema, valueSchema, s) + } +} + +/** + * SparkSqlSerializer2 is a special serializer that creates serialization function and + * deserialization function based on the schema of data. It assumes that values passed in + * are key/value pairs and values returned from it are also key/value pairs. + * The schema of keys is represented by `keySchema` and that of values is represented by + * `valueSchema`. + */ +private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType]) + extends Serializer + with Logging + with Serializable{ + + def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema) +} + +private[sql] object SparkSqlSerializer2 { + + final val NULL = 0 + final val NOT_NULL = 1 + + /** + * Check if rows with the given schema can be serialized with ShuffleSerializer. + */ + def support(schema: Array[DataType]): Boolean = { + if (schema == null) return true + + var i = 0 + while (i < schema.length) { + schema(i) match { + case udt: UserDefinedType[_] => return false + case array: ArrayType => return false + case map: MapType => return false + case struct: StructType => return false + case _ => + } + i += 1 + } + + return true + } + + /** + * The util function to create the serialization function based on the given schema. + */ + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { + (row: Row) => + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we write values to the underlying stream, we also first write the null byte + // first. Then, if the value is not null, we write the contents out. + + case NullType => // Write nothing. + + case BooleanType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeBoolean(row.getBoolean(i)) + } + + case ByteType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeByte(row.getByte(i)) + } + + case ShortType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeShort(row.getShort(i)) + } + + case IntegerType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getInt(i)) + } + + case LongType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeLong(row.getLong(i)) + } + + case FloatType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeFloat(row.getFloat(i)) + } + + case DoubleType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeDouble(row.getDouble(i)) + } + + case decimal: DecimalType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + out.writeInt(bytes.length) + out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) + } + + case DateType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getAs[Int](i)) + } + + case TimestampType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val timestamp = row.getAs[java.sql.Timestamp](i) + val time = timestamp.getTime + val nanos = timestamp.getNanos + out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. + out.writeInt(nanos) // Write the nanoseconds part. + } + + case StringType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[UTF8String](i).getBytes + out.writeInt(bytes.length) + out.write(bytes) + } + + case BinaryType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[Array[Byte]](i) + out.writeInt(bytes.length) + out.write(bytes) + } + } + i += 1 + } + } + } + + /** + * The util function to create the deserialization function based on the given schema. + */ + def createDeserializationFunction( + schema: Array[DataType], + in: DataInputStream, + mutableRow: SpecificMutableRow): () => Unit = { + () => { + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we read values from the underlying stream, we also first read the null byte + // first. Then, if the value is not null, we update the field of the mutable row. + + case NullType => mutableRow.setNullAt(i) // Read nothing. + + case BooleanType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setBoolean(i, in.readBoolean()) + } + + case ByteType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setByte(i, in.readByte()) + } + + case ShortType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setShort(i, in.readShort()) + } + + case IntegerType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setInt(i, in.readInt()) + } + + case LongType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setLong(i, in.readLong()) + } + + case FloatType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setFloat(i, in.readFloat()) + } + + case DoubleType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setDouble(i, in.readDouble()) + } + + case decimal: DecimalType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + // First, read in the unscaled value. + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) + } + + case DateType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.update(i, in.readInt()) + } + + case TimestampType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val time = in.readLong() // Read the milliseconds value. + val nanos = in.readInt() // Read the nanoseconds part. + val timestamp = new Timestamp(time) + timestamp.setNanos(nanos) + mutableRow.update(i, timestamp) + } + + case StringType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, UTF8String(bytes)) + } + + case BinaryType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, bytes) + } + } + i += 1 + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala new file mode 100644 index 0000000000000..27f063d73a9a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.sql.{Timestamp, Date} + +import org.scalatest.{FunSuite, BeforeAndAfterAll} + +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.ShuffleDependency +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} + +class SparkSqlSerializer2DataTypeSuite extends FunSuite { + // Make sure that we will not use serializer2 for unsupported data types. + def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { + val testName = + s"${if (dataType == null) null else dataType.toString} is " + + s"${if (isSupported) "supported" else "unsupported"}" + + test(testName) { + assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) + } + } + + checkSupported(null, isSupported = true) + checkSupported(NullType, isSupported = true) + checkSupported(BooleanType, isSupported = true) + checkSupported(ByteType, isSupported = true) + checkSupported(ShortType, isSupported = true) + checkSupported(IntegerType, isSupported = true) + checkSupported(LongType, isSupported = true) + checkSupported(FloatType, isSupported = true) + checkSupported(DoubleType, isSupported = true) + checkSupported(DateType, isSupported = true) + checkSupported(TimestampType, isSupported = true) + checkSupported(StringType, isSupported = true) + checkSupported(BinaryType, isSupported = true) + checkSupported(DecimalType(10, 5), isSupported = true) + checkSupported(DecimalType.Unlimited, isSupported = true) + + // For now, ArrayType, MapType, and StructType are not supported. + checkSupported(ArrayType(DoubleType, true), isSupported = false) + checkSupported(ArrayType(StringType, false), isSupported = false) + checkSupported(MapType(IntegerType, StringType, true), isSupported = false) + checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) + checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) + // UDTs are not supported right now. + checkSupported(new MyDenseVectorUDT, isSupported = false) +} + +abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { + var allColumns: String = _ + val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] + var numShufflePartitions: Int = _ + var useSerializer2: Boolean = _ + + override def beforeAll(): Unit = { + numShufflePartitions = conf.numShufflePartitions + useSerializer2 = conf.useSqlSerializer2 + + sql("set spark.sql.useSerializer2=true") + + val supportedTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType) + + val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD with all data types supported by SparkSqlSerializer2. + val rdd = + sparkContext.parallelize((1 to 1000), 10).map { i => + Row( + s"str${i}: test serializer2.", + s"binary${i}: test serializer2.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i)) + } + + createDataFrame(rdd, schema).registerTempTable("shuffle") + + super.beforeAll() + } + + override def afterAll(): Unit = { + dropTempTable("shuffle") + sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") + sql(s"set spark.sql.useSerializer2=$useSerializer2") + super.afterAll() + } + + def checkSerializer[T <: Serializer]( + executedPlan: SparkPlan, + expectedSerializerClass: Class[T]): Unit = { + executedPlan.foreach { + case exchange: Exchange => + val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] + val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + val serializerNotSetMessage = + s"Expected $expectedSerializerClass as the serializer of Exchange. " + + s"However, the serializer was not set." + val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) + assert(serializer.getClass === expectedSerializerClass) + case _ => // Ignore other nodes. + } + } + + test("key schema and value schema are not nulls") { + val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + table("shuffle").collect()) + } + + test("value schema is null") { + val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + assert( + df.map(r => r.getString(0)).collect().toSeq === + table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + } + + test("no map output field") { + val df = sql(s"SELECT 1 + 1 FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) + } +} + +/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ +class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { + override def beforeAll(): Unit = { + super.beforeAll() + // Sort merge will not be triggered. + sql("set spark.sql.shuffle.partitions = 200") + } + + test("key schema is null") { + val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") + val df = sql(s"SELECT $aggregations FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) + } +} + +/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ +class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { + + // We are expecting SparkSqlSerializer. + override val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]] + + override def beforeAll(): Unit = { + super.beforeAll() + // To trigger the sort merge. + sql("set spark.sql.shuffle.partitions = 201") + } +} From c736220dac51cf73181fdd7f621c960c4e7bf0c2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 20 Apr 2015 18:54:01 -0700 Subject: [PATCH 32/33] [SPARK-6635][SQL] DataFrame.withColumn should replace columns with identical column names JIRA https://issues.apache.org/jira/browse/SPARK-6635 Author: Liang-Chi Hsieh Closes #5541 from viirya/replace_with_column and squashes the following commits: b539c7b [Liang-Chi Hsieh] For comment. 72f35b1 [Liang-Chi Hsieh] DataFrame.withColumn can replace original column with identical column name. --- .../scala/org/apache/spark/sql/DataFrame.scala | 14 +++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 8 ++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 17c21f6e3a0e9..45f5da387692e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -747,7 +747,19 @@ class DataFrame private[sql]( * Returns a new [[DataFrame]] by adding a column. * @group dfops */ - def withColumn(colName: String, col: Column): DataFrame = select(Column("*"), col.as(colName)) + def withColumn(colName: String, col: Column): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName) else Column(name) + } + select(colNames :_*) + } else { + select(Column("*"), col.as(colName)) + } + } /** * Returns a new [[DataFrame]] with a column renamed. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3250ab476aeb4..b9b6a400ae195 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -473,6 +473,14 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) } + test("replace column using withColumn") { + val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df3 = df2.withColumn("x", df2("x") + 1) + checkAnswer( + df3.select("x"), + Row(2) :: Row(3) :: Row(4) :: Nil) + } + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") From 8136810dfad12008ac300116df7bc8448740f1ae Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 20 Apr 2015 23:18:42 -0700 Subject: [PATCH 33/33] [SPARK-6490][Core] Add spark.rpc.* and deprecate spark.akka.* Deprecated `spark.akka.num.retries`, `spark.akka.retry.wait`, `spark.akka.askTimeout`, `spark.akka.lookupTimeout`, and added `spark.rpc.num.retries`, `spark.rpc.retry.wait`, `spark.rpc.askTimeout`, `spark.rpc.lookupTimeout`. Author: zsxwing Closes #5595 from zsxwing/SPARK-6490 and squashes the following commits: e0d80a9 [zsxwing] Use getTimeAsMs and getTimeAsSeconds and other minor fixes 31dbe69 [zsxwing] Add spark.rpc.* and deprecate spark.akka.* --- .../scala/org/apache/spark/SparkConf.scala | 10 ++++++- .../org/apache/spark/deploy/Client.scala | 6 ++--- .../spark/deploy/client/AppClient.scala | 4 +-- .../apache/spark/deploy/master/Master.scala | 4 +-- .../spark/deploy/master/ui/MasterWebUI.scala | 4 +-- .../deploy/rest/StandaloneRestServer.scala | 8 +++--- .../spark/deploy/worker/ui/WorkerWebUI.scala | 4 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 10 +++---- .../cluster/YarnSchedulerBackend.scala | 4 +-- .../spark/storage/BlockManagerMaster.scala | 4 +-- .../org/apache/spark/util/AkkaUtils.scala | 26 +++---------------- .../org/apache/spark/util/RpcUtils.scala | 23 ++++++++++++++++ .../apache/spark/MapOutputTrackerSuite.scala | 4 +-- .../org/apache/spark/SparkConfSuite.scala | 24 ++++++++++++++++- .../org/apache/spark/rpc/RpcEnvSuite.scala | 4 +-- 15 files changed, 86 insertions(+), 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index e3a649d755450..c1996e08756a6 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -431,7 +431,15 @@ private[spark] object SparkConf extends Logging { "spark.yarn.am.waitTime" -> Seq( AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", // Translate old value to a duration, with 10s wait time per try. - translation = s => s"${s.toLong * 10}s")) + translation = s => s"${s.toLong * 10}s")), + "spark.rpc.numRetries" -> Seq( + AlternateConfig("spark.akka.num.retries", "1.4")), + "spark.rpc.retry.wait" -> Seq( + AlternateConfig("spark.akka.retry.wait", "1.4")), + "spark.rpc.askTimeout" -> Seq( + AlternateConfig("spark.akka.askTimeout", "1.4")), + "spark.rpc.lookupTimeout" -> Seq( + AlternateConfig("spark.akka.lookupTimeout", "1.4")) ) /** diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 8d13b2a2cd4f3..c2c3e9a9e4827 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -27,7 +27,7 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} /** * Proxy that relays messages to the driver. @@ -36,7 +36,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with ActorLogReceive with Logging { var masterActor: ActorSelection = _ - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) override def preStart(): Unit = { masterActor = context.actorSelection( @@ -155,7 +155,7 @@ object Client { if (!driverArgs.logLevel.isGreaterOrEqual(Level.WARN)) { conf.set("spark.akka.logLifecycleEvents", "true") } - conf.set("spark.akka.askTimeout", "10") + conf.set("spark.rpc.askTimeout", "10") conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 4f06d7f96c46e..43c8a934c311a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -30,7 +30,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -193,7 +193,7 @@ private[spark] class AppClient( def stop() { if (actor != null) { try { - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) val future = actor.ask(StopAppClient)(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index c5a6b1beac9be..ff2eed6dee70a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -47,7 +47,7 @@ import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils} private[master] class Master( host: String, @@ -931,7 +931,7 @@ private[deploy] object Master extends Logging { securityManager = securityMgr) val actor = actorSystem.actorOf( Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) val portsRequest = actor.ask(BoundPortsRequest)(timeout) val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index bb11e0642ddc6..aad9c87bdb987 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -21,7 +21,7 @@ import org.apache.spark.Logging import org.apache.spark.deploy.master.Master import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone master. @@ -31,7 +31,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging { val masterActorRef = master.self - val timeout = AkkaUtils.askTimeout(master.conf) + val timeout = RpcUtils.askTimeout(master.conf) val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) initialize() diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 4f19af59f409f..2d6b8d4204795 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -32,7 +32,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ @@ -223,7 +223,7 @@ private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) } protected def handleKill(submissionId: String): KillSubmissionResponse = { - val askTimeout = AkkaUtils.askTimeout(conf) + val askTimeout = RpcUtils.askTimeout(conf) val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) val k = new KillSubmissionResponse @@ -257,7 +257,7 @@ private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) } protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val askTimeout = AkkaUtils.askTimeout(conf) + val askTimeout = RpcUtils.askTimeout(conf) val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } @@ -321,7 +321,7 @@ private[rest] class SubmitRequestServlet( responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { requestMessage match { case submitRequest: CreateSubmissionRequest => - val askTimeout = AkkaUtils.askTimeout(conf) + val askTimeout = RpcUtils.askTimeout(conf) val driverDescription = buildDriverDescription(submitRequest) val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index de6423beb543e..b3bb5f911dbd7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.worker.ui.WorkerWebUI._ import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone worker. @@ -38,7 +38,7 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - private[ui] val timeout = AkkaUtils.askTimeout(worker.conf) + private[ui] val timeout = RpcUtils.askTimeout(worker.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index cba038ca355d7..a5336b7563802 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -25,7 +25,7 @@ import scala.language.postfixOps import scala.reflect.ClassTag import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} /** * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to @@ -38,7 +38,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] abstract class RpcEnv(conf: SparkConf) { - private[spark] val defaultLookupTimeout = AkkaUtils.lookupTimeout(conf) + private[spark] val defaultLookupTimeout = RpcUtils.lookupTimeout(conf) /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement @@ -282,9 +282,9 @@ trait ThreadSafeRpcEndpoint extends RpcEndpoint private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) extends Serializable with Logging { - private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) - private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) - private[this] val defaultAskTimeout = conf.getLong("spark.akka.askTimeout", 30) seconds + private[this] val maxRetries = RpcUtils.numRetries(conf) + private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) + private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) /** * return the address for the [[RpcEndpointRef]] diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index f72566c370a6f..1406a36a669c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -24,7 +24,7 @@ import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{RpcUtils, Utils} import scala.util.control.NonFatal @@ -46,7 +46,7 @@ private[spark] abstract class YarnSchedulerBackend( private val yarnSchedulerEndpoint = rpcEnv.setupEndpoint( YarnSchedulerBackend.ENDPOINT_NAME, new YarnSchedulerEndpoint(rpcEnv)) - private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) + private implicit val askTimeout = RpcUtils.askTimeout(sc.conf) /** * Request executors from the ApplicationMaster by specifying the total number desired. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index ceacf043029f3..c798843bd5d8a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -23,7 +23,7 @@ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.AkkaUtils +import org.apache.spark.util.RpcUtils private[spark] class BlockManagerMaster( @@ -32,7 +32,7 @@ class BlockManagerMaster( isDriver: Boolean) extends Logging { - val timeout = AkkaUtils.askTimeout(conf) + val timeout = RpcUtils.askTimeout(conf) /** Remove a dead executor from the driver endpoint. This is only called on the driver side. */ def removeExecutor(execId: String) { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 8e8cc7cc6389e..b725df3b44596 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import scala.collection.JavaConversions.mapAsJavaMap import scala.concurrent.Await -import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.concurrent.duration.FiniteDuration import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask @@ -125,16 +125,6 @@ private[spark] object AkkaUtils extends Logging { (actorSystem, boundPort) } - /** Returns the default Spark timeout to use for Akka ask operations. */ - def askTimeout(conf: SparkConf): FiniteDuration = { - Duration.create(conf.getLong("spark.akka.askTimeout", 30), "seconds") - } - - /** Returns the default Spark timeout to use for Akka remote actor lookup. */ - def lookupTimeout(conf: SparkConf): FiniteDuration = { - Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") - } - private val AKKA_MAX_FRAME_SIZE_IN_MB = Int.MaxValue / 1024 / 1024 /** Returns the configured max frame size for Akka messages in bytes. */ @@ -150,16 +140,6 @@ private[spark] object AkkaUtils extends Logging { /** Space reserved for extra data in an Akka message besides serialized task or task result. */ val reservedSizeBytes = 200 * 1024 - /** Returns the configured number of times to retry connecting */ - def numRetries(conf: SparkConf): Int = { - conf.getInt("spark.akka.num.retries", 3) - } - - /** Returns the configured number of milliseconds to wait on each retry */ - def retryWaitMs(conf: SparkConf): Int = { - conf.getInt("spark.akka.retry.wait", 3000) - } - /** * Send a message to the given actor and get its result within a default timeout, or * throw a SparkException if this fails. @@ -216,7 +196,7 @@ private[spark] object AkkaUtils extends Logging { val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = AkkaUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } @@ -230,7 +210,7 @@ private[spark] object AkkaUtils extends Logging { val executorActorSystemName = SparkEnv.executorActorSystemName Utils.checkHost(host, "Expected hostname") val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = AkkaUtils.lookupTimeout(conf) + val timeout = RpcUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 6665b17c3d5df..5ae793e0e87a3 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,6 +17,9 @@ package org.apache.spark.util +import scala.concurrent.duration._ +import scala.language.postfixOps + import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} @@ -32,4 +35,24 @@ object RpcUtils { Utils.checkHost(driverHost, "Expected hostname") rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) } + + /** Returns the configured number of times to retry connecting */ + def numRetries(conf: SparkConf): Int = { + conf.getInt("spark.rpc.numRetries", 3) + } + + /** Returns the configured number of milliseconds to wait on each retry */ + def retryWaitMs(conf: SparkConf): Long = { + conf.getTimeAsMs("spark.rpc.retry.wait", "3s") + } + + /** Returns the default Spark timeout to use for RPC ask operations. */ + def askTimeout(conf: SparkConf): FiniteDuration = { + conf.getTimeAsSeconds("spark.rpc.askTimeout", "30s") seconds + } + + /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ + def lookupTimeout(conf: SparkConf): FiniteDuration = { + conf.getTimeAsSeconds("spark.rpc.lookupTimeout", "30s") seconds + } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 6295d34be5ca9..6ed057a7cab97 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -154,7 +154,7 @@ class MapOutputTrackerSuite extends FunSuite { test("remote fetch below akka frame size") { val newConf = new SparkConf newConf.set("spark.akka.frameSize", "1") - newConf.set("spark.akka.askTimeout", "1") // Fail fast + newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) val rpcEnv = createRpcEnv("spark") @@ -180,7 +180,7 @@ class MapOutputTrackerSuite extends FunSuite { test("remote fetch exceeds akka frame size") { val newConf = new SparkConf newConf.set("spark.akka.frameSize", "1") - newConf.set("spark.akka.askTimeout", "1") // Fail fast + newConf.set("spark.rpc.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) val rpcEnv = createRpcEnv("test") diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 8e6c200c4ba00..d7d8014a20498 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark import java.util.concurrent.{TimeUnit, Executors} +import scala.concurrent.duration._ +import scala.language.postfixOps import scala.util.{Try, Random} import org.scalatest.FunSuite import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} -import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { @@ -222,6 +224,26 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(conf.getTimeAsSeconds("spark.yarn.am.waitTime") === 420) } + test("akka deprecated configs") { + val conf = new SparkConf() + + assert(!conf.contains("spark.rpc.num.retries")) + assert(!conf.contains("spark.rpc.retry.wait")) + assert(!conf.contains("spark.rpc.askTimeout")) + assert(!conf.contains("spark.rpc.lookupTimeout")) + + conf.set("spark.akka.num.retries", "1") + assert(RpcUtils.numRetries(conf) === 1) + + conf.set("spark.akka.retry.wait", "2") + assert(RpcUtils.retryWaitMs(conf) === 2L) + + conf.set("spark.akka.askTimeout", "3") + assert(RpcUtils.askTimeout(conf) === (3 seconds)) + + conf.set("spark.akka.lookupTimeout", "4") + assert(RpcUtils.lookupTimeout(conf) === (4 seconds)) + } } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index ada07ef11cd7a..5fbda37c7cb88 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -155,8 +155,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { }) val conf = new SparkConf() - conf.set("spark.akka.retry.wait", "0") - conf.set("spark.akka.num.retries", "1") + conf.set("spark.rpc.retry.wait", "0") + conf.set("spark.rpc.num.retries", "1") val anotherEnv = createRpcEnv(conf, "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout")