spark.kryo.registrator |
(none) |
- If you use Kryo serialization, set this class to register your custom classes with Kryo.
- It should be set to a class that extends
+ If you use Kryo serialization, set this class to register your custom classes with Kryo. This
+ property is useful if you need to register your classes in a custom way, e.g. to specify a custom
+ field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be
+ set to a class that extends
KryoRegistrator .
See the tuning guide for more details.
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 1511ae6dda4ed..11622414494e4 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -83,7 +83,7 @@ val idf = new IDF().fit(tf)
val tfidf: RDD[Vector] = idf.transform(tf)
{% endhighlight %}
-MLLib's IDF implementation provides an option for ignoring terms which occur in less than a
+MLlib's IDF implementation provides an option for ignoring terms which occur in less than a
minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature
can be used by passing the `minDocFreq` value to the IDF constructor.
diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md
index c4632413991f1..10a5131c07414 100644
--- a/docs/mllib-statistics.md
+++ b/docs/mllib-statistics.md
@@ -197,7 +197,7 @@ print Statistics.corr(data, method="pearson")
## Stratified sampling
-Unlike the other statistics functions, which reside in MLLib, stratified sampling methods,
+Unlike the other statistics functions, which reside in MLlib, stratified sampling methods,
`sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified
sampling, the keys can be thought of as a label and the value as a specific attribute. For example
the key can be man or woman, or document ids, and the respective values can be the list of ages
diff --git a/docs/tuning.md b/docs/tuning.md
index 8fb2a0433b1a8..9b5c9adac6a4f 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -47,24 +47,11 @@ registration requirement, but we recommend trying it in any network-intensive ap
Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered
in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library.
-To register your own custom classes with Kryo, create a public class that extends
-[`org.apache.spark.serializer.KryoRegistrator`](api/scala/index.html#org.apache.spark.serializer.KryoRegistrator) and set the
-`spark.kryo.registrator` config property to point to it, as follows:
+To register your own custom classes with Kryo, use the `registerKryoClasses` method.
{% highlight scala %}
-import com.esotericsoftware.kryo.Kryo
-import org.apache.spark.serializer.KryoRegistrator
-
-class MyRegistrator extends KryoRegistrator {
- override def registerClasses(kryo: Kryo) {
- kryo.register(classOf[MyClass1])
- kryo.register(classOf[MyClass2])
- }
-}
-
val conf = new SparkConf().setMaster(...).setAppName(...)
-conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-conf.set("spark.kryo.registrator", "mypackage.MyRegistrator")
+conf.registerKryoClasses(Seq(classOf[MyClass1], classOf[MyClass2]))
val sc = new SparkContext(conf)
{% endhighlight %}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java
index 8d381d4e0a943..95a430f1da234 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java
@@ -32,7 +32,7 @@
import scala.Tuple2;
/**
- * Example using MLLib ALS from Java.
+ * Example using MLlib ALS from Java.
*/
public final class JavaALS {
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
index f796123a25727..e575eedeb465c 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
@@ -30,7 +30,7 @@
import org.apache.spark.mllib.linalg.Vectors;
/**
- * Example using MLLib KMeans from Java.
+ * Example using MLlib KMeans from Java.
*/
public final class JavaKMeans {
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
index 5622df5ce03ff..981bc4f0613a9 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
@@ -57,7 +57,7 @@ public class JavaCustomReceiver extends Receiver {
public static void main(String[] args) {
if (args.length < 2) {
- System.err.println("Usage: JavaNetworkWordCount ");
+ System.err.println("Usage: JavaCustomReceiver ");
System.exit(1);
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
index e06f4dcd54442..e322d4ce5a745 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
@@ -18,17 +18,7 @@
package org.apache.spark.examples.bagel
import org.apache.spark._
-import org.apache.spark.SparkContext._
-import org.apache.spark.serializer.KryoRegistrator
-
import org.apache.spark.bagel._
-import org.apache.spark.bagel.Bagel._
-
-import scala.collection.mutable.ArrayBuffer
-
-import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
-
-import com.esotericsoftware.kryo._
class PageRankUtils extends Serializable {
def computeWithCombiner(numVertices: Long, epsilon: Double)(
@@ -99,13 +89,6 @@ class PRMessage() extends Message[String] with Serializable {
}
}
-class PRKryoRegistrator extends KryoRegistrator {
- def registerClasses(kryo: Kryo) {
- kryo.register(classOf[PRVertex])
- kryo.register(classOf[PRMessage])
- }
-}
-
class CustomPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
index e4db3ec51313d..859abedf2a55e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
@@ -38,8 +38,7 @@ object WikipediaPageRank {
}
val sparkConf = new SparkConf()
sparkConf.setAppName("WikipediaPageRank")
- sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- sparkConf.set("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
+ sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage]))
val inputFile = args(0)
val threshold = args(1).toDouble
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
index 45527d9382fd0..d70d93608a57c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
@@ -46,10 +46,8 @@ object Analytics extends Logging {
}
val options = mutable.Map(optionsList: _*)
- val conf = new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
- .set("spark.locality.wait", "100000")
+ val conf = new SparkConf().set("spark.locality.wait", "100000")
+ GraphXUtils.registerKryoClasses(conf)
val numEPart = options.remove("numEPart").map(_.toInt).getOrElse {
println("Set the number of edge partitions using --numEPart.")
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
index 5f35a5836462e..05676021718d9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
@@ -18,7 +18,7 @@
package org.apache.spark.examples.graphx
import org.apache.spark.SparkContext._
-import org.apache.spark.graphx.PartitionStrategy
+import org.apache.spark.graphx.{GraphXUtils, PartitionStrategy}
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.graphx.util.GraphGenerators
import java.io.{PrintWriter, FileOutputStream}
@@ -80,8 +80,7 @@ object SynthBenchmark {
val conf = new SparkConf()
.setAppName(s"GraphX Synth Benchmark (nverts = $numVertices, app = $app)")
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+ GraphXUtils.registerKryoClasses(conf)
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index fc6678013b932..8796c28db8a66 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -19,7 +19,6 @@ package org.apache.spark.examples.mllib
import scala.collection.mutable
-import com.esotericsoftware.kryo.Kryo
import org.apache.log4j.{Level, Logger}
import scopt.OptionParser
@@ -27,7 +26,6 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.recommendation.{ALS, MatrixFactorizationModel, Rating}
import org.apache.spark.rdd.RDD
-import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}
/**
* An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
@@ -40,13 +38,6 @@ import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}
*/
object MovieLensALS {
- class ALSRegistrator extends KryoRegistrator {
- override def registerClasses(kryo: Kryo) {
- kryo.register(classOf[Rating])
- kryo.register(classOf[mutable.BitSet])
- }
- }
-
case class Params(
input: String = null,
kryo: Boolean = false,
@@ -108,8 +99,7 @@ object MovieLensALS {
def run(params: Params) {
val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
if (params.kryo) {
- conf.set("spark.serializer", classOf[KryoSerializer].getName)
- .set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
+ conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating]))
.set("spark.kryoserializer.buffer.mb", "8")
}
val sc = new SparkContext(conf)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
index 1948c978c30bf..563c948957ecf 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -27,10 +27,10 @@ import org.apache.spark.graphx.impl._
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
import org.apache.spark.util.collection.OpenHashSet
-
/**
* Registers GraphX classes with Kryo for improved performance.
*/
+@deprecated("Register GraphX classes with Kryo using GraphXUtils.registerKryoClasses", "1.2.0")
class GraphKryoRegistrator extends KryoRegistrator {
def registerClasses(kryo: Kryo) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
new file mode 100644
index 0000000000000..2cb07937eaa2a
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx
+
+import org.apache.spark.SparkConf
+
+import org.apache.spark.graphx.impl._
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
+
+import org.apache.spark.util.collection.{OpenHashSet, BitSet}
+import org.apache.spark.util.BoundedPriorityQueue
+
+object GraphXUtils {
+ /**
+ * Registers classes that GraphX uses with Kryo.
+ */
+ def registerKryoClasses(conf: SparkConf) {
+ conf.registerKryoClasses(Array(
+ classOf[Edge[Object]],
+ classOf[(VertexId, Object)],
+ classOf[EdgePartition[Object, Object]],
+ classOf[BitSet],
+ classOf[VertexIdToIndexMap],
+ classOf[VertexAttributeBlock[Object]],
+ classOf[PartitionStrategy],
+ classOf[BoundedPriorityQueue[Object]],
+ classOf[EdgeDirection],
+ classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]],
+ classOf[OpenHashSet[Int]],
+ classOf[OpenHashSet[Long]]))
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
index 47594a800a3b1..a3e28efc75a98 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
@@ -17,9 +17,6 @@
package org.apache.spark.graphx
-import org.scalatest.Suite
-import org.scalatest.BeforeAndAfterEach
-
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
@@ -31,8 +28,7 @@ trait LocalSparkContext {
/** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */
def withSpark[T](f: SparkContext => T) = {
val conf = new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+ GraphXUtils.registerKryoClasses(conf)
val sc = new SparkContext("local", "test", conf)
try {
f(sc)
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index 9d00f76327e4c..db1dac6160080 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -129,9 +129,9 @@ class EdgePartitionSuite extends FunSuite {
val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
val a: EdgePartition[Int, Int] = makeEdgePartition(aList)
val javaSer = new JavaSerializer(new SparkConf())
- val kryoSer = new KryoSerializer(new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+ val conf = new SparkConf()
+ GraphXUtils.registerKryoClasses(conf)
+ val kryoSer = new KryoSerializer(conf)
for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
index f9e771a900013..fe8304c1cdc32 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -125,9 +125,9 @@ class VertexPartitionSuite extends FunSuite {
val verts = Set((0L, 1), (1L, 1), (2L, 1))
val vp = VertexPartition(verts.iterator)
val javaSer = new JavaSerializer(new SparkConf())
- val kryoSer = new KryoSerializer(new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+ val conf = new SparkConf()
+ GraphXUtils.registerKryoClasses(conf)
+ val kryoSer = new KryoSerializer(conf)
for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
val vpSer: VertexPartition[Int] = s.deserialize(s.serialize(vp))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 9a100170b75c6..b478c21537c2a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -673,6 +673,11 @@ private[spark] object SerDe extends Serializable {
rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
}
+ /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
+ def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = {
+ rdd.map(x => Array(x._1, x._2))
+ }
+
/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala
index 87bdc8558aaf5..c67a6d3ae6cce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.api
/**
- * Internal support for MLLib Python API.
+ * Internal support for MLlib Python API.
*
* @see [[org.apache.spark.mllib.api.python.PythonMLLibAPI]]
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
new file mode 100644
index 0000000000000..93a7353e2c070
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+
+/**
+ * ::Experimental::
+ * Evaluator for ranking algorithms.
+ *
+ * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
+ */
+@Experimental
+class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
+ extends Logging with Serializable {
+
+ /**
+ * Compute the average precision of all the queries, truncated at ranking position k.
+ *
+ * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be
+ * computed as #(relevant items retrieved) / k. This formula also applies when the size of the
+ * ground truth set is less than k.
+ *
+ * If a query has an empty ground truth set, zero will be used as precision together with
+ * a log warning.
+ *
+ * See the following paper for detail:
+ *
+ * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+ *
+ * @param k the position to compute the truncated precision, must be positive
+ * @return the average precision at the first k ranking positions
+ */
+ def precisionAt(k: Int): Double = {
+ require(k > 0, "ranking position k should be positive")
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ val n = math.min(pred.length, k)
+ var i = 0
+ var cnt = 0
+ while (i < n) {
+ if (labSet.contains(pred(i))) {
+ cnt += 1
+ }
+ i += 1
+ }
+ cnt.toDouble / k
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+ /**
+ * Returns the mean average precision (MAP) of all the queries.
+ * If a query has an empty ground truth set, the average precision will be zero and a log
+ * warining is generated.
+ */
+ lazy val meanAveragePrecision: Double = {
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ var i = 0
+ var cnt = 0
+ var precSum = 0.0
+ val n = pred.length
+ while (i < n) {
+ if (labSet.contains(pred(i))) {
+ cnt += 1
+ precSum += cnt.toDouble / (i + 1)
+ }
+ i += 1
+ }
+ precSum / labSet.size
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+ /**
+ * Compute the average NDCG value of all the queries, truncated at ranking position k.
+ * The discounted cumulative gain at position k is computed as:
+ * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
+ * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
+ * implementation, the relevance value is binary.
+
+ * If a query has an empty ground truth set, zero will be used as ndcg together with
+ * a log warning.
+ *
+ * See the following paper for detail:
+ *
+ * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+ *
+ * @param k the position to compute the truncated ndcg, must be positive
+ * @return the average ndcg at the first k ranking positions
+ */
+ def ndcgAt(k: Int): Double = {
+ require(k > 0, "ranking position k should be positive")
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ val labSetSize = labSet.size
+ val n = math.min(math.max(pred.length, labSetSize), k)
+ var maxDcg = 0.0
+ var dcg = 0.0
+ var i = 0
+ while (i < n) {
+ val gain = 1.0 / math.log(i + 2)
+ if (labSet.contains(pred(i))) {
+ dcg += gain
+ }
+ if (i < labSetSize) {
+ maxDcg += gain
+ }
+ i += 1
+ }
+ dcg / maxDcg
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
new file mode 100644
index 0000000000000..a2d4bb41484b8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.util.LocalSparkContext
+
+class RankingMetricsSuite extends FunSuite with LocalSparkContext {
+ test("Ranking metrics: map, ndcg") {
+ val predictionAndLabels = sc.parallelize(
+ Seq(
+ (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
+ (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
+ (Array[Int](1, 2, 3, 4, 5), Array[Int]())
+ ), 2)
+ val eps: Double = 1E-5
+
+ val metrics = new RankingMetrics(predictionAndLabels)
+ val map = metrics.meanAveragePrecision
+
+ assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps)
+ assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps)
+ assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps)
+ assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps)
+ assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps)
+ assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps)
+ assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps)
+
+ assert(map ~== 0.355026 absTol eps)
+
+ assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps)
+ assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
+ assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
+ assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
+
+ }
+}
diff --git a/pom.xml b/pom.xml
index 288bbf1114bea..a7e71f9ca5596 100644
--- a/pom.xml
+++ b/pom.xml
@@ -428,6 +428,11 @@
+
+ org.roaringbitmap
+ RoaringBitmap
+ 0.4.1
+
commons-net
commons-net
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 01a5b20e7c51d..705937e3016e2 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -22,6 +22,7 @@ import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
import sbtunidoc.Plugin.genjavadocSettings
+import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings}
import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
import net.virtualvoid.sbt.graph.Plugin.graphSettings
@@ -116,6 +117,7 @@ object SparkBuild extends PomBuild {
retrieveManaged := true,
retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
publishMavenStyle := true,
+ unidocGenjavadocVersion := "0.8",
resolvers += Resolver.mavenLocal,
otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))),
diff --git a/project/build.properties b/project/build.properties
index c12ef652adfcb..32a3aeefaf9fb 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-sbt.version=0.13.5
+sbt.version=0.13.6
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 678f5ed1ba610..9d50a50b109af 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -4,6 +4,8 @@ resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.
resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/"
+resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/"
+
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2")
addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0")
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 17f96b8700bd7..22872dbbe3b55 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -53,6 +53,23 @@ class MatrixFactorizationModel(object):
>>> model = ALS.train(ratings, 1)
>>> model.predictAll(testset).count() == 2
True
+
+ >>> model = ALS.train(ratings, 4)
+ >>> model.userFeatures().count() == 2
+ True
+
+ >>> first_user = model.userFeatures().take(1)[0]
+ >>> latents = first_user[1]
+ >>> len(latents) == 4
+ True
+
+ >>> model.productFeatures().count() == 2
+ True
+
+ >>> first_product = model.productFeatures().take(1)[0]
+ >>> latents = first_product[1]
+ >>> len(latents) == 4
+ True
"""
def __init__(self, sc, java_model):
@@ -83,6 +100,20 @@ def predictAll(self, user_product):
return RDD(sc._jvm.SerDe.javaToPython(jresult), sc,
AutoBatchedSerializer(PickleSerializer()))
+ def userFeatures(self):
+ sc = self._context
+ juf = self._java_model.userFeatures()
+ juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD()
+ return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc,
+ AutoBatchedSerializer(PickleSerializer()))
+
+ def productFeatures(self):
+ sc = self._context
+ jpf = self._java_model.productFeatures()
+ jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD()
+ return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc,
+ AutoBatchedSerializer(PickleSerializer()))
+
class ALS(object):
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index a6019dadf781c..84baf12b906df 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -22,7 +22,7 @@
from functools import wraps
from pyspark import PickleSerializer
-from pyspark.mllib.linalg import _to_java_object_rdd
+from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd
__all__ = ['MultivariateStatisticalSummary', 'Statistics']
@@ -107,7 +107,7 @@ def colStats(rdd):
array([ 2., 0., 0., -2.])
"""
sc = rdd.ctx
- jrdd = _to_java_object_rdd(rdd)
+ jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector))
cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd)
return MultivariateStatisticalSummary(sc, cStats)
@@ -163,14 +163,15 @@ def corr(x, y=None, method=None):
if type(y) == str:
raise TypeError("Use 'method=' to specify method name.")
- jx = _to_java_object_rdd(x)
if not y:
+ jx = _to_java_object_rdd(x.map(_convert_to_vector))
resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method)
bytes = sc._jvm.SerDe.dumps(resultMat)
ser = PickleSerializer()
return ser.loads(str(bytes)).toArray()
else:
- jy = _to_java_object_rdd(y)
+ jx = _to_java_object_rdd(x.map(float))
+ jy = _to_java_object_rdd(y.map(float))
return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 463faf7b6f520..d6fb87b378b4a 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -36,6 +36,8 @@
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.random import RandomRDDs
+from pyspark.mllib.stat import Statistics
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -202,6 +204,23 @@ def test_regression(self):
self.assertTrue(dt_model.predict(features[3]) > 0)
+class StatTests(PySparkTestCase):
+ # SPARK-4023
+ def test_col_with_different_rdds(self):
+ # numpy
+ data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(1000, summary.count())
+ # array
+ data = self.sc.parallelize([range(10)] * 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(10, summary.count())
+ # array
+ data = self.sc.parallelize([pyarray.array("d", range(10))] * 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(10, summary.count())
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 55e247da0e4dc..528a181e8905a 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -31,7 +31,7 @@ def __init__(self, withReplacement, seed=None):
"Falling back to default random generator for sampling.")
self._use_numpy = False
- self._seed = seed if seed is not None else random.randint(0, sys.maxint)
+ self._seed = seed if seed is not None else random.randint(0, 2 ** 32 - 1)
self._withReplacement = withReplacement
self._random = None
self._split = None
@@ -47,7 +47,7 @@ def initRandomGenerator(self, split):
for _ in range(0, split):
# discard the next few values in the sequence to have a
# different seed for the different splits
- self._random.randint(0, sys.maxint)
+ self._random.randint(0, 2 ** 32 - 1)
self._split = split
self._rand_initialized = True
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index dc9dc41121935..2f53fbd27b17a 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -79,7 +79,7 @@ class StreamingContext(object):
L{DStream} various input sources. It can be from an existing L{SparkContext}.
After creating and transforming DStreams, the streaming computation can
be started and stopped using `context.start()` and `context.stop()`,
- respectively. `context.awaitTransformation()` allows the current thread
+ respectively. `context.awaitTermination()` allows the current thread
to wait for the termination of the context by `stop()` or by an exception.
"""
_transformerSerializer = None
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index f5ccf31abb3fa..1a8e4150e63c3 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -433,6 +433,12 @@ def test_deleting_input_files(self):
os.unlink(tempFile.name)
self.assertRaises(Exception, lambda: filtered_data.count())
+ def test_sampling_default_seed(self):
+ # Test for SPARK-3995 (default seed setting)
+ data = self.sc.parallelize(range(1000), 1)
+ subset = data.takeSample(False, 10)
+ self.assertEqual(len(subset), 10)
+
def testAggregateByKey(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 0de29d5cffd0e..fd4f65e488259 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -67,10 +67,6 @@ class HadoopTableReader(
private val _broadcastedHiveConf =
sc.sparkContext.broadcast(new SerializableWritable(hiveExtraConf))
- def broadcastedHiveConf = _broadcastedHiveConf
-
- def hiveConf = _broadcastedHiveConf.value.value
-
override def makeRDDForTable(hiveTable: HiveTable): RDD[Row] =
makeRDDForTable(
hiveTable,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 5a8eef1372e23..23d6d1c5e50fa 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -47,7 +47,7 @@ import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab
* The associated SparkContext can be accessed using `context.sparkContext`. After
* creating and transforming DStreams, the streaming computation can be started and stopped
* using `context.start()` and `context.stop()`, respectively.
- * `context.awaitTransformation()` allows the current thread to wait for the termination
+ * `context.awaitTermination()` allows the current thread to wait for the termination
* of the context by `stop()` or by an exception.
*/
class StreamingContext private[streaming] (
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 9dc26dc6b32a1..7db66c69a6d73 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -46,7 +46,7 @@ import org.apache.spark.streaming.receiver.Receiver
* org.apache.spark.api.java.JavaSparkContext (see core Spark documentation) can be accessed
* using `context.sparkContext`. After creating and transforming DStreams, the streaming
* computation can be started and stopped using `context.start()` and `context.stop()`,
- * respectively. `context.awaitTransformation()` allows the current thread to wait for the
+ * respectively. `context.awaitTermination()` allows the current thread to wait for the
* termination of a context by `stop()` or by an exception.
*/
class JavaStreamingContext(val ssc: StreamingContext) extends Closeable {
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 5c7bca4541222..9c66c785848a5 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -137,15 +137,7 @@ object Client {
System.setProperty("SPARK_YARN_MODE", "true")
val sparkConf = new SparkConf
- try {
- val args = new ClientArguments(argStrings, sparkConf)
- new Client(args, sparkConf).run()
- } catch {
- case e: Exception =>
- Console.err.println(e.getMessage)
- System.exit(1)
- }
-
- System.exit(0)
+ val args = new ClientArguments(argStrings, sparkConf)
+ new Client(args, sparkConf).run()
}
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
index 0efac4ea63702..fb0e34bf5985e 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
@@ -417,17 +417,19 @@ private[spark] trait ClientBase extends Logging {
/**
* Report the state of an application until it has exited, either successfully or
- * due to some failure, then return the application state.
+ * due to some failure, then return a pair of the yarn application state (FINISHED, FAILED,
+ * KILLED, or RUNNING) and the final application state (UNDEFINED, SUCCEEDED, FAILED,
+ * or KILLED).
*
* @param appId ID of the application to monitor.
* @param returnOnRunning Whether to also return the application state when it is RUNNING.
* @param logApplicationReport Whether to log details of the application report every iteration.
- * @return state of the application, one of FINISHED, FAILED, KILLED, and RUNNING.
+ * @return A pair of the yarn application state and the final application state.
*/
def monitorApplication(
appId: ApplicationId,
returnOnRunning: Boolean = false,
- logApplicationReport: Boolean = true): YarnApplicationState = {
+ logApplicationReport: Boolean = true): (YarnApplicationState, FinalApplicationStatus) = {
val interval = sparkConf.getLong("spark.yarn.report.interval", 1000)
var lastState: YarnApplicationState = null
while (true) {
@@ -468,11 +470,11 @@ private[spark] trait ClientBase extends Logging {
if (state == YarnApplicationState.FINISHED ||
state == YarnApplicationState.FAILED ||
state == YarnApplicationState.KILLED) {
- return state
+ return (state, report.getFinalApplicationStatus)
}
if (returnOnRunning && state == YarnApplicationState.RUNNING) {
- return state
+ return (state, report.getFinalApplicationStatus)
}
lastState = state
@@ -485,8 +487,23 @@ private[spark] trait ClientBase extends Logging {
/**
* Submit an application to the ResourceManager and monitor its state.
* This continues until the application has exited for any reason.
+ * If the application finishes with a failed, killed, or undefined status,
+ * throw an appropriate SparkException.
*/
- def run(): Unit = monitorApplication(submitApplication())
+ def run(): Unit = {
+ val (yarnApplicationState, finalApplicationStatus) = monitorApplication(submitApplication())
+ if (yarnApplicationState == YarnApplicationState.FAILED ||
+ finalApplicationStatus == FinalApplicationStatus.FAILED) {
+ throw new SparkException("Application finished with failed status")
+ }
+ if (yarnApplicationState == YarnApplicationState.KILLED ||
+ finalApplicationStatus == FinalApplicationStatus.KILLED) {
+ throw new SparkException("Application is killed")
+ }
+ if (finalApplicationStatus == FinalApplicationStatus.UNDEFINED) {
+ throw new SparkException("The final status of application is undefined")
+ }
+ }
/* --------------------------------------------------------------------------------------- *
| Methods that cannot be implemented here due to API differences across hadoop versions |
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 6bb4b82316ad4..d948a2aeedd45 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -99,7 +99,7 @@ private[spark] class YarnClientSchedulerBackend(
*/
private def waitForApplication(): Unit = {
assert(client != null && appId != null, "Application has not been submitted yet!")
- val state = client.monitorApplication(appId, returnOnRunning = true) // blocking
+ val (state, _) = client.monitorApplication(appId, returnOnRunning = true) // blocking
if (state == YarnApplicationState.FINISHED ||
state == YarnApplicationState.FAILED ||
state == YarnApplicationState.KILLED) {
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 0b43e6ee20538..addaddb711d3c 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -135,15 +135,7 @@ object Client {
System.setProperty("SPARK_YARN_MODE", "true")
val sparkConf = new SparkConf
- try {
- val args = new ClientArguments(argStrings, sparkConf)
- new Client(args, sparkConf).run()
- } catch {
- case e: Exception =>
- Console.err.println(e.getMessage)
- System.exit(1)
- }
-
- System.exit(0)
+ val args = new ClientArguments(argStrings, sparkConf)
+ new Client(args, sparkConf).run()
}
}
diff --git a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index a826b2a78a8f5..d79b85e867fcd 100644
--- a/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/stable/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -29,7 +29,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.server.MiniYARNCluster
-import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.Utils
@@ -123,21 +123,29 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
val main = YarnClusterDriver.getClass.getName().stripSuffix("$")
var result = File.createTempFile("result", null, tempDir)
- // The Client object will call System.exit() after the job is done, and we don't want
- // that because it messes up the scalatest monitoring. So replicate some of what main()
- // does here.
val args = Array("--class", main,
"--jar", "file:" + fakeSparkJar.getAbsolutePath(),
"--arg", "yarn-cluster",
"--arg", result.getAbsolutePath(),
"--num-executors", "1")
- val sparkConf = new SparkConf()
- val yarnConf = SparkHadoopUtil.get.newConfiguration(sparkConf)
- val clientArgs = new ClientArguments(args, sparkConf)
- new Client(clientArgs, yarnConf, sparkConf).run()
+ Client.main(args)
checkResult(result)
}
+ test("run Spark in yarn-cluster mode unsuccessfully") {
+ val main = YarnClusterDriver.getClass.getName().stripSuffix("$")
+
+ // Use only one argument so the driver will fail
+ val args = Array("--class", main,
+ "--jar", "file:" + fakeSparkJar.getAbsolutePath(),
+ "--arg", "yarn-cluster",
+ "--num-executors", "1")
+ val exception = intercept[SparkException] {
+ Client.main(args)
+ }
+ assert(Utils.exceptionString(exception).contains("Application finished with failed status"))
+ }
+
/**
* This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
* any sort of error when the job process finishes successfully, but the job itself fails. So
|