From 4d9e560b5470029143926827b1cb9d72a0bfbeff Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 27 Apr 2015 19:02:51 -0700 Subject: [PATCH 01/18] [SPARK-7090] [MLLIB] Introduce LDAOptimizer to LDA to further improve extensibility jira: https://issues.apache.org/jira/browse/SPARK-7090 LDA was implemented with extensibility in mind. And with the development of OnlineLDA and Gibbs Sampling, we are collecting more detailed requirements from different algorithms. As Joseph Bradley jkbradley proposed in https://github.com/apache/spark/pull/4807 and with some further discussion, we'd like to adjust the code structure a little to present the common interface and extension point clearly. Basically class LDA would be a common entrance for LDA computing. And each LDA object will refer to a LDAOptimizer for the concrete algorithm implementation. Users can customize LDAOptimizer with specific parameters and assign it to LDA. Concrete changes: 1. Add a trait `LDAOptimizer`, which defines the common iterface for concrete implementations. Each subClass is a wrapper for a specific LDA algorithm. 2. Move EMOptimizer to file LDAOptimizer and inherits from LDAOptimizer, rename to EMLDAOptimizer. (in case a more generic EMOptimizer comes in the future) -adjust the constructor of EMOptimizer, since all the parameters should be passed in through initialState method. This can avoid unwanted confusion or overwrite. -move the code from LDA.initalState to initalState of EMLDAOptimizer 3. Add property ldaOptimizer to LDA and its getter/setter, and EMLDAOptimizer is the default Optimizer. 4. Change the return type of LDA.run from DistributedLDAModel to LDAModel. Further work: add OnlineLDAOptimizer and other possible Optimizers once ready. Author: Yuhao Yang Closes #5661 from hhbyyh/ldaRefactor and squashes the following commits: 0e2e006 [Yuhao Yang] respond to review comments 08a45da [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor e756ce4 [Yuhao Yang] solve mima exception d74fd8f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into ldaRefactor 0bb8400 [Yuhao Yang] refactor LDA with Optimizer ec2f857 [Yuhao Yang] protoptype for discussion --- .../spark/examples/mllib/JavaLDAExample.java | 2 +- .../spark/examples/mllib/LDAExample.scala | 4 +- .../apache/spark/mllib/clustering/LDA.scala | 181 +++------------ .../spark/mllib/clustering/LDAModel.scala | 2 +- .../spark/mllib/clustering/LDAOptimizer.scala | 210 ++++++++++++++++++ .../spark/mllib/clustering/JavaLDASuite.java | 2 +- .../spark/mllib/clustering/LDASuite.scala | 2 +- project/MimaExcludes.scala | 4 + 8 files changed, 256 insertions(+), 151 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index 36207ae38d9a9..fd53c81cc4974 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -58,7 +58,7 @@ public Tuple2 call(Tuple2 doc_id) { corpus.cache(); // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus); // Output topics. Each is a distribution over words (matching word count vectors) System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 08a93595a2e17..a1850390c0a86 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -26,7 +26,7 @@ import scopt.OptionParser import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -137,7 +137,7 @@ object LDAExample { sc.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() - val ldaModel = lda.run(corpus) + val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] val elapsed = (System.nanoTime() - startTime) / 1e9 println(s"Finished training LDA model. Summary:") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d006b39acb213..37bf88b73b911 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,16 +17,11 @@ package org.apache.spark.mllib.clustering -import java.util.Random - -import breeze.linalg.{DenseVector => BDV, normalize} - +import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -42,16 +37,9 @@ import org.apache.spark.util.Utils * - "token": instance of a term appearing in a document * - "topic": multinomial distribution over words representing some concept * - * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented - * according to the Asuncion et al. (2009) paper referenced below. - * * References: * - Original LDA paper (journal version): * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. - * - This class implements their "smoothed" LDA model. - * - Paper which clearly explains several algorithms, including EM: - * Asuncion, Welling, Smyth, and Teh. - * "On Smoothing and Inference for Topic Models." UAI, 2009. * * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation * (Wikipedia)]] @@ -63,10 +51,11 @@ class LDA private ( private var docConcentration: Double, private var topicConcentration: Double, private var seed: Long, - private var checkpointInterval: Int) extends Logging { + private var checkpointInterval: Int, + private var ldaOptimizer: LDAOptimizer) extends Logging { def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointInterval = 10) + seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. @@ -220,6 +209,32 @@ class LDA private ( this } + + /** LDAOptimizer used to perform the actual calculation */ + def getOptimizer: LDAOptimizer = ldaOptimizer + + /** + * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) + */ + def setOptimizer(optimizer: LDAOptimizer): this.type = { + this.ldaOptimizer = optimizer + this + } + + /** + * Set the LDAOptimizer used to perform the actual calculation by algorithm name. + * Currently "em" is supported. + */ + def setOptimizer(optimizerName: String): this.type = { + this.ldaOptimizer = + optimizerName.toLowerCase match { + case "em" => new EMLDAOptimizer + case other => + throw new IllegalArgumentException(s"Only em is supported but got $other.") + } + this + } + /** * Learn an LDA model using the given dataset. * @@ -229,9 +244,9 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ - def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { - val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointInterval) + def run(documents: RDD[(Long, Vector)]): LDAModel = { + val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration, + seed, checkpointInterval) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { @@ -241,12 +256,11 @@ class LDA private ( iterationTimes(iter) = elapsedSeconds iter += 1 } - state.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(state, iterationTimes) + state.getLDAModel(iterationTimes) } /** Java-friendly version of [[run()]] */ - def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } } @@ -320,88 +334,10 @@ private[clustering] object LDA { private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 - /** - * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. - * - * @param graph EM graph, storing current parameter estimates in vertex descriptors and - * data (token counts) in edge descriptors. - * @param k Number of topics - * @param vocabSize Number of unique terms - * @param docConcentration "alpha" - * @param topicConcentration "beta" or "eta" - */ - private[clustering] class EMOptimizer( - var graph: Graph[TopicCounts, TokenCount], - val k: Int, - val vocabSize: Int, - val docConcentration: Double, - val topicConcentration: Double, - checkpointInterval: Int) { - - private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - graph, checkpointInterval) - - def next(): EMOptimizer = { - val eta = topicConcentration - val W = vocabSize - val alpha = docConcentration - - val N_k = globalTopicTotals - val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = - (edgeContext) => { - // Compute N_{wj} gamma_{wjk} - val N_wj = edgeContext.attr - // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count - // N_{wj}. - val scaledTopicDistribution: TopicCounts = - computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj - edgeContext.sendToDst((false, scaledTopicDistribution)) - edgeContext.sendToSrc((false, scaledTopicDistribution)) - } - // This is a hack to detect whether we could modify the values in-place. - // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) - val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = - (m0, m1) => { - val sum = - if (m0._1) { - m0._2 += m1._2 - } else if (m1._1) { - m1._2 += m0._2 - } else { - m0._2 + m1._2 - } - (true, sum) - } - // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. - val docTopicDistributions: VertexRDD[TopicCounts] = - graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) - .mapValues(_._2) - // Update the vertex descriptors with the new counts. - val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) - graph = newGraph - graphCheckpointer.updateGraph(newGraph) - globalTopicTotals = computeGlobalTopicTotals() - this - } - - /** - * Aggregate distributions over topics from all term vertices. - * - * Note: This executes an action on the graph RDDs. - */ - var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() - - private def computeGlobalTopicTotals(): TopicCounts = { - val numTopics = k - graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) - } - - } - /** * Compute gamma_{wjk}, a distribution over topics k. */ - private def computePTopic( + private[clustering] def computePTopic( docTopicCounts: TopicCounts, termTopicCounts: TopicCounts, totalTopicCounts: TopicCounts, @@ -427,49 +363,4 @@ private[clustering] object LDA { // normalize BDV(gamma_wj) /= sum } - - /** - * Compute bipartite term/doc graph. - */ - private def initialState( - docs: RDD[(Long, Vector)], - k: Int, - docConcentration: Double, - topicConcentration: Double, - randomSeed: Long, - checkpointInterval: Int): EMOptimizer = { - // For each document, create an edge (Document -> Term) for each unique term in the document. - val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => - // Add edges for terms with non-zero counts. - termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => - Edge(docID, term2index(term), cnt) - } - } - - val vocabSize = docs.take(1).head._2.size - - // Create vertices. - // Initially, we use random soft assignments of tokens to topics (random gamma). - def createVertices(): RDD[(VertexId, TopicCounts)] = { - val verticesTMP: RDD[(VertexId, TopicCounts)] = - edges.mapPartitionsWithIndex { case (partIndex, partEdges) => - val random = new Random(partIndex + randomSeed) - partEdges.flatMap { edge => - val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) - val sum = gamma * edge.attr - Seq((edge.srcId, sum), (edge.dstId, sum)) - } - } - verticesTMP.reduceByKey(_ + _) - } - - val docTermVertices = createVertices() - - // Partition such that edges are grouped by document - val graph = Graph(docTermVertices, edges) - .partitionBy(PartitionStrategy.EdgePartition1D) - - new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 0a3f21ecee0dc..6cf26445f20a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -203,7 +203,7 @@ class DistributedLDAModel private ( import LDA._ - private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { + private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = { this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, state.topicConcentration, iterationTimes) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala new file mode 100644 index 0000000000000..ffd72a294c6c6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, normalize} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * + * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can + * hold optimizer-specific parameters for users to set. + */ +@Experimental +trait LDAOptimizer{ + + /* + DEVELOPERS NOTE: + + An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which + stores internal data structure (Graph or Matrix) and other parameters for the algorithm. + The interface is isolated to improve the extensibility of LDA. + */ + + /** + * Initializer for the optimizer. LDA passes the common parameters to the optimizer and + * the internal structure can be initialized properly. + */ + private[clustering] def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer + + private[clustering] def next(): LDAOptimizer + + private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel +} + +/** + * :: Experimental :: + * + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + */ +@Experimental +class EMLDAOptimizer extends LDAOptimizer{ + + import LDA._ + + /** + * Following fields will only be initialized through initialState method + */ + private[clustering] var graph: Graph[TopicCounts, TokenCount] = null + private[clustering] var k: Int = 0 + private[clustering] var vocabSize: Int = 0 + private[clustering] var docConcentration: Double = 0 + private[clustering] var topicConcentration: Double = 0 + private[clustering] var checkpointInterval: Int = 10 + private var graphCheckpointer: PeriodicGraphCheckpointer[TopicCounts, TokenCount] = null + + /** + * Compute bipartite term/doc graph. + */ + private[clustering] override def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + def createVertices(): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } + } + verticesTMP.reduceByKey(_ + _) + } + + val docTermVertices = createVertices() + + // Partition such that edges are grouped by document + this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D) + this.k = k + this.vocabSize = vocabSize + this.docConcentration = docConcentration + this.topicConcentration = topicConcentration + this.checkpointInterval = checkpointInterval + this.graphCheckpointer = new + PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.globalTopicTotals = computeGlobalTopicTotals() + this + } + + private[clustering] override def next(): EMLDAOptimizer = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + private[clustering] var globalTopicTotals: TopicCounts = null + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + this.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(this, iterationTimes) + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index dc10aa67c7c1f..fbe171b4b1ab1 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -88,7 +88,7 @@ public void distributedLDAModel() { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index cc747dabb9968..41ec794146c69 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2) - val model: DistributedLDAModel = lda.run(corpus) + val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] // Check: basic parameters val localModel = model.toLocal diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7ef363a2f07ad..967961c2bf5c3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -72,6 +72,10 @@ object MimaExcludes { // SPARK-6703 Add getOrCreate method to SparkContext ProblemFilters.exclude[IncompatibleResultTypeProblem] ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + )++ Seq( + // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.clustering.LDA$EMOptimizer") ) case v if v.startsWith("1.3") => From 874a2ca93d095a0dfa1acfdacf0e9d80388c4422 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 27 Apr 2015 21:45:40 -0700 Subject: [PATCH 02/18] [SPARK-7174][Core] Move calling `TaskScheduler.executorHeartbeatReceived` to another thread `HeartbeatReceiver` will call `TaskScheduler.executorHeartbeatReceived`, which is a blocking operation because `TaskScheduler.executorHeartbeatReceived` will call ```Scala blockManagerMaster.driverEndpoint.askWithReply[Boolean]( BlockManagerHeartbeat(blockManagerId), 600 seconds) ``` finally. Even if it asks from a local Actor, it may block the current Akka thread. E.g., the reply may be dispatched to the same thread of the ask operation. So the reply cannot be processed. An extreme case is setting the thread number of Akka dispatch thread pool to 1. jstack log: ``` "sparkDriver-akka.actor.default-dispatcher-14" daemon prio=10 tid=0x00007f2a8c02d000 nid=0x725 waiting on condition [0x00007f2b1d6d0000] java.lang.Thread.State: TIMED_WAITING (parking) at sun.misc.Unsafe.park(Native Method) - parking to wait for <0x00000006197a0868> (a scala.concurrent.impl.Promise$CompletionLatch) at java.util.concurrent.locks.LockSupport.parkNanos(LockSupport.java:226) at java.util.concurrent.locks.AbstractQueuedSynchronizer.doAcquireSharedNanos(AbstractQueuedSynchronizer.java:1033) at java.util.concurrent.locks.AbstractQueuedSynchronizer.tryAcquireSharedNanos(AbstractQueuedSynchronizer.java:1326) at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:208) at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:218) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at scala.concurrent.Await$$anonfun$result$1.apply(package.scala:107) at akka.dispatch.MonitorableThreadFactory$AkkaForkJoinWorkerThread$$anon$3.block(ThreadPoolBuilder.scala:169) at scala.concurrent.forkjoin.ForkJoinPool.managedBlock(ForkJoinPool.java:3640) at akka.dispatch.MonitorableThreadFactory$AkkaForkJoinWorkerThread.blockOn(ThreadPoolBuilder.scala:167) at scala.concurrent.Await$.result(package.scala:107) at org.apache.spark.rpc.RpcEndpointRef.askWithReply(RpcEnv.scala:355) at org.apache.spark.scheduler.DAGScheduler.executorHeartbeatReceived(DAGScheduler.scala:169) at org.apache.spark.scheduler.TaskSchedulerImpl.executorHeartbeatReceived(TaskSchedulerImpl.scala:367) at org.apache.spark.HeartbeatReceiver$$anonfun$receiveAndReply$1.applyOrElse(HeartbeatReceiver.scala:103) at org.apache.spark.rpc.akka.AkkaRpcEnv.org$apache$spark$rpc$akka$AkkaRpcEnv$$processMessage(AkkaRpcEnv.scala:182) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1$$anonfun$receiveWithLogging$1$$anonfun$applyOrElse$4.apply$mcV$sp(AkkaRpcEnv.scala:128) at org.apache.spark.rpc.akka.AkkaRpcEnv.org$apache$spark$rpc$akka$AkkaRpcEnv$$safelyCall(AkkaRpcEnv.scala:203) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1$$anonfun$receiveWithLogging$1.applyOrElse(AkkaRpcEnv.scala:127) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply$mcVL$sp(AbstractPartialFunction.scala:33) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply(AbstractPartialFunction.scala:33) at scala.runtime.AbstractPartialFunction$mcVL$sp.apply(AbstractPartialFunction.scala:25) at org.apache.spark.util.ActorLogReceive$$anon$1.apply(ActorLogReceive.scala:59) at org.apache.spark.util.ActorLogReceive$$anon$1.apply(ActorLogReceive.scala:42) at scala.PartialFunction$class.applyOrElse(PartialFunction.scala:118) at org.apache.spark.util.ActorLogReceive$$anon$1.applyOrElse(ActorLogReceive.scala:42) at akka.actor.Actor$class.aroundReceive(Actor.scala:465) at org.apache.spark.rpc.akka.AkkaRpcEnv$$anonfun$actorRef$lzycompute$1$1$$anon$1.aroundReceive(AkkaRpcEnv.scala:94) at akka.actor.ActorCell.receiveMessage(ActorCell.scala:516) at akka.actor.ActorCell.invoke(ActorCell.scala:487) at akka.dispatch.Mailbox.processMailbox(Mailbox.scala:238) at akka.dispatch.Mailbox.run(Mailbox.scala:220) at akka.dispatch.ForkJoinExecutorConfigurator$AkkaForkJoinTask.exec(AbstractDispatcher.scala:393) at scala.concurrent.forkjoin.ForkJoinTask.doExec(ForkJoinTask.java:260) at scala.concurrent.forkjoin.ForkJoinPool$WorkQueue.runTask(ForkJoinPool.java:1339) at scala.concurrent.forkjoin.ForkJoinPool.runWorker(ForkJoinPool.java:1979) at scala.concurrent.forkjoin.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:107) ``` This PR moved this blocking operation to a separated thread. Author: zsxwing Closes #5723 from zsxwing/SPARK-7174 and squashes the following commits: 98bfe48 [zsxwing] Use a single thread for checking timeout and reporting executorHeartbeatReceived 5b3b545 [zsxwing] Move calling `TaskScheduler.executorHeartbeatReceived` to another thread to avoid blocking the Akka thread pool --- .../org/apache/spark/HeartbeatReceiver.scala | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 68d05d5b02537..f2b024ff6cb67 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -76,13 +76,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) private var timeoutCheckingTask: ScheduledFuture[_] = null - private val timeoutCheckingThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread") + // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not + // block the thread for a long time. + private val eventLoopThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-receiver-event-loop-thread") private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") override def onStart(): Unit = { - timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ExpireDeadHosts)) } @@ -99,11 +101,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) executorLastSeen(executorId) = System.currentTimeMillis() - context.reply(response) + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -125,7 +131,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (sc.supportDynamicAllocation) { // Asynchronously kill the executor to avoid blocking the current thread killExecutorThread.submit(new Runnable { - override def run(): Unit = sc.killExecutor(executorId) + override def run(): Unit = Utils.tryLogNonFatalError { + sc.killExecutor(executorId) + } }) } executorLastSeen.remove(executorId) @@ -137,7 +145,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) } - timeoutCheckingThread.shutdownNow() + eventLoopThread.shutdownNow() killExecutorThread.shutdownNow() } } From 29576e786072bd4218e10036ddfc8d367b1c1446 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 27 Apr 2015 23:10:14 -0700 Subject: [PATCH 03/18] [SPARK-6829] Added math functions for DataFrames Implemented almost all math functions found in scala.math (max, min and abs were already present). cc mengxr marmbrus Author: Burak Yavuz Closes #5616 from brkyvz/math-udfs and squashes the following commits: fb27153 [Burak Yavuz] reverted exception message 836a098 [Burak Yavuz] fixed test and addressed small comment e5f0d13 [Burak Yavuz] addressed code review v2.2 b26c5fb [Burak Yavuz] addressed review v2.1 2761f08 [Burak Yavuz] addressed review v2 6588a5b [Burak Yavuz] fixed merge conflicts b084e10 [Burak Yavuz] Addressed code review 029e739 [Burak Yavuz] fixed atan2 test 534cc11 [Burak Yavuz] added more tests, addressed comments fa68dbe [Burak Yavuz] added double specific test data 937d5a5 [Burak Yavuz] use doubles instead of ints 8e28fff [Burak Yavuz] Added apache header 7ec8f7f [Burak Yavuz] Added math functions for DataFrames --- .../catalyst/analysis/HiveTypeCoercion.scala | 19 + .../sql/catalyst/expressions/Expression.scala | 10 + .../expressions/mathfuncs/binary.scala | 93 +++ .../expressions/mathfuncs/unary.scala | 168 ++++++ .../ExpressionEvaluationSuite.scala | 165 +++++ .../org/apache/spark/sql/mathfunctions.scala | 562 ++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 9 + .../spark/sql/ColumnExpressionSuite.scala | 1 - .../spark/sql/MathExpressionsSuite.scala | 233 ++++++++ 9 files changed, 1259 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 35c7f00d4e42a..73c9a1c7afdad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -79,6 +79,7 @@ trait HiveTypeCoercion { CaseWhenCoercion :: Division :: PropagateTypes :: + ExpectedInputConversion :: Nil /** @@ -643,4 +644,22 @@ trait HiveTypeCoercion { } } + /** + * Casts types according to the expected input types for Expressions that have the trait + * `ExpectsInputTypes`. + */ + object ExpectedInputConversion extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case (child, actual, expected) => + if (actual == expected) child else Cast(child, expected) + } + e.withNewChildren(newC) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4e3bbc06a5b4c..1d71c1b4b0c7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -109,3 +109,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException } + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait ExpectsInputTypes { + + def expectedChildTypes: Seq[DataType] + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala new file mode 100644 index 0000000000000..5b4d912a64f71 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} +import org.apache.spark.sql.types._ + +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + type EvaluatedType = Any + override def symbol: String = null + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def nullable: Boolean = left.nullable || right.nullable + override def toString: String = s"$name($left, $right)" + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } +} + +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") + +case class Hypot( + left: Expression, + right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Atan2( + left: Expression, + right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala new file mode 100644 index 0000000000000..96cb77d487529 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} +import org.apache.spark.sql.types._ + +/** + * A unary expression specifically for math functions. Math Functions expect a specific type of + * input format, therefore these functions extend `ExpectsInputTypes`. + * @param name The short name of the function + */ +abstract class MathematicalExpression(name: String) + extends UnaryExpression with Serializable with ExpectsInputTypes { + self: Product => + type EvaluatedType = Any + + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"$name($child)" +} + +/** + * A unary expression specifically for math functions that take a `Double` as input and return + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForDouble(f: Double => Double, name: String) + extends MathematicalExpression(name) { self: Product => + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } +} + +/** + * A unary expression specifically for math functions that take an `Int` as input and return + * an `Int`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForInt(f: Int => Int, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = IntegerType + override def expectedChildTypes: Seq[DataType] = Seq(IntegerType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) null else f(evalE.asInstanceOf[Int]) + } +} + +/** + * A unary expression specifically for math functions that take a `Float` as input and return + * a `Float`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForFloat(f: Float => Float, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = FloatType + override def expectedChildTypes: Seq[DataType] = Seq(FloatType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Float]) + if (result.isNaN) null else result + } + } +} + +/** + * A unary expression specifically for math functions that take a `Long` as input and return + * a `Long`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class MathematicalExpressionForLong(f: Long => Long, name: String) + extends MathematicalExpression(name) { self: Product => + + override def dataType: DataType = LongType + override def expectedChildTypes: Seq[DataType] = Seq(LongType) + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) null else f(evalE.asInstanceOf[Long]) + } +} + +case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN") + +case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN") + +case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH") + +case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS") + +case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS") + +case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH") + +case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN") + +case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN") + +case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH") + +case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL") + +case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR") + +case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND") + +case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT") + +case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM") + +case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM") + +case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM") + +case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM") + +case class ToDegrees(child: Expression) + extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES") + +case class ToRadians(child: Expression) + extends MathematicalExpressionForDouble(math.toRadians, "RADIANS") + +case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG") + +case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10") + +case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P") + +case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP") + +case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 76298f03c94ae..5390ce43c6639 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ @@ -1152,6 +1153,170 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + */ + def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( + c: Expression => Expression, + f: T => T, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("sin") { + unaryMathFunctionEvaluation(Sin, math.sin) + } + + test("asin") { + unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) + } + + test("sinh") { + unaryMathFunctionEvaluation(Sinh, math.sinh) + } + + test("cos") { + unaryMathFunctionEvaluation(Cos, math.cos) + } + + test("acos") { + unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) + } + + test("cosh") { + unaryMathFunctionEvaluation(Cosh, math.cosh) + } + + test("tan") { + unaryMathFunctionEvaluation(Tan, math.tan) + } + + test("atan") { + unaryMathFunctionEvaluation(Atan, math.atan) + } + + test("tanh") { + unaryMathFunctionEvaluation(Tanh, math.tanh) + } + + test("toDeg") { + unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) + } + + test("toRad") { + unaryMathFunctionEvaluation(ToRadians, math.toRadians) + } + + test("cbrt") { + unaryMathFunctionEvaluation(Cbrt, math.cbrt) + } + + test("ceil") { + unaryMathFunctionEvaluation(Ceil, math.ceil) + } + + test("floor") { + unaryMathFunctionEvaluation(Floor, math.floor) + } + + test("rint") { + unaryMathFunctionEvaluation(Rint, math.rint) + } + + test("exp") { + unaryMathFunctionEvaluation(Exp, math.exp) + } + + test("expm1") { + unaryMathFunctionEvaluation(Expm1, math.expm1) + } + + test("signum") { + unaryMathFunctionEvaluation[Double](Signum, math.signum) + } + + test("isignum") { + unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5)) + } + + test("fsignum") { + unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat)) + } + + test("lsignum") { + unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong)) + } + + test("log") { + unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) + } + + test("log10") { + unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) + } + + test("log1p") { + unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) + } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + def binaryMathFunctionEvaluation( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) + checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("pow") { + binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) + } + + test("hypot") { + binaryMathFunctionEvaluation(Hypot, math.hypot) + } + + test("atan2") { + binaryMathFunctionEvaluation(Atan2, math.atan2) + } } // TODO: Make the tests work with codegen. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala new file mode 100644 index 0000000000000..84f62bf47f955 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ +import org.apache.spark.sql.functions.lit + +/** + * :: Experimental :: + * Mathematical Functions available for [[DataFrame]]. + * + * @groupname double_funcs Functions that require DoubleType as an input + * @groupname int_funcs Functions that require IntegerType as an input + * @groupname float_funcs Functions that require FloatType as an input + * @groupname long_funcs Functions that require LongType as an input + */ +@Experimental +// scalastyle:off +object mathfunctions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Computes the sine of the given value. + * + * @group double_funcs + */ + def sin(e: Column): Column = Sin(e.expr) + + /** + * Computes the sine of the given column. + * + * @group double_funcs + */ + def sin(columnName: String): Column = sin(Column(columnName)) + + /** + * Computes the sine inverse of the given value; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(e: Column): Column = Asin(e.expr) + + /** + * Computes the sine inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2. + * + * @group double_funcs + */ + def asin(columnName: String): Column = asin(Column(columnName)) + + /** + * Computes the hyperbolic sine of the given value. + * + * @group double_funcs + */ + def sinh(e: Column): Column = Sinh(e.expr) + + /** + * Computes the hyperbolic sine of the given column. + * + * @group double_funcs + */ + def sinh(columnName: String): Column = sinh(Column(columnName)) + + /** + * Computes the cosine of the given value. + * + * @group double_funcs + */ + def cos(e: Column): Column = Cos(e.expr) + + /** + * Computes the cosine of the given column. + * + * @group double_funcs + */ + def cos(columnName: String): Column = cos(Column(columnName)) + + /** + * Computes the cosine inverse of the given value; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(e: Column): Column = Acos(e.expr) + + /** + * Computes the cosine inverse of the given column; the returned angle is in the range + * 0.0 through pi. + * + * @group double_funcs + */ + def acos(columnName: String): Column = acos(Column(columnName)) + + /** + * Computes the hyperbolic cosine of the given value. + * + * @group double_funcs + */ + def cosh(e: Column): Column = Cosh(e.expr) + + /** + * Computes the hyperbolic cosine of the given column. + * + * @group double_funcs + */ + def cosh(columnName: String): Column = cosh(Column(columnName)) + + /** + * Computes the tangent of the given value. + * + * @group double_funcs + */ + def tan(e: Column): Column = Tan(e.expr) + + /** + * Computes the tangent of the given column. + * + * @group double_funcs + */ + def tan(columnName: String): Column = tan(Column(columnName)) + + /** + * Computes the tangent inverse of the given value. + * + * @group double_funcs + */ + def atan(e: Column): Column = Atan(e.expr) + + /** + * Computes the tangent inverse of the given column. + * + * @group double_funcs + */ + def atan(columnName: String): Column = atan(Column(columnName)) + + /** + * Computes the hyperbolic tangent of the given value. + * + * @group double_funcs + */ + def tanh(e: Column): Column = Tanh(e.expr) + + /** + * Computes the hyperbolic tangent of the given column. + * + * @group double_funcs + */ + def tanh(columnName: String): Column = tanh(Column(columnName)) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(e: Column): Column = ToDegrees(e.expr) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * @group double_funcs + */ + def toDeg(columnName: String): Column = toDeg(Column(columnName)) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(e: Column): Column = ToRadians(e.expr) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + * + * @group double_funcs + */ + def toRad(columnName: String): Column = toRad(Column(columnName)) + + /** + * Computes the ceiling of the given value. + * + * @group double_funcs + */ + def ceil(e: Column): Column = Ceil(e.expr) + + /** + * Computes the ceiling of the given column. + * + * @group double_funcs + */ + def ceil(columnName: String): Column = ceil(Column(columnName)) + + /** + * Computes the floor of the given value. + * + * @group double_funcs + */ + def floor(e: Column): Column = Floor(e.expr) + + /** + * Computes the floor of the given column. + * + * @group double_funcs + */ + def floor(columnName: String): Column = floor(Column(columnName)) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(e: Column): Column = Rint(e.expr) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + * + * @group double_funcs + */ + def rint(columnName: String): Column = rint(Column(columnName)) + + /** + * Computes the cube-root of the given value. + * + * @group double_funcs + */ + def cbrt(e: Column): Column = Cbrt(e.expr) + + /** + * Computes the cube-root of the given column. + * + * @group double_funcs + */ + def cbrt(columnName: String): Column = cbrt(Column(columnName)) + + /** + * Computes the signum of the given value. + * + * @group double_funcs + */ + def signum(e: Column): Column = Signum(e.expr) + + /** + * Computes the signum of the given column. + * + * @group double_funcs + */ + def signum(columnName: String): Column = signum(Column(columnName)) + + /** + * Computes the signum of the given value. For IntegerType. + * + * @group int_funcs + */ + def isignum(e: Column): Column = ISignum(e.expr) + + /** + * Computes the signum of the given column. For IntegerType. + * + * @group int_funcs + */ + def isignum(columnName: String): Column = isignum(Column(columnName)) + + /** + * Computes the signum of the given value. For FloatType. + * + * @group float_funcs + */ + def fsignum(e: Column): Column = FSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group float_funcs + */ + def fsignum(columnName: String): Column = fsignum(Column(columnName)) + + /** + * Computes the signum of the given value. For LongType. + * + * @group long_funcs + */ + def lsignum(e: Column): Column = LSignum(e.expr) + + /** + * Computes the signum of the given column. For FloatType. + * + * @group long_funcs + */ + def lsignum(columnName: String): Column = lsignum(Column(columnName)) + + /** + * Computes the natural logarithm of the given value. + * + * @group double_funcs + */ + def log(e: Column): Column = Log(e.expr) + + /** + * Computes the natural logarithm of the given column. + * + * @group double_funcs + */ + def log(columnName: String): Column = log(Column(columnName)) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(e: Column): Column = Log10(e.expr) + + /** + * Computes the logarithm of the given value in Base 10. + * + * @group double_funcs + */ + def log10(columnName: String): Column = log10(Column(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + * + * @group double_funcs + */ + def log1p(e: Column): Column = Log1p(e.expr) + + /** + * Computes the natural logarithm of the given column plus one. + * + * @group double_funcs + */ + def log1p(columnName: String): Column = log1p(Column(columnName)) + + /** + * Computes the exponential of the given value. + * + * @group double_funcs + */ + def exp(e: Column): Column = Exp(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def exp(columnName: String): Column = exp(Column(columnName)) + + /** + * Computes the exponential of the given value minus one. + * + * @group double_funcs + */ + def expm1(e: Column): Column = Expm1(e.expr) + + /** + * Computes the exponential of the given column. + * + * @group double_funcs + */ + def expm1(columnName: String): Column = expm1(Column(columnName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + * + * @group double_funcs + */ + def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, rightName: String): Column = + hypot(Column(leftName), Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * @group double_funcs + */ + def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, rightName: String): Column = + atan2(Column(leftName), Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * @group double_funcs + */ + def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index e02c84872c628..e5c9504d21042 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -41,6 +41,7 @@ import java.util.Map; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.mathfunctions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -98,6 +99,14 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); + + // Varargs with mathfunctions + DataFrame df2 = context.table("testData2"); + df2.select(exp("a"), exp("b")); + df2.select(exp(log("a"))); + df2.select(pow("a", "a"), pow("b", 2.0)); + df2.select(pow(col("a"), col("b")), exp("b")); + df2.select(sin("a"), acos("b")); } @Ignore diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 904073b8cb2aa..680b5c636960d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala new file mode 100644 index 0000000000000..561553cc925cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.{Double => JavaDouble} + +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.mathfunctions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +private[this] object MathExpressionsTestData { + + case class DoubleData(a: JavaDouble, b: JavaDouble) + val doubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() + + val nnDoubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() + + case class NullDoubles(a: JavaDouble) + val nullDoubles = + TestSQLContext.sparkContext.parallelize( + NullDoubles(1.0) :: + NullDoubles(2.0) :: + NullDoubles(3.0) :: + NullDoubles(null) :: Nil + ).toDF() +} + +class MathExpressionsSuite extends QueryTest { + + import MathExpressionsTestData._ + + def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + c: Column => Column, + f: T => T): Unit = { + checkAnswer( + doubleData.select(c('a)), + (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c('b)), + (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a)), + (1 to 10).map(n => Row(f(n * 0.1))) + ) + + if (f(-1) === math.log1p(-1)) { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) + ) + } else { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 10).map(n => Row(null)) + ) + } + + checkAnswer( + nnDoubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testTwoToOneMathFunction( + c: (Column, Column) => Column, + d: (Column, Double) => Column, + f: (Double, Double) => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a, 'a)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + + checkAnswer( + nnDoubleData.select(c('a, 'b)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) + ) + + checkAnswer( + nnDoubleData.select(d('a, 2.0)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) + ) + + checkAnswer( + nnDoubleData.select(d('a, -0.5)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) + ) + + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) + + checkAnswer( + nullDoubles.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + } + + test("sin") { + testOneToOneMathFunction(sin, math.sin) + } + + test("asin") { + testOneToOneMathFunction(asin, math.asin) + } + + test("sinh") { + testOneToOneMathFunction(sinh, math.sinh) + } + + test("cos") { + testOneToOneMathFunction(cos, math.cos) + } + + test("acos") { + testOneToOneMathFunction(acos, math.acos) + } + + test("cosh") { + testOneToOneMathFunction(cosh, math.cosh) + } + + test("tan") { + testOneToOneMathFunction(tan, math.tan) + } + + test("atan") { + testOneToOneMathFunction(atan, math.atan) + } + + test("tanh") { + testOneToOneMathFunction(tanh, math.tanh) + } + + test("toDeg") { + testOneToOneMathFunction(toDeg, math.toDegrees) + } + + test("toRad") { + testOneToOneMathFunction(toRad, math.toRadians) + } + + test("cbrt") { + testOneToOneMathFunction(cbrt, math.cbrt) + } + + test("ceil") { + testOneToOneMathFunction(ceil, math.ceil) + } + + test("floor") { + testOneToOneMathFunction(floor, math.floor) + } + + test("rint") { + testOneToOneMathFunction(rint, math.rint) + } + + test("exp") { + testOneToOneMathFunction(exp, math.exp) + } + + test("expm1") { + testOneToOneMathFunction(expm1, math.expm1) + } + + test("signum") { + testOneToOneMathFunction[Double](signum, math.signum) + } + + test("isignum") { + testOneToOneMathFunction[Int](isignum, math.signum) + } + + test("fsignum") { + testOneToOneMathFunction[Float](fsignum, math.signum) + } + + test("lsignum") { + testOneToOneMathFunction[Long](lsignum, math.signum) + } + + test("pow") { + testTwoToOneMathFunction(pow, pow, math.pow) + } + + test("hypot") { + testTwoToOneMathFunction(hypot, hypot, math.hypot) + } + + test("atan2") { + testTwoToOneMathFunction(atan2, atan2, math.atan2) + } + + test("log") { + testOneToOneNonNegativeMathFunction(log, math.log) + } + + test("log10") { + testOneToOneNonNegativeMathFunction(log10, math.log10) + } + + test("log1p") { + testOneToOneNonNegativeMathFunction(log1p, math.log1p) + } + +} From 9e4e82b7bca1129bcd5e0274b9ae1b1be3fb93da Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 27 Apr 2015 23:48:02 -0700 Subject: [PATCH 04/18] [SPARK-5946] [STREAMING] Add Python API for direct Kafka stream Currently only added `createDirectStream` API, I'm not sure if `createRDD` is also needed, since some Java object needs to be wrapped in Python. Please help to review, thanks a lot. Author: jerryshao Author: Saisai Shao Closes #4723 from jerryshao/direct-kafka-python-api and squashes the following commits: a1fe97c [jerryshao] Fix rebase issue eebf333 [jerryshao] Address the comments da40f4e [jerryshao] Fix Python 2.6 Syntax error issue 5c0ee85 [jerryshao] Style fix 4aeac18 [jerryshao] Fix bug in example code 7146d86 [jerryshao] Add unit test bf3bdd6 [jerryshao] Add more APIs and address the comments f5b3801 [jerryshao] Small style fix 8641835 [Saisai Shao] Rebase and update the code 589c05b [Saisai Shao] Fix the style d6fcb6a [Saisai Shao] Address the comments dfda902 [Saisai Shao] Style fix 0f7d168 [Saisai Shao] Add the doc and fix some style issues 67e6880 [Saisai Shao] Fix test bug 917b0db [Saisai Shao] Add Python createRDD API for Kakfa direct stream c3fc11d [jerryshao] Modify the docs 2c00936 [Saisai Shao] address the comments 3360f44 [jerryshao] Fix code style e0e0f0d [jerryshao] Code clean and bug fix 338c41f [Saisai Shao] Add python API and example for direct kafka stream --- .../streaming/direct_kafka_wordcount.py | 55 ++++++ .../spark/streaming/kafka/KafkaUtils.scala | 92 +++++++++- python/pyspark/streaming/kafka.py | 167 +++++++++++++++++- python/pyspark/streaming/tests.py | 84 ++++++++- 4 files changed, 383 insertions(+), 15 deletions(-) create mode 100644 examples/src/main/python/streaming/direct_kafka_wordcount.py diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py new file mode 100644 index 0000000000000..6ef188a220c51 --- /dev/null +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. + Usage: direct_kafka_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + + and then run the example + `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ + spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/direct_kafka_wordcount.py \ + localhost:9092 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") + ssc = StreamingContext(sc, 2) + + brokers, topic = sys.argv[1:] + kvs = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 5a9bd4214cf51..0721ddaf7055a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -21,6 +21,7 @@ import java.lang.{Integer => JInt} import java.lang.{Long => JLong} import java.util.{Map => JMap} import java.util.{Set => JSet} +import java.util.{List => JList} import scala.reflect.ClassTag import scala.collection.JavaConversions._ @@ -234,7 +235,6 @@ object KafkaUtils { new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) } - /** * Create a RDD from Kafka using offset ranges for each topic and partition. * @@ -558,4 +558,94 @@ private class KafkaUtilsPythonHelper { topics, storageLevel) } + + def createRDD( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jrdd = KafkaUtils.createRDD[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jsc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders, + messageHandler + ) + new JavaPairRDD(jrdd.rdd) + } + + def createDirectStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong] + ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { + + if (!fromOffsets.isEmpty) { + import scala.collection.JavaConversions._ + val topicsFromOffsets = fromOffsets.keySet().map(_.topic) + if (topicsFromOffsets != topics.toSet) { + throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") + } + } + + if (fromOffsets.isEmpty) { + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics) + } else { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jstream = KafkaUtils.createDirectStream[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + fromOffsets, + messageHandler) + new JavaPairInputDStream(jstream.inputDStream) + } + } + + def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong + ): OffsetRange = OffsetRange.create(topic, partition, fromOffset, untilOffset) + + def createTopicAndPartition(topic: String, partition: JInt): TopicAndPartition = + TopicAndPartition(topic, partition) + + def createBroker(host: String, port: JInt): Broker = Broker(host, port) } diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 8d610d6569b4a..e278b29003f69 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -17,11 +17,12 @@ from py4j.java_gateway import Py4JJavaError +from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream -__all__ = ['KafkaUtils', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): @@ -67,7 +68,104 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): - print(""" + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create an input stream that directly pulls messages from a Kafka Broker and specific offset. + + This is not a receiver based Kafka input stream, it directly pulls the message from Kafka + in each batch duration and processed without storing. + + This does not use Zookeeper to store offsets. The consumed offsets are tracked + by the stream itself. For interoperability with Kafka monitoring tools that depend on + Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + You can access the offsets used in each batch from the generated RDDs (see + + To recover from driver failures, you have to enable checkpointing in the StreamingContext. + The information on consumed offset can be recovered from the checkpoint. + See the programming guide for details (constraints, etc.). + + :param ssc: StreamingContext object. + :param topics: list of topic_name to consume. + :param kafkaParams: Additional params for Kafka. + :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting + point of the stream. + :param keyDecoder: A function used to decode key (default is utf8_decoder). + :param valueDecoder: A function used to decode value (default is utf8_decoder). + :return: A DStream object + """ + if not isinstance(topics, list): + raise TypeError("topics should be list") + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + + jfromOffsets = dict([(k._jTopicAndPartition(helper), + v) for (k, v) in fromOffsets.items()]) + jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createRDD(sc, kafkaParams, offsetRanges, leaders={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create a RDD from Kafka using offset ranges for each topic and partition. + :param sc: SparkContext object + :param kafkaParams: Additional params for Kafka + :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume + :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty + map, in which case leaders will be looked up on the driver. + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) + :return: A RDD object + """ + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + if not isinstance(offsetRanges, list): + raise TypeError("offsetRanges should be list") + + try: + helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] + jleaders = dict([(k._jTopicAndPartition(helper), + v._jBroker(helper)) for (k, v) in leaders.items()]) + jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(sc) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser) + return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def _printErrorMsg(sc): + print(""" ________________________________________________________________________________________________ Spark Streaming's Kafka libraries not found in class path. Try one of the following. @@ -85,8 +183,63 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, ________________________________________________________________________________________________ -""" % (ssc.sparkContext.version, ssc.sparkContext.version)) - raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) +""" % (sc.version, sc.version)) + + +class OffsetRange(object): + """ + Represents a range of offsets from a single Kafka TopicAndPartition. + """ + + def __init__(self, topic, partition, fromOffset, untilOffset): + """ + Create a OffsetRange to represent range of offsets + :param topic: Kafka topic name. + :param partition: Kafka partition id. + :param fromOffset: Inclusive starting offset. + :param untilOffset: Exclusive ending offset. + """ + self._topic = topic + self._partition = partition + self._fromOffset = fromOffset + self._untilOffset = untilOffset + + def _jOffsetRange(self, helper): + return helper.createOffsetRange(self._topic, self._partition, self._fromOffset, + self._untilOffset) + + +class TopicAndPartition(object): + """ + Represents a specific top and partition for Kafka. + """ + + def __init__(self, topic, partition): + """ + Create a Python TopicAndPartition to map to the Java related object + :param topic: Kafka topic name. + :param partition: Kafka partition id. + """ + self._topic = topic + self._partition = partition + + def _jTopicAndPartition(self, helper): + return helper.createTopicAndPartition(self._topic, self._partition) + + +class Broker(object): + """ + Represent the host and port info for a Kafka broker. + """ + + def __init__(self, host, port): + """ + Create a Python Broker to map to the Java related object. + :param host: Broker's hostname. + :param port: Broker's port. + """ + self._host = host + self._port = port + + def _jBroker(self, helper): + return helper.createBroker(self._host, self._port) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5fa1e5ef081ab..7c06c203455d9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -21,6 +21,7 @@ import time import operator import tempfile +import random import struct from functools import reduce @@ -35,7 +36,7 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kafka import KafkaUtils +from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition class PySparkStreamingTestCase(unittest.TestCase): @@ -590,9 +591,27 @@ def tearDown(self): super(KafkaStreamTests, self).tearDown() + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _validateStreamResult(self, sendData, stream): + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + + def _validateRddResult(self, sendData, rdd): + result = {} + for i in rdd.map(lambda x: x[1]).collect(): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + def test_kafka_stream(self): """Test the Python Kafka stream API.""" - topic = "topic1" + topic = self._randomTopic() sendData = {"a": 3, "b": 5, "c": 10} self._kafkaTestUtils.createTopic(topic) @@ -601,13 +620,64 @@ def test_kafka_stream(self): stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, {"auto.offset.reset": "smallest"}) + self._validateStreamResult(sendData, stream) - result = {} - for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), - sum(sendData.values()))): - result[i] = result.get(i, 0) + 1 + def test_kafka_direct_stream(self): + """Test the Python direct Kafka stream API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} - self.assertEqual(sendData, result) + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_from_offset(self): + """Test the Python direct Kafka stream API with start offset specified.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + fromOffsets = {TopicAndPartition(topic, 0): long(0)} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd(self): + """Test the Python direct Kafka RDD API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self._validateRddResult(sendData, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_with_leaders(self): + """Test the Python direct Kafka RDD API with leaders.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + address = self._kafkaTestUtils.brokerAddress().split(":") + leaders = {TopicAndPartition(topic, 0): Broker(address[0], int(address[1]))} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) + self._validateRddResult(sendData, rdd) if __name__ == "__main__": unittest.main() From bf35edd9d4b8b11df9f47b6ff43831bc95f06322 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 28 Apr 2015 00:38:14 -0700 Subject: [PATCH 05/18] [SPARK-7187] SerializationDebugger should not crash user code rxin Author: Andrew Or Closes #5734 from andrewor14/ser-deb and squashes the following commits: e8aad6c [Andrew Or] NonFatal 57d0ef4 [Andrew Or] try catch improveException --- .../spark/serializer/SerializationDebugger.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cecb992579655..5abfa467c0ec8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -23,6 +23,7 @@ import java.security.AccessController import scala.annotation.tailrec import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.Logging @@ -35,8 +36,15 @@ private[serializer] object SerializationDebugger extends Logging { */ def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { if (enableDebugging && reflect != null) { - new NotSerializableException( - e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + try { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } catch { + case NonFatal(t) => + // Fall back to old exception + logWarning("Exception in serialization debugger", t) + e + } } else { e } From d94cd1a733d5715792e6c4eac87f0d5c81aebbe2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Apr 2015 00:39:08 -0700 Subject: [PATCH 06/18] [SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs. Author: Reynold Xin Closes #5709 from rxin/inc-id and squashes the following commits: 7853611 [Reynold Xin] private sql. a9fda0d [Reynold Xin] Missed a few numbers. 343d896 [Reynold Xin] Self review feedback. a7136cb [Reynold Xin] [SPARK-7135][SQL] DataFrame expression for monotonically increasing IDs. --- python/pyspark/sql/functions.py | 22 +++++++- .../MonotonicallyIncreasingID.scala | 53 +++++++++++++++++++ .../expressions/SparkPartitionID.scala | 6 +-- .../org/apache/spark/sql/functions.scala | 16 ++++++ .../spark/sql/ColumnExpressionSuite.scala | 11 ++++ 5 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f48b7b5d10af7..7b86655d9c82f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -103,8 +103,28 @@ def countDistinct(col, *cols): return Column(jc) +def monotonicallyIncreasingId(): + """A column that generates monotonically increasing 64-bit integers. + + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the record number + within each partition in the lower 33 bits. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records. + + As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + This expression would return the following IDs: + 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + + >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) + >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.monotonicallyIncreasingId()) + + def sparkPartitionId(): - """Returns a column for partition ID of the Spark task. + """A column for partition ID of the Spark task. Note that this is indeterministic because it depends on data partitioning and task scheduling. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala new file mode 100644 index 0000000000000..9ac732b55b188 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression} +import org.apache.spark.sql.types.{LongType, DataType} + +/** + * Returns monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits + * represent the record number within each partition. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * Since this expression is stateful, it cannot be a case object. + */ +private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { + + /** + * Record ID within each partition. By being transient, count's value is reset to 0 every time + * we serialize and deserialize it. + */ + @transient private[this] var count: Long = 0L + + override type EvaluatedType = Long + + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def eval(input: Row): Long = { + val currentCount = count + count += 1 + (TaskContext.get().partitionId().toLong << 33) + currentCount + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index fe7607c6ac340..c2c6cbd491598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -18,16 +18,14 @@ package org.apache.spark.sql.execution.expressions import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.expressions.{Row, Expression} -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row} import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -case object SparkPartitionID extends Expression with trees.LeafNode[Expression] { - self: Product => +private[sql] case object SparkPartitionID extends LeafExpression { override type EvaluatedType = Int diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 9738fd4f93bad..aa31d04a0cbe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -301,6 +301,22 @@ object functions { */ def lower(e: Column): Column = Lower(e.expr) + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + */ + def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + /** * Unary minus, i.e. negate the expression. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 680b5c636960d..2ba5fc21ff57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -309,6 +309,17 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("monotonicallyIncreasingId") { + // Make sure we have 2 partitions, each with 2 records. + val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") + checkAnswer( + df.select(monotonicallyIncreasingId()), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) + } + test("sparkPartitionId") { val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( From e13cd86567a43672297bb488088dd8f40ec799bf Mon Sep 17 00:00:00 2001 From: Pei-Lun Lee Date: Tue, 28 Apr 2015 16:50:18 +0800 Subject: [PATCH 07/18] [SPARK-6352] [SQL] Custom parquet output committer Add new config "spark.sql.parquet.output.committer.class" to allow custom parquet output committer and an output committer class specific to use on s3. Fix compilation error introduced by https://github.com/apache/spark/pull/5042. Respect ParquetOutputFormat.ENABLE_JOB_SUMMARY flag. Author: Pei-Lun Lee Closes #5525 from ypcat/spark-6352 and squashes the following commits: 54c6b15 [Pei-Lun Lee] error handling 472870e [Pei-Lun Lee] add back custom parquet output committer ddd0f69 [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 9ece5c5 [Pei-Lun Lee] compatibility with hadoop 1.x 8413fcd [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 fe65915 [Pei-Lun Lee] add support for parquet config parquet.enable.summary-metadata e17bf47 [Pei-Lun Lee] Merge branch 'master' of https://github.com/apache/spark into spark-6352 9ae7545 [Pei-Lun Lee] [SPARL-6352] [SQL] Change to allow custom parquet output committer. 0d540b9 [Pei-Lun Lee] [SPARK-6352] [SQL] add license c42468c [Pei-Lun Lee] [SPARK-6352] [SQL] add test case 0fc03ca [Pei-Lun Lee] [SPARK-6532] [SQL] hide class DirectParquetOutputCommitter 769bd67 [Pei-Lun Lee] DirectParquetOutputCommitter f75e261 [Pei-Lun Lee] DirectParquetOutputCommitter --- .../DirectParquetOutputCommitter.scala | 73 +++++++++++++++++++ .../sql/parquet/ParquetTableOperations.scala | 21 ++++++ .../spark/sql/parquet/ParquetIOSuite.scala | 22 ++++++ 3 files changed, 116 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala new file mode 100644 index 0000000000000..f5ce2718bec4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + +import parquet.Log +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} + +private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + val LOG = Log.getLog(classOf[ParquetOutputCommitter]) + + override def getWorkPath(): Path = outputPath + override def abortTask(taskContext: TaskAttemptContext): Unit = {} + override def commitTask(taskContext: TaskAttemptContext): Unit = {} + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true + override def setupJob(jobContext: JobContext): Unit = {} + override def setupTask(taskContext: TaskAttemptContext): Unit = {} + + override def commitJob(jobContext: JobContext) { + val configuration = ContextUtil.getConfiguration(jobContext) + val fileSystem = outputPath.getFileSystem(configuration) + + if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { + try { + val outputStatus = fileSystem.getFileStatus(outputPath) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) + try { + ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) + } catch { + case e: Exception => { + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) + } + } + } + } catch { + case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) + } + } + + if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { + try { + val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSystem.create(successPath).close() + } catch { + case e: Exception => LOG.warn("could not write success file for " + outputPath, e) + } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index a938b77578686..aded126ea0615 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -381,6 +381,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} + var committer: OutputCommitter = null // override to choose output filename so not overwrite existing ones override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { @@ -403,6 +404,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } + + // override to create output committer from configuration + override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + if (committer == null) { + val output = getOutputPath(context) + val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] + } + committer + } + + // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 + private def getOutputPath(context: TaskAttemptContext): Path = { + context.getConfiguration().get("mapred.output.dir") match { + case null => null + case name => new Path(name) + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 97c0f439acf13..b504842053690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -381,6 +381,28 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } + + test("SPARK-6352 DirectParquetOutputCommitter") { + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } + finally { + configuration.set("spark.sql.parquet.output.committer.class", + "parquet.hadoop.ParquetOutputCommitter") + } + } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { From 7f3b3b7eb7d14767124a28ec0062c4d60d6c16fc Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 28 Apr 2015 07:48:34 -0400 Subject: [PATCH 08/18] [SPARK-7168] [BUILD] Update plugin versions in Maven build and centralize versions Update Maven build plugin versions and centralize plugin version management Author: Sean Owen Closes #5720 from srowen/SPARK-7168 and squashes the following commits: 98a8947 [Sean Owen] Make install, deploy plugin versions explicit 4ecf3b2 [Sean Owen] Update Maven build plugin versions and centralize plugin version management --- assembly/pom.xml | 1 - core/pom.xml | 1 - network/common/pom.xml | 1 - pom.xml | 44 ++++++++++++++++++++++++++++++------------ sql/hive/pom.xml | 1 - 5 files changed, 32 insertions(+), 16 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 20593e710dedb..2b4d0a990bf22 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -194,7 +194,6 @@ org.apache.maven.plugins maven-assembly-plugin - 2.4 dist diff --git a/core/pom.xml b/core/pom.xml index 5e89d548cd47f..459ef66712c36 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -478,7 +478,6 @@ org.codehaus.mojo exec-maven-plugin - 1.3.2 sparkr-pkg diff --git a/network/common/pom.xml b/network/common/pom.xml index 22c738bde6d42..0c3147761cfc5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -95,7 +95,6 @@ org.apache.maven.plugins maven-jar-plugin - 2.2 test-jar-on-test-compile diff --git a/pom.xml b/pom.xml index 9fbce1d639d8b..928f5d0f5efad 100644 --- a/pom.xml +++ b/pom.xml @@ -1082,7 +1082,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.3.1 + 1.4 enforce-versions @@ -1105,7 +1105,7 @@ org.codehaus.mojo build-helper-maven-plugin - 1.8 + 1.9.1 net.alchim31.maven @@ -1176,7 +1176,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.1 + 3.3 ${java.version} ${java.version} @@ -1189,7 +1189,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.18 + 2.18.1 @@ -1260,17 +1260,17 @@ org.apache.maven.plugins maven-jar-plugin - 2.4 + 2.6 org.apache.maven.plugins maven-antrun-plugin - 1.7 + 1.8 org.apache.maven.plugins maven-source-plugin - 2.2.1 + 2.4 true @@ -1287,7 +1287,7 @@ org.apache.maven.plugins maven-clean-plugin - 2.5 + 2.6.1 @@ -1305,7 +1305,27 @@ org.apache.maven.plugins maven-javadoc-plugin - 2.10.1 + 2.10.3 + + + org.codehaus.mojo + exec-maven-plugin + 1.4.0 + + + org.apache.maven.plugins + maven-assembly-plugin + 2.5.3 + + + org.apache.maven.plugins + maven-install-plugin + 2.5.2 + + + org.apache.maven.plugins + maven-deploy-plugin + 2.8.2 @@ -1315,7 +1335,7 @@ org.apache.maven.plugins maven-dependency-plugin - 2.9 + 2.10 test-compile @@ -1334,7 +1354,7 @@ org.codehaus.gmavenplus gmavenplus-plugin - 1.2 + 1.5 process-test-classes @@ -1359,7 +1379,7 @@ org.apache.maven.plugins maven-shade-plugin - 2.2 + 2.3 false diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 21dce8d8a565a..e322340094e6f 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -183,7 +183,6 @@ org.apache.maven.plugins maven-dependency-plugin - 2.4 copy-dependencies From 75905c57cd57bc5b650ac5f486580ef8a229b260 Mon Sep 17 00:00:00 2001 From: Jim Carroll Date: Tue, 28 Apr 2015 07:51:02 -0400 Subject: [PATCH 09/18] [SPARK-7100] [MLLIB] Fix persisted RDD leak in GradientBoostTrees This fixes a leak of a persisted RDD where GradientBoostTrees can call persist but never unpersists. Jira: https://issues.apache.org/jira/browse/SPARK-7100 Discussion: http://apache-spark-developers-list.1001551.n3.nabble.com/GradientBoostTrees-leaks-a-persisted-RDD-td11750.html Author: Jim Carroll Closes #5669 from jimfcarroll/gb-unpersist-fix and squashes the following commits: 45f4b03 [Jim Carroll] [SPARK-7100][MLLib] Fix persisted RDD leak in GradientBoostTrees --- .../apache/spark/mllib/tree/GradientBoostedTrees.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 0e31c7ed58df8..deac390130128 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -177,9 +177,10 @@ object GradientBoostedTrees extends Logging { treeStrategy.assertValid() // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { input.persist(StorageLevel.MEMORY_AND_DISK) - } + true + } else false timer.stop("init") @@ -265,6 +266,9 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + + if (persistedInput) input.unpersist() + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, From 268c419f1586110b90e68f98cd000a782d18828c Mon Sep 17 00:00:00 2001 From: Masayoshi TSUZUKI Date: Tue, 28 Apr 2015 07:55:21 -0400 Subject: [PATCH 10/18] [SPARK-6435] spark-shell --jars option does not add all jars to classpath Modified to accept double-quotated args properly in spark-shell.cmd. Author: Masayoshi TSUZUKI Closes #5227 from tsudukim/feature/SPARK-6435-2 and squashes the following commits: ac55787 [Masayoshi TSUZUKI] removed unnecessary argument. 60789a7 [Masayoshi TSUZUKI] Merge branch 'master' of https://github.com/apache/spark into feature/SPARK-6435-2 1fee420 [Masayoshi TSUZUKI] fixed test code for escaping '='. 0d4dc41 [Masayoshi TSUZUKI] - escaped comman and semicolon in CommandBuilderUtils.java - added random string to the temporary filename - double-quotation followed by `cmd /c` did not worked properly - no need to escape `=` by `^` - if double-quoted string ended with `\` like classpath, the last `\` is parsed as the escape charactor and the closing `"` didn't work properly 2a332e5 [Masayoshi TSUZUKI] Merge branch 'master' into feature/SPARK-6435-2 04f4291 [Masayoshi TSUZUKI] [SPARK-6435] spark-shell --jars option does not add all jars to classpath --- bin/spark-class2.cmd | 5 ++++- .../org/apache/spark/launcher/CommandBuilderUtils.java | 9 ++++----- .../src/main/java/org/apache/spark/launcher/Main.java | 6 +----- .../apache/spark/launcher/CommandBuilderUtilsSuite.java | 5 ++++- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 3d068dd3a2739..db09fa27e51a6 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -61,7 +61,10 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( +set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt +"%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i ) +del %LAUNCHER_OUTPUT% %SPARK_CMD% diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 8028e42ffb483..261402856ac5e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -244,7 +244,7 @@ static String quoteForBatchScript(String arg) { boolean needsQuotes = false; for (int i = 0; i < arg.length(); i++) { int c = arg.codePointAt(i); - if (Character.isWhitespace(c) || c == '"' || c == '=') { + if (Character.isWhitespace(c) || c == '"' || c == '=' || c == ',' || c == ';') { needsQuotes = true; break; } @@ -261,15 +261,14 @@ static String quoteForBatchScript(String arg) { quoted.append('"'); break; - case '=': - quoted.append('^'); - break; - default: break; } quoted.appendCodePoint(cp); } + if (arg.codePointAt(arg.length() - 1) == '\\') { + quoted.append("\\"); + } quoted.append("\""); return quoted.toString(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 206acfb514d86..929b29a49ed70 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -101,12 +101,9 @@ public static void main(String[] argsArray) throws Exception { * The method quotes all arguments so that spaces are handled as expected. Quotes within arguments * are "double quoted" (which is batch for escaping a quote). This page has more details about * quoting and other batch script fun stuff: http://ss64.com/nt/syntax-esc.html - * - * The command is executed using "cmd /c" and formatted in single line, since that's the - * easiest way to consume this from a batch script (see spark-class2.cmd). */ private static String prepareWindowsCommand(List cmd, Map childEnv) { - StringBuilder cmdline = new StringBuilder("cmd /c \""); + StringBuilder cmdline = new StringBuilder(); for (Map.Entry e : childEnv.entrySet()) { cmdline.append(String.format("set %s=%s", e.getKey(), e.getValue())); cmdline.append(" && "); @@ -115,7 +112,6 @@ private static String prepareWindowsCommand(List cmd, Map Date: Tue, 28 Apr 2015 09:46:08 -0700 Subject: [PATCH 11/18] [SPARK-5253] [ML] LinearRegression with L1/L2 (ElasticNet) using OWLQN Author: DB Tsai Author: DB Tsai Closes #4259 from dbtsai/lir and squashes the following commits: a81c201 [DB Tsai] add import org.apache.spark.util.Utils back 9fc48ed [DB Tsai] rebase 2178b63 [DB Tsai] add comments 9988ca8 [DB Tsai] addressed feedback and fixed a bug. TODO: documentation and build another synthetic dataset which can catch the bug fixed in this commit. fcbaefe [DB Tsai] Refactoring 4eb078d [DB Tsai] first commit --- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +- .../spark/ml/param/shared/sharedParams.scala | 34 ++ .../ml/regression/LinearRegression.scala | 304 ++++++++++++++++-- .../apache/spark/mllib/linalg/Vectors.scala | 8 +- .../spark/mllib/optimization/Gradient.scala | 6 +- .../spark/mllib/optimization/LBFGS.scala | 15 +- .../mllib/util/LinearDataGenerator.scala | 43 ++- .../ml/regression/LinearRegressionSuite.scala | 158 +++++++-- 8 files changed, 508 insertions(+), 64 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e88c48741e99f..3f7e8f5a0b22c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -46,7 +46,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()"))) + ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), + ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"), + ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a860b8834cff9..7d2c76d6c62c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -276,4 +276,38 @@ trait HasSeed extends Params { /** @group getParam */ final def getSeed: Long = getOrDefault(seed) } + +/** + * :: DeveloperApi :: + * Trait for shared param elasticNetParam. + */ +@DeveloperApi +trait HasElasticNetParam extends Params { + + /** + * Param for the ElasticNet mixing parameter. + * @group param + */ + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter") + + /** @group getParam */ + final def getElasticNetParam: Double = getOrDefault(elasticNetParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param tol. + */ +@DeveloperApi +trait HasTol extends Params { + + /** + * Param for the convergence tolerance for iterative algorithms. + * @group param + */ + final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms") + + /** @group getParam */ + final def getTol: Double = getOrDefault(tol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 26ca7459c4fdf..f92c6816eb54c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -17,21 +17,29 @@ package org.apache.spark.ml.regression +import scala.collection.mutable + +import breeze.linalg.{norm => brzNorm, DenseVector => BDV} +import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction} + import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Params, ParamMap} -import org.apache.spark.ml.param.shared._ -import org.apache.spark.mllib.linalg.{BLAS, Vector} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends RegressorParams - with HasRegParam with HasMaxIter - + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** * :: AlphaComponent :: @@ -42,34 +50,119 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setDefault(regParam -> 0.1, maxIter -> 100) - - /** @group setParam */ + /** + * Set the regularization parameter. + * Default is 0.0. + * @group setParam + */ def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Set the ElasticNet mixing parameter. + * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + * For 0 < alpha < 1, the penalty is a combination of L1 and L2. + * Default is 0.0 which is an L2 penalty. + * @group setParam + */ + def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) + setDefault(elasticNetParam -> 0.0) - /** @group setParam */ + /** + * Set the maximal number of iterations. + * Default is 100. + * @group setParam + */ def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val oldDataset = extractLabeledPoints(dataset, paramMap) + // Extract columns from data. If dataset is persisted, do not persist instances. + val instances = extractLabeledPoints(dataset, paramMap).map { + case LabeledPoint(label: Double, features: Vector) => (label, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) { - oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + + val (summarizer, statCounter) = instances.treeAggregate( + (new MultivariateOnlineSummarizer, new StatCounter))( { + case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), + (label: Double, features: Vector)) => + (summarizer.add(features), statCounter.merge(label)) + }, { + case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), + (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => + (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) + }) + + val numFeatures = summarizer.mean.size + val yMean = statCounter.mean + val yStd = math.sqrt(statCounter.variance) + + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) + + // Since we implicitly do the feature scaling when we compute the cost function + // to improve the convergence, the effective regParam will be changed. + val effectiveRegParam = paramMap(regParam) / yStd + val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam + val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam + + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, + featuresStd, featuresMean, effectiveL2RegParam) + + val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { + new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol)) + } else { + new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol)) + } + + val initialWeights = Vectors.zeros(numFeatures) + val states = + optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) + + var state = states.next() + val lossHistory = mutable.ArrayBuilder.make[Double] + + while (states.hasNext) { + lossHistory += state.value + state = states.next() + } + lossHistory += state.value + + // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector. + // The weights are trained in the scaled space; we're converting them back to + // the original space. + val weights = { + val rawWeights = state.x.toArray.clone() + var i = 0 + while (i < rawWeights.length) { + rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } + i += 1 + } + Vectors.dense(rawWeights) } - // Train model - val lr = new LinearRegressionWithSGD() - lr.optimizer - .setRegParam(paramMap(regParam)) - .setNumIterations(paramMap(maxIter)) - val model = lr.run(oldDataset) - val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + // The intercept in R's GLMNET is computed using closed form after the coefficients are + // converged. See the following discussion for detail. + // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) if (handlePersistence) { - oldDataset.unpersist() + instances.unpersist() } - lrm + new LinearRegressionModel(this, paramMap, weights, intercept) } } @@ -88,7 +181,7 @@ class LinearRegressionModel private[ml] ( with LinearRegressionParams { override protected def predict(features: Vector): Double = { - BLAS.dot(features, weights) + intercept + dot(features, weights) + intercept } override protected def copy(): LinearRegressionModel = { @@ -97,3 +190,168 @@ class LinearRegressionModel private[ml] ( m } } + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in a online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + + * * Compute gradient and loss for a Least-squared loss function, as used in linear regression. + * This is correct for the averaged least squares loss function (mean squared error) + * L = 1/2n ||A weights-y||^2 + * See also the documentation for the precise formulation. + * + * @param weights weights/coefficients corresponding to features + * + * @param updater Updater to be used to update weights after every iteration. + */ +private class LeastSquaresAggregator( + weights: Vector, + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double]) extends Serializable { + + private var totalCnt: Long = 0 + private var lossSum = 0.0 + private var diffSum = 0.0 + + private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { + val weightsArray = weights.toArray.clone() + var sum = 0.0 + var i = 0 + while (i < weightsArray.length) { + if (featuresStd(i) != 0.0) { + weightsArray(i) /= featuresStd(i) + sum += weightsArray(i) * featuresMean(i) + } else { + weightsArray(i) = 0.0 + } + i += 1 + } + (weightsArray, -sum + labelMean / labelStd, weightsArray.length) + } + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + + private val gradientSumArray: Array[Double] = Array.ofDim[Double](dim) + + /** + * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param label The label for this data point. + * @param data The features for one data point in dense/sparse vector format to be added + * into this aggregator. + * @return This LeastSquaresAggregator object. + */ + def add(label: Double, data: Vector): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${data.size}.") + + val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + data.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += diff * value / featuresStd(index) + } + } + lossSum += diff * diff / 2.0 + diffSum += diff + } + + totalCnt += 1 + this + } + + /** + * Merge another LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other LeastSquaresAggregator to be merged. + * @return This LeastSquaresAggregator object. + */ + def merge(other: LeastSquaresAggregator): this.type = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + + if (other.totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + diffSum += other.diffSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + def count: Long = totalCnt + + def loss: Double = lossSum / totalCnt + + def gradient: Vector = { + val result = Vectors.dense(gradientSumArray.clone()) + + val correction = { + val temp = effectiveWeightsArray.clone() + var i = 0 + while (i < temp.length) { + temp(i) *= featuresMean(i) + i += 1 + } + Vectors.dense(temp) + } + + axpy(-diffSum, correction, result) + scal(1.0 / totalCnt, result) + result + } +} + +/** + * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. + * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It's used in Breeze's convex optimization routines. + */ +private class LeastSquaresCostFun( + data: RDD[(Double, Vector)], + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double], + effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { + + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + val w = Vectors.fromBreeze(weights) + + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + labelMean, featuresStd, featuresMean))( + seqOp = (c, v) => (c, v) match { + case (aggregator, (label, features)) => aggregator.add(label, features) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + // regVal is the sum of weight squares for L2 regularization + val norm = brzNorm(weights, 2.0) + val regVal = 0.5 * effectiveL2regParam * norm * norm + + val loss = leastSquaresAggregator.loss + regVal + val gradient = leastSquaresAggregator.gradient + axpy(effectiveL2regParam, w, gradient) + + (loss, gradient.toBreeze.asInstanceOf[BDV[Double]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 166c00cff634d..af0cfe22ca10d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -85,7 +85,7 @@ sealed trait Vector extends Serializable { /** * Converts the instance to a breeze vector. */ - private[mllib] def toBreeze: BV[Double] + private[spark] def toBreeze: BV[Double] /** * Gets the value of the ith element. @@ -284,7 +284,7 @@ object Vectors { /** * Creates a vector instance from a breeze vector. */ - private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { + private[spark] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { @@ -483,7 +483,7 @@ class DenseVector(val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values - private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int): Double = values(i) @@ -543,7 +543,7 @@ class SparseVector( new SparseVector(size, indices.clone(), values.clone()) } - private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 8bfa0d2b64995..240baeb5a158b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -37,7 +37,11 @@ abstract class Gradient extends Serializable { * * @return (gradient: Vector, loss: Double) */ - def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } /** * Compute the gradient and loss given the features of a single data point, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index ef6eccd90711a..efedc112d380e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.optimization +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV} @@ -164,7 +165,7 @@ object LBFGS extends Logging { regParam: Double, initialWeights: Vector): (Vector, Array[Double]) = { - val lossHistory = new ArrayBuffer[Double](maxNumIterations) + val lossHistory = mutable.ArrayBuilder.make[Double] val numExamples = data.count() @@ -181,17 +182,19 @@ object LBFGS extends Logging { * and regVal is the regularization value computed in the previous iteration as well. */ var state = states.next() - while(states.hasNext) { - lossHistory.append(state.value) + while (states.hasNext) { + lossHistory += state.value state = states.next() } - lossHistory.append(state.value) + lossHistory += state.value val weights = Vectors.fromBreeze(state.x) + val lossHistoryArray = lossHistory.result() + logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( - lossHistory.takeRight(10).mkString(", "))) + lossHistoryArray.takeRight(10).mkString(", "))) - (weights, lossHistory.toArray) + (weights, lossHistoryArray) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index c9d33787b0bb5..d7bb943e84f53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -56,6 +56,10 @@ object LinearDataGenerator { } /** + * For compatibility, the generated data without specifying the mean and variance + * will have zero mean and variance of (1.0/3.0) since the original output range is + * [-1, 1] with uniform distribution, and the variance of uniform distribution + * is (b - a)^2^ / 12 which will be (1.0/3.0) * * @param intercept Data intercept * @param weights Weights to be applied. @@ -70,10 +74,47 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double = 0.1): Seq[LabeledPoint] = { + generateLinearInput(intercept, weights, + Array.fill[Double](weights.size)(0.0), + Array.fill[Double](weights.size)(1.0 / 3.0), + nPoints, seed, eps)} + + /** + * + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @return Seq of input. + */ + def generateLinearInput( + intercept: Double, + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + eps: Double): Seq[LabeledPoint] = { val rnd = new Random(seed) val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) + Array.fill[Double](weights.length)(rnd.nextDouble)) + + x.map(vector => { + // This doesn't work if `vector` is a sparse vector. + val vectorArray = vector.toArray + var i = 0 + while (i < vectorArray.size) { + vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + }) + val y = x.map { xi => blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bbb44c3e2dfc2..80323ef5201a6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,47 +19,149 @@ package org.apache.spark.ml.regression import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext, DataFrame} class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ + /** + * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + * is the same as the one trained by R's glmnet package. The following instruction + * describes how to reproduce the data in R. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = + * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") + */ override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) } - test("linear regression: default params") { - val lr = new LinearRegression - assert(lr.getLabelCol == "label") - val model = lr.fit(dataset) - model.transform(dataset) - .select("label", "prediction") - .collect() - // Check defaults - assert(model.getFeaturesCol == "features") - assert(model.getPredictionCol == "prediction") + test("linear regression with intercept without regularization") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + /** + * Using the following R code to load the data and train the model using glmnet package. + * + * library("glmnet") + * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * label <- as.numeric(data$V1) + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.300528 + * as.numeric.data.V2. 4.701024 + * as.numeric.data.V3. 7.198257 + */ + val interceptR = 6.298698 + val weightsR = Array(4.700706, 7.199082) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with L1 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.311546 + * as.numeric.data.V2. 2.123522 + * as.numeric.data.V3. 4.605651 + */ + val interceptR = 6.243000 + val weightsR = Array(4.024821, 6.679841) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } - test("linear regression with setters") { - // Set params, train, and check as many as we can. - val lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(1.0) - val model = lr.fit(dataset) - assert(model.fittingParamMap.get(lr.maxIter).get === 10) - assert(model.fittingParamMap.get(lr.regParam).get === 1.0) - - // Call fit() with new params, and check as many as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred") - assert(model2.fittingParamMap.get(lr.maxIter).get === 5) - assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) - assert(model2.getPredictionCol == "thePred") + test("linear regression with intercept with L2 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.328062 + * as.numeric.data.V2. 3.222034 + * as.numeric.data.V3. 4.926260 + */ + val interceptR = 5.269376 + val weightsR = Array(3.736216, 5.712356) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with ElasticNet regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.324108 + * as.numeric.data.V2. 3.168435 + * as.numeric.data.V3. 5.200403 + */ + val interceptR = 5.696056 + val weightsR = Array(3.670489, 6.001122) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } From b14cd2364932e504695bcc49486ffb4518fdf33d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 28 Apr 2015 09:59:36 -0700 Subject: [PATCH 12/18] [SPARK-7140] [MLLIB] only scan the first 16 entries in Vector.hashCode The Python SerDe calls `Object.hashCode`, which is very expensive for Vectors. It is not necessary to scan the whole vector, especially for large ones. In this PR, we only scan the first 16 nonzeros. srowen Author: Xiangrui Meng Closes #5697 from mengxr/SPARK-7140 and squashes the following commits: 2abc86d [Xiangrui Meng] typo 8fb7d74 [Xiangrui Meng] update impl 1ebad60 [Xiangrui Meng] only scan the first 16 nonzeros in Vector.hashCode --- .../apache/spark/mllib/linalg/Vectors.scala | 88 ++++++++++++++----- 1 file changed, 67 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index af0cfe22ca10d..34833e90d4af0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -52,7 +52,7 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v2: Vector => { + case v2: Vector => if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => @@ -63,20 +63,28 @@ sealed trait Vector extends Serializable { Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } - } case _ => false } } + /** + * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros + * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + */ override def hashCode(): Int = { - var result: Int = size + 31 - this.foreachActive { case (index, value) => - // ignore explict 0 for comparison between sparse and dense - if (value != 0) { - result = 31 * result + index - // refer to {@link java.util.Arrays.equals} for hash algorithm - val bits = java.lang.Double.doubleToLongBits(value) - result = 31 * result + (bits ^ (bits >>> 32)).toInt + // This is a reference implementation. It calls return in foreachActive, which is slow. + // Subclasses should override it with optimized implementation. + var result: Int = 31 + size + this.foreachActive { (index, value) => + if (index < 16) { + // ignore explicit 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + return result } } result @@ -317,7 +325,7 @@ object Vectors { case SparseVector(n, ids, vs) => vs case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } - val size = values.size + val size = values.length if (p == 1) { var sum = 0.0 @@ -371,8 +379,8 @@ object Vectors { val v1Indices = v1.indices val v2Values = v2.values val v2Indices = v2.indices - val nnzv1 = v1Indices.size - val nnzv2 = v2Indices.size + val nnzv1 = v1Indices.length + val nnzv2 = v2Indices.length var kv1 = 0 var kv2 = 0 @@ -401,7 +409,7 @@ object Vectors { case (DenseVector(vv1), DenseVector(vv2)) => var kv = 0 - val sz = vv1.size + val sz = vv1.length while (kv < sz) { val score = vv1(kv) - vv2(kv) squaredDistance += score * score @@ -422,7 +430,7 @@ object Vectors { var kv2 = 0 val indices = v1.indices var squaredDistance = 0.0 - val nnzv1 = indices.size + val nnzv1 = indices.length val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 @@ -451,8 +459,8 @@ object Vectors { v1Values: Array[Double], v2Indices: IndexedSeq[Int], v2Values: Array[Double]): Boolean = { - val v1Size = v1Values.size - val v2Size = v2Values.size + val v1Size = v1Values.length + val v2Size = v2Values.length var k1 = 0 var k2 = 0 var allEqual = true @@ -493,7 +501,7 @@ class DenseVector(val values: Array[Double]) extends Vector { private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localValues = values while (i < localValuesSize) { @@ -501,6 +509,22 @@ class DenseVector(val values: Array[Double]) extends Vector { i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + var i = 0 + val end = math.min(values.length, 16) + while (i < end) { + val v = values(i) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(values(i)) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + i += 1 + } + result + } } object DenseVector { @@ -522,8 +546,8 @@ class SparseVector( val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + - s" indices match the dimension of the values. You provided ${indices.size} indices and " + - s" ${values.size} values.") + s" indices match the dimension of the values. You provided ${indices.length} indices and " + + s" ${values.length} values.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" @@ -547,7 +571,7 @@ class SparseVector( private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localIndices = indices val localValues = values @@ -556,6 +580,28 @@ class SparseVector( i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + val end = values.length + var continue = true + var k = 0 + while ((k < end) & continue) { + val i = indices(k) + if (i < 16) { + val v = values(k) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + continue = false + } + k += 1 + } + result + } } object SparseVector { From 52ccf1d3739694826915cdf01642bab02958eb78 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 28 Apr 2015 10:24:00 -0700 Subject: [PATCH 13/18] [Core][test][minor] replace try finally block with tryWithSafeFinally Author: Zhang, Liye Closes #5739 from liyezhang556520/trySafeFinally and squashes the following commits: 55683e5 [Zhang, Liye] replace try finally block with tryWithSafeFinally --- .../apache/spark/deploy/history/FsHistoryProviderSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fcae603c7d18e..9e367a0d9af0d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -224,9 +224,9 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers EventLoggingListener.initEventLog(new FileOutputStream(file)) } val writer = new OutputStreamWriter(bstream, "UTF-8") - try { + Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) - } finally { + } { writer.close() } } From 8aab94d8984e9d12194dbda47b2e7d9dbc036889 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Tue, 28 Apr 2015 12:08:18 -0700 Subject: [PATCH 14/18] [SPARK-4286] Add an external shuffle service that can be run as a daemon. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows Mesos deployments to use the shuffle service (and implicitly dynamic allocation). It does so by adding a new "main" class and two corresponding scripts in `sbin`: - `sbin/start-shuffle-service.sh` - `sbin/stop-shuffle-service.sh` Specific options can be passed in `SPARK_SHUFFLE_OPTS`. This is picking up work from #3861 /cc tnachen Author: Iulian Dragos Closes #4990 from dragos/feature/external-shuffle-service and squashes the following commits: 6c2b148 [Iulian Dragos] Import order and wrong name fixup. 07804ad [Iulian Dragos] Moved ExternalShuffleService to the `deploy` package + other minor tweaks. 4dc1f91 [Iulian Dragos] Reviewer’s comments: 8145429 [Iulian Dragos] Add an external shuffle service that can be run as a daemon. --- conf/spark-env.sh.template | 3 +- ...ice.scala => ExternalShuffleService.scala} | 59 ++++++++++++++++--- .../apache/spark/deploy/worker/Worker.scala | 13 ++-- docs/job-scheduling.md | 2 +- .../launcher/SparkClassCommandBuilder.java | 4 ++ sbin/start-shuffle-service.sh | 33 +++++++++++ sbin/stop-shuffle-service.sh | 25 ++++++++ 7 files changed, 124 insertions(+), 15 deletions(-) rename core/src/main/scala/org/apache/spark/deploy/{worker/StandaloneWorkerShuffleService.scala => ExternalShuffleService.scala} (59%) create mode 100755 sbin/start-shuffle-service.sh create mode 100755 sbin/stop-shuffle-service.sh diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 67f81d33361e1..43c4288912b18 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -3,7 +3,7 @@ # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. -# Options read when launching programs locally with +# Options read when launching programs locally with # ./bin/run-example or ./bin/spark-submit # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node @@ -39,6 +39,7 @@ # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") +# - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala similarity index 59% rename from core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala rename to core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index b9798963bab0a..cd16f992a3c0a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.deploy.worker +package org.apache.spark.deploy + +import java.util.concurrent.CountDownLatch import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext @@ -23,6 +25,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslRpcHandler import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.util.Utils /** * Provides a server from which Executors can read shuffle files (rather than reading directly from @@ -31,8 +34,8 @@ import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler * * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. */ -private[worker] -class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) +private[deploy] +class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) extends Logging { private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) @@ -51,16 +54,58 @@ class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: Secu /** Starts the external shuffle service if the user has configured us to. */ def startIfEnabled() { if (enabled) { - require(server == null, "Shuffle server already started") - logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - server = transportContext.createServer(port) + start() } } + /** Start the external shuffle service */ + def start() { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + def stop() { - if (enabled && server != null) { + if (server != null) { server.close() server = null } } } + +/** + * A main class for running the external shuffle service. + */ +object ExternalShuffleService extends Logging { + @volatile + private var server: ExternalShuffleService = _ + + private val barrier = new CountDownLatch(1) + + def main(args: Array[String]): Unit = { + val sparkConf = new SparkConf + Utils.loadDefaultSparkProperties(sparkConf) + val securityManager = new SecurityManager(sparkConf) + + // we override this value since this service is started from the command line + // and we assume the user really wants it to be running + sparkConf.set("spark.shuffle.service.enabled", "true") + server = new ExternalShuffleService(sparkConf, securityManager) + server.start() + + installShutdownHook() + + // keep running until the process is terminated + barrier.await() + } + + private def installShutdownHook(): Unit = { + Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { + override def run() { + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3ee2eb69e8a4e..8f3cc54051048 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,6 +34,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem @@ -61,7 +62,7 @@ private[worker] class Worker( assert (port > 0) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -85,10 +86,10 @@ private[worker] class Worker( private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders - private val CLEANUP_INTERVAL_MILLIS = + private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") @@ -112,7 +113,7 @@ private[worker] class Worker( } else { new File(sys.env.get("SPARK_HOME").getOrElse(".")) } - + var workDir: File = null val finishedExecutors = new HashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] @@ -122,7 +123,7 @@ private[worker] class Worker( val finishedApps = new HashSet[String] // The shuffle service is not actually started unless configured. - private val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + private val shuffleService = new ExternalShuffleService(conf, securityMgr) private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -134,7 +135,7 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - + private var registrationRetryTimer: Option[Cancellable] = None var coresUsed = 0 diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 963e88a3e1d8f..8d9c2ba2041b2 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -32,7 +32,7 @@ Resource allocation can be configured as follows, based on the cluster type: * **Standalone mode:** By default, applications submitted to the standalone mode cluster will run in FIFO (first-in-first-out) order, and each application will try to use all available nodes. You can limit the number of nodes an application uses by setting the `spark.cores.max` configuration property in it, - or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. + or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. Finally, in addition to controlling cores, each application's `spark.executor.memory` setting controls its memory use. * **Mesos:** To use static partitioning on Mesos, set the `spark.mesos.coarse` configuration property to `true`, diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index e601a0a19f368..d80abf2a8676e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -69,6 +69,10 @@ public List buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; } else if (className.startsWith("org.apache.spark.tools.")) { String sparkHome = getSparkHome(); File toolsDir = new File(join(File.separator, sparkHome, "tools", "target", diff --git a/sbin/start-shuffle-service.sh b/sbin/start-shuffle-service.sh new file mode 100755 index 0000000000000..4fddcf7f95d40 --- /dev/null +++ b/sbin/start-shuffle-service.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the external shuffle server on the machine this script is executed on. +# +# Usage: start-shuffle-server.sh +# +# Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle server configuration. +# + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/stop-shuffle-service.sh b/sbin/stop-shuffle-service.sh new file mode 100755 index 0000000000000..4cb6891ae27fa --- /dev/null +++ b/sbin/stop-shuffle-service.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stops the external shuffle service on the machine this script is executed on. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 From 2d222fb39dd978e5a33cde6ceb59307cbdf7b171 Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Tue, 28 Apr 2015 12:18:55 -0700 Subject: [PATCH 15/18] [SPARK-5932] [CORE] Use consistent naming for size properties I've added an interface to JavaUtils to do byte conversion and added hooks within Utils.scala to handle conversion within Spark code (like for time strings). I've added matching tests for size conversion, and then updated all deprecated configs and documentation as per SPARK-5933. Author: Ilya Ganelin Closes #5574 from ilganeli/SPARK-5932 and squashes the following commits: 11f6999 [Ilya Ganelin] Nit fixes 49a8720 [Ilya Ganelin] Whitespace fix 2ab886b [Ilya Ganelin] Scala style fc85733 [Ilya Ganelin] Got rid of floating point math 852a407 [Ilya Ganelin] [SPARK-5932] Added much improved overflow handling. Can now handle sizes up to Long.MAX_VALUE Petabytes instead of being capped at Long.MAX_VALUE Bytes 9ee779c [Ilya Ganelin] Simplified fraction matches 22413b1 [Ilya Ganelin] Made MAX private 3dfae96 [Ilya Ganelin] Fixed some nits. Added automatic conversion of old paramter for kryoserializer.mb to new values. e428049 [Ilya Ganelin] resolving merge conflict 8b43748 [Ilya Ganelin] Fixed error in pattern matching for doubles 84a2581 [Ilya Ganelin] Added smoother handling of fractional values for size parameters. This now throws an exception and added a warning for old spark.kryoserializer.buffer d3d09b6 [Ilya Ganelin] [SPARK-5932] Fixing error in KryoSerializer fe286b4 [Ilya Ganelin] Resolved merge conflict c7803cd [Ilya Ganelin] Empty lines 54b78b4 [Ilya Ganelin] Simplified byteUnit class 69e2f20 [Ilya Ganelin] Updates to code f32bc01 [Ilya Ganelin] [SPARK-5932] Fixed error in API in SparkConf.scala where Kb conversion wasn't being done properly (was Mb). Added test cases for both timeUnit and ByteUnit conversion f15f209 [Ilya Ganelin] Fixed conversion of kryo buffer size 0f4443e [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5932 35a7fa7 [Ilya Ganelin] Minor formatting 928469e [Ilya Ganelin] [SPARK-5932] Converted some longs to ints 5d29f90 [Ilya Ganelin] [SPARK-5932] Finished documentation updates 7a6c847 [Ilya Ganelin] [SPARK-5932] Updated spark.shuffle.file.buffer afc9a38 [Ilya Ganelin] [SPARK-5932] Updated spark.broadcast.blockSize and spark.storage.memoryMapThreshold ae7e9f6 [Ilya Ganelin] [SPARK-5932] Updated spark.io.compression.snappy.block.size 2d15681 [Ilya Ganelin] [SPARK-5932] Updated spark.executor.logs.rolling.size.maxBytes 1fbd435 [Ilya Ganelin] [SPARK-5932] Updated spark.broadcast.blockSize eba4de6 [Ilya Ganelin] [SPARK-5932] Updated spark.shuffle.file.buffer.kb b809a78 [Ilya Ganelin] [SPARK-5932] Updated spark.kryoserializer.buffer.max 0cdff35 [Ilya Ganelin] [SPARK-5932] Updated to use bibibytes in method names. Updated spark.kryoserializer.buffer.mb and spark.reducer.maxMbInFlight 475370a [Ilya Ganelin] [SPARK-5932] Simplified ByteUnit code, switched to using longs. Updated docs to clarify that we use kibi, mebi etc instead of kilo, mega 851d691 [Ilya Ganelin] [SPARK-5932] Updated memoryStringToMb to use new interfaces a9f4fcf [Ilya Ganelin] [SPARK-5932] Added unit tests for unit conversion 747393a [Ilya Ganelin] [SPARK-5932] Added unit tests for ByteString conversion 09ea450 [Ilya Ganelin] [SPARK-5932] Added byte string conversion to Jav utils 5390fd9 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-5932 db9a963 [Ilya Ganelin] Closing second spark context 1dc0444 [Ilya Ganelin] Added ref equality check 8c884fa [Ilya Ganelin] Made getOrCreate synchronized cb0c6b7 [Ilya Ganelin] Doc updates and code cleanup 270cfe3 [Ilya Ganelin] [SPARK-6703] Documentation fixes 15e8dea [Ilya Ganelin] Updated comments and added MiMa Exclude 0e1567c [Ilya Ganelin] Got rid of unecessary option for AtomicReference dfec4da [Ilya Ganelin] Changed activeContext to AtomicReference 733ec9f [Ilya Ganelin] Fixed some bugs in test code 8be2f83 [Ilya Ganelin] Replaced match with if e92caf7 [Ilya Ganelin] [SPARK-6703] Added test to ensure that getOrCreate both allows creation, retrieval, and a second context if desired a99032f [Ilya Ganelin] Spacing fix d7a06b8 [Ilya Ganelin] Updated SparkConf class to add getOrCreate method. Started test suite implementation --- .../scala/org/apache/spark/SparkConf.scala | 90 ++++++++++++++- .../spark/broadcast/TorrentBroadcast.scala | 3 +- .../apache/spark/io/CompressionCodec.scala | 8 +- .../spark/serializer/KryoSerializer.scala | 17 +-- .../shuffle/FileShuffleBlockManager.scala | 3 +- .../hash/BlockStoreShuffleFetcher.scala | 3 +- .../org/apache/spark/storage/DiskStore.scala | 3 +- .../scala/org/apache/spark/util/Utils.scala | 53 ++++++--- .../collection/ExternalAppendOnlyMap.scala | 6 +- .../util/collection/ExternalSorter.scala | 4 +- .../util/logging/RollingFileAppender.scala | 2 +- .../org/apache/spark/DistributedSuite.scala | 2 +- .../org/apache/spark/SparkConfSuite.scala | 19 ++++ .../KryoSerializerResizableOutputSuite.scala | 8 +- .../serializer/KryoSerializerSuite.scala | 2 +- .../BlockManagerReplicationSuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 6 +- .../org/apache/spark/util/UtilsSuite.scala | 100 +++++++++++++++- docs/configuration.md | 60 ++++++---- docs/tuning.md | 2 +- .../spark/examples/mllib/MovieLensALS.scala | 2 +- .../apache/spark/network/util/ByteUnit.java | 67 +++++++++++ .../apache/spark/network/util/JavaUtils.java | 107 ++++++++++++++++-- 23 files changed, 488 insertions(+), 81 deletions(-) create mode 100644 network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c1996e08756a6..a8fc90ad2050e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -211,7 +211,74 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.timeStringAsMs(get(key, defaultValue)) } + /** + * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then bytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsBytes(key: String): Long = { + Utils.byteStringAsBytes(get(key)) + } + + /** + * Get a size parameter as bytes, falling back to a default if not set. If no + * suffix is provided then bytes are assumed. + */ + def getSizeAsBytes(key: String, defaultValue: String): Long = { + Utils.byteStringAsBytes(get(key, defaultValue)) + } + + /** + * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Kibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsKb(key: String): Long = { + Utils.byteStringAsKb(get(key)) + } + + /** + * Get a size parameter as Kibibytes, falling back to a default if not set. If no + * suffix is provided then Kibibytes are assumed. + */ + def getSizeAsKb(key: String, defaultValue: String): Long = { + Utils.byteStringAsKb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Mebibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsMb(key: String): Long = { + Utils.byteStringAsMb(get(key)) + } + + /** + * Get a size parameter as Mebibytes, falling back to a default if not set. If no + * suffix is provided then Mebibytes are assumed. + */ + def getSizeAsMb(key: String, defaultValue: String): Long = { + Utils.byteStringAsMb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Gibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsGb(key: String): Long = { + Utils.byteStringAsGb(get(key)) + } + /** + * Get a size parameter as Gibibytes, falling back to a default if not set. If no + * suffix is provided then Gibibytes are assumed. + */ + def getSizeAsGb(key: String, defaultValue: String): Long = { + Utils.byteStringAsGb(get(key, defaultValue)) + } + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -407,7 +474,13 @@ private[spark] object SparkConf extends Logging { "The spark.cache.class property is no longer being used! Specify storage levels using " + "the RDD.persist() method instead."), DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", - "Please use spark.{driver,executor}.userClassPathFirst instead.")) + "Please use spark.{driver,executor}.userClassPathFirst instead."), + DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", + "Please use spark.kryoserializer.buffer instead. The default value for " + + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + + "are no longer accepted. To specify the equivalent now, one may use '64k'.") + ) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) } @@ -432,6 +505,21 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", // Translate old value to a duration, with 10s wait time per try. translation = s => s"${s.toLong * 10}s")), + "spark.reducer.maxSizeInFlight" -> Seq( + AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), + "spark.kryoserializer.buffer" -> + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${s.toDouble * 1000}k")), + "spark.kryoserializer.buffer.max" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), + "spark.shuffle.file.buffer" -> Seq( + AlternateConfig("spark.shuffle.file.buffer.kb", "1.4")), + "spark.executor.logs.rolling.maxSize" -> Seq( + AlternateConfig("spark.executor.logs.rolling.size.maxBytes", "1.4")), + "spark.io.compression.snappy.blockSize" -> Seq( + AlternateConfig("spark.io.compression.snappy.block.size", "1.4")), + "spark.io.compression.lz4.blockSize" -> Seq( + AlternateConfig("spark.io.compression.lz4.block.size", "1.4")), "spark.rpc.numRetries" -> Seq( AlternateConfig("spark.akka.num.retries", "1.4")), "spark.rpc.retry.wait" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 23b02e60338fb..a0c9b5e63c744 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -74,7 +74,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } else { None } - blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided + blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 } setConf(SparkEnv.get.conf) diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0709b6d689e86..0756cdb2ed8e6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -97,7 +97,7 @@ private[spark] object CompressionCodec { /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.lz4.block.size`. + * Block size can be configured by `spark.io.compression.lz4.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -107,7 +107,7 @@ private[spark] object CompressionCodec { class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.lz4.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.lz4.blockSize", "32k").toInt new LZ4BlockOutputStream(s, blockSize) } @@ -137,7 +137,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.snappy.block.size`. + * Block size can be configured by `spark.io.compression.snappy.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -153,7 +153,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { } override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 579fb6624e692..754832b8a4ca7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,16 +49,17 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) - if (bufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + - s"2048 mb, got: + $bufferSizeMb mb.") + private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") + + if (bufferSizeKb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + + s"2048 mb, got: + $bufferSizeKb mb.") } - private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + private val bufferSize = (bufferSizeKb * 1024).toInt - val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt if (maxBufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + s"2048 mb, got: + $maxBufferSizeMb mb.") } private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 @@ -173,7 +174,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + - "increase spark.kryoserializer.buffer.max.mb value.") + "increase spark.kryoserializer.buffer.max value.") } ByteBuffer.wrap(output.toBytes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 538e150ead05a..e9b4e2b955dc8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -78,7 +78,8 @@ class FileShuffleBlockManager(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 7a2c5ae32d98b..80374adc44296 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -79,7 +79,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockManager, blocksByAddress, serializer, - SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 4b232ae7d3180..1f45956282166 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - val minMemoryMapBytes = blockManager.conf.getLong( - "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L) + val minMemoryMapBytes = blockManager.conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") override def getSize(blockId: BlockId): Long = { diskManager.getFile(blockId.name).length diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 342bc9a06db47..4c028c06a5138 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1020,21 +1020,48 @@ private[spark] object Utils extends Logging { } /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + def byteStringAsBytes(str: String): Long = { + JavaUtils.byteStringAsBytes(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + def byteStringAsKb(str: String): Long = { + JavaUtils.byteStringAsKb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + def byteStringAsMb(str: String): Long = { + JavaUtils.byteStringAsMb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m, 500g) to gibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + def byteStringAsGb(str: String): Long = { + JavaUtils.byteStringAsGb(str) + } + + /** + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of mebibytes. */ def memoryStringToMb(str: String): Int = { - val lower = str.toLowerCase - if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt - } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt - } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 - } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes - (lower.toLong / 1024 / 1024).toInt - } + // Convert to bytes, rather than directly to MB, because when no units are specified the unit + // is assumed to be bytes + (JavaUtils.byteStringAsBytes(str) / 1024 / 1024).toInt } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 30dd7f22e494f..f912049563906 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,8 +89,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - - private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 79a695fb62086..ef3cac622505e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -108,7 +108,9 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index e579421676343..7138b4b8e4533 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -138,7 +138,7 @@ private[spark] object RollingFileAppender { val STRATEGY_DEFAULT = "" val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval" val INTERVAL_DEFAULT = "daily" - val SIZE_PROPERTY = "spark.executor.logs.rolling.size.maxBytes" + val SIZE_PROPERTY = "spark.executor.logs.rolling.maxSize" val SIZE_DEFAULT = (1024 * 1024).toString val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" val DEFAULT_BUFFER_SIZE = 8192 diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 97ea3578aa8ba..96a9c207ad022 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -77,7 +77,7 @@ class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { } test("groupByKey where map output sizes exceed maxMbInFlight") { - val conf = new SparkConf().set("spark.reducer.maxMbInFlight", "1") + val conf = new SparkConf().set("spark.reducer.maxSizeInFlight", "1m") sc = new SparkContext(clusterUrl, "test", conf) // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output // file should be about 2.5 MB diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 272e6af0514e4..68d08e32f9aa4 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -24,11 +24,30 @@ import scala.language.postfixOps import scala.util.{Try, Random} import org.scalatest.FunSuite +import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { + test("Test byteString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getSizeAsBytes("fake","1k") === ByteUnit.KiB.toBytes(1)) + assert(conf.getSizeAsKb("fake","1k") === ByteUnit.KiB.toKiB(1)) + assert(conf.getSizeAsMb("fake","1k") === ByteUnit.KiB.toMiB(1)) + assert(conf.getSizeAsGb("fake","1k") === ByteUnit.KiB.toGiB(1)) + } + + test("Test timeString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getTimeAsMs("fake","1ms") === TimeUnit.MILLISECONDS.toMillis(1)) + assert(conf.getTimeAsSeconds("fake","1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) + } + test("loading from system properties") { System.setProperty("spark.test.testProperty", "2") val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 967c9e9899c9d..da98d09184735 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -33,8 +33,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo without resizable output buffer should fail on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "1m") val sc = new SparkContext("local", "test", conf) intercept[SparkException](sc.parallelize(x).collect()) LocalSparkContext.stop(sc) @@ -43,8 +43,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo with resizable output buffer should succeed on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "2m") val sc = new SparkContext("local", "test", conf) assert(sc.parallelize(x).collect() === x) LocalSparkContext.stop(sc) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index b070a54aa989b..1b13559e77cb8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -269,7 +269,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("serialization buffer overflow reporting") { import org.apache.spark.SparkException - val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max.mb" + val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" val largeObject = (1 to 1000000).toArray diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index ffa5162a31841..f647200402ecb 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -50,7 +50,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 7d82a7c66ad1a..6957bc72e9903 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -55,7 +55,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -814,14 +814,14 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // be nice to refactor classes involved in disk storage in a way that // allows for easier testing. val blockManager = mock(classOf[BlockManager]) - when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "0")) val diskBlockManager = new DiskBlockManager(blockManager, conf) val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val mapped = diskStoreMapped.getBytes(blockId).get - when(blockManager.conf).thenReturn(conf.clone.set(confKey, (1000 * 1000).toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "1m")) val diskStoreNotMapped = new DiskStore(blockManager, diskBlockManager) diskStoreNotMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val notMapped = diskStoreNotMapped.getBytes(blockId).get diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 1ba99803f5a0e..62a3cbcdf69ea 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -23,7 +23,6 @@ import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols import java.util.concurrent.TimeUnit import java.util.Locale -import java.util.PriorityQueue import scala.collection.mutable.ListBuffer import scala.util.Random @@ -35,6 +34,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.ByteUnit import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -65,6 +65,10 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) // Test invalid strings + intercept[NumberFormatException] { + Utils.timeStringAsMs("600l") + } + intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") } @@ -82,6 +86,100 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { } } + test("Test byteString conversion") { + // Test zero + assert(Utils.byteStringAsBytes("0") === 0) + + assert(Utils.byteStringAsGb("1") === 1) + assert(Utils.byteStringAsGb("1g") === 1) + assert(Utils.byteStringAsGb("1023m") === 0) + assert(Utils.byteStringAsGb("1024m") === 1) + assert(Utils.byteStringAsGb("1048575k") === 0) + assert(Utils.byteStringAsGb("1048576k") === 1) + assert(Utils.byteStringAsGb("1k") === 0) + assert(Utils.byteStringAsGb("1t") === ByteUnit.TiB.toGiB(1)) + assert(Utils.byteStringAsGb("1p") === ByteUnit.PiB.toGiB(1)) + + assert(Utils.byteStringAsMb("1") === 1) + assert(Utils.byteStringAsMb("1m") === 1) + assert(Utils.byteStringAsMb("1048575b") === 0) + assert(Utils.byteStringAsMb("1048576b") === 1) + assert(Utils.byteStringAsMb("1023k") === 0) + assert(Utils.byteStringAsMb("1024k") === 1) + assert(Utils.byteStringAsMb("3645k") === 3) + assert(Utils.byteStringAsMb("1024gb") === 1048576) + assert(Utils.byteStringAsMb("1g") === ByteUnit.GiB.toMiB(1)) + assert(Utils.byteStringAsMb("1t") === ByteUnit.TiB.toMiB(1)) + assert(Utils.byteStringAsMb("1p") === ByteUnit.PiB.toMiB(1)) + + assert(Utils.byteStringAsKb("1") === 1) + assert(Utils.byteStringAsKb("1k") === 1) + assert(Utils.byteStringAsKb("1m") === ByteUnit.MiB.toKiB(1)) + assert(Utils.byteStringAsKb("1g") === ByteUnit.GiB.toKiB(1)) + assert(Utils.byteStringAsKb("1t") === ByteUnit.TiB.toKiB(1)) + assert(Utils.byteStringAsKb("1p") === ByteUnit.PiB.toKiB(1)) + + assert(Utils.byteStringAsBytes("1") === 1) + assert(Utils.byteStringAsBytes("1k") === ByteUnit.KiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1m") === ByteUnit.MiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1g") === ByteUnit.GiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1t") === ByteUnit.TiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1p") === ByteUnit.PiB.toBytes(1)) + + // Overflow handling, 1073741824p exceeds Long.MAX_VALUE if converted straight to Bytes + // This demonstrates that we can have e.g 1024^3 PB without overflowing. + assert(Utils.byteStringAsGb("1073741824p") === ByteUnit.PiB.toGiB(1073741824)) + assert(Utils.byteStringAsMb("1073741824p") === ByteUnit.PiB.toMiB(1073741824)) + + // Run this to confirm it doesn't throw an exception + assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) + assert(ByteUnit.PiB.toPiB(9223372036854775807L) === 9223372036854775807L) + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to bytes + Utils.byteStringAsBytes("9223372036854775808") + } + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to TB + ByteUnit.PiB.toTiB(9223372036854775807L) + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064") + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064m") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("500ub") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600b") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("600gb This breaks") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This 123mb breaks") + } + } + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") diff --git a/docs/configuration.md b/docs/configuration.md index d587b91124cb8..72105feba4919 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -48,6 +48,17 @@ The following format is accepted: 5d (days) 1y (years) + +Properties that specify a byte size should be configured with a unit of size. +The following format is accepted: + + 1b (bytes) + 1k or 1kb (kibibytes = 1024 bytes) + 1m or 1mb (mebibytes = 1024 kibibytes) + 1g or 1gb (gibibytes = 1024 mebibytes) + 1t or 1tb (tebibytes = 1024 gibibytes) + 1p or 1pb (pebibytes = 1024 tebibytes) + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -272,12 +283,11 @@ Apart from these, the following properties are also available, and may be useful - spark.executor.logs.rolling.size.maxBytes + spark.executor.logs.rolling.maxSize (none) Set the max size of the file by which the executor logs will be rolled over. - Rolling is disabled by default. Value is set in terms of bytes. - See spark.executor.logs.rolling.maxRetainedFiles + Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -366,10 +376,10 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -403,10 +413,10 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -582,18 +592,18 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -641,19 +651,19 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -698,9 +708,9 @@ Apart from these, the following properties are also available, and may be useful - + @@ -816,9 +826,9 @@ Apart from these, the following properties are also available, and may be useful - + diff --git a/docs/tuning.md b/docs/tuning.md index cbd227868b248..1cb223e74f382 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -60,7 +60,7 @@ val sc = new SparkContext(conf) The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. -If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` +If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` config property. The default is 2, but this value needs to be large enough to hold the *largest* object you will serialize. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 0bc36ea65e1ab..99588b0984ab2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -100,7 +100,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") if (params.kryo) { conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating])) - .set("spark.kryoserializer.buffer.mb", "8") + .set("spark.kryoserializer.buffer", "8m") } val sc = new SparkContext(conf) diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java new file mode 100644 index 0000000000000..36d655017fb0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.util; + +public enum ByteUnit { + BYTE (1), + KiB (1024L), + MiB ((long) Math.pow(1024L, 2L)), + GiB ((long) Math.pow(1024L, 3L)), + TiB ((long) Math.pow(1024L, 4L)), + PiB ((long) Math.pow(1024L, 5L)); + + private ByteUnit(long multiplier) { + this.multiplier = multiplier; + } + + // Interpret the provided number (d) with suffix (u) as this unit type. + // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k + public long convertFrom(long d, ByteUnit u) { + return u.convertTo(d, this); + } + + // Convert the provided number (d) interpreted as this unit type to unit type (u). + public long convertTo(long d, ByteUnit u) { + if (multiplier > u.multiplier) { + long ratio = multiplier / u.multiplier; + if (Long.MAX_VALUE / ratio < d) { + throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in " + + name() + ". Try a larger unit (e.g. MiB instead of KiB)"); + } + return d * ratio; + } else { + // Perform operations in this order to avoid potential overflow + // when computing d * multiplier + return d / (u.multiplier / multiplier); + } + } + + public double toBytes(long d) { + if (d < 0) { + throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); + } + return d * multiplier; + } + + public long toKiB(long d) { return convertTo(d, KiB); } + public long toMiB(long d) { return convertTo(d, MiB); } + public long toGiB(long d) { return convertTo(d, GiB); } + public long toTiB(long d) { return convertTo(d, TiB); } + public long toPiB(long d) { return convertTo(d, PiB); } + + private final long multiplier; +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index b6fbace509a0e..6b514aaa1290d 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -126,7 +126,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -137,6 +137,21 @@ private static boolean isSymlink(File file) throws IOException { .put("d", TimeUnit.DAYS) .build(); + private static final ImmutableMap byteSuffixes = + ImmutableMap.builder() + .put("b", ByteUnit.BYTE) + .put("k", ByteUnit.KiB) + .put("kb", ByteUnit.KiB) + .put("m", ByteUnit.MiB) + .put("mb", ByteUnit.MiB) + .put("g", ByteUnit.GiB) + .put("gb", ByteUnit.GiB) + .put("t", ByteUnit.TiB) + .put("tb", ByteUnit.TiB) + .put("p", ByteUnit.PiB) + .put("pb", ByteUnit.PiB) + .build(); + /** * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for * internal use. If no suffix is provided a direct conversion is attempted. @@ -145,16 +160,14 @@ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); try { - String suffix; - long val; Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); - if (m.matches()) { - val = Long.parseLong(m.group(1)); - suffix = m.group(2); - } else { + if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); @@ -164,7 +177,7 @@ private static long parseTimeString(String str, TimeUnit unit) { return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + - "milliseconds (ms), microseconds (us), minutes (m or min) hour (h), or day (d). " + + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; throw new NumberFormatException(timeError + "\n" + e.getMessage()); @@ -186,5 +199,83 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } + + /** + * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for + * internal use. If no suffix is provided a direct conversion of the provided default is + * attempted. + */ + private static long parseByteString(String str, ByteUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); + Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); + + if (m.matches()) { + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + + // Check for invalid suffixes + if (suffix != null && !byteSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + } else if (fractionMatcher.matches()) { + throw new NumberFormatException("Fractional values are not supported. Input was: " + + fractionMatcher.group(1)); + } else { + throw new NumberFormatException("Failed to parse byte string: " + str); + } + + } catch (NumberFormatException e) { + String timeError = "Size must be specified as bytes (b), " + + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + + "E.g. 50b, 100k, or 250m."; + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + public static long byteStringAsBytes(String str) { + return parseByteString(str, ByteUnit.BYTE); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + public static long byteStringAsKb(String str) { + return parseByteString(str, ByteUnit.KiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + public static long byteStringAsMb(String str) { + return parseByteString(str, ByteUnit.MiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + public static long byteStringAsGb(String str) { + return parseByteString(str, ByteUnit.GiB); + } } From 80098109d908b738b43d397e024756ff617d0af4 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Tue, 28 Apr 2015 12:33:48 -0700 Subject: [PATCH 16/18] [SPARK-6314] [CORE] handle JsonParseException for history server This is handled in the same way with [SPARK-6197](https://issues.apache.org/jira/browse/SPARK-6197). The result of this PR is that exception showed in history server log will be replaced by a warning, and the application that with un-complete history log file will be listed on history server webUI Author: Zhang, Liye Closes #5736 from liyezhang556520/SPARK-6314 and squashes the following commits: b8d2d88 [Zhang, Liye] handle JsonParseException for history server --- .../org/apache/spark/deploy/history/FsHistoryProvider.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a94ebf6e53750..fb2cbbcccc54b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -333,8 +333,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } try { val appListener = new ApplicationEventListener + val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) - bus.replay(logInput, logPath.toString) + bus.replay(logInput, logPath.toString, !appCompleted) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), @@ -343,7 +344,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis appListener.endTime.getOrElse(-1L), getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), - isApplicationCompleted(eventLog)) + appCompleted) } finally { logInput.close() } From 53befacced828bbac53c6e3a4976ec3f036bae9e Mon Sep 17 00:00:00 2001 From: Timothy Chen Date: Tue, 28 Apr 2015 13:31:08 -0700 Subject: [PATCH 17/18] [SPARK-5338] [MESOS] Add cluster mode support for Mesos This patch adds the support for cluster mode to run on Mesos. It introduces a new Mesos framework dedicated to launch new apps/drivers, and can be called with the spark-submit script and specifying --master flag to the cluster mode REST interface instead of Mesos master. Example: ./bin/spark-submit --deploy-mode cluster --class org.apache.spark.examples.SparkPi --master mesos://10.0.0.206:8077 --executor-memory 1G --total-executor-cores 100 examples/target/spark-examples_2.10-1.3.0-SNAPSHOT.jar 30 Part of this patch is also to abstract the StandaloneRestServer so it can have different implementations of the REST endpoints. Features of the cluster mode in this PR: - Supports supervise mode where scheduler will keep trying to reschedule exited job. - Adds a new UI for the cluster mode scheduler to see all the running jobs, finished jobs, and supervise jobs waiting to be retried - Supports state persistence to ZK, so when the cluster scheduler fails over it can pick up all the queued and running jobs Author: Timothy Chen Author: Luc Bourlier Closes #5144 from tnachen/mesos_cluster_mode and squashes the following commits: 069e946 [Timothy Chen] Fix rebase. e24b512 [Timothy Chen] Persist submitted driver. 390c491 [Timothy Chen] Fix zk conf key for mesos zk engine. e324ac1 [Timothy Chen] Fix merge. fd5259d [Timothy Chen] Address review comments. 1553230 [Timothy Chen] Address review comments. c6c6b73 [Timothy Chen] Pass spark properties to mesos cluster tasks. f7d8046 [Timothy Chen] Change app name to spark cluster. 17f93a2 [Timothy Chen] Fix head of line blocking in scheduling drivers. 6ff8e5c [Timothy Chen] Address comments and add logging. df355cd [Timothy Chen] Add metrics to mesos cluster scheduler. 20f7284 [Timothy Chen] Address review comments 7252612 [Timothy Chen] Fix tests. a46ad66 [Timothy Chen] Allow zk cli param override. 920fc4b [Timothy Chen] Fix scala style issues. 862b5b5 [Timothy Chen] Support asking driver status when it's retrying. 7f214c2 [Timothy Chen] Fix RetryState visibility e0f33f7 [Timothy Chen] Add supervise support and persist retries. 371ce65 [Timothy Chen] Handle cluster mode recovery and state persistence. 3d4dfa1 [Luc Bourlier] Adds support to kill submissions febfaba [Timothy Chen] Bound the finished drivers in memory 543a98d [Timothy Chen] Schedule multiple jobs 6887e5e [Timothy Chen] Support looking at SPARK_EXECUTOR_URI env variable in schedulers 8ec76bc [Timothy Chen] Fix Mesos dispatcher UI. d57d77d [Timothy Chen] Add documentation 825afa0 [Luc Bourlier] Supports more spark-submit parameters b8e7181 [Luc Bourlier] Adds a shutdown latch to keep the deamon running 0fa7780 [Luc Bourlier] Launch task through the mesos scheduler 5b7a12b [Timothy Chen] WIP: Making a cluster mode a mesos framework. 4b2f5ef [Timothy Chen] Specify user jar in command to be replaced with local. e775001 [Timothy Chen] Support fetching remote uris in driver runner. 7179495 [Timothy Chen] Change Driver page output and add logging 880bc27 [Timothy Chen] Add Mesos Cluster UI to display driver results 9986731 [Timothy Chen] Kill drivers when shutdown 67cbc18 [Timothy Chen] Rename StandaloneRestClient to RestClient and add sbin scripts e3facdd [Timothy Chen] Add Mesos Cluster dispatcher --- .../spark/deploy/FaultToleranceTest.scala | 2 +- .../{master => }/SparkCuratorUtil.scala | 10 +- .../org/apache/spark/deploy/SparkSubmit.scala | 48 +- .../spark/deploy/SparkSubmitArguments.scala | 11 +- .../apache/spark/deploy/master/Master.scala | 2 +- .../master/ZooKeeperLeaderElectionAgent.scala | 1 + .../master/ZooKeeperPersistenceEngine.scala | 1 + .../deploy/mesos/MesosClusterDispatcher.scala | 116 ++++ .../MesosClusterDispatcherArguments.scala | 101 +++ .../deploy/mesos/MesosDriverDescription.scala | 65 ++ .../deploy/mesos/ui/MesosClusterPage.scala | 114 ++++ .../deploy/mesos/ui/MesosClusterUI.scala | 48 ++ ...lient.scala => RestSubmissionClient.scala} | 35 +- .../deploy/rest/RestSubmissionServer.scala | 318 +++++++++ .../deploy/rest/StandaloneRestServer.scala | 344 ++-------- .../rest/SubmitRestProtocolRequest.scala | 2 +- .../rest/SubmitRestProtocolResponse.scala | 6 +- .../deploy/rest/mesos/MesosRestServer.scala | 158 +++++ .../mesos/CoarseMesosSchedulerBackend.scala | 82 +-- .../mesos/MesosClusterPersistenceEngine.scala | 134 ++++ .../cluster/mesos/MesosClusterScheduler.scala | 608 ++++++++++++++++++ .../mesos/MesosClusterSchedulerSource.scala | 40 ++ .../cluster/mesos/MesosSchedulerBackend.scala | 85 +-- .../cluster/mesos/MesosSchedulerUtils.scala | 95 +++ .../spark/deploy/SparkSubmitSuite.scala | 2 +- .../rest/StandaloneRestSubmitSuite.scala | 46 +- .../mesos/MesosClusterSchedulerSuite.scala | 76 +++ docs/running-on-mesos.md | 23 +- sbin/start-mesos-dispatcher.sh | 40 ++ sbin/stop-mesos-dispatcher.sh | 27 + 30 files changed, 2147 insertions(+), 493 deletions(-) rename core/src/main/scala/org/apache/spark/deploy/{master => }/SparkCuratorUtil.scala (89%) create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala rename core/src/main/scala/org/apache/spark/deploy/rest/{StandaloneRestClient.scala => RestSubmissionClient.scala} (93%) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala create mode 100644 core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala create mode 100755 sbin/start-mesos-dispatcher.sh create mode 100755 sbin/stop-mesos-dispatcher.sh diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index a7c89276a045e..c048b78910f38 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -32,7 +32,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.{Logging, SparkConf, SparkContext} -import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} +import org.apache.spark.deploy.master.RecoveryState import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala similarity index 89% rename from core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala rename to core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index 5b22481ea8c5f..b8d3993540220 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.master +package org.apache.spark.deploy import scala.collection.JavaConversions._ @@ -25,15 +25,17 @@ import org.apache.zookeeper.KeeperException import org.apache.spark.{Logging, SparkConf} -private[deploy] object SparkCuratorUtil extends Logging { +private[spark] object SparkCuratorUtil extends Logging { private val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 private val ZK_SESSION_TIMEOUT_MILLIS = 60000 private val RETRY_WAIT_MILLIS = 5000 private val MAX_RECONNECT_ATTEMPTS = 3 - def newClient(conf: SparkConf): CuratorFramework = { - val ZK_URL = conf.get("spark.deploy.zookeeper.url") + def newClient( + conf: SparkConf, + zkUrlConf: String = "spark.deploy.zookeeper.url"): CuratorFramework = { + val ZK_URL = conf.get(zkUrlConf) val zk = CuratorFrameworkFactory.newClient(ZK_URL, ZK_SESSION_TIMEOUT_MILLIS, ZK_CONNECTION_TIMEOUT_MILLIS, new ExponentialBackoffRetry(RETRY_WAIT_MILLIS, MAX_RECONNECT_ATTEMPTS)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 296a0764b8baf..f4f572e1e256e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -36,11 +36,11 @@ import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} - import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} + /** * Whether to submit, kill, or request the status of an application. * The latter two operations are currently supported only for standalone cluster mode. @@ -114,18 +114,20 @@ object SparkSubmit { } } - /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */ + /** + * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. + */ private def kill(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .killSubmission(args.master, args.submissionToKill) } /** * Request the status of an existing submission using the REST protocol. - * Standalone cluster mode only. + * Standalone and Mesos cluster mode only. */ private def requestStatus(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) } @@ -252,6 +254,7 @@ object SparkSubmit { } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code @@ -294,8 +297,9 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) => - printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + case (MESOS, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python " + + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -377,15 +381,6 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), - // Standalone cluster only - // Do not set CL arguments here because there are multiple possibilities for the main class - OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"), - OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER, - sysProp = "spark.driver.supervise"), - // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), @@ -413,7 +408,15 @@ object SparkSubmit { OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files") + sysProp = "spark.files"), + OptionAssigner(args.jars, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars"), + OptionAssigner(args.driverMemory, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.memory"), + OptionAssigner(args.driverCores, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.cores"), + OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy") ) // In client mode, launch the application main class directly @@ -452,7 +455,7 @@ object SparkSubmit { // All Spark parameters are expected to be passed to the client through system properties. if (args.isStandaloneCluster) { if (args.useRest) { - childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient" + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" childArgs += (args.primaryResource, args.mainClass) } else { // In legacy standalone cluster mode, use Client as a wrapper around the user class @@ -496,6 +499,15 @@ object SparkSubmit { } } + if (isMesosCluster) { + assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" + childArgs += (args.primaryResource, args.mainClass) + if (args.childArgs != null) { + childArgs ++= args.childArgs + } + } + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c896842943f2b..c621b8fc86f94 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -241,8 +241,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateKillArguments(): Unit = { - if (!master.startsWith("spark://")) { - SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!") + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { + SparkSubmit.printErrorAndExit( + "Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { SparkSubmit.printErrorAndExit("Please specify a submission to kill.") @@ -250,9 +251,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateStatusRequestArguments(): Unit = { - if (!master.startsWith("spark://")) { + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { SparkSubmit.printErrorAndExit( - "Requesting submission statuses is only supported in standalone mode!") + "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") @@ -485,6 +486,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). + | + | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. | --kill SUBMISSION_ID If given, kills the driver specified. | --status SUBMISSION_ID If given, requests the status of the driver specified. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index ff2eed6dee70a..1c21c179562ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -130,7 +130,7 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, self, masterUrl, conf)) + Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 4823fd7cac0cb..52758d6a7c4be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -23,6 +23,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index a285783f72000..80db6d474b5c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -26,6 +26,7 @@ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala new file mode 100644 index 0000000000000..5d4e5b899dfdc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.concurrent.CountDownLatch + +import org.apache.spark.deploy.mesos.ui.MesosClusterUI +import org.apache.spark.deploy.rest.mesos.MesosRestServer +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.util.SignalLogger +import org.apache.spark.{Logging, SecurityManager, SparkConf} + +/* + * A dispatcher that is responsible for managing and launching drivers, and is intended to be + * used for Mesos cluster mode. The dispatcher is a long-running process started by the user in + * the cluster independently of Spark applications. + * It contains a [[MesosRestServer]] that listens for requests to submit drivers and a + * [[MesosClusterScheduler]] that processes these requests by negotiating with the Mesos master + * for resources. + * + * A typical new driver lifecycle is the following: + * - Driver submitted via spark-submit talking to the [[MesosRestServer]] + * - [[MesosRestServer]] queues the driver request to [[MesosClusterScheduler]] + * - [[MesosClusterScheduler]] gets resource offers and launches the drivers that are in queue + * + * This dispatcher supports both Mesos fine-grain or coarse-grain mode as the mode is configurable + * per driver launched. + * This class is needed since Mesos doesn't manage frameworks, so the dispatcher acts as + * a daemon to launch drivers as Mesos frameworks upon request. The dispatcher is also started and + * stopped by sbin/start-mesos-dispatcher and sbin/stop-mesos-dispatcher respectively. + */ +private[mesos] class MesosClusterDispatcher( + args: MesosClusterDispatcherArguments, + conf: SparkConf) + extends Logging { + + private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) + private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase() + logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) + + private val engineFactory = recoveryMode match { + case "NONE" => new BlackHoleMesosClusterPersistenceEngineFactory + case "ZOOKEEPER" => new ZookeeperMesosClusterPersistenceEngineFactory(conf) + case _ => throw new IllegalArgumentException("Unsupported recovery mode: " + recoveryMode) + } + + private val scheduler = new MesosClusterScheduler(engineFactory, conf) + + private val server = new MesosRestServer(args.host, args.port, conf, scheduler) + private val webUi = new MesosClusterUI( + new SecurityManager(conf), + args.webUiPort, + conf, + publicAddress, + scheduler) + + private val shutdownLatch = new CountDownLatch(1) + + def start(): Unit = { + webUi.bind() + scheduler.frameworkUrl = webUi.activeWebUiUrl + scheduler.start() + server.start() + } + + def awaitShutdown(): Unit = { + shutdownLatch.await() + } + + def stop(): Unit = { + webUi.stop() + server.stop() + scheduler.stop() + shutdownLatch.countDown() + } +} + +private[mesos] object MesosClusterDispatcher extends Logging { + def main(args: Array[String]) { + SignalLogger.register(log) + val conf = new SparkConf + val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + conf.setMaster(dispatcherArgs.masterUrl) + conf.setAppName(dispatcherArgs.name) + dispatcherArgs.zookeeperUrl.foreach { z => + conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER") + conf.set("spark.mesos.deploy.zookeeper.url", z) + } + val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) + dispatcher.start() + val shutdownHook = new Thread() { + override def run() { + logInfo("Shutdown hook is shutting down dispatcher") + dispatcher.stop() + dispatcher.awaitShutdown() + } + } + Runtime.getRuntime.addShutdownHook(shutdownHook) + dispatcher.awaitShutdown() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala new file mode 100644 index 0000000000000..894cb78d8591a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.SparkConf +import org.apache.spark.util.{IntParam, Utils} + + +private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { + var host = Utils.localHostName() + var port = 7077 + var name = "Spark Cluster" + var webUiPort = 8081 + var masterUrl: String = _ + var zookeeperUrl: Option[String] = None + var propertiesFile: String = _ + + parse(args.toList) + + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + private def parse(args: List[String]): Unit = args match { + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value + parse(tail) + + case ("--port" | "-p") :: IntParam(value) :: tail => + port = value + parse(tail) + + case ("--webui-port" | "-p") :: IntParam(value) :: tail => + webUiPort = value + parse(tail) + + case ("--zk" | "-z") :: value :: tail => + zookeeperUrl = Some(value) + parse(tail) + + case ("--master" | "-m") :: value :: tail => + if (!value.startsWith("mesos://")) { + System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + System.exit(1) + } + masterUrl = value.stripPrefix("mesos://") + parse(tail) + + case ("--name") :: value :: tail => + name = value + parse(tail) + + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => + printUsageAndExit(0) + + case Nil => { + if (masterUrl == null) { + System.err.println("--master is required") + printUsageAndExit(1) + } + } + + case _ => + printUsageAndExit(1) + } + + private def printUsageAndExit(exitCode: Int): Unit = { + System.err.println( + "Usage: MesosClusterDispatcher [options]\n" + + "\n" + + "Options:\n" + + " -h HOST, --host HOST Hostname to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + + " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" + + " --name NAME Framework name to show in Mesos UI\n" + + " -m --master MASTER URI for connecting to Mesos master\n" + + " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" + + " Zookeeper for persistence\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala new file mode 100644 index 0000000000000..1948226800afe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.Date + +import org.apache.spark.deploy.Command +import org.apache.spark.scheduler.cluster.mesos.MesosClusterRetryState + +/** + * Describes a Spark driver that is submitted from the + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]], to be launched by + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * @param jarUrl URL to the application jar + * @param mem Amount of memory for the driver + * @param cores Number of cores for the driver + * @param supervise Supervise the driver for long running app + * @param command The command to launch the driver. + * @param schedulerProperties Extra properties to pass the Mesos scheduler + */ +private[spark] class MesosDriverDescription( + val name: String, + val jarUrl: String, + val mem: Int, + val cores: Double, + val supervise: Boolean, + val command: Command, + val schedulerProperties: Map[String, String], + val submissionId: String, + val submissionDate: Date, + val retryState: Option[MesosClusterRetryState] = None) + extends Serializable { + + def copy( + name: String = name, + jarUrl: String = jarUrl, + mem: Int = mem, + cores: Double = cores, + supervise: Boolean = supervise, + command: Command = command, + schedulerProperties: Map[String, String] = schedulerProperties, + submissionId: String = submissionId, + submissionDate: Date = submissionDate, + retryState: Option[MesosClusterRetryState] = retryState): MesosDriverDescription = { + new MesosDriverDescription(name, jarUrl, mem, cores, supervise, command, schedulerProperties, + submissionId, submissionDate, retryState) + } + + override def toString: String = s"MesosDriverDescription (${command.mainClass})" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala new file mode 100644 index 0000000000000..7b2005e0f1237 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.mesos.Protos.TaskStatus +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage("") { + def render(request: HttpServletRequest): Seq[Node] = { + val state = parent.scheduler.getSchedulerState() + val queuedHeaders = Seq("Driver ID", "Submit Date", "Main Class", "Driver Resources") + val driverHeaders = queuedHeaders ++ + Seq("Start Date", "Mesos Slave ID", "State") + val retryHeaders = Seq("Driver ID", "Submit Date", "Description") ++ + Seq("Last Failed Status", "Next Retry Time", "Attempt Count") + val queuedTable = UIUtils.listingTable(queuedHeaders, queuedRow, state.queuedDrivers) + val launchedTable = UIUtils.listingTable(driverHeaders, driverRow, state.launchedDrivers) + val finishedTable = UIUtils.listingTable(driverHeaders, driverRow, state.finishedDrivers) + val retryTable = UIUtils.listingTable(retryHeaders, retryRow, state.pendingRetryDrivers) + val content = +

Mesos Framework ID: {state.frameworkId}

+
+
+

Queued Drivers:

+ {queuedTable} +

Launched Drivers:

+ {launchedTable} +

Finished Drivers:

+ {finishedTable} +

Supervise drivers waiting for retry:

+ {retryTable} +
+
; + UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + } + + private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { +
+ + + + + + } + + private def driverRow(state: MesosClusterSubmissionState): Seq[Node] = { + + + + + + + + + + } + + private def retryRow(submission: MesosDriverDescription): Seq[Node] = { + + + + + + + + + } + + private def stateString(status: Option[TaskStatus]): String = { + if (status.isEmpty) { + return "" + } + val sb = new StringBuilder + val s = status.get + sb.append(s"State: ${s.getState}") + if (status.get.hasMessage) { + sb.append(s", Message: ${s.getMessage}") + } + if (status.get.hasHealthy) { + sb.append(s", Healthy: ${s.getHealthy}") + } + if (status.get.hasSource) { + sb.append(s", Source: ${s.getSource}") + } + if (status.get.hasReason) { + sb.append(s", Reason: ${s.getReason}") + } + if (status.get.hasTimestamp) { + sb.append(s", Time: ${s.getTimestamp}") + } + sb.toString() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala new file mode 100644 index 0000000000000..4865d46dbc4ab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.{SparkUI, WebUI} + +/** + * UI that displays driver results from the [[org.apache.spark.deploy.mesos.MesosClusterDispatcher]] + */ +private[spark] class MesosClusterUI( + securityManager: SecurityManager, + port: Int, + conf: SparkConf, + dispatcherPublicAddress: String, + val scheduler: MesosClusterScheduler) + extends WebUI(securityManager, port, conf) { + + initialize() + + def activeWebUiUrl: String = "http://" + dispatcherPublicAddress + ":" + boundPort + + override def initialize() { + attachPage(new MesosClusterPage(this)) + attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + } +} + +private object MesosClusterUI { + val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala similarity index 93% rename from core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala rename to core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index b8fd406fb6f9a..307cebfb4bd09 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -30,9 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} import org.apache.spark.util.Utils /** - * A client that submits applications to the standalone Master using a REST protocol. - * This client is intended to communicate with the [[StandaloneRestServer]] and is - * currently used for cluster mode only. + * A client that submits applications to a [[RestSubmissionServer]]. * * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action], * where [action] can be one of create, kill, or status. Each type of request is represented in @@ -53,8 +51,10 @@ import org.apache.spark.util.Utils * implementation of this client can use that information to retry using the version specified * by the server. */ -private[deploy] class StandaloneRestClient extends Logging { - import StandaloneRestClient._ +private[spark] class RestSubmissionClient extends Logging { + import RestSubmissionClient._ + + private val supportedMasterPrefixes = Seq("spark://", "mesos://") /** * Submit an application specified by the parameters in the provided request. @@ -62,7 +62,7 @@ private[deploy] class StandaloneRestClient extends Logging { * If the submission was successful, poll the status of the submission and report * it to the user. Otherwise, report the error message provided by the server. */ - private[rest] def createSubmission( + def createSubmission( master: String, request: CreateSubmissionRequest): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to launch an application in $master.") @@ -107,7 +107,7 @@ private[deploy] class StandaloneRestClient extends Logging { } /** Construct a message that captures the specified parameters for submitting an application. */ - private[rest] def constructSubmitRequest( + def constructSubmitRequest( appResource: String, mainClass: String, appArgs: Array[String], @@ -219,14 +219,23 @@ private[deploy] class StandaloneRestClient extends Logging { /** Return the base URL for communicating with the server, including the protocol version. */ private def getBaseUrl(master: String): String = { - val masterUrl = master.stripPrefix("spark://").stripSuffix("/") + var masterUrl = master + supportedMasterPrefixes.foreach { prefix => + if (master.startsWith(prefix)) { + masterUrl = master.stripPrefix(prefix) + } + } + masterUrl = masterUrl.stripSuffix("/") s"http://$masterUrl/$PROTOCOL_VERSION/submissions" } /** Throw an exception if this is not standalone mode. */ private def validateMaster(master: String): Unit = { - if (!master.startsWith("spark://")) { - throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + val valid = supportedMasterPrefixes.exists { prefix => master.startsWith(prefix) } + if (!valid) { + throw new IllegalArgumentException( + "This REST client only supports master URLs that start with " + + "one of the following: " + supportedMasterPrefixes.mkString(",")) } } @@ -295,7 +304,7 @@ private[deploy] class StandaloneRestClient extends Logging { } } -private[rest] object StandaloneRestClient { +private[spark] object RestSubmissionClient { private val REPORT_DRIVER_STATUS_INTERVAL = 1000 private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" @@ -315,7 +324,7 @@ private[rest] object StandaloneRestClient { } val sparkProperties = conf.getAll.toMap val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } - val client = new StandaloneRestClient + val client = new RestSubmissionClient val submitRequest = client.constructSubmitRequest( appResource, mainClass, appArgs, sparkProperties, environmentVariables) client.createSubmission(master, submitRequest) @@ -323,7 +332,7 @@ private[rest] object StandaloneRestClient { def main(args: Array[String]): Unit = { if (args.size < 2) { - sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]") + sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]") sys.exit(1) } val appResource = args(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala new file mode 100644 index 0000000000000..2e78d03e5c0cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import scala.io.Source +import com.fasterxml.jackson.core.JsonProcessingException +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + */ +private[spark] abstract class RestSubmissionServer( + val host: String, + val requestedPort: Int, + val masterConf: SparkConf) extends Logging { + protected val submitRequestServlet: SubmitRequestServlet + protected val killRequestServlet: KillRequestServlet + protected val statusRequestServlet: StatusRequestServlet + + private var _server: Option[Server] = None + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/${RestSubmissionServer.PROTOCOL_VERSION}/submissions" + protected lazy val contextToServlet = Map[String, RestServlet]( + s"$baseContext/create/*" -> submitRequestServlet, + s"$baseContext/kill/*" -> killRequestServlet, + s"$baseContext/status/*" -> statusRequestServlet, + "/*" -> new ErrorServlet // default handler + ) + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Map the servlets to their corresponding contexts and attach them to a server. + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, startPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/") + contextToServlet.foreach { case (prefix, servlet) => + mainHandler.addServlet(new ServletHolder(servlet), prefix) + } + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } +} + +private[rest] object RestSubmissionServer { + val PROTOCOL_VERSION = RestSubmissionClient.PROTOCOL_VERSION + val SC_UNKNOWN_PROTOCOL_VERSION = 468 +} + +/** + * An abstract servlet for handling requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class RestServlet extends HttpServlet with Logging { + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val message = validateResponse(responseMessage, responseServlet) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.getWriter.write(message.toJson) + } + + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown match { + case j: JObject => j.obj.map { case (k, _) => k }.toArray + case _ => Array.empty[String] // No difference + } + } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Throwable): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e + } + + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + + /** + * Validate the response to ensure that it is correctly constructed. + * + * If it is, simply return the message as is. Otherwise, return an error response instead + * to propagate the exception back to the client and set the appropriate error code. + */ + private def validateResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + try { + responseMessage.validate() + responseMessage + } catch { + case e: Exception => + responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + handleError("Internal server error: " + formatException(e)) + } + } +} + +/** + * A servlet for handling kill requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class KillRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } + sendResponse(responseMessage, response) + } + + protected def handleKill(submissionId: String): KillSubmissionResponse +} + +/** + * A servlet for handling status requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class StatusRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } + sendResponse(responseMessage, response) + } + + protected def handleStatus(submissionId: String): SubmissionStatusResponse +} + +/** + * A servlet for handling submit requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class SubmitRequestServlet extends RestServlet { + + /** + * Submit an application to the Master with parameters specified in the request. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If the request is successfully processed, return an appropriate response to the + * client indicating so. Otherwise, return error instead. + */ + protected override def doPost( + requestServlet: HttpServletRequest, + responseServlet: HttpServletResponse): Unit = { + val responseMessage = + try { + val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + requestMessage.validate() + handleSubmit(requestMessageJson, requestMessage, responseServlet) + } catch { + // The client failed to provide a valid JSON, so this is not our fault + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Malformed request: " + formatException(e)) + } + sendResponse(responseMessage, responseServlet) + } + + protected def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse +} + +/** + * A default servlet that handles error cases that are not captured by other servlets. + */ +private class ErrorServlet extends RestServlet { + private val serverVersion = RestSubmissionServer.PROTOCOL_VERSION + + /** Service a faulty request by returning an appropriate error message to the client. */ + protected override def service( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val path = request.getPathInfo + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList + var versionMismatch = false + var msg = + parts match { + case Nil => + // http://host:port/ + "Missing protocol version." + case `serverVersion` :: Nil => + // http://host:port/correct-version + "Missing the /submissions prefix." + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* + "Missing an action: please specify one of /create, /kill, or /status." + case unknownVersion :: tail => + // http://host:port/unknown-version/* + versionMismatch = true + s"Unknown protocol version '$unknownVersion'." + case _ => + // never reached + s"Malformed path $path." + } + msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." + val error = handleError(msg) + // If there is a version mismatch, include the highest protocol version that + // this server supports in case the client wants to retry with our version + if (versionMismatch) { + error.highestProtocolVersion = serverVersion + response.setStatus(RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) + } else { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + } + sendResponse(error, response) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 2d6b8d4204795..502b9bb701ccf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -18,26 +18,16 @@ package org.apache.spark.deploy.rest import java.io.File -import java.net.InetSocketAddress -import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} - -import scala.io.Source +import javax.servlet.http.HttpServletResponse import akka.actor.ActorRef -import com.fasterxml.jackson.core.JsonProcessingException -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} -import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.json4s._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} -import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** - * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * A server that responds to requests submitted by the [[RestSubmissionClient]]. * This is intended to be embedded in the standalone Master and used in cluster mode only. * * This server responds with different HTTP codes depending on the situation: @@ -54,173 +44,31 @@ import org.apache.spark.deploy.ClientArguments._ * * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to + * @param masterConf the conf used by the Master * @param masterActor reference to the Master actor to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to - * @param masterConf the conf used by the Master */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends Logging { - - import StandaloneRestServer._ - - private var _server: Option[Server] = None - - // A mapping from URL prefixes to servlets that serve them. Exposed for testing. - protected val baseContext = s"/$PROTOCOL_VERSION/submissions" - protected val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf), - s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf), - s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf), - "/*" -> new ErrorServlet // default handler - ) - - /** Start the server and return the bound port. */ - def start(): Int = { - val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) - _server = Some(server) - logInfo(s"Started REST server for submitting applications on port $boundPort") - boundPort - } - - /** - * Map the servlets to their corresponding contexts and attach them to a server. - * Return a 2-tuple of the started server and the bound port. - */ - private def doStart(startPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(host, startPort)) - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - val mainHandler = new ServletContextHandler - mainHandler.setContextPath("/") - contextToServlet.foreach { case (prefix, servlet) => - mainHandler.addServlet(new ServletHolder(servlet), prefix) - } - server.setHandler(mainHandler) - server.start() - val boundPort = server.getConnectors()(0).getLocalPort - (server, boundPort) - } - - def stop(): Unit = { - _server.foreach(_.stop()) - } -} - -private[rest] object StandaloneRestServer { - val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION - val SC_UNKNOWN_PROTOCOL_VERSION = 468 -} - -/** - * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. - */ -private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging { - - /** - * Serialize the given response message to JSON and send it through the response servlet. - * This validates the response before sending it to ensure it is properly constructed. - */ - protected def sendResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): Unit = { - val message = validateResponse(responseMessage, responseServlet) - responseServlet.setContentType("application/json") - responseServlet.setCharacterEncoding("utf-8") - responseServlet.getWriter.write(message.toJson) - } - - /** - * Return any fields in the client request message that the server does not know about. - * - * The mechanism for this is to reconstruct the JSON on the server side and compare the - * diff between this JSON and the one generated on the client side. Any fields that are - * only in the client JSON are treated as unexpected. - */ - protected def findUnknownFields( - requestJson: String, - requestMessage: SubmitRestProtocolMessage): Array[String] = { - val clientSideJson = parse(requestJson) - val serverSideJson = parse(requestMessage.toJson) - val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) - unknown match { - case j: JObject => j.obj.map { case (k, _) => k }.toArray - case _ => Array.empty[String] // No difference - } - } - - /** Return a human readable String representation of the exception. */ - protected def formatException(e: Throwable): String = { - val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") - s"$e\n$stackTraceString" - } - - /** Construct an error message to signal the fact that an exception has been thrown. */ - protected def handleError(message: String): ErrorResponse = { - val e = new ErrorResponse - e.serverSparkVersion = sparkVersion - e.message = message - e - } - - /** - * Parse a submission ID from the relative path, assuming it is the first part of the path. - * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. - * The returned submission ID cannot be empty. If the path is unexpected, return None. - */ - protected def parseSubmissionId(path: String): Option[String] = { - if (path == null || path.isEmpty) { - None - } else { - path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) - } - } - - /** - * Validate the response to ensure that it is correctly constructed. - * - * If it is, simply return the message as is. Otherwise, return an error response instead - * to propagate the exception back to the client and set the appropriate error code. - */ - private def validateResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - try { - responseMessage.validate() - responseMessage - } catch { - case e: Exception => - responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) - handleError("Internal server error: " + formatException(e)) - } - } + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + protected override val killRequestServlet = + new StandaloneKillRequestServlet(masterActor, masterConf) + protected override val statusRequestServlet = + new StandaloneStatusRequestServlet(masterActor, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, have the Master kill the corresponding - * driver and return an appropriate response to the client. Otherwise, return error. - */ - protected override def doPost( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleKill).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in kill request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -238,23 +86,8 @@ private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, request the status of the corresponding - * driver from the Master and include it in the response. Otherwise, return error. - */ - protected override def doGet( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleStatus).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in status request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -276,71 +109,11 @@ private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ -private[rest] class SubmitRequestServlet( +private[rest] class StandaloneSubmitRequestServlet( masterActor: ActorRef, masterUrl: String, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * Submit an application to the Master with parameters specified in the request. - * - * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. - * If the request is successfully processed, return an appropriate response to the - * client indicating so. Otherwise, return error instead. - */ - protected override def doPost( - requestServlet: HttpServletRequest, - responseServlet: HttpServletResponse): Unit = { - val responseMessage = - try { - val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString - val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) - // The response should have already been validated on the client. - // In case this is not true, validate it ourselves to avoid potential NPEs. - requestMessage.validate() - handleSubmit(requestMessageJson, requestMessage, responseServlet) - } catch { - // The client failed to provide a valid JSON, so this is not our fault - case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Malformed request: " + formatException(e)) - } - sendResponse(responseMessage, responseServlet) - } - - /** - * Handle the submit request and construct an appropriate response to return to the client. - * - * This assumes that the request message is already successfully validated. - * If the request message is not of the expected type, return error to the client. - */ - private def handleSubmit( - requestMessageJson: String, - requestMessage: SubmitRestProtocolMessage, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - requestMessage match { - case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) - val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) - val submitResponse = new CreateSubmissionResponse - submitResponse.serverSparkVersion = sparkVersion - submitResponse.message = response.message - submitResponse.success = response.success - submitResponse.submissionId = response.driverId.orNull - val unknownFields = findUnknownFields(requestMessageJson, requestMessage) - if (unknownFields.nonEmpty) { - // If there are fields that the server does not know about, warn the client - submitResponse.unknownFields = unknownFields - } - submitResponse - case unexpected => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError(s"Received message of unexpected type ${unexpected.messageType}.") - } - } + extends SubmitRequestServlet { /** * Build a driver description from the fields specified in the submit request. @@ -389,50 +162,37 @@ private[rest] class SubmitRequestServlet( new DriverDescription( appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) } -} -/** - * A default servlet that handles error cases that are not captured by other servlets. - */ -private class ErrorServlet extends StandaloneRestServlet { - private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION - - /** Service a faulty request by returning an appropriate error message to the client. */ - protected override def service( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val path = request.getPathInfo - val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList - var versionMismatch = false - var msg = - parts match { - case Nil => - // http://host:port/ - "Missing protocol version." - case `serverVersion` :: Nil => - // http://host:port/correct-version - "Missing the /submissions prefix." - case `serverVersion` :: "submissions" :: tail => - // http://host:port/correct-version/submissions/* - "Missing an action: please specify one of /create, /kill, or /status." - case unknownVersion :: tail => - // http://host:port/unknown-version/* - versionMismatch = true - s"Unknown protocol version '$unknownVersion'." - case _ => - // never reached - s"Malformed path $path." - } - msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." - val error = handleError(msg) - // If there is a version mismatch, include the highest protocol version that - // this server supports in case the client wants to retry with our version - if (versionMismatch) { - error.highestProtocolVersion = serverVersion - response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) - } else { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + /** + * Handle the submit request and construct an appropriate response to return to the client. + * + * This assumes that the request message is already successfully validated. + * If the request message is not of the expected type, return error to the client. + */ + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val askTimeout = RpcUtils.askTimeout(conf) + val driverDescription = buildDriverDescription(submitRequest) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success + submitResponse.submissionId = response.driverId.orNull + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + submitResponse.unknownFields = unknownFields + } + submitResponse + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") } - sendResponse(error, response) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index d80abdf15fb34..0d50a768942ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -61,7 +61,7 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { assertProperty[Boolean](key, "boolean", _.toBoolean) private def assertPropertyIsNumeric(key: String): Unit = - assertProperty[Int](key, "numeric", _.toInt) + assertProperty[Double](key, "numeric", _.toDouble) private def assertPropertyIsMemory(key: String): Unit = assertProperty[Int](key, "memory", Utils.memoryStringToMb) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index 8fde8c142a4c1..0e226ee294cab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -35,7 +35,7 @@ private[rest] abstract class SubmitRestProtocolResponse extends SubmitRestProtoc /** * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. */ -private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -46,7 +46,7 @@ private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse /** * A response to a kill request in the REST application submission protocol. */ -private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -58,7 +58,7 @@ private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { /** * A response to a status request in the REST application submission protocol. */ -private[rest] class SubmissionStatusResponse extends SubmitRestProtocolResponse { +private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse { var submissionId: String = null var driverState: String = null var workerId: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala new file mode 100644 index 0000000000000..fd17a980c9319 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest.mesos + +import java.io.File +import java.text.SimpleDateFormat +import java.util.Date +import java.util.concurrent.atomic.AtomicLong +import javax.servlet.http.HttpServletResponse + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest._ +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.util.Utils +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} + + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * All requests are forwarded to + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * This is intended to be used in Mesos cluster mode only. + * For more details about the REST submission please refer to [[RestSubmissionServer]] javadocs. + */ +private[spark] class MesosRestServer( + host: String, + requestedPort: Int, + masterConf: SparkConf, + scheduler: MesosClusterScheduler) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new MesosSubmitRequestServlet(scheduler, masterConf) + protected override val killRequestServlet = + new MesosKillRequestServlet(scheduler, masterConf) + protected override val statusRequestServlet = + new MesosStatusRequestServlet(scheduler, masterConf) +} + +private[deploy] class MesosSubmitRequestServlet( + scheduler: MesosClusterScheduler, + conf: SparkConf) + extends SubmitRequestServlet { + + private val DEFAULT_SUPERVISE = false + private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_CORES = 1.0 + + private val nextDriverNumber = new AtomicLong(0) + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def newDriverId(submitDate: Date): String = { + "driver-%s-%04d".format( + createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) + } + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that launches a mesos framework for the job. + * This does not currently consider fields used by python applications since python + * is not supported in mesos cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) + + // Construct driver description + val conf = new SparkConf(false).setAll(sparkProperties) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + mainClass, appArgs, environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toDouble).getOrElse(DEFAULT_CORES) + val submitDate = new Date() + val submissionId = newDriverId(submitDate) + + new MesosDriverDescription( + name, appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, + command, request.sparkProperties, submissionId, submitDate) + } + + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val driverDescription = buildDriverDescription(submitRequest) + val s = scheduler.submitDriver(driverDescription) + s.serverSparkVersion = sparkVersion + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + s.unknownFields = unknownFields + } + s + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } +} + +private[deploy] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends KillRequestServlet { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = scheduler.killDriver(submissionId) + k.serverSparkVersion = sparkVersion + k + } +} + +private[deploy] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends StatusRequestServlet { + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val d = scheduler.getDriverStatus(submissionId) + d.serverSparkVersion = sparkVersion + d + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 82f652dae0378..3412301e64fd7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,17 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList} -import java.util.Collections +import java.util.{Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} - -import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -49,17 +46,10 @@ private[spark] class CoarseMesosSchedulerBackend( master: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler - with Logging { + with MesosSchedulerUtils { val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt @@ -87,26 +77,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() - - synchronized { - new Thread("CoarseMesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } def createCommand(offer: Offer, numCores: Int): CommandInfo = { @@ -150,8 +122,10 @@ private[spark] class CoarseMesosSchedulerBackend( conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - val uri = conf.get("spark.executor.uri", null) - if (uri == null) { + val uri = conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + + if (uri.isEmpty) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" @@ -164,7 +138,7 @@ private[spark] class CoarseMesosSchedulerBackend( } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + @@ -173,7 +147,7 @@ private[spark] class CoarseMesosSchedulerBackend( s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } command.build() } @@ -183,18 +157,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } + markRegistered() } override def disconnected(d: SchedulerDriver) {} @@ -245,14 +208,6 @@ private[spark] class CoarseMesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - private def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Build a Mesos resource protobuf object */ private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() @@ -284,7 +239,8 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } - driver.reviveOffers() // In case we'd rejected everything before but have now lost a node + // In case we'd rejected everything before but have now lost a node + mesosDriver.reviveOffers() } } } @@ -296,8 +252,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def stop() { super.stop() - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala new file mode 100644 index 0000000000000..3efc536f1456c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.collection.JavaConversions._ + +import org.apache.curator.framework.CuratorFramework +import org.apache.zookeeper.CreateMode +import org.apache.zookeeper.KeeperException.NoNodeException + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.util.Utils + +/** + * Persistence engine factory that is responsible for creating new persistence engines + * to store Mesos cluster mode state. + */ +private[spark] abstract class MesosClusterPersistenceEngineFactory(conf: SparkConf) { + def createEngine(path: String): MesosClusterPersistenceEngine +} + +/** + * Mesos cluster persistence engine is responsible for persisting Mesos cluster mode + * specific state, so that on failover all the state can be recovered and the scheduler + * can resume managing the drivers. + */ +private[spark] trait MesosClusterPersistenceEngine { + def persist(name: String, obj: Object): Unit + def expunge(name: String): Unit + def fetch[T](name: String): Option[T] + def fetchAll[T](): Iterable[T] +} + +/** + * Zookeeper backed persistence engine factory. + * All Zk engines created from this factory shares the same Zookeeper client, so + * all of them reuses the same connection pool. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) + extends MesosClusterPersistenceEngineFactory(conf) { + + lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + + def createEngine(path: String): MesosClusterPersistenceEngine = { + new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) + } +} + +/** + * Black hole persistence engine factory that creates black hole + * persistence engines, which stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngineFactory + extends MesosClusterPersistenceEngineFactory(null) { + def createEngine(path: String): MesosClusterPersistenceEngine = { + new BlackHoleMesosClusterPersistenceEngine + } +} + +/** + * Black hole persistence engine that stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngine extends MesosClusterPersistenceEngine { + override def persist(name: String, obj: Object): Unit = {} + override def fetch[T](name: String): Option[T] = None + override def expunge(name: String): Unit = {} + override def fetchAll[T](): Iterable[T] = Iterable.empty[T] +} + +/** + * Zookeeper based Mesos cluster persistence engine, that stores cluster mode state + * into Zookeeper. Each engine object is operating under one folder in Zookeeper, but + * reuses a shared Zookeeper client. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngine( + baseDir: String, + zk: CuratorFramework, + conf: SparkConf) + extends MesosClusterPersistenceEngine with Logging { + private val WORKING_DIR = + conf.get("spark.deploy.zookeeper.dir", "/spark_mesos_dispatcher") + "/" + baseDir + + SparkCuratorUtil.mkdir(zk, WORKING_DIR) + + def path(name: String): String = { + WORKING_DIR + "/" + name + } + + override def expunge(name: String): Unit = { + zk.delete().forPath(path(name)) + } + + override def persist(name: String, obj: Object): Unit = { + val serialized = Utils.serialize(obj) + val zkPath = path(name) + zk.create().withMode(CreateMode.PERSISTENT).forPath(zkPath, serialized) + } + + override def fetch[T](name: String): Option[T] = { + val zkPath = path(name) + + try { + val fileData = zk.getData().forPath(zkPath) + Some(Utils.deserialize[T](fileData)) + } catch { + case e: NoNodeException => None + case e: Exception => { + logWarning("Exception while reading persisted file, deleting", e) + zk.delete().forPath(zkPath) + None + } + } + } + + override def fetchAll[T](): Iterable[T] = { + zk.getChildren.forPath(WORKING_DIR).map(fetch[T]).flatten + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala new file mode 100644 index 0000000000000..0396e62be5309 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.io.File +import java.util.concurrent.locks.ReentrantLock +import java.util.{Collections, Date, List => JList} + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.mesos.Protos.Environment.Variable +import org.apache.mesos.Protos.TaskStatus.Reason +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.{Scheduler, SchedulerDriver} +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} + + +/** + * Tracks the current state of a Mesos Task that runs a Spark driver. + * @param driverDescription Submitted driver description from + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]] + * @param taskId Mesos TaskID generated for the task + * @param slaveId Slave ID that the task is assigned to + * @param mesosTaskStatus The last known task status update. + * @param startDate The date the task was launched + */ +private[spark] class MesosClusterSubmissionState( + val driverDescription: MesosDriverDescription, + val taskId: TaskID, + val slaveId: SlaveID, + var mesosTaskStatus: Option[TaskStatus], + var startDate: Date) + extends Serializable { + + def copy(): MesosClusterSubmissionState = { + new MesosClusterSubmissionState( + driverDescription, taskId, slaveId, mesosTaskStatus, startDate) + } +} + +/** + * Tracks the retry state of a driver, which includes the next time it should be scheduled + * and necessary information to do exponential backoff. + * This class is not thread-safe, and we expect the caller to handle synchronizing state. + * @param lastFailureStatus Last Task status when it failed. + * @param retries Number of times it has been retried. + * @param nextRetry Time at which it should be retried next + * @param waitTime The amount of time driver is scheduled to wait until next retry. + */ +private[spark] class MesosClusterRetryState( + val lastFailureStatus: TaskStatus, + val retries: Int, + val nextRetry: Date, + val waitTime: Int) extends Serializable { + def copy(): MesosClusterRetryState = + new MesosClusterRetryState(lastFailureStatus, retries, nextRetry, waitTime) +} + +/** + * The full state of the cluster scheduler, currently being used for displaying + * information on the UI. + * @param frameworkId Mesos Framework id for the cluster scheduler. + * @param masterUrl The Mesos master url + * @param queuedDrivers All drivers queued to be launched + * @param launchedDrivers All launched or running drivers + * @param finishedDrivers All terminated drivers + * @param pendingRetryDrivers All drivers pending to be retried + */ +private[spark] class MesosClusterSchedulerState( + val frameworkId: String, + val masterUrl: Option[String], + val queuedDrivers: Iterable[MesosDriverDescription], + val launchedDrivers: Iterable[MesosClusterSubmissionState], + val finishedDrivers: Iterable[MesosClusterSubmissionState], + val pendingRetryDrivers: Iterable[MesosDriverDescription]) + +/** + * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode + * as Mesos tasks in a Mesos cluster. + * All drivers are launched asynchronously by the framework, which will eventually be launched + * by one of the slaves in the cluster. The results of the driver will be stored in slave's task + * sandbox which is accessible by visiting the Mesos UI. + * This scheduler supports recovery by persisting all its state and performs task reconciliation + * on recover, which gets all the latest state for all the drivers from Mesos master. + */ +private[spark] class MesosClusterScheduler( + engineFactory: MesosClusterPersistenceEngineFactory, + conf: SparkConf) + extends Scheduler with MesosSchedulerUtils { + var frameworkUrl: String = _ + private val metricsSystem = + MetricsSystem.createMetricsSystem("mesos_cluster", conf, new SecurityManager(conf)) + private val master = conf.get("spark.master") + private val appName = conf.get("spark.app.name") + private val queuedCapacity = conf.getInt("spark.mesos.maxDrivers", 200) + private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) + private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute + private val schedulerState = engineFactory.createEngine("scheduler") + private val stateLock = new ReentrantLock() + private val finishedDrivers = + new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) + private var frameworkId: String = null + // Holds all the launched drivers and current launch state, keyed by driver id. + private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]() + // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation. + // All drivers that are loaded after failover are added here, as we need get the latest + // state of the tasks from Mesos. + private val pendingRecover = new mutable.HashMap[String, SlaveID]() + // Stores all the submitted drivers that hasn't been launched. + private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]() + // All supervised drivers that are waiting to retry after termination. + private val pendingRetryDrivers = new ArrayBuffer[MesosDriverDescription]() + private val queuedDriversState = engineFactory.createEngine("driverQueue") + private val launchedDriversState = engineFactory.createEngine("launchedDrivers") + private val pendingRetryDriversState = engineFactory.createEngine("retryList") + // Flag to mark if the scheduler is ready to be called, which is until the scheduler + // is registered with Mesos master. + @volatile protected var ready = false + private var masterInfo: Option[MasterInfo] = None + + def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { + val c = new CreateSubmissionResponse + if (!ready) { + c.success = false + c.message = "Scheduler is not ready to take requests" + return c + } + + stateLock.synchronized { + if (isQueueFull()) { + c.success = false + c.message = "Already reached maximum submission size" + return c + } + c.submissionId = desc.submissionId + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + c.success = true + } + c + } + + def killDriver(submissionId: String): KillSubmissionResponse = { + val k = new KillSubmissionResponse + if (!ready) { + k.success = false + k.message = "Scheduler is not ready to take requests" + return k + } + k.submissionId = submissionId + stateLock.synchronized { + // We look for the requested driver in the following places: + // 1. Check if submission is running or launched. + // 2. Check if it's still queued. + // 3. Check if it's in the retry list. + // 4. Check if it has already completed. + if (launchedDrivers.contains(submissionId)) { + val task = launchedDrivers(submissionId) + mesosDriver.killTask(task.taskId) + k.success = true + k.message = "Killing running driver" + } else if (removeFromQueuedDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's still pending" + } else if (removeFromPendingRetryDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's being retried" + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + k.success = false + k.message = "Driver already terminated" + } else { + k.success = false + k.message = "Cannot find driver" + } + } + k + } + + def getDriverStatus(submissionId: String): SubmissionStatusResponse = { + val s = new SubmissionStatusResponse + if (!ready) { + s.success = false + s.message = "Scheduler is not ready to take requests" + return s + } + s.submissionId = submissionId + stateLock.synchronized { + if (queuedDrivers.exists(_.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "QUEUED" + } else if (launchedDrivers.contains(submissionId)) { + s.success = true + s.driverState = "RUNNING" + launchedDrivers(submissionId).mesosTaskStatus.foreach(state => s.message = state.toString) + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "FINISHED" + finishedDrivers + .find(d => d.driverDescription.submissionId.equals(submissionId)).get.mesosTaskStatus + .foreach(state => s.message = state.toString) + } else if (pendingRetryDrivers.exists(_.submissionId.equals(submissionId))) { + val status = pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .get.retryState.get.lastFailureStatus + s.success = true + s.driverState = "RETRYING" + s.message = status.toString + } else { + s.success = false + s.driverState = "NOT_FOUND" + } + } + s + } + + private def isQueueFull(): Boolean = launchedDrivers.size >= queuedCapacity + + /** + * Recover scheduler state that is persisted. + * We still need to do task reconciliation to be up to date of the latest task states + * as it might have changed while the scheduler is failing over. + */ + private def recoverState(): Unit = { + stateLock.synchronized { + launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state => + launchedDrivers(state.taskId.getValue) = state + pendingRecover(state.taskId.getValue) = state.slaveId + } + queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d) + // There is potential timing issue where a queued driver might have been launched + // but the scheduler shuts down before the queued driver was able to be removed + // from the queue. We try to mitigate this issue by walking through all queued drivers + // and remove if they're already launched. + queuedDrivers + .filter(d => launchedDrivers.contains(d.submissionId)) + .foreach(d => removeFromQueuedDrivers(d.submissionId)) + pendingRetryDriversState.fetchAll[MesosDriverDescription]() + .foreach(s => pendingRetryDrivers += s) + // TODO: Consider storing finished drivers so we can show them on the UI after + // failover. For now we clear the history on each recovery. + finishedDrivers.clear() + } + } + + /** + * Starts the cluster scheduler and wait until the scheduler is registered. + * This also marks the scheduler to be ready for requests. + */ + def start(): Unit = { + // TODO: Implement leader election to make sure only one framework running in the cluster. + val fwId = schedulerState.fetch[String]("frameworkId") + val builder = FrameworkInfo.newBuilder() + .setUser(Utils.getCurrentUserName()) + .setName(appName) + .setWebuiUrl(frameworkUrl) + .setCheckpoint(true) + .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash + fwId.foreach { id => + builder.setId(FrameworkID.newBuilder().setValue(id).build()) + frameworkId = id + } + recoverState() + metricsSystem.registerSource(new MesosClusterSchedulerSource(this)) + metricsSystem.start() + startScheduler(master, MesosClusterScheduler.this, builder.build()) + ready = true + } + + def stop(): Unit = { + ready = false + metricsSystem.report() + metricsSystem.stop() + mesosDriver.stop(true) + } + + override def registered( + driver: SchedulerDriver, + newFrameworkId: FrameworkID, + masterInfo: MasterInfo): Unit = { + logInfo("Registered as framework ID " + newFrameworkId.getValue) + if (newFrameworkId.getValue != frameworkId) { + frameworkId = newFrameworkId.getValue + schedulerState.persist("frameworkId", frameworkId) + } + markRegistered() + + stateLock.synchronized { + this.masterInfo = Some(masterInfo) + if (!pendingRecover.isEmpty) { + // Start task reconciliation if we need to recover. + val statuses = pendingRecover.collect { + case (taskId, slaveId) => + val newStatus = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(slaveId) + .setState(MesosTaskState.TASK_STAGING) + .build() + launchedDrivers.get(taskId).map(_.mesosTaskStatus.getOrElse(newStatus)) + .getOrElse(newStatus) + } + // TODO: Page the status updates to avoid trying to reconcile + // a large amount of tasks at once. + driver.reconcileTasks(statuses) + } + } + } + + private def buildDriverCommand(desc: MesosDriverDescription): CommandInfo = { + val appJar = CommandInfo.URI.newBuilder() + .setValue(desc.jarUrl.stripPrefix("file:").stripPrefix("local:")).build() + val builder = CommandInfo.newBuilder().addUris(appJar) + val entries = + (conf.getOption("spark.executor.extraLibraryPath").toList ++ + desc.command.libraryPathEntries) + val prefixEnv = if (!entries.isEmpty) { + Utils.libraryPathEnvPrefix(entries) + } else { + "" + } + val envBuilder = Environment.newBuilder() + desc.command.environment.foreach { case (k, v) => + envBuilder.addVariables(Variable.newBuilder().setName(k).setValue(v).build()) + } + // Pass all spark properties to executor. + val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") + envBuilder.addVariables( + Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) + val cmdOptions = generateCmdOption(desc) + val executorUri = desc.schedulerProperties.get("spark.executor.uri") + .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) + val appArguments = desc.command.arguments.mkString(" ") + val cmd = if (executorUri.isDefined) { + builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) + val folderBasename = executorUri.get.split('/').last.split('.').head + val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" + val cmdJar = s"../${desc.jarUrl.split("/").last}" + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } else { + val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") + .orElse(conf.getOption("spark.home")) + .orElse(Option(System.getenv("SPARK_HOME"))) + .getOrElse { + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdJar = desc.jarUrl.split("/").last + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } + builder.setValue(cmd) + builder.setEnvironment(envBuilder.build()) + builder.build() + } + + private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + var options = Seq( + "--name", desc.schedulerProperties("spark.app.name"), + "--class", desc.command.mainClass, + "--master", s"mesos://${conf.get("spark.master")}", + "--driver-cores", desc.cores.toString, + "--driver-memory", s"${desc.mem}M") + desc.schedulerProperties.get("spark.executor.memory").map { v => + options ++= Seq("--executor-memory", v) + } + desc.schedulerProperties.get("spark.cores.max").map { v => + options ++= Seq("--total-executor-cores", v) + } + options + } + + private class ResourceOffer(val offer: Offer, var cpu: Double, var mem: Double) { + override def toString(): String = { + s"Offer id: ${offer.getId.getValue}, cpu: $cpu, mem: $mem" + } + } + + /** + * This method takes all the possible candidates and attempt to schedule them with Mesos offers. + * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled + * logic on each task. + */ + private def scheduleTasks( + candidates: Seq[MesosDriverDescription], + afterLaunchCallback: (String) => Boolean, + currentOffers: List[ResourceOffer], + tasks: mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]): Unit = { + for (submission <- candidates) { + val driverCpu = submission.cores + val driverMem = submission.mem + logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") + val offerOption = currentOffers.find { o => + o.cpu >= driverCpu && o.mem >= driverMem + } + if (offerOption.isEmpty) { + logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + + s"cpu: $driverCpu, mem: $driverMem") + } else { + val offer = offerOption.get + offer.cpu -= driverCpu + offer.mem -= driverMem + val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() + val cpuResource = Resource.newBuilder() + .setName("cpus").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build() + val memResource = Resource.newBuilder() + .setName("mem").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build() + val commandInfo = buildDriverCommand(submission) + val appName = submission.schedulerProperties("spark.app.name") + val taskInfo = TaskInfo.newBuilder() + .setTaskId(taskId) + .setName(s"Driver for $appName") + .setSlaveId(offer.offer.getSlaveId) + .setCommand(commandInfo) + .addResources(cpuResource) + .addResources(memResource) + .build() + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) + queuedTasks += taskInfo + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + + submission.submissionId) + val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, + None, new Date()) + launchedDrivers(submission.submissionId) = newState + launchedDriversState.persist(submission.submissionId, newState) + afterLaunchCallback(submission.submissionId) + } + } + } + + override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { + val currentOffers = offers.map { o => + new ResourceOffer( + o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) + }.toList + logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") + val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() + val currentTime = new Date() + + stateLock.synchronized { + // We first schedule all the supervised drivers that are ready to retry. + // This list will be empty if none of the drivers are marked as supervise. + val driversToRetry = pendingRetryDrivers.filter { d => + d.retryState.get.nextRetry.before(currentTime) + } + scheduleTasks( + driversToRetry, + removeFromPendingRetryDrivers, + currentOffers, + tasks) + // Then we walk through the queued drivers and try to schedule them. + scheduleTasks( + queuedDrivers, + removeFromQueuedDrivers, + currentOffers, + tasks) + } + tasks.foreach { case (offerId, tasks) => + driver.launchTasks(Collections.singleton(offerId), tasks) + } + offers + .filter(o => !tasks.keySet.contains(o.getId)) + .foreach(o => driver.declineOffer(o.getId)) + } + + def getSchedulerState(): MesosClusterSchedulerState = { + def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + stateLock.synchronized { + new MesosClusterSchedulerState( + frameworkId, + masterInfo.map(m => s"http://${m.getIp}:${m.getPort}"), + copyBuffer(queuedDrivers), + launchedDrivers.values.map(_.copy()).toList, + finishedDrivers.map(_.copy()).toList, + copyBuffer(pendingRetryDrivers)) + } + } + + override def offerRescinded(driver: SchedulerDriver, offerId: OfferID): Unit = {} + override def disconnected(driver: SchedulerDriver): Unit = {} + override def reregistered(driver: SchedulerDriver, masterInfo: MasterInfo): Unit = { + logInfo(s"Framework re-registered with master ${masterInfo.getId}") + } + override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} + override def error(driver: SchedulerDriver, error: String): Unit = { + logError("Error received: " + error) + } + + /** + * Check if the task state is a recoverable state that we can relaunch the task. + * Task state like TASK_ERROR are not relaunchable state since it wasn't able + * to be validated by Mesos. + */ + private def shouldRelaunch(state: MesosTaskState): Boolean = { + state == MesosTaskState.TASK_FAILED || + state == MesosTaskState.TASK_KILLED || + state == MesosTaskState.TASK_LOST + } + + override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { + val taskId = status.getTaskId.getValue + stateLock.synchronized { + if (launchedDrivers.contains(taskId)) { + if (status.getReason == Reason.REASON_RECONCILIATION && + !pendingRecover.contains(taskId)) { + // Task has already received update and no longer requires reconciliation. + return + } + val state = launchedDrivers(taskId) + // Check if the driver is supervise enabled and can be relaunched. + if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { + removeFromLaunchedDrivers(taskId) + val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState + val (retries, waitTimeSec) = retryState + .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } + .getOrElse{ (1, 1) } + val nextRetry = new Date(new Date().getTime + waitTimeSec * 1000L) + + val newDriverDescription = state.driverDescription.copy( + retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) + pendingRetryDrivers += newDriverDescription + pendingRetryDriversState.persist(taskId, newDriverDescription) + } else if (TaskState.isFinished(TaskState.fromMesos(status.getState))) { + removeFromLaunchedDrivers(taskId) + if (finishedDrivers.size >= retainedDrivers) { + val toRemove = math.max(retainedDrivers / 10, 1) + finishedDrivers.trimStart(toRemove) + } + finishedDrivers += state + } + state.mesosTaskStatus = Option(status) + } else { + logError(s"Unable to find driver $taskId in status update") + } + } + } + + override def frameworkMessage( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + message: Array[Byte]): Unit = {} + + override def executorLost( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + status: Int): Unit = {} + + private def removeFromQueuedDrivers(id: String): Boolean = { + val index = queuedDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + queuedDrivers.remove(index) + queuedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromLaunchedDrivers(id: String): Boolean = { + if (launchedDrivers.remove(id).isDefined) { + launchedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromPendingRetryDrivers(id: String): Boolean = { + val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + pendingRetryDrivers.remove(index) + pendingRetryDriversState.expunge(id) + true + } else { + false + } + } + + def getQueuedDriversSize: Int = queuedDrivers.size + def getLaunchedDriversSize: Int = launchedDrivers.size + def getPendingRetryDriversSize: Int = pendingRetryDrivers.size +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala new file mode 100644 index 0000000000000..1fe94974c8e36 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[mesos] class MesosClusterSchedulerSource(scheduler: MesosClusterScheduler) + extends Source { + override def sourceName: String = "mesos_cluster" + override def metricRegistry: MetricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("waitingDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getQueuedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("launchedDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getLaunchedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("retryDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getPendingRetryDriversSize + }) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index d9d62b0e287ed..8346a2407489f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -18,23 +18,19 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections +import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, - ExecutorInfo => MesosExecutorInfo, _} - +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -47,14 +43,7 @@ private[spark] class MesosSchedulerBackend( master: String) extends SchedulerBackend with MScheduler - with Logging { - - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null + with MesosSchedulerUtils { // Which slave IDs we have executors on val slaveIdsWithExecutors = new HashSet[String] @@ -73,26 +62,9 @@ private[spark] class MesosSchedulerBackend( @volatile var appId: String = _ override def start() { - synchronized { - classLoader = Thread.currentThread.getContextClassLoader - - new Thread("MesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + classLoader = Thread.currentThread.getContextClassLoader + startScheduler(master, MesosSchedulerBackend.this, fwInfo) } def createExecutorInfo(execId: String): MesosExecutorInfo = { @@ -125,17 +97,19 @@ private[spark] class MesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val uri = sc.conf.get("spark.executor.uri", null) + val uri = sc.conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + val executorBackendName = classOf[MesosExecutorBackend].getName - if (uri == null) { + if (uri.isEmpty) { val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } val cpus = Resource.newBuilder() .setName("cpus") @@ -181,18 +155,7 @@ private[spark] class MesosSchedulerBackend( inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } + markRegistered() } } @@ -287,14 +250,6 @@ private[spark] class MesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Turn a Spark TaskDescription into a Mesos task */ def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() @@ -339,13 +294,13 @@ private[spark] class MesosSchedulerBackend( } override def stop() { - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } override def reviveOffers() { - driver.reviveOffers() + mesosDriver.reviveOffers() } override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} @@ -380,7 +335,7 @@ private[spark] class MesosSchedulerBackend( } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - driver.killTask( + mesosDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala new file mode 100644 index 0000000000000..d11228f3d016a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util.List +import java.util.concurrent.CountDownLatch + +import scala.collection.JavaConversions._ + +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} +import org.apache.mesos.{MesosSchedulerDriver, Scheduler} +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * Shared trait for implementing a Mesos Scheduler. This holds common state and helper + * methods and Mesos scheduler will use. + */ +private[mesos] trait MesosSchedulerUtils extends Logging { + // Lock used to wait for scheduler to be registered + private final val registerLatch = new CountDownLatch(1) + + // Driver for talking to Mesos + protected var mesosDriver: MesosSchedulerDriver = null + + /** + * Starts the MesosSchedulerDriver with the provided information. This method returns + * only after the scheduler has registered with Mesos. + * @param masterUrl Mesos master connection URL + * @param scheduler Scheduler object + * @param fwInfo FrameworkInfo to pass to the Mesos master + */ + def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = { + synchronized { + if (mesosDriver != null) { + registerLatch.await() + return + } + + new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { + setDaemon(true) + + override def run() { + mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl) + try { + val ret = mesosDriver.run() + logInfo("driver.run() returned with code " + ret) + if (ret.equals(Status.DRIVER_ABORTED)) { + System.exit(1) + } + } catch { + case e: Exception => { + logError("driver.run() failed", e) + System.exit(1) + } + } + } + }.start() + + registerLatch.await() + } + } + + /** + * Signal that the scheduler has registered with Mesos. + */ + protected def markRegistered(): Unit = { + registerLatch.countDown() + } + + /** + * Get the amount of resources for the specified type from the resource list + */ + protected def getResource(res: List[Resource], name: String): Double = { + for (r <- res if r.getName == name) { + return r.getScalar.getValue + } + 0.0 + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4561e5b8e9663..c4e6f06146b0a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -231,7 +231,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.rest.StandaloneRestClient") + mainClass should be ("org.apache.spark.deploy.rest.RestSubmissionClient") } else { childArgsStr should startWith ("--supervise --memory 4g --cores 5") childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2" diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 8e09976636386..0a318a27ac212 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -39,9 +39,9 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { - private val client = new StandaloneRestClient + private val client = new RestSubmissionClient private var actorSystem: Option[ActorSystem] = None - private var server: Option[StandaloneRestServer] = None + private var server: Option[RestSubmissionServer] = None override def afterEach() { actorSystem.foreach(_.shutdown()) @@ -89,7 +89,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { conf.set("spark.app.name", "dreamer") val appArgs = Array("one", "two", "six") // main method calls this - val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf) + val response = RestSubmissionClient.run("app-resource", "main-class", appArgs, conf) val submitResponse = getSubmitResponse(response) assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) assert(submitResponse.serverSparkVersion === SPARK_VERSION) @@ -208,7 +208,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val json = constructSubmitRequest(masterUrl).toJson val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" @@ -238,7 +238,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths, bad requests") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" val statusRequestPath = s"$httpUrl/$v/submissions/status" @@ -276,7 +276,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("bad request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET") val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET") val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET") @@ -292,7 +292,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(code5 === HttpServletResponse.SC_BAD_REQUEST) assert(code6 === HttpServletResponse.SC_BAD_REQUEST) assert(code7 === HttpServletResponse.SC_BAD_REQUEST) - assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + assert(code8 === RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) // all responses should be error responses val errorResponse1 = getErrorResponse(response1) val errorResponse2 = getErrorResponse(response2) @@ -310,13 +310,13 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(errorResponse5.highestProtocolVersion === null) assert(errorResponse6.highestProtocolVersion === null) assert(errorResponse7.highestProtocolVersion === null) - assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION) + assert(errorResponse8.highestProtocolVersion === RestSubmissionServer.PROTOCOL_VERSION) } test("server returns unknown fields") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val oldJson = constructSubmitRequest(masterUrl).toJson val oldFields = parse(oldJson).asInstanceOf[JObject].obj @@ -340,7 +340,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("client handles faulty server") { val masterUrl = startFaultyServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill/anything" val statusRequestPath = s"$httpUrl/$v/submissions/status/anything" @@ -400,9 +400,9 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) val _server = if (faulty) { - new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } else { - new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } val port = _server.start() // set these to clean them up after every test @@ -563,20 +563,18 @@ private class SmarterMaster extends Actor { private class FaultyStandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) { + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { - protected override val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new MalformedSubmitServlet, - s"$baseContext/kill/*" -> new InvalidKillServlet, - s"$baseContext/status/*" -> new ExplodingStatusServlet, - "/*" -> new ErrorServlet - ) + protected override val submitRequestServlet = new MalformedSubmitServlet + protected override val killRequestServlet = new InvalidKillServlet + protected override val statusRequestServlet = new ExplodingStatusServlet /** A faulty servlet that produces malformed responses. */ - class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) { + class MalformedSubmitServlet + extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) { protected override def sendResponse( responseMessage: SubmitRestProtocolResponse, responseServlet: HttpServletResponse): Unit = { @@ -586,7 +584,7 @@ private class FaultyStandaloneRestServer( } /** A faulty servlet that produces invalid responses. */ - class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) { + class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = super.handleKill(submissionId) k.submissionId = null @@ -595,7 +593,7 @@ private class FaultyStandaloneRestServer( } /** A faulty status servlet that explodes. */ - class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) { + class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) { private def explode: Int = 1 / 0 protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val s = super.handleStatus(submissionId) diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala new file mode 100644 index 0000000000000..f28e29e9b8d8e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import java.util.Date + +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.{LocalSparkContext, SparkConf} + + +class MesosClusterSchedulerSuite extends FunSuite with LocalSparkContext with MockitoSugar { + + private val command = new Command("mainClass", Seq("arg"), null, null, null, null) + + test("can queue drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val response2 = + scheduler.submitDriver(new MesosDriverDescription( + "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + assert(response2.success) + val state = scheduler.getSchedulerState() + val queuedDrivers = state.queuedDrivers.toList + assert(queuedDrivers(0).submissionId == response.submissionId) + assert(queuedDrivers(1).submissionId == response2.submissionId) + } + + test("can kill queued drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val killResponse = scheduler.killDriver(response.submissionId) + assert(killResponse.success) + val state = scheduler.getSchedulerState() + assert(state.queuedDrivers.isEmpty) + } +} diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 594bf78b67713..8f53d8201a089 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -78,6 +78,9 @@ To verify that the Mesos cluster is ready for Spark, navigate to the Mesos maste To use Mesos from Spark, you need a Spark binary package available in a place accessible by Mesos, and a Spark driver program configured to connect to Mesos. +Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure +`spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -107,7 +110,11 @@ the `make-distribution.sh` script included in a Spark source tarball/checkout. The Master URLs for Mesos are in the form `mesos://host:5050` for a single-master Mesos cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooKeeper. -The driver also needs some configuration in `spark-env.sh` to interact properly with Mesos: +## Client Mode + +In client mode, a Spark Mesos framework is launched directly on the client machine and waits for the driver output. + +The driver needs some configuration in `spark-env.sh` to interact properly with Mesos: 1. In `spark-env.sh` set some environment variables: * `export MESOS_NATIVE_JAVA_LIBRARY=`. This path is typically @@ -129,8 +136,7 @@ val sc = new SparkContext(conf) {% endhighlight %} (You can also use [`spark-submit`](submitting-applications.html) and configure `spark.executor.uri` -in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file. Note -that `spark-submit` currently only supports deploying the Spark driver in `client` mode for Mesos.) +in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file.) When running a shell, the `spark.executor.uri` parameter is inherited from `SPARK_EXECUTOR_URI`, so it does not need to be redundantly passed in as a system property. @@ -139,6 +145,17 @@ it does not need to be redundantly passed in as a system property. ./bin/spark-shell --master mesos://host:5050 {% endhighlight %} +## Cluster mode + +Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client +can find the results of the driver from the Mesos Web UI. + +To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master url (e.g: mesos://host:5050). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url +to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +Spark cluster Web UI. # Mesos Run Modes diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh new file mode 100755 index 0000000000000..ef1fc573d5c65 --- /dev/null +++ b/sbin/start-mesos-dispatcher.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Starts the Mesos Cluster Dispatcher on the machine this script is executed on. +# The Mesos Cluster Dispatcher is responsible for launching the Mesos framework and +# Rest server to handle driver requests for Mesos cluster mode. +# Only one cluster dispatcher is needed per Mesos cluster. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then + SPARK_MESOS_DISPATCHER_PORT=7077 +fi + +if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then + SPARK_MESOS_DISPATCHER_HOST=`hostname` +fi + + +"$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" diff --git a/sbin/stop-mesos-dispatcher.sh b/sbin/stop-mesos-dispatcher.sh new file mode 100755 index 0000000000000..cb65d95b5e524 --- /dev/null +++ b/sbin/stop-mesos-dispatcher.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Stop the Mesos Cluster dispatcher on the machine this script is executed on. + +sbin=`dirname "$0"` +sbin=`cd "$sbin"; pwd` + +. "$sbin/spark-config.sh" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 + From 28b1af7420e0b5e7e2dfc09eafc45fe2ffcde5ec Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 28 Apr 2015 13:49:29 -0700 Subject: [PATCH 18/18] [MINOR] [CORE] Warn users who try to cache RDDs with dynamic allocation on. Author: Marcelo Vanzin Closes #5751 from vanzin/cached-rdd-warning and squashes the following commits: 554cc07 [Marcelo Vanzin] Change message. 9efb9da [Marcelo Vanzin] [minor] [core] Warn users who try to cache RDDs with dynamic allocation on. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 65b903a55d5bd..d0cf2a8dd01cd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1396,6 +1396,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { + _executorAllocationManager.foreach { _ => + logWarning( + s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + + s"${rdd.id} will be lost when executors are removed.") + } persistentRdds(rdd.id) = rdd }
Property NameDefaultMeaning
spark.reducer.maxMbInFlight48spark.reducer.maxSizeInFlight48m - Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since + Maximum size of map outputs to fetch simultaneously from each reduce task. Since each output requires us to create a buffer to receive it, this represents a fixed memory overhead per reduce task, so keep it small unless you have a large amount of memory.
spark.shuffle.file.buffer.kb32spark.shuffle.file.buffer32k - Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers + Size of the in-memory buffer for each shuffle file output stream. These buffers reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
spark.io.compression.lz4.block.size32768spark.io.compression.lz4.blockSize32k - Block size (in bytes) used in LZ4 compression, in the case when LZ4 compression codec + Block size used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used.
spark.io.compression.snappy.block.size32768spark.io.compression.snappy.blockSize32k - Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec + Block size used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used.
spark.kryoserializer.buffer.max.mb64spark.kryoserializer.buffer.max64m - Maximum allowable size of Kryo serialization buffer, in megabytes. This must be larger than any + Maximum allowable size of Kryo serialization buffer. This must be larger than any object you attempt to serialize. Increase this if you get a "buffer limit exceeded" exception inside Kryo.
spark.kryoserializer.buffer.mb0.064spark.kryoserializer.buffer64k - Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to spark.kryoserializer.buffer.max.mb if needed.
Property NameDefaultMeaning
spark.broadcast.blockSize40964m - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Size of each piece of a block for TorrentBroadcastFactory. Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit.
spark.storage.memoryMapThreshold20971522m - Size of a block, in bytes, above which Spark memory maps when reading a block from disk. + Size of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system.
{submission.submissionId}{submission.submissionDate}{submission.command.mainClass}cpus: {submission.cores}, mem: {submission.mem}
{state.driverDescription.submissionId}{state.driverDescription.submissionDate}{state.driverDescription.command.mainClass}cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem}{state.startDate}{state.slaveId.getValue}{stateString(state.mesosTaskStatus)}
{submission.submissionId}{submission.submissionDate}{submission.command.mainClass}{submission.retryState.get.lastFailureStatus}{submission.retryState.get.nextRetry}{submission.retryState.get.retries}