diff --git a/.rat-excludes b/.rat-excludes index 20e3372464386..d8bee1f8e49c9 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -44,6 +44,7 @@ SparkImports.scala SparkJLineCompletion.scala SparkJLineReader.scala SparkMemberHandlers.scala +SparkReplReporter.scala sbt sbt-launch-lib.bash plugins.sbt diff --git a/assembly/pom.xml b/assembly/pom.xml index 31a01e4d8e1de..c65192bde64c6 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -66,22 +66,22 @@ org.apache.spark - spark-repl_${scala.binary.version} + spark-streaming_${scala.binary.version} ${project.version} org.apache.spark - spark-streaming_${scala.binary.version} + spark-graphx_${scala.binary.version} ${project.version} org.apache.spark - spark-graphx_${scala.binary.version} + spark-sql_${scala.binary.version} ${project.version} org.apache.spark - spark-sql_${scala.binary.version} + spark-repl_${scala.binary.version} ${project.version} @@ -197,6 +197,11 @@ spark-hive_${scala.binary.version} ${project.version} + + + + hive-thriftserver + org.apache.spark spark-hive-thriftserver_${scala.binary.version} diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 905bbaf99b374..298641f2684de 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -20,8 +20,6 @@ # This script computes Spark's classpath and prints it to stdout; it's used by both the "run" # script and the ExecutorRunner in standalone cluster mode. -SCALA_VERSION=2.10 - # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" @@ -36,7 +34,7 @@ else CLASSPATH="$CLASSPATH:$FWDIR/conf" fi -ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION" +ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION" if [ -n "$JAVA_HOME" ]; then JAR_CMD="$JAVA_HOME/bin/jar" @@ -48,19 +46,19 @@ fi if [ -n "$SPARK_PREPEND_CLASSES" ]; then echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ "classes ahead of assembly." >&2 - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" fi # Use spark-assembly jar from either RELEASE or assembly directory @@ -123,15 +121,15 @@ fi # Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 if [[ $SPARK_TESTING == 1 ]]; then - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes" fi # Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 6d4231b204595..356b3d49b2ffe 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -36,3 +36,23 @@ if [ -z "$SPARK_ENV_LOADED" ]; then set +a fi fi + +# Setting SPARK_SCALA_VERSION if not already set. + +if [ -z "$SPARK_SCALA_VERSION" ]; then + + ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" + ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" + + if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then + echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 + echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 + exit 1 + fi + + if [ -d "$ASSEMBLY_DIR2" ]; then + export SPARK_SCALA_VERSION="2.11" + else + export SPARK_SCALA_VERSION="2.10" + fi +fi diff --git a/bin/pyspark b/bin/pyspark index 96f30a260a09e..0b4f695dd06dd 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,7 +25,7 @@ export SPARK_HOME="$FWDIR" source "$FWDIR/bin/utils.sh" -SCALA_VERSION=2.10 +source "$FWDIR"/bin/load-spark-env.sh function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 @@ -40,7 +40,7 @@ fi # Exit if the user hasn't compiled Spark if [ ! -f "$FWDIR/RELEASE" ]; then # Exit if the user hasn't compiled Spark - ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null + ls "$FWDIR"/assembly/target/scala-$SPARK_SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null if [[ $? != 0 ]]; then echo "Failed to find Spark assembly in $FWDIR/assembly/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 @@ -48,8 +48,6 @@ if [ ! -f "$FWDIR/RELEASE" ]; then fi fi -. "$FWDIR"/bin/load-spark-env.sh - # In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` # executable, while the worker would still be launched using PYSPARK_PYTHON. # @@ -134,7 +132,5 @@ if [[ "$1" =~ \.py$ ]]; then gatherSparkSubmitOpts "$@" exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}" else - # PySpark shell requires special handling downstream - export PYSPARK_SHELL=1 exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS fi diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 59415e9bdec2c..a542ec80b49d6 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -59,7 +59,6 @@ for /f %%i in ('echo %1^| findstr /R "\.py"') do ( ) if [%PYTHON_FILE%] == [] ( - set PYSPARK_SHELL=1 if [%IPYTHON%] == [1] ( ipython %IPYTHON_OPTS% ) else ( diff --git a/bin/run-example b/bin/run-example index 34dd71c71880e..3d932509426fc 100755 --- a/bin/run-example +++ b/bin/run-example @@ -17,12 +17,12 @@ # limitations under the License. # -SCALA_VERSION=2.10 - FWDIR="$(cd "`dirname "$0"`"/..; pwd)" export SPARK_HOME="$FWDIR" EXAMPLES_DIR="$FWDIR"/examples +. "$FWDIR"/bin/load-spark-env.sh + if [ -n "$1" ]; then EXAMPLE_CLASS="$1" shift @@ -36,8 +36,8 @@ fi if [ -f "$FWDIR/RELEASE" ]; then export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`" -elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then - export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`" +elif [ -e "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar ]; then + export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar`" fi if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then diff --git a/bin/spark-class b/bin/spark-class index 925367b0dd187..0d58d95c1aee3 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -24,8 +24,6 @@ case "`uname`" in CYGWIN*) cygwin=true;; esac -SCALA_VERSION=2.10 - # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" @@ -128,9 +126,9 @@ fi TOOLS_DIR="$FWDIR"/tools SPARK_TOOLS_JAR="" -if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then +if [ -e "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the SBT build - export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar`" + export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar`" fi if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the Maven build @@ -149,7 +147,7 @@ fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then if test -z "$SPARK_TOOLS_JAR"; then - echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2 + echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2 echo "You need to build Spark before running $1." 1>&2 exit 1 fi diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index f8ffbf64278fb..0886b0276fb90 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -28,7 +28,7 @@ # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. # - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job. -# Options for the daemons used in the standalone deploy mode: +# Options for the daemons used in the standalone deploy mode # - SPARK_MASTER_IP, to bind the master to a different IP address or hostname # - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports for the master # - SPARK_MASTER_OPTS, to set config properties only for the master (e.g. "-Dx=y") @@ -41,3 +41,10 @@ # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers + +# Generic options for the daemons used in the standalone deploy mode +# - SPARK_CONF_DIR Alternate conf dir. (Default: ${SPARK_HOME}/conf) +# - SPARK_LOG_DIR Where log files are stored. (Default: ${SPARK_HOME}/logs) +# - SPARK_PID_DIR Where the pid file is stored. (Default: /tmp) +# - SPARK_IDENT_STRING A string representing this instance of spark. (Default: $USER) +# - SPARK_NICENESS The scheduling priority for daemons. (Default: 0) diff --git a/core/pom.xml b/core/pom.xml index 41296e0eca330..492eddda744c2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -34,6 +34,34 @@ Spark Project Core http://spark.apache.org/ + + com.twitter + chill_${scala.binary.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + + + com.twitter + chill-java + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + org.apache.hadoop hadoop-client @@ -46,12 +74,12 @@ org.apache.spark - spark-network-common_2.10 + spark-network-common_${scala.binary.version} ${project.version} org.apache.spark - spark-network-shuffle_2.10 + spark-network-shuffle_${scala.binary.version} ${project.version} @@ -132,14 +160,6 @@ net.jpountz.lz4 lz4 - - com.twitter - chill_${scala.binary.version} - - - com.twitter - chill-java - org.roaringbitmap RoaringBitmap @@ -309,14 +329,16 @@ org.scalatest scalatest-maven-plugin - - - ${basedir}/.. - 1 - ${spark.classpath} - - + + + test + + test + + + + org.apache.maven.plugins @@ -424,4 +446,5 @@ + diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index ef93009a074e7..88adb892998af 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -28,7 +28,9 @@ import org.apache.spark.scheduler._ * the scheduler queue is not drained in N seconds, then new executors are added. If the queue * persists for another M seconds, then more executors are added and so on. The number added * in each round increases exponentially from the previous round until an upper bound on the - * number of executors has been reached. + * number of executors has been reached. The upper bound is based both on a configured property + * and on the number of tasks pending: the policy will never increase the number of executor + * requests past the number needed to handle all pending tasks. * * The rationale for the exponential increase is twofold: (1) Executors should be added slowly * in the beginning in case the number of extra executors needed turns out to be small. Otherwise, @@ -82,6 +84,12 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + // TODO: The default value of 1 for spark.executor.cores works right now because dynamic + // allocation is only supported for YARN and the default number of cores per executor in YARN is + // 1, but it might need to be attained differently for different cluster managers + private val tasksPerExecutor = + conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + validateSettings() // Number of executors to add in the next round @@ -110,6 +118,9 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Clock used to schedule when executors should be added and removed private var clock: Clock = new RealClock + // Listener for Spark events that impact the allocation policy + private val listener = new ExecutorAllocationListener(this) + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -141,6 +152,9 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } + if (tasksPerExecutor == 0) { + throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores") + } } /** @@ -154,7 +168,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging * Register for scheduler callbacks to decide when to add and remove executors. */ def start(): Unit = { - val listener = new ExecutorAllocationListener(this) sc.addSparkListener(listener) startPolling() } @@ -218,13 +231,27 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging return 0 } - // Request executors with respect to the upper bound - val actualNumExecutorsToAdd = - if (numExistingExecutors + numExecutorsToAdd <= maxNumExecutors) { - numExecutorsToAdd - } else { - maxNumExecutors - numExistingExecutors - } + // The number of executors needed to satisfy all pending tasks is the number of tasks pending + // divided by the number of tasks each executor can fit, rounded up. + val maxNumExecutorsPending = + (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor + if (numExecutorsPending >= maxNumExecutorsPending) { + logDebug(s"Not adding executors because there are already $numExecutorsPending " + + s"pending and pending tasks could only fill $maxNumExecutorsPending") + numExecutorsToAdd = 1 + return 0 + } + + // It's never useful to request more executors than could satisfy all the pending tasks, so + // cap request at that amount. + // Also cap request with respect to the configured upper bound. + val maxNumExecutorsToAdd = math.min( + maxNumExecutorsPending - numExecutorsPending, + maxNumExecutors - numExistingExecutors) + assert(maxNumExecutorsToAdd > 0) + + val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd) + val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd val addRequestAcknowledged = testing || sc.requestExecutors(actualNumExecutorsToAdd) if (addRequestAcknowledged) { @@ -445,6 +472,16 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId) } + + /** + * An estimate of the total number of pending tasks remaining for currently running stages. Does + * not account for tasks which may have failed and been resubmitted. + */ + def totalPendingTasks(): Int = { + stageIdToNumTasks.map { case (stageId, numTasks) => + numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0) + }.sum + } } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 03ea672c813d1..7cccf74003431 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -25,6 +25,7 @@ import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger import java.util.UUID.randomUUID import scala.collection.{Map, Set} +import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} @@ -57,11 +58,25 @@ import org.apache.spark.util._ * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. + * * @param config a Spark Config object describing the application configuration. Any settings in * this config overrides the default configs as well as system properties. */ +class SparkContext(config: SparkConf) extends Logging { + + // The call site where this SparkContext was constructed. + private val creationSite: CallSite = Utils.getCallSite() + + // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active + private val allowMultipleContexts: Boolean = + config.getBoolean("spark.driver.allowMultipleContexts", false) -class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having started construction. + // NOTE: this must be placed at the beginning of the SparkContext constructor. + SparkContext.markPartiallyConstructed(this, allowMultipleContexts) // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It @@ -228,6 +243,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { private[spark] val jobProgressListener = new JobProgressListener(conf) listenerBus.addListener(jobProgressListener) + val statusTracker = new SparkStatusTracker(this) + // Initialize the Spark UI private[spark] val ui: Option[SparkUI] = if (conf.getBoolean("spark.ui.enabled", true)) { @@ -1001,6 +1018,69 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** The version of Spark on which this application is running. */ def version = SPARK_VERSION + /** + * Return a map from the slave to the max memory available for caching and the remaining + * memory available for caching. + */ + def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => + (blockManagerId.host + ":" + blockManagerId.port, mem) + } + } + + /** + * :: DeveloperApi :: + * Return information about what RDDs are cached, if they are in mem or on disk, how much space + * they take, etc. + */ + @DeveloperApi + def getRDDStorageInfo: Array[RDDInfo] = { + val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.filter(_.isCached) + } + + /** + * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. + * Note that this does not necessarily mean the caching or computation was successful. + */ + def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap + + /** + * :: DeveloperApi :: + * Return information about blocks stored in all of the slaves + */ + @DeveloperApi + def getExecutorStorageStatus: Array[StorageStatus] = { + env.blockManager.master.getStorageStatus + } + + /** + * :: DeveloperApi :: + * Return pools for fair scheduler + */ + @DeveloperApi + def getAllPools: Seq[Schedulable] = { + // TODO(xiajunluan): We should take nested pools into account + taskScheduler.rootPool.schedulableQueue.toSeq + } + + /** + * :: DeveloperApi :: + * Return the pool associated with the given name, if one exists + */ + @DeveloperApi + def getPoolForName(pool: String): Option[Schedulable] = { + Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) + } + + /** + * Return current scheduling mode + */ + def getSchedulingMode: SchedulingMode.SchedulingMode = { + taskScheduler.schedulingMode + } + /** * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. @@ -1100,27 +1180,30 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** Shut down the SparkContext. */ def stop() { - postApplicationEnd() - ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { - env.metricsSystem.report() - metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) - cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - listenerBus.stop() - eventLogger.foreach(_.stop()) - logInfo("Successfully stopped SparkContext") - } else { - logInfo("SparkContext already stopped") + SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + postApplicationEnd() + ui.foreach(_.stop()) + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { + env.metricsSystem.report() + metadataCleaner.cancel() + env.actorSystem.stop(heartbeatReceiver) + cleaner.foreach(_.stop()) + dagSchedulerCopy.stop() + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + SparkEnv.set(null) + listenerBus.stop() + eventLogger.foreach(_.stop()) + logInfo("Successfully stopped SparkContext") + SparkContext.clearActiveContext() + } else { + logInfo("SparkContext already stopped") + } } } @@ -1409,6 +1492,11 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { private[spark] def cleanup(cleanupTime: Long) { persistentRdds.clearOldValues(cleanupTime) } + + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having finished construction. + // NOTE: this must be placed at the end of the SparkContext constructor. + SparkContext.setActiveContext(this, allowMultipleContexts) } /** @@ -1417,6 +1505,107 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { */ object SparkContext extends Logging { + /** + * Lock that guards access to global variables that track SparkContext construction. + */ + private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() + + /** + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var activeContext: Option[SparkContext] = None + + /** + * Points to a partially-constructed SparkContext if some thread is in the SparkContext + * constructor, or `None` if no SparkContext is being constructed. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var contextBeingConstructed: Option[SparkContext] = None + + /** + * Called to ensure that no other SparkContext is running in this JVM. + * + * Throws an exception if a running context is detected and logs a warning if another thread is + * constructing a SparkContext. This warning is necessary because the current locking scheme + * prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private def assertNoOtherContextIsRunning( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + contextBeingConstructed.foreach { otherContext => + if (otherContext ne sc) { // checks for reference equality + // Since otherContext might point to a partially-constructed context, guard against + // its creationSite field being null: + val otherContextCreationSite = + Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location") + val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" + + " constructor). This may indicate an error, since only one SparkContext may be" + + " running in this JVM (see SPARK-2243)." + + s" The other SparkContext was created at:\n$otherContextCreationSite" + logWarning(warnMsg) + } + + activeContext.foreach { ctx => + val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + + " To ignore this error, set spark.driver.allowMultipleContexts = true. " + + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" + val exception = new SparkException(errMsg) + if (allowMultipleContexts) { + logWarning("Multiple running SparkContexts detected in the same JVM!", exception) + } else { + throw exception + } + } + } + } + } + + /** + * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is + * running. Throws an exception if a running context is detected and logs a warning if another + * thread is constructing a SparkContext. This warning is necessary because the current locking + * scheme prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private[spark] def markPartiallyConstructed( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = Some(sc) + } + } + + /** + * Called at the end of the SparkContext constructor to ensure that no other SparkContext has + * raced with this constructor and started. + */ + private[spark] def setActiveContext( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = None + activeContext = Some(sc) + } + } + + /** + * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's + * also called in unit tests to prevent a flood of warnings from test suites that don't / can't + * properly clean up their SparkContexts. + */ + private[spark] def clearActiveContext(): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + activeContext = None + } + } + private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e7454beddbfd0..e464b32e61dd6 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -168,9 +168,11 @@ object SparkEnv extends Logging { executorId: String, hostname: String, port: Int, + numCores: Int, isLocal: Boolean, actorSystem: ActorSystem = null): SparkEnv = { - create(conf, executorId, hostname, port, false, isLocal, defaultActorSystem = actorSystem) + create(conf, executorId, hostname, port, false, isLocal, defaultActorSystem = actorSystem, + numUsableCores = numCores) } /** @@ -184,7 +186,8 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean, listenerBus: LiveListenerBus = null, - defaultActorSystem: ActorSystem = null): SparkEnv = { + defaultActorSystem: ActorSystem = null, + numUsableCores: Int = 0): SparkEnv = { // Listener bus is only used on the driver if (isDriver) { @@ -276,7 +279,7 @@ object SparkEnv extends Logging { val blockTransferService = conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { case "netty" => - new NettyBlockTransferService(conf, securityManager) + new NettyBlockTransferService(conf, securityManager, numUsableCores) case "nio" => new NioBlockTransferService(conf, securityManager) } @@ -287,7 +290,8 @@ object SparkEnv extends Logging { // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, + numUsableCores) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala b/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala deleted file mode 100644 index 1982499c5e1d3..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.Map -import scala.collection.JavaConversions._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{SchedulingMode, Schedulable} -import org.apache.spark.storage.{StorageStatus, StorageUtils, RDDInfo} - -/** - * Trait that implements Spark's status APIs. This trait is designed to be mixed into - * SparkContext; it allows the status API code to live in its own file. - */ -private[spark] trait SparkStatusAPI { this: SparkContext => - - /** - * Return a map from the slave to the max memory available for caching and the remaining - * memory available for caching. - */ - def getExecutorMemoryStatus: Map[String, (Long, Long)] = { - env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.host + ":" + blockManagerId.port, mem) - } - } - - /** - * :: DeveloperApi :: - * Return information about what RDDs are cached, if they are in mem or on disk, how much space - * they take, etc. - */ - @DeveloperApi - def getRDDStorageInfo: Array[RDDInfo] = { - val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray - StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) - rddInfos.filter(_.isCached) - } - - /** - * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. - */ - def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - - /** - * :: DeveloperApi :: - * Return information about blocks stored in all of the slaves - */ - @DeveloperApi - def getExecutorStorageStatus: Array[StorageStatus] = { - env.blockManager.master.getStorageStatus - } - - /** - * :: DeveloperApi :: - * Return pools for fair scheduler - */ - @DeveloperApi - def getAllPools: Seq[Schedulable] = { - // TODO(xiajunluan): We should take nested pools into account - taskScheduler.rootPool.schedulableQueue.toSeq - } - - /** - * :: DeveloperApi :: - * Return the pool associated with the given name, if one exists - */ - @DeveloperApi - def getPoolForName(pool: String): Option[Schedulable] = { - Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) - } - - /** - * Return current scheduling mode - */ - def getSchedulingMode: SchedulingMode.SchedulingMode = { - taskScheduler.schedulingMode - } - - - /** - * Return a list of all known jobs in a particular job group. The returned list may contain - * running, failed, and completed jobs, and may vary across invocations of this method. This - * method does not guarantee the order of the elements in its result. - */ - def getJobIdsForGroup(jobGroup: String): Array[Int] = { - jobProgressListener.synchronized { - val jobData = jobProgressListener.jobIdToData.valuesIterator - jobData.filter(_.jobGroup.exists(_ == jobGroup)).map(_.jobId).toArray - } - } - - /** - * Returns job information, or `None` if the job info could not be found or was garbage collected. - */ - def getJobInfo(jobId: Int): Option[SparkJobInfo] = { - jobProgressListener.synchronized { - jobProgressListener.jobIdToData.get(jobId).map { data => - new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) - } - } - } - - /** - * Returns stage information, or `None` if the stage info could not be found or was - * garbage collected. - */ - def getStageInfo(stageId: Int): Option[SparkStageInfo] = { - jobProgressListener.synchronized { - for ( - info <- jobProgressListener.stageIdToInfo.get(stageId); - data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) - ) yield { - new SparkStageInfoImpl( - stageId, - info.attemptId, - info.name, - info.numTasks, - data.numActiveTasks, - data.numCompleteTasks, - data.numFailedTasks) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala new file mode 100644 index 0000000000000..c18d763d7ff4d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Low-level status reporting APIs for monitoring job and stage progress. + * + * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should + * be prepared to handle empty / missing information. For example, a job's stage ids may be known + * but the status API may not have any information about the details of those stages, so + * `getStageInfo` could potentially return `None` for a valid stage id. + * + * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs + * will provide information for the last `spark.ui.retainedStages` stages and + * `spark.ui.retainedJobs` jobs. + * + * NOTE: this class's constructor should be considered private and may be subject to change. + */ +class SparkStatusTracker private[spark] (sc: SparkContext) { + + private val jobProgressListener = sc.jobProgressListener + + /** + * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then + * returns all known jobs that are not associated with a job group. + * + * The returned list may contain running, failed, and completed jobs, and may vary across + * invocations of this method. This method does not guarantee the order of the elements in + * its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = { + jobProgressListener.synchronized { + val jobData = jobProgressListener.jobIdToData.valuesIterator + jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + } + } + + /** + * Returns an array containing the ids of all active stages. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveStageIds(): Array[Int] = { + jobProgressListener.synchronized { + jobProgressListener.activeStages.values.map(_.stageId).toArray + } + } + + /** + * Returns an array containing the ids of all active jobs. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveJobIds(): Array[Int] = { + jobProgressListener.synchronized { + jobProgressListener.activeJobs.values.map(_.jobId).toArray + } + } + + /** + * Returns job information, or `None` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): Option[SparkJobInfo] = { + jobProgressListener.synchronized { + jobProgressListener.jobIdToData.get(jobId).map { data => + new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) + } + } + } + + /** + * Returns stage information, or `None` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): Option[SparkStageInfo] = { + jobProgressListener.synchronized { + for ( + info <- jobProgressListener.stageIdToInfo.get(stageId); + data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) + ) yield { + new SparkStageInfoImpl( + stageId, + info.attemptId, + info.name, + info.numTasks, + data.numActiveTasks, + data.numCompleteTasks, + data.numFailedTasks) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 5c6e8d32c5c8a..6a6d9bf6857d3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -42,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. + * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. */ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround with Closeable { @@ -105,6 +108,8 @@ class JavaSparkContext(val sc: SparkContext) private[spark] val env = sc.env + def statusTracker = new JavaSparkStatusTracker(sc) + def isLocal: java.lang.Boolean = sc.isLocal def sparkUser: String = sc.sparkUser @@ -134,25 +139,6 @@ class JavaSparkContext(val sc: SparkContext) /** Default min number of partitions for Hadoop RDDs when not given by user */ def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions - - /** - * Return a list of all known jobs in a particular job group. The returned list may contain - * running, failed, and completed jobs, and may vary across invocations of this method. This - * method does not guarantee the order of the elements in its result. - */ - def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.getJobIdsForGroup(jobGroup) - - /** - * Returns job information, or `null` if the job info could not be found or was garbage collected. - */ - def getJobInfo(jobId: Int): SparkJobInfo = sc.getJobInfo(jobId).orNull - - /** - * Returns stage information, or `null` if the stage info could not be found or was - * garbage collected. - */ - def getStageInfo(stageId: Int): SparkStageInfo = sc.getStageInfo(stageId).orNull - /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala new file mode 100644 index 0000000000000..3300cad9efbab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext} + +/** + * Low-level status reporting APIs for monitoring job and stage progress. + * + * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should + * be prepared to handle empty / missing information. For example, a job's stage ids may be known + * but the status API may not have any information about the details of those stages, so + * `getStageInfo` could potentially return `null` for a valid stage id. + * + * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs + * will provide information for the last `spark.ui.retainedStages` stages and + * `spark.ui.retainedJobs` jobs. + * + * NOTE: this class's constructor should be considered private and may be subject to change. + */ +class JavaSparkStatusTracker private[spark] (sc: SparkContext) { + + /** + * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then + * returns all known jobs that are not associated with a job group. + * + * The returned list may contain running, failed, and completed jobs, and may vary across + * invocations of this method. This method does not guarantee the order of the elements in + * its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.statusTracker.getJobIdsForGroup(jobGroup) + + /** + * Returns an array containing the ids of all active stages. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveStageIds(): Array[Int] = sc.statusTracker.getActiveStageIds() + + /** + * Returns an array containing the ids of all active jobs. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveJobIds(): Array[Int] = sc.statusTracker.getActiveJobIds() + + /** + * Returns job information, or `null` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): SparkJobInfo = sc.statusTracker.getJobInfo(jobId).orNull + + /** + * Returns stage information, or `null` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): SparkStageInfo = sc.statusTracker.getStageInfo(stageId).orNull +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 87f5cf944ed85..a5ea478f231d7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -39,7 +39,7 @@ import scala.reflect.ClassTag * * {{{ * scala> val broadcastVar = sc.broadcast(Array(1, 2, 3)) - * broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c) + * broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0) * * scala> broadcastVar.value * res0: Array[Int] = Array(1, 2, 3) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 7dade04273b08..31f0a462f84d8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -191,10 +191,12 @@ private[broadcast] object HttpBroadcast extends Logging { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) uc = newuri.toURL.openConnection() + uc.setConnectTimeout(httpReadTimeout) uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") uc = new URL(url).openConnection() + uc.setConnectTimeout(httpReadTimeout) } val in = { diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 4e802e02c4149..39150deab863c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -17,8 +17,6 @@ package org.apache.spark.deploy -import java.net.{URI, URISyntaxException} - import scala.collection.mutable.ListBuffer import org.apache.log4j.Level @@ -116,12 +114,5 @@ private[spark] class ClientArguments(args: Array[String]) { } object ClientArguments { - def isValidJarUrl(s: String): Boolean = { - try { - val uri = new URI(s) - uri.getScheme != null && uri.getAuthority != null && s.endsWith("jar") - } catch { - case _: URISyntaxException => false - } - } + def isValidJarUrl(s: String): Boolean = s.matches("(.+):(.+)jar") } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index af94b05ce3847..039c8719e2867 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -87,8 +87,8 @@ object PythonRunner { // Strip the URI scheme from the path formattedPath = new URI(formattedPath).getScheme match { - case Utils.windowsDrive(d) if windows => formattedPath case null => formattedPath + case Utils.windowsDrive(d) if windows => formattedPath case _ => new URI(formattedPath).getPath } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b43e68e40f791..8a62519bd2315 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -340,7 +340,7 @@ object SparkSubmit { e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { println(s"Failed to load main class $childMainClass.") - println("You need to build Spark with -Phive.") + println("You need to build Spark with -Phive and -Phive-thriftserver.") } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index 2b894a796c8c6..aa3743ca7df63 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -129,6 +129,16 @@ private[spark] object SparkSubmitDriverBootstrapper { val process = builder.start() + // If we kill an app while it's running, its sub-process should be killed too. + Runtime.getRuntime().addShutdownHook(new Thread() { + override def run() = { + if (process != null) { + process.destroy() + sys.exit(process.waitFor()) + } + } + }) + // Redirect stdout and stderr from the child JVM val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout") val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr") @@ -139,14 +149,15 @@ private[spark] object SparkSubmitDriverBootstrapper { // subprocess there already reads directly from our stdin, so we should avoid spawning a // thread that contends with the subprocess in reading from System.in. val isWindows = Utils.isWindows - val isPySparkShell = sys.env.contains("PYSPARK_SHELL") + val isSubprocess = sys.env.contains("IS_SUBPROCESS") if (!isWindows) { val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin") stdinThread.start() - // For the PySpark shell, Spark submit itself runs as a python subprocess, and so this JVM - // should terminate on broken pipe, which signals that the parent process has exited. In - // Windows, the termination logic for the PySpark shell is handled in java_gateway.py - if (isPySparkShell) { + // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on + // broken pipe, signaling that the parent process has exited. This is the case if the + // application is launched directly from python, as in the PySpark shell. In Windows, + // the termination logic is handled in java_gateway.py + if (isSubprocess) { stdinThread.join() process.destroy() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala index d044e1d01d429..b9798963bab0a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala @@ -39,7 +39,7 @@ class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: Secu private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) private val useSasl: Boolean = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(sparkConf) + private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) private val blockHandler = new ExternalShuffleBlockHandler(transportConf) private val transportContext: TransportContext = { val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 3711824a40cfc..5f46f3b1f085e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -57,9 +57,9 @@ private[spark] class CoarseGrainedExecutorBackend( override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") - // Make this host instead of hostPort ? val (hostname, _) = Utils.parseHostPort(hostPort) - executor = new Executor(executorId, hostname, sparkProperties, isLocal = false, actorSystem) + executor = new Executor(executorId, hostname, sparkProperties, cores, isLocal = false, + actorSystem) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index caf4d76713d49..4c378a278b4c1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -43,6 +43,7 @@ private[spark] class Executor( executorId: String, slaveHostname: String, properties: Seq[(String, String)], + numCores: Int, isLocal: Boolean = false, actorSystem: ActorSystem = null) extends Logging @@ -83,7 +84,7 @@ private[spark] class Executor( if (!isLocal) { val port = conf.getInt("spark.executor.port", 0) val _env = SparkEnv.createExecutorEnv( - conf, executorId, slaveHostname, port, isLocal, actorSystem) + conf, executorId, slaveHostname, port, numCores, isLocal, actorSystem) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) _env.blockManager.initialize(conf.getAppId) diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index bca0b152268ad..f15e6bc33fb41 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -19,6 +19,8 @@ package org.apache.spark.executor import java.nio.ByteBuffer +import scala.collection.JavaConversions._ + import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver, MesosNativeLibrary} import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} @@ -50,14 +52,23 @@ private[spark] class MesosExecutorBackend executorInfo: ExecutorInfo, frameworkInfo: FrameworkInfo, slaveInfo: SlaveInfo) { - logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) + + // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend. + val cpusPerTask = executorInfo.getResourcesList + .find(_.getName == "cpus") + .map(_.getScalar.getValue.toInt) + .getOrElse(0) + val executorId = executorInfo.getExecutorId.getValue + + logInfo(s"Registered with Mesos as executor ID $executorId with $cpusPerTask cpus") this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++ Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue)) executor = new Executor( - executorInfo.getExecutorId.getValue, + executorId, slaveInfo.getHostname, - properties) + properties, + cpusPerTask) } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 183bce3d8d8d3..d3601cca832b2 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -19,14 +19,13 @@ package org.apache.spark.input import scala.collection.JavaConversions._ +import org.apache.hadoop.conf.{Configuration, Configurable} import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.InputSplit import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader -import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit /** * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for @@ -34,17 +33,24 @@ import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit * the value is the entire content of file. */ -private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] { +private[spark] class WholeTextFileInputFormat + extends CombineFileInputFormat[String, String] with Configurable { + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + private var conf: Configuration = _ + def setConf(c: Configuration) { + conf = c + } + def getConf: Configuration = conf + override def createRecordReader( split: InputSplit, context: TaskAttemptContext): RecordReader[String, String] = { - new CombineFileRecordReader[String, String]( - split.asInstanceOf[CombineFileSplit], - context, - classOf[WholeTextFileRecordReader]) + val reader = new WholeCombineFileRecordReader(split, context) + reader.setConf(conf) + reader } /** diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index 3564ab2e2a162..6d59b24eb0596 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -17,11 +17,13 @@ package org.apache.spark.input +import org.apache.hadoop.conf.{Configuration, Configurable} import com.google.common.io.{ByteStreams, Closeables} import org.apache.hadoop.io.Text +import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.InputSplit -import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader} import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -34,7 +36,13 @@ private[spark] class WholeTextFileRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) - extends RecordReader[String, String] { + extends RecordReader[String, String] with Configurable { + + private var conf: Configuration = _ + def setConf(c: Configuration) { + conf = c + } + def getConf: Configuration = conf private[this] val path = split.getPath(index) private[this] val fs = path.getFileSystem(context.getConfiguration) @@ -57,8 +65,16 @@ private[spark] class WholeTextFileRecordReader( override def nextKeyValue(): Boolean = { if (!processed) { + val conf = new Configuration + val factory = new CompressionCodecFactory(conf) + val codec = factory.getCodec(path) // infers from file ext. val fileIn = fs.open(path) - val innerBuffer = ByteStreams.toByteArray(fileIn) + val innerBuffer = if (codec != null) { + ByteStreams.toByteArray(codec.createInputStream(fileIn)) + } else { + ByteStreams.toByteArray(fileIn) + } + value = new Text(innerBuffer).toString Closeables.close(fileIn, false) processed = true @@ -68,3 +84,33 @@ private[spark] class WholeTextFileRecordReader( } } } + + +/** + * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file + * out in a key-value pair, where the key is the file path and the value is the entire content of + * the file. + */ +private[spark] class WholeCombineFileRecordReader( + split: InputSplit, + context: TaskAttemptContext) + extends CombineFileRecordReader[String, String]( + split.asInstanceOf[CombineFileSplit], + context, + classOf[WholeTextFileRecordReader] + ) with Configurable { + + private var conf: Configuration = _ + def setConf(c: Configuration) { + conf = c + } + def getConf: Configuration = conf + + override def initNextRecordReader(): Boolean = { + val r = super.initNextRecordReader() + if (r) { + this.curReader.asInstanceOf[WholeTextFileRecordReader].setConf(conf) + } + r + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index f8a7f640689a2..0027cbb0ff1fb 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -35,13 +35,13 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager) +class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int) extends BlockTransferService { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. private val serializer = new JavaSerializer(conf) private val authEnabled = securityManager.isAuthenticationEnabled() - private val transportConf = SparkTransportConf.fromSparkConf(conf) + private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index 9fa4fa77b8817..ce4225cae6d88 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -20,11 +20,22 @@ package org.apache.spark.network.netty import org.apache.spark.SparkConf import org.apache.spark.network.util.{TransportConf, ConfigProvider} -/** - * Utility for creating a [[TransportConf]] from a [[SparkConf]]. - */ object SparkTransportConf { - def fromSparkConf(conf: SparkConf): TransportConf = { + /** + * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param numUsableCores if nonzero, this will restrict the server and client threads to only + * use the given number of cores, rather than all of the machine's cores. + * This restriction will only occur if these properties are not already set. + */ + def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + val conf = _conf.clone + if (numUsableCores > 0) { + // Only set if serverThreads/clientThreads not already set. + conf.set("spark.shuffle.io.serverThreads", + conf.get("spark.shuffle.io.serverThreads", numUsableCores.toString)) + conf.set("spark.shuffle.io.clientThreads", + conf.get("spark.shuffle.io.clientThreads", numUsableCores.toString)) + } new TransportConf(new ConfigProvider { override def get(name: String): String = conf.get(name) }) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index f198aa8564a54..df4b085d2251e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -18,13 +18,13 @@ package org.apache.spark.network.nio import java.io.IOException +import java.lang.ref.WeakReference import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} -import java.util.{Timer, TimerTask} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.concurrent.duration._ @@ -32,6 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 +import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} import org.apache.spark._ import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} @@ -77,7 +78,8 @@ private[nio] class ConnectionManager( } private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) + private val ackTimeoutMonitor = + new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) @@ -139,7 +141,10 @@ private[nio] class ConnectionManager( new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - private val messageStatuses = new HashMap[Int, MessageStatus] + // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this + // map when messages are sent and are removed when acknowledgement messages are received or when + // acknowledgement timeouts expire + private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val registerRequests = new SynchronizedQueue[SendingConnection] @@ -899,22 +904,41 @@ private[nio] class ConnectionManager( : Future[Message] = { val promise = Promise[Message]() - val timeoutTask = new TimerTask { - override def run(): Unit = { + // It's important that the TimerTask doesn't capture a reference to `message`, which can cause + // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time + // at which they would originally be scheduled to run. Therefore, extract the message id + // from outside of the TimerTask closure (see SPARK-4393 for more context). + val messageId = message.id + // Keep a weak reference to the promise so that the completed promise may be garbage-collected + val promiseReference = new WeakReference(promise) + val timeoutTask: TimerTask = new TimerTask { + override def run(timeout: Timeout): Unit = { messageStatuses.synchronized { - messageStatuses.remove(message.id).foreach ( s => { + messageStatuses.remove(messageId).foreach { s => val e = new IOException("sendMessageReliably failed because ack " + s"was not received within $ackTimeout sec") - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) + val p = promiseReference.get + if (p != null) { + // Attempt to fail the promise with a Timeout exception + if (!p.tryFailure(e)) { + // If we reach here, then someone else has already signalled success or failure + // on this promise, so log a warning: + logError("Ignore error because promise is completed", e) + } + } else { + // The WeakReference was empty, which should never happen because + // sendMessageReliably's caller should have a strong reference to promise.future; + logError("Promise was garbage collected; this should never happen!", e) } - }) + } } } } + val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) + val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTask.cancel() + timeoutTaskHandle.cancel() s match { case scala.util.Failure(e) => // Indicates a failure where we either never sent or never got ACK'd @@ -943,7 +967,6 @@ private[nio] class ConnectionManager( messageStatuses += ((message.id, status)) } - ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } @@ -953,7 +976,7 @@ private[nio] class ConnectionManager( } def stop() { - ackTimeoutMonitor.cancel() + ackTimeoutMonitor.stop() selectorThread.interrupt() selectorThread.join() selector.close() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 716f2dd17733b..e4025bcf48db6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1202,7 +1202,7 @@ abstract class RDD[T: ClassTag]( */ def checkpoint() { if (context.checkpointDir.isEmpty) { - throw new Exception("Checkpoint directory has not been set in the SparkContext") + throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { checkpointData = Some(new RDDCheckpointData(this)) checkpointData.get.markForCheckpoint() @@ -1309,7 +1309,7 @@ abstract class RDD[T: ClassTag]( def debugSelf (rdd: RDD[_]): Seq[String] = { import Utils.bytesToString - val persistence = storageLevel.description + val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => " CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index c5f3493477bc5..d13795186c48e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -166,29 +166,16 @@ private[spark] class MesosSchedulerBackend( execArgs } - private def setClassLoader(): ClassLoader = { - val oldClassLoader = Thread.currentThread.getContextClassLoader - Thread.currentThread.setContextClassLoader(classLoader) - oldClassLoader - } - - private def restoreClassLoader(oldClassLoader: ClassLoader) { - Thread.currentThread.setContextClassLoader(oldClassLoader) - } - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) registeredLock.synchronized { isRegistered = true registeredLock.notifyAll() } - } finally { - restoreClassLoader(oldClassLoader) } } @@ -200,6 +187,16 @@ private[spark] class MesosSchedulerBackend( } } + private def inClassLoader()(fun: => Unit) = { + val oldClassLoader = Thread.currentThread.getContextClassLoader + Thread.currentThread.setContextClassLoader(classLoader) + try { + fun + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) + } + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} @@ -210,66 +207,57 @@ private[spark] class MesosSchedulerBackend( * tasks are balanced across the cluster. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - val oldClassLoader = setClassLoader() - try { - synchronized { - // Build a big list of the offerable workers, and remember their indices so that we can - // figure out which Offer to reply to for each worker - val offerableWorkers = new ArrayBuffer[WorkerOffer] - val offerableIndices = new HashMap[String, Int] - - def sufficientOffer(o: Offer) = { - val mem = getResource(o.getResourcesList, "mem") - val cpus = getResource(o.getResourcesList, "cpus") - val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= 2 * scheduler.CPUS_PER_TASK) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) - } + inClassLoader() { + val (acceptedOffers, declinedOffers) = offers.partition { o => + val mem = getResource(o.getResourcesList, "mem") + val cpus = getResource(o.getResourcesList, "cpus") + val slaveId = o.getSlaveId.getValue + (mem >= MemoryUtils.calculateTotalMemory(sc) && + // need at least 1 for executor, 1 for task + cpus >= 2 * scheduler.CPUS_PER_TASK) || + (slaveIdsWithExecutors.contains(slaveId) && + cpus >= scheduler.CPUS_PER_TASK) + } - for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) { - val slaveId = offer.getSlaveId.getValue - offerableIndices.put(slaveId, index) - val cpus = if (slaveIdsWithExecutors.contains(slaveId)) { - getResource(offer.getResourcesList, "cpus").toInt - } else { - // If the executor doesn't exist yet, subtract CPU for executor - getResource(offer.getResourcesList, "cpus").toInt - - scheduler.CPUS_PER_TASK - } - offerableWorkers += new WorkerOffer( - offer.getSlaveId.getValue, - offer.getHostname, - cpus) + val offerableWorkers = acceptedOffers.map { o => + val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { + getResource(o.getResourcesList, "cpus").toInt + } else { + // If the executor doesn't exist yet, subtract CPU for executor + getResource(o.getResourcesList, "cpus").toInt - + scheduler.CPUS_PER_TASK } + new WorkerOffer( + o.getSlaveId.getValue, + o.getHostname, + cpus) + } - // Call into the TaskSchedulerImpl - val taskLists = scheduler.resourceOffers(offerableWorkers) - - // Build a list of Mesos tasks for each slave - val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]()) - for ((taskList, index) <- taskLists.zipWithIndex) { - if (!taskList.isEmpty) { - for (taskDesc <- taskList) { - val slaveId = taskDesc.executorId - val offerNum = offerableIndices(slaveId) - slaveIdsWithExecutors += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) - } + val slaveIdToOffer = acceptedOffers.map(o => o.getSlaveId.getValue -> o).toMap + + val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] + + // Call into the TaskSchedulerImpl + scheduler.resourceOffers(offerableWorkers) + .filter(!_.isEmpty) + .foreach { offer => + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slaveIdsWithExecutors += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(createMesosTask(taskDesc, slaveId)) } } - // Reply to the offers - val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? - for (i <- 0 until offers.size) { - d.launchTasks(Collections.singleton(offers(i).getId), mesosTasks(i), filters) - } + // Reply to the offers + val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? + + mesosTasks.foreach { case (slaveId, tasks) => + d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } - } finally { - restoreClassLoader(oldClassLoader) + + declinedOffers.foreach(o => d.declineOffer(o.getId)) } } @@ -308,8 +296,7 @@ private[spark] class MesosSchedulerBackend( } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { val tid = status.getTaskId.getValue.toLong val state = TaskState.fromMesos(status.getState) synchronized { @@ -322,18 +309,13 @@ private[spark] class MesosSchedulerBackend( } } scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer) - } finally { - restoreClassLoader(oldClassLoader) } } override def error(d: SchedulerDriver, message: String) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { logError("Mesos error: " + message) scheduler.error(message) - } finally { - restoreClassLoader(oldClassLoader) } } @@ -350,15 +332,12 @@ private[spark] class MesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) synchronized { slaveIdsWithExecutors -= slaveId.getValue } scheduler.executorLost(slaveId.getValue, reason) - } finally { - restoreClassLoader(oldClassLoader) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index c0264836de738..a2f1f14264a99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -51,7 +51,7 @@ private[spark] class LocalActor( private val localExecutorHostname = "localhost" val executor = new Executor( - localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true) + localExecutorId, localExecutorHostname, scheduler.conf.getAll, totalCores, isLocal = true) override def receiveWithLogging = { case ReviveOffers => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 39434f473a9d8..308c59eda594d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -73,7 +73,8 @@ private[spark] class BlockManager( mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, blockTransferService: BlockTransferService, - securityManager: SecurityManager) + securityManager: SecurityManager, + numUsableCores: Int) extends BlockDataManager with Logging { val diskBlockManager = new DiskBlockManager(this, conf) @@ -121,8 +122,8 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTranserService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), securityManager, - securityManager.isAuthenticationEnabled()) + val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) } else { blockTransferService } @@ -174,9 +175,10 @@ private[spark] class BlockManager( mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, blockTransferService: BlockTransferService, - securityManager: SecurityManager) = { + securityManager: SecurityManager, + numUsableCores: Int) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager) + conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } /** diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 3312671b6f885..7bc1e24d58711 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -175,7 +175,7 @@ private[spark] object UIUtils extends Logging { val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index e9c755e36f716..c82730f524eb7 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.exec +import java.net.URLDecoder import javax.servlet.http.HttpServletRequest import scala.util.Try @@ -29,7 +30,19 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage private val sc = parent.sc def render(request: HttpServletRequest): Seq[Node] = { - val executorId = Option(request.getParameter("executorId")).getOrElse { + val executorId = Option(request.getParameter("executorId")).map { + executorId => + // Due to YARN-2844, "" in the url will be encoded to "%25253Cdriver%25253E" when + // running in yarn-cluster mode. `request.getParameter("executorId")` will return + // "%253Cdriver%253E". Therefore we need to decode it until we get the real id. + var id = executorId + var decodedId = URLDecoder.decode(id, "UTF-8") + while (id != decodedId) { + id = decodedId + decodedId = URLDecoder.decode(id, "UTF-8") + } + id + }.getOrElse { return Text(s"Missing executorId parameter") } val time = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 048fee3ce1ff4..71b59b1d078ca 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.exec +import java.net.URLEncoder import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -139,8 +140,9 @@ private[ui] class ExecutorsPage( { if (threadDumpEnabled) { + val encodedId = URLEncoder.encode(info.id, "UTF-8") - Thread Dump + Thread Dump } else { Seq.empty diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index eae542df85d08..2ff561ccc7da0 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -175,7 +175,9 @@ private[ui] class StageTableBase( Seq.empty }} ++ {makeDescription(s)} - {submissionTime} + + {submissionTime} + {formattedDuration} {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index d7dccd4af8c6e..0e4c6d633a4a9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -105,7 +105,8 @@ private[spark] trait Spillable[C] { */ @inline private def logSpillage(size: Long) { val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" - .format(threadId, size / (1024 * 1024), _spillCount, if (_spillCount > 1) "s" else "")) + logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)" + .format(threadId, org.apache.spark.util.Utils.bytesToString(size), + _spillCount, if (_spillCount > 1) "s" else "")) } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 66cf60d25f6d1..ce804f94f3267 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -37,20 +37,24 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { .set("spark.dynamicAllocation.enabled", "true") intercept[SparkException] { new SparkContext(conf) } SparkEnv.get.stop() // cleanup the created environment + SparkContext.clearActiveContext() // Only min val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1") intercept[SparkException] { new SparkContext(conf1) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Only max val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2") intercept[SparkException] { new SparkContext(conf2) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Both min and max, but min > max intercept[SparkException] { createSparkContext(2, 1) } SparkEnv.get.stop() + SparkContext.clearActiveContext() // Both min and max, and min == max val sc1 = createSparkContext(1, 1) @@ -76,6 +80,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("add executors") { sc = createSparkContext(1, 10) val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Keep adding until the limit is reached assert(numExecutorsPending(manager) === 0) @@ -117,6 +122,51 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(numExecutorsToAdd(manager) === 1) } + test("add executors capped by num pending tasks") { + sc = createSparkContext(1, 10) + val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) + + // Verify that we're capped at number of tasks in the stage + assert(numExecutorsPending(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 1) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 3) + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 5) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a task reduces the cap + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) + sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 6) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 7) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that re-running a task doesn't reduce the cap further + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 3))) + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 8) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 9) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a task once we're at our limit doesn't blow things up + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) + assert(addExecutors(manager) === 0) + assert(numExecutorsPending(manager) === 9) + } + test("remove executors") { sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get @@ -170,6 +220,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test ("interleaving add and remove") { sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Add a few executors assert(addExecutors(manager) === 1) @@ -343,6 +394,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { val clock = new TestClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Scheduler queue backlogged onSchedulerBacklogged(manager) diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 9623d665177ef..55799f55146cb 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -38,7 +38,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { - val transportConf = SparkTransportConf.fromSparkConf(conf) + val transportConf = SparkTransportConf.fromSparkConf(conf, numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf) val transportContext = new TransportContext(transportConf, rpcHandler) server = transportContext.createServer() diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 31edad1c56c73..9e454ddcc52a6 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -21,9 +21,62 @@ import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable -class SparkContextSuite extends FunSuite { - //Regression test for SPARK-3121 +class SparkContextSuite extends FunSuite with LocalSparkContext { + + /** Allows system properties to be changed in tests */ + private def withSystemProperty[T](property: String, value: String)(block: => T): T = { + val originalValue = System.getProperty(property) + try { + System.setProperty(property, value) + block + } finally { + if (originalValue == null) { + System.clearProperty(property) + } else { + System.setProperty(property, originalValue) + } + } + } + + test("Only one SparkContext may be active at a time") { + // Regression test for SPARK-4180 + withSystemProperty("spark.driver.allowMultipleContexts", "false") { + val conf = new SparkConf().setAppName("test").setMaster("local") + sc = new SparkContext(conf) + // A SparkContext is already running, so we shouldn't be able to create a second one + intercept[SparkException] { new SparkContext(conf) } + // After stopping the running context, we should be able to create a new one + resetSparkContext() + sc = new SparkContext(conf) + } + } + + test("Can still construct a new SparkContext after failing to construct a previous one") { + withSystemProperty("spark.driver.allowMultipleContexts", "false") { + // This is an invalid configuration (no app name or master URL) + intercept[SparkException] { + new SparkContext(new SparkConf()) + } + // Even though those earlier calls failed, we should still be able to create a new context + sc = new SparkContext(new SparkConf().setMaster("local").setAppName("test")) + } + } + + test("Check for multiple SparkContexts can be disabled via undocumented debug option") { + withSystemProperty("spark.driver.allowMultipleContexts", "true") { + var secondSparkContext: SparkContext = null + try { + val conf = new SparkConf().setAppName("test").setMaster("local") + sc = new SparkContext(conf) + secondSparkContext = new SparkContext(conf) + } finally { + Option(secondSparkContext).foreach(_.stop()) + } + } + } + test("BytesWritable implicit conversion is correct") { + // Regression test for SPARK-3121 val bytesWritable = new BytesWritable() val inputArray = (1 to 10).map(_.toByte).toArray bytesWritable.set(inputArray, 0, 10) diff --git a/core/src/test/scala/org/apache/spark/StatusAPISuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala similarity index 69% rename from core/src/test/scala/org/apache/spark/StatusAPISuite.scala rename to core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 4468fba8c1dff..8577e4ac7e33e 100644 --- a/core/src/test/scala/org/apache/spark/StatusAPISuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -27,9 +27,10 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.JobExecutionStatus._ import org.apache.spark.SparkContext._ -class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { +class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { test("basic status API usage") { + sc = new SparkContext("local", "test", new SparkConf(false)) val jobFuture = sc.parallelize(1 to 10000, 2).map(identity).groupBy(identity).collectAsync() val jobId: Int = eventually(timeout(10 seconds)) { val jobIds = jobFuture.jobIds @@ -37,20 +38,20 @@ class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { jobIds.head } val jobInfo = eventually(timeout(10 seconds)) { - sc.getJobInfo(jobId).get + sc.statusTracker.getJobInfo(jobId).get } jobInfo.status() should not be FAILED val stageIds = jobInfo.stageIds() stageIds.size should be(2) val firstStageInfo = eventually(timeout(10 seconds)) { - sc.getStageInfo(stageIds(0)).get + sc.statusTracker.getStageInfo(stageIds(0)).get } firstStageInfo.stageId() should be(stageIds(0)) firstStageInfo.currentAttemptId() should be(0) firstStageInfo.numTasks() should be(2) eventually(timeout(10 seconds)) { - val updatedFirstStageInfo = sc.getStageInfo(stageIds(0)).get + val updatedFirstStageInfo = sc.statusTracker.getStageInfo(stageIds(0)).get updatedFirstStageInfo.numCompletedTasks() should be(2) updatedFirstStageInfo.numActiveTasks() should be(0) updatedFirstStageInfo.numFailedTasks() should be(0) @@ -58,21 +59,31 @@ class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { } test("getJobIdsForGroup()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + // Passing `null` should return jobs that were not run in a job group: + val defaultJobGroupFuture = sc.parallelize(1 to 1000).countAsync() + val defaultJobGroupJobId = eventually(timeout(10 seconds)) { + defaultJobGroupFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup(null).toSet should be (Set(defaultJobGroupJobId)) + } + // Test jobs submitted in job groups: sc.setJobGroup("my-job-group", "description") - sc.getJobIdsForGroup("my-job-group") should be (Seq.empty) + sc.statusTracker.getJobIdsForGroup("my-job-group") should be (Seq.empty) val firstJobFuture = sc.parallelize(1 to 1000).countAsync() val firstJobId = eventually(timeout(10 seconds)) { firstJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.getJobIdsForGroup("my-job-group") should be (Seq(firstJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group") should be (Seq(firstJobId)) } val secondJobFuture = sc.parallelize(1 to 1000).countAsync() val secondJobId = eventually(timeout(10 seconds)) { secondJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) } } } \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 94a2bdd74e744..4161aede1d1d0 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -29,12 +29,6 @@ class ClientSuite extends FunSuite with Matchers { ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo") should be (false) ClientArguments.isValidJarUrl("/missing/a/protocol/jarfile.jar") should be (false) ClientArguments.isValidJarUrl("not-even-a-path.jar") should be (false) - - // No authority - ClientArguments.isValidJarUrl("hdfs:someHost:1234/jarfile.jar") should be (false) - - // Invalid syntax - ClientArguments.isValidJarUrl("hdfs:") should be (false) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index d8cd0ff2c9026..eb7bd7ab3986e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,7 +21,7 @@ import java.io._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkException, TestUtils} +import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.Utils import org.scalatest.FunSuite @@ -451,24 +451,25 @@ class SparkSubmitSuite extends FunSuite with Matchers { } } -object JarCreationTest { +object JarCreationTest extends Logging { def main(args: Array[String]) { Utils.configTestLog4j("INFO") val conf = new SparkConf() val sc = new SparkContext(conf) val result = sc.makeRDD(1 to 100, 10).mapPartitions { x => - var foundClasses = false + var exception: String = null try { Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) Class.forName("SparkSubmitClassA", true, Thread.currentThread().getContextClassLoader) - foundClasses = true } catch { - case _: Throwable => // catch all + case t: Throwable => + exception = t + "\n" + t.getStackTraceString + exception = exception.replaceAll("\n", "\n\t") } - Seq(foundClasses).iterator + Option(exception).toSeq.iterator }.collect() - if (result.contains(false)) { - throw new Exception("Could not load user defined classes inside of executors") + if (result.nonEmpty) { + throw new Exception("Could not load user class from jar:\n" + result(0)) } } } diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 12d1c7b2faba6..98b0a16ce88ba 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.io.Text import org.apache.spark.SparkContext import org.apache.spark.util.Utils +import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} /** * Tests the correctness of @@ -38,20 +39,32 @@ import org.apache.spark.util.Utils */ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { private var sc: SparkContext = _ + private var factory: CompressionCodecFactory = _ override def beforeAll() { sc = new SparkContext("local", "test") // Set the block size of local file system to test whether files are split right or not. sc.hadoopConfiguration.setLong("fs.local.block.size", 32) + sc.hadoopConfiguration.set("io.compression.codecs", + "org.apache.hadoop.io.compress.GzipCodec,org.apache.hadoop.io.compress.DefaultCodec") + factory = new CompressionCodecFactory(sc.hadoopConfiguration) } override def afterAll() { sc.stop() } - private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte]) = { - val out = new DataOutputStream(new FileOutputStream(s"${inputDir.toString}/$fileName")) + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], + compress: Boolean) = { + val out = if (compress) { + val codec = new GzipCodec + val path = s"${inputDir.toString}/$fileName${codec.getDefaultExtension}" + codec.createOutputStream(new DataOutputStream(new FileOutputStream(path))) + } else { + val path = s"${inputDir.toString}/$fileName" + new DataOutputStream(new FileOutputStream(path)) + } out.write(contents, 0, contents.length) out.close() } @@ -68,7 +81,7 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { println(s"Local disk address is ${dir.toString}.") WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => - createNativeFile(dir, filename, contents) + createNativeFile(dir, filename, contents, false) } val res = sc.wholeTextFiles(dir.toString, 3).collect() @@ -86,6 +99,31 @@ class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { Utils.deleteRecursively(dir) } + + test("Correctness of WholeTextFileRecordReader with GzipCodec.") { + val dir = Utils.createTempDir() + println(s"Local disk address is ${dir.toString}.") + + WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents, true) + } + + val res = sc.wholeTextFiles(dir.toString, 3).collect() + + assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, + "Number of files read out does not fit with the actual value.") + + for ((filename, contents) <- res) { + val shortName = filename.split('/').last.split('.')(0) + + assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName), + s"Missing file name $filename.") + assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString, + s"file $filename contents can not match.") + } + + Utils.deleteRecursively(dir) + } } /** diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 530f5d6db5a29..94bfa67451892 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -104,11 +104,11 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) val securityManager0 = new SecurityManager(conf0) - val exec0 = new NettyBlockTransferService(conf0, securityManager0) + val exec0 = new NettyBlockTransferService(conf0, securityManager0, numCores = 1) exec0.init(blockManager) val securityManager1 = new SecurityManager(conf1) - val exec1 = new NettyBlockTransferService(conf1, securityManager1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1, numCores = 1) exec1.init(blockManager) val result = fetchBlock(exec0, exec1, "1", blockId) match { diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..bef8d3a58ba63 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import org.scalatest.FunSuite +import org.apache.spark.{scheduler, SparkConf, SparkContext, LocalSparkContext} +import org.apache.spark.scheduler.{TaskDescription, WorkerOffer, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.mesos.{MemoryUtils, MesosSchedulerBackend} +import org.apache.mesos.SchedulerDriver +import org.apache.mesos.Protos._ +import org.scalatest.mock.EasyMockSugar +import org.apache.mesos.Protos.Value.Scalar +import org.easymock.{Capture, EasyMock} +import java.nio.ByteBuffer +import java.util.Collections +import java.util +import scala.collection.mutable + +class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with EasyMockSugar { + test("mesos resource offer is launching tasks") { + def createOffer(id: Int, mem: Int, cpu: Int) = { + val builder = Offer.newBuilder() + builder.addResourcesBuilder() + .setName("mem") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(mem)) + builder.addResourcesBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Scalar.newBuilder().setValue(cpu)) + builder.setId(OfferID.newBuilder().setValue(id.toString).build()).setFrameworkId(FrameworkID.newBuilder().setValue("f1")) + .setSlaveId(SlaveID.newBuilder().setValue("s1")).setHostname("localhost").build() + } + + val driver = EasyMock.createMock(classOf[SchedulerDriver]) + val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) + + val sc = EasyMock.createMock(classOf[SparkContext]) + + EasyMock.expect(sc.executorMemory).andReturn(100).anyTimes() + EasyMock.expect(sc.getSparkHome()).andReturn(Option("/path")).anyTimes() + EasyMock.expect(sc.executorEnvs).andReturn(new mutable.HashMap).anyTimes() + EasyMock.expect(sc.conf).andReturn(new SparkConf).anyTimes() + EasyMock.replay(sc) + val minMem = MemoryUtils.calculateTotalMemory(sc).toInt + val minCpu = 4 + val offers = new java.util.ArrayList[Offer] + offers.add(createOffer(1, minMem, minCpu)) + offers.add(createOffer(1, minMem - 1, minCpu)) + val backend = new MesosSchedulerBackend(taskScheduler, sc, "master") + val workerOffers = Seq(offers.get(0)).map(o => new WorkerOffer( + o.getSlaveId.getValue, + o.getHostname, + 2 + )) + val taskDesc = new TaskDescription(1L, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0))) + EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(workerOffers))).andReturn(Seq(Seq(taskDesc))) + EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes() + EasyMock.replay(taskScheduler) + val capture = new Capture[util.Collection[TaskInfo]] + EasyMock.expect( + driver.launchTasks( + EasyMock.eq(Collections.singleton(offers.get(0).getId)), + EasyMock.capture(capture), + EasyMock.anyObject(classOf[Filters]) + ) + ).andReturn(Status.valueOf(1)) + EasyMock.expect(driver.declineOffer(offers.get(1).getId)).andReturn(Status.valueOf(1)) + EasyMock.replay(driver) + backend.resourceOffers(driver, offers) + assert(capture.getValue.size() == 1) + val taskInfo = capture.getValue.iterator().next() + assert(taskInfo.getName.equals("n1")) + val cpus = taskInfo.getResourcesList.get(0) + assert(cpus.getName.equals("cpus")) + assert(cpus.getScalar.getValue.equals(2.0)) + assert(taskInfo.getSlaveId.getValue.equals("s1")) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index f63e772bf1e59..c2903c8597997 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -62,7 +62,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr) + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store store @@ -263,7 +263,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, - 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr) + 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test assert(master.getPeers(store.blockManagerId).toSet === Set(failableStore.blockManagerId)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9529502bc8e10..5554efbcbadf8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -74,7 +74,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, - mapOutputTracker, shuffleManager, transfer, securityMgr) + mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } @@ -795,7 +795,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, - new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr) + new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, + 0) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index f9d1af88f3a13..0ea2d13a83505 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -118,7 +118,7 @@ class SizeEstimatorSuite // TODO: If we sample 100 elements, this should always be 4176 ? val estimatedSize = SizeEstimator.estimate(Array.fill(1000)(d1)) assert(estimatedSize >= 4000, "Estimated size " + estimatedSize + " should be more than 4000") - assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4100") + assert(estimatedSize <= 4200, "Estimated size " + estimatedSize + " should be less than 4200") } test("32-bit arch") { diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 16ea1a71290dc..0b7069f6e116a 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -30,71 +30,84 @@ import time import urllib2 -# Fill in release details here: -RELEASE_URL = "http://people.apache.org/~pwendell/spark-1.0.0-rc1/" -RELEASE_KEY = "9E4FE3AF" -RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1006/" -RELEASE_VERSION = "1.0.0" +# Note: The following variables must be set before use! +RELEASE_URL = "http://people.apache.org/~andrewor14/spark-1.1.1-rc1/" +RELEASE_KEY = "XXXXXXXX" # Your 8-digit hex +RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1033" +RELEASE_VERSION = "1.1.1" SCALA_VERSION = "2.10.4" SCALA_BINARY_VERSION = "2.10" -# +# Do not set these LOG_FILE_NAME = "spark_audit_%s" % time.strftime("%h_%m_%Y_%I_%M_%S") LOG_FILE = open(LOG_FILE_NAME, 'w') WORK_DIR = "/tmp/audit_%s" % int(time.time()) MAVEN_CMD = "mvn" GPG_CMD = "gpg" +SBT_CMD = "sbt -Dsbt.log.noformat=true" -print "Starting tests, log output in %s. Test results printed below:" % LOG_FILE_NAME - -# Track failures +# Track failures to print them at the end failures = [] +# Log a message. Use sparingly because this flushes every write. +def log(msg): + LOG_FILE.write(msg + "\n") + LOG_FILE.flush() +def log_and_print(msg): + print msg + log(msg) + +# Prompt the user to delete the scratch directory used def clean_work_files(): - print "OK to delete scratch directory '%s'? (y/N): " % WORK_DIR - response = raw_input() + response = raw_input("OK to delete scratch directory '%s'? (y/N) " % WORK_DIR) if response == "y": shutil.rmtree(WORK_DIR) - print "Should I delete the log output file '%s'? (y/N): " % LOG_FILE_NAME - response = raw_input() - if response == "y": - os.unlink(LOG_FILE_NAME) - +# Run the given command and log its output to the log file def run_cmd(cmd, exit_on_failure=True): - print >> LOG_FILE, "Running command: %s" % cmd + log("Running command: %s" % cmd) ret = subprocess.call(cmd, shell=True, stdout=LOG_FILE, stderr=LOG_FILE) if ret != 0 and exit_on_failure: - print "Command failed: %s" % cmd + log_and_print("Command failed: %s" % cmd) clean_work_files() sys.exit(-1) return ret - def run_cmd_with_output(cmd): - print >> sys.stderr, "Running command: %s" % cmd + log_and_print("Running command: %s" % cmd) return subprocess.check_output(cmd, shell=True, stderr=LOG_FILE) +# Test if the given condition is successful +# If so, print the pass message; otherwise print the failure message +def test(cond, msg): + return passed(msg) if cond else failed(msg) -def test(bool, str): - if bool: - return passed(str) - failed(str) - - -def passed(str): - print "[PASSED] %s" % str - - -def failed(str): - failures.append(str) - print "[**FAILED**] %s" % str +def passed(msg): + log_and_print("[PASSED] %s" % msg) +def failed(msg): + failures.append(msg) + log_and_print("[**FAILED**] %s" % msg) def get_url(url): return urllib2.urlopen(url).read() +# If the path exists, prompt the user to delete it +# If the resource is not deleted, abort +def ensure_path_not_present(path): + full_path = os.path.expanduser(path) + if os.path.exists(full_path): + print "Found %s locally." % full_path + response = raw_input("This can interfere with testing published artifacts. OK to delete? (y/N) ") + if response == "y": + shutil.rmtree(full_path) + else: + print "Abort." + sys.exit(-1) + +log_and_print("|-------- Starting Spark audit tests for release %s --------|" % RELEASE_VERSION) +log_and_print("Log output can be found in %s" % LOG_FILE_NAME) original_dir = os.getcwd() @@ -114,37 +127,36 @@ def get_url(url): cache_ivy_spark = "~/.ivy2/cache/org.apache.spark" local_maven_kafka = "~/.m2/repository/org/apache/kafka" local_maven_kafka = "~/.m2/repository/org/apache/spark" - - -def ensure_path_not_present(x): - if os.path.exists(os.path.expanduser(x)): - print "Please remove %s, it can interfere with testing published artifacts." % x - sys.exit(-1) - map(ensure_path_not_present, [local_ivy_spark, cache_ivy_spark, local_maven_kafka]) # SBT build tests +log_and_print("==== Building SBT modules ====") os.chdir("blank_sbt_build") os.environ["SPARK_VERSION"] = RELEASE_VERSION os.environ["SCALA_VERSION"] = SCALA_VERSION os.environ["SPARK_RELEASE_REPOSITORY"] = RELEASE_REPOSITORY os.environ["SPARK_AUDIT_MASTER"] = "local" for module in modules: + log("==== Building module %s in SBT ====" % module) os.environ["SPARK_MODULE"] = module - ret = run_cmd("sbt clean update", exit_on_failure=False) - test(ret == 0, "sbt build against '%s' module" % module) + ret = run_cmd("%s clean update" % SBT_CMD, exit_on_failure=False) + test(ret == 0, "SBT build against '%s' module" % module) os.chdir(original_dir) # SBT application tests +log_and_print("==== Building SBT applications ====") for app in ["sbt_app_core", "sbt_app_graphx", "sbt_app_streaming", "sbt_app_sql", "sbt_app_hive", "sbt_app_kinesis"]: + log("==== Building application %s in SBT ====" % app) os.chdir(app) - ret = run_cmd("sbt clean run", exit_on_failure=False) - test(ret == 0, "sbt application (%s)" % app) + ret = run_cmd("%s clean run" % SBT_CMD, exit_on_failure=False) + test(ret == 0, "SBT application (%s)" % app) os.chdir(original_dir) # Maven build tests os.chdir("blank_maven_build") +log_and_print("==== Building Maven modules ====") for module in modules: + log("==== Building module %s in maven ====" % module) cmd = ('%s --update-snapshots -Dspark.release.repository="%s" -Dspark.version="%s" ' '-Dspark.module="%s" clean compile' % (MAVEN_CMD, RELEASE_REPOSITORY, RELEASE_VERSION, module)) @@ -152,6 +164,8 @@ def ensure_path_not_present(x): test(ret == 0, "maven build against '%s' module" % module) os.chdir(original_dir) +# Maven application tests +log_and_print("==== Building Maven applications ====") os.chdir("maven_app_core") mvn_exec_cmd = ('%s --update-snapshots -Dspark.release.repository="%s" -Dspark.version="%s" ' '-Dscala.binary.version="%s" clean compile ' @@ -172,15 +186,14 @@ def ensure_path_not_present(x): artifact_regex = r = re.compile("") artifacts = r.findall(index_page) +# Verify artifact integrity for artifact in artifacts: - print "==== Verifying download integrity for artifact: %s ====" % artifact + log_and_print("==== Verifying download integrity for artifact: %s ====" % artifact) artifact_url = "%s/%s" % (RELEASE_URL, artifact) - run_cmd("wget %s" % artifact_url) - key_file = "%s.asc" % artifact + run_cmd("wget %s" % artifact_url) run_cmd("wget %s/%s" % (RELEASE_URL, key_file)) - run_cmd("wget %s%s" % (artifact_url, ".sha")) # Verify signature @@ -208,31 +221,17 @@ def ensure_path_not_present(x): os.chdir(WORK_DIR) -for artifact in artifacts: - print "==== Verifying build and tests for artifact: %s ====" % artifact - os.chdir(os.path.join(WORK_DIR, dir_name)) - - os.environ["MAVEN_OPTS"] = "-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - # Verify build - print "==> Running build" - run_cmd("sbt assembly") - passed("sbt build successful") - run_cmd("%s package -DskipTests" % MAVEN_CMD) - passed("Maven build successful") - - # Verify tests - print "==> Performing unit tests" - run_cmd("%s test" % MAVEN_CMD) - passed("Tests successful") - os.chdir(WORK_DIR) - -clean_work_files() - +# Report result +log_and_print("\n") if len(failures) == 0: - print "ALL TESTS PASSED" + log_and_print("*** ALL TESTS PASSED ***") else: - print "SOME TESTS DID NOT PASS" + log_and_print("XXXXX SOME TESTS DID NOT PASS XXXXX") for f in failures: - print f - + log_and_print(" %s" % f) os.chdir(original_dir) + +# Clean up +clean_work_files() + +log_and_print("|-------- Spark release audit complete --------|") diff --git a/dev/audit-release/blank_sbt_build/build.sbt b/dev/audit-release/blank_sbt_build/build.sbt index 696c7f651837c..62815542e5bd9 100644 --- a/dev/audit-release/blank_sbt_build/build.sbt +++ b/dev/audit-release/blank_sbt_build/build.sbt @@ -19,10 +19,12 @@ name := "Spark Release Auditor" version := "1.0" -scalaVersion := "2.9.3" +scalaVersion := System.getenv.get("SCALA_VERSION") libraryDependencies += "org.apache.spark" % System.getenv.get("SPARK_MODULE") % System.getenv.get("SPARK_VERSION") resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), + "Eclipse Paho Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/", + "Maven Repository" at "http://repo1.maven.org/maven2/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_hive/build.sbt b/dev/audit-release/sbt_app_hive/build.sbt index a0d4f25da5842..c8824f2b15e55 100644 --- a/dev/audit-release/sbt_app_hive/build.sbt +++ b/dev/audit-release/sbt_app_hive/build.sbt @@ -25,4 +25,5 @@ libraryDependencies += "org.apache.spark" %% "spark-hive" % System.getenv.get("S resolvers ++= Seq( "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), + "Maven Repository" at "http://repo1.maven.org/maven2/", "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_hive/src/main/resources/hive-site.xml b/dev/audit-release/sbt_app_hive/src/main/resources/hive-site.xml deleted file mode 100644 index 93b835813d535..0000000000000 --- a/dev/audit-release/sbt_app_hive/src/main/resources/hive-site.xml +++ /dev/null @@ -1,213 +0,0 @@ - - - - - - - - - - - - - - - - - - build.dir - ${user.dir}/build - - - - build.dir.hive - ${build.dir}/hive - - - - hadoop.tmp.dir - ${build.dir.hive}/test/hadoop-${user.name} - A base for other temporary directories. - - - - - - hive.exec.scratchdir - ${build.dir}/scratchdir - Scratch space for Hive jobs - - - - hive.exec.local.scratchdir - ${build.dir}/localscratchdir/ - Local scratch space for Hive jobs - - - - javax.jdo.option.ConnectionURL - - jdbc:derby:;databaseName=../build/test/junit_metastore_db;create=true - - - - javax.jdo.option.ConnectionDriverName - org.apache.derby.jdbc.EmbeddedDriver - - - - javax.jdo.option.ConnectionUserName - APP - - - - javax.jdo.option.ConnectionPassword - mine - - - - - hive.metastore.warehouse.dir - ${test.warehouse.dir} - - - - - hive.metastore.metadb.dir - ${build.dir}/test/data/metadb/ - - Required by metastore server or if the uris argument below is not supplied - - - - - test.log.dir - ${build.dir}/test/logs - - - - - test.src.dir - ${build.dir}/src/test - - - - - - - hive.jar.path - ${build.dir.hive}/ql/hive-exec-${version}.jar - - - - - hive.metastore.rawstore.impl - org.apache.hadoop.hive.metastore.ObjectStore - Name of the class that implements org.apache.hadoop.hive.metastore.rawstore interface. This class is used to store and retrieval of raw metadata objects such as table, database - - - - hive.querylog.location - ${build.dir}/tmp - Location of the structured hive logs - - - - - - hive.task.progress - false - Track progress of a task - - - - hive.support.concurrency - false - Whether hive supports concurrency or not. A zookeeper instance must be up and running for the default hive lock manager to support read-write locks. - - - - fs.pfile.impl - org.apache.hadoop.fs.ProxyLocalFileSystem - A proxy for local file system used for cross file system testing - - - - hive.exec.mode.local.auto - false - - Let hive determine whether to run in local mode automatically - Disabling this for tests so that minimr is not affected - - - - - hive.auto.convert.join - false - Whether Hive enable the optimization about converting common join into mapjoin based on the input file size - - - - hive.ignore.mapjoin.hint - false - Whether Hive ignores the mapjoin hint - - - - hive.input.format - org.apache.hadoop.hive.ql.io.CombineHiveInputFormat - The default input format, if it is not specified, the system assigns it. It is set to HiveInputFormat for hadoop versions 17, 18 and 19, whereas it is set to CombineHiveInputFormat for hadoop 20. The user can always overwrite it - if there is a bug in CombineHiveInputFormat, it can always be manually set to HiveInputFormat. - - - - hive.default.rcfile.serde - org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe - The default SerDe hive will use for the rcfile format - - - diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh new file mode 100755 index 0000000000000..7473c20d28e09 --- /dev/null +++ b/dev/change-version-to-2.10.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +find . -name 'pom.xml' | grep -v target \ + | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.11|\1_2.10|g' {} diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh new file mode 100755 index 0000000000000..3957a9f3ba258 --- /dev/null +++ b/dev/change-version-to-2.11.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +find . -name 'pom.xml' | grep -v target \ + | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.10|\1_2.11|g' {} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 281e8d4de6d71..a6e90a15ee84b 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -27,6 +27,7 @@ # Would be nice to add: # - Send output to stderr and have useful logging in stdout +# Note: The following variables must be set before use! GIT_USERNAME=${GIT_USERNAME:-pwendell} GIT_PASSWORD=${GIT_PASSWORD:-XXX} GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} @@ -101,7 +102,7 @@ make_binary_release() { cp -r spark spark-$RELEASE_VERSION-bin-$NAME cd spark-$RELEASE_VERSION-bin-$NAME - ./make-distribution.sh --name $NAME --tgz $FLAGS + ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log cd .. cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . rm -rf spark-$RELEASE_VERSION-bin-$NAME @@ -117,13 +118,13 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" & -make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & -make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" & -make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" & +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & +make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" & +make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" & make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & -make_binary_release "mapr3" "-Pmapr3 -Phive" & -make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" & +make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" & +make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" & wait # Copy data diff --git a/dev/run-tests b/dev/run-tests index de607e4344453..328a73bd8b26d 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -139,9 +139,6 @@ echo "=========================================================================" CURRENT_BLOCK=$BLOCK_BUILD { - # We always build with Hive because the PySpark Spark SQL tests need it. - BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-0.12.0" - # NOTE: echo "q" is needed because sbt on encountering a build file with failure #+ (either resolution or compilation) prompts the user for input either q, r, etc @@ -151,15 +148,17 @@ CURRENT_BLOCK=$BLOCK_BUILD # QUESTION: Why doesn't 'yes "q"' work? # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? # First build with 0.12 to ensure patches do not break the hive 12 build + HIVE_12_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver -Phive-0.12.0" echo "[info] Compile with hive 0.12" echo -e "q\n" \ - | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean hive/compile hive-thriftserver/compile \ + | sbt/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" # Then build with default version(0.13.1) because tests are based on this version - echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS -Phive" + echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS"\ + " -Phive -Phive-thriftserver" echo -e "q\n" \ - | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive package assembly/assembly \ + | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" } @@ -174,7 +173,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. # This must be a single argument, as it is. if [ -n "$_RUN_SQL_TESTS" ]; then - SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" + SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" fi if [ -n "$_SQL_TESTS_ONLY" ]; then diff --git a/dev/scalastyle b/dev/scalastyle index ed1b6b730af6e..c3c6012e74ffa 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt # Check style with YARN alpha built too echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt diff --git a/docs/building-spark.md b/docs/building-spark.md index 238ddae15545e..bb18414092aae 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -101,25 +101,34 @@ mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -Dski # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, -add the `-Phive` profile to your existing build options. By default Spark -will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using -the `-Phive-0.12.0` profile. +add the `-Phive` and `Phive-thriftserver` profiles to your existing build options. +By default Spark will build with Hive 0.13.1 bindings. You can also build for +Hive 0.12.0 using the `-Phive-0.12.0` profile. {% highlight bash %} # Apache Hadoop 2.4.X with Hive 13 support -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package # Apache Hadoop 2.4.X with Hive 12 support -mvn -Pyarn -Phive-0.12.0 -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package +mvn -Pyarn -Phive -Phive-thriftserver-0.12.0 -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package {% endhighlight %} +# Building for Scala 2.11 +To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: + + mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package + +Scala 2.11 support in Spark is experimental and does not support a few features. +Specifically, Spark's external Kafka library and JDBC component are not yet +supported in Scala 2.11 builds. + # Spark Tests in Maven Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: - mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package - mvn -Pyarn -Phadoop-2.3 -Phive test + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-thriftserver clean package + mvn -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test The ScalaTest plugin also supports running only a specific test suite as follows: @@ -182,16 +191,16 @@ can be set to control the SBT build. For example: Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive assembly - sbt/sbt -Pyarn -Phadoop-2.3 -Phive test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test To run only a specific test suite as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive "test-only org.apache.spark.repl.ReplSuite" + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.ReplSuite" To run test suites of a specific sub project as follows: - sbt/sbt -Pyarn -Phadoop-2.3 -Phive core/test + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver core/test # Speeding up Compilation with Zinc diff --git a/docs/configuration.md b/docs/configuration.md index f0b396e21f198..8839162c3a13e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -52,7 +52,7 @@ Then, you can supply configuration values at runtime: --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} -The Spark shell and [`spark-submit`](cluster-overview.html#launching-applications-with-spark-submit) +The Spark shell and [`spark-submit`](submitting-applications.html) tool support two ways to load configurations dynamically. The first are command line options, such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` flag, but uses special flags for properties that play a part in launching the Spark application. diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 18420afb27e3c..49f319ba775e5 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -117,6 +117,8 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/ how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object that contains information about your application. +Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before creating a new one. + {% highlight scala %} val conf = new SparkConf().setAppName(appName).setMaster(master) new SparkContext(conf) @@ -1131,7 +1133,7 @@ method. The code below shows this: {% highlight scala %} scala> val broadcastVar = sc.broadcast(Array(1, 2, 3)) -broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c) +broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0) scala> broadcastVar.value res0: Array[Int] = Array(1, 2, 3) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 2f7e4981e5bb9..dfe2db4b3fce8 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -39,7 +39,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.preserve.staging.files false - Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather then delete them. + Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. @@ -159,7 +159,7 @@ For example: lib/spark-examples*.jar \ 10 -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Viewing Logs" section below for how to see driver and executor logs. +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: @@ -181,7 +181,7 @@ In YARN terminology, executors and application masters run inside "containers". yarn logs -applicationId -will print out the contents of all log files from all containers from the given application. +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ffcce2c588879..5500da83b2b66 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -14,7 +14,7 @@ title: Spark SQL Programming Guide Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using Spark. At the core of this component is a new type of RDD, [SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects, along with +[Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). @@ -728,7 +728,7 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. -In order to use Hive you must first run "`sbt/sbt -Phive assembly/assembly`" (or use `-Phive` for maven). +Hive support is enabled by adding the `-Phive` and `-Phive-thriftserver` flags to Spark's build. This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. diff --git a/examples/pom.xml b/examples/pom.xml index 910eb55308b9d..85e133779e465 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,48 +34,6 @@ Spark Project Examples http://spark.apache.org/ - - - kinesis-asl - - - org.apache.spark - spark-streaming-kinesis-asl_${scala.binary.version} - ${project.version} - - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - - - hbase-hadoop2 - - - hbase.profile - hadoop2 - - - - 0.98.7-hadoop2 - - - - hbase-hadoop1 - - - !hbase.profile - - - - 0.98.7-hadoop1 - - - - - @@ -124,11 +82,6 @@ spark-streaming-twitter_${scala.binary.version} ${project.version} - - org.apache.spark - spark-streaming-kafka_${scala.binary.version} - ${project.version} - org.apache.spark spark-streaming-flume_${scala.binary.version} @@ -136,12 +89,12 @@ org.apache.spark - spark-streaming-zeromq_${scala.binary.version} + spark-streaming-mqtt_${scala.binary.version} ${project.version} org.apache.spark - spark-streaming-mqtt_${scala.binary.version} + spark-streaming-zeromq_${scala.binary.version} ${project.version} @@ -260,15 +213,15 @@ test-jar test - - com.twitter - algebird-core_${scala.binary.version} - 0.1.11 - org.apache.commons commons-math3 + + com.twitter + algebird-core_${scala.binary.version} + 0.8.1 + org.scalatest scalatest_${scala.binary.version} @@ -401,4 +354,83 @@
    + + + kinesis-asl + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + + + hbase-hadoop2 + + + hbase.profile + hadoop2 + + + + 0.98.7-hadoop2 + + + + hbase-hadoop1 + + + !hbase.profile + + + + 0.98.7-hadoop1 + + + + + scala-2.10 + + !scala-2.11 + + + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + scala-2.10/src/main/scala + scala-2.10/src/main/java + + + + + + + + + diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java similarity index 92% rename from examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java rename to examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java index 430e96ab14d9d..e68ec74c3ed54 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java @@ -31,7 +31,7 @@ /** * Example of using Spark's status APIs from Java. */ -public final class JavaStatusAPIDemo { +public final class JavaStatusTrackerDemo { public static final String APP_NAME = "JavaStatusAPIDemo"; @@ -58,8 +58,8 @@ public static void main(String[] args) throws Exception { continue; } int currentJobId = jobIds.get(jobIds.size() - 1); - SparkJobInfo jobInfo = sc.getJobInfo(currentJobId); - SparkStageInfo stageInfo = sc.getStageInfo(jobInfo.stageIds()[0]); + SparkJobInfo jobInfo = sc.statusTracker().getJobInfo(currentJobId); + SparkStageInfo stageInfo = sc.statusTracker().getStageInfo(jobInfo.stageIds()[0]); System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() + " active, " + stageInfo.numCompletedTasks() + " complete"); } diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java new file mode 100644 index 0000000000000..22ba68d8c354c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import org.apache.spark.sql.api.java.Row; +import org.apache.spark.SparkConf; + +/** + * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java + * bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of + * this example {@link SimpleTextClassificationPipeline}. Run with + *
    + * bin/run-example ml.JavaSimpleTextClassificationPipeline
    + * 
    + */ +public class JavaSimpleTextClassificationPipeline { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); + JavaSparkContext jsc = new JavaSparkContext(conf); + JavaSQLContext jsql = new JavaSQLContext(jsc); + + // Prepare training documents, which are labeled. + List localTraining = Lists.newArrayList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0)); + JavaSchemaRDD training = + jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // Fit the pipeline to training documents. + PipelineModel model = pipeline.fit(training); + + // Prepare test documents, which are unlabeled. + List localTest = Lists.newArrayList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop")); + JavaSchemaRDD test = + jsql.applySchema(jsc.parallelize(localTest), Document.class); + + // Make predictions on test documents. + model.transform(test).registerAsTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + for (Row r: predictions.collect()) { + System.out.println(r); + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala new file mode 100644 index 0000000000000..ee7897d9062d9 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.beans.BeanInfo + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.sql.SQLContext + +@BeanInfo +case class LabeledDocument(id: Long, text: String, label: Double) + +@BeanInfo +case class Document(id: Long, text: String) + +/** + * A simple text classification pipeline that recognizes "spark" from input text. This is to show + * how to create and configure an ML pipeline. Run with + * {{{ + * bin/run-example ml.SimpleTextClassificationPipeline + * }}} + */ +object SimpleTextClassificationPipeline { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Prepare training documents, which are labeled. + val training = sparkContext.parallelize(Seq( + LabeledDocument(0L, "a b c d e spark", 1.0), + LabeledDocument(1L, "b d", 0.0), + LabeledDocument(2L, "spark f g h", 1.0), + LabeledDocument(3L, "hadoop mapreduce", 0.0))) + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // Fit the pipeline to training documents. + val model = pipeline.fit(training) + + // Prepare test documents, which are unlabeled. + val test = sparkContext.parallelize(Seq( + Document(4L, "spark i j k"), + Document(5L, "l m n"), + Document(6L, "mapreduce spark"), + Document(7L, "apache hadoop"))) + + // Make predictions on test documents. + model.transform(test) + .select('id, 'text, 'score, 'prediction) + .collect() + .foreach(println) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 1edd2432a0352..a113653810b93 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -55,7 +55,7 @@ object BinaryClassification { stepSize: Double = 1.0, algorithm: Algorithm = LR, regType: RegType = L2, - regParam: Double = 0.1) extends AbstractParams[Params] + regParam: Double = 0.01) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index e1f9622350135..6815b1c052208 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -47,7 +47,7 @@ object LinearRegression extends App { numIterations: Int = 100, stepSize: Double = 1.0, regType: RegType = L2, - regParam: Double = 0.1) extends AbstractParams[Params] + regParam: Double = 0.01) extends AbstractParams[Params] val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index a4d159bf38377..514252b89e74e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -18,12 +18,13 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf +import org.apache.spark.HashPartitioner import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every - * second. + * second starting with initial value of word count. * Usage: StatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. @@ -51,11 +52,18 @@ object StatefulNetworkWordCount { Some(currentCount + previousCount) } + val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { + iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + } + val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") + // Initial RDD input to updateStateByKey + val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) + // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') val lines = ssc.socketTextStream(args(0), args(1).toInt) @@ -64,7 +72,8 @@ object StatefulNetworkWordCount { // Update the cumulative count using updateStateByKey // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) + val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, + new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 28ac5929df44a..4d26b640e8d74 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -17,13 +17,12 @@ package org.apache.spark.streaming.kafka +import java.util.Properties + import scala.collection.Map import scala.reflect.{classTag, ClassTag} -import java.util.Properties -import java.util.concurrent.Executors - -import kafka.consumer._ +import kafka.consumer.{KafkaStream, Consumer, ConsumerConfig, ConsumerConnector} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties @@ -32,6 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils /** * Input stream that pulls messages from a Kafka Broker. @@ -51,12 +51,16 @@ class KafkaInputDStream[ @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], + useReliableReceiver: Boolean, storageLevel: StorageLevel ) extends ReceiverInputDStream[(K, V)](ssc_) with Logging { def getReceiver(): Receiver[(K, V)] = { - new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) - .asInstanceOf[Receiver[(K, V)]] + if (!useReliableReceiver) { + new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) + } else { + new ReliableKafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) + } } } @@ -69,14 +73,15 @@ class KafkaReceiver[ kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel - ) extends Receiver[Any](storageLevel) with Logging { + ) extends Receiver[(K, V)](storageLevel) with Logging { // Connection to Kafka - var consumerConnector : ConsumerConnector = null + var consumerConnector: ConsumerConnector = null def onStop() { if (consumerConnector != null) { consumerConnector.shutdown() + consumerConnector = null } } @@ -102,11 +107,11 @@ class KafkaReceiver[ .newInstance(consumerConfig.props) .asInstanceOf[Decoder[V]] - // Create Threads for each Topic/Message Stream we are listening + // Create threads for each topic/message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - val executorPool = Executors.newFixedThreadPool(topics.values.sum) + val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") try { // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => @@ -117,13 +122,15 @@ class KafkaReceiver[ } } - // Handles Kafka Messages - private class MessageHandler[K: ClassTag, V: ClassTag](stream: KafkaStream[K, V]) + // Handles Kafka messages + private class MessageHandler(stream: KafkaStream[K, V]) extends Runnable { def run() { logInfo("Starting MessageHandler.") try { - for (msgAndMetadata <- stream) { + val streamIterator = stream.iterator() + while (streamIterator.hasNext()) { + val msgAndMetadata = streamIterator.next() store((msgAndMetadata.key, msgAndMetadata.message)) } } catch { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index ec812e1ef3b04..b4ac929e0c070 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -70,7 +70,8 @@ object KafkaUtils { topics: Map[String, Int], storageLevel: StorageLevel ): ReceiverInputDStream[(K, V)] = { - new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, storageLevel) + val walEnabled = ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false) + new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, walEnabled, storageLevel) } /** @@ -99,7 +100,6 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. - * */ def createStream( jssc: JavaStreamingContext, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala new file mode 100644 index 0000000000000..be734b80272d1 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.util.Properties +import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap} + +import scala.collection.{Map, mutable} +import scala.reflect.{ClassTag, classTag} + +import kafka.common.TopicAndPartition +import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} +import kafka.message.MessageAndMetadata +import kafka.serializer.Decoder +import kafka.utils.{VerifiableProperties, ZKGroupTopicDirs, ZKStringSerializer, ZkUtils} +import org.I0Itec.zkclient.ZkClient + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} +import org.apache.spark.util.Utils + +/** + * ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss. + * It is turned off by default and will be enabled when + * spark.streaming.receiver.writeAheadLog.enable is true. The difference compared to KafkaReceiver + * is that this receiver manages topic-partition/offset itself and updates the offset information + * after data is reliably stored as write-ahead log. Offsets will only be updated when data is + * reliably stored, so the potential data loss problem of KafkaReceiver can be eliminated. + * + * Note: ReliableKafkaReceiver will set auto.commit.enable to false to turn off automatic offset + * commit mechanism in Kafka consumer. So setting this configuration manually within kafkaParams + * will not take effect. + */ +private[streaming] +class ReliableKafkaReceiver[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( + kafkaParams: Map[String, String], + topics: Map[String, Int], + storageLevel: StorageLevel) + extends Receiver[(K, V)](storageLevel) with Logging { + + private val groupId = kafkaParams("group.id") + private val AUTO_OFFSET_COMMIT = "auto.commit.enable" + private def conf = SparkEnv.get.conf + + /** High level consumer to connect to Kafka. */ + private var consumerConnector: ConsumerConnector = null + + /** zkClient to connect to Zookeeper to commit the offsets. */ + private var zkClient: ZkClient = null + + /** + * A HashMap to manage the offset for each topic/partition, this HashMap is called in + * synchronized block, so mutable HashMap will not meet concurrency issue. + */ + private var topicPartitionOffsetMap: mutable.HashMap[TopicAndPartition, Long] = null + + /** A concurrent HashMap to store the stream block id and related offset snapshot. */ + private var blockOffsetMap: ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]] = null + + /** + * Manage the BlockGenerator in receiver itself for better managing block store and offset + * commit. + */ + private var blockGenerator: BlockGenerator = null + + /** Thread pool running the handlers for receiving message from multiple topics and partitions. */ + private var messageHandlerThreadPool: ThreadPoolExecutor = null + + override def onStart(): Unit = { + logInfo(s"Starting Kafka Consumer Stream with group: $groupId") + + // Initialize the topic-partition / offset hash map. + topicPartitionOffsetMap = new mutable.HashMap[TopicAndPartition, Long] + + // Initialize the stream block id / offset snapshot hash map. + blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]() + + // Initialize the block generator for storing Kafka message. + blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf) + + if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") { + logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " + + "otherwise we will manually set it to false to turn off auto offset commit in Kafka") + } + + val props = new Properties() + kafkaParams.foreach(param => props.put(param._1, param._2)) + // Manually set "auto.commit.enable" to "false" no matter user explicitly set it to true, + // we have to make sure this property is set to false to turn off auto commit mechanism in + // Kafka. + props.setProperty(AUTO_OFFSET_COMMIT, "false") + + val consumerConfig = new ConsumerConfig(props) + + assert(!consumerConfig.autoCommitEnable) + + logInfo(s"Connecting to Zookeeper: ${consumerConfig.zkConnect}") + consumerConnector = Consumer.create(consumerConfig) + logInfo(s"Connected to Zookeeper: ${consumerConfig.zkConnect}") + + zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs, + consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer) + + messageHandlerThreadPool = Utils.newDaemonFixedThreadPool( + topics.values.sum, "KafkaMessageHandler") + + blockGenerator.start() + + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(consumerConfig.props) + .asInstanceOf[Decoder[K]] + + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(consumerConfig.props) + .asInstanceOf[Decoder[V]] + + val topicMessageStreams = consumerConnector.createMessageStreams( + topics, keyDecoder, valueDecoder) + + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => + messageHandlerThreadPool.submit(new MessageHandler(stream)) + } + } + } + + override def onStop(): Unit = { + if (messageHandlerThreadPool != null) { + messageHandlerThreadPool.shutdown() + messageHandlerThreadPool = null + } + + if (consumerConnector != null) { + consumerConnector.shutdown() + consumerConnector = null + } + + if (zkClient != null) { + zkClient.close() + zkClient = null + } + + if (blockGenerator != null) { + blockGenerator.stop() + blockGenerator = null + } + + if (topicPartitionOffsetMap != null) { + topicPartitionOffsetMap.clear() + topicPartitionOffsetMap = null + } + + if (blockOffsetMap != null) { + blockOffsetMap.clear() + blockOffsetMap = null + } + } + + /** Store a Kafka message and the associated metadata as a tuple. */ + private def storeMessageAndMetadata( + msgAndMetadata: MessageAndMetadata[K, V]): Unit = { + val topicAndPartition = TopicAndPartition(msgAndMetadata.topic, msgAndMetadata.partition) + val data = (msgAndMetadata.key, msgAndMetadata.message) + val metadata = (topicAndPartition, msgAndMetadata.offset) + blockGenerator.addDataWithCallback(data, metadata) + } + + /** Update stored offset */ + private def updateOffset(topicAndPartition: TopicAndPartition, offset: Long): Unit = { + topicPartitionOffsetMap.put(topicAndPartition, offset) + } + + /** + * Remember the current offsets for each topic and partition. This is called when a block is + * generated. + */ + private def rememberBlockOffsets(blockId: StreamBlockId): Unit = { + // Get a snapshot of current offset map and store with related block id. + val offsetSnapshot = topicPartitionOffsetMap.toMap + blockOffsetMap.put(blockId, offsetSnapshot) + topicPartitionOffsetMap.clear() + } + + /** Store the ready-to-be-stored block and commit the related offsets to zookeeper. */ + private def storeBlockAndCommitOffset( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) + Option(blockOffsetMap.get(blockId)).foreach(commitOffset) + blockOffsetMap.remove(blockId) + } + + /** + * Commit the offset of Kafka's topic/partition, the commit mechanism follow Kafka 0.8.x's + * metadata schema in Zookeeper. + */ + private def commitOffset(offsetMap: Map[TopicAndPartition, Long]): Unit = { + if (zkClient == null) { + val thrown = new IllegalStateException("Zookeeper client is unexpectedly null") + stop("Zookeeper client is not initialized before commit offsets to ZK", thrown) + return + } + + for ((topicAndPart, offset) <- offsetMap) { + try { + val topicDirs = new ZKGroupTopicDirs(groupId, topicAndPart.topic) + val zkPath = s"${topicDirs.consumerOffsetDir}/${topicAndPart.partition}" + + ZkUtils.updatePersistentPath(zkClient, zkPath, offset.toString) + } catch { + case e: Exception => + logWarning(s"Exception during commit offset $offset for topic" + + s"${topicAndPart.topic}, partition ${topicAndPart.partition}", e) + } + + logInfo(s"Committed offset $offset for topic ${topicAndPart.topic}, " + + s"partition ${topicAndPart.partition}") + } + } + + /** Class to handle received Kafka message. */ + private final class MessageHandler(stream: KafkaStream[K, V]) extends Runnable { + override def run(): Unit = { + while (!isStopped) { + try { + val streamIterator = stream.iterator() + while (streamIterator.hasNext) { + storeMessageAndMetadata(streamIterator.next) + } + } catch { + case e: Exception => + logError("Error handling message", e) + } + } + } + } + + /** Class to handle blocks generated by the block generator. */ + private final class GeneratedBlockHandler extends BlockGeneratorListener { + + def onAddData(data: Any, metadata: Any): Unit = { + // Update the offset of the data that was added to the generator + if (metadata != null) { + val (topicAndPartition, offset) = metadata.asInstanceOf[(TopicAndPartition, Long)] + updateOffset(topicAndPartition, offset) + } + } + + def onGenerateBlock(blockId: StreamBlockId): Unit = { + // Remember the offsets of topics/partitions when a block has been generated + rememberBlockOffsets(blockId) + } + + def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + // Store block and commit the blocks offset + storeBlockAndCommitOffset(blockId, arrayBuffer) + } + + def onError(message: String, throwable: Throwable): Unit = { + reportError(message, throwable) + } + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index efb0099c7c850..6e1abf3f385ee 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -20,7 +20,10 @@ import java.io.Serializable; import java.util.HashMap; import java.util.List; +import java.util.Random; +import org.apache.spark.SparkConf; +import org.apache.spark.streaming.Duration; import scala.Predef; import scala.Tuple2; import scala.collection.JavaConverters; @@ -32,8 +35,6 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.LocalJavaStreamingContext; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -42,25 +43,27 @@ import org.junit.After; import org.junit.Before; -public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable { - private transient KafkaStreamSuite testSuite = new KafkaStreamSuite(); +public class JavaKafkaStreamSuite implements Serializable { + private transient JavaStreamingContext ssc = null; + private transient Random random = new Random(); + private transient KafkaStreamSuiteBase suiteBase = null; @Before - @Override public void setUp() { - testSuite.beforeFunction(); + suiteBase = new KafkaStreamSuiteBase() { }; + suiteBase.setupKafka(); System.clearProperty("spark.driver.port"); - //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, new Duration(500)); } @After - @Override public void tearDown() { ssc.stop(); ssc = null; System.clearProperty("spark.driver.port"); - testSuite.afterFunction(); + suiteBase.tearDownKafka(); } @Test @@ -74,15 +77,15 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - testSuite.createTopic(topic); + suiteBase.createTopic(topic); HashMap tmp = new HashMap(sent); - testSuite.produceAndSendMessage(topic, - JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms())); + suiteBase.produceAndSendMessage(topic, + JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( + Predef.>conforms())); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", testSuite.zkHost() + ":" + testSuite.zkPort()); - kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); + kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); + kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); JavaPairDStream stream = KafkaUtils.createStream(ssc, @@ -124,11 +127,16 @@ public Void call(JavaPairRDD rdd) throws Exception { ); ssc.start(); - ssc.awaitTermination(3000); - + long startTime = System.currentTimeMillis(); + boolean sizeMatches = false; + while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { + sizeMatches = sent.size() == result.size(); + Thread.sleep(200); + } Assert.assertEquals(sent.size(), result.size()); for (String k : sent.keySet()) { Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); } + ssc.stop(); } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 6943326eb750e..b19c053ebfc44 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -19,51 +19,57 @@ package org.apache.spark.streaming.kafka import java.io.File import java.net.InetSocketAddress -import java.util.{Properties, Random} +import java.util.Properties import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random import kafka.admin.CreateTopicCommand import kafka.common.{KafkaException, TopicAndPartition} -import kafka.producer.{KeyedMessage, ProducerConfig, Producer} -import kafka.utils.ZKStringSerializer +import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.{StringDecoder, StringEncoder} import kafka.server.{KafkaConfig, KafkaServer} - +import kafka.utils.ZKStringSerializer import org.I0Itec.zkclient.ZkClient +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually -import org.apache.zookeeper.server.ZooKeeperServer -import org.apache.zookeeper.server.NIOServerCnxnFactory - -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class KafkaStreamSuite extends TestSuiteBase { - import KafkaTestUtils._ - - val zkHost = "localhost" - var zkPort: Int = 0 - val zkConnectionTimeout = 6000 - val zkSessionTimeout = 6000 - - protected var brokerPort = 9092 - protected var brokerConf: KafkaConfig = _ - protected var zookeeper: EmbeddedZookeeper = _ - protected var zkClient: ZkClient = _ - protected var server: KafkaServer = _ - protected var producer: Producer[String, String] = _ - - override def useManualClock = false - - override def beforeFunction() { +/** + * This is an abstract base class for Kafka testsuites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + */ +abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { + + var zkAddress: String = _ + var zkClient: ZkClient = _ + + private val zkHost = "localhost" + private val zkConnectionTimeout = 6000 + private val zkSessionTimeout = 6000 + private var zookeeper: EmbeddedZookeeper = _ + private var zkPort: Int = 0 + private var brokerPort = 9092 + private var brokerConf: KafkaConfig = _ + private var server: KafkaServer = _ + private var producer: Producer[String, String] = _ + + def setupKafka() { // Zookeeper server startup zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") // Get the actual zookeeper binding port zkPort = zookeeper.actualPort + zkAddress = s"$zkHost:$zkPort" logInfo("==================== 0 ====================") - zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) logInfo("==================== 1 ====================") @@ -71,7 +77,7 @@ class KafkaStreamSuite extends TestSuiteBase { var bindSuccess: Boolean = false while(!bindSuccess) { try { - val brokerProps = getBrokerConfig(brokerPort, s"$zkHost:$zkPort") + val brokerProps = getBrokerConfig() brokerConf = new KafkaConfig(brokerProps) server = new KafkaServer(brokerConf) logInfo("==================== 2 ====================") @@ -89,53 +95,30 @@ class KafkaStreamSuite extends TestSuiteBase { Thread.sleep(2000) logInfo("==================== 4 ====================") - super.beforeFunction() } - override def afterFunction() { - producer.close() - server.shutdown() - brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - - zkClient.close() - zookeeper.shutdown() - - super.afterFunction() - } - - test("Kafka input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val topic = "topic1" - val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) - produceAndSendMessage(topic, sent) + def tearDownKafka() { + if (producer != null) { + producer.close() + producer = null + } - val kafkaParams = Map("zookeeper.connect" -> s"$zkHost:$zkPort", - "group.id" -> s"test-consumer-${random.nextInt(10000)}", - "auto.offset.reset" -> "smallest") + if (server != null) { + server.shutdown() + server = null + } - val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( - ssc, - kafkaParams, - Map(topic -> 1), - StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() - stream.map { case (k, v) => v } - .countByValue() - .foreachRDD { r => - val ret = r.collect() - ret.toMap.foreach { kv => - val count = result.getOrElseUpdate(kv._1, 0) + kv._2 - result.put(kv._1, count) - } - } - ssc.start() - ssc.awaitTermination(3000) + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - assert(sent.size === result.size) - sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } + if (zkClient != null) { + zkClient.close() + zkClient = null + } - ssc.stop() + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } } private def createTestMessage(topic: String, sent: Map[String, Int]) @@ -150,58 +133,43 @@ class KafkaStreamSuite extends TestSuiteBase { CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") logInfo("==================== 5 ====================") // wait until metadata is propagated - waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000) + waitUntilMetadataIsPropagated(topic, 0) } def produceAndSendMessage(topic: String, sent: Map[String, Int]) { - val brokerAddr = brokerConf.hostName + ":" + brokerConf.port - producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr))) + producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) producer.send(createTestMessage(topic, sent): _*) + producer.close() logInfo("==================== 6 ====================") } -} - -object KafkaTestUtils { - val random = new Random() - def getBrokerConfig(port: Int, zkConnect: String): Properties = { + private def getBrokerConfig(): Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") - props.put("port", port.toString) + props.put("port", brokerPort.toString) props.put("log.dir", Utils.createTempDir().getAbsolutePath) - props.put("zookeeper.connect", zkConnect) + props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props } - def getProducerConfig(brokerList: String): Properties = { + private def getProducerConfig(): Properties = { + val brokerAddr = brokerConf.hostName + ":" + brokerConf.port val props = new Properties() - props.put("metadata.broker.list", brokerList) + props.put("metadata.broker.list", brokerAddr) props.put("serializer.class", classOf[StringEncoder].getName) props } - def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = { - val startTime = System.currentTimeMillis() - while (true) { - if (condition()) - return true - if (System.currentTimeMillis() > startTime + waitTime) - return false - Thread.sleep(waitTime.min(100L)) + private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { + eventually(timeout(1000 milliseconds), interval(100 milliseconds)) { + assert( + server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)), + s"Partition [$topic, $partition] metadata not propagated after timeout" + ) } - // Should never go to here - throw new RuntimeException("unexpected error") - } - - def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int, - timeout: Long) { - assert(waitUntilTrue(() => - servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains( - TopicAndPartition(topic, partition))), timeout), - s"Partition [$topic, $partition] metadata not propagated after timeout") } class EmbeddedZookeeper(val zkConnect: String) { @@ -227,3 +195,53 @@ object KafkaTestUtils { } } } + + +class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { + var ssc: StreamingContext = _ + + before { + setupKafka() + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + tearDownKafka() + } + + test("Kafka input stream") { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + produceAndSendMessage(topic, sent) + + val kafkaParams = Map("zookeeper.connect" -> zkAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}", + "auto.offset.reset" -> "smallest") + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map(_._2).countByValue().foreachRDD { r => + val ret = r.collect() + ret.toMap.foreach { kv => + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } + } + ssc.start() + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + assert(sent.size === result.size) + sent.keys.foreach { k => + assert(sent(k) === result(k).toInt) + } + } + ssc.stop() + } +} + diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala new file mode 100644 index 0000000000000..64ccc92c81fa9 --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + + +import java.io.File + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import com.google.common.io.Files +import kafka.serializer.StringDecoder +import kafka.utils.{ZKGroupTopicDirs, ZkUtils} +import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually + +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} + +class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { + + val sparkConf = new SparkConf() + .setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + + + var groupId: String = _ + var kafkaParams: Map[String, String] = _ + var ssc: StreamingContext = _ + var tempDirectory: File = null + + before { + setupKafka() + groupId = s"test-consumer-${Random.nextInt(10000)}" + kafkaParams = Map( + "zookeeper.connect" -> zkAddress, + "group.id" -> groupId, + "auto.offset.reset" -> "smallest" + ) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + tempDirectory = Files.createTempDir() + ssc.checkpoint(tempDirectory.getAbsolutePath) + } + + after { + if (ssc != null) { + ssc.stop() + } + if (tempDirectory != null && tempDirectory.exists()) { + FileUtils.deleteDirectory(tempDirectory) + tempDirectory = null + } + tearDownKafka() + } + + + test("Reliable Kafka input stream with single topic") { + var topic = "test-topic" + createTopic(topic) + produceAndSendMessage(topic, data) + + // Verify whether the offset of this group/topic/partition is 0 before starting. + assert(getCommitOffset(groupId, topic, 0) === None) + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map { case (k, v) => v }.foreachRDD { r => + val ret = r.collect() + ret.foreach { v => + val count = result.getOrElseUpdate(v, 0) + 1 + result.put(v, count) + } + } + ssc.start() + eventually(timeout(20000 milliseconds), interval(200 milliseconds)) { + // A basic process verification for ReliableKafkaReceiver. + // Verify whether received message number is equal to the sent message number. + assert(data.size === result.size) + // Verify whether each message is the same as the data to be verified. + data.keys.foreach { k => assert(data(k) === result(k).toInt) } + // Verify the offset number whether it is equal to the total message number. + assert(getCommitOffset(groupId, topic, 0) === Some(29L)) + } + ssc.stop() + } + + test("Reliable Kafka input stream with multiple topics") { + val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) + topics.foreach { case (t, _) => + createTopic(t) + produceAndSendMessage(t, data) + } + + // Before started, verify all the group/topic/partition offsets are 0. + topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === None) } + + // Consuming all the data sent to the broker which will potential commit the offsets internally. + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY) + stream.foreachRDD(_ => Unit) + ssc.start() + eventually(timeout(20000 milliseconds), interval(100 milliseconds)) { + // Verify the offset for each group/topic to see whether they are equal to the expected one. + topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) } + } + ssc.stop() + } + + + /** Getting partition offset from Zookeeper. */ + private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = { + assert(zkClient != null, "Zookeeper client is not initialized") + val topicDirs = new ZKGroupTopicDirs(groupId, topic) + val zkPath = s"${topicDirs.consumerOffsetDir}/$partition" + ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong) + } +} diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 371f1f1e9d39a..362a76e515938 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -52,11 +52,6 @@ mqtt-client 0.4.0
    - - ${akka.group} - akka-zeromq_${scala.binary.version} - ${akka.version} - org.scalatest scalatest_${scala.binary.version} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala new file mode 100644 index 0000000000000..f70715fca6eea --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +/** + * Represents an edge along with its neighboring vertices and allows sending messages along the + * edge. Used in [[Graph#aggregateMessages]]. + */ +abstract class EdgeContext[VD, ED, A] { + /** The vertex id of the edge's source vertex. */ + def srcId: VertexId + /** The vertex id of the edge's destination vertex. */ + def dstId: VertexId + /** The vertex attribute of the edge's source vertex. */ + def srcAttr: VD + /** The vertex attribute of the edge's destination vertex. */ + def dstAttr: VD + /** The attribute associated with the edge. */ + def attr: ED + + /** Sends a message to the source vertex. */ + def sendToSrc(msg: A): Unit + /** Sends a message to the destination vertex. */ + def sendToDst(msg: A): Unit + + /** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */ + def toEdgeTriplet: EdgeTriplet[VD, ED] = { + val et = new EdgeTriplet[VD, ED] + et.srcId = srcId + et.srcAttr = srcAttr + et.dstId = dstId + et.dstAttr = dstAttr + et.attr = attr + et + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 5267560b3e5ce..cc70b396a8dd4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -17,14 +17,19 @@ package org.apache.spark.graphx -import scala.reflect.{classTag, ClassTag} +import scala.language.existentials +import scala.reflect.ClassTag -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.Dependency +import org.apache.spark.Partition +import org.apache.spark.SparkContext +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.graphx.impl.EdgePartitionBuilder +import org.apache.spark.graphx.impl.EdgeRDDImpl /** * `EdgeRDD[ED, VD]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each @@ -32,33 +37,16 @@ import org.apache.spark.graphx.impl.EdgePartitionBuilder * edge to provide the triplet view. Shipping of the vertex attributes is managed by * `impl.ReplicatedVertexView`. */ -class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( - val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], - val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) - extends RDD[Edge[ED]](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { +abstract class EdgeRDD[ED]( + @transient sc: SparkContext, + @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { - override def setName(_name: String): this.type = { - if (partitionsRDD.name != null) { - partitionsRDD.setName(partitionsRDD.name + ", " + _name) - } else { - partitionsRDD.setName(_name) - } - this - } - setName("EdgeRDD") + private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } override protected def getPartitions: Array[Partition] = partitionsRDD.partitions - /** - * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the - * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new - * partitioner that allows co-partitioning with `partitionsRDD`. - */ - override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) - override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { - val p = firstParent[(PartitionID, EdgePartition[ED, VD])].iterator(part, context) + val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { p.next._2.iterator.map(_.copy()) } else { @@ -66,45 +54,6 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( } } - override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() - - /** - * Persists the edge partitions at the specified storage level, ignoring any existing target - * storage level. - */ - override def persist(newLevel: StorageLevel): this.type = { - partitionsRDD.persist(newLevel) - this - } - - override def unpersist(blocking: Boolean = true): this.type = { - partitionsRDD.unpersist(blocking) - this - } - - /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ - override def cache(): this.type = { - partitionsRDD.persist(targetStorageLevel) - this - } - - /** The number of edges in the RDD. */ - override def count(): Long = { - partitionsRDD.map(_._2.size.toLong).reduce(_ + _) - } - - private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( - f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = { - this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => - if (iter.hasNext) { - val (pid, ep) = iter.next() - Iterator(Tuple2(pid, f(pid, ep))) - } else { - Iterator.empty - } - }, preservesPartitioning = true)) - } - /** * Map the values in an edge partitioning preserving the structure but changing the values. * @@ -112,22 +61,14 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( * @param f the function from an edge to a new edge value * @return a new EdgeRDD containing the new edge values */ - def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] = - mapEdgePartitions((pid, part) => part.map(f)) + def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2] /** * Reverse all the edges in this RDD. * * @return a new EdgeRDD containing all the edges reversed */ - def reverse: EdgeRDD[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) - - /** Removes all edges but those matching `epred` and where both vertices match `vpred`. */ - def filter( - epred: EdgeTriplet[VD, ED] => Boolean, - vpred: (VertexId, VD) => Boolean): EdgeRDD[ED, VD] = { - mapEdgePartitions((pid, part) => part.filter(epred, vpred)) - } + def reverse: EdgeRDD[ED] /** * Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same @@ -139,23 +80,8 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( * with values supplied by `f` */ def innerJoin[ED2: ClassTag, ED3: ClassTag] - (other: EdgeRDD[ED2, _]) - (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] = { - val ed2Tag = classTag[ED2] - val ed3Tag = classTag[ED3] - this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) { - (thisIter, otherIter) => - val (pid, thisEPart) = thisIter.next() - val (_, otherEPart) = otherIter.next() - Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag))) - }) - } - - /** Replaces the vertex partitions while preserving all other properties of the VertexRDD. */ - private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( - partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDD[ED2, VD2] = { - new EdgeRDD(partitionsRDD, this.targetStorageLevel) - } + (other: EdgeRDD[ED2]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] /** * Changes the target storage level while preserving all other properties of the @@ -164,11 +90,7 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( * This does not actually trigger a cache; to do this, call * [[org.apache.spark.graphx.EdgeRDD#cache]] on the returned EdgeRDD. */ - private[graphx] def withTargetStorageLevel( - targetStorageLevel: StorageLevel): EdgeRDD[ED, VD] = { - new EdgeRDD(this.partitionsRDD, targetStorageLevel) - } - + private[graphx] def withTargetStorageLevel(targetStorageLevel: StorageLevel): EdgeRDD[ED] } object EdgeRDD { @@ -178,7 +100,7 @@ object EdgeRDD { * @tparam ED the edge attribute type * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD */ - def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDD[ED, VD] = { + def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDDImpl[ED, VD] = { val edgePartitions = edges.mapPartitionsWithIndex { (pid, iter) => val builder = new EdgePartitionBuilder[ED, VD] iter.foreach { e => @@ -195,8 +117,8 @@ object EdgeRDD { * @tparam ED the edge attribute type * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD */ - def fromEdgePartitions[ED: ClassTag, VD: ClassTag]( - edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDD[ED, VD] = { - new EdgeRDD(edgePartitions) + private[graphx] def fromEdgePartitions[ED: ClassTag, VD: ClassTag]( + edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDDImpl[ED, VD] = { + new EdgeRDDImpl(edgePartitions) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index fa4b891754c40..637791543514c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * along with their vertex data. * */ - @transient val edges: EdgeRDD[ED, VD] + @transient val edges: EdgeRDD[ED] /** * An RDD containing the edge triplets, which are edges along with the vertex data associated with @@ -208,7 +208,37 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * */ def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { - mapTriplets((pid, iter) => iter.map(map)) + mapTriplets((pid, iter) => iter.map(map), TripletFields.All) + } + + /** + * Transforms each edge attribute using the map function, passing it the adjacent vertex + * attributes as well. If adjacent vertex values are not required, + * consider using `mapEdges` instead. + * + * @note This does not change the structure of the + * graph or modify the values of this graph. As a consequence + * the underlying index structures can be reused. + * + * @param map the function from an edge object to a new edge value. + * @param tripletFields which fields should be included in the edge triplet passed to the map + * function. If not all fields are needed, specifying this can improve performance. + * + * @tparam ED2 the new edge data type + * + * @example This function might be used to initialize edge + * attributes based on the attributes associated with each vertex. + * {{{ + * val rawGraph: Graph[Int, Int] = someLoadFunction() + * val graph = rawGraph.mapTriplets[Int]( edge => + * edge.src.data - edge.dst.data) + * }}} + * + */ + def mapTriplets[ED2: ClassTag]( + map: EdgeTriplet[VD, ED] => ED2, + tripletFields: TripletFields): Graph[VD, ED2] = { + mapTriplets((pid, iter) => iter.map(map), tripletFields) } /** @@ -223,12 +253,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * the underlying index structures can be reused. * * @param map the iterator transform + * @param tripletFields which fields should be included in the edge triplet passed to the map + * function. If not all fields are needed, specifying this can improve performance. * * @tparam ED2 the new edge data type * */ - def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]) - : Graph[VD, ED2] + def mapTriplets[ED2: ClassTag]( + map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] /** * Reverses all edges in the graph. If this graph contains an edge from a to b then the returned @@ -287,6 +320,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of * the map phase destined to each vertex. * + * This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead. + * * @tparam A the type of "message" to be sent to each vertex * * @param mapFunc the user defined map function which returns 0 or @@ -296,13 +331,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * be commutative and associative and is used to combine the output * of the map phase * - * @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to - * consider when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on - * edges with destination in the active set. If the direction is `Out`, - * `mapFunc` will only be run on edges originating from vertices in the active set. If the - * direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set - * . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the - * active set. The active set must have the same index as the graph's vertices. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run only on edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. * * @example We can use this function to compute the in-degree of each * vertex @@ -319,6 +356,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * predicate or implement PageRank. * */ + @deprecated("use aggregateMessages", "1.2.0") def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, @@ -326,8 +364,80 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab : VertexRDD[A] /** - * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The - * input table should contain at most one entry for each vertex. If no entry in `other` is + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * + * @example We can use this function to compute the in-degree of each + * vertex + * {{{ + * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") + * val inDeg: RDD[(VertexId, Int)] = + * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * }}} + * + * @note By expressing computation at the edge level we achieve + * maximum parallelism. This is one of the core functions in the + * Graph API in that enables neighborhood level computation. For + * example this function can be used to count neighbors satisfying a + * predicate or implement PageRank. + * + */ + def aggregateMessages[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields = TripletFields.All) + : VertexRDD[A] = { + aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None) + } + + /** + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * This variant can take an active set to restrict the computation and is intended for internal + * use only. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run on only edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. + */ + private[graphx] def aggregateMessagesWithActiveSet[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]) + : VertexRDD[A] + + /** + * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. + * The input table should contain at most one entry for each vertex. If no entry in `other` is * provided for a particular vertex in the graph, the map function receives `None`. * * @tparam U the type of entry in the table of updates diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index d0dd45dba618e..d5150382d599b 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali */ private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = { if (edgeDirection == EdgeDirection.In) { - graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _) + graph.aggregateMessages(_.sendToDst(1), _ + _, TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { - graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _) + graph.aggregateMessages(_.sendToSrc(1), _ + _, TripletFields.None) } else { // EdgeDirection.Either - graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _) + graph.aggregateMessages(ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) }, _ + _, + TripletFields.None) } } @@ -88,18 +89,17 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = { val nbrs = if (edgeDirection == EdgeDirection.Either) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _ - ) + graph.aggregateMessages[Array[VertexId]]( + ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) }, + _ ++ _, TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.srcId, Array(et.dstId))), - reduceFunc = _ ++ _) + graph.aggregateMessages[Array[VertexId]]( + ctx => ctx.sendToSrc(Array(ctx.dstId)), + _ ++ _, TripletFields.None) } else if (edgeDirection == EdgeDirection.In) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _) + graph.aggregateMessages[Array[VertexId]]( + ctx => ctx.sendToDst(Array(ctx.srcId)), + _ ++ _, TripletFields.None) } else { throw new SparkException("It doesn't make sense to collect neighbor ids without a " + "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)") @@ -122,22 +122,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @return the vertex set of neighboring vertex attributes for each vertex */ def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { - val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]]( - edge => { - val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr))) - val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr))) - edgeDirection match { - case EdgeDirection.Either => Iterator(msgToSrc, msgToDst) - case EdgeDirection.In => Iterator(msgToDst) - case EdgeDirection.Out => Iterator(msgToSrc) - case EdgeDirection.Both => - throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" + - "EdgeDirection.Either instead.") - } - }, - (a, b) => a ++ b) - - graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) => + val nbrs = edgeDirection match { + case EdgeDirection.Either => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => { + ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) + ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) + }, + (a, b) => a ++ b, TripletFields.SrcDstOnly) + case EdgeDirection.In => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), + (a, b) => a ++ b, TripletFields.SrcOnly) + case EdgeDirection.Out => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), + (a, b) => a ++ b, TripletFields.DstOnly) + case EdgeDirection.Both => + throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + + "EdgeDirection.Either instead.") + } + graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) => nbrsOpt.getOrElse(Array.empty[(VertexId, VD)]) } } // end of collectNeighbor @@ -160,18 +165,20 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = { edgeDirection match { case EdgeDirection.Either => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))), - (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => { + ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))) + ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))) + }, + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.In => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))), + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Out => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))), + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Both => throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + "EdgeDirection.Either instead.") diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java new file mode 100644 index 0000000000000..34df4b7ee7a06 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx; + +import java.io.Serializable; + +/** + * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the + * system to populate only those fields for efficiency. + */ +public class TripletFields implements Serializable { + public final boolean useSrc; + public final boolean useDst; + public final boolean useEdge; + + public TripletFields() { + this(true, true, true); + } + + public TripletFields(boolean useSrc, boolean useDst, boolean useEdge) { + this.useSrc = useSrc; + this.useDst = useDst; + this.useEdge = useEdge; + } + + public static final TripletFields None = new TripletFields(false, false, false); + public static final TripletFields EdgeOnly = new TripletFields(false, false, true); + public static final TripletFields SrcOnly = new TripletFields(true, false, false); + public static final TripletFields DstOnly = new TripletFields(false, true, false); + public static final TripletFields SrcDstOnly = new TripletFields(true, true, false); + public static final TripletFields SrcAndEdge = new TripletFields(true, false, true); + public static final TripletFields Src = SrcAndEdge; + public static final TripletFields DstAndEdge = new TripletFields(false, true, true); + public static final TripletFields Dst = DstAndEdge; + public static final TripletFields All = new TripletFields(true, true, true); +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 12216d9d33d66..1db3df03c8052 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -27,6 +27,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx.impl.RoutingTablePartition import org.apache.spark.graphx.impl.ShippableVertexPartition import org.apache.spark.graphx.impl.VertexAttributeBlock +import org.apache.spark.graphx.impl.VertexRDDImpl /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by @@ -53,62 +54,16 @@ import org.apache.spark.graphx.impl.VertexAttributeBlock * * @tparam VD the vertex attribute associated with each vertex in the set. */ -class VertexRDD[@specialized VD: ClassTag]( - val partitionsRDD: RDD[ShippableVertexPartition[VD]], - val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) - extends RDD[(VertexId, VD)](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { +abstract class VertexRDD[VD]( + @transient sc: SparkContext, + @transient deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { - require(partitionsRDD.partitioner.isDefined) + implicit protected def vdTag: ClassTag[VD] - /** - * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting - * VertexRDD will be based on a different index and can no longer be quickly joined with this - * RDD. - */ - def reindex(): VertexRDD[VD] = this.withPartitionsRDD(partitionsRDD.map(_.reindex())) - - override val partitioner = partitionsRDD.partitioner + private[graphx] def partitionsRDD: RDD[ShippableVertexPartition[VD]] override protected def getPartitions: Array[Partition] = partitionsRDD.partitions - override protected def getPreferredLocations(s: Partition): Seq[String] = - partitionsRDD.preferredLocations(s) - - override def setName(_name: String): this.type = { - if (partitionsRDD.name != null) { - partitionsRDD.setName(partitionsRDD.name + ", " + _name) - } else { - partitionsRDD.setName(_name) - } - this - } - setName("VertexRDD") - - /** - * Persists the vertex partitions at the specified storage level, ignoring any existing target - * storage level. - */ - override def persist(newLevel: StorageLevel): this.type = { - partitionsRDD.persist(newLevel) - this - } - - override def unpersist(blocking: Boolean = true): this.type = { - partitionsRDD.unpersist(blocking) - this - } - - /** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */ - override def cache(): this.type = { - partitionsRDD.persist(targetStorageLevel) - this - } - - /** The number of vertices in the RDD. */ - override def count(): Long = { - partitionsRDD.map(_.size.toLong).reduce(_ + _) - } - /** * Provides the `RDD[(VertexId, VD)]` equivalent output. */ @@ -116,22 +71,28 @@ class VertexRDD[@specialized VD: ClassTag]( firstParent[ShippableVertexPartition[VD]].iterator(part, context).next.iterator } + /** + * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting + * VertexRDD will be based on a different index and can no longer be quickly joined with this + * RDD. + */ + def reindex(): VertexRDD[VD] + /** * Applies a function to each `VertexPartition` of this RDD and returns a new VertexRDD. */ private[graphx] def mapVertexPartitions[VD2: ClassTag]( f: ShippableVertexPartition[VD] => ShippableVertexPartition[VD2]) - : VertexRDD[VD2] = { - val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true) - this.withPartitionsRDD(newPartitionsRDD) - } - + : VertexRDD[VD2] /** * Restricts the vertex set to the set of vertices satisfying the given predicate. This operation * preserves the index for efficient joins with the original RDD, and it sets bits in the bitmask * rather than allocating new memory. * + * It is declared and defined here to allow refining the return type from `RDD[(VertexId, VD)]` to + * `VertexRDD[VD]`. + * * @param pred the user defined predicate, which takes a tuple to conform to the * `RDD[(VertexId, VD)]` interface */ @@ -147,8 +108,7 @@ class VertexRDD[@specialized VD: ClassTag]( * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the * original VertexRDD */ - def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] = - this.mapVertexPartitions(_.map((vid, attr) => f(attr))) + def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] /** * Maps each vertex attribute, additionally supplying the vertex ID. @@ -159,23 +119,13 @@ class VertexRDD[@specialized VD: ClassTag]( * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the * original VertexRDD. The resulting VertexRDD retains the same index. */ - def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = - this.mapVertexPartitions(_.map(f)) + def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] /** * Hides vertices that are the same between `this` and `other`; for vertices that are different, * keeps the values from `other`. */ - def diff(other: VertexRDD[VD]): VertexRDD[VD] = { - val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true - ) { (thisIter, otherIter) => - val thisPart = thisIter.next() - val otherPart = otherIter.next() - Iterator(thisPart.diff(otherPart)) - } - this.withPartitionsRDD(newPartitionsRDD) - } + def diff(other: VertexRDD[VD]): VertexRDD[VD] /** * Left joins this RDD with another VertexRDD with the same index. This function will fail if @@ -192,16 +142,7 @@ class VertexRDD[@specialized VD: ClassTag]( * @return a VertexRDD containing the results of `f` */ def leftZipJoin[VD2: ClassTag, VD3: ClassTag] - (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] = { - val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true - ) { (thisIter, otherIter) => - val thisPart = thisIter.next() - val otherPart = otherIter.next() - Iterator(thisPart.leftJoin(otherPart)(f)) - } - this.withPartitionsRDD(newPartitionsRDD) - } + (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] /** * Left joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is @@ -222,37 +163,14 @@ class VertexRDD[@specialized VD: ClassTag]( def leftJoin[VD2: ClassTag, VD3: ClassTag] (other: RDD[(VertexId, VD2)]) (f: (VertexId, VD, Option[VD2]) => VD3) - : VertexRDD[VD3] = { - // Test if the other vertex is a VertexRDD to choose the optimal join strategy. - // If the other set is a VertexRDD then we use the much more efficient leftZipJoin - other match { - case other: VertexRDD[_] => - leftZipJoin(other)(f) - case _ => - this.withPartitionsRDD[VD3]( - partitionsRDD.zipPartitions( - other.partitionBy(this.partitioner.get), preservesPartitioning = true) { - (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f)) - } - ) - } - } + : VertexRDD[VD3] /** * Efficiently inner joins this VertexRDD with another VertexRDD sharing the same index. See * [[innerJoin]] for the behavior of the join. */ def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U]) - (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { - val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true - ) { (thisIter, otherIter) => - val thisPart = thisIter.next() - val otherPart = otherIter.next() - Iterator(thisPart.innerJoin(otherPart)(f)) - } - this.withPartitionsRDD(newPartitionsRDD) - } + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] /** * Inner joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is @@ -266,21 +184,7 @@ class VertexRDD[@specialized VD: ClassTag]( * `this` and `other`, with values supplied by `f` */ def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) - (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { - // Test if the other vertex is a VertexRDD to choose the optimal join strategy. - // If the other set is a VertexRDD then we use the much more efficient innerZipJoin - other match { - case other: VertexRDD[_] => - innerZipJoin(other)(f) - case _ => - this.withPartitionsRDD( - partitionsRDD.zipPartitions( - other.partitionBy(this.partitioner.get), preservesPartitioning = true) { - (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f)) - } - ) - } - } + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] /** * Aggregates vertices in `messages` that have the same ids using `reduceFunc`, returning a @@ -294,38 +198,20 @@ class VertexRDD[@specialized VD: ClassTag]( * messages. */ def aggregateUsingIndex[VD2: ClassTag]( - messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = { - val shuffled = messages.partitionBy(this.partitioner.get) - val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) => - thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc)) - } - this.withPartitionsRDD[VD2](parts) - } + messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] /** * Returns a new `VertexRDD` reflecting a reversal of all edge directions in the corresponding * [[EdgeRDD]]. */ - def reverseRoutingTables(): VertexRDD[VD] = - this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse)) + def reverseRoutingTables(): VertexRDD[VD] /** Prepares this VertexRDD for efficient joins with the given EdgeRDD. */ - def withEdges(edges: EdgeRDD[_, _]): VertexRDD[VD] = { - val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get) - val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) { - (partIter, routingTableIter) => - val routingTable = - if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty - partIter.map(_.withRoutingTable(routingTable)) - } - this.withPartitionsRDD(vertexPartitions) - } + def withEdges(edges: EdgeRDD[_]): VertexRDD[VD] /** Replaces the vertex partitions while preserving all other properties of the VertexRDD. */ private[graphx] def withPartitionsRDD[VD2: ClassTag]( - partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] = { - new VertexRDD(partitionsRDD, this.targetStorageLevel) - } + partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] /** * Changes the target storage level while preserving all other properties of the @@ -335,20 +221,14 @@ class VertexRDD[@specialized VD: ClassTag]( * [[org.apache.spark.graphx.VertexRDD#cache]] on the returned VertexRDD. */ private[graphx] def withTargetStorageLevel( - targetStorageLevel: StorageLevel): VertexRDD[VD] = { - new VertexRDD(this.partitionsRDD, targetStorageLevel) - } + targetStorageLevel: StorageLevel): VertexRDD[VD] /** Generates an RDD of vertex attributes suitable for shipping to the edge partitions. */ private[graphx] def shipVertexAttributes( - shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = { - partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst))) - } + shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] /** Generates an RDD of vertex IDs suitable for shipping to the edge partitions. */ - private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = { - partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds())) - } + private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] } // end of VertexRDD @@ -374,7 +254,7 @@ object VertexRDD { val vertexPartitions = vPartitioned.mapPartitions( iter => Iterator(ShippableVertexPartition(iter)), preservesPartitioning = true) - new VertexRDD(vertexPartitions) + new VertexRDDImpl(vertexPartitions) } /** @@ -389,7 +269,7 @@ object VertexRDD { * @param defaultVal the vertex attribute to use when creating missing vertices */ def apply[VD: ClassTag]( - vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD): VertexRDD[VD] = { + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD): VertexRDD[VD] = { VertexRDD(vertices, edges, defaultVal, (a, b) => a) } @@ -406,7 +286,7 @@ object VertexRDD { * @param mergeFunc the commutative, associative duplicate vertex attribute merge function */ def apply[VD: ClassTag]( - vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD, mergeFunc: (VD, VD) => VD + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD, mergeFunc: (VD, VD) => VD ): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices @@ -419,7 +299,7 @@ object VertexRDD { if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal, mergeFunc)) } - new VertexRDD(vertexPartitions) + new VertexRDDImpl(vertexPartitions) } /** @@ -434,18 +314,18 @@ object VertexRDD { * @param defaultVal the vertex attribute to use when creating missing vertices */ def fromEdges[VD: ClassTag]( - edges: EdgeRDD[_, _], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = { + edges: EdgeRDD[_], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = { val routingTables = createRoutingTables(edges, new HashPartitioner(numPartitions)) val vertexPartitions = routingTables.mapPartitions({ routingTableIter => val routingTable = if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty Iterator(ShippableVertexPartition(Iterator.empty, routingTable, defaultVal)) }, preservesPartitioning = true) - new VertexRDD(vertexPartitions) + new VertexRDDImpl(vertexPartitions) } - private def createRoutingTables( - edges: EdgeRDD[_, _], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = { + private[graphx] def createRoutingTables( + edges: EdgeRDD[_], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = { // Determine which vertices each edge partition needs by creating a mapping from vid to pid. val vid2pid = edges.partitionsRDD.mapPartitions(_.flatMap( Function.tupled(RoutingTablePartition.edgePartitionToMsgs))) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java new file mode 100644 index 0000000000000..377ae849f045c --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl; + +/** + * Criteria for filtering edges based on activeness. For internal use only. + */ +public enum EdgeActiveness { + /** Neither the source vertex nor the destination vertex need be active. */ + Neither, + /** The source vertex must be active. */ + SrcOnly, + /** The destination vertex must be active. */ + DstOnly, + /** Both vertices must be active. */ + Both, + /** At least one vertex must be active. */ + Either +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index a5c9cd1f8b4e6..373af75448374 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -21,63 +21,94 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet /** - * A collection of edges stored in columnar format, along with any vertex attributes referenced. The - * edges are stored in 3 large columnar arrays (src, dst, attribute). The arrays are clustered by - * src. There is an optional active vertex set for filtering computation on the edges. + * A collection of edges, along with referenced vertex attributes and an optional active vertex set + * for filtering computation on the edges. + * + * The edges are stored in columnar format in `localSrcIds`, `localDstIds`, and `data`. All + * referenced global vertex ids are mapped to a compact set of local vertex ids according to the + * `global2local` map. Each local vertex id is a valid index into `vertexAttrs`, which stores the + * corresponding vertex attribute, and `local2global`, which stores the reverse mapping to global + * vertex id. The global vertex ids that are active are optionally stored in `activeSet`. + * + * The edges are clustered by source vertex id, and the mapping from global vertex id to the index + * of the corresponding edge cluster is stored in `index`. * * @tparam ED the edge attribute type * @tparam VD the vertex attribute type * - * @param srcIds the source vertex id of each edge - * @param dstIds the destination vertex id of each edge + * @param localSrcIds the local source vertex id of each edge as an index into `local2global` and + * `vertexAttrs` + * @param localDstIds the local destination vertex id of each edge as an index into `local2global` + * and `vertexAttrs` * @param data the attribute associated with each edge - * @param index a clustered index on source vertex id - * @param vertices a map from referenced vertex ids to their corresponding attributes. Must - * contain all vertex ids from `srcIds` and `dstIds`, though not necessarily valid attributes for - * those vertex ids. The mask is not used. + * @param index a clustered index on source vertex id as a map from each global source vertex id to + * the offset in the edge arrays where the cluster for that vertex id begins + * @param global2local a map from referenced vertex ids to local ids which index into vertexAttrs + * @param local2global an array of global vertex ids where the offsets are local vertex ids + * @param vertexAttrs an array of vertex attributes where the offsets are local vertex ids * @param activeSet an optional active vertex set for filtering computation on the edges */ private[graphx] class EdgePartition[ @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag]( - val srcIds: Array[VertexId] = null, - val dstIds: Array[VertexId] = null, - val data: Array[ED] = null, - val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, - val vertices: VertexPartition[VD] = null, - val activeSet: Option[VertexSet] = None - ) extends Serializable { + localSrcIds: Array[Int], + localDstIds: Array[Int], + data: Array[ED], + index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + local2global: Array[VertexId], + vertexAttrs: Array[VD], + activeSet: Option[VertexSet]) + extends Serializable { - /** Return a new `EdgePartition` with the specified edge data. */ - def withData[ED2: ClassTag](data_ : Array[ED2]): EdgePartition[ED2, VD] = { - new EdgePartition(srcIds, dstIds, data_, index, vertices, activeSet) - } + /** No-arg constructor for serialization. */ + private def this() = this(null, null, null, null, null, null, null, null) - /** Return a new `EdgePartition` with the specified vertex partition. */ - def withVertices[VD2: ClassTag]( - vertices_ : VertexPartition[VD2]): EdgePartition[ED, VD2] = { - new EdgePartition(srcIds, dstIds, data, index, vertices_, activeSet) + /** Return a new `EdgePartition` with the specified edge data. */ + def withData[ED2: ClassTag](data: Array[ED2]): EdgePartition[ED2, VD] = { + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) } /** Return a new `EdgePartition` with the specified active set, provided as an iterator. */ def withActiveSet(iter: Iterator[VertexId]): EdgePartition[ED, VD] = { - val newActiveSet = new VertexSet - iter.foreach(newActiveSet.add(_)) - new EdgePartition(srcIds, dstIds, data, index, vertices, Some(newActiveSet)) - } - - /** Return a new `EdgePartition` with the specified active set. */ - def withActiveSet(activeSet_ : Option[VertexSet]): EdgePartition[ED, VD] = { - new EdgePartition(srcIds, dstIds, data, index, vertices, activeSet_) + val activeSet = new VertexSet + while (iter.hasNext) { activeSet.add(iter.next()) } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, + Some(activeSet)) } /** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */ def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = { - this.withVertices(vertices.innerJoinKeepLeft(iter)) + val newVertexAttrs = new Array[VD](vertexAttrs.length) + System.arraycopy(vertexAttrs, 0, newVertexAttrs, 0, vertexAttrs.length) + while (iter.hasNext) { + val kv = iter.next() + newVertexAttrs(global2local(kv._1)) = kv._2 + } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, + activeSet) } + /** Return a new `EdgePartition` without any locally cached vertex attributes. */ + def withoutVertexAttributes[VD2: ClassTag](): EdgePartition[ED, VD2] = { + val newVertexAttrs = new Array[VD2](vertexAttrs.length) + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, + activeSet) + } + + @inline private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) + + @inline private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) + + @inline private def attrs(pos: Int): ED = data(pos) + /** Look up vid in activeSet, throwing an exception if it is None. */ def isActive(vid: VertexId): Boolean = { activeSet.get.contains(vid) @@ -92,11 +123,19 @@ class EdgePartition[ * @return a new edge partition with all edges reversed. */ def reverse: EdgePartition[ED, VD] = { - val builder = new EdgePartitionBuilder(size)(classTag[ED], classTag[VD]) - for (e <- iterator) { - builder.add(e.dstId, e.srcId, e.attr) + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet, size) + var i = 0 + while (i < size) { + val localSrcId = localSrcIds(i) + val localDstId = localDstIds(i) + val srcId = local2global(localSrcId) + val dstId = local2global(localDstId) + val attr = data(i) + builder.add(dstId, srcId, localDstId, localSrcId, attr) + i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -157,13 +196,25 @@ class EdgePartition[ def filter( epred: EdgeTriplet[VD, ED] => Boolean, vpred: (VertexId, VD) => Boolean): EdgePartition[ED, VD] = { - val filtered = tripletIterator().filter(et => - vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) - val builder = new EdgePartitionBuilder[ED, VD] - for (e <- filtered) { - builder.add(e.srcId, e.dstId, e.attr) + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet) + var i = 0 + while (i < size) { + // The user sees the EdgeTriplet, so we can't reuse it and must create one per edge. + val localSrcId = localSrcIds(i) + val localDstId = localDstIds(i) + val et = new EdgeTriplet[VD, ED] + et.srcId = local2global(localSrcId) + et.dstId = local2global(localDstId) + et.srcAttr = vertexAttrs(localSrcId) + et.dstAttr = vertexAttrs(localDstId) + et.attr = data(i) + if (vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) { + builder.add(et.srcId, et.dstId, localSrcId, localDstId, et.attr) + } + i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -183,28 +234,40 @@ class EdgePartition[ * @return a new edge partition without duplicate edges */ def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED, VD] = { - val builder = new EdgePartitionBuilder[ED, VD] + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet) var currSrcId: VertexId = null.asInstanceOf[VertexId] var currDstId: VertexId = null.asInstanceOf[VertexId] + var currLocalSrcId = -1 + var currLocalDstId = -1 var currAttr: ED = null.asInstanceOf[ED] + // Iterate through the edges, accumulating runs of identical edges using the curr* variables and + // releasing them to the builder when we see the beginning of the next run var i = 0 while (i < size) { if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) { + // This edge should be accumulated into the existing run currAttr = merge(currAttr, data(i)) } else { + // This edge starts a new run of edges if (i > 0) { - builder.add(currSrcId, currDstId, currAttr) + // First release the existing run to the builder + builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr) } + // Then start accumulating for a new run currSrcId = srcIds(i) currDstId = dstIds(i) + currLocalSrcId = localSrcIds(i) + currLocalDstId = localDstIds(i) currAttr = data(i) } i += 1 } + // Finally, release the last accumulated run if (size > 0) { - builder.add(currSrcId, currDstId, currAttr) + builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr) } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -220,7 +283,8 @@ class EdgePartition[ def innerJoin[ED2: ClassTag, ED3: ClassTag] (other: EdgePartition[ED2, _]) (f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3, VD] = { - val builder = new EdgePartitionBuilder[ED3, VD] + val builder = new ExistingEdgePartitionBuilder[ED3, VD]( + global2local, local2global, vertexAttrs, activeSet) var i = 0 var j = 0 // For i = index of each edge in `this`... @@ -233,12 +297,13 @@ class EdgePartition[ while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 } if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) { // ... run `f` on the matching edge - builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j))) + builder.add(srcId, dstId, localSrcIds(i), localDstIds(i), + f(srcId, dstId, this.data(i), other.attrs(j))) } } i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -246,7 +311,7 @@ class EdgePartition[ * * @return size of the partition */ - val size: Int = srcIds.size + val size: Int = localSrcIds.size /** The number of unique source vertices in the partition. */ def indexSize: Int = index.size @@ -280,55 +345,198 @@ class EdgePartition[ * It is safe to keep references to the objects from this iterator. */ def tripletIterator( - includeSrc: Boolean = true, includeDst: Boolean = true): Iterator[EdgeTriplet[VD, ED]] = { - new EdgeTripletIterator(this, includeSrc, includeDst) + includeSrc: Boolean = true, includeDst: Boolean = true) + : Iterator[EdgeTriplet[VD, ED]] = new Iterator[EdgeTriplet[VD, ED]] { + private[this] var pos = 0 + + override def hasNext: Boolean = pos < EdgePartition.this.size + + override def next() = { + val triplet = new EdgeTriplet[VD, ED] + val localSrcId = localSrcIds(pos) + val localDstId = localDstIds(pos) + triplet.srcId = local2global(localSrcId) + triplet.dstId = local2global(localDstId) + if (includeSrc) { + triplet.srcAttr = vertexAttrs(localSrcId) + } + if (includeDst) { + triplet.dstAttr = vertexAttrs(localDstId) + } + triplet.attr = data(pos) + pos += 1 + triplet + } } /** - * Upgrade the given edge iterator into a triplet iterator. + * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning + * all edges sequentially. * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. + * @param sendMsg generates messages to neighboring vertices of an edge + * @param mergeMsg the combiner applied to messages destined to the same vertex + * @param tripletFields which triplet fields `sendMsg` uses + * @param activeness criteria for filtering edges based on activeness + * + * @return iterator aggregated messages keyed by the receiving vertex id */ - def upgradeIterator( - edgeIter: Iterator[Edge[ED]], includeSrc: Boolean = true, includeDst: Boolean = true) - : Iterator[EdgeTriplet[VD, ED]] = { - new ReusingEdgeTripletIterator(edgeIter, this, includeSrc, includeDst) + def aggregateMessagesEdgeScan[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeness: EdgeActiveness): Iterator[(VertexId, A)] = { + val aggregates = new Array[A](vertexAttrs.length) + val bitset = new BitSet(vertexAttrs.length) + + var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset) + var i = 0 + while (i < size) { + val localSrcId = localSrcIds(i) + val srcId = local2global(localSrcId) + val localDstId = localDstIds(i) + val dstId = local2global(localDstId) + val edgeIsActive = + if (activeness == EdgeActiveness.Neither) true + else if (activeness == EdgeActiveness.SrcOnly) isActive(srcId) + else if (activeness == EdgeActiveness.DstOnly) isActive(dstId) + else if (activeness == EdgeActiveness.Both) isActive(srcId) && isActive(dstId) + else if (activeness == EdgeActiveness.Either) isActive(srcId) || isActive(dstId) + else throw new Exception("unreachable") + if (edgeIsActive) { + val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD] + val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.set(srcId, dstId, localSrcId, localDstId, srcAttr, dstAttr, data(i)) + sendMsg(ctx) + } + i += 1 + } + + bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) } } /** - * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The - * iterator is generated using an index scan, so it is efficient at skipping edges that don't - * match srcIdPred. + * Send messages along edges and aggregate them at the receiving vertices. Implemented by + * filtering the source vertex index, then scanning each edge cluster. * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. - */ - def indexIterator(srcIdPred: VertexId => Boolean): Iterator[Edge[ED]] = - index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator)) - - /** - * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The - * cluster must start at position `index`. + * @param sendMsg generates messages to neighboring vertices of an edge + * @param mergeMsg the combiner applied to messages destined to the same vertex + * @param tripletFields which triplet fields `sendMsg` uses + * @param activeness criteria for filtering edges based on activeness * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. + * @return iterator aggregated messages keyed by the receiving vertex id */ - private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] { - private[this] val edge = new Edge[ED] - private[this] var pos = index + def aggregateMessagesIndexScan[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeness: EdgeActiveness): Iterator[(VertexId, A)] = { + val aggregates = new Array[A](vertexAttrs.length) + val bitset = new BitSet(vertexAttrs.length) + + var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset) + index.iterator.foreach { cluster => + val clusterSrcId = cluster._1 + val clusterPos = cluster._2 + val clusterLocalSrcId = localSrcIds(clusterPos) - override def hasNext: Boolean = { - pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId + val scanCluster = + if (activeness == EdgeActiveness.Neither) true + else if (activeness == EdgeActiveness.SrcOnly) isActive(clusterSrcId) + else if (activeness == EdgeActiveness.DstOnly) true + else if (activeness == EdgeActiveness.Both) isActive(clusterSrcId) + else if (activeness == EdgeActiveness.Either) true + else throw new Exception("unreachable") + + if (scanCluster) { + var pos = clusterPos + val srcAttr = + if (tripletFields.useSrc) vertexAttrs(clusterLocalSrcId) else null.asInstanceOf[VD] + ctx.setSrcOnly(clusterSrcId, clusterLocalSrcId, srcAttr) + while (pos < size && localSrcIds(pos) == clusterLocalSrcId) { + val localDstId = localDstIds(pos) + val dstId = local2global(localDstId) + val edgeIsActive = + if (activeness == EdgeActiveness.Neither) true + else if (activeness == EdgeActiveness.SrcOnly) true + else if (activeness == EdgeActiveness.DstOnly) isActive(dstId) + else if (activeness == EdgeActiveness.Both) isActive(dstId) + else if (activeness == EdgeActiveness.Either) isActive(clusterSrcId) || isActive(dstId) + else throw new Exception("unreachable") + if (edgeIsActive) { + val dstAttr = + if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.setRest(dstId, localDstId, dstAttr, data(pos)) + sendMsg(ctx) + } + pos += 1 + } + } } - override def next(): Edge[ED] = { - assert(srcIds(pos) == srcId) - edge.srcId = srcIds(pos) - edge.dstId = dstIds(pos) - edge.attr = data(pos) - pos += 1 - edge + bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) } + } +} + +private class AggregatingEdgeContext[VD, ED, A]( + mergeMsg: (A, A) => A, + aggregates: Array[A], + bitset: BitSet) + extends EdgeContext[VD, ED, A] { + + private[this] var _srcId: VertexId = _ + private[this] var _dstId: VertexId = _ + private[this] var _localSrcId: Int = _ + private[this] var _localDstId: Int = _ + private[this] var _srcAttr: VD = _ + private[this] var _dstAttr: VD = _ + private[this] var _attr: ED = _ + + def set( + srcId: VertexId, dstId: VertexId, + localSrcId: Int, localDstId: Int, + srcAttr: VD, dstAttr: VD, + attr: ED) { + _srcId = srcId + _dstId = dstId + _localSrcId = localSrcId + _localDstId = localDstId + _srcAttr = srcAttr + _dstAttr = dstAttr + _attr = attr + } + + def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD) { + _srcId = srcId + _localSrcId = localSrcId + _srcAttr = srcAttr + } + + def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) { + _dstId = dstId + _localDstId = localDstId + _dstAttr = dstAttr + _attr = attr + } + + override def srcId = _srcId + override def dstId = _dstId + override def srcAttr = _srcAttr + override def dstAttr = _dstAttr + override def attr = _attr + + override def sendToSrc(msg: A) { + send(_localSrcId, msg) + } + override def sendToDst(msg: A) { + send(_localDstId, msg) + } + + @inline private def send(localId: Int, msg: A) { + if (bitset.get(localId)) { + aggregates(localId) = mergeMsg(aggregates(localId), msg) + } else { + aggregates(localId) = msg + bitset.set(localId) } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 2b6137be25547..b0cb0fe47d461 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -25,10 +25,11 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +/** Constructs an EdgePartition from scratch. */ private[graphx] class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( size: Int = 64) { - var edges = new PrimitiveVector[Edge[ED]](size) + private[this] val edges = new PrimitiveVector[Edge[ED]](size) /** Add a new edge to the partition. */ def add(src: VertexId, dst: VertexId, d: ED) { @@ -38,8 +39,67 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla def toEdgePartition: EdgePartition[ED, VD] = { val edgeArray = edges.trim().array Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering) - val srcIds = new Array[VertexId](edgeArray.size) - val dstIds = new Array[VertexId](edgeArray.size) + val localSrcIds = new Array[Int](edgeArray.size) + val localDstIds = new Array[Int](edgeArray.size) + val data = new Array[ED](edgeArray.size) + val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] + val global2local = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] + val local2global = new PrimitiveVector[VertexId] + var vertexAttrs = Array.empty[VD] + // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and + // adding them to the index. Also populate a map from vertex id to a sequential local offset. + if (edgeArray.length > 0) { + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId + var currLocalId = -1 + var i = 0 + while (i < edgeArray.size) { + val srcId = edgeArray(i).srcId + val dstId = edgeArray(i).dstId + localSrcIds(i) = global2local.changeValue(srcId, + { currLocalId += 1; local2global += srcId; currLocalId }, identity) + localDstIds(i) = global2local.changeValue(dstId, + { currLocalId += 1; local2global += dstId; currLocalId }, identity) + data(i) = edgeArray(i).attr + if (srcId != currSrcId) { + currSrcId = srcId + index.update(currSrcId, i) + } + + i += 1 + } + vertexAttrs = new Array[VD](currLocalId + 1) + } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs, + None) + } +} + +/** + * Constructs an EdgePartition from an existing EdgePartition with the same vertex set. This enables + * reuse of the local vertex ids. Intended for internal use in EdgePartition only. + */ +private[impl] +class ExistingEdgePartitionBuilder[ + @specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( + global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + local2global: Array[VertexId], + vertexAttrs: Array[VD], + activeSet: Option[VertexSet], + size: Int = 64) { + private[this] val edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) + + /** Add a new edge to the partition. */ + def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) { + edges += EdgeWithLocalIds(src, dst, localSrc, localDst, d) + } + + def toEdgePartition: EdgePartition[ED, VD] = { + val edgeArray = edges.trim().array + Sorting.quickSort(edgeArray)(EdgeWithLocalIds.lexicographicOrdering) + val localSrcIds = new Array[Int](edgeArray.size) + val localDstIds = new Array[Int](edgeArray.size) val data = new Array[ED](edgeArray.size) val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and @@ -49,8 +109,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla var currSrcId: VertexId = edgeArray(0).srcId var i = 0 while (i < edgeArray.size) { - srcIds(i) = edgeArray(i).srcId - dstIds(i) = edgeArray(i).dstId + localSrcIds(i) = edgeArray(i).localSrcId + localDstIds(i) = edgeArray(i).localDstId data(i) = edgeArray(i).attr if (edgeArray(i).srcId != currSrcId) { currSrcId = edgeArray(i).srcId @@ -60,13 +120,24 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla } } - // Create and populate a VertexPartition with vids from the edges, but no attributes - val vidsIter = srcIds.iterator ++ dstIds.iterator - val vertexIds = new OpenHashSet[VertexId] - vidsIter.foreach(vid => vertexIds.add(vid)) - val vertices = new VertexPartition( - vertexIds, new Array[VD](vertexIds.capacity), vertexIds.getBitSet) + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) + } +} - new EdgePartition(srcIds, dstIds, data, index, vertices) +private[impl] case class EdgeWithLocalIds[@specialized ED]( + srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, attr: ED) + +private[impl] object EdgeWithLocalIds { + implicit def lexicographicOrdering[ED] = new Ordering[EdgeWithLocalIds[ED]] { + override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = { + if (a.srcId == b.srcId) { + if (a.dstId == b.dstId) 0 + else if (a.dstId < b.dstId) -1 + else 1 + } else if (a.srcId < b.srcId) -1 + else 1 + } } + } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala new file mode 100644 index 0000000000000..a8169613b4fd2 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.{classTag, ClassTag} + +import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +import org.apache.spark.graphx._ + +class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( + override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], + val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) + extends EdgeRDD[ED](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { + + override def setName(_name: String): this.type = { + if (partitionsRDD.name != null) { + partitionsRDD.setName(partitionsRDD.name + ", " + _name) + } else { + partitionsRDD.setName(_name) + } + this + } + setName("EdgeRDD") + + /** + * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the + * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new + * partitioner that allows co-partitioning with `partitionsRDD`. + */ + override val partitioner = + partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + + override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() + + /** + * Persists the edge partitions at the specified storage level, ignoring any existing target + * storage level. + */ + override def persist(newLevel: StorageLevel): this.type = { + partitionsRDD.persist(newLevel) + this + } + + override def unpersist(blocking: Boolean = true): this.type = { + partitionsRDD.unpersist(blocking) + this + } + + /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ + override def cache(): this.type = { + partitionsRDD.persist(targetStorageLevel) + this + } + + /** The number of edges in the RDD. */ + override def count(): Long = { + partitionsRDD.map(_._2.size.toLong).reduce(_ + _) + } + + override def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDDImpl[ED2, VD] = + mapEdgePartitions((pid, part) => part.map(f)) + + override def reverse: EdgeRDDImpl[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) + + def filter( + epred: EdgeTriplet[VD, ED] => Boolean, + vpred: (VertexId, VD) => Boolean): EdgeRDDImpl[ED, VD] = { + mapEdgePartitions((pid, part) => part.filter(epred, vpred)) + } + + override def innerJoin[ED2: ClassTag, ED3: ClassTag] + (other: EdgeRDD[ED2]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDDImpl[ED3, VD] = { + val ed2Tag = classTag[ED2] + val ed3Tag = classTag[ED3] + this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) { + (thisIter, otherIter) => + val (pid, thisEPart) = thisIter.next() + val (_, otherEPart) = otherIter.next() + Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag))) + }) + } + + def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( + f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDDImpl[ED2, VD2] = { + this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => + if (iter.hasNext) { + val (pid, ep) = iter.next() + Iterator(Tuple2(pid, f(pid, ep))) + } else { + Iterator.empty + } + }, preservesPartitioning = true)) + } + + private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( + partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDDImpl[ED2, VD2] = { + new EdgeRDDImpl(partitionsRDD, this.targetStorageLevel) + } + + override private[graphx] def withTargetStorageLevel( + targetStorageLevel: StorageLevel): EdgeRDDImpl[ED, VD] = { + new EdgeRDDImpl(this.partitionsRDD, targetStorageLevel) + } + +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala deleted file mode 100644 index 56f79a7097fce..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.reflect.ClassTag - -import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -/** - * The Iterator type returned when constructing edge triplets. This could be an anonymous class in - * EdgePartition.tripletIterator, but we name it here explicitly so it is easier to debug / profile. - */ -private[impl] -class EdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val edgePartition: EdgePartition[ED, VD], - val includeSrc: Boolean, - val includeDst: Boolean) - extends Iterator[EdgeTriplet[VD, ED]] { - - // Current position in the array. - private var pos = 0 - - override def hasNext: Boolean = pos < edgePartition.size - - override def next() = { - val triplet = new EdgeTriplet[VD, ED] - triplet.srcId = edgePartition.srcIds(pos) - if (includeSrc) { - triplet.srcAttr = edgePartition.vertices(triplet.srcId) - } - triplet.dstId = edgePartition.dstIds(pos) - if (includeDst) { - triplet.dstAttr = edgePartition.vertices(triplet.dstId) - } - triplet.attr = edgePartition.data(pos) - pos += 1 - triplet - } -} - -/** - * An Iterator type for internal use that reuses EdgeTriplet objects. This could be an anonymous - * class in EdgePartition.upgradeIterator, but we name it here explicitly so it is easier to debug / - * profile. - */ -private[impl] -class ReusingEdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val edgeIter: Iterator[Edge[ED]], - val edgePartition: EdgePartition[ED, VD], - val includeSrc: Boolean, - val includeDst: Boolean) - extends Iterator[EdgeTriplet[VD, ED]] { - - private val triplet = new EdgeTriplet[VD, ED] - - override def hasNext = edgeIter.hasNext - - override def next() = { - triplet.set(edgeIter.next()) - if (includeSrc) { - triplet.srcAttr = edgePartition.vertices(triplet.srcId) - } - if (includeDst) { - triplet.dstAttr = edgePartition.vertices(triplet.dstId) - } - triplet - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 33f35cfb69a26..0eae2a673874a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -23,7 +23,6 @@ import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ import org.apache.spark.graphx.util.BytecodeUtils @@ -44,7 +43,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( /** Default constructor is provided to support serialization */ protected def this() = this(null, null) - @transient override val edges: EdgeRDD[ED, VD] = replicatedVertexView.edges + @transient override val edges: EdgeRDDImpl[ED, VD] = replicatedVertexView.edges /** Return a RDD that brings edges together with their source and destination vertices. */ @transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = { @@ -127,13 +126,12 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } override def mapTriplets[ED2: ClassTag]( - f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = { + f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] = { vertices.cache() - val mapUsesSrcAttr = accessesVertexAttr(f, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(f, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val newEdges = replicatedVertexView.edges.mapEdgePartitions { (pid, part) => - part.map(f(pid, part.tripletIterator(mapUsesSrcAttr, mapUsesDstAttr))) + part.map(f(pid, part.tripletIterator(tripletFields.useSrc, tripletFields.useDst))) } new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges)) } @@ -171,15 +169,38 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { + + def sendMsg(ctx: EdgeContext[VD, ED, A]) { + mapFunc(ctx.toEdgeTriplet).foreach { kv => + val id = kv._1 + val msg = kv._2 + if (id == ctx.srcId) { + ctx.sendToSrc(msg) + } else { + assert(id == ctx.dstId) + ctx.sendToDst(msg) + } + } + } - vertices.cache() + val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") + val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") + val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true) + + aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt) + } + + override def aggregateMessagesWithActiveSet[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { + vertices.cache() // For each vertex, replicate its attribute only to partitions where it is // in the relevant position in an edge. - val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val view = activeSetOpt match { case Some((activeSet, _)) => replicatedVertexView.withActiveSet(activeSet) @@ -193,42 +214,40 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( case (pid, edgePartition) => // Choose scan method val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat - val edgeIter = activeDirectionOpt match { + activeDirectionOpt match { case Some(EdgeDirection.Both) => if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) - .filter(e => edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Both) } else { - edgePartition.iterator.filter(e => - edgePartition.isActive(e.srcId) && edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Both) } case Some(EdgeDirection.Either) => // TODO: Because we only have a clustered index on the source vertex ID, we can't filter // the index here. Instead we have to scan all edges and then do the filter. - edgePartition.iterator.filter(e => - edgePartition.isActive(e.srcId) || edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Either) case Some(EdgeDirection.Out) => if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.SrcOnly) } else { - edgePartition.iterator.filter(e => edgePartition.isActive(e.srcId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.SrcOnly) } case Some(EdgeDirection.In) => - edgePartition.iterator.filter(e => edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.DstOnly) case _ => // None - edgePartition.iterator + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Neither) } - - // Scan edges and run the map function - val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr) - .flatMap(mapFunc(_)) - // Note: This doesn't allow users to send messages to arbitrary vertices. - edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator - }).setName("GraphImpl.mapReduceTriplets - preAgg") + }).setName("GraphImpl.aggregateMessages - preAgg") // do the final reduction reusing the index map - vertices.aggregateUsingIndex(preAgg, reduceFunc) - } // end of mapReduceTriplets + vertices.aggregateUsingIndex(preAgg, mergeMsg) + } override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) @@ -304,11 +323,10 @@ object GraphImpl { */ def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = { + edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { // Convert the vertex partitions in edges to the correct type - val newEdges = edges.mapEdgePartitions( - (pid, part) => part.withVertices(part.vertices.map( - (vid, attr) => null.asInstanceOf[VD]))) + val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]] + .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) GraphImpl.fromExistingRDDs(vertices, newEdges) } @@ -319,8 +337,8 @@ object GraphImpl { */ def fromExistingRDDs[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED, VD]): GraphImpl[VD, ED] = { - new GraphImpl(vertices, new ReplicatedVertexView(edges)) + edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + new GraphImpl(vertices, new ReplicatedVertexView(edges.asInstanceOf[EdgeRDDImpl[ED, VD]])) } /** @@ -328,7 +346,7 @@ object GraphImpl { * `defaultVertexAttr`. The vertices will have the same number of partitions as the EdgeRDD. */ private def fromEdgeRDD[VD: ClassTag, ED: ClassTag]( - edges: EdgeRDD[ED, VD], + edges: EdgeRDDImpl[ED, VD], defaultVertexAttr: VD, edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 86b366eb9202b..8ab255bd4038c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx._ */ private[impl] class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( - var edges: EdgeRDD[ED, VD], + var edges: EdgeRDDImpl[ED, VD], var hasSrcId: Boolean = false, var hasDstId: Boolean = false) { @@ -42,7 +42,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * shipping level. */ def withEdges[VD2: ClassTag, ED2: ClassTag]( - edges_ : EdgeRDD[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { + edges_ : EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { new ReplicatedVertexView(edges_, hasSrcId, hasDstId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 7a7fa91aadfe1..eb3c997e0f3c0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -56,11 +56,9 @@ object RoutingTablePartition { // Determine which positions each vertex id appears in using a map where the low 2 bits // represent src and dst val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, Byte] - edgePartition.srcIds.iterator.foreach { srcId => - map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte) - } - edgePartition.dstIds.iterator.foreach { dstId => - map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte) + edgePartition.iterator.foreach { e => + map.changeValue(e.srcId, 0x1, (b: Byte) => (b | 0x1).toByte) + map.changeValue(e.dstId, 0x2, (b: Byte) => (b | 0x2).toByte) } map.iterator.map { vidAndPosition => val vid = vidAndPosition._1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala new file mode 100644 index 0000000000000..d92a55a189298 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel + +import org.apache.spark.graphx._ + +class VertexRDDImpl[VD] private[graphx] ( + val partitionsRDD: RDD[ShippableVertexPartition[VD]], + val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) + (implicit override protected val vdTag: ClassTag[VD]) + extends VertexRDD[VD](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { + + require(partitionsRDD.partitioner.isDefined) + + override def reindex(): VertexRDD[VD] = this.withPartitionsRDD(partitionsRDD.map(_.reindex())) + + override val partitioner = partitionsRDD.partitioner + + override protected def getPreferredLocations(s: Partition): Seq[String] = + partitionsRDD.preferredLocations(s) + + override def setName(_name: String): this.type = { + if (partitionsRDD.name != null) { + partitionsRDD.setName(partitionsRDD.name + ", " + _name) + } else { + partitionsRDD.setName(_name) + } + this + } + setName("VertexRDD") + + /** + * Persists the vertex partitions at the specified storage level, ignoring any existing target + * storage level. + */ + override def persist(newLevel: StorageLevel): this.type = { + partitionsRDD.persist(newLevel) + this + } + + override def unpersist(blocking: Boolean = true): this.type = { + partitionsRDD.unpersist(blocking) + this + } + + /** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */ + override def cache(): this.type = { + partitionsRDD.persist(targetStorageLevel) + this + } + + /** The number of vertices in the RDD. */ + override def count(): Long = { + partitionsRDD.map(_.size).reduce(_ + _) + } + + override private[graphx] def mapVertexPartitions[VD2: ClassTag]( + f: ShippableVertexPartition[VD] => ShippableVertexPartition[VD2]) + : VertexRDD[VD2] = { + val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true) + this.withPartitionsRDD(newPartitionsRDD) + } + + override def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] = + this.mapVertexPartitions(_.map((vid, attr) => f(attr))) + + override def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = + this.mapVertexPartitions(_.map(f)) + + override def diff(other: VertexRDD[VD]): VertexRDD[VD] = { + val newPartitionsRDD = partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true + ) { (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.diff(otherPart)) + } + this.withPartitionsRDD(newPartitionsRDD) + } + + override def leftZipJoin[VD2: ClassTag, VD3: ClassTag] + (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] = { + val newPartitionsRDD = partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true + ) { (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.leftJoin(otherPart)(f)) + } + this.withPartitionsRDD(newPartitionsRDD) + } + + override def leftJoin[VD2: ClassTag, VD3: ClassTag] + (other: RDD[(VertexId, VD2)]) + (f: (VertexId, VD, Option[VD2]) => VD3) + : VertexRDD[VD3] = { + // Test if the other vertex is a VertexRDD to choose the optimal join strategy. + // If the other set is a VertexRDD then we use the much more efficient leftZipJoin + other match { + case other: VertexRDD[_] => + leftZipJoin(other)(f) + case _ => + this.withPartitionsRDD[VD3]( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f)) + } + ) + } + } + + override def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U]) + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { + val newPartitionsRDD = partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true + ) { (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.innerJoin(otherPart)(f)) + } + this.withPartitionsRDD(newPartitionsRDD) + } + + override def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { + // Test if the other vertex is a VertexRDD to choose the optimal join strategy. + // If the other set is a VertexRDD then we use the much more efficient innerZipJoin + other match { + case other: VertexRDD[_] => + innerZipJoin(other)(f) + case _ => + this.withPartitionsRDD( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f)) + } + ) + } + } + + override def aggregateUsingIndex[VD2: ClassTag]( + messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = { + val shuffled = messages.partitionBy(this.partitioner.get) + val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) => + thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc)) + } + this.withPartitionsRDD[VD2](parts) + } + + override def reverseRoutingTables(): VertexRDD[VD] = + this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse)) + + override def withEdges(edges: EdgeRDD[_]): VertexRDD[VD] = { + val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get) + val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) { + (partIter, routingTableIter) => + val routingTable = + if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty + partIter.map(_.withRoutingTable(routingTable)) + } + this.withPartitionsRDD(vertexPartitions) + } + + override private[graphx] def withPartitionsRDD[VD2: ClassTag]( + partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] = { + new VertexRDDImpl(partitionsRDD, this.targetStorageLevel) + } + + override private[graphx] def withTargetStorageLevel( + targetStorageLevel: StorageLevel): VertexRDD[VD] = { + new VertexRDDImpl(this.partitionsRDD, targetStorageLevel) + } + + override private[graphx] def shipVertexAttributes( + shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = { + partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst))) + } + + override private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = { + partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds())) + } + +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 257e2f3a36115..e40ae0d615466 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -85,7 +85,7 @@ object PageRank extends Logging { // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree - .mapTriplets( e => 1.0 / e.srcAttr ) + .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.SrcOnly ) // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) @@ -96,8 +96,8 @@ object PageRank extends Logging { // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation. - val rankUpdates = rankGraph.mapReduceTriplets[Double]( - e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _) + val rankUpdates = rankGraph.aggregateMessages[Double]( + ctx => ctx.sendToDst(ctx.srcAttr * ctx.attr), _ + _, TripletFields.SrcAndEdge) // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index ccd7de537b6e3..f58587e10a820 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -74,9 +74,9 @@ object SVDPlusPlus { var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() // Calculate initial bias and norm - val t0 = g.mapReduceTriplets( - et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))), - (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2)) + val t0 = g.aggregateMessages[(Long, Double)]( + ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, + (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) g = g.outerJoinVertices(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), @@ -84,15 +84,17 @@ object SVDPlusPlus { (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) } - def mapTrainF(conf: Conf, u: Double) - (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) - : Iterator[(VertexId, (DoubleMatrix, DoubleMatrix, Double))] = { - val (usr, itm) = (et.srcAttr, et.dstAttr) + def sendMsgTrainF(conf: Conf, u: Double) + (ctx: EdgeContext[ + (DoubleMatrix, DoubleMatrix, Double, Double), + Double, + (DoubleMatrix, DoubleMatrix, Double)]) { + val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) - val err = et.attr - pred + val err = ctx.attr - pred val updateP = q.mul(err) .subColumnVector(p.mul(conf.gamma7)) .mul(conf.gamma2) @@ -102,16 +104,16 @@ object SVDPlusPlus { val updateY = q.mul(err * usr._4) .subColumnVector(itm._2.mul(conf.gamma7)) .mul(conf.gamma2) - Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)), - (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))) + ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) + ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)) } for (i <- 0 until conf.maxIters) { // Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes g.cache() - val t1 = g.mapReduceTriplets( - et => Iterator((et.srcId, et.dstAttr._2)), - (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2)) + val t1 = g.aggregateMessages[DoubleMatrix]( + ctx => ctx.sendToSrc(ctx.dstAttr._2), + (g1, g2) => g1.addColumnVector(g2)) g = g.outerJoinVertices(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => @@ -121,8 +123,8 @@ object SVDPlusPlus { // Phase 2, update p for user nodes and q, y for item nodes g.cache() - val t2 = g.mapReduceTriplets( - mapTrainF(conf, u), + val t2 = g.aggregateMessages( + sendMsgTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) g = g.outerJoinVertices(t2) { @@ -135,20 +137,18 @@ object SVDPlusPlus { } // calculate error on training set - def mapTestF(conf: Conf, u: Double) - (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) - : Iterator[(VertexId, Double)] = - { - val (usr, itm) = (et.srcAttr, et.dstAttr) + def sendMsgTestF(conf: Conf, u: Double) + (ctx: EdgeContext[(DoubleMatrix, DoubleMatrix, Double, Double), Double, Double]) { + val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) - val err = (et.attr - pred) * (et.attr - pred) - Iterator((et.dstId, err)) + val err = (ctx.attr - pred) * (ctx.attr - pred) + ctx.sendToDst(err) } g.cache() - val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2) + val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) g = g.outerJoinVertices(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index 7c396e6e66a28..daf162085e3e4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -61,26 +61,27 @@ object TriangleCount { (vid, _, optSet) => optSet.getOrElse(null) } // Edge function computes intersection of smaller vertex with larger vertex - def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Iterator[(VertexId, Int)] = { - assert(et.srcAttr != null) - assert(et.dstAttr != null) - val (smallSet, largeSet) = if (et.srcAttr.size < et.dstAttr.size) { - (et.srcAttr, et.dstAttr) + def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) { + assert(ctx.srcAttr != null) + assert(ctx.dstAttr != null) + val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) { + (ctx.srcAttr, ctx.dstAttr) } else { - (et.dstAttr, et.srcAttr) + (ctx.dstAttr, ctx.srcAttr) } val iter = smallSet.iterator var counter: Int = 0 while (iter.hasNext) { val vid = iter.next() - if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) { + if (vid != ctx.srcId && vid != ctx.dstId && largeSet.contains(vid)) { counter += 1 } } - Iterator((et.srcId, counter), (et.dstId, counter)) + ctx.sendToSrc(counter) + ctx.sendToDst(counter) } // compute the intersection along edges - val counters: VertexRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _ + _) + val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _) // Merge counters with the graph and divide by two since each triangle is counted twice g.outerJoinVertices(counters) { (vid, _, optCounter: Option[Int]) => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 6506bac73d71c..df773db6e4326 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -118,7 +118,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { // Each vertex should be replicated to at most 2 * sqrt(p) partitions val partitionSets = partitionedGraph.edges.partitionsRDD.mapPartitions { iter => val part = iter.next()._2 - Iterator((part.srcIds ++ part.dstIds).toSet) + Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet) }.collect if (!verts.forall(id => partitionSets.count(_.contains(id)) <= bound)) { val numFailures = verts.count(id => partitionSets.count(_.contains(id)) > bound) @@ -130,7 +130,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { // This should not be true for the default hash partitioning val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter => val part = iter.next()._2 - Iterator((part.srcIds ++ part.dstIds).toSet) + Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet) }.collect assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound)) @@ -318,6 +318,21 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("aggregateMessages") { + withSpark { sc => + val n = 5 + val agg = starGraph(sc, n).aggregateMessages[String]( + ctx => { + if (ctx.dstAttr != null) { + throw new Exception( + "expected ctx.dstAttr to be null due to TripletFields, but it was " + ctx.dstAttr) + } + ctx.sendToDst(ctx.srcAttr) + }, _ + _, TripletFields.SrcOnly) + assert(agg.collect().toSet === (1 to n).map(x => (x: VertexId, "v")).toSet) + } + } + test("outerJoinVertices") { withSpark { sc => val n = 5 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index db1dac6160080..515f3a9cd02eb 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -82,29 +82,6 @@ class EdgePartitionSuite extends FunSuite { assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges) } - test("upgradeIterator") { - val edges = List((0, 1, 0), (1, 0, 0)) - val verts = List((0L, 1), (1L, 2)) - val part = makeEdgePartition(edges).updateVertices(verts.iterator) - assert(part.upgradeIterator(part.iterator).map(_.toTuple).toList === - part.tripletIterator().toList.map(_.toTuple)) - } - - test("indexIterator") { - val edgesFrom0 = List(Edge(0, 1, 0)) - val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0)) - val sortedEdges = edgesFrom0 ++ edgesFrom1 - val builder = new EdgePartitionBuilder[Int, Nothing] - for (e <- Random.shuffle(sortedEdges)) { - builder.add(e.srcId, e.dstId, e.attr) - } - - val edgePartition = builder.toEdgePartition - assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges) - assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0) - assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1) - } - test("innerJoin") { val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0)) @@ -125,8 +102,18 @@ class EdgePartitionSuite extends FunSuite { assert(ep.numActives == Some(2)) } + test("tripletIterator") { + val builder = new EdgePartitionBuilder[Int, Int] + builder.add(1, 2, 0) + builder.add(1, 3, 0) + builder.add(1, 4, 0) + val ep = builder.toEdgePartition + val result = ep.tripletIterator().toList.map(et => (et.srcId, et.dstId)) + assert(result === Seq((1, 2), (1, 3), (1, 4))) + } + test("serialization") { - val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) + val aList = List((0, 1, 1), (1, 0, 2), (1, 2, 3), (5, 4, 4), (5, 5, 5)) val a: EdgePartition[Int, Int] = makeEdgePartition(aList) val javaSer = new JavaSerializer(new SparkConf()) val conf = new SparkConf() @@ -135,11 +122,7 @@ class EdgePartitionSuite extends FunSuite { for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) - assert(aSer.srcIds.toList === a.srcIds.toList) - assert(aSer.dstIds.toList === a.dstIds.toList) - assert(aSer.data.toList === a.data.toList) - assert(aSer.index != null) - assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet) + assert(aSer.tripletIterator().toList === a.tripletIterator().toList) } } } diff --git a/make-distribution.sh b/make-distribution.sh index 0bc839e1dbe4d..2267b1aa08a6c 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -59,7 +59,7 @@ while (( "$#" )); do exit_with_usage ;; --with-hive) - echo "Error: '--with-hive' is no longer supported, use Maven option -Phive" + echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; --skip-java-test) @@ -181,6 +181,9 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +# This will fail if the -Pyarn profile is not provided +# In this case, silence the error and ignore the return code of this command +cp "$FWDIR"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" diff --git a/mllib/pom.xml b/mllib/pom.xml index 87a7ddaba97f2..dd68b27a78bdc 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -100,6 +100,11 @@ junit-interface test + + org.mockito + mockito-all + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala new file mode 100644 index 0000000000000..fdbee743e8177 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.api.java.JavaSchemaRDD + +/** + * :: AlphaComponent :: + * Abstract class for estimators that fit models to data. + */ +@AlphaComponent +abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { + + /** + * Fits a single model to the input data with optional parameters. + * + * @param dataset input dataset + * @param paramPairs optional list of param pairs (overwrite embedded params) + * @return fitted model + */ + @varargs + def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + val map = new ParamMap().put(paramPairs: _*) + fit(dataset, map) + } + + /** + * Fits a single model to the input data with provided parameter map. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted model + */ + def fit(dataset: SchemaRDD, paramMap: ParamMap): M + + /** + * Fits multiple models to the input data with multiple sets of parameters. + * The default implementation uses a for loop on each parameter map. + * Subclasses could overwrite this to optimize multi-model training. + * + * @param dataset input dataset + * @param paramMaps an array of parameter maps + * @return fitted models, matching the input parameter maps + */ + def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + paramMaps.map(fit(dataset, _)) + } + + // Java-friendly versions of fit. + + /** + * Fits a single model to the input data with optional parameters. + * + * @param dataset input dataset + * @param paramPairs optional list of param pairs (overwrite embedded params) + * @return fitted model + */ + @varargs + def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = { + fit(dataset.schemaRDD, paramPairs: _*) + } + + /** + * Fits a single model to the input data with provided parameter map. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted model + */ + def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = { + fit(dataset.schemaRDD, paramMap) + } + + /** + * Fits multiple models to the input data with multiple sets of parameters. + * + * @param dataset input dataset + * @param paramMaps an array of parameter maps + * @return fitted models, matching the input parameter maps + */ + def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = { + fit(dataset.schemaRDD, paramMaps).asJava + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala new file mode 100644 index 0000000000000..db563dd550e56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.SchemaRDD + +/** + * :: AlphaComponent :: + * Abstract class for evaluators that compute metrics from predictions. + */ +@AlphaComponent +abstract class Evaluator extends Identifiable { + + /** + * Evaluates the output. + * + * @param dataset a dataset that contains labels/observations and predictions. + * @param paramMap parameter map that specifies the input columns and output metrics + * @return metric + */ + def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala similarity index 58% rename from graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala rename to mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index 49b2704390fea..cd84b05bfb496 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -15,23 +15,19 @@ * limitations under the License. */ -package org.apache.spark.graphx.impl +package org.apache.spark.ml -import scala.reflect.ClassTag -import scala.util.Random +import java.util.UUID -import org.scalatest.FunSuite - -import org.apache.spark.graphx._ +/** + * Object with a unique id. + */ +private[ml] trait Identifiable extends Serializable { -class EdgeTripletIteratorSuite extends FunSuite { - test("iterator.toList") { - val builder = new EdgePartitionBuilder[Int, Int] - builder.add(1, 2, 0) - builder.add(1, 3, 0) - builder.add(1, 4, 0) - val iter = new EdgeTripletIterator[Int, Int](builder.toEdgePartition, true, true) - val result = iter.toList.map(et => (et.srcId, et.dstId)) - assert(result === Seq((1, 2), (1, 3), (1, 4))) - } + /** + * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * random hex chars. + */ + private[ml] val uid: String = + this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala new file mode 100644 index 0000000000000..cae5082b51196 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.ParamMap + +/** + * :: AlphaComponent :: + * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. + * + * @tparam M model type + */ +@AlphaComponent +abstract class Model[M <: Model[M]] extends Transformer { + /** + * The parent estimator that produced this model. + */ + val parent: Estimator[M] + + /** + * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model. + */ + val fittingParamMap: ParamMap +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala new file mode 100644 index 0000000000000..e545df1e37b9c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.collection.mutable.ListBuffer + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{Params, Param, ParamMap} +import org.apache.spark.sql.{SchemaRDD, StructType} + +/** + * :: AlphaComponent :: + * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. + */ +@AlphaComponent +abstract class PipelineStage extends Serializable with Logging { + + /** + * Derives the output schema from the input schema and parameters. + */ + private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType + + /** + * Derives the output schema from the input schema and parameters, optionally with logging. + */ + protected def transformSchema( + schema: StructType, + paramMap: ParamMap, + logging: Boolean): StructType = { + if (logging) { + logDebug(s"Input schema: ${schema.json}") + } + val outputSchema = transformSchema(schema, paramMap) + if (logging) { + logDebug(s"Expected output schema: ${outputSchema.json}") + } + outputSchema + } +} + +/** + * :: AlphaComponent :: + * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each + * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the + * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will + * be called on the input dataset to fit a model. Then the model, which is a transformer, will be + * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]], + * its [[Transformer.transform]] method will be called to produce the dataset for the next stage. + * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and + * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as + * an identity transformer. + */ +@AlphaComponent +class Pipeline extends Estimator[PipelineModel] { + + /** param for pipeline stages */ + val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") + def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } + def getStages: Array[PipelineStage] = get(stages) + + /** + * Fits the pipeline to the input dataset with additional parameters. If a stage is an + * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model. + * Then the model, which is a transformer, will be used to transform the dataset as the input to + * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be + * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an + * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the + * pipeline stages. If there are no stages, the output model acts as an identity transformer. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted pipeline + */ + override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val theStages = map(stages) + // Search for the last estimator. + var indexOfLastEstimator = -1 + theStages.view.zipWithIndex.foreach { case (stage, index) => + stage match { + case _: Estimator[_] => + indexOfLastEstimator = index + case _ => + } + } + var curDataset = dataset + val transformers = ListBuffer.empty[Transformer] + theStages.view.zipWithIndex.foreach { case (stage, index) => + if (index <= indexOfLastEstimator) { + val transformer = stage match { + case estimator: Estimator[_] => + estimator.fit(curDataset, paramMap) + case t: Transformer => + t + case _ => + throw new IllegalArgumentException( + s"Do not support stage $stage of type ${stage.getClass}") + } + curDataset = transformer.transform(curDataset, paramMap) + transformers += transformer + } else { + transformers += stage.asInstanceOf[Transformer] + } + } + + new PipelineModel(this, map, transformers.toArray) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val theStages = map(stages) + require(theStages.toSet.size == theStages.size, + "Cannot have duplicate components in a pipeline.") + theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap)) + } +} + +/** + * :: AlphaComponent :: + * Represents a compiled pipeline. + */ +@AlphaComponent +class PipelineModel private[ml] ( + override val parent: Pipeline, + override val fittingParamMap: ParamMap, + private[ml] val stages: Array[Transformer]) + extends Model[PipelineModel] with Logging { + + /** + * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input + * estimator does not exist in the pipeline. + */ + def getModel[M <: Model[M]](stage: Estimator[M]): M = { + val matched = stages.filter { + case m: Model[_] => m.parent.eq(stage) + case _ => false + } + if (matched.isEmpty) { + throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.") + } else if (matched.size > 1) { + throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.") + } else { + matched.head.asInstanceOf[M] + } + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap)) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala new file mode 100644 index 0000000000000..490e6609ad311 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.types._ + +/** + * :: AlphaComponent :: + * Abstract class for transformers that transform one dataset into another. + */ +@AlphaComponent +abstract class Transformer extends PipelineStage with Params { + + /** + * Transforms the dataset with optional parameters + * @param dataset input dataset + * @param paramPairs optional list of param pairs, overwrite embedded params + * @return transformed dataset + */ + @varargs + def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + val map = new ParamMap() + paramPairs.foreach(map.put(_)) + transform(dataset, map) + } + + /** + * Transforms the dataset with provided parameter map as additional parameters. + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + + // Java-friendly versions of transform. + + /** + * Transforms the dataset with optional parameters. + * @param dataset input datset + * @param paramPairs optional list of param pairs, overwrite embedded params + * @return transformed dataset + */ + @varargs + def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = { + transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD + } + + /** + * Transforms the dataset with provided parameter map as additional parameters. + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = { + transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD + } +} + +/** + * Abstract class for transformers that take one input column, apply transformation, and output the + * result as a new column. + */ +private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]] + extends Transformer with HasInputCol with HasOutputCol with Logging { + + def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] + + /** + * Creates the transform function using the given param map. The input param map already takes + * account of the embedded param map. So the param values should be determined solely by the input + * param map. + */ + protected def createTransformFunc(paramMap: ParamMap): IN => OUT + + /** + * Validates the input type. Throw an exception if it is invalid. + */ + protected def validateInputType(inputType: DataType): Unit = {} + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + validateInputType(inputType) + if (schema.fieldNames.contains(map(outputCol))) { + throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") + } + val output = ScalaReflection.schemaFor[OUT] + val outputFields = schema.fields :+ + StructField(map(outputCol), output.dataType, output.nullable) + StructType(outputFields) + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val udf = this.createTransformFunc(map) + dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala new file mode 100644 index 0000000000000..85b8899636ca5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.storage.StorageLevel + +/** + * :: AlphaComponent :: + * Params for logistic regression. + */ +@AlphaComponent +private[classification] trait LogisticRegressionParams extends Params + with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol + with HasScoreCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param paramMap additional parameters + * @param fitting whether this is in fitting + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean): StructType = { + val map = this.paramMap ++ paramMap + val featuresType = schema(map(featuresCol)).dataType + // TODO: Support casting Array[Double] and Array[Float] to Vector. + require(featuresType.isInstanceOf[VectorUDT], + s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") + if (fitting) { + val labelType = schema(map(labelCol)).dataType + require(labelType == DoubleType, + s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") + } + val fieldNames = schema.fieldNames + require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.") + require(!fieldNames.contains(map(predictionCol)), + s"Prediction column ${map(predictionCol)} already exists.") + val outputFields = schema.fields ++ Seq( + StructField(map(scoreCol), DoubleType, false), + StructField(map(predictionCol), DoubleType, false)) + StructType(outputFields) + } +} + +/** + * Logistic regression. + */ +class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams { + + setRegParam(0.1) + setMaxIter(100) + setThreshold(0.5) + + def setRegParam(value: Double): this.type = set(regParam, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) + def setLabelCol(value: String): this.type = set(labelCol, value) + def setThreshold(value: Double): this.type = set(threshold, value) + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + .map { case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + }.persist(StorageLevel.MEMORY_AND_DISK) + val lr = new LogisticRegressionWithLBFGS + lr.optimizer + .setRegParam(map(regParam)) + .setNumIterations(map(maxIter)) + val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights) + instances.unpersist() + // copy model params + Params.inheritValues(map, this, lrm) + lrm + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = true) + } +} + +/** + * :: AlphaComponent :: + * Model produced by [[LogisticRegression]]. + */ +@AlphaComponent +class LogisticRegressionModel private[ml] ( + override val parent: LogisticRegression, + override val fittingParamMap: ParamMap, + weights: Vector) + extends Model[LogisticRegressionModel] with LogisticRegressionParams { + + def setThreshold(value: Double): this.type = set(threshold, value) + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = false) + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val score: Vector => Double = (v) => { + val margin = BLAS.dot(v, weights) + 1.0 / (1.0 + math.exp(-margin)) + } + val t = map(threshold) + val predict: Double => Double = (score) => { + if (score > t) 1.0 else 0.0 + } + dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) + .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala new file mode 100644 index 0000000000000..0b0504e036ec9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.sql.{DoubleType, Row, SchemaRDD} + +/** + * :: AlphaComponent :: + * Evaluator for binary classification, which expects two input columns: score and label. + */ +@AlphaComponent +class BinaryClassificationEvaluator extends Evaluator with Params + with HasScoreCol with HasLabelCol { + + /** param for metric name in evaluation */ + val metricName: Param[String] = new Param(this, "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + def getMetricName: String = get(metricName) + def setMetricName(value: String): this.type = set(metricName, value) + + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + val map = this.paramMap ++ paramMap + + val schema = dataset.schema + val scoreType = schema(map(scoreCol)).dataType + require(scoreType == DoubleType, + s"Score column ${map(scoreCol)} must be double type but found $scoreType") + val labelType = schema(map(labelCol)).dataType + require(labelType == DoubleType, + s"Label column ${map(labelCol)} must be double type but found $labelType") + + import dataset.sqlContext._ + val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) + .map { case Row(score: Double, label: Double) => + (score, label) + } + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val metric = map(metricName) match { + case "areaUnderROC" => + metrics.areaUnderROC() + case "areaUnderPR" => + metrics.areaUnderPR() + case other => + throw new IllegalArgumentException(s"Does not support metric $other.") + } + metrics.unpersist() + metric + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala new file mode 100644 index 0000000000000..b98b1755a3584 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.Vector + +/** + * :: AlphaComponent :: + * Maps a sequence of terms to their term frequencies using the hashing trick. + */ +@AlphaComponent +class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { + + /** number of features */ + val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) + def setNumFeatures(value: Int) = set(numFeatures, value) + def getNumFeatures: Int = get(numFeatures) + + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { + val hashingTF = new feature.HashingTF(paramMap(numFeatures)) + hashingTF.transform + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala new file mode 100644 index 0000000000000..896a6b83b67bf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ + +/** + * Params for [[StandardScaler]] and [[StandardScalerModel]]. + */ +private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol + +/** + * :: AlphaComponent :: + * Standardizes features by removing the mean and scaling to unit variance using column summary + * statistics on the samples in the training set. + */ +@AlphaComponent +class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + + def setInputCol(value: String): this.type = set(inputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val input = dataset.select(map(inputCol).attr) + .map { case Row(v: Vector) => + v + } + val scaler = new feature.StandardScaler().fit(input) + val model = new StandardScalerModel(this, map, scaler) + Params.inheritValues(map, this, model) + model + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[StandardScaler]]. + */ +@AlphaComponent +class StandardScalerModel private[ml] ( + override val parent: StandardScaler, + override val fittingParamMap: ParamMap, + scaler: feature.StandardScalerModel) + extends Model[StandardScalerModel] with StandardScalerParams { + + def setInputCol(value: String): this.type = set(inputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val scale: (Vector) => Vector = (v) => { + scaler.transform(v) + } + dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala new file mode 100644 index 0000000000000..0a6599b64c011 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.{DataType, StringType} + +/** + * :: AlphaComponent :: + * A tokenizer that converts the input string to lowercase and then splits it by white spaces. + */ +@AlphaComponent +class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { + + protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + _.toLowerCase.split("\\s") + } + + protected override def validateInputType(inputType: DataType): Unit = { + require(inputType == StringType, s"Input type must be string type but got $inputType.") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java new file mode 100644 index 0000000000000..00d9c802e930d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * assemble and configure practical machine learning pipelines. + */ +@AlphaComponent +package org.apache.spark.ml; + +import org.apache.spark.annotation.AlphaComponent; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala new file mode 100644 index 0000000000000..51cd48c90432a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * assemble and configure practical machine learning pipelines. + */ +package object ml diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala new file mode 100644 index 0000000000000..8fd46aef4b99d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +import java.lang.reflect.Modifier + +import org.apache.spark.annotation.AlphaComponent + +import scala.annotation.varargs +import scala.collection.mutable + +import org.apache.spark.ml.Identifiable + +/** + * :: AlphaComponent :: + * A param with self-contained documentation and optionally default value. Primitive-typed param + * should use the specialized versions, which are more friendly to Java users. + * + * @param parent parent object + * @param name param name + * @param doc documentation + * @tparam T param value type + */ +@AlphaComponent +class Param[T] ( + val parent: Params, + val name: String, + val doc: String, + val defaultValue: Option[T] = None) + extends Serializable { + + /** + * Creates a param pair with the given value (for Java). + */ + def w(value: T): ParamPair[T] = this -> value + + /** + * Creates a param pair with the given value (for Scala). + */ + def ->(value: T): ParamPair[T] = ParamPair(this, value) + + override def toString: String = { + if (defaultValue.isDefined) { + s"$name: $doc (default: ${defaultValue.get})" + } else { + s"$name: $doc" + } + } +} + +// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... + +/** Specialized version of [[Param[Double]]] for Java. */ +class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None) + extends Param[Double](parent, name, doc, defaultValue) { + + override def w(value: Double): ParamPair[Double] = super.w(value) +} + +/** Specialized version of [[Param[Int]]] for Java. */ +class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None) + extends Param[Int](parent, name, doc, defaultValue) { + + override def w(value: Int): ParamPair[Int] = super.w(value) +} + +/** Specialized version of [[Param[Float]]] for Java. */ +class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None) + extends Param[Float](parent, name, doc, defaultValue) { + + override def w(value: Float): ParamPair[Float] = super.w(value) +} + +/** Specialized version of [[Param[Long]]] for Java. */ +class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None) + extends Param[Long](parent, name, doc, defaultValue) { + + override def w(value: Long): ParamPair[Long] = super.w(value) +} + +/** Specialized version of [[Param[Boolean]]] for Java. */ +class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None) + extends Param[Boolean](parent, name, doc, defaultValue) { + + override def w(value: Boolean): ParamPair[Boolean] = super.w(value) +} + +/** + * A param amd its value. + */ +case class ParamPair[T](param: Param[T], value: T) + +/** + * :: AlphaComponent :: + * Trait for components that take parameters. This also provides an internal param map to store + * parameter values attached to the instance. + */ +@AlphaComponent +trait Params extends Identifiable with Serializable { + + /** Returns all params. */ + def params: Array[Param[_]] = { + val methods = this.getClass.getMethods + methods.filter { m => + Modifier.isPublic(m.getModifiers) && + classOf[Param[_]].isAssignableFrom(m.getReturnType) && + m.getParameterTypes.isEmpty + }.sortBy(_.getName) + .map(m => m.invoke(this).asInstanceOf[Param[_]]) + } + + /** + * Validates parameter values stored internally plus the input parameter map. + * Raises an exception if any parameter is invalid. + */ + def validate(paramMap: ParamMap): Unit = {} + + /** + * Validates parameter values stored internally. + * Raise an exception if any parameter value is invalid. + */ + def validate(): Unit = validate(ParamMap.empty) + + /** + * Returns the documentation of all params. + */ + def explainParams(): String = params.mkString("\n") + + /** Checks whether a param is explicitly set. */ + def isSet(param: Param[_]): Boolean = { + require(param.parent.eq(this)) + paramMap.contains(param) + } + + /** Gets a param by its name. */ + private[ml] def getParam(paramName: String): Param[Any] = { + val m = this.getClass.getMethod(paramName) + assert(Modifier.isPublic(m.getModifiers) && + classOf[Param[_]].isAssignableFrom(m.getReturnType) && + m.getParameterTypes.isEmpty) + m.invoke(this).asInstanceOf[Param[Any]] + } + + /** + * Sets a parameter in the embedded param map. + */ + private[ml] def set[T](param: Param[T], value: T): this.type = { + require(param.parent.eq(this)) + paramMap.put(param.asInstanceOf[Param[Any]], value) + this + } + + /** + * Gets the value of a parameter in the embedded param map. + */ + private[ml] def get[T](param: Param[T]): T = { + require(param.parent.eq(this)) + paramMap(param) + } + + /** + * Internal param map. + */ + protected val paramMap: ParamMap = ParamMap.empty +} + +private[ml] object Params { + + /** + * Copies parameter values from the parent estimator to the child model it produced. + * @param paramMap the param map that holds parameters of the parent + * @param parent the parent estimator + * @param child the child model + */ + def inheritValues[E <: Params, M <: E]( + paramMap: ParamMap, + parent: E, + child: M): Unit = { + parent.params.foreach { param => + if (paramMap.contains(param)) { + child.set(child.getParam(param.name), paramMap(param)) + } + } + } +} + +/** + * :: AlphaComponent :: + * A param to value map. + */ +@AlphaComponent +class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { + + /** + * Creates an empty param map. + */ + def this() = this(mutable.Map.empty[Param[Any], Any]) + + /** + * Puts a (param, value) pair (overwrites if the input param exists). + */ + def put[T](param: Param[T], value: T): this.type = { + map(param.asInstanceOf[Param[Any]]) = value + this + } + + /** + * Puts a list of param pairs (overwrites if the input params exists). + */ + def put(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + put(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } + + /** + * Optionally returns the value associated with a param or its default. + */ + def get[T](param: Param[T]): Option[T] = { + map.get(param.asInstanceOf[Param[Any]]) + .orElse(param.defaultValue) + .asInstanceOf[Option[T]] + } + + /** + * Gets the value of the input param or its default value if it does not exist. + * Raises a NoSuchElementException if there is no value associated with the input param. + */ + def apply[T](param: Param[T]): T = { + val value = get(param) + if (value.isDefined) { + value.get + } else { + throw new NoSuchElementException(s"Cannot find param ${param.name}.") + } + } + + /** + * Checks whether a parameter is explicitly specified. + */ + def contains(param: Param[_]): Boolean = { + map.contains(param.asInstanceOf[Param[Any]]) + } + + /** + * Filters this param map for the given parent. + */ + def filter(parent: Params): ParamMap = { + val filtered = map.filterKeys(_.parent == parent) + new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]]) + } + + /** + * Make a copy of this param map. + */ + def copy: ParamMap = new ParamMap(map.clone()) + + override def toString: String = { + map.map { case (param, value) => + s"\t${param.parent.uid}-${param.name}: $value" + }.mkString("{\n", ",\n", "\n}") + } + + /** + * Returns a new param map that contains parameters in this map and the given map, + * where the latter overwrites this if there exists conflicts. + */ + def ++(other: ParamMap): ParamMap = { + new ParamMap(this.map ++ other.map) + } + + + /** + * Adds all parameters from the input param map into this param map. + */ + def ++=(other: ParamMap): this.type = { + this.map ++= other.map + this + } + + /** + * Converts this param map to a sequence of param pairs. + */ + def toSeq: Seq[ParamPair[_]] = { + map.toSeq.map { case (param, value) => + ParamPair(param, value) + } + } +} + +object ParamMap { + + /** + * Returns an empty param map. + */ + def empty: ParamMap = new ParamMap() + + /** + * Constructs a param map by specifying its entries. + */ + @varargs + def apply(paramPairs: ParamPair[_]*): ParamMap = { + new ParamMap().put(paramPairs: _*) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala new file mode 100644 index 0000000000000..ef141d3eb2b06 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +private[ml] trait HasRegParam extends Params { + /** param for regularization parameter */ + val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + def getRegParam: Double = get(regParam) +} + +private[ml] trait HasMaxIter extends Params { + /** param for max number of iterations */ + val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + def getMaxIter: Int = get(maxIter) +} + +private[ml] trait HasFeaturesCol extends Params { + /** param for features column name */ + val featuresCol: Param[String] = + new Param(this, "featuresCol", "features column name", Some("features")) + def getFeaturesCol: String = get(featuresCol) +} + +private[ml] trait HasLabelCol extends Params { + /** param for label column name */ + val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) + def getLabelCol: String = get(labelCol) +} + +private[ml] trait HasScoreCol extends Params { + /** param for score column name */ + val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score")) + def getScoreCol: String = get(scoreCol) +} + +private[ml] trait HasPredictionCol extends Params { + /** param for prediction column name */ + val predictionCol: Param[String] = + new Param(this, "predictionCol", "prediction column name", Some("prediction")) + def getPredictionCol: String = get(predictionCol) +} + +private[ml] trait HasThreshold extends Params { + /** param for threshold in (binary) prediction */ + val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") + def getThreshold: Double = get(threshold) +} + +private[ml] trait HasInputCol extends Params { + /** param for input column name */ + val inputCol: Param[String] = new Param(this, "inputCol", "input column name") + def getInputCol: String = get(inputCol) +} + +private[ml] trait HasOutputCol extends Params { + /** param for output column name */ + val outputCol: Param[String] = new Param(this, "outputCol", "output column name") + def getOutputCol: String = get(outputCol) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala new file mode 100644 index 0000000000000..194b9bfd9a9e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import com.github.fommil.netlib.F2jBLAS + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.{SchemaRDD, StructType} + +/** + * Params for [[CrossValidator]] and [[CrossValidatorModel]]. + */ +private[ml] trait CrossValidatorParams extends Params { + /** param for the estimator to be cross-validated */ + val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + def getEstimator: Estimator[_] = get(estimator) + + /** param for estimator param maps */ + val estimatorParamMaps: Param[Array[ParamMap]] = + new Param(this, "estimatorParamMaps", "param maps for the estimator") + def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) + + /** param for the evaluator for selection */ + val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + def getEvaluator: Evaluator = get(evaluator) + + /** param for number of folds for cross validation */ + val numFolds: IntParam = + new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + def getNumFolds: Int = get(numFolds) +} + +/** + * :: AlphaComponent :: + * K-fold cross validation. + */ +@AlphaComponent +class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { + + private val f2jBLAS = new F2jBLAS + + def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + def setNumFolds(value: Int): this.type = set(numFolds, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + val map = this.paramMap ++ paramMap + val schema = dataset.schema + transformSchema(dataset.schema, paramMap, logging = true) + val sqlCtx = dataset.sqlContext + val est = map(estimator) + val eval = map(evaluator) + val epm = map(estimatorParamMaps) + val numModels = epm.size + val metrics = new Array[Double](epm.size) + val splits = MLUtils.kFold(dataset, map(numFolds), 0) + splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => + val trainingDataset = sqlCtx.applySchema(training, schema).cache() + val validationDataset = sqlCtx.applySchema(validation, schema).cache() + // multi-model training + logDebug(s"Train split $splitIndex with multiple sets of parameters.") + val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + var i = 0 + while (i < numModels) { + val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) + logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + metrics(i) += metric + i += 1 + } + } + f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) + logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + logInfo(s"Best cross-validation metric: $bestMetric.") + val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + val cvModel = new CrossValidatorModel(this, map, bestModel) + Params.inheritValues(map, this, cvModel) + cvModel + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + map(estimator).transformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model from k-fold cross validation. + */ +@AlphaComponent +class CrossValidatorModel private[ml] ( + override val parent: CrossValidator, + override val fittingParamMap: ParamMap, + val bestModel: Model[_]) + extends Model[CrossValidatorModel] with CrossValidatorParams { + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + bestModel.transform(dataset, paramMap) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + bestModel.transformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala new file mode 100644 index 0000000000000..dafe73d82c00a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.annotation.varargs +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ + +/** + * :: AlphaComponent :: + * Builder for a param grid used in grid search-based model selection. + */ +@AlphaComponent +class ParamGridBuilder { + + private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] + + /** + * Sets the given parameters in this grid to fixed values. + */ + def baseOn(paramMap: ParamMap): this.type = { + baseOn(paramMap.toSeq: _*) + this + } + + /** + * Sets the given parameters in this grid to fixed values. + */ + @varargs + def baseOn(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + addGrid(p.param.asInstanceOf[Param[Any]], Seq(p.value)) + } + this + } + + /** + * Adds a param with multiple values (overwrites if the input param exists). + */ + def addGrid[T](param: Param[T], values: Iterable[T]): this.type = { + paramGrid.put(param, values) + this + } + + // specialized versions of addGrid for Java. + + /** + * Adds a double param with multiple values. + */ + def addGrid(param: DoubleParam, values: Array[Double]): this.type = { + addGrid[Double](param, values) + } + + /** + * Adds a int param with multiple values. + */ + def addGrid(param: IntParam, values: Array[Int]): this.type = { + addGrid[Int](param, values) + } + + /** + * Adds a float param with multiple values. + */ + def addGrid(param: FloatParam, values: Array[Float]): this.type = { + addGrid[Float](param, values) + } + + /** + * Adds a long param with multiple values. + */ + def addGrid(param: LongParam, values: Array[Long]): this.type = { + addGrid[Long](param, values) + } + + /** + * Adds a boolean param with true and false. + */ + def addGrid(param: BooleanParam): this.type = { + addGrid[Boolean](param, Array(true, false)) + } + + /** + * Builds and returns all combinations of parameters specified by the param grid. + */ + def build(): Array[ParamMap] = { + var paramMaps = Array(new ParamMap) + paramGrid.foreach { case (param, values) => + val newParamMaps = values.flatMap { v => + paramMaps.map(_.copy.put(param.asInstanceOf[Param[Any]], v)) + } + paramMaps = newParamMaps.toArray + } + paramMaps + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 70d7138e3060f..c8476a5370b6c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -28,22 +28,22 @@ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.api.python.{PythonRDD, SerDeUtil} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ import org.apache.spark.mllib.feature._ -import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.test.ChiSqTestResult +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -103,9 +103,11 @@ class PythonMLLibAPI extends Serializable { lrAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { lrAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") + } else if (regType == null) { + lrAlg.optimizer.setUpdater(new SimpleUpdater) + } else { + throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: ['l1', 'l2', None].") } trainRegressionModel( lrAlg, @@ -180,9 +182,11 @@ class PythonMLLibAPI extends Serializable { SVMAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { SVMAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { + } else if (regType == null) { + SVMAlg.optimizer.setUpdater(new SimpleUpdater) + } else { throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") + + " Can only be initialized using the following string values: ['l1', 'l2', None].") } trainRegressionModel( SVMAlg, @@ -213,9 +217,11 @@ class PythonMLLibAPI extends Serializable { LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) } else if (regType == "l1") { LogRegAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { + } else if (regType == null) { + LogRegAlg.optimizer.setUpdater(new SimpleUpdater) + } else { throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") + + " Can only be initialized using the following string values: ['l1', 'l2', None].") } trainRegressionModel( LogRegAlg, @@ -250,7 +256,7 @@ class PythonMLLibAPI extends Serializable { .setInitializationMode(initializationMode) // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. .disableUncachedWarning() - return kMeansAlg.run(data.rdd) + kMeansAlg.run(data.rdd) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 84d3c7cebd7c8..18b95f1edc0b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -71,9 +71,10 @@ class LogisticRegressionModel ( } /** - * Train a classification model for Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * + * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By + * default L2 regularization is used, which can be changed via + * [[LogisticRegressionWithSGD.optimizer]]. + * NOTE: Labels used in Logistic Regression should be {0, 1}. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( @@ -93,9 +94,10 @@ class LogisticRegressionWithSGD private ( override protected val validators = List(DataValidators.binaryLabelValidator) /** - * Construct a LogisticRegression object with default parameters + * Construct a LogisticRegression object with default parameters: {stepSize: 1.0, + * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 0.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 80f8a1b2f1e84..ab9515b2a6db8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -72,7 +72,8 @@ class SVMModel ( } /** - * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. + * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2 + * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. * NOTE: Labels used in SVM should be {0, 1}. */ class SVMWithSGD private ( @@ -92,9 +93,10 @@ class SVMWithSGD private ( override protected val validators = List(DataValidators.binaryLabelValidator) /** - * Construct a SVM object with default parameters + * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, + * regParm: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new SVMModel(weights, intercept) @@ -185,6 +187,6 @@ object SVMWithSGD { * @return a SVMModel which has the weights and offset from training. */ def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala index 562663ad36b40..be3319d60ce25 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -24,26 +24,43 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl def apply(c: BinaryConfusionMatrix): Double } -/** Precision. */ +/** Precision. Defined as 1.0 when there are no positive examples. */ private[evaluation] object Precision extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives) + override def apply(c: BinaryConfusionMatrix): Double = { + val totalPositives = c.numTruePositives + c.numFalsePositives + if (totalPositives == 0) { + 1.0 + } else { + c.numTruePositives.toDouble / totalPositives + } + } } -/** False positive rate. */ +/** False positive rate. Defined as 0.0 when there are no negative examples. */ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - c.numFalsePositives.toDouble / c.numNegatives + override def apply(c: BinaryConfusionMatrix): Double = { + if (c.numNegatives == 0) { + 0.0 + } else { + c.numFalsePositives.toDouble / c.numNegatives + } + } } -/** Recall. */ +/** Recall. Defined as 0.0 when there are no positive examples. */ private[evaluation] object Recall extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - c.numTruePositives.toDouble / c.numPositives + override def apply(c: BinaryConfusionMatrix): Double = { + if (c.numPositives == 0) { + 0.0 + } else { + c.numTruePositives.toDouble / c.numPositives + } + } } /** - * F-Measure. + * F-Measure. Defined as 0 if both precision and recall are 0. EG in the case that all examples + * are false positives. * @param beta the beta constant in F-Measure * @see http://en.wikipedia.org/wiki/F1_score */ @@ -52,6 +69,10 @@ private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificati override def apply(c: BinaryConfusionMatrix): Double = { val precision = Precision(c) val recall = Recall(c) - (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall) + if (precision + recall == 0) { + 0.0 + } else { + (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 54ee930d61003..89539e600f48c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -25,7 +25,7 @@ import org.apache.spark.Logging /** * BLAS routines for MLlib's vectors and matrices. */ -private[mllib] object BLAS extends Serializable with Logging { +private[spark] object BLAS extends Serializable with Logging { @transient private var _f2jBLAS: NetlibBLAS = _ @transient private var _nativeBLAS: NetlibBLAS = _ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index ac217edc619ab..60ab2aaa8f27a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -115,6 +115,9 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def deserialize(datum: Any): Vector = { datum match { + // TODO: something wrong with UDT serialization + case v: Vector => + v case row: Row => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") @@ -234,7 +237,7 @@ object Vectors { private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => - if (v.offset == 0 && v.stride == 1) { + if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { new DenseVector(v.data) } else { new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 84d192db53e26..038edc3521f14 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -20,20 +20,20 @@ package org.apache.spark.mllib.recommendation import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.{abs, sqrt} -import scala.util.Random -import scala.util.Sorting +import scala.util.{Random, Sorting} import scala.util.hashing.byteswap32 import org.jblas.{DoubleMatrix, SimpleBlas, Solve} +import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.broadcast.Broadcast -import org.apache.spark.{Logging, HashPartitioner, Partitioner} -import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -import org.apache.spark.mllib.optimization.NNLS /** * Out-link information for a user or product block. This includes the original user/product IDs @@ -325,6 +325,11 @@ class ALS private ( new MatrixFactorizationModel(rank, usersOut, productsOut) } + /** + * Java-friendly version of [[ALS.run]]. + */ + def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd) + /** * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors * for each user (or product), in a distributed fashion. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 66b58ba770160..969e23be21623 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.recommendation +import java.lang.{Integer => JavaInteger} + import org.jblas.DoubleMatrix -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.api.python.SerDe +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.rdd.RDD /** * Model representing the result of matrix factorization. @@ -65,6 +65,13 @@ class MatrixFactorizationModel private[mllib] ( } } + /** + * Java-friendly version of [[MatrixFactorizationModel.predict]]. + */ + def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { + predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() + } + /** * Recommends products to a user. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 17c753c56681f..2067b36f246b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.regression +import scala.beans.BeanInfo + import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException @@ -27,6 +29,7 @@ import org.apache.spark.SparkException * @param label Label for this data point. * @param features List of features for this data point. */ +@BeanInfo case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { "(%s,%s)".format(label, features) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index cb0d39e759a9f..f9791c6571782 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -67,9 +67,9 @@ class LassoWithSGD private ( /** * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 1.0, miniBatchFraction: 1.0}. + * regParam: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new LassoModel(weights, intercept) @@ -161,6 +161,6 @@ object LassoWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int): LassoModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index a826deb695ee1..c8cad773f5efb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -68,9 +68,9 @@ class RidgeRegressionWithSGD private ( /** * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 1.0, miniBatchFraction: 1.0}. + * regParam: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new RidgeRegressionModel(weights, intercept) @@ -143,7 +143,7 @@ object RidgeRegressionWithSGD { numIterations: Int, stepSize: Double, regParam: Double): RidgeRegressionModel = { - train(input, numIterations, stepSize, regParam, 1.0) + train(input, numIterations, stepSize, regParam, 0.01) } /** @@ -158,6 +158,6 @@ object RidgeRegressionWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int): RidgeRegressionModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index fab7c4405c65d..654479ac2dd4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -49,6 +49,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var currMax: BDV[Double] = _ private var currMin: BDV[Double] = _ + /** + * Adds input value to position i. + */ + private[this] def add(i: Int, value: Double) = { + if (value != 0.0) { + if (currMax(i) < value) { + currMax(i) = value + } + if (currMin(i) > value) { + currMin(i) = value + } + + val prevMean = currMean(i) + val diff = value - prevMean + currMean(i) = prevMean + diff / (nnz(i) + 1.0) + currM2n(i) += (value - currMean(i)) * diff + currM2(i) += value * value + currL1(i) += math.abs(value) + + nnz(i) += 1.0 + } + } + /** * Add a new sample to this summarizer, and update the statistical summary. * @@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - @inline def update(i: Int, value: Double) = { - if (value != 0.0) { - if (currMax(i) < value) { - currMax(i) = value - } - if (currMin(i) > value) { - currMin(i) = value - } - - val tmpPrevMean = currMean(i) - currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) - currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) - currM2(i) += value * value - currL1(i) += math.abs(value) - - nnz(i) += 1.0 - } - } - sample match { case dv: DenseVector => { var j = 0 while (j < dv.size) { - update(j, dv.values(j)) + add(j, dv.values(j)) j += 1 } } case sv: SparseVector => var j = 0 while (j < sv.indices.size) { - update(sv.indices(j), sv.values(j)) + add(sv.indices(j), sv.values(j)) j += 1 } case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) @@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt - val deltaMean: BDV[Double] = currMean - other.currMean var i = 0 while (i < n) { - // merge mean together - if (other.currMean(i) != 0.0) { - currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / - (nnz(i) + other.nnz(i)) - } - // merge m2n together - if (nnz(i) + other.nnz(i) != 0.0) { - currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / - (nnz(i) + other.nnz(i)) - } - // merge m2 together - if (nnz(i) + other.nnz(i) != 0.0) { + val thisNnz = nnz(i) + val otherNnz = other.nnz(i) + val totalNnz = thisNnz + otherNnz + if (totalNnz != 0.0) { + val deltaMean = other.currMean(i) - currMean(i) + // merge mean together + currMean(i) += deltaMean * otherNnz / totalNnz + // merge m2n together + currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz + // merge m2 together currM2(i) += other.currM2(i) - } - // merge l1 together - if (nnz(i) + other.nnz(i) != 0.0) { + // merge l1 together currL1(i) += other.currL1(i) + // merge max and min + currMax(i) = math.max(currMax(i), other.currMax(i)) + currMin(i) = math.min(currMin(i), other.currMin(i)) } - - if (currMax(i) < other.currMax(i)) { - currMax(i) = other.currMax(i) - } - if (currMin(i) > other.currMin(i)) { - currMin(i) = other.currMin(i) - } + nnz(i) = totalNnz i += 1 } - nnz += other.nnz } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n this.currMean = other.currMean.copy diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ec1d99ab26f9c..ac4d02ee3928b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.tree.model +import org.apache.spark.api.java.JavaRDD import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD @@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } + + /** + * Predict values for the given data set using the model trained. + * + * @param features JavaRDD representing data points to be predicted + * @return JavaRDD of predictions for each of the given data points + */ + def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { + predict(features.rdd) + } + /** * Get number of nodes in tree, including leaf nodes. */ diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java new file mode 100644 index 0000000000000..42846677ed285 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +/** + * Test Pipeline construction and fitting in Java. + */ +public class JavaPipelineSuite { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaPipelineSuite"); + jsql = new JavaSQLContext(jsc); + JavaRDD points = + jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); + dataset = jsql.applySchema(points, LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void pipeline() { + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + LogisticRegression lr = new LogisticRegression() + .setFeaturesCol("scaledFeatures"); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {scaler, lr}); + PipelineModel model = pipeline.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java new file mode 100644 index 0000000000000..76eb7f00329f2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaLogisticRegressionSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + jsql = new JavaSQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void logisticRegression() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionWithSetters() { + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold + .registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionFitWithVarargs() { + LogisticRegression lr = new LogisticRegression(); + lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0)); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java new file mode 100644 index 0000000000000..a266ebd2071a1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaCrossValidatorSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); + jsql = new JavaSQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void crossValidationWithLogisticRegression() { + LogisticRegression lr = new LogisticRegression(); + ParamMap[] lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.001, 1000.0}) + .addGrid(lr.maxIter(), new int[] {0, 10}) + .build(); + BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator(); + CrossValidator cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3); + CrossValidatorModel cvModel = cv.fit(dataset); + ParamMap bestParamMap = cvModel.bestModel().fittingParamMap(); + Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam())); + Assert.assertEquals(10, bestParamMap.apply(lr.maxIter())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index f6ca9643227f8..af688c504cf1e 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -23,13 +23,14 @@ import scala.Tuple2; import scala.Tuple3; +import com.google.common.collect.Lists; import org.jblas.DoubleMatrix; - import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -47,61 +48,48 @@ public void tearDown() { sc = null; } - static void validatePrediction( + void validatePrediction( MatrixFactorizationModel model, int users, int products, - int features, DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { - DoubleMatrix predictedU = new DoubleMatrix(users, features); - List> userFeatures = model.userFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (Tuple2 userFeature : userFeatures) { - predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); - } - } - DoubleMatrix predictedP = new DoubleMatrix(products, features); - - List> productFeatures = - model.productFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (Tuple2 productFeature : productFeatures) { - predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); + List> localUsersProducts = + Lists.newArrayListWithCapacity(users * products); + for (int u=0; u < users; ++u) { + for (int p=0; p < products; ++p) { + localUsersProducts.add(new Tuple2(u, p)); } } - - DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); - + JavaPairRDD usersProducts = sc.parallelizePairs(localUsersProducts); + List predictedRatings = model.predict(usersProducts).collect(); + Assert.assertEquals(users * products, predictedRatings.size()); if (!implicitPrefs) { - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double correct = trueRatings.get(u, p); - Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", - prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); - } + for (Rating r: predictedRatings) { + double prediction = r.rating(); + double correct = trueRatings.get(r.user(), r.product()); + Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", + prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); } } else { // For implicit prefs we use the confidence-weighted RMSE to test // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double truePref = truePrefs.get(u, p); - double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p)); - double err = confidence * (truePref - prediction) * (truePref - prediction); - sqErr += err; - denom += confidence; - } + for (Rating r: predictedRatings) { + double prediction = r.rating(); + double truePref = truePrefs.get(r.user(), r.product()); + double confidence = 1.0 + + /* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product())); + double err = confidence * (truePref - prediction) * (truePref - prediction); + sqErr += err; + denom += confidence; } double rmse = Math.sqrt(sqErr / denom); Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", - rmse, matchThreshold), rmse < matchThreshold); + rmse, matchThreshold), rmse < matchThreshold); } } @@ -116,7 +104,7 @@ public void runALSUsingStaticMethods() { JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @Test @@ -132,8 +120,8 @@ public void runALSUsingConstructor() { MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) - .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + .run(data); + validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @Test @@ -147,7 +135,7 @@ public void runImplicitALSUsingStaticMethods() { JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test @@ -165,7 +153,7 @@ public void runImplicitALSUsingConstructor() { .setIterations(iterations) .setImplicitPrefs(true) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test @@ -183,7 +171,7 @@ public void runImplicitALSWithNegativeWeight() { .setImplicitPrefs(true) .setSeed(8675309L) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala new file mode 100644 index 0000000000000..4515084bc7ae9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.when +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.SchemaRDD + +class PipelineSuite extends FunSuite { + + abstract class MyModel extends Model[MyModel] + + test("pipeline") { + val estimator0 = mock[Estimator[MyModel]] + val model0 = mock[MyModel] + val transformer1 = mock[Transformer] + val estimator2 = mock[Estimator[MyModel]] + val model2 = mock[MyModel] + val transformer3 = mock[Transformer] + val dataset0 = mock[SchemaRDD] + val dataset1 = mock[SchemaRDD] + val dataset2 = mock[SchemaRDD] + val dataset3 = mock[SchemaRDD] + val dataset4 = mock[SchemaRDD] + + when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) + when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) + when(model0.parent).thenReturn(estimator0) + when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2) + when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2) + when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3) + when(model2.parent).thenReturn(estimator2) + when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4) + + val pipeline = new Pipeline() + .setStages(Array(estimator0, transformer1, estimator2, transformer3)) + val pipelineModel = pipeline.fit(dataset0) + + assert(pipelineModel.stages.size === 4) + assert(pipelineModel.stages(0).eq(model0)) + assert(pipelineModel.stages(1).eq(transformer1)) + assert(pipelineModel.stages(2).eq(model2)) + assert(pipelineModel.stages(3).eq(transformer3)) + + assert(pipelineModel.getModel(estimator0).eq(model0)) + assert(pipelineModel.getModel(estimator2).eq(model2)) + intercept[NoSuchElementException] { + pipelineModel.getModel(mock[Estimator[MyModel]]) + } + val output = pipelineModel.transform(dataset0) + assert(output.eq(dataset4)) + } + + test("pipeline with duplicate stages") { + val estimator = mock[Estimator[MyModel]] + val pipeline = new Pipeline() + .setStages(Array(estimator, estimator)) + val dataset = mock[SchemaRDD] + intercept[IllegalArgumentException] { + pipeline.fit(dataset) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala new file mode 100644 index 0000000000000..e8030fef55b1d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{SQLContext, SchemaRDD} + +class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: SchemaRDD = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + dataset = sqlContext.createSchemaRDD( + sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + } + + test("logistic regression") { + val sqlContext = this.sqlContext + import sqlContext._ + val lr = new LogisticRegression + val model = lr.fit(dataset) + model.transform(dataset) + .select('label, 'prediction) + .collect() + } + + test("logistic regression with setters") { + val sqlContext = this.sqlContext + import sqlContext._ + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + val model = lr.fit(dataset) + model.transform(dataset, model.threshold -> 0.8) // overwrite threshold + .select('label, 'score, 'prediction) + .collect() + } + + test("logistic regression fit and transform with varargs") { + val sqlContext = this.sqlContext + import sqlContext._ + val lr = new LogisticRegression + val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) + model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") + .select('label, 'probability, 'prediction) + .collect() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala new file mode 100644 index 0000000000000..1ce2987612378 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +import org.scalatest.FunSuite + +class ParamsSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param") { + assert(maxIter.name === "maxIter") + assert(maxIter.doc === "max number of iterations") + assert(maxIter.defaultValue.get === 100) + assert(maxIter.parent.eq(solver)) + assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") + assert(inputCol.defaultValue === None) + } + + test("param pair") { + val pair0 = maxIter -> 5 + val pair1 = maxIter.w(5) + val pair2 = ParamPair(maxIter, 5) + for (pair <- Seq(pair0, pair1, pair2)) { + assert(pair.param.eq(maxIter)) + assert(pair.value === 5) + } + } + + test("param map") { + val map0 = ParamMap.empty + + assert(!map0.contains(maxIter)) + assert(map0(maxIter) === maxIter.defaultValue.get) + map0.put(maxIter, 10) + assert(map0.contains(maxIter)) + assert(map0(maxIter) === 10) + + assert(!map0.contains(inputCol)) + intercept[NoSuchElementException] { + map0(inputCol) + } + map0.put(inputCol -> "input") + assert(map0.contains(inputCol)) + assert(map0(inputCol) === "input") + + val map1 = map0.copy + val map2 = ParamMap(maxIter -> 10, inputCol -> "input") + val map3 = new ParamMap() + .put(maxIter, 10) + .put(inputCol, "input") + val map4 = ParamMap.empty ++ map0 + val map5 = ParamMap.empty + map5 ++= map0 + + for (m <- Seq(map1, map2, map3, map4, map5)) { + assert(m.contains(maxIter)) + assert(m(maxIter) === 10) + assert(m.contains(inputCol)) + assert(m(inputCol) === "input") + } + } + + test("params") { + val params = solver.params + assert(params.size === 2) + assert(params(0).eq(inputCol), "params must be ordered by name") + assert(params(1).eq(maxIter)) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) + assert(solver.getParam("maxIter").eq(maxIter)) + intercept[NoSuchMethodException] { + solver.getParam("abc") + } + assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { + solver.validate() + } + solver.validate(ParamMap(inputCol -> "input")) + solver.setInputCol("input") + assert(solver.isSet(inputCol)) + assert(solver.getInputCol === "input") + solver.validate() + intercept[IllegalArgumentException] { + solver.validate(ParamMap(maxIter -> -10)) + } + solver.setMaxIter(-10) + intercept[IllegalArgumentException] { + solver.validate() + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala new file mode 100644 index 0000000000000..1a65883d78a71 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +/** A subclass of Params for testing. */ +class TestParams extends Params { + + val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) + def setMaxIter(value: Int): this.type = { set(maxIter, value); this } + def getMaxIter: Int = get(maxIter) + + val inputCol = new Param[String](this, "inputCol", "input column name") + def setInputCol(value: String): this.type = { set(inputCol, value); this } + def getInputCol: String = get(inputCol) + + override def validate(paramMap: ParamMap) = { + val m = this.paramMap ++ paramMap + require(m(maxIter) >= 0) + require(m.contains(inputCol)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala new file mode 100644 index 0000000000000..41cc13da4d5b1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.scalatest.FunSuite + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{SQLContext, SchemaRDD} + +class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { + + @transient var dataset: SchemaRDD = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sqlContext = new SQLContext(sc) + dataset = sqlContext.createSchemaRDD( + sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + } + + test("cross validation with logistic regression") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + val cvModel = cv.fit(dataset) + val bestParamMap = cvModel.bestModel.fittingParamMap + assert(bestParamMap(lr.regParam) === 0.001) + assert(bestParamMap(lr.maxIter) === 10) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala new file mode 100644 index 0000000000000..20aa100112bfe --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.ml.param.{ParamMap, TestParams} + +class ParamGridBuilderSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param grid builder") { + def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit = { + assert(maps.size === expected.size) + maps.foreach { m => + val tuple = (m(maxIter), m(inputCol)) + assert(expected.contains(tuple)) + expected.remove(tuple) + } + assert(expected.isEmpty) + } + + val maps0 = new ParamGridBuilder() + .baseOn(maxIter -> 10) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected0 = mutable.Set( + (10, "input0"), + (10, "input1")) + validateGrid(maps0, expected0) + + val maps1 = new ParamGridBuilder() + .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten + .addGrid(maxIter, Array(10, 20)) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected1 = mutable.Set( + (10, "input0"), + (20, "input0"), + (10, "input1"), + (20, "input1")) + validateGrid(maps1, expected1) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index e954baaf7d91e..4e812994405b3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ object LogisticRegressionSuite { @@ -57,7 +57,7 @@ object LogisticRegressionSuite { } } -class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers { +class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -80,13 +80,16 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val testRDD = sc.parallelize(testData, 2) testRDD.cache() val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer.setStepSize(10.0).setNumIterations(20) + lr.optimizer + .setStepSize(10.0) + .setRegParam(0.0) + .setNumIterations(20) val model = lr.run(testRDD) // Test the weights - assert(model.weights(0) ~== -1.52 relTol 0.01) - assert(model.intercept ~== 2.00 relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -112,10 +115,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD) // Test the weights - assert(model.weights(0) ~== -1.52 relTol 0.01) - assert(model.intercept ~== 2.00 relTol 0.01) - assert(model.weights(0) ~== model.weights(0) relTol 0.01) - assert(model.intercept ~== model.intercept relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -141,13 +142,16 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match // Use half as many iterations as the previous test. val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer.setStepSize(10.0).setNumIterations(10) + lr.optimizer + .setStepSize(10.0) + .setRegParam(0.0) + .setNumIterations(10) val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -1.50 relTol 0.01) - assert(model.intercept ~== 1.97 relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -212,8 +216,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -1.50 relTol 0.02) - assert(model.intercept ~== 1.97 relTol 0.02) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 80989bc074e84..e68fe89d6ccea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} object NaiveBayesSuite { @@ -60,7 +60,7 @@ object NaiveBayesSuite { } } -class NaiveBayesSuite extends FunSuite with LocalSparkContext { +class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOfPredictions = predictions.zip(input).count { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 65e5df58db4c7..a2de7fbd41383 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} object SVMSuite { @@ -58,7 +58,7 @@ object SVMSuite { } -class SVMSuite extends FunSuite with LocalSparkContext { +class SVMSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index afa1f79b95a12..9ebef8466c831 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -22,10 +22,10 @@ import scala.util.Random import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class KMeansSuite extends FunSuite with LocalSparkContext { +class KMeansSuite extends FunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 994e0feb8629e..79847633ff0dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { +class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index a733f88b60b80..8a18e2971cab6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -19,44 +19,109 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { +class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext { - def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 + private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 - def cond2(x: ((Double, Double), (Double, Double))): Boolean = + private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean = (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) + private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = { + assert(left.zip(right).forall(areWithinEpsilon)) + } + + private def assertTupleSequencesMatch(left: Seq[(Double, Double)], + right: Seq[(Double, Double)]): Unit = { + assert(left.zip(right).forall(pairsWithinEpsilon)) + } + + private def validateMetrics(metrics: BinaryClassificationMetrics, + expectedThresholds: Seq[Double], + expectedROCCurve: Seq[(Double, Double)], + expectedPRCurve: Seq[(Double, Double)], + expectedFMeasures1: Seq[Double], + expectedFmeasures2: Seq[Double], + expectedPrecisions: Seq[Double], + expectedRecalls: Seq[Double]) = { + + assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds) + assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve) + assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve) absTol 1E-5) + assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve) + assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve) absTol 1E-5) + assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), + expectedThresholds.zip(expectedFMeasures1)) + assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), + expectedThresholds.zip(expectedFmeasures2)) + assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), + expectedThresholds.zip(expectedPrecisions)) + assertTupleSequencesMatch(metrics.recallByThreshold().collect(), + expectedThresholds.zip(expectedRecalls)) + } + test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) val metrics = new BinaryClassificationMetrics(scoreAndLabels) - val threshold = Seq(0.8, 0.6, 0.4, 0.1) + val thresholds = Seq(0.8, 0.6, 0.4, 0.1) val numTruePositives = Seq(1, 3, 3, 4) val numFalsePositives = Seq(0, 1, 2, 3) val numPositives = 4 val numNegatives = 3 - val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) => + val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) => t.toDouble / (t + f) } - val recall = numTruePositives.map(t => t.toDouble / numPositives) + val recalls = numTruePositives.map(t => t.toDouble / numPositives) val fpr = numFalsePositives.map(f => f.toDouble / numNegatives) - val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) - val pr = recall.zip(precision) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(metrics.thresholds().collect().zip(threshold).forall(cond1)) - assert(metrics.roc().collect().zip(rocCurve).forall(cond2)) - assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5) - assert(metrics.pr().collect().zip(prCurve).forall(cond2)) - assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5) - assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2)) - assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2)) - assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2)) - assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2)) + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + } + + test("binary evaluation metrics for RDD where all examples have positive label") { + val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + + val thresholds = Seq(0.5) + val precisions = Seq(1.0) + val recalls = Seq(1.0) + val fpr = Seq(0.0) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} + val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + } + + test("binary evaluation metrics for RDD where all examples have negative label") { + val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + + val thresholds = Seq(0.5) + val precisions = Seq(0.0) + val recalls = Seq(0.0) + val fpr = Seq(1.0) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { + case (0, 0) => 0.0 + case (r, p) => 2.0 * (p * r) / (p + r) + } + val f2 = pr.map { + case (0, 0) => 0.0 + case (r, p) => 5.0 * (p * r) / (4.0 * p + r) + } + + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 1ea503971c864..7dc4f3cfbc4e4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Matrices -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { +class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala index 342baa0274e9c..2537dd62c92f2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -class MultilabelMetricsSuite extends FunSuite with LocalSparkContext { +class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext { test("Multilabel evaluation metrics") { /* * Documents true labels (5x class0, 3x class1, 4x class2): diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index a2d4bb41484b8..609eed983ff4e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class RankingMetricsSuite extends FunSuite with LocalSparkContext { +class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { val predictionAndLabels = sc.parallelize( Seq( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 5396d7b2b74fa..670b4c34e6095 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class RegressionMetricsSuite extends FunSuite with LocalSparkContext { +class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { test("regression metrics") { val predictionAndObservations = sc.parallelize( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index a599e0d938569..0c4dfb7b97c7f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class HashingTFSuite extends FunSuite with LocalSparkContext { +class HashingTFSuite extends FunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") { val hashingTF = new HashingTF(1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 43974f84e3ca8..30147e7fd948f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.FunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class IDFSuite extends FunSuite with LocalSparkContext { +class IDFSuite extends FunSuite with MLlibTestSparkContext { test("idf") { val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index 2bf9d9816ae45..85fdd271b5ed1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -22,10 +22,10 @@ import org.scalatest.FunSuite import breeze.linalg.{norm => brzNorm} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class NormalizerSuite extends FunSuite with LocalSparkContext { +class NormalizerSuite extends FunSuite with MLlibTestSparkContext { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index e217b93cebbdb..4c93c0ca4f86c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD -class StandardScalerSuite extends FunSuite with LocalSparkContext { +class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { data.treeAggregate(new MultivariateOnlineSummarizer)( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index e34335d89eb75..52278690dbd89 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class Word2VecSuite extends FunSuite with LocalSparkContext { +class Word2VecSuite extends FunSuite with MLlibTestSparkContext { // TODO: add more tests diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 93a84fe07b32a..59cd85eab27d0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.linalg +import breeze.linalg.{DenseMatrix => BDM} import org.scalatest.FunSuite import org.apache.spark.SparkException @@ -166,4 +167,10 @@ class VectorsSuite extends FunSuite { assert(v === udt.deserialize(udt.serialize(v))) } } + + test("fromBreeze") { + val x = BDM.zeros[Double](10, 10) + val v = Vectors.fromBreeze(x(::, 0)) + assert(v.size === x.rows) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index cd45438fb628f..f8709751efce6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.FunSuite import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors -class CoordinateMatrixSuite extends FunSuite with LocalSparkContext { +class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index f7c46f23b746d..e25bc02b06c9a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -21,11 +21,11 @@ import org.scalatest.FunSuite import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vectors} -class IndexedRowMatrixSuite extends FunSuite with LocalSparkContext { +class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 63f3ed58c0d4d..dbf55ff81ca99 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -23,9 +23,9 @@ import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, s import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} -class RowMatrixSuite extends FunSuite with LocalSparkContext { +class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index bf040110e228b..86481c6e66200 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -61,7 +61,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers { +class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index ccba004baa007..70c64775e4c04 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -23,10 +23,10 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { +class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { val nPoints = 10000 val A = 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index c50b78bcbcc61..ea5889b3ecd5e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.util.StatCounter @@ -34,7 +34,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDsSuite extends FunSuite with LocalSparkContext with Serializable { +class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 4ef67a40b9f49..681ce9263933b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.rdd import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ -class RDDFunctionsSuite extends FunSuite with LocalSparkContext { +class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 017c39edb185f..603d0ad127b86 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import org.jblas.DoubleMatrix import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.recommendation.ALS.BlockStats object ALSSuite { @@ -85,7 +85,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with LocalSparkContext { +class ALSSuite extends FunSuite with MLlibTestSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 7aa96421aed87..2668dcc14a842 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -23,9 +23,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, - LocalSparkContext} + MLlibTestSparkContext} -class LassoSuite extends FunSuite with LocalSparkContext { +class LassoSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 4f89112b650c5..864622a9296a6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -23,9 +23,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, - LocalSparkContext} + MLlibTestSparkContext} -class LinearRegressionSuite extends FunSuite with LocalSparkContext { +class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 727bbd051ff15..18d3bf5ea4eca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -24,9 +24,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, - LocalSparkContext} + MLlibTestSparkContext} -class RidgeRegressionSuite extends FunSuite with LocalSparkContext { +class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { predictions.zip(input).map { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index 34548c86ebc14..d20a09b4b4925 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -24,9 +24,9 @@ import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends FunSuite with LocalSparkContext { +class CorrelationSuite extends FunSuite with MLlibTestSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 6de3840b3f198..15418e6035965 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -25,10 +25,10 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class HypothesisTestSuite extends FunSuite with LocalSparkContext { +class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { test("chi squared pearson goodness of fit") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 1e9415249104b..23b0eec865de6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") } + + test("merging summarizer when one side has zero mean (SPARK-4355)") { + val s0 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(2.0)) + .add(Vectors.dense(2.0)) + val s1 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(1.0)) + .add(Vectors.dense(-1.0)) + s0.merge(s1) + assert(s0.mean(0) ~== 1.0 absTol 1e-14) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index c579cb58549f5..972c905ec9ffa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class DecisionTreeSuite extends FunSuite with LocalSparkContext { +class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala index 99a02eda60baf..84de40103d8aa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala @@ -25,17 +25,17 @@ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[GradientBoosting]]. */ -class GradientBoostingSuite extends FunSuite with LocalSparkContext { +class GradientBoostingSuite extends FunSuite with MLlibTestSparkContext { test("Regression with continuous features: SquaredError") { GradientBoostingSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -53,7 +53,7 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) - EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) + EnsembleTestHelper.validateRegressor(gbt, arr, 0.03) // Make sure trees are the same. assert(gbtTree.toString == dt.toString) @@ -63,7 +63,7 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { test("Regression with continuous features: Absolute Error") { GradientBoostingSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] @@ -81,7 +81,7 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { assert(gbt.weakHypotheses.size === numIterations) val gbtTree = gbt.weakHypotheses(0) - EnsembleTestHelper.validateRegressor(gbt, arr, 0.02) + EnsembleTestHelper.validateRegressor(gbt, arr, 0.03) // Make sure trees are the same. assert(gbtTree.toString == dt.toString) @@ -91,7 +91,7 @@ class GradientBoostingSuite extends FunSuite with LocalSparkContext { test("Binary classification with continuous features: Log Loss") { GradientBoostingSuite.testCombinations.foreach { case (numIterations, learningRate, subsamplingRate) => - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) val rdd = sc.parallelize(arr) val categoricalFeaturesInfo = Map.empty[Int, Int] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 73c4393c3581a..2734e089d62e6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -28,12 +28,12 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} import org.apache.spark.mllib.tree.model.Node -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[RandomForest]]. */ -class RandomForestSuite extends FunSuite with LocalSparkContext { +class RandomForestSuite extends FunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala index 5cb433232e714..b184e936672ca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree.impl import org.scalatest.FunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[BaggedPoint]]. */ -class BaggedPointSuite extends FunSuite with LocalSparkContext { +class BaggedPointSuite extends FunSuite with MLlibTestSparkContext { test("BaggedPoint RDD: without subsampling") { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 0dbe766b4d917..88bc49cc61f94 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.util.Utils -class MLUtilsSuite extends FunSuite with LocalSparkContext { +class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala similarity index 89% rename from mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala rename to mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 7857d9e5ee5c4..b658889476d37 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -22,15 +22,15 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SparkConf, SparkContext} -trait LocalSparkContext extends BeforeAndAfterAll { self: Suite => +trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ override def beforeAll() { + super.beforeAll() val conf = new SparkConf() - .setMaster("local") - .setAppName("test") + .setMaster("local[2]") + .setAppName("MLlibUnitTest") sc = new SparkContext(conf) - super.beforeAll() } override def afterAll() { diff --git a/network/common/pom.xml b/network/common/pom.xml index 8b24ebf1ba1f2..2bd0a7d2945dd 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -41,16 +41,16 @@ io.netty netty-all + + org.slf4j slf4j-api + provided - - com.google.guava guava - 11.0.2 provided diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 397d3a8455c86..76bce8592816a 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -118,7 +118,8 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()); // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()); + bootstrap.option(ChannelOption.ALLOCATOR, NettyUtils.createPooledByteBufAllocator( + conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads())); final AtomicReference clientRef = new AtomicReference(); @@ -190,34 +191,4 @@ public void close() { workerGroup = null; } } - - /** - * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches - * are disabled because the ByteBufs are allocated by the event loop thread, but released by the - * executor thread rather than the event loop thread. Those thread-local caches actually delay - * the recycling of buffers, leading to larger memory usage. - */ - private PooledByteBufAllocator createPooledByteBufAllocator() { - return new PooledByteBufAllocator( - conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(), - getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), - getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), - getPrivateStaticField("DEFAULT_PAGE_SIZE"), - getPrivateStaticField("DEFAULT_MAX_ORDER"), - 0, // tinyCacheSize - 0, // smallCacheSize - 0 // normalCacheSize - ); - } - - /** Used to get defaults from Netty's private static fields. */ - private int getPrivateStaticField(String name) { - try { - Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); - f.setAccessible(true); - return f.getInt(null); - } catch (Exception e) { - throw new RuntimeException(e); - } - } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java index 579676c2c3564..625c3257d764e 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -72,8 +72,8 @@ private void init(int portToBind) { NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); EventLoopGroup workerGroup = bossGroup; - PooledByteBufAllocator allocator = new PooledByteBufAllocator( - conf.preferDirectBufs() && PlatformDependent.directBufferPreferred()); + PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator( + conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads()); bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java index 2a7664fe89388..5c654a6fd6ebe 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -17,9 +17,11 @@ package org.apache.spark.network.util; +import java.lang.reflect.Field; import java.util.concurrent.ThreadFactory; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; @@ -32,6 +34,7 @@ import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.util.internal.PlatformDependent; /** * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. @@ -103,4 +106,40 @@ public static String getRemoteAddress(Channel channel) { } return ""; } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled because the ByteBufs are allocated by the event loop thread, but released by the + * executor thread rather than the event loop thread. Those thread-local caches actually delay + * the recycling of buffers, leading to larger memory usage. + */ + public static PooledByteBufAllocator createPooledByteBufAllocator( + boolean allowDirectBufs, + boolean allowCache, + int numCores) { + if (numCores == 0) { + numCores = Runtime.getRuntime().availableProcessors(); + } + return new PooledByteBufAllocator( + allowDirectBufs && PlatformDependent.directBufferPreferred(), + Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores), + Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), allowDirectBufs ? numCores : 0), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + allowCache ? getPrivateStaticField("DEFAULT_TINY_CACHE_SIZE") : 0, + allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0, + allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0 + ); + } + + /** Used to get defaults from Netty's private static fields. */ + private static int getPrivateStaticField(String name) { + try { + Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.getInt(null); + } catch (Exception e) { + throw new RuntimeException(e); + } + } } diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 27c8467687f10..12ff034cfe588 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -39,26 +39,26 @@ org.apache.spark - spark-network-common_2.10 + spark-network-common_${scala.binary.version} ${project.version} + + org.slf4j slf4j-api + provided - - com.google.guava guava - 11.0.2 provided org.apache.spark - spark-network-common_2.10 + spark-network-common_${scala.binary.version} ${project.version} test-jar test diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java index 60485bace643c..62fce9b0d16cd 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -23,6 +23,7 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to read a set of blocks. Returns {@link StreamHandle}. */ public class OpenBlocks extends BlockTransferMessage { diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java index 38acae3b31d64..7eb4385044077 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** * Initial registration message between an executor and its local shuffle server. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index 21369c8cfb0d6..bc9daa6158ba3 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -17,11 +17,11 @@ package org.apache.spark.network.shuffle.protocol; -import java.io.Serializable; - import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java index 38abe29cc585f..0b23e112bd512 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -23,6 +23,8 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ public class UploadBlock extends BlockTransferMessage { diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index 6e6f6f3e79296..7845011ec3200 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -39,7 +39,7 @@ org.apache.spark - spark-network-shuffle_2.10 + spark-network-shuffle_${scala.binary.version} ${project.version} @@ -54,5 +54,38 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + diff --git a/pom.xml b/pom.xml index 4e0cd6c151d0b..cc7bce175778f 100644 --- a/pom.xml +++ b/pom.xml @@ -97,30 +97,26 @@ sql/catalyst sql/core sql/hive - repl assembly external/twitter - external/kafka external/flume external/flume-sink - external/zeromq external/mqtt + external/zeromq examples + repl UTF-8 UTF-8 - + org.spark-project.akka + 2.3.4-spark 1.6 spark - 2.10.4 - 2.10 2.0.1 0.18.1 shaded-protobuf - org.spark-project.akka - 2.3.4-spark 1.7.5 1.2.17 1.0.4 @@ -137,7 +133,7 @@ 1.6.0rc3 1.2.3 8.1.14.v20131031 - 0.3.6 + 0.5.0 3.0.0 1.7.6 @@ -146,9 +142,13 @@ 1.1.0 4.2.6 3.1.1 - + ${project.build.directory}/spark-test-classpath.txt 64m 512m + 2.10.4 + 2.10 + ${scala.version} + org.scala-lang @@ -267,19 +267,66 @@ + + - org.spark-project.spark unused 1.0.0 + + + org.codehaus.groovy + groovy-all + 2.3.7 + provided + + + ${jline.groupid} + jline + ${jline.version} + + + com.twitter + chill_${scala.binary.version} + ${chill.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + + + com.twitter + chill-java + ${chill.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + org.eclipse.jetty jetty-util @@ -366,7 +413,7 @@ org.xerial.snappy snappy-java - 1.1.1.3 + 1.1.1.6 net.jpountz.lz4 @@ -395,36 +442,6 @@ protobuf-java ${protobuf.version} - - com.twitter - chill_${scala.binary.version} - ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - - - - com.twitter - chill-java - ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - - ${akka.group} akka-actor_${scala.binary.version} @@ -512,11 +529,6 @@ scala-reflect ${scala.version} - - org.scala-lang - jline - ${scala.version} - org.scala-lang scala-library @@ -965,6 +977,8 @@ ${session.executionRootDirectory} 1 false + ${test_classpath} + true @@ -1026,6 +1040,47 @@ + + + org.apache.maven.plugins + maven-dependency-plugin + 2.9 + + + test-compile + + build-classpath + + + test + ${test_classpath_file} + + + + + + + + org.codehaus.gmavenplus + gmavenplus-plugin + 1.2 + + + process-test-classes + + execute + + + + + + + + + org.apache.maven.plugins @@ -1242,9 +1297,6 @@ mapr3 - - false - 1.0.3-mapr-3.0.3 2.3.0-mapr-4.0.0-FCS @@ -1255,9 +1307,6 @@ mapr4 - - false - 2.3.0-mapr-4.0.0-FCS 2.3.0-mapr-4.0.0-FCS @@ -1287,9 +1336,6 @@ hadoop-provided - - false - org.apache.hadoop @@ -1335,19 +1381,13 @@ - hive - - false - + hive-thriftserver sql/hive-thriftserver hive-0.12.0 - - false - 0.12.0-protobuf-2.5 0.12.0 @@ -1356,14 +1396,41 @@ hive-0.13.1 - - false - 0.13.1a 0.13.1 10.10.1.1 + + + scala-2.10 + + !scala-2.11 + + + 2.10.4 + 2.10 + ${scala.version} + org.scala-lang + + + external/kafka + + + + + scala-2.11 + + scala-2.11 + + + 2.11.2 + 2.11 + 2.12 + jline + + + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a94d09be3bec6..8a2a865867fc4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -85,6 +85,10 @@ object MimaExcludes { "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), ProblemFilters.exclude[MissingTypesProblem]( "org.apache.spark.rdd.PairRDDFunctions") + ) ++ Seq( + // SPARK-4062 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") ) case v if v.startsWith("1.1") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 657e4b4432775..1697b6d4f2d43 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -31,19 +31,19 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq) = + sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, + streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, networkYarn, java8Tests, - sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "network-yarn", + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, + sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") - .map(ProjectRef(buildLocation, _)) + val assemblyProjects@Seq(assembly, examples, networkYarn) = + Seq("assembly", "examples", "network-yarn").map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") // Root project. @@ -68,8 +68,8 @@ object SparkBuild extends PomBuild { profiles ++= Seq("spark-ganglia-lgpl") } if (Properties.envOrNone("SPARK_HIVE").isDefined) { - println("NOTE: SPARK_HIVE is deprecated, please use -Phive flag.") - profiles ++= Seq("hive") + println("NOTE: SPARK_HIVE is deprecated, please use -Phive and -Phive-thriftserver flags.") + profiles ++= Seq("hive", "hive-thriftserver") } Properties.envOrNone("SPARK_HADOOP_VERSION") match { case Some(v) => @@ -91,13 +91,24 @@ object SparkBuild extends PomBuild { profiles } - override val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { + override val profiles = { + val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { case None => backwardCompatibility case Some(v) => if (backwardCompatibility.nonEmpty) println("Note: We ignore environment variables, when use of profile is detected in " + "conjunction with environment variable.") v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq + } + + if (profiles.exists(_.contains("scala-"))) { + profiles + } else if (System.getProperty("scala-2.11") != null) { + profiles ++ Seq("scala-2.11") + } else { + println("Enabled default scala profile") + profiles ++ Seq("scala-2.10") + } } Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { @@ -136,7 +147,8 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ - (allProjects ++ optionallyEnabledProjects ++ assemblyProjects).foreach(enable(sharedSettings)) + (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) + .foreach(enable(sharedSettings ++ ExludedDependencies.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -178,6 +190,16 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +/** + This excludes library dependencies in sbt, which are specified in maven but are + not needed by sbt build. + */ +object ExludedDependencies { + lazy val settings = Seq( + libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } + ) +} + /** * Following project only exists to pull previous artifacts of Spark for generating * Mima ignores. For more information see: SPARK 2071 @@ -270,8 +292,15 @@ object Assembly { lazy val settings = assemblySettings ++ Seq( test in assembly := {}, - jarName in assembly <<= (version, moduleName) map { (v, mName) => mName + "-"+v + "-hadoop" + - Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" }, + jarName in assembly <<= (version, moduleName) map { (v, mName) => + if (mName.contains("network-yarn")) { + // This must match the same name used in maven (see network/yarn/pom.xml) + "spark-" + v + "-yarn-shuffle.jar" + } else { + mName + "-" + v + "-hadoop" + + Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" + } + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -348,13 +377,17 @@ object TestSettings { javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", javaOptions in Test += "-Dspark.ui.enabled=false", + javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, + // This places test scope jars on the classpath of executors during tests. + javaOptions in Test += + "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files. + map(_.getAbsolutePath).mkString(":").stripSuffix(":"), javaOptions += "-Xmx3g", - // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 3ef2d5451da0d..8863f272da415 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -26,7 +26,7 @@ import sbt.Keys._ object SparkPluginDef extends Build { lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) - lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git") + lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") // There is actually no need to publish this artifact. def styleSettings = Defaults.defaultSettings ++ Seq ( diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index e39e6514d77a1..9556e4718e585 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -37,16 +37,6 @@ """ -# The following block allows us to import python's random instead of mllib.random for scripts in -# mllib that depend on top level pyspark packages, which transitively depend on python's random. -# Since Python's import logic looks for modules in the current package first, we eliminate -# mllib.random as a candidate for C{import random} by removing the first search path, the script's -# location, in order to force the loader to look in Python's top-level modules for C{random}. -import sys -s = sys.path.pop(0) -import random -sys.path.insert(0, s) - from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD diff --git a/python/pyspark/context.py b/python/pyspark/context.py index faa5952258aef..b6c991453d4de 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -289,12 +289,29 @@ def stop(self): def parallelize(self, c, numSlices=None): """ - Distribute a local Python collection to form an RDD. + Distribute a local Python collection to form an RDD. Using xrange + is recommended if the input represents a range for performance. - >>> sc.parallelize(range(5), 5).glom().collect() - [[0], [1], [2], [3], [4]] + >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect() + [[0], [2], [3], [4], [6]] + >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect() + [[], [0], [], [2], [4]] """ - numSlices = numSlices or self.defaultParallelism + numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism + if isinstance(c, xrange): + size = len(c) + if size == 0: + return self.parallelize([], numSlices) + step = c[1] - c[0] if size > 1 else 1 + start0 = c[0] + + def getStart(split): + return start0 + (split * size / numSlices) * step + + def f(split, iterator): + return xrange(getStart(split), getStart(split + 1), step) + + return self.parallelize([], numSlices).mapPartitionsWithIndex(f) # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 9c70fa5c16d0c..a975dc19cb78e 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -45,7 +45,9 @@ def launch_gateway(): # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + env = dict(os.environ) + env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdout=PIPE, stdin=PIPE) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 4149f54931d1f..5030a655fcbba 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -24,3 +24,37 @@ import numpy if numpy.version.version < '1.4': raise Exception("MLlib requires NumPy 1.4+") + +__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random', + 'recommendation', 'regression', 'stat', 'tree', 'util'] + +import sys +import rand as random +random.__name__ = 'random' +random.RandomRDDs.__module__ = __name__ + '.random' + + +class RandomModuleHook(object): + """ + Hook to import pyspark.mllib.random + """ + fullname = __name__ + '.random' + + def find_module(self, name, path=None): + # skip all other modules + if not name.startswith(self.fullname): + return + return self + + def load_module(self, name): + if name == self.fullname: + return random + + cname = name.rsplit('.', 1)[-1] + try: + return getattr(random, cname) + except AttributeError: + raise ImportError + + +sys.meta_path.append(RandomModuleHook()) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 5d90dddb5df1c..b654813fb4cf6 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -76,7 +76,7 @@ class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType="none", intercept=False): + initialWeights=None, regParam=0.01, regType="l2", intercept=False): """ Train a logistic regression model on the given data. @@ -87,16 +87,16 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param miniBatchFraction: Fraction of data to be used for each SGD iteration. :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 1.0). + :param regParam: The regularizer parameter (default: 0.01). :param regType: The type of regularizer used for training our model. :Allowed values: - - "l1" for using L1Updater - - "l2" for using SquaredL2Updater - - "none" for no regularizer + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization - (default: "none") + (default: "l2") @param intercept: Boolean parameter which indicates the use or not of the augmented representation for @@ -104,8 +104,9 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, are activated or not). """ def train(rdd, i): - return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, iterations, step, - miniBatchFraction, i, regParam, regType, intercept) + return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), + float(step), float(miniBatchFraction), i, float(regParam), regType, + bool(intercept)) return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) @@ -145,8 +146,8 @@ def predict(self, x): class SVMWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None, regType="none", intercept=False): + def train(cls, data, iterations=100, step=1.0, regParam=0.01, + miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): """ Train a support vector machine on the given data. @@ -154,7 +155,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, :param iterations: The number of iterations (default: 100). :param step: The step parameter used in SGD (default: 1.0). - :param regParam: The regularizer parameter (default: 1.0). + :param regParam: The regularizer parameter (default: 0.01). :param miniBatchFraction: Fraction of data to be used for each SGD iteration. :param initialWeights: The initial weights (default: None). @@ -162,11 +163,11 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, our model. :Allowed values: - - "l1" for using L1Updater - - "l2" for using SquaredL2Updater, - - "none" for no regularizer. + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization - (default: "none") + (default: "l2") @param intercept: Boolean parameter which indicates the use or not of the augmented representation for @@ -174,8 +175,9 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, are activated or not). """ def train(rdd, i): - return callMLlibFunc("trainSVMModelWithSGD", rdd, iterations, step, regParam, - miniBatchFraction, i, regType, intercept) + return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i, regType, + bool(intercept)) return _regression_train_wrapper(train, SVMModel, data, initialWeights) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 9ec28079aef43..8cb992df2d9c7 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -18,8 +18,11 @@ """ Python package for feature in MLlib. """ +from __future__ import absolute_import + import sys import warnings +import random from py4j.protocol import Py4JJavaError @@ -341,8 +344,6 @@ def __init__(self): """ Construct Word2Vec instance """ - import random # this can't be on the top because of mllib.random - self.vectorSize = 100 self.learningRate = 0.025 self.numPartitions = 1 @@ -411,8 +412,5 @@ def _test(): exit(-1) if __name__ == "__main__": - # remove current path from list of search paths to avoid importing mllib.random - # for C{import random}, which is done in an external dependency of pyspark during doctests. - import sys sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index e35202dca0acc..537b17657809c 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -614,8 +614,4 @@ def _test(): exit(-1) if __name__ == "__main__": - # remove current path from list of search paths to avoid importing mllib.random - # for C{import random}, which is done in an external dependency of pyspark during doctests. - import sys - sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/rand.py similarity index 100% rename from python/pyspark/mllib/random.py rename to python/pyspark/mllib/rand.py diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 66e25a48dfa71..f4f5e615fadc3 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -138,7 +138,7 @@ class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType="none", intercept=False): + initialWeights=None, regParam=0.0, regType=None, intercept=False): """ Train a linear regression model on the given data. @@ -149,16 +149,16 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, :param miniBatchFraction: Fraction of data to be used for each SGD iteration. :param initialWeights: The initial weights (default: None). - :param regParam: The regularizer parameter (default: 1.0). + :param regParam: The regularizer parameter (default: 0.0). :param regType: The type of regularizer used for training our model. :Allowed values: - - "l1" for using L1Updater, - - "l2" for using SquaredL2Updater, - - "none" for no regularizer. + - "l1" for using L1 regularization (lasso), + - "l2" for using L2 regularization (ridge), + - None for no regularization - (default: "none") + (default: None) @param intercept: Boolean parameter which indicates the use or not of the augmented representation for @@ -166,11 +166,11 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, are activated or not). """ def train(rdd, i): - return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step, - miniBatchFraction, i, regParam, regType, intercept) + return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), + float(step), float(miniBatchFraction), i, float(regParam), + regType, bool(intercept)) - return _regression_train_wrapper(train, LinearRegressionModel, - data, initialWeights) + return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) class LassoModel(LinearRegressionModelBase): @@ -209,12 +209,13 @@ class LassoModel(LinearRegressionModelBase): class LassoWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, + def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None): """Train a Lasso regression model on the given data.""" def train(rdd, i): - return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam, - miniBatchFraction, i) + return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i) + return _regression_train_wrapper(train, LassoModel, data, initialWeights) @@ -254,15 +255,14 @@ class RidgeRegressionModel(LinearRegressionModelBase): class RidgeRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, + def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None): """Train a ridge regression model on the given data.""" def train(rdd, i): - return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam, - miniBatchFraction, i) + return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i) - return _regression_train_wrapper(train, RidgeRegressionModel, - data, initialWeights) + return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) def _test(): diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 5d1a3c0962796..ef0d556fac7bc 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -124,10 +124,13 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 - True - >>> model.predict(array([0.0])) == 0 - True + >>> model.predict(array([1.0])) + 1.0 + >>> model.predict(array([0.0])) + 0.0 + >>> rdd = sc.parallelize([[1.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @@ -170,14 +173,13 @@ def trainRegressor(data, categoricalFeaturesInfo, ... ] >>> >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {}) - >>> model.predict(array([0.0, 1.0])) == 1 - True - >>> model.predict(array([0.0, 0.0])) == 0 - True - >>> model.predict(SparseVector(2, {1: 1.0})) == 1 - True - >>> model.predict(SparseVector(2, {1: 0.0})) == 0 - True + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {1: 0.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) diff --git a/python/run-tests b/python/run-tests index a4f0cac059ff3..e66854b44dfa6 100755 --- a/python/run-tests +++ b/python/run-tests @@ -72,7 +72,7 @@ function run_mllib_tests() { run_test "pyspark/mllib/clustering.py" run_test "pyspark/mllib/feature.py" run_test "pyspark/mllib/linalg.py" - run_test "pyspark/mllib/random.py" + run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/stat.py" diff --git a/repl/pom.xml b/repl/pom.xml index af528c8914335..c2bf9fdfbcce7 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -35,9 +35,16 @@ repl /usr/share/spark root + scala-2.10/src/main/scala + scala-2.10/src/test/scala + + ${jline.groupid} + jline + ${jline.version} + org.apache.spark spark-core_${scala.binary.version} @@ -75,11 +82,6 @@ scala-reflect ${scala.version} - - org.scala-lang - jline - ${scala.version} - org.slf4j jul-to-slf4j @@ -122,6 +124,51 @@ + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + ${extra.source.dir} + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + ${extra.testsource.dir} + + + + + + + + scala-2.11 + + scala-2.11 + + + scala-2.11/src/main/scala + scala-2.11/src/test/scala + + + diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/Main.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkImports.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala similarity index 100% rename from repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala rename to repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala new file mode 100644 index 0000000000000..5e93a71995072 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import org.apache.spark.util.Utils +import org.apache.spark._ + +import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter.SparkILoop + +object Main extends Logging { + + val conf = new SparkConf() + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = conf.get("spark.repl.classdir", tmp) + val outputDir = Utils.createTempDir(rootDir) + val s = new Settings() + s.processArguments(List("-Yrepl-class-based", + "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true) + val classServer = new HttpServer(outputDir, new SecurityManager(conf)) + var sparkContext: SparkContext = _ + var interp = new SparkILoop // this is a public var because tests reset it. + + def main(args: Array[String]) { + if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + // Start the classServer and store its URI in a spark system property + // (which will be passed to executors so that they can connect to it) + classServer.start() + interp.process(s) // Repl starts and goes in loop of R.E.P.L + classServer.stop() + Option(sparkContext).map(_.stop) + } + + + def getAddedJars: Array[String] = { + val envJars = sys.env.get("ADD_JARS") + val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) } + val jars = propJars.orElse(envJars).getOrElse("") + Utils.resolveURIs(jars).split(",").filter(_.nonEmpty) + } + + def createSparkContext(): SparkContext = { + val execUri = System.getenv("SPARK_EXECUTOR_URI") + val jars = getAddedJars + val conf = new SparkConf() + .setMaster(getMaster) + .setAppName("Spark shell") + .setJars(jars) + .set("spark.repl.class.uri", classServer.uri) + logInfo("Spark class server started at " + classServer.uri) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) + } + if (System.getenv("SPARK_HOME") != null) { + conf.setSparkHome(System.getenv("SPARK_HOME")) + } + sparkContext = new SparkContext(conf) + logInfo("Created spark context..") + sparkContext + } + + private def getMaster: String = { + val master = { + val envMaster = sys.env.get("MASTER") + val propMaster = sys.props.get("spark.master") + propMaster.orElse(envMaster).getOrElse("local[*]") + } + master + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala new file mode 100644 index 0000000000000..8e519fa67f649 --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala @@ -0,0 +1,86 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Paul Phillips + */ + +package scala.tools.nsc +package interpreter + +import scala.tools.nsc.ast.parser.Tokens.EOF + +trait SparkExprTyper { + val repl: SparkIMain + + import repl._ + import global.{ reporter => _, Import => _, _ } + import naming.freshInternalVarName + + def symbolOfLine(code: String): Symbol = { + def asExpr(): Symbol = { + val name = freshInternalVarName() + // Typing it with a lazy val would give us the right type, but runs + // into compiler bugs with things like existentials, so we compile it + // behind a def and strip the NullaryMethodType which wraps the expr. + val line = "def " + name + " = " + code + + interpretSynthetic(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + // drop NullaryMethodType + sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType) + case _ => NoSymbol + } + } + def asDefn(): Symbol = { + val old = repl.definedSymbolList.toSet + + interpretSynthetic(code) match { + case IR.Success => + repl.definedSymbolList filterNot old match { + case Nil => NoSymbol + case sym :: Nil => sym + case syms => NoSymbol.newOverloaded(NoPrefix, syms) + } + case _ => NoSymbol + } + } + def asError(): Symbol = { + interpretSynthetic(code) + NoSymbol + } + beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError() + } + + private var typeOfExpressionDepth = 0 + def typeOfExpression(expr: String, silent: Boolean = true): Type = { + if (typeOfExpressionDepth > 2) { + repldbg("Terminating typeOfExpression recursion for expression: " + expr) + return NoType + } + typeOfExpressionDepth += 1 + // Don't presently have a good way to suppress undesirable success output + // while letting errors through, so it is first trying it silently: if there + // is an error, and errors are desired, then it re-evaluates non-silently + // to induce the error message. + try beSilentDuring(symbolOfLine(expr).tpe) match { + case NoType if !silent => symbolOfLine(expr).tpe // generate error + case tpe => tpe + } + finally typeOfExpressionDepth -= 1 + } + + // This only works for proper types. + def typeOfTypeString(typeString: String): Type = { + def asProperType(): Option[Type] = { + val name = freshInternalVarName() + val line = "def %s: %s = ???" format (name, typeString) + interpretSynthetic(line) match { + case IR.Success => + val sym0 = symbolOfTerm(name) + Some(sym0.asMethod.returnType) + case _ => None + } + } + beSilentDuring(asProperType()) getOrElse NoType + } +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala new file mode 100644 index 0000000000000..a591e9fc4622b --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -0,0 +1,966 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Alexander Spoon + */ + +package scala +package tools.nsc +package interpreter + +import scala.language.{ implicitConversions, existentials } +import scala.annotation.tailrec +import Predef.{ println => _, _ } +import interpreter.session._ +import StdReplTags._ +import scala.reflect.api.{Mirror, Universe, TypeCreator} +import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName } +import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } +import scala.reflect.{ClassTag, classTag} +import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader } +import ScalaClassLoader._ +import scala.reflect.io.{ File, Directory } +import scala.tools.util._ +import scala.collection.generic.Clearable +import scala.concurrent.{ ExecutionContext, Await, Future, future } +import ExecutionContext.Implicits._ +import java.io.{ BufferedReader, FileReader } + +/** The Scala interactive shell. It provides a read-eval-print loop + * around the Interpreter class. + * After instantiation, clients should call the main() method. + * + * If no in0 is specified, then input will come from the console, and + * the class will attempt to provide input editing feature such as + * input history. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + * @version 1.2 + */ +class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter) + extends AnyRef + with LoopCommands +{ + def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out) + def this() = this(None, new JPrintWriter(Console.out, true)) +// +// @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp +// @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i + + var in: InteractiveReader = _ // the input stream from which commands come + var settings: Settings = _ + var intp: SparkIMain = _ + + var globalFuture: Future[Boolean] = _ + + protected def asyncMessage(msg: String) { + if (isReplInfo || isReplPower) + echoAndRefresh(msg) + } + + def initializeSpark() { + intp.beQuietDuring { + command( """ + @transient val sc = org.apache.spark.repl.Main.createSparkContext(); + """) + command("import org.apache.spark.SparkContext._") + } + echo("Spark context available as sc.") + } + + /** Print a welcome message */ + def printWelcome() { + import org.apache.spark.SPARK_VERSION + echo("""Welcome to + ____ __ + / __/__ ___ _____/ /__ + _\ \/ _ \/ _ `/ __/ '_/ + /___/ .__/\_,_/_/ /_/\_\ version %s + /_/ + """.format(SPARK_VERSION)) + val welcomeMsg = "Using Scala %s (%s, Java %s)".format( + versionString, javaVmName, javaVersion) + echo(welcomeMsg) + echo("Type in expressions to have them evaluated.") + echo("Type :help for more information.") + } + + override def echoCommandMessage(msg: String) { + intp.reporter printUntruncatedMessage msg + } + + // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals]) + def history = in.history + + // classpath entries added via :cp + var addedClasspath: String = "" + + /** A reverse list of commands to replay if the user requests a :replay */ + var replayCommandStack: List[String] = Nil + + /** A list of commands to replay if the user requests a :replay */ + def replayCommands = replayCommandStack.reverse + + /** Record a command for replay should the user request a :replay */ + def addReplay(cmd: String) = replayCommandStack ::= cmd + + def savingReplayStack[T](body: => T): T = { + val saved = replayCommandStack + try body + finally replayCommandStack = saved + } + def savingReader[T](body: => T): T = { + val saved = in + try body + finally in = saved + } + + /** Close the interpreter and set the var to null. */ + def closeInterpreter() { + if (intp ne null) { + intp.close() + intp = null + } + } + + class SparkILoopInterpreter extends SparkIMain(settings, out) { + outer => + + override lazy val formatting = new Formatting { + def prompt = SparkILoop.this.prompt + } + override protected def parentClassLoader = + settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader ) + } + + /** Create a new interpreter. */ + def createInterpreter() { + if (addedClasspath != "") + settings.classpath append addedClasspath + + intp = new SparkILoopInterpreter + } + + /** print a friendly help message */ + def helpCommand(line: String): Result = { + if (line == "") helpSummary() + else uniqueCommand(line) match { + case Some(lc) => echo("\n" + lc.help) + case _ => ambiguousError(line) + } + } + private def helpSummary() = { + val usageWidth = commands map (_.usageMsg.length) max + val formatStr = "%-" + usageWidth + "s %s" + + echo("All commands can be abbreviated, e.g. :he instead of :help.") + + commands foreach { cmd => + echo(formatStr.format(cmd.usageMsg, cmd.help)) + } + } + private def ambiguousError(cmd: String): Result = { + matchingCommands(cmd) match { + case Nil => echo(cmd + ": no such command. Type :help for help.") + case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?") + } + Result(keepRunning = true, None) + } + private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd) + private def uniqueCommand(cmd: String): Option[LoopCommand] = { + // this lets us add commands willy-nilly and only requires enough command to disambiguate + matchingCommands(cmd) match { + case List(x) => Some(x) + // exact match OK even if otherwise appears ambiguous + case xs => xs find (_.name == cmd) + } + } + + /** Show the history */ + lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") { + override def usage = "[num]" + def defaultLines = 20 + + def apply(line: String): Result = { + if (history eq NoHistory) + return "No history available." + + val xs = words(line) + val current = history.index + val count = try xs.head.toInt catch { case _: Exception => defaultLines } + val lines = history.asStrings takeRight count + val offset = current - lines.size + 1 + + for ((line, index) <- lines.zipWithIndex) + echo("%3d %s".format(index + offset, line)) + } + } + + // When you know you are most likely breaking into the middle + // of a line being typed. This softens the blow. + protected def echoAndRefresh(msg: String) = { + echo("\n" + msg) + in.redrawLine() + } + protected def echo(msg: String) = { + out println msg + out.flush() + } + + /** Search the history */ + def searchHistory(_cmdline: String) { + val cmdline = _cmdline.toLowerCase + val offset = history.index - history.size + 1 + + for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline) + echo("%d %s".format(index + offset, line)) + } + + private val currentPrompt = Properties.shellPromptString + + /** Prompt to print when awaiting input */ + def prompt = currentPrompt + + import LoopCommand.{ cmd, nullary } + + /** Standard commands **/ + lazy val standardCommands = List( + cmd("cp", "", "add a jar or directory to the classpath", addClasspath), + cmd("edit", "|", "edit history", editCommand), + cmd("help", "[command]", "print this summary or command-specific help", helpCommand), + historyCommand, + cmd("h?", "", "search the history", searchHistory), + cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand), + //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand), + cmd("javap", "", "disassemble a file or class name", javapCommand), + cmd("line", "|", "place line(s) at the end of history", lineCommand), + cmd("load", "", "interpret lines in a file", loadCommand), + cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand), + // nullary("power", "enable power user mode", powerCmd), + nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)), + nullary("replay", "reset execution and replay all previous commands", replay), + nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand), + cmd("save", "", "save replayable session to a file", saveCommand), + shCommand, + cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings), + nullary("silent", "disable/enable automatic printing of results", verbosity), +// cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand), +// cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand), + nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand) + ) + + /** Power user commands */ +// lazy val powerCommands: List[LoopCommand] = List( +// cmd("phase", "", "set the implicit phase for power commands", phaseCommand) +// ) + + private def importsCommand(line: String): Result = { + val tokens = words(line) + val handlers = intp.languageWildcardHandlers ++ intp.importHandlers + + handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach { + case (handler, idx) => + val (types, terms) = handler.importedSymbols partition (_.name.isTypeName) + val imps = handler.implicitSymbols + val found = tokens filter (handler importsSymbolNamed _) + val typeMsg = if (types.isEmpty) "" else types.size + " types" + val termMsg = if (terms.isEmpty) "" else terms.size + " terms" + val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit" + val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "") + val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")") + + intp.reporter.printMessage("%2d) %-30s %s%s".format( + idx + 1, + handler.importString, + statsMsg, + foundMsg + )) + } + } + + private def findToolsJar() = PathResolver.SupplementalLocations.platformTools + + private def addToolsJarToLoader() = { + val cl = findToolsJar() match { + case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader) + case _ => intp.classLoader + } + if (Javap.isAvailable(cl)) { + repldbg(":javap available.") + cl + } + else { + repldbg(":javap unavailable: no tools.jar at " + jdkHome) + intp.classLoader + } + } +// +// protected def newJavap() = +// JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp)) +// +// private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap()) + + // Still todo: modules. +// private def typeCommand(line0: String): Result = { +// line0.trim match { +// case "" => ":type [-v] " +// case s => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") +// } +// } + +// private def kindCommand(expr: String): Result = { +// expr.trim match { +// case "" => ":kind [-v] " +// case s => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ") +// } +// } + + private def warningsCommand(): Result = { + if (intp.lastWarnings.isEmpty) + "Can't find any cached warnings." + else + intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) } + } + + private def changeSettings(args: String): Result = { + def showSettings() = { + for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString) + } + def updateSettings() = { + // put aside +flag options + val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+")) + val tmps = new Settings + val (ok, leftover) = tmps.processArguments(rest, processAll = true) + if (!ok) echo("Bad settings request.") + else if (leftover.nonEmpty) echo("Unprocessed settings.") + else { + // boolean flags set-by-user on tmp copy should be off, not on + val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting]) + val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg)) + // update non-flags + settings.processArguments(nonbools, processAll = true) + // also snag multi-value options for clearing, e.g. -Ylog: and -language: + for { + s <- settings.userSetSettings + if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting] + if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init)) + } s match { + case c: Clearable => c.clear() + case _ => + } + def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = { + for (b <- bs) + settings.lookupSetting(name(b)) match { + case Some(s) => + if (s.isInstanceOf[Settings#BooleanSetting]) setter(s) + else echo(s"Not a boolean flag: $b") + case _ => + echo(s"Not an option: $b") + } + } + update(minuses, identity, _.tryToSetFromPropertyValue("false")) // turn off + update(pluses, "-" + _.drop(1), _.tryToSet(Nil)) // turn on + } + } + if (args.isEmpty) showSettings() else updateSettings() + } + + private def javapCommand(line: String): Result = { +// if (javap == null) +// ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome) +// else if (line == "") +// ":javap [-lcsvp] [path1 path2 ...]" +// else +// javap(words(line)) foreach { res => +// if (res.isError) return "Failed: " + res.value +// else res.show() +// } + } + + private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent" + + private def phaseCommand(name: String): Result = { +// val phased: Phased = power.phased +// import phased.NoPhaseName +// +// if (name == "clear") { +// phased.set(NoPhaseName) +// intp.clearExecutionWrapper() +// "Cleared active phase." +// } +// else if (name == "") phased.get match { +// case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)" +// case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get) +// } +// else { +// val what = phased.parse(name) +// if (what.isEmpty || !phased.set(what)) +// "'" + name + "' does not appear to represent a valid phase." +// else { +// intp.setExecutionWrapper(pathToPhaseWrapper) +// val activeMessage = +// if (what.toString.length == name.length) "" + what +// else "%s (%s)".format(what, name) +// +// "Active phase is now: " + activeMessage +// } +// } + } + + /** Available commands */ + def commands: List[LoopCommand] = standardCommands ++ ( + // if (isReplPower) + // powerCommands + // else + Nil + ) + + val replayQuestionMessage = + """|That entry seems to have slain the compiler. Shall I replay + |your session? I can re-run each line except the last one. + |[y/n] + """.trim.stripMargin + + private val crashRecovery: PartialFunction[Throwable, Boolean] = { + case ex: Throwable => + val (err, explain) = ( + if (intp.isInitializeComplete) + (intp.global.throwableAsString(ex), "") + else + (ex.getMessage, "The compiler did not initialize.\n") + ) + echo(err) + + ex match { + case _: NoSuchMethodError | _: NoClassDefFoundError => + echo("\nUnrecoverable error.") + throw ex + case _ => + def fn(): Boolean = + try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() }) + catch { case _: RuntimeException => false } + + if (fn()) replay() + else echo("\nAbandoning crashed session.") + } + true + } + + // return false if repl should exit + def processLine(line: String): Boolean = { + import scala.concurrent.duration._ + Await.ready(globalFuture, 60.seconds) + + (line ne null) && (command(line) match { + case Result(false, _) => false + case Result(_, Some(line)) => addReplay(line) ; true + case _ => true + }) + } + + private def readOneLine() = { + out.flush() + in readLine prompt + } + + /** The main read-eval-print loop for the repl. It calls + * command() for each line of input, and stops when + * command() returns false. + */ + @tailrec final def loop() { + if ( try processLine(readOneLine()) catch crashRecovery ) + loop() + } + + /** interpret all lines from a specified file */ + def interpretAllFrom(file: File) { + savingReader { + savingReplayStack { + file applyReader { reader => + in = SimpleReader(reader, out, interactive = false) + echo("Loading " + file + "...") + loop() + } + } + } + } + + /** create a new interpreter and replay the given commands */ + def replay() { + reset() + if (replayCommandStack.isEmpty) + echo("Nothing to replay.") + else for (cmd <- replayCommands) { + echo("Replaying: " + cmd) // flush because maybe cmd will have its own output + command(cmd) + echo("") + } + } + def resetCommand() { + echo("Resetting interpreter state.") + if (replayCommandStack.nonEmpty) { + echo("Forgetting this session history:\n") + replayCommands foreach echo + echo("") + replayCommandStack = Nil + } + if (intp.namedDefinedTerms.nonEmpty) + echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", ")) + if (intp.definedTypes.nonEmpty) + echo("Forgetting defined types: " + intp.definedTypes.mkString(", ")) + + reset() + } + def reset() { + intp.reset() + unleashAndSetPhase() + } + + def lineCommand(what: String): Result = editCommand(what, None) + + // :edit id or :edit line + def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR")) + + def editCommand(what: String, editor: Option[String]): Result = { + def diagnose(code: String) = { + echo("The edited code is incomplete!\n") + val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") + if (errless) echo("The compiler reports no errors.") + } + def historicize(text: String) = history match { + case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true + case _ => false + } + def edit(text: String): Result = editor match { + case Some(ed) => + val tmp = File.makeTemp() + tmp.writeAll(text) + try { + val pr = new ProcessResult(s"$ed ${tmp.path}") + pr.exitCode match { + case 0 => + tmp.safeSlurp() match { + case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.") + case Some(edited) => + echo(edited.lines map ("+" + _) mkString "\n") + val res = intp interpret edited + if (res == IR.Incomplete) diagnose(edited) + else { + historicize(edited) + Result(lineToRecord = Some(edited), keepRunning = true) + } + case None => echo("Can't read edited text. Did you delete it?") + } + case x => echo(s"Error exit from $ed ($x), ignoring") + } + } finally { + tmp.delete() + } + case None => + if (historicize(text)) echo("Placing text in recent history.") + else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text") + } + + // if what is a number, use it as a line number or range in history + def isNum = what forall (c => c.isDigit || c == '-' || c == '+') + // except that "-" means last value + def isLast = (what == "-") + if (isLast || !isNum) { + val name = if (isLast) intp.mostRecentVar else what + val sym = intp.symbolOfIdent(name) + intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match { + case Some(req) => edit(req.line) + case None => echo(s"No symbol in scope: $what") + } + } else try { + val s = what + // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur) + val (start, len) = + if ((s indexOf '+') > 0) { + val (a,b) = s splitAt (s indexOf '+') + (a.toInt, b.drop(1).toInt) + } else { + (s indexOf '-') match { + case -1 => (s.toInt, 1) + case 0 => val n = s.drop(1).toInt ; (history.index - n, n) + case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n) + case i => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n) + } + } + import scala.collection.JavaConverters._ + val index = (start - 1) max 0 + val text = history match { + case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n" + case _ => history.asStrings.slice(index, index + len) mkString "\n" + } + edit(text) + } catch { + case _: NumberFormatException => echo(s"Bad range '$what'") + echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)") + } + } + + /** fork a shell and run a command */ + lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") { + override def usage = "" + def apply(line: String): Result = line match { + case "" => showUsage() + case _ => + val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})" + intp interpret toRun + () + } + } + + def withFile[A](filename: String)(action: File => A): Option[A] = { + val res = Some(File(filename)) filter (_.exists) map action + if (res.isEmpty) echo("That file does not exist") // courtesy side-effect + res + } + + def loadCommand(arg: String) = { + var shouldReplay: Option[String] = None + withFile(arg)(f => { + interpretAllFrom(f) + shouldReplay = Some(":load " + arg) + }) + Result(keepRunning = true, shouldReplay) + } + + def saveCommand(filename: String): Result = ( + if (filename.isEmpty) echo("File name is required.") + else if (replayCommandStack.isEmpty) echo("No replay commands in session") + else File(filename).printlnAll(replayCommands: _*) + ) + + def addClasspath(arg: String): Unit = { + val f = File(arg).normalize + if (f.exists) { + addedClasspath = ClassPath.join(addedClasspath, f.path) + val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) + echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) + replay() + } + else echo("The path '" + f + "' doesn't seem to exist.") + } + + def powerCmd(): Result = { + if (isReplPower) "Already in power mode." + else enablePowerMode(isDuringInit = false) + } + def enablePowerMode(isDuringInit: Boolean) = { + replProps.power setValue true + unleashAndSetPhase() + // asyncEcho(isDuringInit, power.banner) + } + private def unleashAndSetPhase() { + if (isReplPower) { + // power.unleash() + // Set the phase to "typer" + // intp beSilentDuring phaseCommand("typer") + } + } + + def asyncEcho(async: Boolean, msg: => String) { + if (async) asyncMessage(msg) + else echo(msg) + } + + def verbosity() = { + val old = intp.printResults + intp.printResults = !old + echo("Switched " + (if (old) "off" else "on") + " result printing.") + } + + /** Run one command submitted by the user. Two values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. */ + def command(line: String): Result = { + if (line startsWith ":") { + val cmd = line.tail takeWhile (x => !x.isWhitespace) + uniqueCommand(cmd) match { + case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace)) + case _ => ambiguousError(cmd) + } + } + else if (intp.global == null) Result(keepRunning = false, None) // Notice failure to create compiler + else Result(keepRunning = true, interpretStartingWith(line)) + } + + private def readWhile(cond: String => Boolean) = { + Iterator continually in.readLine("") takeWhile (x => x != null && cond(x)) + } + + def pasteCommand(arg: String): Result = { + var shouldReplay: Option[String] = None + def result = Result(keepRunning = true, shouldReplay) + val (raw, file) = + if (arg.isEmpty) (false, None) + else { + val r = """(-raw)?(\s+)?([^\-]\S*)?""".r + arg match { + case r(flag, sep, name) => + if (flag != null && name != null && sep == null) + echo(s"""I assume you mean "$flag $name"?""") + (flag != null, Option(name)) + case _ => + echo("usage: :paste -raw file") + return result + } + } + val code = file match { + case Some(name) => + withFile(name)(f => { + shouldReplay = Some(s":paste $arg") + val s = f.slurp.trim + if (s.isEmpty) echo(s"File contains no code: $f") + else echo(s"Pasting file $f...") + s + }) getOrElse "" + case None => + echo("// Entering paste mode (ctrl-D to finish)\n") + val text = (readWhile(_ => true) mkString "\n").trim + if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n") + else echo("\n// Exiting paste mode, now interpreting.\n") + text + } + def interpretCode() = { + val res = intp interpret code + // if input is incomplete, let the compiler try to say why + if (res == IR.Incomplete) { + echo("The pasted code is incomplete!\n") + // Remembrance of Things Pasted in an object + val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}") + if (errless) echo("...but compilation found no error? Good luck with that.") + } + } + def compileCode() = { + val errless = intp compileSources new BatchSourceFile("", code) + if (!errless) echo("There were compilation errors!") + } + if (code.nonEmpty) { + if (raw) compileCode() else interpretCode() + } + result + } + + private object paste extends Pasted { + val ContinueString = " | " + val PromptString = "scala> " + + def interpret(line: String): Unit = { + echo(line.trim) + intp interpret line + echo("") + } + + def transcript(start: String) = { + echo("\n// Detected repl transcript paste: ctrl-D to finish.\n") + apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim)) + } + } + import paste.{ ContinueString, PromptString } + + /** Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ + def interpretStartingWith(code: String): Option[String] = { + // signal completion non-completion input has been received + in.completion.resetVerbosity() + + def reallyInterpret = { + val reallyResult = intp.interpret(code) + (reallyResult, reallyResult match { + case IR.Error => None + case IR.Success => Some(code) + case IR.Incomplete => + if (in.interactive && code.endsWith("\n\n")) { + echo("You typed two blank lines. Starting a new command.") + None + } + else in.readLine(ContinueString) match { + case null => + // we know compilation is going to fail since we're at EOF and the + // parser thinks the input is still incomplete, but since this is + // a file being read non-interactively we want to fail. So we send + // it straight to the compiler for the nice error message. + intp.compileString(code) + None + + case line => interpretStartingWith(code + "\n" + line) + } + }) + } + + /** Here we place ourselves between the user and the interpreter and examine + * the input they are ostensibly submitting. We intervene in several cases: + * + * 1) If the line starts with "scala> " it is assumed to be an interpreter paste. + * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation + * on the previous result. + * 3) If the Completion object's execute returns Some(_), we inject that value + * and avoid the interpreter, as it's likely not valid scala code. + */ + if (code == "") None + else if (!paste.running && code.trim.startsWith(PromptString)) { + paste.transcript(code) + None + } + else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") { + interpretStartingWith(intp.mostRecentVar + code) + } + else if (code.trim startsWith "//") { + // line comment, do nothing + None + } + else + reallyInterpret._2 + } + + // runs :load `file` on any files passed via -i + def loadFiles(settings: Settings) = settings match { + case settings: GenericRunnerSettings => + for (filename <- settings.loadfiles.value) { + val cmd = ":load " + filename + command(cmd) + addReplay(cmd) + echo("") + } + case _ => + } + + /** Tries to create a JLineReader, falling back to SimpleReader: + * unless settings or properties are such that it should start + * with SimpleReader. + */ + def chooseReader(settings: Settings): InteractiveReader = { + if (settings.Xnojline || Properties.isEmacsShell) + SimpleReader() + else try new JLineReader( + if (settings.noCompletion) NoCompletion + else new SparkJLineCompletion(intp) + ) + catch { + case ex @ (_: Exception | _: NoClassDefFoundError) => + echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.") + SimpleReader() + } + } + protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = + u.TypeTag[T]( + m, + new TypeCreator { + def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type = + m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type] + }) + + private def loopPostInit() { + // Bind intp somewhere out of the regular namespace where + // we can get at it in generated code. + intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain])) + // Auto-run code via some setting. + ( replProps.replAutorunCode.option + flatMap (f => io.File(f).safeSlurp()) + foreach (intp quietRun _) + ) + // classloader and power mode setup + intp.setContextClassLoader() + if (isReplPower) { + // replProps.power setValue true + // unleashAndSetPhase() + // asyncMessage(power.banner) + } + // SI-7418 Now, and only now, can we enable TAB completion. + in match { + case x: JLineReader => x.consoleReader.postInit + case _ => + } + } + def process(settings: Settings): Boolean = savingContextLoader { + this.settings = settings + createInterpreter() + + // sets in to some kind of reader depending on environmental cues + in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true)) + globalFuture = future { + intp.initializeSynchronous() + loopPostInit() + !intp.reporter.hasErrors + } + import scala.concurrent.duration._ + Await.ready(globalFuture, 10 seconds) + printWelcome() + initializeSpark() + loadFiles(settings) + + try loop() + catch AbstractOrMissingHandler() + finally closeInterpreter() + + true + } + + @deprecated("Use `process` instead", "2.9.0") + def main(settings: Settings): Unit = process(settings) //used by sbt +} + +object SparkILoop { + implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp + + // Designed primarily for use by test code: take a String with a + // bunch of code, and prints out a transcript of what it would look + // like if you'd just typed it into the repl. + def runForTranscript(code: String, settings: Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) { + override def write(str: String) = { + // completely skip continuation lines + if (str forall (ch => ch.isWhitespace || ch == '|')) () + else super.write(str) + } + } + val input = new BufferedReader(new StringReader(code.trim + "\n")) { + override def readLine(): String = { + val s = super.readLine() + // helping out by printing the line being interpreted. + if (s != null) + output.println(s) + s + } + } + val repl = new SparkILoop(input, output) + if (settings.classpath.isDefault) + settings.classpath.value = sys.props("java.class.path") + + repl process settings + } + } + } + + /** Creates an interpreter loop with default settings and feeds + * the given code to it as input. + */ + def run(code: String, sets: Settings = new Settings): String = { + import java.io.{ BufferedReader, StringReader, OutputStreamWriter } + + stringFromStream { ostream => + Console.withOut(ostream) { + val input = new BufferedReader(new StringReader(code)) + val output = new JPrintWriter(new OutputStreamWriter(ostream), true) + val repl = new SparkILoop(input, output) + + if (sets.classpath.isDefault) + sets.classpath.value = sys.props("java.class.path") + + repl process sets + } + } + } + def run(lines: List[String]): String = run(lines map (_ + "\n") mkString) +} diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala new file mode 100644 index 0000000000000..1bb62c84abddc --- /dev/null +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -0,0 +1,1319 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Martin Odersky + */ + +package scala +package tools.nsc +package interpreter + +import PartialFunction.cond +import scala.language.implicitConversions +import scala.beans.BeanProperty +import scala.collection.mutable +import scala.concurrent.{ Future, ExecutionContext } +import scala.reflect.runtime.{ universe => ru } +import scala.reflect.{ ClassTag, classTag } +import scala.reflect.internal.util.{ BatchSourceFile, SourceFile } +import scala.tools.util.PathResolver +import scala.tools.nsc.io.AbstractFile +import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings } +import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps } +import scala.tools.nsc.util.Exceptional.unwrap +import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable} + +/** An interpreter for Scala code. + * + * The main public entry points are compile(), interpret(), and bind(). + * The compile() method loads a complete Scala file. The interpret() method + * executes one line of Scala code at the request of the user. The bind() + * method binds an object to a variable that can then be used by later + * interpreted code. + * + * The overall approach is based on compiling the requested code and then + * using a Java classloader and Java reflection to run the code + * and access its results. + * + * In more detail, a single compiler instance is used + * to accumulate all successfully compiled or interpreted Scala code. To + * "interpret" a line of code, the compiler generates a fresh object that + * includes the line of code and which has public member(s) to export + * all variables defined by that code. To extract the result of an + * interpreted line to show the user, a second "result object" is created + * which imports the variables exported by the above object and then + * exports members called "$eval" and "$print". To accomodate user expressions + * that read from variables or methods defined in previous statements, "import" + * statements are used. + * + * This interpreter shares the strengths and weaknesses of using the + * full compiler-to-Java. The main strength is that interpreted code + * behaves exactly as does compiled code, including running at full speed. + * The main weakness is that redefining classes and methods is not handled + * properly, because rebinding at the Java level is technically difficult. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + */ +class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings, + protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports { + imain => + + setBindings(createBindings, ScriptContext.ENGINE_SCOPE) + object replOutput extends ReplOutput(settings.Yreploutdir) { } + + @deprecated("Use replOutput.dir instead", "2.11.0") + def virtualDirectory = replOutput.dir + // Used in a test case. + def showDirectory() = replOutput.show(out) + + private[nsc] var printResults = true // whether to print result lines + private[nsc] var totalSilence = false // whether to print anything + private var _initializeComplete = false // compiler is initialized + private var _isInitialized: Future[Boolean] = null // set up initialization future + private var bindExceptions = true // whether to bind the lastException variable + private var _executionWrapper = "" // code to be wrapped around all lines + + /** We're going to go to some trouble to initialize the compiler asynchronously. + * It's critical that nothing call into it until it's been initialized or we will + * run into unrecoverable issues, but the perceived repl startup time goes + * through the roof if we wait for it. So we initialize it with a future and + * use a lazy val to ensure that any attempt to use the compiler object waits + * on the future. + */ + private var _classLoader: util.AbstractFileClassLoader = null // active classloader + private val _compiler: ReplGlobal = newCompiler(settings, reporter) // our private compiler + + def compilerClasspath: Seq[java.net.URL] = ( + if (isInitializeComplete) global.classPath.asURLs + else new PathResolver(settings).result.asURLs // the compiler's classpath + ) + def settings = initialSettings + // Run the code body with the given boolean settings flipped to true. + def withoutWarnings[T](body: => T): T = beQuietDuring { + val saved = settings.nowarn.value + if (!saved) + settings.nowarn.value = true + + try body + finally if (!saved) settings.nowarn.value = false + } + + /** construct an interpreter that reports to Console */ + def this(settings: Settings, out: JPrintWriter) = this(null, settings, out) + def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this(factory: ScriptEngineFactory) = this(factory, new Settings()) + def this() = this(new Settings()) + + lazy val formatting: Formatting = new Formatting { + val prompt = Properties.shellPromptString + } + lazy val reporter: SparkReplReporter = new SparkReplReporter(this) + + import formatting._ + import reporter.{ printMessage, printUntruncatedMessage } + + // This exists mostly because using the reporter too early leads to deadlock. + private def echo(msg: String) { Console println msg } + private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) + private def _initialize() = { + try { + // if this crashes, REPL will hang its head in shame + val run = new _compiler.Run() + assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") + run compileSources _initSources + _initializeComplete = true + true + } + catch AbstractOrMissingHandler() + } + private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" + private val logScope = scala.sys.props contains "scala.repl.scope" + private def scopelog(msg: String) = if (logScope) Console.err.println(msg) + + // argument is a thunk to execute after init is done + def initialize(postInitSignal: => Unit) { + synchronized { + if (_isInitialized == null) { + _isInitialized = + Future(try _initialize() finally postInitSignal)(ExecutionContext.global) + } + } + } + def initializeSynchronous(): Unit = { + if (!isInitializeComplete) { + _initialize() + assert(global != null, global) + } + } + def isInitializeComplete = _initializeComplete + + lazy val global: Global = { + if (!isInitializeComplete) _initialize() + _compiler + } + + import global._ + import definitions.{ ObjectClass, termMember, dropNullaryMethod} + + lazy val runtimeMirror = ru.runtimeMirror(classLoader) + + private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol } + + def getClassIfDefined(path: String) = ( + noFatal(runtimeMirror staticClass path) + orElse noFatal(rootMirror staticClass path) + ) + def getModuleIfDefined(path: String) = ( + noFatal(runtimeMirror staticModule path) + orElse noFatal(rootMirror staticModule path) + ) + + implicit class ReplTypeOps(tp: Type) { + def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) + } + + // TODO: If we try to make naming a lazy val, we run into big time + // scalac unhappiness with what look like cycles. It has not been easy to + // reduce, but name resolution clearly takes different paths. + object naming extends { + val global: imain.global.type = imain.global + } with Naming { + // make sure we don't overwrite their unwisely named res3 etc. + def freshUserTermName(): TermName = { + val name = newTermName(freshUserVarName()) + if (replScope containsName name) freshUserTermName() + else name + } + def isInternalTermName(name: Name) = isInternalVarName("" + name) + } + import naming._ + + object deconstruct extends { + val global: imain.global.type = imain.global + } with StructuredTypeStrings + + lazy val memberHandlers = new { + val intp: imain.type = imain + } with SparkMemberHandlers + import memberHandlers._ + + /** Temporarily be quiet */ + def beQuietDuring[T](body: => T): T = { + val saved = printResults + printResults = false + try body + finally printResults = saved + } + def beSilentDuring[T](operation: => T): T = { + val saved = totalSilence + totalSilence = true + try operation + finally totalSilence = saved + } + + def quietRun[T](code: String) = beQuietDuring(interpret(code)) + + /** takes AnyRef because it may be binding a Throwable or an Exceptional */ + private def withLastExceptionLock[T](body: => T, alt: => T): T = { + assert(bindExceptions, "withLastExceptionLock called incorrectly.") + bindExceptions = false + + try beQuietDuring(body) + catch logAndDiscard("withLastExceptionLock", alt) + finally bindExceptions = true + } + + def executionWrapper = _executionWrapper + def setExecutionWrapper(code: String) = _executionWrapper = code + def clearExecutionWrapper() = _executionWrapper = "" + + /** interpreter settings */ + lazy val isettings = new SparkISettings(this) + + /** Instantiate a compiler. Overridable. */ + protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = { + settings.outputDirs setSingleOutput replOutput.dir + settings.exposeEmptyPackage.value = true + new Global(settings, reporter) with ReplGlobal { override def toString: String = "" } + } + + /** Parent classloader. Overridable. */ + protected def parentClassLoader: ClassLoader = + settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() ) + + /* A single class loader is used for all commands interpreted by this Interpreter. + It would also be possible to create a new class loader for each command + to interpret. The advantages of the current approach are: + + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" + + The main disadvantage is: + + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects refer to the old + definitions. + */ + def resetClassLoader() = { + repldbg("Setting new classloader: was " + _classLoader) + _classLoader = null + ensureClassLoader() + } + final def ensureClassLoader() { + if (_classLoader == null) + _classLoader = makeClassLoader() + } + def classLoader: util.AbstractFileClassLoader = { + ensureClassLoader() + _classLoader + } + + def backticked(s: String): String = ( + (s split '.').toList map { + case "_" => "_" + case s if nme.keywords(newTermName(s)) => s"`$s`" + case s => s + } mkString "." + ) + def readRootPath(readPath: String) = getModuleIfDefined(readPath) + + abstract class PhaseDependentOps { + def shift[T](op: => T): T + + def path(name: => Name): String = shift(path(symbolOfName(name))) + def path(sym: Symbol): String = backticked(shift(sym.fullName)) + def sig(sym: Symbol): String = shift(sym.defString) + } + object typerOp extends PhaseDependentOps { + def shift[T](op: => T): T = exitingTyper(op) + } + object flatOp extends PhaseDependentOps { + def shift[T](op: => T): T = exitingFlatten(op) + } + + def originalPath(name: String): String = originalPath(name: TermName) + def originalPath(name: Name): String = typerOp path name + def originalPath(sym: Symbol): String = typerOp path sym + def flatPath(sym: Symbol): String = flatOp shift sym.javaClassName + def translatePath(path: String) = { + val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path) + sym.toOption map flatPath + } + def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath + + private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) { + /** Overridden here to try translating a simple name to the generated + * class name if the original attempt fails. This method is used by + * getResourceAsStream as well as findClass. + */ + override protected def findAbstractFile(name: String): AbstractFile = + super.findAbstractFile(name) match { + case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull + case file => file + } + } + private def makeClassLoader(): util.AbstractFileClassLoader = + new TranslatingClassLoader(parentClassLoader match { + case null => ScalaClassLoader fromURLs compilerClasspath + case p => new ScalaClassLoader.URLClassLoader(compilerClasspath, p) + }) + + // Set the current Java "context" class loader to this interpreter's class loader + def setContextClassLoader() = classLoader.setAsContext() + + def allDefinedNames: List[Name] = exitingTyper(replScope.toList.map(_.name).sorted) + def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted + + /** Most recent tree handled which wasn't wholly synthetic. */ + private def mostRecentlyHandledTree: Option[Tree] = { + prevRequests.reverse foreach { req => + req.handlers.reverse foreach { + case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) + case _ => () + } + } + None + } + + private def updateReplScope(sym: Symbol, isDefined: Boolean) { + def log(what: String) { + val mark = if (sym.isType) "t " else "v " + val name = exitingTyper(sym.nameString) + val info = cleanTypeAfterTyper(sym) + val defn = sym defStringSeenAs info + + scopelog(f"[$mark$what%6s] $name%-25s $defn%s") + } + if (ObjectClass isSubClass sym.owner) return + // unlink previous + replScope lookupAll sym.name foreach { sym => + log("unlink") + replScope unlink sym + } + val what = if (isDefined) "define" else "import" + log(what) + replScope enter sym + } + + def recordRequest(req: Request) { + if (req == null) + return + + prevRequests += req + + // warning about serially defining companions. It'd be easy + // enough to just redefine them together but that may not always + // be what people want so I'm waiting until I can do it better. + exitingTyper { + req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym => + val oldSym = replScope lookup newSym.name.companionName + if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) { + replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.") + replwarn("Companions must be defined together; you may wish to use :paste mode for this.") + } + } + } + exitingTyper { + req.imports foreach (sym => updateReplScope(sym, isDefined = false)) + req.defines foreach (sym => updateReplScope(sym, isDefined = true)) + } + } + + private[nsc] def replwarn(msg: => String) { + if (!settings.nowarnings) + printMessage(msg) + } + + def compileSourcesKeepingRun(sources: SourceFile*) = { + val run = new Run() + assert(run.typerPhase != NoPhase, "REPL requires a typer phase.") + reporter.reset() + run compileSources sources.toList + (!reporter.hasErrors, run) + } + + /** Compile an nsc SourceFile. Returns true if there are + * no compilation errors, or false otherwise. + */ + def compileSources(sources: SourceFile*): Boolean = + compileSourcesKeepingRun(sources: _*)._1 + + /** Compile a string. Returns true if there are no + * compilation errors, or false otherwise. + */ + def compileString(code: String): Boolean = + compileSources(new BatchSourceFile("