diff --git a/.rat-excludes b/.rat-excludes index bccb043c2bb55..eaefef1b0aa2e 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -25,6 +25,7 @@ log4j-defaults.properties bootstrap-tooltip.js jquery-1.11.1.min.js sorttable.js +.*avsc .*txt .*json .*data diff --git a/README.md b/README.md index f87e07aa5cc90..a1a48f5bd0819 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,15 @@ If your project is built with Maven, add this to your POM file's ` +## A Note About Thrift JDBC server and CLI for Spark SQL + +Spark SQL supports Thrift JDBC server and CLI. +See sql-programming-guide.md for more information about those features. +You can use those features by setting `-Phive-thriftserver` when building Spark as follows. + + $ sbt/sbt -Phive-thriftserver assembly + + ## Configuration Please refer to the [Configuration guide](http://spark.apache.org/docs/latest/configuration.html) diff --git a/bin/beeline b/bin/beeline index 09fe366c609fa..1bda4dba50605 100755 --- a/bin/beeline +++ b/bin/beeline @@ -17,29 +17,14 @@ # limitations under the License. # -# Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +# +# Shell script for starting BeeLine -# Find the java binary -if [ -n "${JAVA_HOME}" ]; then - RUNNER="${JAVA_HOME}/bin/java" -else - if [ `command -v java` ]; then - RUNNER="java" - else - echo "JAVA_HOME is not set" >&2 - exit 1 - fi -fi +# Enter posix mode for bash +set -o posix -# Compute classpath using external script -classpath_output=$($FWDIR/bin/compute-classpath.sh) -if [[ "$?" != "0" ]]; then - echo "$classpath_output" - exit 1 -else - CLASSPATH=$classpath_output -fi +# Figure out where Spark is installed +FWDIR="$(cd `dirname $0`/..; pwd)" CLASS="org.apache.hive.beeline.BeeLine" -exec "$RUNNER" -cp "$CLASSPATH" $CLASS "$@" +exec "$FWDIR/bin/spark-class" $CLASS "$@" diff --git a/bin/pyspark b/bin/pyspark index 39a20e2a24a3c..01d42025c978e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -23,12 +23,18 @@ FWDIR="$(cd `dirname $0`/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" +source $FWDIR/bin/utils.sh + SCALA_VERSION=2.10 -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then +function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 exit 0 +} + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + usage fi # Exit if the user hasn't compiled Spark @@ -66,10 +72,11 @@ fi # Build up arguments list manually to preserve quotes and backslashes. # We export Spark submit arguments as an environment variable because shell.py must run as a # PYTHONSTARTUP script, which does not take in arguments. This is required for IPython notebooks. - +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" PYSPARK_SUBMIT_ARGS="" whitespace="[[:space:]]" -for i in "$@"; do +for i in "${SUBMISSION_OPTS[@]}"; do if [[ $i =~ \" ]]; then i=$(echo $i | sed 's/\"/\\\"/g'); fi if [[ $i =~ $whitespace ]]; then i=\"$i\"; fi PYSPARK_SUBMIT_ARGS="$PYSPARK_SUBMIT_ARGS $i" @@ -90,7 +97,10 @@ fi if [[ "$1" =~ \.py$ ]]; then echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 echo -e "Use ./bin/spark-submit \n" 1>&2 - exec $FWDIR/bin/spark-submit "$@" + primary=$1 + shift + gatherSparkSubmitOpts "$@" + exec $FWDIR/bin/spark-submit "${SUBMISSION_OPTS[@]}" $primary "${APPLICATION_OPTS[@]}" else # Only use ipython if no command line arguments were provided [SPARK-1134] if [[ "$IPYTHON" = "1" ]]; then diff --git a/bin/spark-shell b/bin/spark-shell index 756c8179d12b6..8b7ccd7439551 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -31,13 +31,21 @@ set -o posix ## Global script variables FWDIR="$(cd `dirname $0`/..; pwd)" +function usage() { + echo "Usage: ./bin/spark-shell [options]" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + exit 0 +} + if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./bin/spark-shell [options]" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 - exit 0 + usage fi -function main(){ +source $FWDIR/bin/utils.sh +SUBMIT_USAGE_FUNCTION=usage +gatherSparkSubmitOpts "$@" + +function main() { if $cygwin; then # Workaround for issue involving JLine and Cygwin # (see http://sourceforge.net/p/jline/bugs/40/). @@ -46,11 +54,11 @@ function main(){ # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main spark-shell "$@" + $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" fi } diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index b56d69801171c..2ee60b4e2a2b3 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -19,4 +19,4 @@ rem set SPARK_HOME=%~dp0.. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd spark-shell --class org.apache.spark.repl.Main %* +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell diff --git a/bin/spark-sql b/bin/spark-sql index bba7f897b19bc..564f1f419060f 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -23,14 +23,72 @@ # Enter posix mode for bash set -o posix +CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" + # Figure out where Spark is installed FWDIR="$(cd `dirname $0`/..; pwd)" -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./sbin/spark-sql [options]" +function usage { + echo "Usage: ./bin/spark-sql [options] [cli option]" + pattern="usage" + pattern+="\|Spark assembly has been built with Hive" + pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" + pattern+="\|Spark Command: " + pattern+="\|--help" + pattern+="\|=======" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + echo + echo "CLI options:" + $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 +} + +function ensure_arg_number { + arg_number=$1 + at_least=$2 + + if [[ $arg_number -lt $at_least ]]; then + usage + exit 1 + fi +} + +if [[ "$@" = --help ]] || [[ "$@" = -h ]]; then + usage exit 0 fi -CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" -exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ +CLI_ARGS=() +SUBMISSION_ARGS=() + +while (($#)); do + case $1 in + -d | --define | --database | -f | -h | --hiveconf | --hivevar | -i | -p) + ensure_arg_number $# 2 + CLI_ARGS+=("$1"); shift + CLI_ARGS+=("$1"); shift + ;; + + -e) + ensure_arg_number $# 2 + CLI_ARGS+=("$1"); shift + CLI_ARGS+=("$1"); shift + ;; + + -s | --silent) + CLI_ARGS+=("$1"); shift + ;; + + -v | --verbose) + # Both SparkSubmit and SparkSQLCLIDriver recognizes -v | --verbose + CLI_ARGS+=("$1") + SUBMISSION_ARGS+=("$1"); shift + ;; + + *) + SUBMISSION_ARGS+=("$1"); shift + ;; + esac +done + +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${CLI_ARGS[@]}" diff --git a/bin/utils.sh b/bin/utils.sh new file mode 100644 index 0000000000000..0804b1ed9f231 --- /dev/null +++ b/bin/utils.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Gather all all spark-submit options into SUBMISSION_OPTS +function gatherSparkSubmitOpts() { + + if [ -z "$SUBMIT_USAGE_FUNCTION" ]; then + echo "Function for printing usage of $0 is not set." 1>&2 + echo "Please set usage function to shell variable 'SUBMIT_USAGE_FUNCTION' in $0" 1>&2 + exit 1 + fi + + # NOTE: If you add or remove spark-sumbmit options, + # modify NOT ONLY this script but also SparkSubmitArgument.scala + SUBMISSION_OPTS=() + APPLICATION_OPTS=() + while (($#)); do + case "$1" in + --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \ + --conf | --properties-file | --driver-memory | --driver-java-options | \ + --driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \ + --total-executor-cores | --executor-cores | --queue | --num-executors | --archives) + if [[ $# -lt 2 ]]; then + "$SUBMIT_USAGE_FUNCTION" + exit 1; + fi + SUBMISSION_OPTS+=("$1"); shift + SUBMISSION_OPTS+=("$1"); shift + ;; + + --verbose | -v | --supervise) + SUBMISSION_OPTS+=("$1"); shift + ;; + + *) + APPLICATION_OPTS+=("$1"); shift + ;; + esac + done + + export SUBMISSION_OPTS + export APPLICATION_OPTS +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java deleted file mode 100644 index 0d31894d6ec7a..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.util.concurrent.TimeUnit; - -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioSocketChannel; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -class FileClient { - - private static final Logger LOG = LoggerFactory.getLogger(FileClient.class.getName()); - - private final FileClientHandler handler; - private Channel channel = null; - private Bootstrap bootstrap = null; - private EventLoopGroup group = null; - private final int connectTimeout; - private final int sendTimeout = 60; // 1 min - - FileClient(FileClientHandler handler, int connectTimeout) { - this.handler = handler; - this.connectTimeout = connectTimeout; - } - - public void init() { - group = new OioEventLoopGroup(); - bootstrap = new Bootstrap(); - bootstrap.group(group) - .channel(OioSocketChannel.class) - .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.TCP_NODELAY, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout) - .handler(new FileClientChannelInitializer(handler)); - } - - public void connect(String host, int port) { - try { - // Start the connection attempt. - channel = bootstrap.connect(host, port).sync().channel(); - // ChannelFuture cf = channel.closeFuture(); - //cf.addListener(new ChannelCloseListener(this)); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted while trying to connect", e); - close(); - } - } - - public void waitForClose() { - try { - channel.closeFuture().sync(); - } catch (InterruptedException e) { - LOG.warn("FileClient interrupted", e); - } - } - - public void sendRequest(String file) { - //assert(file == null); - //assert(channel == null); - try { - // Should be able to send the message to network link channel. - boolean bSent = channel.writeAndFlush(file + "\r\n").await(sendTimeout, TimeUnit.SECONDS); - if (!bSent) { - throw new RuntimeException("Failed to send"); - } - } catch (InterruptedException e) { - LOG.error("Error", e); - } - } - - public void close() { - if (group != null) { - group.shutdownGracefully(); - group = null; - bootstrap = null; - } - } -} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java deleted file mode 100644 index 63d3d927255f9..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; - -import org.apache.spark.storage.BlockId; - -abstract class FileClientHandler extends SimpleChannelInboundHandler { - - private FileHeader currentHeader = null; - - private volatile boolean handlerCalled = false; - - public boolean isComplete() { - return handlerCalled; - } - - public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); - public abstract void handleError(BlockId blockId); - - @Override - public void channelRead0(ChannelHandlerContext ctx, ByteBuf in) { - // get header - if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) { - currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE())); - } - // get file - if(in.readableBytes() >= currentHeader.fileLen()) { - handle(ctx, in, currentHeader); - handlerCalled = true; - currentHeader = null; - ctx.close(); - } - } - -} - diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java deleted file mode 100644 index c93425e2787dc..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.net.InetSocketAddress; - -import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelOption; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.oio.OioEventLoopGroup; -import io.netty.channel.socket.oio.OioServerSocketChannel; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Server that accept the path of a file an echo back its content. - */ -class FileServer { - - private static final Logger LOG = LoggerFactory.getLogger(FileServer.class.getName()); - - private EventLoopGroup bossGroup = null; - private EventLoopGroup workerGroup = null; - private ChannelFuture channelFuture = null; - private int port = 0; - - FileServer(PathResolver pResolver, int port) { - InetSocketAddress addr = new InetSocketAddress(port); - - // Configure the server. - bossGroup = new OioEventLoopGroup(); - workerGroup = new OioEventLoopGroup(); - - ServerBootstrap bootstrap = new ServerBootstrap(); - bootstrap.group(bossGroup, workerGroup) - .channel(OioServerSocketChannel.class) - .option(ChannelOption.SO_BACKLOG, 100) - .option(ChannelOption.SO_RCVBUF, 1500) - .childHandler(new FileServerChannelInitializer(pResolver)); - // Start the server. - channelFuture = bootstrap.bind(addr); - try { - // Get the address we bound to. - InetSocketAddress boundAddress = - ((InetSocketAddress) channelFuture.sync().channel().localAddress()); - this.port = boundAddress.getPort(); - } catch (InterruptedException ie) { - this.port = 0; - } - } - - /** - * Start the file server asynchronously in a new thread. - */ - public void start() { - Thread blockingThread = new Thread() { - @Override - public void run() { - try { - channelFuture.channel().closeFuture().sync(); - LOG.info("FileServer exiting"); - } catch (InterruptedException e) { - LOG.error("File server start got interrupted", e); - } - // NOTE: bootstrap is shutdown in stop() - } - }; - blockingThread.setDaemon(true); - blockingThread.start(); - } - - public int getPort() { - return port; - } - - public void stop() { - // Close the bound channel. - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly(); - channelFuture = null; - } - - // Shutdown event groups - if (bossGroup != null) { - bossGroup.shutdownGracefully(); - bossGroup = null; - } - - if (workerGroup != null) { - workerGroup.shutdownGracefully(); - workerGroup = null; - } - // TODO: Shutdown all accepted channels as well ? - } -} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java deleted file mode 100644 index c0133e19c7f79..0000000000000 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty; - -import java.io.File; -import java.io.FileInputStream; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.channel.DefaultFileRegion; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.FileSegment; - -class FileServerHandler extends SimpleChannelInboundHandler { - - private static final Logger LOG = LoggerFactory.getLogger(FileServerHandler.class.getName()); - - private final PathResolver pResolver; - - FileServerHandler(PathResolver pResolver){ - this.pResolver = pResolver; - } - - @Override - public void channelRead0(ChannelHandlerContext ctx, String blockIdString) { - BlockId blockId = BlockId.apply(blockIdString); - FileSegment fileSegment = pResolver.getBlockLocation(blockId); - // if getBlockLocation returns null, close the channel - if (fileSegment == null) { - //ctx.close(); - return; - } - File file = fileSegment.file(); - if (file.exists()) { - if (!file.isFile()) { - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - long length = fileSegment.length(); - if (length > Integer.MAX_VALUE || length <= 0) { - ctx.write(new FileHeader(0, blockId).buffer()); - ctx.flush(); - return; - } - int len = (int) length; - ctx.write((new FileHeader(len, blockId)).buffer()); - try { - ctx.write(new DefaultFileRegion(new FileInputStream(file) - .getChannel(), fileSegment.offset(), fileSegment.length())); - } catch (Exception e) { - LOG.error("Exception: ", e); - } - } else { - ctx.write(new FileHeader(0, blockId).buffer()); - } - ctx.flush(); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - LOG.error("Exception: ", cause); - ctx.close(); - } -} diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bf3c3a6ceb5ef..3848734d6f639 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -66,10 +66,15 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** * Whether the cleaning thread will block on cleanup tasks. - * This is set to true only for tests. + * + * Due to SPARK-3015, this is set to true by default. This is intended to be only a temporary + * workaround for the issue, which is ultimately caused by the way the BlockManager actors + * issue inter-dependent blocking Akka messages to each other at high frequencies. This happens, + * for instance, when the driver performs a GC and cleans up all broadcast blocks that are no + * longer in scope. */ private val blockOnCleanupTasks = sc.conf.getBoolean( - "spark.cleaner.referenceTracking.blocking", false) + "spark.cleaner.referenceTracking.blocking", true) @volatile private var stopped = false @@ -174,9 +179,6 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def blockManagerMaster = sc.env.blockManager.master private def broadcastManager = sc.env.broadcastManager private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - - // Used for testing. These methods explicitly blocks until cleanup is completed - // to ensure that more reliable testing. } private object ContextCleaner { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 24ccce21b62ca..83ae57b7f1516 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -21,6 +21,7 @@ import akka.actor.Actor import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.util.ActorLogReceive /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -36,8 +37,10 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) extends Actor { - override def receive = { +private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) + extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val response = HeartbeatResponse( !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index f40baa8e43592..5c262bcbddf76 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -33,7 +33,7 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which // introduces an expensive read fence. - if (context.interrupted) { + if (context.isInterrupted) { throw new TaskKilledException } else { delegate.hasNext diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 894091761485d..51705c895a55c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -38,10 +38,10 @@ private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage /** Actor class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - def receive = { + override def receiveWithLogging = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = sender.path.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 13f0bff7ee507..605df0e929faa 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -45,7 +45,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { /** Create a SparkConf that loads defaults from system properties and the classpath */ def this() = this(true) - private val settings = new HashMap[String, String]() + private[spark] val settings = new HashMap[String, String]() if (loadDefaults) { // Load any spark.* system properties @@ -210,6 +210,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { new SparkConf(false).setAll(settings) } + /** + * By using this instead of System.getenv(), environment variables can be mocked + * in unit tests. + */ + private[spark] def getenv(name: String): String = System.getenv(name) + /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { @@ -227,7 +233,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { // Validate spark.executor.extraJavaOptions settings.get(executorOptsKey).map { javaOpts => if (javaOpts.contains("-Dspark")) { - val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts)'. " + + val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " + "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." throw new Exception(msg) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 9d4edeb6d96cf..fc36e37c53f5e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -156,11 +156,9 @@ object SparkEnv extends Logging { conf.set("spark.driver.port", boundPort.toString) } - // Create an instance of the class named by the given Java system property, or by - // defaultClassName if the property is not set, and return it as a T - def instantiateClass[T](propertyName: String, defaultClassName: String): T = { - val name = conf.get(propertyName, defaultClassName) - val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader) + // Create an instance of the class with the given name, possibly initializing it with our conf + def instantiateClass[T](className: String): T = { + val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just // SparkConf, then one taking no arguments try { @@ -178,11 +176,17 @@ object SparkEnv extends Logging { } } - val serializer = instantiateClass[Serializer]( + // Create an instance of the class named by the given SparkConf property, or defaultClassName + // if the property is not set, possibly initializing it with our conf + def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = { + instantiateClass[T](conf.get(propertyName, defaultClassName)) + } + + val serializer = instantiateClassFromConf[Serializer]( "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") - val closureSerializer = instantiateClass[Serializer]( + val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") def registerOrLookup(name: String, newActor: => Actor): ActorRef = { @@ -206,12 +210,22 @@ object SparkEnv extends Logging { "MapOutputTracker", new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + // Let the user specify short names for shuffle managers + val shortShuffleMgrNames = Map( + "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", + "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") + val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) + + val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager, mapOutputTracker) + serializer, conf, securityManager, mapOutputTracker, shuffleManager) val connectionManager = blockManager.connectionManager @@ -246,11 +260,6 @@ object SparkEnv extends Logging { "." } - val shuffleManager = instantiateClass[ShuffleManager]( - "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") - - val shuffleMemoryManager = new ShuffleMemoryManager(conf) - // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 51f40c339d13c..2b99b8a5af250 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,10 +21,18 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.TaskCompletionListener + /** * :: DeveloperApi :: * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + * @param taskMetrics performance metrics of the task */ @DeveloperApi class TaskContext( @@ -39,13 +47,45 @@ class TaskContext( def splitId = partitionId // List of callback functions to execute when the task completes. - @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] // Whether the corresponding task has been killed. - @volatile var interrupted: Boolean = false + @volatile private var interrupted: Boolean = false + + // Whether the task has completed. + @volatile private var completed: Boolean = false + + /** Checks whether the task has completed. */ + def isCompleted: Boolean = completed - // Whether the task has completed, before the onCompleteCallbacks are executed. - @volatile var completed: Boolean = false + /** Checks whether the task has been killed. */ + def isInterrupted: Boolean = interrupted + + // TODO: Also track whether the task has completed successfully or with exception. + + /** + * Add a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + onCompleteCallbacks += listener + this + } + + /** + * Add a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + def addTaskCompletionListener(f: TaskContext => Unit): this.type = { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } + this + } /** * Add a callback function to be executed on task completion. An example use @@ -53,13 +93,22 @@ class TaskContext( * Will be called in any situation - success, failure, or cancellation. * @param f Callback function. */ + @deprecated("use addTaskCompletionListener", "1.1.0") def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += f + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f() + } } - def executeOnCompleteCallbacks() { + /** Marks the task as completed and triggers the listeners. */ + private[spark] def markTaskCompleted(): Unit = { completed = true // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { _() } + onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) } + } + + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(): Unit = { + interrupted = true } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 76d4193e96aea..feeb6c02caa78 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -133,68 +133,62 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over - * the RDD to create a sample size that's exactly equal to the sum of + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of * math.ceil(numItems * samplingRate) over all key values. */ def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double], - exact: Boolean, seed: Long): JavaPairRDD[K, V] = - new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed)) + new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed)) /** * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use additional passes over - * the RDD to create a sample size that's exactly equal to the sum of + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of * math.ceil(numItems * samplingRate) over all key values. * - * Use Utils.random.nextLong as the default seed for the random number generator + * Use Utils.random.nextLong as the default seed for the random number generator. */ def sampleByKey(withReplacement: Boolean, - fractions: JMap[K, Double], - exact: Boolean): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong) + fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKey(withReplacement, fractions, Utils.random.nextLong) /** - * Return a subset of this RDD sampled by key (via stratified sampling). - * - * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * Produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via - * simple random sampling. + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. */ - def sampleByKey(withReplacement: Boolean, + @Experimental + def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double], seed: Long): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, false, seed) + new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, seed)) /** - * Return a subset of this RDD sampled by key (via stratified sampling). + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * Produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with one pass over the RDD via - * simple random sampling. + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. * - * Use Utils.random.nextLong as the default seed for the random number generator + * Use Utils.random.nextLong as the default seed for the random number generator. */ - def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = - sampleByKey(withReplacement, fractions, false, Utils.random.nextLong) + @Experimental + def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong) /** * Return the union of this RDD and another one. Any identical elements will appear multiple diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index f3b05e1243045..49dc95f349eac 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils import org.apache.spark.{Logging, SerializableWritable, SparkException} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io._ @@ -42,7 +43,7 @@ private[python] object Converter extends Logging { defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { converterClass.map { cc => Try { - val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]] + val c = Utils.classForName(cc).newInstance().asInstanceOf[Converter[Any, Any]] logInfo(s"Loaded converter: $cc") c } match { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0b5322c6fb965..747023812f754 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -62,13 +62,13 @@ private[spark] class PythonRDD( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread + envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) - context.addOnCompleteCallback { () => + context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() // Cleanup the worker socket. This will also cause the Python worker to exit. @@ -137,7 +137,7 @@ private[spark] class PythonRDD( } } catch { - case e: Exception if context.interrupted => + case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) throw new TaskKilledException @@ -176,7 +176,7 @@ private[spark] class PythonRDD( /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ def shutdownOnTaskCompletion() { - assert(context.completed) + assert(context.isCompleted) this.interrupt() } @@ -209,7 +209,7 @@ private[spark] class PythonRDD( PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) dataOut.flush() } catch { - case e: Exception if context.completed || context.interrupted => + case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) case e: Exception => @@ -235,10 +235,10 @@ private[spark] class PythonRDD( override def run() { // Kill the worker if it is interrupted, checking until task completion. // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.interrupted && !context.completed) { + while (!context.isInterrupted && !context.isCompleted) { Thread.sleep(2000) } - if (!context.completed) { + if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") env.destroyPythonWorker(pythonExec, envVars.toMap, worker) @@ -315,6 +315,14 @@ private[spark] object PythonRDD extends Logging { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } + def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + sc.broadcast(obj) + } + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { // The right way to implement this would be to use TypeTags to get the full // type of T. Since I don't want to introduce breaking changes throughout the @@ -372,8 +380,8 @@ private[spark] object PythonRDD extends Logging { batchSize: Int) = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, @@ -440,9 +448,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] - val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] + val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] if (path.isDefined) { sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf) } else { @@ -509,9 +517,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - val kc = Class.forName(keyClass).asInstanceOf[Class[K]] - val vc = Class.forName(valueClass).asInstanceOf[Class[V]] - val fc = Class.forName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] + val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] + val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] if (path.isDefined) { sc.sc.hadoopFile(path.get, fc, kc, vc) } else { @@ -558,7 +566,7 @@ private[spark] object PythonRDD extends Logging { for { k <- Option(keyClass) v <- Option(valueClass) - } yield (Class.forName(k), Class.forName(v)) + } yield (Utils.classForName(k), Utils.classForName(v)) } private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String, @@ -621,10 +629,10 @@ private[spark] object PythonRDD extends Logging { val (kc, vc) = getKeyValueTypes(keyClass, valueClass).getOrElse( inferKeyValueTypes(rdd, keyConverterClass, valueConverterClass)) val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) - val codec = Option(compressionCodecClass).map(Class.forName(_).asInstanceOf[Class[C]]) + val codec = Option(compressionCodecClass).map(Utils.classForName(_).asInstanceOf[Class[C]]) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) - val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec=codec) } @@ -653,7 +661,7 @@ private[spark] object PythonRDD extends Logging { val mergedConf = getMergedConf(confAsMap, pyRDD.context.hadoopConfiguration) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) - val fc = Class.forName(outputFormatClass).asInstanceOf[Class[F]] + val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] converted.saveAsNewAPIHadoopFile(path, kc, vc, fc, mergedConf) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 7af260d0b7f26..bf716a8ab025b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -68,7 +68,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val socket = new Socket(daemonHost, daemonPort) val pid = new DataInputStream(socket.getInputStream).readInt() if (pid < 0) { - throw new IllegalStateException("Python daemon failed to launch worker") + throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } daemonWorkers.put(socket, pid) socket diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index a8c827030a1ef..6a187b40628a2 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -32,8 +32,19 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi trait BroadcastFactory { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit + + /** + * Creates a new broadcast variable. + * + * @param value value to broadcast + * @param isLocal whether we are in local mode (single JVM process) + * @param id unique id representing this broadcast variable + */ def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit + def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 86731b684f441..6173fd3a69fc7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,53 +17,117 @@ package org.apache.spark.broadcast -import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} +import java.io._ +import java.nio.ByteBuffer +import scala.collection.JavaConversions.asJavaEnumeration import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.Utils +import org.apache.spark.util.ByteBufferInputStream /** - * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like - * protocol to do a distributed transfer of the broadcasted data to the executors. - * The mechanism is as follows. The driver divides the serializes the broadcasted data, - * divides it into smaller chunks, and stores them in the BlockManager of the driver. - * These chunks are reported to the BlockManagerMaster so that all the executors can - * learn the location of those chunks. The first time the broadcast variable (sent as - * part of task) is deserialized at a executor, all the chunks are fetched using - * the BlockManager. When all the chunks are fetched (initially from the driver's - * BlockManager), they are combined and deserialized to recreate the broadcasted data. - * However, the chunks are also stored in the BlockManager and reported to the - * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns - * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be - * made to other executors who already have those chunks, resulting in a distributed - * fetching. This prevents the driver from being the bottleneck in sending out multiple - * copies of the broadcast data (one per executor) as done by the - * [[org.apache.spark.broadcast.HttpBroadcast]]. + * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. + * + * The mechanism is as follows: + * + * The driver divides the serialized object into small chunks and + * stores those chunks in the BlockManager of the driver. + * + * On each executor, the executor first attempts to fetch the object from its BlockManager. If + * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or + * other executors if available. Once it gets the chunks, it puts the chunks in its own + * BlockManager, ready for other executors to fetch from. + * + * This prevents the driver from being the bottleneck in sending out multiple copies of the + * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]]. + * + * @param obj object to broadcast + * @param isLocal whether Spark is running in local mode (single JVM process). + * @param id A unique identifier for the broadcast variable. */ private[spark] class TorrentBroadcast[T: ClassTag]( - @transient var value_ : T, isLocal: Boolean, id: Long) + obj : T, + @transient private val isLocal: Boolean, + id: Long) extends Broadcast[T](id) with Logging with Serializable { - override protected def getValue() = value_ + /** + * Value of the broadcast object. On driver, this is set directly by the constructor. + * On executors, this is reconstructed by [[readObject]], which builds this value by reading + * blocks from the driver and/or other executors. + */ + @transient private var _value: T = obj private val broadcastId = BroadcastBlockId(id) - TorrentBroadcast.synchronized { + /** Total number of blocks this broadcast variable contains. */ + private val numBlocks: Int = writeBlocks() + + override protected def getValue() = _value + + /** + * Divide the object into multiple blocks and put those blocks in the block manager. + * + * @return number of blocks this broadcast variable is divided into + */ + private def writeBlocks(): Int = { + // For local mode, just put the object in the BlockManager so we can find it later. SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + + if (!isLocal) { + val blocks = TorrentBroadcast.blockifyObject(_value) + blocks.zipWithIndex.foreach { case (block, i) => + SparkEnv.get.blockManager.putBytes( + BroadcastBlockId(id, "piece" + i), + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + } + blocks.length + } else { + 0 + } } - @transient private var arrayOfBlocks: Array[TorrentBlock] = null - @transient private var totalBlocks = -1 - @transient private var totalBytes = -1 - @transient private var hasBlocks = 0 + /** Fetch torrent blocks from the driver and/or other executors. */ + private def readBlocks(): Array[ByteBuffer] = { + // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported + // to the driver, so other executors can pull these chunks from this executor as well. + val blocks = new Array[ByteBuffer](numBlocks) + val bm = SparkEnv.get.blockManager - if (!isLocal) { - sendBroadcast() + for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { + val pieceId = BroadcastBlockId(id, "piece" + pid) + + // First try getLocalBytes because there is a chance that previous attempts to fetch the + // broadcast blocks have already fetched some of the blocks. In that case, some blocks + // would be available locally (on this executor). + var blockOpt = bm.getLocalBytes(pieceId) + if (!blockOpt.isDefined) { + blockOpt = bm.getRemoteBytes(pieceId) + blockOpt match { + case Some(block) => + // If we found the block from remote executors/driver's BlockManager, put the block + // in this executor's BlockManager. + SparkEnv.get.blockManager.putBytes( + pieceId, + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + // If we get here, the option is defined. + blocks(pid) = blockOpt.get + } + blocks } /** @@ -81,30 +145,6 @@ private[spark] class TorrentBroadcast[T: ClassTag]( TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) } - private def sendBroadcast() { - val tInfo = TorrentBroadcast.blockifyObject(value_) - totalBlocks = tInfo.totalBlocks - totalBytes = tInfo.totalBytes - hasBlocks = tInfo.totalBlocks - - // Store meta-info - val metaId = BroadcastBlockId(id, "meta") - val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) - } - - // Store individual pieces - for (i <- 0 until totalBlocks) { - val pieceId = BroadcastBlockId(id, "piece" + i) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - } - } - } - /** Used by the JVM when serializing this object. */ private def writeObject(out: ObjectOutputStream) { assertValid() @@ -115,110 +155,42 @@ private[spark] class TorrentBroadcast[T: ClassTag]( private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(broadcastId) match { + SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => - value_ = x.asInstanceOf[T] + _value = x.asInstanceOf[T] case None => - val start = System.nanoTime logInfo("Started reading broadcast variable " + id) - - // Initialize @transient variables that will receive garbage values from the master. - resetWorkerVariables() - - if (receiveBroadcast()) { - value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - - /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. - * This creates a trade-off between memory usage and latency. Storing copy doubles - * the memory footprint; not storing doubles deserialization cost. Also, - * this does not need to be reported to BlockManagerMaster since other executors - * does not need to access this block (they only need to fetch the chunks, - * which are reported). - */ - SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - - // Remove arrayOfBlocks from memory once value_ is on local cache - resetWorkerVariables() - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 + val start = System.nanoTime() + val blocks = readBlocks() + val time = (System.nanoTime() - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - private def resetWorkerVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - } - - private def receiveBroadcast(): Boolean = { - // Receive meta-info about the size of broadcast data, - // the number of chunks it is divided into, etc. - val metaId = BroadcastBlockId(id, "meta") - var attemptId = 10 - while (attemptId > 0 && totalBlocks == -1) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(metaId) match { - case Some(x) => - val tInfo = x.asInstanceOf[TorrentInfo] - totalBlocks = tInfo.totalBlocks - totalBytes = tInfo.totalBytes - arrayOfBlocks = new Array[TorrentBlock](totalBlocks) - hasBlocks = 0 - - case None => - Thread.sleep(500) - } + _value = TorrentBroadcast.unBlockifyObject[T](blocks) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + SparkEnv.get.blockManager.putSingle( + broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } - attemptId -= 1 } - if (totalBlocks == -1) { - return false - } - - /* - * Fetch actual chunks of data. Note that all these chunks are stored in - * the BlockManager and reported to the master, so that other executors - * can find out and pull the chunks from this executor. - */ - val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) - for (pid <- recvOrder) { - val pieceId = BroadcastBlockId(id, "piece" + pid) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => - arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] - hasBlocks += 1 - SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) - - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) - } - } - } - - hasBlocks == totalBlocks } - } -private[broadcast] object TorrentBroadcast extends Logging { + +private object TorrentBroadcast extends Logging { + /** Size of each block. Default value is 4MB. */ private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null + private var compress: Boolean = false + private var compressionCodec: CompressionCodec = null def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { if (!initialized) { + compress = conf.getBoolean("spark.broadcast.compress", true) + compressionCodec = CompressionCodec.createCodec(conf) initialized = true } } @@ -228,43 +200,42 @@ private[broadcast] object TorrentBroadcast extends Logging { initialized = false } - def blockifyObject[T](obj: T): TorrentInfo = { - val byteArray = Utils.serialize[T](obj) + def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = { + // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks + // so we don't need to do the extra memory copy. + val bos = new ByteArrayOutputStream() + val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.serializeStream(out) + serOut.writeObject[T](obj).close() + val byteArray = bos.toByteArray val bais = new ByteArrayInputStream(byteArray) + val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt + val blocks = new Array[ByteBuffer](numBlocks) - var blockNum = byteArray.length / BLOCK_SIZE - if (byteArray.length % BLOCK_SIZE != 0) { - blockNum += 1 - } - - val blocks = new Array[TorrentBlock](blockNum) var blockId = 0 - for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) val tempByteArray = new Array[Byte](thisBlockSize) bais.read(tempByteArray, 0, thisBlockSize) - blocks(blockId) = new TorrentBlock(blockId, tempByteArray) + blocks(blockId) = ByteBuffer.wrap(tempByteArray) blockId += 1 } bais.close() - - val info = TorrentInfo(blocks, blockNum, byteArray.length) - info.hasBlocks = blockNum - info + blocks } - def unBlockifyObject[T]( - arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, - totalBlocks: Int): T = { - val retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) - } - Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) + def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = { + val is = new SequenceInputStream( + asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) + val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is + + val ser = SparkEnv.get.serializer.newInstance() + val serIn = ser.deserializeStream(in) + val obj = serIn.readObject[T]() + serIn.close() + obj } /** @@ -272,22 +243,6 @@ private[broadcast] object TorrentBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver. */ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { - synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) - } + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } } - -private[broadcast] case class TorrentBlock( - blockID: Int, - byteArray: Array[Byte]) - extends Serializable - -private[broadcast] case class TorrentInfo( - @transient arrayOfBlocks: Array[TorrentBlock], - totalBlocks: Int, - totalBytes: Int) - extends Serializable { - - @transient var hasBlocks = 0 -} diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index c07003784e8ac..065ddda50e65e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -27,12 +27,14 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} /** * Proxy that relays messages to the driver. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends Actor with Logging { +private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) + extends Actor with ActorLogReceive with Logging { + var masterActor: ActorSelection = _ val timeout = AkkaUtils.askTimeout(conf) @@ -114,7 +116,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends } } - override def receive = { + override def receiveWithLogging = { case SubmitDriverResponse(success, driverId, message) => println(message) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9391f24e71ed7..d545f58c5da7e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -219,11 +219,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { /** Fill in values by parsing user options. */ private def parseOpts(opts: Seq[String]): Unit = { - var inSparkOpts = true + val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r // Delineates parsing of Spark options from parsing of user options. parse(opts) + /** + * NOTE: If you add or remove spark-submit options, + * modify NOT ONLY this file but also utils.sh + */ def parse(opts: Seq[String]): Unit = opts match { case ("--name") :: value :: tail => name = value @@ -322,33 +326,21 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { verbose = true parse(tail) + case EQ_SEPARATED_OPT(opt, value) :: tail => + parse(opt :: value :: tail) + + case value :: tail if value.startsWith("-") => + SparkSubmit.printErrorAndExit(s"Unrecognized option '$value'.") + case value :: tail => - if (inSparkOpts) { - value match { - // convert --foo=bar to --foo bar - case v if v.startsWith("--") && v.contains("=") && v.split("=").size == 2 => - val parts = v.split("=") - parse(Seq(parts(0), parts(1)) ++ tail) - case v if v.startsWith("-") => - val errMessage = s"Unrecognized option '$value'." - SparkSubmit.printErrorAndExit(errMessage) - case v => - primaryResource = - if (!SparkSubmit.isShell(v) && !SparkSubmit.isInternal(v)) { - Utils.resolveURI(v).toString - } else { - v - } - inSparkOpts = false - isPython = SparkSubmit.isPython(v) - parse(tail) + primaryResource = + if (!SparkSubmit.isShell(value) && !SparkSubmit.isInternal(value)) { + Utils.resolveURI(value).toString + } else { + value } - } else { - if (!value.isEmpty) { - childArgs += value - } - parse(tail) - } + isPython = SparkSubmit.isPython(value) + childArgs ++= tail case Nil => } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index d38e9e79204c2..32790053a6be8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -30,7 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -56,7 +56,7 @@ private[spark] class AppClient( var registered = false var activeMasterUrl: String = null - class ClientActor extends Actor with Logging { + class ClientActor extends Actor with ActorLogReceive with Logging { var master: ActorSelection = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times @@ -119,7 +119,7 @@ private[spark] class AppClient( .contains(remoteUrl.hostPort) } - override def receive = { + override def receiveWithLogging = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 72d0589689e71..d3674427b1271 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -46,6 +46,11 @@ private[spark] class ApplicationInfo( init() + private def readObject(in: java.io.ObjectInputStream): Unit = { + in.defaultReadObject() + init() + } + private def init() { state = ApplicationState.WAITING executors = new mutable.HashMap[Int, ExecutorInfo] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala index c87b66f047dc8..38db02cd2421b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source class ApplicationSource(val application: ApplicationInfo) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.%s.%s".format("application", application.desc.name, + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.%s.%s".format("application", application.desc.name, System.currentTimeMillis()) metricRegistry.register(MetricRegistry.name("status"), new Gauge[String] { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index a70ecdb375373..5017273e87c07 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -42,14 +42,14 @@ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} private[spark] class Master( host: String, port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher // to use Akka's scheduler.schedule() @@ -167,7 +167,7 @@ private[spark] class Master( context.stop(leaderElectionAgent) } - override def receive = { + override def receiveWithLogging = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -697,7 +697,7 @@ private[spark] class Master( appIdToUI(app.id) = ui webUi.attachSparkUI(ui) // Application UI is successfully rebuilt, so link the Master UI to it - app.desc.appUiUrl = ui.basePath + app.desc.appUiUrl = ui.getBasePath true } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 36c1b87b7f684..9c3f79f1244b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source private[spark] class MasterSource(val master: Master) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "master" + override val metricRegistry = new MetricRegistry() + override val sourceName = "master" // Gauge for worker numbers in cluster metricRegistry.register(MetricRegistry.name("workers"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 458d9947bd873..81400af22c0bf 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,7 +34,7 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} /** * @param masterUrls Each url should look like spark://host:port. @@ -51,7 +51,7 @@ private[spark] class Worker( workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher Utils.checkHost(host, "Expected hostname") @@ -72,7 +72,6 @@ private[spark] class Worker( val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) val testing: Boolean = sys.props.contains("spark.testing") - val masterLock: Object = new Object() var master: ActorSelection = null var masterAddress: Address = null var activeMasterUrl: String = "" @@ -136,7 +135,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) + webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() registerWithMaster() @@ -145,18 +144,16 @@ private[spark] class Worker( } def changeMaster(url: String, uiUrl: String) { - masterLock.synchronized { - activeMasterUrl = url - activeMasterWebUiUrl = uiUrl - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) - masterAddress = activeMasterUrl match { - case Master.sparkUrlRegex(_host, _port) => - Address("akka.tcp", Master.systemName, _host, _port.toInt) - case x => - throw new SparkException("Invalid spark URL: " + x) - } - connected = true + activeMasterUrl = url + activeMasterWebUiUrl = uiUrl + master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + masterAddress = activeMasterUrl match { + case Master.sparkUrlRegex(_host, _port) => + Address("akka.tcp", Master.systemName, _host, _port.toInt) + case x => + throw new SparkException("Invalid spark URL: " + x) } + connected = true } def tryRegisterAllMasters() { @@ -187,7 +184,7 @@ private[spark] class Worker( } } - override def receive = { + override def receiveWithLogging = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true @@ -199,9 +196,7 @@ private[spark] class Worker( } case SendHeartbeat => - masterLock.synchronized { - if (connected) { master ! Heartbeat(workerId) } - } + if (connected) { master ! Heartbeat(workerId) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor @@ -244,9 +239,7 @@ private[spark] class Worker( manager.start() coresUsed += cores_ memoryUsed += memory_ - masterLock.synchronized { - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) - } + master ! ExecutorStateChanged(appId, execId, manager.state, None, None) } catch { case e: Exception => { logError("Failed to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) @@ -254,17 +247,13 @@ private[spark] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - masterLock.synchronized { - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None) - } + master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, None, None) } } } case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - masterLock.synchronized { - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) - } + master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -330,9 +319,7 @@ private[spark] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - masterLock.synchronized { - master ! DriverStateChanged(driverId, state, exception) - } + master ! DriverStateChanged(driverId, state, exception) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem @@ -373,7 +360,8 @@ private[spark] class Worker( private[spark] object Worker extends Logging { def main(argStrings: Array[String]) { SignalLogger.register(log) - val args = new WorkerArguments(argStrings) + val conf = new SparkConf + val args = new WorkerArguments(argStrings, conf) val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) actorSystem.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index dc5158102054e..1e295aaa48c30 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -20,11 +20,12 @@ package org.apache.spark.deploy.worker import java.lang.management.ManagementFactory import org.apache.spark.util.{IntParam, MemoryParam, Utils} +import org.apache.spark.SparkConf /** * Command-line parser for the worker. */ -private[spark] class WorkerArguments(args: Array[String]) { +private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 0 var webUiPort = 8081 @@ -46,6 +47,9 @@ private[spark] class WorkerArguments(args: Array[String]) { if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt } + if (conf.contains("spark.worker.ui.port")) { + webUiPort = conf.get("spark.worker.ui.port").toInt + } if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala index b7ddd8c816cbc..df1e01b23b932 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala @@ -22,8 +22,8 @@ import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.spark.metrics.source.Source private[spark] class WorkerSource(val worker: Worker) extends Source { - val sourceName = "worker" - val metricRegistry = new MetricRegistry() + override val sourceName = "worker" + override val metricRegistry = new MetricRegistry() metricRegistry.register(MetricRegistry.name("executors"), new Gauge[Int] { override def getValue: Int = worker.executors.size diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 530c147000904..6d0d0bbe5ecec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -22,13 +22,15 @@ import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, Di import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat +import org.apache.spark.util.ActorLogReceive /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) extends Actor - with Logging { +private[spark] class WorkerWatcher(workerUrl: String) + extends Actor with ActorLogReceive with Logging { + override def preStart() { context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) @@ -48,7 +50,7 @@ private[spark] class WorkerWatcher(workerUrl: String) extends Actor def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receive = { + override def receiveWithLogging = { case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 47fbda600bea7..b07942a9ca729 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -34,8 +34,8 @@ private[spark] class WorkerWebUI( val worker: Worker, val workDir: File, - port: Option[Int] = None) - extends WebUI(worker.securityMgr, getUIPort(port, worker.conf), worker.conf, name = "WorkerUI") + requestedPort: Int) + extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { val timeout = AkkaUtils.askTimeout(worker.conf) @@ -55,10 +55,5 @@ class WorkerWebUI( } private[spark] object WorkerWebUI { - val DEFAULT_PORT = 8081 val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR - - def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = { - requestedPort.getOrElse(conf.getInt("spark.worker.ui.port", WorkerWebUI.DEFAULT_PORT)) - } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 1f46a0f176490..13af5b6f5812d 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -31,14 +31,15 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, hostPort: String, cores: Int, - sparkProperties: Seq[(String, String)]) extends Actor with ExecutorBackend with Logging { + sparkProperties: Seq[(String, String)]) + extends Actor with ActorLogReceive with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") @@ -52,7 +53,7 @@ private[spark] class CoarseGrainedExecutorBackend( context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } - override def receive = { + override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") // Make this host instead of hostPort ? diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c2b9c660ddaec..2f76e532aeb76 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -62,16 +62,6 @@ private[spark] class Executor( val conf = new SparkConf(true) conf.setAll(properties) - // If we are in yarn mode, systems can have different disk layouts so we must set it - // to what Yarn on this system said was available. This will be used later when SparkEnv - // created. - if (java.lang.Boolean.valueOf( - System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))) { - conf.set("spark.local.dir", getYarnLocalDirs()) - } else if (sys.env.contains("SPARK_LOCAL_DIRS")) { - conf.set("spark.local.dir", sys.env("SPARK_LOCAL_DIRS")) - } - if (!isLocal) { // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire @@ -99,6 +89,9 @@ private[spark] class Executor( private val urlClassLoader = createClassLoader() private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + // Set the classloader for serializer + env.serializer.setDefaultClassLoader(urlClassLoader) + // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) @@ -131,21 +124,6 @@ private[spark] class Executor( threadPool.shutdown() } - /** Get the Yarn approved local directories. */ - private def getYarnLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(System.getenv("LOCAL_DIRS")) - .getOrElse("")) - - if (localDirs.isEmpty) { - throw new Exception("Yarn Local dirs can't be empty") - } - localDirs - } - class TaskRunner( execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer) extends Runnable { @@ -374,6 +352,7 @@ private[spark] class Executor( for (taskRunner <- runningTasks.values()) { if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => + metrics.updateShuffleReadMetrics tasksMetrics += ((taskRunner.taskId, metrics)) } } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 0ed52cfe9df61..d6721586566c2 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -35,9 +35,10 @@ private[spark] class ExecutorSource(val executor: Executor, executorId: String) }) } - val metricRegistry = new MetricRegistry() + override val metricRegistry = new MetricRegistry() + // TODO: It would be nice to pass the application name here - val sourceName = "executor.%s".format(executorId) + override val sourceName = "executor.%s".format(executorId) // Gauge for executor thread pool's actively executing task counts metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 56cd8723a3a22..99a88c13456df 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,8 @@ package org.apache.spark.executor +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.{BlockId, BlockStatus} @@ -81,12 +83,27 @@ class TaskMetrics extends Serializable { var inputMetrics: Option[InputMetrics] = None /** - * If this task reads from shuffle output, metrics on getting shuffle data will be collected here + * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. + * This includes read metrics aggregated over all the task's shuffle dependencies. */ private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None def shuffleReadMetrics = _shuffleReadMetrics + /** + * This should only be used when recreating TaskMetrics, not when updating read metrics in + * executors. + */ + private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) { + _shuffleReadMetrics = shuffleReadMetrics + } + + /** + * ShuffleReadMetrics per dependency for collecting independently while task is in progress. + */ + @transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] = + new ArrayBuffer[ShuffleReadMetrics]() + /** * If this task writes to shuffle output, metrics on the written shuffle data will be collected * here @@ -98,19 +115,31 @@ class TaskMetrics extends Serializable { */ var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None - /** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */ - def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized { - _shuffleReadMetrics match { - case Some(existingMetrics) => - existingMetrics.shuffleFinishTime = math.max( - existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime) - existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime - existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched - existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched - existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead - case None => - _shuffleReadMetrics = Some(newMetrics) + /** + * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization + * issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each + * dependency, and merge these metrics before reporting them to the driver. This method returns + * a ShuffleReadMetrics for a dependency and registers it for merging later. + */ + private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized { + val readMetrics = new ShuffleReadMetrics() + depsShuffleReadMetrics += readMetrics + readMetrics + } + + /** + * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. + */ + private[spark] def updateShuffleReadMetrics() = synchronized { + val merged = new ShuffleReadMetrics() + for (depMetrics <- depsShuffleReadMetrics) { + merged.fetchWaitTime += depMetrics.fetchWaitTime + merged.localBlocksFetched += depMetrics.localBlocksFetched + merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched + merged.remoteBytesRead += depMetrics.remoteBytesRead + merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime) } + _shuffleReadMetrics = Some(merged) } } @@ -190,10 +219,10 @@ class ShuffleWriteMetrics extends Serializable { /** * Number of bytes written for the shuffle by this task */ - var shuffleBytesWritten: Long = _ + @volatile var shuffleBytesWritten: Long = _ /** * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ - var shuffleWriteTime: Long = _ + @volatile var shuffleWriteTime: Long = _ } diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 1b66218d86dd9..ef9c43ecf14f6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -46,17 +46,24 @@ trait CompressionCodec { private[spark] object CompressionCodec { + + private val shortCompressionCodecNames = Map( + "lz4" -> classOf[LZ4CompressionCodec].getName, + "lzf" -> classOf[LZFCompressionCodec].getName, + "snappy" -> classOf[SnappyCompressionCodec].getName) + def createCodec(conf: SparkConf): CompressionCodec = { createCodec(conf, conf.get("spark.io.compression.codec", DEFAULT_COMPRESSION_CODEC)) } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val ctor = Class.forName(codecName, true, Utils.getContextOrSparkClassLoader) + val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader) .getConstructor(classOf[SparkConf]) ctor.newInstance(conf).asInstanceOf[CompressionCodec] } - val DEFAULT_COMPRESSION_CODEC = classOf[SnappyCompressionCodec].getName + val DEFAULT_COMPRESSION_CODEC = "snappy" } diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala index f865f9648a91e..635bff2cd7ec8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala @@ -21,12 +21,9 @@ import com.codahale.metrics.MetricRegistry import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} private[spark] class JvmSource extends Source { - val sourceName = "jvm" - val metricRegistry = new MetricRegistry() + override val sourceName = "jvm" + override val metricRegistry = new MetricRegistry() - val gcMetricSet = new GarbageCollectorMetricSet - val memGaugeSet = new MemoryUsageGaugeSet - - metricRegistry.registerAll(gcMetricSet) - metricRegistry.registerAll(memGaugeSet) + metricRegistry.registerAll(new GarbageCollectorMetricSet) + metricRegistry.registerAll(new MemoryUsageGaugeSet) } diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index 04df2f3b0d696..af35f1fc3e459 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val security = if (isSecurityNeg) 1 else 0 if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null) + new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null) gotChunkForSendingOnce = true return Some(newChunk) } @@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, + hasError, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } @@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer) return Some(newChunk) } None diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 4c00225280cce..b3e951ded6e77 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,10 +17,12 @@ package org.apache.spark.network +import java.io.IOException import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ +import java.util.{Timer, TimerTask} import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} @@ -45,22 +47,32 @@ private[spark] class ConnectionManager( name: String = "Connection manager") extends Logging { + /** + * Used by sendMessageReliably to track messages being sent. + * @param message the message that was sent + * @param connectionManagerId the connection manager that sent this message + * @param completionHandler callback that's invoked when the send has completed or failed + */ class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, completionHandler: MessageStatus => Unit) { + /** This is non-None if message has been ack'd */ var ackMessage: Option[Message] = None - var attempted = false - var acked = false - def markDone() { completionHandler(this) } + def markDone(ackMessage: Option[Message]) { + this.ackMessage = ackMessage + completionHandler(this) + } } private val selector = SelectorProvider.provider.openSelector() + private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) // default to 30 second timeout waiting for authentication private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) + private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), @@ -442,11 +454,7 @@ private[spark] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.markDone() - } + status.markDone(None) }) messageStatuses.retain((i, status) => { @@ -459,7 +467,7 @@ private[spark] class ConnectionManager( val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId) if (!sendingConnectionOpt.isDefined) { - logError("Corresponding SendingConnectionManagerId not found") + logError(s"Corresponding SendingConnection to ${remoteConnectionManagerId} not found") return } @@ -475,11 +483,7 @@ private[spark] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.synchronized { - s.attempted = true - s.acked = false - s.markDone() - } + s.markDone(None) } messageStatuses.retain((i, status) => { @@ -547,13 +551,13 @@ private[spark] class ConnectionManager( val securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId.toString) val message = securityMsgResp.toBufferMessage - if (message == null) throw new Exception("Error creating security message") + if (message == null) throw new IOException("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) } catch { case e: Exception => { logError("Error handling sasl client authentication", e) waitingConn.close() - throw new Exception("Error evaluating sasl response: " + e) + throw new IOException("Error evaluating sasl response: ", e) } } } @@ -649,46 +653,59 @@ private[spark] class ConnectionManager( } } if (bufferMessage.hasAckId()) { - val sentMessageStatus = messageStatuses.synchronized { + messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status + status.markDone(Some(message)) } case None => { - throw new Exception("Could not find reference for received ack message " + - message.id) + /** + * We can fall down on this code because of following 2 cases + * + * (1) Invalid ack sent due to buggy code. + * + * (2) Late-arriving ack for a SendMessageStatus + * To avoid unwilling late-arriving ack + * caused by long pause like GC, you can set + * larger value than default to spark.core.connection.ack.wait.timeout + */ + logWarning(s"Could not find reference for received ack Message ${message.id}") } } } - sentMessageStatus.synchronized { - sentMessageStatus.ackMessage = Some(message) - sentMessageStatus.attempted = true - sentMessageStatus.acked = true - sentMessageStatus.markDone() - } } else { - val ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logDebug("Not calling back as callback is null") - None - } + var ackMessage : Option[Message] = None + try { + ackMessage = if (onReceiveCallback != null) { + logDebug("Calling back") + onReceiveCallback(bufferMessage, connectionManagerId) + } else { + logDebug("Not calling back as callback is null") + None + } - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " - + ackMessage.get.getClass) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logDebug("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + if (ackMessage.isDefined) { + if (!ackMessage.get.isInstanceOf[BufferMessage]) { + logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + + ackMessage.get.getClass) + } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { + logDebug("Response to " + bufferMessage + " does not have ack id set") + ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + } + } + } catch { + case e: Exception => { + logError(s"Exception was thrown while processing message", e) + val m = Message.createBufferMessage(bufferMessage.id) + m.hasError = true + ackMessage = Some(m) } + } finally { + sendMessage(connectionManagerId, ackMessage.getOrElse { + Message.createBufferMessage(bufferMessage.id) + }) } - - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) } } case _ => throw new Exception("Unknown type message received") @@ -800,11 +817,7 @@ private[spark] class ConnectionManager( case Some(msgStatus) => { messageStatuses -= message.id logInfo("Notifying " + msgStatus.connectionManagerId) - msgStatus.synchronized { - msgStatus.attempted = true - msgStatus.acked = false - msgStatus.markDone() - } + msgStatus.markDone(None) } case None => { logError("no messageStatus for failed message id: " + message.id) @@ -823,28 +836,57 @@ private[spark] class ConnectionManager( selector.wakeup() } + /** + * Send a message and block until an acknowldgment is received or an error occurs. + * @param connectionManagerId the message's destination + * @param message the message being sent + * @return a Future that either returns the acknowledgment message or captures an exception. + */ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message) - : Future[Option[Message]] = { - val promise = Promise[Option[Message]] - val status = new MessageStatus( - message, connectionManagerId, s => promise.success(s.ackMessage)) + : Future[Message] = { + val promise = Promise[Message]() + + val timeoutTask = new TimerTask { + override def run(): Unit = { + messageStatuses.synchronized { + messageStatuses.remove(message.id).foreach ( s => { + promise.failure( + new IOException(s"sendMessageReliably failed because ack " + + "was not received within ${ackTimeout} sec")) + }) + } + } + } + + val status = new MessageStatus(message, connectionManagerId, s => { + timeoutTask.cancel() + s.ackMessage match { + case None => // Indicates a failure where we either never sent or never got ACK'd + promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) + case Some(ackMessage) => + if (ackMessage.hasError) { + promise.failure( + new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + } else { + promise.success(ackMessage) + } + } + }) messageStatuses.synchronized { messageStatuses += ((message.id, status)) } + + ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } - def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, - message: Message): Option[Message] = { - Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) - } - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { onReceiveCallback = callback } def stop() { + ackTimeoutMonitor.cancel() selectorThread.interrupt() selectorThread.join() selector.close() @@ -862,6 +904,7 @@ private[spark] class ConnectionManager( private[spark] object ConnectionManager { + import ExecutionContext.Implicits.global def main(args: Array[String]) { val conf = new SparkConf @@ -896,7 +939,7 @@ private[spark] object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), Duration.Inf) }) println("--------------------------") println() @@ -917,8 +960,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -952,8 +997,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis @@ -982,8 +1029,10 @@ private[spark] object ConnectionManager { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) }).foreach(f => { - val g = Await.result(f, 1 second) - if (!g.isDefined) println("Failed") + f.onFailure { + case e => println("Failed due to " + e) + } + Await.ready(f, 1 second) }) val finishTime = System.currentTimeMillis Thread.sleep(1000) diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index 7caccfdbb44f9..04ea50f62918c 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var startTime = -1L var finishTime = -1L var isSecurityNeg = false + var hasError = false def size: Int @@ -87,6 +88,7 @@ private[spark] object Message { case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) } + newMessage.hasError = header.hasError newMessage.senderAddress = header.address newMessage } diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index ead663ede7a1c..f3ecca5f992e0 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader( val totalSize: Int, val chunkSize: Int, val other: Int, + val hasError: Boolean, val securityNeg: Int, val address: InetSocketAddress) { lazy val buffer = { @@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). + put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]). putInt(securityNeg). putInt(ip.size). put(ip). @@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader( private[spark] object MessageChunkHeader { - val HEADER_SIZE = 44 + val HEADER_SIZE = 45 def create(buffer: ByteBuffer): MessageChunkHeader = { if (buffer.remaining != HEADER_SIZE) { @@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() + val hasError = buffer.get() != 0 val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg, new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index b8ea7c2cff9a2..ea2ad104ecae1 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -20,6 +20,10 @@ package org.apache.spark.network import java.nio.ByteBuffer import org.apache.spark.{SecurityManager, SparkConf} +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.Try + private[spark] object SenderTest { def main(args: Array[String]) { @@ -51,7 +55,8 @@ private[spark] object SenderTest { val dataMessage = Message.createBufferMessage(buffer.duplicate) val startTime = System.currentTimeMillis /* println("Started timer at " + startTime) */ - val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) + val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage) + val responseStr: String = Try(Await.result(promise, Duration.Inf)) .map { response => val buffer = response.asInstanceOf[BufferMessage].buffers(0) new String(buffer.array, "utf-8") diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala deleted file mode 100644 index 136c1912045aa..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import io.netty.buffer._ - -import org.apache.spark.Logging -import org.apache.spark.storage.{BlockId, TestBlockId} - -private[spark] class FileHeader ( - val fileLen: Int, - val blockId: BlockId) extends Logging { - - lazy val buffer = { - val buf = Unpooled.buffer() - buf.capacity(FileHeader.HEADER_SIZE) - buf.writeInt(fileLen) - buf.writeInt(blockId.name.length) - blockId.name.foreach((x: Char) => buf.writeByte(x)) - // padding the rest of header - if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { - buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) - } else { - throw new Exception("too long header " + buf.readableBytes) - logInfo("too long header") - } - buf - } - -} - -private[spark] object FileHeader { - - val HEADER_SIZE = 40 - - def getFileLenOffset = 0 - def getFileLenSize = Integer.SIZE/8 - - def create(buf: ByteBuf): FileHeader = { - val length = buf.readInt - val idLength = buf.readInt - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buf.readByte().asInstanceOf[Char] - } - val blockId = BlockId(idBuilder.toString()) - new FileHeader(length, blockId) - } - - def main (args:Array[String]) { - val header = new FileHeader(25, TestBlockId("my_block")) - val buf = header.buffer - val newHeader = FileHeader.create(buf) - System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) - } -} - diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala new file mode 100644 index 0000000000000..b5870152c5a64 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import org.apache.spark.SparkConf + +/** + * A central location that tracks all the settings we exposed to users. + */ +private[spark] +class NettyConfig(conf: SparkConf) { + + /** Port the server listens on. Default to a random port. */ + private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0) + + /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ + private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase + + /** Connect timeout in secs. Default 60 secs. */ + private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000 + + /** + * Percentage of the desired amount of time spent for I/O in the child event loops. + * Only applicable in nio and epoll. + */ + private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80) + + /** Requested maximum length of the queue of incoming connections. */ + private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) + + /** + * Receive buffer size (SO_RCVBUF). + * Note: the optimal size for receive buffer and send buffer should be + * latency * network_bandwidth. + * Assuming latency = 1ms, network_bandwidth = 10Gbps + * buffer size should be ~ 1.25MB + */ + private[netty] val receiveBuf: Option[Int] = + conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) + + /** Send buffer size (SO_SNDBUF). */ + private[netty] val sendBuf: Option[Int] = + conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) +} diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala old mode 100755 new mode 100644 similarity index 80% rename from core/src/main/java/org/apache/spark/network/netty/PathResolver.java rename to core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala index 7ad8d03efbadc..0d7695072a7b1 --- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java +++ b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty -import org.apache.spark.storage.BlockId; -import org.apache.spark.storage.FileSegment; +import org.apache.spark.storage.{BlockId, FileSegment} -public interface PathResolver { +trait PathResolver { /** Get the file segment in which the given block resides. */ - FileSegment getBlockLocation(BlockId blockId); + def getBlockLocation(blockId: BlockId): FileSegment } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala deleted file mode 100644 index e7b2855e1ec91..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.util.concurrent.Executors - -import scala.collection.JavaConverters._ - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.util.CharsetUtil - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.storage.BlockId - -private[spark] class ShuffleCopier(conf: SparkConf) extends Logging { - - def getBlock(host: String, port: Int, blockId: BlockId, - resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { - - val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000) - val fc = new FileClient(handler, connectTimeout) - - try { - fc.init() - fc.connect(host, port) - fc.sendRequest(blockId.name) - fc.waitForClose() - fc.close() - } catch { - // Handle any socket-related exceptions in FileClient - case e: Exception => { - logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e) - handler.handleError(blockId) - } - } - } - - def getBlock(cmId: ConnectionManagerId, blockId: BlockId, - resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { - getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) - } - - def getBlocks(cmId: ConnectionManagerId, - blocks: Seq[(BlockId, Long)], - resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { - - for ((blockId, size) <- blocks) { - getBlock(cmId, blockId, resultCollectCallback) - } - } -} - - -private[spark] object ShuffleCopier extends Logging { - - private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit) - extends FileClientHandler with Logging { - - override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)") - resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) - } - - override def handleError(blockId: BlockId) { - if (!isComplete) { - resultCollectCallBack(blockId, -1, null) - } - } - } - - def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) { - if (size != -1) { - logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") - } - } - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: ShuffleCopier ") - System.exit(1) - } - val host = args(0) - val port = args(1).toInt - val blockId = BlockId(args(2)) - val threads = if (args.length > 3) args(3).toInt else 10 - - val copiers = Executors.newFixedThreadPool(80) - val tasks = (for (i <- Range(0, threads)) yield { - Executors.callable(new Runnable() { - def run() { - val copier = new ShuffleCopier(new SparkConf) - copier.getBlock(host, port, blockId, echoResultCollectCallBack) - } - }) - }).asJava - copiers.invokeAll(tasks) - copiers.shutdown() - System.exit(0) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala deleted file mode 100644 index 7ef7aecc6a9fb..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import java.io.File - -import org.apache.spark.Logging -import org.apache.spark.util.Utils -import org.apache.spark.storage.{BlockId, FileSegment} - -private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { - - val server = new FileServer(pResolver, portIn) - server.start() - - def stop() { - server.stop() - } - - def port: Int = server.getPort() -} - - -/** - * An application for testing the shuffle sender as a standalone program. - */ -private[spark] object ShuffleSender { - - def main(args: Array[String]) { - if (args.length < 3) { - System.err.println( - "Usage: ShuffleSender ") - System.exit(1) - } - - val port = args(0).toInt - val subDirsPerLocalDir = args(1).toInt - val localDirs = args.drop(2).map(new File(_)) - - val pResovler = new PathResolver { - override def getBlockLocation(blockId: BlockId): FileSegment = { - if (!blockId.isShuffle) { - throw new Exception("Block " + blockId + " is not a shuffle block") - } - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(blockId) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) - val file = new File(subDir, blockId.name) - new FileSegment(file, 0, file.length()) - } - } - val sender = new ShuffleSender(port, pResovler) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala new file mode 100644 index 0000000000000..e28219dd7745b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.util.EventListener + + +trait BlockClientListener extends EventListener { + + def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit + + def onFetchFailure(blockId: String, errorMsg: String): Unit + +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala new file mode 100644 index 0000000000000..5aea7ba2f3673 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.util.concurrent.TimeoutException + +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel.socket.SocketChannel +import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption} +import io.netty.handler.codec.LengthFieldBasedFrameDecoder +import io.netty.handler.codec.string.StringEncoder +import io.netty.util.CharsetUtil + +import org.apache.spark.Logging + +/** + * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]]. + * Use [[BlockFetchingClientFactory]] to instantiate this client. + * + * The constructor blocks until a connection is successfully established. + * + * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +@throws[TimeoutException] +private[spark] +class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int) + extends Logging { + + private val handler = new BlockFetchingClientHandler + + /** Netty Bootstrap for creating the TCP connection. */ + private val bootstrap: Bootstrap = { + val b = new Bootstrap + b.group(factory.workerGroup) + .channel(factory.socketChannelClass) + // Use pooled buffers to reduce temporary buffer allocation + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) + + b.handler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)) + // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4 + .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4)) + .addLast("handler", handler) + } + }) + b + } + + /** Netty ChannelFuture for the connection. */ + private val cf: ChannelFuture = bootstrap.connect(hostname, port) + if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { + throw new TimeoutException( + s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") + } + + /** + * Ask the remote server for a sequence of blocks, and execute the callback. + * + * Note that this is asynchronous and returns immediately. Upstream caller should throttle the + * rate of fetching; otherwise we could run out of memory. + * + * @param blockIds sequence of block ids to fetch. + * @param listener callback to fire on fetch success / failure. + */ + def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = { + // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline. + // It's also best to limit the number of "flush" calls since it requires system calls. + // Let's concatenate the string and then call writeAndFlush once. + // This is also why this implementation might be more efficient than multiple, separate + // fetch block calls. + var startTime: Long = 0 + logTrace { + startTime = System.nanoTime + s"Sending request $blockIds to $hostname:$port" + } + + blockIds.foreach { blockId => + handler.addRequest(blockId, listener) + } + + val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n") + writeFuture.addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace { + val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 + s"Sending request $blockIds to $hostname:$port took $timeTaken ms" + } + } else { + // Fail all blocks. + val errorMsg = + s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" + logError(errorMsg, future.cause) + blockIds.foreach { blockId => + listener.onFetchFailure(blockId, errorMsg) + handler.removeRequest(blockId) + } + } + } + }) + } + + def waitForClose(): Unit = { + cf.channel().closeFuture().sync() + } + + def close(): Unit = cf.channel().close() +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala new file mode 100644 index 0000000000000..2b28402c52b49 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.nio.NioSocketChannel +import io.netty.channel.socket.oio.OioSocketChannel +import io.netty.channel.{EventLoopGroup, Channel} + +import org.apache.spark.SparkConf +import org.apache.spark.network.netty.NettyConfig +import org.apache.spark.util.Utils + +/** + * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses + * the worker thread pool for Netty. + * + * Concurrency: createClient is safe to be called from multiple threads concurrently. + */ +private[spark] +class BlockFetchingClientFactory(val conf: NettyConfig) { + + def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) + + /** A thread factory so the threads are named (for debugging). */ + val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") + + /** The following two are instantiated by the [[init]] method, depending ioMode. */ + var socketChannelClass: Class[_ <: Channel] = _ + var workerGroup: EventLoopGroup = _ + + init() + + /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ + private def init(): Unit = { + def initOio(): Unit = { + socketChannelClass = classOf[OioSocketChannel] + workerGroup = new OioEventLoopGroup(0, threadFactory) + } + def initNio(): Unit = { + socketChannelClass = classOf[NioSocketChannel] + workerGroup = new NioEventLoopGroup(0, threadFactory) + } + def initEpoll(): Unit = { + socketChannelClass = classOf[EpollSocketChannel] + workerGroup = new EpollEventLoopGroup(0, threadFactory) + } + + conf.ioMode match { + case "nio" => initNio() + case "oio" => initOio() + case "epoll" => initEpoll() + case "auto" => + // For auto mode, first try epoll (only available on Linux), then nio. + try { + initEpoll() + } catch { + // TODO: Should we log the throwable? But that always happen on non-Linux systems. + // Perhaps the right thing to do is to check whether the system is Linux, and then only + // call initEpoll on Linux. + case e: Throwable => initNio() + } + } + } + + /** + * Create a new BlockFetchingClient connecting to the given remote host / port. + * + * This blocks until a connection is successfully established. + * + * Concurrency: This method is safe to call from multiple threads. + */ + def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = { + new BlockFetchingClient(this, remoteHost, remotePort) + } + + def stop(): Unit = { + if (workerGroup != null) { + workerGroup.shutdownGracefully() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala new file mode 100644 index 0000000000000..83265b164299d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import io.netty.buffer.ByteBuf +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging + + +/** + * Handler that processes server responses. It uses the protocol documented in + * [[org.apache.spark.network.netty.server.BlockServer]]. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +private[client] +class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging { + + /** Tracks the list of outstanding requests and their listeners on success/failure. */ + private val outstandingRequests = java.util.Collections.synchronizedMap { + new java.util.HashMap[String, BlockClientListener] + } + + def addRequest(blockId: String, listener: BlockClientListener): Unit = { + outstandingRequests.put(blockId, listener) + } + + def removeRequest(blockId: String): Unit = { + outstandingRequests.remove(blockId) + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" + logError(errorMsg, cause) + + // Fire the failure callback for all outstanding blocks + outstandingRequests.synchronized { + val iter = outstandingRequests.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + entry.getValue.onFetchFailure(entry.getKey, errorMsg) + } + outstandingRequests.clear() + } + + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { + val totalLen = in.readInt() + val blockIdLen = in.readInt() + val blockIdBytes = new Array[Byte](math.abs(blockIdLen)) + in.readBytes(blockIdBytes) + val blockId = new String(blockIdBytes) + val blockSize = totalLen - math.abs(blockIdLen) - 4 + + def server = ctx.channel.remoteAddress.toString + + // blockIdLen is negative when it is an error message. + if (blockIdLen < 0) { + val errorMessageBytes = new Array[Byte](blockSize) + in.readBytes(errorMessageBytes) + val errorMsg = new String(errorMessageBytes) + logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server") + + val listener = outstandingRequests.get(blockId) + if (listener == null) { + // Ignore callback + logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") + } else { + outstandingRequests.remove(blockId) + listener.onFetchFailure(blockId, errorMsg) + } + } else { + logTrace(s"Received block $blockId ($blockSize B) from $server") + + val listener = outstandingRequests.get(blockId) + if (listener == null) { + // Ignore callback + logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") + } else { + outstandingRequests.remove(blockId) + listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in)) + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala similarity index 50% rename from core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java rename to core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala index 46efec8f8d963..9740ee64d1f2d 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerChannelInitializer.java +++ b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala @@ -15,27 +15,30 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.network.netty.client -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.DelimiterBasedFrameDecoder; -import io.netty.handler.codec.Delimiters; -import io.netty.handler.codec.string.StringDecoder; - -class FileServerChannelInitializer extends ChannelInitializer { +/** + * A simple iterator that lazily initializes the underlying iterator. + * + * The use case is that sometimes we might have many iterators open at the same time, and each of + * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer). + * This could lead to too many buffers open. If this iterator is used, we lazily initialize those + * buffers. + */ +private[spark] +class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] { - private final PathResolver pResolver; + lazy val proxy = createIterator - FileServerChannelInitializer(PathResolver pResolver) { - this.pResolver = pResolver; + override def hasNext: Boolean = { + val gotNext = proxy.hasNext + if (!gotNext) { + close() + } + gotNext } - @Override - public void initChannel(SocketChannel channel) { - channel.pipeline() - .addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter())) - .addLast("stringDecoder", new StringDecoder()) - .addLast("handler", new FileServerHandler(pResolver)); - } + override def next(): Any = proxy.next() + + def close(): Unit = Unit } diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala new file mode 100644 index 0000000000000..ea1abf5eccc26 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.io.InputStream +import java.nio.ByteBuffer + +import io.netty.buffer.{ByteBuf, ByteBufInputStream} + + +/** + * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty. + * This is a Scala value class. + * + * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of + * reference by the retain method and release method. + */ +private[spark] +class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal { + + /** Return the nio ByteBuffer view of the underlying buffer. */ + def byteBuffer(): ByteBuffer = underlying.nioBuffer + + /** Creates a new input stream that starts from the current position of the buffer. */ + def inputStream(): InputStream = new ByteBufInputStream(underlying) + + /** Increment the reference counter by one. */ + def retain(): Unit = underlying.retain() + + /** Decrement the reference counter by one and release the buffer if the ref count is 0. */ + def release(): Unit = underlying.release() +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala new file mode 100644 index 0000000000000..162e9cc6828d4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +/** + * Header describing a block. This is used only in the server pipeline. + * + * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it. + * + * @param blockSize length of the block content, excluding the length itself. + * If positive, this is the header for a block (not part of the header). + * If negative, this is the header and content for an error message. + * @param blockId block id + * @param error some error message from reading the block + */ +private[server] +class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala new file mode 100644 index 0000000000000..8e4dda4ef8595 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.handler.codec.MessageToByteEncoder + +/** + * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol. + */ +private[server] +class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] { + override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = { + // message = message length (4 bytes) + block id length (4 bytes) + block id + block data + // message length = block id length (4 bytes) + size of block id + size of block data + val blockIdBytes = msg.blockId.getBytes + msg.error match { + case Some(errorMsg) => + val errorBytes = errorMsg.getBytes + out.writeInt(4 + blockIdBytes.length + errorBytes.size) + out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors + out.writeBytes(blockIdBytes) // next is blockId itself + out.writeBytes(errorBytes) // error message + case None => + out.writeInt(4 + blockIdBytes.length + msg.blockSize) + out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length + out.writeBytes(blockIdBytes) // next is blockId itself + // msg of size blockSize will be written by ServerHandler + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala new file mode 100644 index 0000000000000..7b2f9a8d4dfd0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import java.net.InetSocketAddress + +import io.netty.bootstrap.ServerBootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption} +import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.channel.socket.oio.OioServerSocketChannel +import io.netty.handler.codec.LineBasedFrameDecoder +import io.netty.handler.codec.string.StringDecoder +import io.netty.util.CharsetUtil + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.network.netty.NettyConfig +import org.apache.spark.storage.BlockDataProvider +import org.apache.spark.util.Utils + + +/** + * Server for serving Spark data blocks. + * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]]. + * + * Protocol for requesting blocks (client to server): + * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n" + * + * Protocol for sending blocks (server to client): + * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data. + * + * frame-length should not include the length of itself. + * If block-id-length is negative, then this is an error message rather than block-data. The real + * length is the absolute value of the frame-length. + * + */ +private[spark] +class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging { + + def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = { + this(new NettyConfig(sparkConf), dataProvider) + } + + def port: Int = _port + + def hostName: String = _hostName + + private var _port: Int = conf.serverPort + private var _hostName: String = "" + private var bootstrap: ServerBootstrap = _ + private var channelFuture: ChannelFuture = _ + + init() + + /** Initialize the server. */ + private def init(): Unit = { + bootstrap = new ServerBootstrap + val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") + val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") + + // Use only one thread to accept connections, and 2 * num_cores for worker. + def initNio(): Unit = { + val bossGroup = new NioEventLoopGroup(1, bossThreadFactory) + val workerGroup = new NioEventLoopGroup(0, workerThreadFactory) + workerGroup.setIoRatio(conf.ioRatio) + bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) + } + def initOio(): Unit = { + val bossGroup = new OioEventLoopGroup(1, bossThreadFactory) + val workerGroup = new OioEventLoopGroup(0, workerThreadFactory) + bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) + } + def initEpoll(): Unit = { + val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory) + val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory) + workerGroup.setIoRatio(conf.ioRatio) + bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) + } + + conf.ioMode match { + case "nio" => initNio() + case "oio" => initOio() + case "epoll" => initEpoll() + case "auto" => + // For auto mode, first try epoll (only available on Linux), then nio. + try { + initEpoll() + } catch { + // TODO: Should we log the throwable? But that always happen on non-Linux systems. + // Perhaps the right thing to do is to check whether the system is Linux, and then only + // call initEpoll on Linux. + case e: Throwable => initNio() + } + } + + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + + // Various (advanced) user-configured settings. + conf.backLog.foreach { backLog => + bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) + } + conf.receiveBuf.foreach { receiveBuf => + bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) + } + conf.sendBuf.foreach { sendBuf => + bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) + } + + bootstrap.childHandler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 + .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) + .addLast("blockHeaderEncoder", new BlockHeaderEncoder) + .addLast("handler", new BlockServerHandler(dataProvider)) + } + }) + + channelFuture = bootstrap.bind(new InetSocketAddress(_port)) + channelFuture.sync() + + val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] + _port = addr.getPort + _hostName = addr.getHostName + } + + /** Shutdown the server. */ + def stop(): Unit = { + if (channelFuture != null) { + channelFuture.channel().close().awaitUninterruptibly() + channelFuture = null + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully() + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully() + } + bootstrap = null + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala new file mode 100644 index 0000000000000..cc70bd0c5c477 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import io.netty.channel.ChannelInitializer +import io.netty.channel.socket.SocketChannel +import io.netty.handler.codec.LineBasedFrameDecoder +import io.netty.handler.codec.string.StringDecoder +import io.netty.util.CharsetUtil +import org.apache.spark.storage.BlockDataProvider + + +/** Channel initializer that sets up the pipeline for the BlockServer. */ +private[netty] +class BlockServerChannelInitializer(dataProvider: BlockDataProvider) + extends ChannelInitializer[SocketChannel] { + + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 + .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) + .addLast("blockHeaderEncoder", new BlockHeaderEncoder) + .addLast("handler", new BlockServerHandler(dataProvider)) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala new file mode 100644 index 0000000000000..40dd5e5d1a2ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import java.io.FileInputStream +import java.nio.ByteBuffer +import java.nio.channels.FileChannel + +import io.netty.buffer.Unpooled +import io.netty.channel._ + +import org.apache.spark.Logging +import org.apache.spark.storage.{FileSegment, BlockDataProvider} + + +/** + * A handler that processes requests from clients and writes block data back. + * + * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first + * so channelRead0 is called once per line (i.e. per block id). + */ +private[server] +class BlockServerHandler(dataProvider: BlockDataProvider) + extends SimpleChannelInboundHandler[String] with Logging { + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = { + def client = ctx.channel.remoteAddress.toString + + // A helper function to send error message back to the client. + def respondWithError(error: String): Unit = { + ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (!future.isSuccess) { + // TODO: Maybe log the success case as well. + logError(s"Error sending error back to $client", future.cause) + ctx.close() + } + } + } + ) + } + + def writeFileSegment(segment: FileSegment): Unit = { + // Send error message back if the block is too large. Even though we are capable of sending + // large (2G+) blocks, the receiving end cannot handle it so let's fail fast. + // Once we fixed the receiving end to be able to process large blocks, this should be removed. + // Also make sure we update BlockHeaderEncoder to support length > 2G. + + // See [[BlockHeaderEncoder]] for the way length is encoded. + if (segment.length + blockId.length + 4 > Int.MaxValue) { + respondWithError(s"Block $blockId size ($segment.length) greater than 2G") + return + } + + var fileChannel: FileChannel = null + try { + fileChannel = new FileInputStream(segment.file).getChannel + } catch { + case e: Exception => + logError( + s"Error opening channel for $blockId in ${segment.file} for request from $client", e) + respondWithError(e.getMessage) + } + + // Found the block. Send it back. + if (fileChannel != null) { + // Write the header and block data. In the case of failures, the listener on the block data + // write should close the connection. + ctx.write(new BlockHeader(segment.length.toInt, blockId)) + + val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length) + ctx.writeAndFlush(region).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${segment.length} B) back to $client") + } else { + logError(s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() + } + } + }) + } + } + + def writeByteBuffer(buf: ByteBuffer): Unit = { + ctx.write(new BlockHeader(buf.remaining, blockId)) + ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client") + } else { + logError(s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() + } + } + }) + } + + logTrace(s"Received request from $client to fetch block $blockId") + + var blockData: Either[FileSegment, ByteBuffer] = null + + // First make sure we can find the block. If not, send error back to the user. + try { + blockData = dataProvider.getBlockData(blockId) + } catch { + case e: Exception => + logError(s"Error opening block $blockId for request from $client", e) + respondWithError(e.getMessage) + return + } + + blockData match { + case Left(segment) => writeFileSegment(segment) + case Right(buf) => writeByteBuffer(buf) + } + + } // end of channelRead0 +} diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 34c51b833025e..20938781ac694 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -141,7 +141,7 @@ private[spark] object CheckpointRDD extends Logging { val deserializeStream = serializer.deserializeStream(fileInputStream) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => deserializeStream.close()) + context.addTaskCompletionListener(context => deserializeStream.close()) deserializeStream.asIterator.asInstanceOf[Iterator[T]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 9ca971c8a4c27..e0494ee39657c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -95,7 +95,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * If the elements in RDD do not vary (max == min) always returns a single bucket. */ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { - // Compute the minimum and the maxium + // Scala's built-in range has issues. See #SI-8782 + def customRange(min: Double, max: Double, steps: Int): IndexedSeq[Double] = { + val span = max - min + Range.Int(0, steps, 1).map(s => min + (s * span) / steps) :+ max + } + // Compute the minimum and the maximum val (max: Double, min: Double) = self.mapPartitions { items => Iterator(items.foldRight(Double.NegativeInfinity, Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => @@ -107,9 +112,11 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { throw new UnsupportedOperationException( "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") } - val increment = (max-min)/bucketCount.toDouble - val range = if (increment != 0) { - Range.Double.inclusive(min, max, increment) + val range = if (min != max) { + // Range.Double.inclusive(min, max, increment) + // The above code doesn't always work. See Scala bug #SI-8782. + // https://issues.scala-lang.org/browse/SI-8782 + customRange(min, max, bucketCount) } else { List(min, min) } @@ -119,11 +126,11 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** * Compute a histogram using the provided buckets. The buckets are all open - * to the left except for the last which is closed + * to the right except for the last which is closed * e.g. for the array * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] - * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 + * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 + * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 8d92ea01d9a3f..c8623314c98eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -197,7 +197,7 @@ class HadoopRDD[K, V]( reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => closeIfNeeded() } + context.addTaskCompletionListener{ context => closeIfNeeded() } val key: K = reader.createKey() val value: V = reader.createValue() diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 8947e66f4577c..0e38f224ac81d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -68,7 +68,7 @@ class JdbcRDD[T: ClassTag]( } override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { - context.addOnCompleteCallback{ () => closeIfNeeded() } + context.addTaskCompletionListener{ context => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7dfec9a18ec67..58f707b9b4634 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -129,7 +129,7 @@ class NewHadoopRDD[K, V]( context.taskMetrics.inputMetrics = Some(inputMetrics) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) + context.addTaskCompletionListener(context => close()) var havePair = false var finished = false diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 93af50c0a9cd1..f6d9d12fe9006 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -197,33 +197,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling). * * Create a sample of this RDD using variable sampling rates for different keys as specified by - * `fractions`, a key to sampling rate map. - * - * If `exact` is set to false, create the sample via simple random sampling, with one pass - * over the RDD, to produce a sample of size that's approximately equal to the sum of - * math.ceil(numItems * samplingRate) over all key values; otherwise, use - * additional passes over the RDD to create a sample size that's exactly equal to the sum of - * math.ceil(numItems * samplingRate) over all key values with a 99.99% confidence. When sampling - * without replacement, we need one additional pass over the RDD to guarantee sample size; - * when sampling with replacement, we need two additional passes. + * `fractions`, a key to sampling rate map, via simple random sampling with one pass over the + * RDD, to produce a sample of size that's approximately equal to the sum of + * math.ceil(numItems * samplingRate) over all key values. * * @param withReplacement whether to sample with or without replacement * @param fractions map of specific keys to sampling rates * @param seed seed for the random number generator - * @param exact whether sample size needs to be exactly math.ceil(fraction * size) per key * @return RDD containing the sampled subset */ def sampleByKey(withReplacement: Boolean, fractions: Map[K, Double], - exact: Boolean = false, - seed: Long = Utils.random.nextLong): RDD[(K, V)]= { + seed: Long = Utils.random.nextLong): RDD[(K, V)] = { require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") val samplingFunc = if (withReplacement) { - StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, exact, seed) + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, false, seed) } else { - StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, exact, seed) + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, false, seed) + } + self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) + } + + /** + * ::Experimental:: + * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly + * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). + * + * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) + * over all key values with a 99.99% confidence. When sampling without replacement, we need one + * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need + * two additional passes. + * + * @param withReplacement whether to sample with or without replacement + * @param fractions map of specific keys to sampling rates + * @param seed seed for the random number generator + * @return RDD containing the sampled subset + */ + @Experimental + def sampleByKeyExact(withReplacement: Boolean, + fractions: Map[K, Double], + seed: Long = Utils.random.nextLong): RDD[(K, V)] = { + + require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.") + + val samplingFunc = if (withReplacement) { + StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, true, seed) + } else { + StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, true, seed) } self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true) } @@ -237,6 +260,25 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey[V]((v: V) => v, func, func, partitioner) } + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. + */ + def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { + reduceByKey(new HashPartitioner(numPartitions), func) + } + + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * parallelism level. + */ + def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { + reduceByKey(defaultPartitioner(self), func) + } + /** * Merge the values for each key using an associative reduce function, but return the results * immediately to the master as a Map. This will also perform the merging locally on each mapper @@ -374,15 +416,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) countApproxDistinctByKey(relativeSD, defaultPartitioner(self)) } - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. - */ - def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { - reduceByKey(new HashPartitioner(numPartitions), func) - } - /** * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. @@ -482,16 +515,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } - /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ - * parallelism level. - */ - def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { - reduceByKey(defaultPartitioner(self), func) - } - /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e1c49e35abecd..daea2617e62ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1004,7 +1004,7 @@ abstract class RDD[T: ClassTag]( }, (h1: HyperLogLogPlus, h2: HyperLogLogPlus) => { h1.addAll(h2) - h2 + h1 }).cardinality() } @@ -1233,6 +1233,11 @@ abstract class RDD[T: ClassTag]( dependencies.head.rdd.asInstanceOf[RDD[U]] } + /** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */ + protected[spark] def parent[U: ClassTag](j: Int) = { + dependencies(j).rdd.asInstanceOf[RDD[U]] + } + /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */ def context = sc @@ -1294,6 +1299,19 @@ abstract class RDD[T: ClassTag]( /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString: String = { + // Get a debug description of an rdd without its children + def debugSelf (rdd: RDD[_]): Seq[String] = { + import Utils.bytesToString + + val persistence = storageLevel.description + val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => + " CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format( + info.numCachedPartitions, bytesToString(info.memSize), + bytesToString(info.tachyonSize), bytesToString(info.diskSize))) + + s"$rdd [$persistence]" +: storageInfo + } + // Apply a different rule to the last child def debugChildren(rdd: RDD[_], prefix: String): Seq[String] = { val len = rdd.dependencies.length @@ -1319,7 +1337,11 @@ abstract class RDD[T: ClassTag]( val partitionStr = "(" + rdd.partitions.size + ")" val leftOffset = (partitionStr.length - 1) / 2 val nextPrefix = (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset)) - Seq(partitionStr + " " + rdd) ++ debugChildren(rdd, nextPrefix) + + debugSelf(rdd).zipWithIndex.map{ + case (desc: String, 0) => s"$partitionStr $desc" + case (desc: String, _) => s"$nextPrefix $desc" + } ++ debugChildren(rdd, nextPrefix) } def shuffleDebugString(rdd: RDD[_], prefix: String = "", isLastChild: Boolean): Seq[String] = { val partitionStr = "(" + rdd.partitions.size + ")" @@ -1329,7 +1351,11 @@ abstract class RDD[T: ClassTag]( thisPrefix + (if (isLastChild) " " else "| ") + (" " * leftOffset) + "|" + (" " * (partitionStr.length - leftOffset))) - Seq(thisPrefix + "+-" + partitionStr + " " + rdd) ++ debugChildren(rdd, nextPrefix) + + debugSelf(rdd).zipWithIndex.map{ + case (desc: String, 0) => s"$thisPrefix+-$partitionStr $desc" + case (desc: String, _) => s"$nextPrefix$desc" + } ++ debugChildren(rdd, nextPrefix) } def debugString(rdd: RDD[_], prefix: String = "", @@ -1337,9 +1363,8 @@ abstract class RDD[T: ClassTag]( isLastChild: Boolean = false): Seq[String] = { if (isShuffle) { shuffleDebugString(rdd, prefix, isLastChild) - } - else { - Seq(prefix + rdd) ++ debugChildren(rdd, prefix) + } else { + debugSelf(rdd).map(prefix + _) ++ debugChildren(rdd, prefix) } } firstDebugString(this).mkString("\n") diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 197167ecad0bd..0c97eb0aaa51f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -83,8 +83,7 @@ class UnionRDD[T: ClassTag]( override def compute(s: Partition, context: TaskContext): Iterator[T] = { val part = s.asInstanceOf[UnionPartition[T]] - val parentRdd = dependencies(part.parentRddIndex).rdd.asInstanceOf[RDD[T]] - parentRdd.iterator(part.parentPartition, context) + parent[T](part.parentRddIndex).iterator(part.parentPartition, context) } override def getPreferredLocations(s: Partition): Seq[String] = diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 430e45ada5808..b86cfbfa48fbe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -121,6 +121,9 @@ class DAGScheduler( private[scheduler] var eventProcessActor: ActorRef = _ + /** If enabled, we may run certain actions like take() and first() locally. */ + private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) + private def initializeEventProcessActor() { // blocking the thread until supervisor is started, which ensures eventProcessActor is // not null before any job is submitted @@ -631,7 +634,7 @@ class DAGScheduler( val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { - taskContext.executeOnCompleteCallbacks() + taskContext.markTaskCompleted() } } catch { case e: Exception => @@ -732,7 +735,9 @@ class DAGScheduler( logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) - if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { + val shouldRunLocally = + localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 + if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) runLocally(job) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 5878e733908f5..94944399b134a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -24,8 +24,8 @@ import org.apache.spark.metrics.source.Source private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: SparkContext) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.DAGScheduler".format(sc.appName) + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.DAGScheduler".format(sc.appName) metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] { override def getValue: Int = dagScheduler.failedStages.size diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 406147f167bf3..370fcd85aa680 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -54,7 +54,8 @@ private[spark] class EventLoggingListener( private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 private val logBaseDir = sparkConf.get("spark.eventLog.dir", DEFAULT_LOG_DIR).stripSuffix("/") - private val name = appName.replaceAll("[ :/]", "-").toLowerCase + "-" + System.currentTimeMillis + private val name = appName.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_") + .toLowerCase + "-" + System.currentTimeMillis val logDir = Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/") protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize, @@ -127,6 +128,8 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) override def onApplicationEnd(event: SparkListenerApplicationEnd) = logEvent(event, flushLogger = true) + // No-op because logging every update would be overkill + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { } /** * Stop logging events. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 47dd112f68325..4d6b5c81883b6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -162,6 +162,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime + val gcTime = " GC_TIME=" + taskMetrics.jvmGCTime val inputMetrics = taskMetrics.inputMetrics match { case Some(metrics) => " READ_METHOD=" + metrics.readMethod.toString + @@ -179,11 +180,13 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener case None => "" } val writeMetrics = taskMetrics.shuffleWriteMetrics match { - case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + case Some(metrics) => + " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + + " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime case None => "" } - stageLogInfo(stageId, status + info + executorRunTime + inputMetrics + shuffleReadMetrics + - writeMetrics) + stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + + shuffleReadMetrics + writeMetrics) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index d09fd7aa57642..2ccbd8edeb028 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -61,7 +61,7 @@ private[spark] class ResultTask[T, U]( try { func(context, rdd.iterator(partition, context)) } finally { - context.executeOnCompleteCallbacks() + context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 11255c07469d4..381eff2147e95 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask( } throw e } finally { - context.executeOnCompleteCallbacks() + context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c5e421404a21..6aa0cca06878d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -46,7 +46,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) - context.taskMetrics.hostname = Utils.localHostName(); + context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { kill(interruptThread = false) @@ -87,7 +87,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex def kill(interruptThread: Boolean) { _killed = true if (context != null) { - context.interrupted = true + context.markInterrupted() } if (interruptThread && taskThread != null) { taskThread.interrupt() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 20a4bd12f93f6..d9d53faf843ff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -690,8 +690,7 @@ private[spark] class TaskSetManager( handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) } // recalculate valid locality levels and waits when executor is lost - myLocalityLevels = computeValidLocalityLevels() - localityWaits = myLocalityLevels.map(getLocalityWait) + recomputeLocality() } /** @@ -775,9 +774,15 @@ private[spark] class TaskSetManager( levels.toArray } - def executorAdded() { + def recomputeLocality() { + val previousLocalityLevel = myLocalityLevels(currentLocalityIndex) myLocalityLevels = computeValidLocalityLevels() localityWaits = myLocalityLevels.map(getLocalityWait) + currentLocalityIndex = getLocalityIndex(previousLocalityLevel) + } + + def executorAdded() { + recomputeLocality() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9f085eef46720..2a3711ae2a78c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -30,7 +30,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} import org.apache.spark.ui.JettyUtils /** @@ -47,21 +47,24 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) - var totalExpectedExecutors = new AtomicInteger(0) + var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - // Submit tasks only after (registered executors / total expected executors) + // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. - var minRegisteredRatio = conf.getDouble("spark.scheduler.minRegisteredExecutorsRatio", 0) - if (minRegisteredRatio > 1) minRegisteredRatio = 1 - // Whatever minRegisteredExecutorsRatio is arrived, submit tasks after the time(milliseconds). + var minRegisteredRatio = + math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) + // Submit tasks after maxRegisteredWaitingTime milliseconds + // if minRegisteredRatio has not yet been reached val maxRegisteredWaitingTime = - conf.getInt("spark.scheduler.maxRegisteredExecutorsWaitingTime", 30000) + conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() - var ready = if (minRegisteredRatio <= 0) true else false - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor { + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + + override protected def log = CoarseGrainedSchedulerBackend.this.log + private val executorActor = new HashMap[String, ActorRef] private val executorAddress = new HashMap[String, Address] private val executorHost = new HashMap[String, String] @@ -79,7 +82,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } - def receive = { + def receiveWithLogging = { case RegisterExecutor(executorId, hostPort, cores) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorActor.contains(executorId)) { @@ -94,12 +97,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A executorAddress(executorId) = sender.path.address addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) - if (executorActor.size >= totalExpectedExecutors.get() * minRegisteredRatio && !ready) { - ready = true - logInfo("SchedulerBackend is ready for scheduling beginning, registered executors: " + - executorActor.size + ", total expected executors: " + totalExpectedExecutors.get() + - ", minRegisteredExecutorsRatio: " + minRegisteredRatio) - } + totalRegisteredExecutors.addAndGet(1) makeOffers() } @@ -268,14 +266,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } } + def sufficientResourcesRegistered(): Boolean = true + override def isReady(): Boolean = { - if (ready) { + if (sufficientResourcesRegistered) { + logInfo("SchedulerBackend is ready for scheduling beginning after " + + s"reached minRegisteredResourcesRatio: $minRegisteredRatio") return true } if ((System.currentTimeMillis() - createTime) >= maxRegisteredWaitingTime) { - ready = true logInfo("SchedulerBackend is ready for scheduling beginning after waiting " + - "maxRegisteredExecutorsWaitingTime: " + maxRegisteredWaitingTime) + s"maxRegisteredResourcesWaitingTime: $maxRegisteredWaitingTime(ms)") return true } false diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index a28446f6c8a6b..589dba2e40d20 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -36,6 +36,7 @@ private[spark] class SparkDeploySchedulerBackend( var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = conf.getOption("spark.cores.max").map(_.toInt) + val totalExpectedCores = maxCores.getOrElse(0) override def start() { super.start() @@ -97,7 +98,6 @@ private[spark] class SparkDeploySchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - totalExpectedExecutors.addAndGet(1) logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( fullId, hostPort, cores, Utils.megabytesToString(memory))) } @@ -110,4 +110,8 @@ private[spark] class SparkDeploySchedulerBackend( logInfo("Executor %s removed: %s".format(fullId, message)) removeExecutor(fullId.split("/")(1), reason.toString) } + + override def sufficientResourcesRegistered(): Boolean = { + totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 3d1cf312ccc97..bec9502f20466 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -23,9 +23,9 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.{TaskMetrics, Executor, ExecutorBackend} +import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -43,7 +43,7 @@ private case class StopExecutor() private[spark] class LocalActor( scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, - private val totalCores: Int) extends Actor with Logging { + private val totalCores: Int) extends Actor with ActorLogReceive with Logging { private var freeCores = totalCores @@ -53,7 +53,7 @@ private[spark] class LocalActor( val executor = new Executor( localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true) - def receive = { + override def receiveWithLogging = { case ReviveOffers => reviveOffers() diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 34bc3124097bb..554a33ce7f1a6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -63,8 +63,11 @@ extends DeserializationStream { def close() { objIn.close() } } -private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance { - def serialize[T: ClassTag](t: T): ByteBuffer = { + +private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) + extends SerializerInstance { + + override def serialize[T: ClassTag](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) out.writeObject(t) @@ -72,23 +75,23 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize ByteBuffer.wrap(bos.toByteArray) } - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis) - in.readObject().asInstanceOf[T] + in.readObject() } - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val bis = new ByteBufferInputStream(bytes) val in = deserializeStream(bis, loader) - in.readObject().asInstanceOf[T] + in.readObject() } - def serializeStream(s: OutputStream): SerializationStream = { + override def serializeStream(s: OutputStream): SerializationStream = { new JavaSerializationStream(s, counterReset) } - def deserializeStream(s: InputStream): DeserializationStream = { + override def deserializeStream(s: InputStream): DeserializationStream = { new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader) } @@ -109,7 +112,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) - def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset) + override def newInstance(): SerializerInstance = { + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) + new JavaSerializerInstance(counterReset, classLoader) + } override def writeExternal(out: ObjectOutput) { out.writeInt(counterReset) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 407cb9db6ee9a..87ef9bb0b43c6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf) val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() kryo.setRegistrationRequired(registrationRequired) - val classLoader = Thread.currentThread.getContextClassLoader + + val oldClassLoader = Thread.currentThread.getContextClassLoader + val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. @@ -79,15 +81,21 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) // Allow the user to register their own classes by setting spark.kryo.registrator - try { - for (regCls <- registrator) { - logDebug("Running user registrator: " + regCls) + for (regCls <- registrator) { + logDebug("Running user registrator: " + regCls) + try { val reg = Class.forName(regCls, true, classLoader).newInstance() .asInstanceOf[KryoRegistrator] + + // Use the default classloader when calling the user registrator. + Thread.currentThread.setContextClassLoader(classLoader) reg.registerClasses(kryo) + } catch { + case e: Exception => + throw new SparkException(s"Failed to invoke $regCls", e) + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) } - } catch { - case e: Exception => logError("Failed to run spark.kryo.registrator", e) } // Register Chill's classes; we do this after our ranges and the user's own classes to let @@ -98,7 +106,7 @@ class KryoSerializer(conf: SparkConf) kryo } - def newInstance(): SerializerInstance = { + override def newInstance(): SerializerInstance = { new KryoSerializerInstance(this) } } @@ -107,20 +115,20 @@ private[spark] class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { val output = new KryoOutput(outStream) - def writeObject[T: ClassTag](t: T): SerializationStream = { + override def writeObject[T: ClassTag](t: T): SerializationStream = { kryo.writeClassAndObject(output, t) this } - def flush() { output.flush() } - def close() { output.close() } + override def flush() { output.flush() } + override def close() { output.close() } } private[spark] class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { - val input = new KryoInput(inStream) + private val input = new KryoInput(inStream) - def readObject[T: ClassTag](): T = { + override def readObject[T: ClassTag](): T = { try { kryo.readClassAndObject(input).asInstanceOf[T] } catch { @@ -130,31 +138,31 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser } } - def close() { + override def close() { // Kryo's Input automatically closes the input stream it is using. input.close() } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.newKryo() + private val kryo = ks.newKryo() // Make these lazy vals to avoid creating a buffer unless we use them - lazy val output = ks.newKryoOutput() - lazy val input = new KryoInput() + private lazy val output = ks.newKryoOutput() + private lazy val input = new KryoInput() - def serialize[T: ClassTag](t: T): ByteBuffer = { + override def serialize[T: ClassTag](t: T): ByteBuffer = { output.clear() kryo.writeClassAndObject(output, t) ByteBuffer.wrap(output.toBytes) } - def deserialize[T: ClassTag](bytes: ByteBuffer): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = { input.setBuffer(bytes.array) kryo.readClassAndObject(input).asInstanceOf[T] } - def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = { val oldClassLoader = kryo.getClassLoader kryo.setClassLoader(loader) input.setBuffer(bytes.array) @@ -163,11 +171,11 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ obj } - def serializeStream(s: OutputStream): SerializationStream = { + override def serializeStream(s: OutputStream): SerializationStream = { new KryoSerializationStream(kryo, s) } - def deserializeStream(s: InputStream): DeserializationStream = { + override def deserializeStream(s: InputStream): DeserializationStream = { new KryoDeserializationStream(kryo, s) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index f2f5cea469c61..a9144cdd97b8c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -43,11 +43,30 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator} * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi -trait Serializer { +abstract class Serializer { + + /** + * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should + * make sure it is using this when set. + */ + @volatile protected var defaultClassLoader: Option[ClassLoader] = None + + /** + * Sets a class loader for the serializer to use in deserialization. + * + * @return this Serializer object + */ + def setDefaultClassLoader(classLoader: ClassLoader): Serializer = { + defaultClassLoader = Some(classLoader) + this + } + + /** Creates a new [[SerializerInstance]]. */ def newInstance(): SerializerInstance } +@DeveloperApi object Serializer { def getSerializer(serializer: Serializer): Serializer = { if (serializer == null) SparkEnv.get.serializer else serializer @@ -64,7 +83,7 @@ object Serializer { * An instance of a serializer, for use by one thread at a time. */ @DeveloperApi -trait SerializerInstance { +abstract class SerializerInstance { def serialize[T: ClassTag](t: T): ByteBuffer def deserialize[T: ClassTag](bytes: ByteBuffer): T @@ -74,21 +93,6 @@ trait SerializerInstance { def serializeStream(s: OutputStream): SerializationStream def deserializeStream(s: InputStream): DeserializationStream - - def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = { - // Default implementation uses serializeStream - val stream = new ByteArrayOutputStream() - serializeStream(stream).writeAll(iterator) - val buffer = ByteBuffer.wrap(stream.toByteArray) - buffer.flip() - buffer - } - - def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - // Default implementation uses deserializeStream - buffer.rewind() - deserializeStream(new ByteBufferInputStream(buffer)).asIterator - } } /** @@ -96,7 +100,7 @@ trait SerializerInstance { * A stream for writing serialized objects. */ @DeveloperApi -trait SerializationStream { +abstract class SerializationStream { def writeObject[T: ClassTag](t: T): SerializationStream def flush(): Unit def close(): Unit @@ -115,7 +119,7 @@ trait SerializationStream { * A stream for reading serialized objects. */ @DeveloperApi -trait DeserializationStream { +abstract class DeserializationStream { def readObject[T: ClassTag](): T def close(): Unit diff --git a/core/src/main/scala/org/apache/spark/serializer/package-info.java b/core/src/main/scala/org/apache/spark/serializer/package-info.java index 4c0b73ab36a00..207c6e02e4293 100644 --- a/core/src/main/scala/org/apache/spark/serializer/package-info.java +++ b/core/src/main/scala/org/apache/spark/serializer/package-info.java @@ -18,4 +18,4 @@ /** * Pluggable serializers for RDD and shuffle data. */ -package org.apache.spark.serializer; \ No newline at end of file +package org.apache.spark.serializer; diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 99788828981c7..12b475658e29d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) + serializer: Serializer, + shuffleMetrics: ShuffleReadMetrics) : Iterator[T] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) @@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) + val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { - val shuffleMetrics = new ShuffleReadMetrics - shuffleMetrics.shuffleFinishTime = System.currentTimeMillis - shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime - shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead - shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks - shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics) + context.taskMetrics.updateShuffleReadMetrics() }) new InterruptibleIterator[T](context, completionIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 7c9dc8e5f88ef..7bed97a63f0f6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser, + readMetrics) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { @@ -58,7 +60,7 @@ private[spark] class HashShuffleReader[K, C]( // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, // the ExternalSorter won't spill to disk. val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) - sorter.write(aggregatedIter) + sorter.insertAll(aggregatedIter) context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled sorter.iterator diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 45d3b8b9b8725..51e454d9313c9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -39,10 +39,14 @@ private[spark] class HashShuffleWriter[K, V]( // we don't try deleting files, etc twice. private var stopping = false + private val writeMetrics = new ShuffleWriteMetrics() + metrics.shuffleWriteMetrics = Some(writeMetrics) + private val blockManager = SparkEnv.get.blockManager private val shuffleBlockManager = blockManager.shuffleBlockManager private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) - private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser) + private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser, + writeMetrics) /** Write a bunch of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { @@ -99,22 +103,12 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - var totalTime = 0L val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => writer.commitAndClose() val size = writer.fileSegment().length - totalBytes += size - totalTime += writer.timeWriting() MapOutputTracker.compressSize(size) } - // Update shuffle metrics. - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - metrics.shuffleWriteMetrics = Some(shuffleMetrics) - new MapStatus(blockManager.blockManagerId, compressedSizes) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 24db2f287a47b..22f656fa371ea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -44,6 +44,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private var sorter: ExternalSorter[K, V, _] = null private var outputFile: File = null + private var indexFile: File = null // Are we in the process of stopping? Because map tasks can call stop() with success = true // and then call stop() with success = false if they get an exception, we want to make sure @@ -52,89 +53,41 @@ private[spark] class SortShuffleWriter[K, V, C]( private var mapStatus: MapStatus = null + private val writeMetrics = new ShuffleWriteMetrics() + context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics) + /** Write a bunch of records to this task's output */ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = { - // Get an iterator with the elements for each partition ID - val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = { - if (dep.mapSideCombine) { - if (!dep.aggregator.isDefined) { - throw new IllegalStateException("Aggregator is empty for map-side combine") - } - sorter = new ExternalSorter[K, V, C]( - dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) - sorter.write(records) - sorter.partitionedIterator - } else { - // In this case we pass neither an aggregator nor an ordering to the sorter, because we - // don't care whether the keys get sorted in each partition; that will be done on the - // reduce side if the operation being run is sortByKey. - sorter = new ExternalSorter[K, V, V]( - None, Some(dep.partitioner), None, dep.serializer) - sorter.write(records) - sorter.partitionedIterator + if (dep.mapSideCombine) { + if (!dep.aggregator.isDefined) { + throw new IllegalStateException("Aggregator is empty for map-side combine") } + sorter = new ExternalSorter[K, V, C]( + dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + sorter.insertAll(records) + } else { + // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't + // care whether the keys get sorted in each partition; that will be done on the reduce side + // if the operation being run is sortByKey. + sorter = new ExternalSorter[K, V, V]( + None, Some(dep.partitioner), None, dep.serializer) + sorter.insertAll(records) } // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later // serve different ranges of this file using an index file that we create at the end. val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0) - outputFile = blockManager.diskBlockManager.getFile(blockId) - - // Track location of each range in the output file - val offsets = new Array[Long](numPartitions + 1) - val lengths = new Array[Long](numPartitions) - - // Statistics - var totalBytes = 0L - var totalTime = 0L - - for ((id, elements) <- partitions) { - if (elements.hasNext) { - val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize) - for (elem <- elements) { - writer.write(elem) - } - writer.commitAndClose() - val segment = writer.fileSegment() - offsets(id + 1) = segment.offset + segment.length - lengths(id) = segment.length - totalTime += writer.timeWriting() - totalBytes += segment.length - } else { - // The partition is empty; don't create a new writer to avoid writing headers, etc - offsets(id + 1) = offsets(id) - } - } - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics) - context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled - context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled + outputFile = blockManager.diskBlockManager.getFile(blockId) + indexFile = blockManager.diskBlockManager.getFile(blockId.name + ".index") - // Write an index file with the offsets of each block, plus a final offset at the end for the - // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure - // out where each block begins and ends. - - val diskBlockManager = blockManager.diskBlockManager - val indexFile = diskBlockManager.getFile(blockId.name + ".index") - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) - try { - var i = 0 - while (i < numPartitions + 1) { - out.writeLong(offsets(i)) - i += 1 - } - } finally { - out.close() - } + val partitionLengths = sorter.writePartitionedFile(blockId, context) // Register our map output with the ShuffleBlockManager, which handles cleaning it over time blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions) mapStatus = new MapStatus(blockManager.blockManagerId, - lengths.map(MapOutputTracker.compressSize)) + partitionLengths.map(MapOutputTracker.compressSize)) } /** Close this writer, passing along whether the map completed */ @@ -151,6 +104,9 @@ private[spark] class SortShuffleWriter[K, V, C]( if (outputFile != null) { outputFile.delete() } + if (indexFile != null) { + indexFile.delete() + } return None } } finally { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala similarity index 66% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala rename to core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala index 2deaf4ae8dcab..5b6d086630834 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala @@ -15,14 +15,18 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.model +package org.apache.spark.storage + +import java.nio.ByteBuffer + /** - * Filter specifying a split and type of comparison to be applied on features - * @param split split specifying the feature index, type and threshold - * @param comparison integer specifying <,=,> + * An interface for providing data for blocks. + * + * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer. + * + * Aside from unit tests, [[BlockManager]] is the main class that implements this. */ -private[tree] case class Filter(split: Split, comparison: Int) { - // Comparison -1,0,1 signifies <.=,> - override def toString = " split = " + split + "comparison = " + comparison +private[spark] trait BlockDataProvider { + def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index ccf830e118ee7..ca60ec78b62ee 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -18,17 +18,17 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue +import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue - -import io.netty.buffer.ByteBuf +import scala.util.{Failure, Success} import org.apache.spark.{Logging, SparkException} +import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.network.BufferMessage import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.network.netty.ShuffleCopier import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils @@ -46,39 +46,43 @@ import org.apache.spark.util.Utils private[storage] trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { def initialize() - def numLocalBlocks: Int - def numRemoteBlocks: Int - def fetchWaitTime: Long - def remoteBytesRead: Long } private[storage] object BlockFetcherIterator { - // A request to fetch one or more blocks, complete with their sizes + /** + * A request to fetch blocks from a remote BlockManager. + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } - // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize - // the block (since we want all deserializaton to happen in the calling thread); can also - // represent a fetch failure if size == -1. + /** + * Result of a fetch from a remote block. A failure is represented as size == -1. + * @param blockId block id + * @param size estimated size of the block, used to calculate bytesInFlight. + * Note that this is NOT the exact bytes. + * @param deserialize closure to return the result in the form of an Iterator. + */ class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { def failed: Boolean = size == -1 } + // TODO: Refactor this whole thing to make code more reusable. class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer) + serializer: Serializer, + readMetrics: ShuffleReadMetrics) extends BlockFetcherIterator { import blockManager._ - private var _remoteBytesRead = 0L - private var _fetchWaitTime = 0L - if (blocksByAddress == null) { throw new IllegalArgumentException("BlocksByAddress is null") } @@ -88,13 +92,9 @@ object BlockFetcherIterator { protected var startTime = System.currentTimeMillis - // This represents the number of local blocks, also counting zero-sized blocks - private var numLocal = 0 // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks protected val localBlocksToFetch = new ArrayBuffer[BlockId]() - // This represents the number of remote blocks, also counting zero-sized blocks - private var numRemote = 0 // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks protected val remoteBlocksToFetch = new HashSet[BlockId]() @@ -103,10 +103,10 @@ object BlockFetcherIterator { // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that // the number of bytes in flight is limited to maxBytesInFlight - private val fetchRequests = new Queue[FetchRequest] + protected val fetchRequests = new Queue[FetchRequest] // Current bytes in flight from our requests - private var bytesInFlight = 0L + protected var bytesInFlight = 0L protected def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( @@ -118,8 +118,8 @@ object BlockFetcherIterator { bytesInFlight += req.size val sizeMap = req.blocks.toMap // so we can look up the size of each blockID val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onSuccess { - case Some(message) => { + future.onComplete { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) for (blockMessage <- blockMessageArray) { @@ -131,12 +131,15 @@ object BlockFetcherIterator { val networkSize = blockMessage.getData.limit() results.put(new FetchResult(blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize + // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can + // be incrementing bytes read at the same time (SPARK-2625). + readMetrics.remoteBytesRead += networkSize + readMetrics.remoteBlocksFetched += 1 logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } - case None => { - logError("Could not get block(s) from " + cmId) + case Failure(exception) => { + logError("Could not get block(s) from " + cmId, exception) for ((blockId, size) <- req.blocks) { results.put(new FetchResult(blockId, -1, null)) } @@ -154,14 +157,14 @@ object BlockFetcherIterator { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] + var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { + totalBlocks += blockInfos.size if (address == blockManagerId) { - numLocal = blockInfos.size // Filter out zero-sized blocks localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) _numBlocksToFetch += localBlocksToFetch.size } else { - numRemote += blockInfos.size val iterator = blockInfos.iterator var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] @@ -191,7 +194,7 @@ object BlockFetcherIterator { } } logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " + - (numLocal + numRemote) + " blocks") + totalBlocks + " blocks") remoteRequests } @@ -204,6 +207,7 @@ object BlockFetcherIterator { // getLocalFromDisk never return None but throws BlockException val iter = getLocalFromDisk(id, serializer).get // Pass 0 as size since it's not in flight + readMetrics.localBlocksFetched += 1 results.put(new FetchResult(id, 0, () => iter)) logDebug("Got local block " + id) } catch { @@ -237,12 +241,6 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - override def numLocalBlocks: Int = numLocal - override def numRemoteBlocks: Int = numRemote - override def fetchWaitTime: Long = _fetchWaitTime - override def remoteBytesRead: Long = _remoteBytesRead - - // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue // as they arrive. @volatile protected var resultsGotten = 0 @@ -254,7 +252,7 @@ object BlockFetcherIterator { val startFetchWait = System.currentTimeMillis() val result = results.take() val stopFetchWait = System.currentTimeMillis() - _fetchWaitTime += (stopFetchWait - startFetchWait) + readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) if (! result.failed) bytesInFlight -= result.size while (!fetchRequests.isEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { @@ -268,80 +266,62 @@ object BlockFetcherIterator { class NettyBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { + serializer: Serializer, + readMetrics: ShuffleReadMetrics) + extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { - import blockManager._ + override protected def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest] + bytesInFlight += req.size + val sizeMap = req.blocks.toMap // so we can look up the size of each blockID + + // This could throw a TimeoutException. In that case we will just retry the task. + val client = blockManager.nettyBlockClientFactory.createClient( + cmId.host, req.address.nettyPort) + val blocks = req.blocks.map(_._1.toString) + + client.fetchBlocks( + blocks, + new BlockClientListener { + override def onFetchFailure(blockId: String, errorMsg: String): Unit = { + logError(s"Could not get block(s) from $cmId with error: $errorMsg") + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } - private def startCopiers(numCopiers: Int): List[_ <: Thread] = { - (for ( i <- Range(0,numCopiers) ) yield { - val copier = new Thread { - override def run(){ - try { - while(!isInterrupted && !fetchRequestsSync.isEmpty) { - sendRequest(fetchRequestsSync.take()) + override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { + // Increment the reference count so the buffer won't be recycled. + // TODO: This could result in memory leaks when the task is stopped due to exception + // before the iterator is exhausted. + data.retain() + val buf = data.byteBuffer() + val blockSize = buf.remaining() + val bid = BlockId(blockId) + + // TODO: remove code duplication between here and BlockManager.dataDeserialization. + results.put(new FetchResult(bid, sizeMap(bid), () => { + def createIterator: Iterator[Any] = { + val stream = blockManager.wrapForCompression(bid, data.inputStream()) + serializer.newInstance().deserializeStream(stream).asIterator + } + new LazyInitIterator(createIterator) { + // Release the buffer when we are done traversing it. + override def close(): Unit = data.release() } - } catch { - case x: InterruptedException => logInfo("Copier Interrupted") - // case _ => throw new SparkException("Exception Throw in Shuffle Copier") + })) + + readMetrics.synchronized { + readMetrics.remoteBytesRead += blockSize + readMetrics.remoteBlocksFetched += 1 } + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } - copier.start - copier - }).toList - } - - // keep this to interrupt the threads when necessary - private def stopCopiers() { - for (copier <- copiers) { - copier.interrupt() - } - } - - override protected def sendRequest(req: FetchRequest) { - - def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) { - val fetchResult = new FetchResult(blockId, blockSize, - () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) - results.put(fetchResult) - } - - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.host)) - val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort) - val cpier = new ShuffleCopier(blockManager.conf) - cpier.getBlocks(cmId, req.blocks, putResult) - logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host ) - } - - private var copiers: List[_ <: Thread] = null - - override def initialize() { - // Split Local Remote Blocks and set numBlocksToFetch - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - for (request <- Utils.randomize(remoteRequests)) { - fetchRequestsSync.put(request) - } - - copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6)) - logInfo("Started " + fetchRequestsSync.size + " remote fetches in " + - Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - override def next(): (BlockId, Option[Iterator[Any]]) = { - resultsGotten += 1 - val result = results.take() - // If all the results has been retrieved, copiers will exit automatically - (result.blockId, if (result.failed) None else Some(result.deserialize())) + ) } } // End of NettyBlockFetcherIterator diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3876cf43e2a7d..12a92d44f4c36 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -25,16 +25,20 @@ import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Cancellable, Props} +import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ +import org.apache.spark.network.netty.client.BlockFetchingClientFactory +import org.apache.spark.network.netty.server.BlockServer import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ + private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues @@ -57,13 +61,13 @@ private[spark] class BlockManager( maxMemory: Long, val conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) - extends Logging { + mapOutputTracker: MapOutputTracker, + shuffleManager: ShuffleManager) + extends BlockDataProvider with Logging { private val port = conf.getInt("spark.blockManager.port", 0) - val shuffleBlockManager = new ShuffleBlockManager(this) - val diskBlockManager = new DiskBlockManager(shuffleBlockManager, - conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) + val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) val connectionManager = new ConnectionManager(port, conf, securityManager, "Connection manager for block manager") @@ -86,13 +90,25 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } + private val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) + // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private val nettyPort: Int = { - val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) - val nettyPortConfig = conf.getInt("spark.shuffle.sender.port", 0) - if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 + private[storage] val nettyBlockClientFactory: BlockFetchingClientFactory = { + if (useNetty) new BlockFetchingClientFactory(conf) else null } + private val nettyBlockServer: BlockServer = { + if (useNetty) { + val server = new BlockServer(conf, this) + logInfo(s"Created NettyBlockServer binding to port: ${server.port}") + server + } else { + null + } + } + + private val nettyPort: Int = if (useNetty) nettyBlockServer.port else 0 + val blockManagerId = BlockManagerId( executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) @@ -142,9 +158,10 @@ private[spark] class BlockManager( serializer: Serializer, conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker) = { + mapOutputTracker: MapOutputTracker, + shuffleManager: ShuffleManager) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager, mapOutputTracker) + conf, securityManager, mapOutputTracker, shuffleManager) } /** @@ -216,6 +233,20 @@ private[spark] class BlockManager( } } + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + val bid = BlockId(blockId) + if (bid.isShuffle) { + Left(diskBlockManager.getBlockLocation(bid)) + } else { + val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + if (blockBytesOpt.isDefined) { + Right(blockBytesOpt.get) + } else { + throw new BlockNotFoundException(blockId) + } + } + } + /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. @@ -539,12 +570,15 @@ private[spark] class BlockManager( */ def getMultiple( blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer): BlockFetcherIterator = { + serializer: Serializer, + readMetrics: ShuffleReadMetrics): BlockFetcherIterator = { val iter = if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer, + readMetrics) } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer, + readMetrics) } iter.initialize() iter @@ -562,17 +596,19 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. This is currently used for - * writing shuffle files out. Callers should handle error cases. + * The Block will be appended to the File specified by filename. Callers should handle error + * cases. */ def getDiskWriter( blockId: BlockId, file: File, serializer: Serializer, - bufferSize: Int): BlockObjectWriter = { + bufferSize: Int, + writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites) + new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites, + writeMetrics) } /** @@ -1056,6 +1092,14 @@ private[spark] class BlockManager( connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() + + if (nettyBlockClientFactory != null) { + nettyBlockClientFactory.stop() + } + if (nettyBlockServer != null) { + nettyBlockServer.stop() + } + actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index bd31e3c5a187f..3ab07703b6f85 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -31,7 +31,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} /** * BlockManagerMasterActor is an actor on the master node to track statuses of @@ -39,7 +39,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -55,8 +55,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000)) - val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", - 60000) + val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) var timeoutCheckingTask: Cancellable = null @@ -67,9 +66,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus super.preStart() } - def receive = { + override def receiveWithLogging = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => - logInfo("received a register") register(blockManagerId, maxMemSize, slaveActor) sender ! true @@ -118,7 +116,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus sender ! true case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") sender ! true if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 10b65286fb7db..2ba16b8476600 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -53,7 +53,7 @@ private[spark] object BlockManagerMessages { sender: ActorRef) extends ToBlockManagerMaster - class UpdateBlockInfo( + case class UpdateBlockInfo( var blockManagerId: BlockManagerId, var blockId: BlockId, var storageLevel: StorageLevel, @@ -84,24 +84,6 @@ private[spark] object BlockManagerMessages { } } - object UpdateBlockInfo { - def apply( - blockManagerId: BlockManagerId, - blockId: BlockId, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long, - tachyonSize: Long): UpdateBlockInfo = { - new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize) - } - - // For pattern-matching - def unapply(h: UpdateBlockInfo) - : Option[(BlockManagerId, BlockId, StorageLevel, Long, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize, h.tachyonSize)) - } - } - case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 6d4db064dff58..c194e0fed3367 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -23,6 +23,7 @@ import akka.actor.{ActorRef, Actor} import org.apache.spark.{Logging, MapOutputTracker} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.ActorLogReceive /** * An actor to take commands from the master to execute options. For example, @@ -32,12 +33,12 @@ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends Actor with Logging { + extends Actor with ActorLogReceive with Logging { import context.dispatcher // Operations that involve removing blocks may be slow and should be done asynchronously - override def receive = { + override def receiveWithLogging = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index 3f14c40ec61cb..49fea6d9e2a76 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -24,8 +24,8 @@ import org.apache.spark.metrics.source.Source private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: SparkContext) extends Source { - val metricRegistry = new MetricRegistry() - val sourceName = "%s.BlockManager".format(sc.appName) + override val metricRegistry = new MetricRegistry() + override val sourceName = "%s.BlockManager".format(sc.appName) metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { override def getValue: Long = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index c7766a3a65671..bf002a42d5dc5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -23,6 +23,10 @@ import org.apache.spark.Logging import org.apache.spark.network._ import org.apache.spark.util.Utils +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.{Try, Failure, Success} + /** * A network interface for BlockManager. Each slave should have one * BlockManagerWorker. @@ -44,13 +48,19 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => logError("Exception handling buffer message", e) - None + case e: Exception => { + logError("Exception handling buffer message", e) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } } } case otherMessage: Any => { logError("Unknown type message received: " + otherMessage) - None + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) } } } @@ -109,9 +119,9 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromPutBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) - resultMessage.isDefined + val resultMessage = Try(Await.result(connectionManager.sendMessageReliably( + toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) + resultMessage.isSuccess } def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { @@ -119,10 +129,10 @@ private[spark] object BlockManagerWorker extends Logging { val connectionManager = blockManager.connectionManager val blockMessage = BlockMessage.fromGetBlock(msg) val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage) + val responseMessage = Try(Await.result(connectionManager.sendMessageReliably( + toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) responseMessage match { - case Some(message) => { + case Success(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] logDebug("Response message received " + bufferMessage) BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { @@ -130,7 +140,7 @@ private[spark] object BlockManagerWorker extends Logging { return blockMessage.getData }) } - case None => logDebug("No response message received") + case Failure(exception) => logDebug("No response message received") } null } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala new file mode 100644 index 0000000000000..9ef453605f4f1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + + +class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 01d46e1ffc960..adda971fd7b47 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -22,6 +22,7 @@ import java.nio.channels.FileChannel import org.apache.spark.Logging import org.apache.spark.serializer.{SerializationStream, Serializer} +import org.apache.spark.executor.ShuffleWriteMetrics /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -60,41 +61,26 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) { * This is only valid after commitAndClose() has been called. */ def fileSegment(): FileSegment - - /** - * Cumulative time spent performing blocking writes, in ns. - */ - def timeWriting(): Long - - /** - * Number of bytes written so far - */ - def bytesWritten: Long } -/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */ +/** + * BlockObjectWriter which writes directly to a file on disk. Appends to the given file. + * The given write metrics will be updated incrementally, but will not necessarily be current until + * commitAndClose is called. + */ private[spark] class DiskBlockObjectWriter( blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int, compressStream: OutputStream => OutputStream, - syncWrites: Boolean) + syncWrites: Boolean, + writeMetrics: ShuffleWriteMetrics) extends BlockObjectWriter(blockId) with Logging { - /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { - def timeWriting = _timeWriting - private var _timeWriting = 0L - - private def callWithTiming(f: => Unit) = { - val start = System.nanoTime() - f - _timeWriting += (System.nanoTime() - start) - } - def write(i: Int): Unit = callWithTiming(out.write(i)) override def write(b: Array[Byte]) = callWithTiming(out.write(b)) override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) @@ -111,7 +97,11 @@ private[spark] class DiskBlockObjectWriter( private val initialPosition = file.length() private var finalPosition: Long = -1 private var initialized = false - private var _timeWriting = 0L + + /** Calling channel.position() to update the write metrics can be a little bit expensive, so we + * only call it every N writes */ + private var writesSinceMetricsUpdate = 0 + private var lastPosition = initialPosition override def open(): BlockObjectWriter = { fos = new FileOutputStream(file, true) @@ -128,14 +118,11 @@ private[spark] class DiskBlockObjectWriter( if (syncWrites) { // Force outstanding writes to disk and track how long it takes objOut.flush() - val start = System.nanoTime() - fos.getFD.sync() - _timeWriting += System.nanoTime() - start + def sync = fos.getFD.sync() + callWithTiming(sync) } objOut.close() - _timeWriting += ts.timeWriting - channel = null bs = null fos = null @@ -153,6 +140,7 @@ private[spark] class DiskBlockObjectWriter( // serializer stream and the lower level stream. objOut.flush() bs.flush() + updateBytesWritten() close() } finalPosition = file.length() @@ -162,6 +150,8 @@ private[spark] class DiskBlockObjectWriter( // truncating the file to its initial position. override def revertPartialWritesAndClose() { try { + writeMetrics.shuffleBytesWritten -= (lastPosition - initialPosition) + if (initialized) { objOut.flush() bs.flush() @@ -184,19 +174,36 @@ private[spark] class DiskBlockObjectWriter( if (!initialized) { open() } + objOut.writeObject(value) + + if (writesSinceMetricsUpdate == 32) { + writesSinceMetricsUpdate = 0 + updateBytesWritten() + } else { + writesSinceMetricsUpdate += 1 + } } override def fileSegment(): FileSegment = { - new FileSegment(file, initialPosition, bytesWritten) + new FileSegment(file, initialPosition, finalPosition - initialPosition) } - // Only valid if called after close() - override def timeWriting() = _timeWriting + private def updateBytesWritten() { + val pos = channel.position() + writeMetrics.shuffleBytesWritten += (pos - lastPosition) + lastPosition = pos + } + + private def callWithTiming(f: => Unit) = { + val start = System.nanoTime() + f + writeMetrics.shuffleWriteTime += (System.nanoTime() - start) + } - // Only valid if called after commit() - override def bytesWritten: Long = { - assert(finalPosition != -1, "bytesWritten is only valid after successful commit()") - finalPosition - initialPosition + // For testing + private[spark] def flush() { + objOut.flush() + bs.flush() } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 4d66ccea211fa..ec022ce9c048a 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -21,9 +21,9 @@ import java.io.File import java.text.SimpleDateFormat import java.util.{Date, Random, UUID} -import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.{SparkConf, SparkEnv, Logging} import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.network.netty.{PathResolver, ShuffleSender} +import org.apache.spark.network.netty.PathResolver import org.apache.spark.util.Utils import org.apache.spark.shuffle.sort.SortShuffleManager @@ -33,9 +33,10 @@ import org.apache.spark.shuffle.sort.SortShuffleManager * However, it is also possible to have a block map to only a segment of a file, by calling * mapBlockToFileSegment(). * - * @param rootDirs The directories to use for storing block files. Data will be hashed among these. + * Block files are hashed among the directories listed in spark.local.dir (or in + * SPARK_LOCAL_DIRS, if it's set). */ -private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, rootDirs: String) +private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, conf: SparkConf) extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @@ -46,13 +47,12 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid * having really large inodes at the top level. */ - val localDirs: Array[File] = createLocalDirs() + val localDirs: Array[File] = createLocalDirs(conf) if (localDirs.isEmpty) { logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - private var shuffleSender : ShuffleSender = null addShutdownHook() @@ -131,10 +131,9 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, (blockId, getFile(blockId)) } - private def createLocalDirs(): Array[File] = { - logDebug(s"Creating local directories at root dirs '$rootDirs'") + private def createLocalDirs(conf: SparkConf): Array[File] = { val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").flatMap { rootDir => + Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir => var foundLocalDir = false var localDir: File = null var localDirId: String = null @@ -186,15 +185,5 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, } } } - - if (shuffleSender != null) { - shuffleSender.stop() - } - } - - private[storage] def startShuffleBlockSender(port: Int): Int = { - shuffleSender = new ShuffleSender(port, this) - logInfo(s"Created ShuffleSender binding to port: ${shuffleSender.port}") - shuffleSender.port } } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 28f675c2bbb1e..0a09c24d61879 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -238,7 +238,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // If our vector's size has exceeded the threshold, request more memory val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { - val amountToRequest = (currentSize * (memoryGrowthFactor - 1)).toLong + val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { @@ -254,7 +254,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } // New threshold is currentSize * memoryGrowthFactor - memoryThreshold = currentSize + amountToRequest + memoryThreshold += amountToRequest } } elementsUnrolled += 1 diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index f9fdffae8bd8f..b8f5d3a5b02aa 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -25,10 +25,12 @@ import scala.collection.JavaConversions._ import org.apache.spark.Logging import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.executor.ShuffleWriteMetrics /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -61,7 +63,8 @@ private[spark] trait ShuffleWriterGroup { */ // TODO: Factor this into a separate class for each ShuffleManager implementation private[spark] -class ShuffleBlockManager(blockManager: BlockManager) extends Logging { +class ShuffleBlockManager(blockManager: BlockManager, + shuffleManager: ShuffleManager) extends Logging { def conf = blockManager.conf // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. @@ -70,8 +73,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { conf.getBoolean("spark.shuffle.consolidateFiles", false) // Are we using sort-based shuffle? - val sortBasedShuffle = - conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName + val sortBasedShuffle = shuffleManager.isInstanceOf[SortShuffleManager] private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 @@ -111,7 +113,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { * Get a ShuffleWriterGroup for the given map task, which will register it as complete * when the writers are closed successfully */ - def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { + def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer, + writeMetrics: ShuffleWriteMetrics) = { new ShuffleWriterGroup { shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets)) private val shuffleState = shuffleStates(shuffleId) @@ -121,7 +124,8 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { fileGroup = getUnusedFileGroup() Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize) + blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize, + writeMetrics) } } else { Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => @@ -136,7 +140,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { logWarning(s"Failed to remove existing shuffle file $blockFile") } } - blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize) + blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 75c2e09a6bbb8..aa83ea90ee9ee 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.util.concurrent.ArrayBlockingQueue import akka.actor._ +import org.apache.spark.shuffle.hash.HashShuffleManager import util.Random import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} @@ -101,7 +102,7 @@ private[spark] object ThreadingTest { conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf), new MapOutputTrackerMaster(conf)) + new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 29e9cf947856f..6b4689291097f 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -93,7 +93,7 @@ private[spark] object JettyUtils extends Logging { def createServletHandler( path: String, servlet: HttpServlet, - basePath: String = ""): ServletContextHandler = { + basePath: String): ServletContextHandler = { val prefixedPath = attachPrefix(basePath, path) val contextHandler = new ServletContextHandler val holder = new ServletHolder(servlet) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 6c788a37dc70b..cccd59d122a92 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -76,6 +76,8 @@ private[spark] class SparkUI( } } + def getAppName = appName + /** Set the app name for this UI. */ def setAppName(name: String) { appName = name @@ -100,6 +102,13 @@ private[spark] class SparkUI( private[spark] def appUIAddress = s"http://$appUIHostPort" } +private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) + extends WebUITab(parent, prefix) { + + def appName: String = parent.getAppName + +} + private[spark] object SparkUI { val DEFAULT_PORT = 4040 val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 715cc2f4df8dd..bee6dad3387e5 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -163,17 +163,15 @@ private[spark] object UIUtils extends Logging { /** Returns a spark page with correctly formatted headers */ def headerSparkPage( - content: => Seq[Node], - basePath: String, - appName: String, title: String, - tabs: Seq[WebUITab], - activeTab: WebUITab, + content: => Seq[Node], + activeTab: SparkUITab, refreshInterval: Option[Int] = None): Seq[Node] = { - val header = tabs.map { tab => + val appName = activeTab.appName + val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 5f52f95088007..5d88ca403a674 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -50,6 +50,7 @@ private[spark] abstract class WebUI( protected val publicHostName = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) private val className = Utils.getFormattedClassName(this) + def getBasePath: String = basePath def getTabs: Seq[WebUITab] = tabs.toSeq def getHandlers: Seq[ServletContextHandler] = handlers.toSeq def getSecurityManager: SecurityManager = securityManager @@ -135,6 +136,8 @@ private[spark] abstract class WebUITab(parent: WebUI, val prefix: String) { /** Get a list of header tabs from the parent UI. */ def headerTabs: Seq[WebUITab] = parent.getTabs + + def basePath: String = parent.getBasePath } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index b347eb1b83c1f..f0a1174a71d34 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -24,8 +24,6 @@ import scala.xml.Node import org.apache.spark.ui.{UIUtils, WebUIPage} private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -45,7 +43,7 @@ private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("")

    Classpath Entries

    {classpathEntriesTable} - UIUtils.headerSparkPage(content, basePath, appName, "Environment", parent.headerTabs, parent) + UIUtils.headerSparkPage("Environment", content, parent) } private def propertyHeader = Seq("Name", "Value") diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index bbbe55ecf44a1..0d158fbe638d3 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -21,9 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.ui._ -private[ui] class EnvironmentTab(parent: SparkUI) extends WebUITab(parent, "environment") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") { val listener = new EnvironmentListener attachPage(new EnvironmentPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b814b0e6b8509..02df4e8fe61af 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -43,8 +43,6 @@ private case class ExecutorSummaryInfo( maxMemory: Long) private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -101,8 +99,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { ; - UIUtils.headerSparkPage(content, basePath, appName, "Executors (" + execInfo.size + ")", - parent.headerTabs, parent) + UIUtils.headerSparkPage("Executors (" + execInfo.size + ")", content, parent) } /** Render an HTML row representing an executor */ diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 5c2d1d1fe75d3..61eb111cd9100 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -23,11 +23,9 @@ import org.apache.spark.ExceptionFailure import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener -import org.apache.spark.ui.{SparkUI, WebUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab} -private[ui] class ExecutorsTab(parent: SparkUI) extends WebUITab(parent, "executors") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { val listener = new ExecutorsListener(parent.storageStatusListener) attachPage(new ExecutorsPage(this)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index a57a354620163..74cd637d88155 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -153,6 +153,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val (errorMessage, metrics): (Option[String], Option[TaskMetrics]) = taskEnd.reason match { case org.apache.spark.Success => + stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 (None, Option(taskEnd.taskMetrics)) case e: ExceptionFailure => // Handle ExceptionFailure because we might have metrics @@ -199,6 +200,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.shuffleReadBytes += shuffleReadDelta execSummary.shuffleRead += shuffleReadDelta + val inputBytesDelta = + (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L) + - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L)) + stageData.inputBytes += inputBytesDelta + execSummary.inputBytes += inputBytesDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 0da62892118d4..a82f71ed08475 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -26,8 +26,6 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val live = parent.live private val sc = parent.sc private val listener = parent.listener @@ -94,7 +92,7 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")

    Failed Stages ({failedStages.size})

    ++ failedStagesTable.toNodeSeq - UIUtils.headerSparkPage(content, basePath, appName, "Spark Stages", parent.headerTabs, parent) + UIUtils.headerSparkPage("Spark Stages", content, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala index 8a01ec80c9dd6..c16542c9db30f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -21,12 +21,10 @@ import javax.servlet.http.HttpServletRequest import org.apache.spark.SparkConf import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, WebUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab} /** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stages") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") { val live = parent.live val sc = parent.sc val conf = if (live) sc.conf else new SparkConf @@ -53,4 +51,5 @@ private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stag Thread.sleep(100) } } + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 0a2bf31833d2b..7a6c7d1a497ed 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -26,8 +26,6 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing specific pool details */ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { - private val appName = parent.appName - private val basePath = parent.basePath private val live = parent.live private val sc = parent.sc private val listener = parent.listener @@ -51,8 +49,7 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {

    Summary

    ++ poolTable.toNodeSeq ++

    {activeStages.size} Active Stages

    ++ activeStagesTable.toNodeSeq - UIUtils.headerSparkPage(content, basePath, appName, "Fair Scheduler Pool: " + poolName, - parent.headerTabs, parent) + UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index f4b68f241966d..64178e1e33d41 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -25,7 +25,6 @@ import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { - private val basePath = parent.basePath private val listener = parent.listener def toNodeSeq: Seq[Node] = { @@ -59,11 +58,11 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { case Some(stages) => stages.size case None => 0 } + val href = "%s/stages/pool?poolname=%s" + .format(UIUtils.prependBaseUri(parent.basePath), p.name) - - {p.name} - + {p.name} {p.minShare} {p.weight} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8bc1ba758cf77..d4eb02722ad12 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -29,8 +29,6 @@ import org.apache.spark.scheduler.AccumulableInfo /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -44,8 +42,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet - return UIUtils.headerSparkPage(content, basePath, appName, - "Details for Stage %s".format(stageId), parent.headerTabs, parent) + return UIUtils.headerSparkPage("Details for Stage %s".format(stageId), content, parent) } val stageData = stageDataOption.get @@ -227,8 +224,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { maybeAccumulableTable ++

    Tasks

    ++ taskTable - UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId), - parent.headerTabs, parent) + UIUtils.headerSparkPage("Details for Stage %d".format(stageId), content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 3dcfaf76e4aba..16ad0df45aa0d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -32,7 +32,6 @@ private[ui] class StageTableBase( parent: JobProgressTab, killEnabled: Boolean = false) { - private val basePath = parent.basePath private val listener = parent.listener protected def isFairScheduler = parent.isFairScheduler @@ -88,17 +87,19 @@ private[ui] class StageTableBase( private def makeDescription(s: StageInfo): Seq[Node] = { // scalastyle:off val killLink = if (killEnabled) { + val killLinkUri = "%s/stages/stage/kill?id=%s&terminate=true" + .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + val confirm = "return window.confirm('Are you sure you want to kill stage %s ?');" + .format(s.stageId) - (kill) + (kill) } // scalastyle:on - val nameLink = - - {s.name} - + val nameLinkUri ="%s/stages/stage?id=%s" + .format(UIUtils.prependBaseUri(parent.basePath), s.stageId) + val nameLink = {s.name} val cachedRddInfos = s.rddInfos.filter(_.numCachedPartitions > 0) val details = if (s.details.nonEmpty) { @@ -111,7 +112,7 @@ private[ui] class StageTableBase( Text("RDD: ") ++ // scalastyle:off cachedRddInfos.map { i => - {i.name} + {i.name} } // scalastyle:on }} @@ -157,7 +158,7 @@ private[ui] class StageTableBase( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(parent.basePath), stageData.schedulingPool)}> {stageData.schedulingPool} @@ -168,7 +169,7 @@ private[ui] class StageTableBase( {submissionTime} {formattedDuration} - {makeProgressBar(stageData.numActiveTasks, stageData.numCompleteTasks, + {makeProgressBar(stageData.numActiveTasks, stageData.completedIndices.size, stageData.numFailedTasks, s.numTasks)} {inputReadWithUnit} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 85db15472a00c..a336bf7e1ed02 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -19,6 +19,7 @@ package org.apache.spark.ui.jobs import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} +import org.apache.spark.util.collection.OpenHashSet import scala.collection.mutable.HashMap @@ -38,6 +39,7 @@ private[jobs] object UIData { class StageUIData { var numActiveTasks: Int = _ var numCompleteTasks: Int = _ + var completedIndices = new OpenHashSet[Int]() var numFailedTasks: Int = _ var executorRunTime: Long = _ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 84ac53da47552..8a0075ae8daf7 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -27,8 +27,6 @@ import org.apache.spark.util.Utils /** Page showing storage details for a given RDD */ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -36,8 +34,7 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val storageStatusList = listener.storageStatusList val rddInfo = listener.rddInfoList.find(_.id == rddId).getOrElse { // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage(Seq[Node](), basePath, appName, "RDD Not Found", - parent.headerTabs, parent) + return UIUtils.headerSparkPage("RDD Not Found", Seq[Node](), parent) } // Worker table @@ -96,8 +93,7 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { ; - UIUtils.headerSparkPage(content, basePath, appName, "RDD Storage Info for " + rddInfo.name, - parent.headerTabs, parent) + UIUtils.headerSparkPage("RDD Storage Info for " + rddInfo.name, content, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 9813d9330ac7f..716591c9ed449 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -27,14 +27,12 @@ import org.apache.spark.util.Utils /** Page showing list of RDD's currently stored in the cluster */ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { - private val appName = parent.appName - private val basePath = parent.basePath private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val rdds = listener.rddInfoList val content = UIUtils.listingTable(rddHeader, rddRow, rdds) - UIUtils.headerSparkPage(content, basePath, appName, "Storage ", parent.headerTabs, parent) + UIUtils.headerSparkPage("Storage", content, parent) } /** Header fields for the RDD table */ @@ -52,7 +50,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { // scalastyle:off - + {rdd.name} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 5f6740d495521..67f72a94f0269 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -25,9 +25,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.storage._ /** Web UI showing storage status of all RDD's in the given SparkContext. */ -private[ui] class StorageTab(parent: SparkUI) extends WebUITab(parent, "storage") { - val appName = parent.appName - val basePath = parent.basePath +private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") { val listener = new StorageListener(parent.storageStatusListener) attachPage(new StoragePage(this)) diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala new file mode 100644 index 0000000000000..332d0cbb2dc0c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import akka.actor.Actor +import org.slf4j.Logger + +/** + * A trait to enable logging all Akka actor messages. Here's an example of using this: + * + * {{{ + * class BlockManagerMasterActor extends Actor with ActorLogReceive with Logging { + * ... + * override def receiveWithLogging = { + * case GetLocations(blockId) => + * sender ! getLocations(blockId) + * ... + * } + * ... + * } + * }}} + * + */ +private[spark] trait ActorLogReceive { + self: Actor => + + override def receive: Actor.Receive = new Actor.Receive { + + private val _receiveWithLogging = receiveWithLogging + + override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + + override def apply(o: Any): Unit = { + if (log.isDebugEnabled) { + log.debug(s"[actor] received message $o from ${self.sender}") + } + val start = System.nanoTime + _receiveWithLogging.apply(o) + val timeTaken = (System.nanoTime - start).toDouble / 1000000 + if (log.isDebugEnabled) { + log.debug(s"[actor] handled message ($timeTaken ms) $o from ${self.sender}") + } + } + } + + def receiveWithLogging: Actor.Receive + + protected def log: Logger +} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index b112b359368cd..1e18ec688c40d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -72,8 +72,9 @@ private[spark] object JsonProtocol { case applicationEnd: SparkListenerApplicationEnd => applicationEndToJson(applicationEnd) - // Not used, but keeps compiler happy + // These aren't used, but keeps compiler happy case SparkListenerShutdown => JNothing + case SparkListenerExecutorMetricsUpdate(_, _) => JNothing } } @@ -560,9 +561,8 @@ private[spark] object JsonProtocol { metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long] metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long] metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long] - Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics => - metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics)) - } + metrics.setShuffleReadMetrics( + Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)) metrics.shuffleWriteMetrics = Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.inputMetrics = diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala new file mode 100644 index 0000000000000..c1b8bf052c0ca --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.EventListener + +import org.apache.spark.TaskContext +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution completes. + */ +@DeveloperApi +trait TaskCompletionListener extends EventListener { + def onTaskCompletion(context: TaskContext) +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c60be4f8a11d2..d6d74ce269219 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -146,6 +146,9 @@ private[spark] object Utils extends Logging { Try { Class.forName(clazz, false, getContextOrSparkClassLoader) }.isSuccess } + /** Preferred alternative to Class.forName(className) */ + def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) + /** * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. */ @@ -284,17 +287,32 @@ private[spark] object Utils extends Logging { /** Copy all data from an InputStream to an OutputStream */ def copyStream(in: InputStream, out: OutputStream, - closeStreams: Boolean = false) + closeStreams: Boolean = false): Long = { + var count = 0L try { - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) + if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]) { + // When both streams are File stream, use transferTo to improve copy performance. + val inChannel = in.asInstanceOf[FileInputStream].getChannel() + val outChannel = out.asInstanceOf[FileOutputStream].getChannel() + val size = inChannel.size() + + // In case transferTo method transferred less data than we have required. + while (count < size) { + count += inChannel.transferTo(count, size - count, outChannel) + } + } else { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + count += n + } } } + count } finally { if (closeStreams) { try { @@ -431,12 +449,71 @@ private[spark] object Utils extends Logging { } /** - * Get a temporary directory using Spark's spark.local.dir property, if set. This will always - * return a single directory, even though the spark.local.dir property might be a list of - * multiple paths. + * Get the path of a temporary directory. Spark's local directories can be configured through + * multiple settings, which are used with the following precedence: + * + * - If called from inside of a YARN container, this will return a directory chosen by YARN. + * - If the SPARK_LOCAL_DIRS environment variable is set, this will return a directory from it. + * - Otherwise, if the spark.local.dir is set, this will return a directory from it. + * - Otherwise, this will return java.io.tmpdir. + * + * Some of these configuration options might be lists of multiple paths, but this method will + * always return a single directory. */ def getLocalDir(conf: SparkConf): String = { - conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) + getOrCreateLocalRootDirs(conf)(0) + } + + private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { + // These environment variables are set by YARN. + // For Hadoop 0.23.X, we check for YARN_LOCAL_DIRS (we use this below in getYarnLocalDirs()) + // For Hadoop 2.X, we check for CONTAINER_ID. + conf.getenv("CONTAINER_ID") != null || conf.getenv("YARN_LOCAL_DIRS") != null + } + + /** + * Gets or creates the directories listed in spark.local.dir or SPARK_LOCAL_DIRS, + * and returns only the directories that exist / could be created. + * + * If no directories could be created, this will return an empty list. + */ + private[spark] def getOrCreateLocalRootDirs(conf: SparkConf): Array[String] = { + val confValue = if (isRunningInYarnContainer(conf)) { + // If we are in yarn mode, systems can have different disk layouts so we must set it + // to what Yarn on this system said was available. + getYarnLocalDirs(conf) + } else { + Option(conf.getenv("SPARK_LOCAL_DIRS")).getOrElse( + conf.get("spark.local.dir", System.getProperty("java.io.tmpdir"))) + } + val rootDirs = confValue.split(',') + logDebug(s"Getting/creating local root dirs at '$confValue'") + + rootDirs.flatMap { rootDir => + val localDir: File = new File(rootDir) + val foundLocalDir = localDir.exists || localDir.mkdirs() + if (!foundLocalDir) { + logError(s"Failed to create local root dir in $rootDir. Ignoring this directory.") + None + } else { + Some(rootDir) + } + } + } + + /** Get the Yarn approved local directories. */ + private def getYarnLocalDirs(conf: SparkConf): String = { + // Hadoop 0.23 and 2.x have different Environment variable names for the + // local dirs, so lets check both. We assume one of the 2 is set. + // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X + val localDirs = Option(conf.getenv("YARN_LOCAL_DIRS")) + .getOrElse(Option(conf.getenv("LOCAL_DIRS")) + .getOrElse("")) + + if (localDirs.isEmpty) { + throw new Exception("Yarn Local dirs can't be empty") + } + localDirs } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 260a5c3888aa7..9f85b94a70800 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -31,6 +31,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator +import org.apache.spark.executor.ShuffleWriteMetrics /** * :: DeveloperApi :: @@ -102,6 +103,10 @@ class ExternalAppendOnlyMap[K, V, C]( private var _diskBytesSpilled = 0L private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Write metrics for current spill + private var curWriteMetrics: ShuffleWriteMetrics = _ + private val keyComparator = new HashComparator[K] private val ser = serializer.newInstance() @@ -172,7 +177,9 @@ class ExternalAppendOnlyMap[K, V, C]( logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) val (blockId, file) = diskBlockManager.createTempBlock() - var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, + curWriteMetrics) var objectsWritten = 0 // List of batch sizes (bytes) in the order they are written to disk @@ -183,9 +190,8 @@ class ExternalAppendOnlyMap[K, V, C]( val w = writer writer = null w.commitAndClose() - val bytesWritten = w.bytesWritten - batchSizes.append(bytesWritten) - _diskBytesSpilled += bytesWritten + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + batchSizes.append(curWriteMetrics.shuffleBytesWritten) objectsWritten = 0 } @@ -199,7 +205,9 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, + curWriteMetrics) } } if (objectsWritten > 0) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3f93afd57b3ad..5d8a648d9551e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -25,9 +25,10 @@ import scala.collection.mutable import com.google.common.io.ByteStreams -import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner} +import org.apache.spark._ import org.apache.spark.serializer.{DeserializationStream, Serializer} -import org.apache.spark.storage.BlockId +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.storage.{BlockObjectWriter, BlockId} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -66,6 +67,13 @@ import org.apache.spark.storage.BlockId * for equality to merge values. * * - Users are expected to call stop() at the end to delete all the intermediate files. + * + * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is + * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to + * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can + * then concatenate these files to produce a single sorted file, without having to serialize and + * de-serialize each item twice (as is needed during the merge). This speeds up the map side of + * groupBy, sort, etc operations since they do no partial aggregation. */ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, @@ -112,14 +120,29 @@ private[spark] class ExternalSorter[K, V, C]( // What threshold of elementsRead we start estimating map size at. private val trackMemoryThreshold = 1000 - // Spilling statistics + // Total spilling statistics private var spillCount = 0 private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L + // Write metrics for current spill + private var curWriteMetrics: ShuffleWriteMetrics = _ + // How much of the shared memory pool this collection has claimed private var myMemoryThreshold = 0L + // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need + // local aggregation and sorting, write numPartitions files directly and just concatenate them + // at the end. This avoids doing serialization and deserialization twice to merge together the + // spilled files, which would happen with the normal code path. The downside is having multiple + // files open at a time and thus more memory allocated to buffers. + private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + private val bypassMergeSort = + (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty) + + // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled + private var partitionWriters: Array[BlockObjectWriter] = null + // A comparator for keys K that orders them within a partition to allow aggregation or sorting. // Can be a partial ordering by hash code if a total ordering is not provided through by the // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some @@ -133,7 +156,14 @@ private[spark] class ExternalSorter[K, V, C]( } }) - // A comparator for (Int, K) elements that orders them by partition and then possibly by key + // A comparator for (Int, K) pairs that orders them by only their partition ID + private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] { + override def compare(a: (Int, K), b: (Int, K)): Int = { + a._1 - b._1 + } + } + + // A comparator that orders (Int, K) pairs by partition ID and then possibly by key private val partitionKeyComparator: Comparator[(Int, K)] = { if (ordering.isDefined || aggregator.isDefined) { // Sort by partition ID then key comparator @@ -149,11 +179,7 @@ private[spark] class ExternalSorter[K, V, C]( } } else { // Just sort it by partition ID - new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - a._1 - b._1 - } - } + partitionComparator } } @@ -167,7 +193,7 @@ private[spark] class ExternalSorter[K, V, C]( elementsPerPartition: Array[Long]) private val spills = new ArrayBuffer[SpilledFile] - def write(records: Iterator[_ <: Product2[K, V]]): Unit = { + def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = { // TODO: stop combining if we find that the reduction factor isn't high val shouldCombine = aggregator.isDefined @@ -238,8 +264,41 @@ private[spark] class ExternalSorter[K, V, C]( val threadId = Thread.currentThread().getId logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + + if (bypassMergeSort) { + spillToPartitionFiles(collection) + } else { + spillToMergeableFile(collection) + } + + if (usingMap) { + map = new SizeTrackingAppendOnlyMap[(Int, K), C] + } else { + buffer = new SizeTrackingPairBuffer[(Int, K), C] + } + + // Release our memory back to the shuffle pool so that other threads can grab it + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0 + + _memoryBytesSpilled += memorySize + } + + /** + * Spill our in-memory collection to a sorted file that we can merge later (normal code path). + * We add this file into spilledFiles to find it later. + * + * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition. + * See spillToPartitionedFiles() for that code path. + * + * @param collection whichever collection we're using (map or buffer) + */ + private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + assert(!bypassMergeSort) + val (blockId, file) = diskBlockManager.createTempBlock() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) var objectsWritten = 0 // Objects written since the last flush // List of batch sizes (bytes) in the order they are written to disk @@ -254,9 +313,8 @@ private[spark] class ExternalSorter[K, V, C]( val w = writer writer = null w.commitAndClose() - val bytesWritten = w.bytesWritten - batchSizes.append(bytesWritten) - _diskBytesSpilled += bytesWritten + _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten + batchSizes.append(curWriteMetrics.shuffleBytesWritten) objectsWritten = 0 } @@ -275,7 +333,8 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize) + curWriteMetrics = new ShuffleWriteMetrics() + writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { @@ -299,18 +358,36 @@ private[spark] class ExternalSorter[K, V, C]( } } - if (usingMap) { - map = new SizeTrackingAppendOnlyMap[(Int, K), C] - } else { - buffer = new SizeTrackingPairBuffer[(Int, K), C] - } + spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) + } - // Release our memory back to the shuffle pool so that other threads can grab it - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0 + /** + * Spill our in-memory collection to separate files, one for each partition. This is used when + * there's no aggregator and ordering and the number of partitions is small, because it allows + * writePartitionedFile to just concatenate files without deserializing data. + * + * @param collection whichever collection we're using (map or buffer) + */ + private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { + assert(bypassMergeSort) + + // Create our file writers if we haven't done so yet + if (partitionWriters == null) { + curWriteMetrics = new ShuffleWriteMetrics() + partitionWriters = Array.fill(numPartitions) { + val (blockId, file) = diskBlockManager.createTempBlock() + blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open() + } + } - spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)) - _memoryBytesSpilled += memorySize + val it = collection.iterator // No need to sort stuff, just write each element out + while (it.hasNext) { + val elem = it.next() + val partitionId = elem._1._1 + val key = elem._1._2 + val value = elem._2 + partitionWriters(partitionId).write((key, value)) + } } /** @@ -474,7 +551,6 @@ private[spark] class ExternalSorter[K, V, C]( skipToNextPartition() - // Intermediate file and deserializer streams that read from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams var fileStream: FileInputStream = null @@ -614,23 +690,25 @@ private[spark] class ExternalSorter[K, V, C]( def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = { val usingMap = aggregator.isDefined val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer - if (spills.isEmpty) { + if (spills.isEmpty && partitionWriters == null) { // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps // we don't even need to sort by anything other than partition ID if (!ordering.isDefined) { - // The user isn't requested sorted keys, so only sort by partition ID, not key - val partitionComparator = new Comparator[(Int, K)] { - override def compare(a: (Int, K), b: (Int, K)): Int = { - a._1 - b._1 - } - } + // The user hasn't requested sorted keys, so only sort by partition ID, not key groupByPartition(collection.destructiveSortedIterator(partitionComparator)) } else { // We do need to sort by both partition ID and key groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator)) } + } else if (bypassMergeSort) { + // Read data from each partition file and merge it together with the data in memory; + // note that there's no ordering or aggregator in this case -- we just partition objects + val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator)) + collIter.map { case (partitionId, values) => + (partitionId, values ++ readPartitionFile(partitionWriters(partitionId))) + } } else { - // General case: merge spilled and in-memory data + // Merge spilled and in-memory data merge(spills, collection.destructiveSortedIterator(partitionKeyComparator)) } } @@ -640,9 +718,112 @@ private[spark] class ExternalSorter[K, V, C]( */ def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2) + /** + * Write all the data added into this ExternalSorter into a file in the disk store, creating + * an .index file for it as well with the offsets of each partition. This is called by the + * SortShuffleWriter and can go through an efficient path of just concatenating binary files + * if we decided to avoid merge-sorting. + * + * @param blockId block ID to write to. The index file will be blockId.name + ".index". + * @param context a TaskContext for a running Spark task, for us to update shuffle metrics. + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + def writePartitionedFile(blockId: BlockId, context: TaskContext): Array[Long] = { + val outputFile = blockManager.diskBlockManager.getFile(blockId) + + // Track location of each range in the output file + val offsets = new Array[Long](numPartitions + 1) + val lengths = new Array[Long](numPartitions) + + if (bypassMergeSort && partitionWriters != null) { + // We decided to write separate files for each partition, so just concatenate them. To keep + // this simple we spill out the current in-memory collection so that everything is in files. + spillToPartitionFiles(if (aggregator.isDefined) map else buffer) + partitionWriters.foreach(_.commitAndClose()) + var out: FileOutputStream = null + var in: FileInputStream = null + try { + out = new FileOutputStream(outputFile) + for (i <- 0 until numPartitions) { + in = new FileInputStream(partitionWriters(i).fileSegment().file) + val size = org.apache.spark.util.Utils.copyStream(in, out, false) + in.close() + in = null + lengths(i) = size + offsets(i + 1) = offsets(i) + lengths(i) + } + } finally { + if (out != null) { + out.close() + } + if (in != null) { + in.close() + } + } + } else { + // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by + // partition and just write everything directly. + for ((id, elements) <- this.partitionedIterator) { + if (elements.hasNext) { + val writer = blockManager.getDiskWriter( + blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get) + for (elem <- elements) { + writer.write(elem) + } + writer.commitAndClose() + val segment = writer.fileSegment() + offsets(id + 1) = segment.offset + segment.length + lengths(id) = segment.length + } else { + // The partition is empty; don't create a new writer to avoid writing headers, etc + offsets(id + 1) = offsets(id) + } + } + } + + context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled + context.taskMetrics.diskBytesSpilled += diskBytesSpilled + + // Write an index file with the offsets of each block, plus a final offset at the end for the + // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure + // out where each block begins and ends. + + val diskBlockManager = blockManager.diskBlockManager + val indexFile = diskBlockManager.getFile(blockId.name + ".index") + val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) + try { + var i = 0 + while (i < numPartitions + 1) { + out.writeLong(offsets(i)) + i += 1 + } + } finally { + out.close() + } + + lengths + } + + /** + * Read a partition file back as an iterator (used in our iterator method) + */ + def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = { + if (writer.isOpen) { + writer.commitAndClose() + } + blockManager.getLocalFromDisk(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]] + } + def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() + if (partitionWriters != null) { + partitionWriters.foreach { w => + w.revertPartialWritesAndClose() + diskBlockManager.getFile(w.blockId).delete() + } + partitionWriters = null + } } def memoryBytesSpilled: Long = _memoryBytesSpilled diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 56150caa5d6ba..e1c13de04a0be 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1239,12 +1239,28 @@ public Tuple2 call(Integer i) { Assert.assertTrue(worCounts.size() == 2); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); - JavaPairRDD wrExact = rdd2.sampleByKey(true, fractions, true, 1L); + } + + @Test + @SuppressWarnings("unchecked") + public void sampleByKeyExact() { + JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); + JavaPairRDD rdd2 = rdd1.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(Integer i) { + return new Tuple2(i % 2, 1); + } + }); + Map fractions = Maps.newHashMap(); + fractions.put(0, 0.5); + fractions.put(1, 1.0); + JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); Map wrExactCounts = (Map) (Object) wrExact.countByKey(); Assert.assertTrue(wrExactCounts.size() == 2); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); - JavaPairRDD worExact = rdd2.sampleByKey(false, fractions, true, 1L); + JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); Map worExactCounts = (Map) (Object) worExact.countByKey(); Assert.assertTrue(worExactCounts.size() == 2); Assert.assertTrue(worExactCounts.get(0) == 2); diff --git a/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java new file mode 100644 index 0000000000000..3d50ab4fabe42 --- /dev/null +++ b/core/src/test/java/org/apache/spark/serializer/TestJavaSerializerImpl.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer; + +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import scala.Option; +import scala.reflect.ClassTag; + + +/** + * A simple Serializer implementation to make sure the API is Java-friendly. + */ +class TestJavaSerializerImpl extends Serializer { + + @Override + public SerializerInstance newInstance() { + return null; + } + + static class SerializerInstanceImpl extends SerializerInstance { + @Override + public ByteBuffer serialize(T t, ClassTag evidence$1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag evidence$1) { + return null; + } + + @Override + public T deserialize(ByteBuffer bytes, ClassTag evidence$1) { + return null; + } + + @Override + public SerializationStream serializeStream(OutputStream s) { + return null; + } + + @Override + public DeserializationStream deserializeStream(InputStream s) { + return null; + } + } + + static class SerializationStreamImpl extends SerializationStream { + + @Override + public SerializationStream writeObject(T t, ClassTag evidence$1) { + return null; + } + + @Override + public void flush() { + + } + + @Override + public void close() { + + } + } + + static class DeserializationStreamImpl extends DeserializationStream { + + @Override + public T readObject(ClassTag evidence$1) { + return null; + } + + @Override + public void close() { + + } + } +} diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java similarity index 58% rename from core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java rename to core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index 264cf97d0209f..af34cdb03e4d1 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientChannelInitializer.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -15,25 +15,25 @@ * limitations under the License. */ -package org.apache.spark.network.netty; +package org.apache.spark.util; -import io.netty.channel.ChannelInitializer; -import io.netty.channel.socket.SocketChannel; -import io.netty.handler.codec.string.StringEncoder; +import org.apache.spark.TaskContext; -class FileClientChannelInitializer extends ChannelInitializer { - private final FileClientHandler fhandler; - - FileClientChannelInitializer(FileClientHandler handler) { - fhandler = handler; - } +/** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ +public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { @Override - public void initChannel(SocketChannel channel) { - // file no more than 2G - channel.pipeline() - .addLast("encoder", new StringEncoder()) - .addLast("handler", fhandler); + public void onTaskCompletion(TaskContext context) { + context.isCompleted(); + context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.runningLocally(); + context.taskMetrics(); + context.addTaskCompletionListener(this); } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 7c3d0208b195a..978a6ded80829 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.broadcast -import org.apache.spark.storage.{BroadcastBlockId, _} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} import org.scalatest.FunSuite +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException} +import org.apache.spark.storage._ + + class BroadcastSuite extends FunSuite with LocalSparkContext { private val httpConf = broadcastConf("HttpBroadcastFactory") @@ -44,7 +46,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing HttpBroadcast variables in a local cluster") { val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val conf = httpConf.clone + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.broadcast.compress", "true") + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -69,7 +74,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { test("Accessing TorrentBroadcast variables in a local cluster") { val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val conf = torrentConf.clone + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.broadcast.compress", "true") + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -118,12 +126,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { val numSlaves = if (distributed) 2 else 0 - def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) - // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { + val blockId = BroadcastBlockId(broadcastId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === 1) statuses.head match { case (bm, status) => assert(bm.executorId === "", "Block should only be on the driver") @@ -133,14 +139,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } if (distributed) { // this file is only generated in distributed mode - assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!") } } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { + val blockId = BroadcastBlockId(broadcastId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) assert(statuses.size === numSlaves + 1) statuses.foreach { case (_, status) => assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) @@ -151,21 +157,21 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - assert(blockIds.size === 1) - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { + val blockId = BroadcastBlockId(broadcastId) + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) val expectedNumBlocks = if (removeFromDriver) 0 else 1 val possiblyNot = if (removeFromDriver) "" else " not" assert(statuses.size === expectedNumBlocks, "Block should%s be unpersisted on the driver".format(possiblyNot)) if (distributed && removeFromDriver) { // this file is only generated in distributed mode - assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + assert(!HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file should%s be deleted".format(possiblyNot)) } } - testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -179,67 +185,51 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) { val numSlaves = if (distributed) 2 else 0 - def getBlockIds(id: Long) = { - val broadcastBlockId = BroadcastBlockId(id) - val metaBlockId = BroadcastBlockId(id, "meta") - // Assume broadcast value is small enough to fit into 1 piece - val pieceBlockId = BroadcastBlockId(id, "piece0") - if (distributed) { - // the metadata and piece blocks are generated only in distributed mode - Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) - } else { - Seq[BroadcastBlockId](broadcastBlockId) - } + // Verify that blocks are persisted only on the driver + def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { + var blockId = BroadcastBlockId(broadcastId) + var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === 1) + + blockId = BroadcastBlockId(broadcastId, "piece0") + statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === (if (distributed) 1 else 0)) } - // Verify that blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { + var blockId = BroadcastBlockId(broadcastId) + var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + if (distributed) { + assert(statuses.size === numSlaves + 1) + } else { assert(statuses.size === 1) - statuses.head match { case (bm, status) => - assert(bm.executorId === "", "Block should only be on the driver") - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store on the driver") - assert(status.diskSize === 0, "Block should not be in disk store on the driver") - } } - } - // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - if (blockId.field == "meta") { - // Meta data is only on the driver - assert(statuses.size === 1) - statuses.head match { case (bm, _) => assert(bm.executorId === "") } - } else { - // Other blocks are on both the executors and the driver - assert(statuses.size === numSlaves + 1, - blockId + " has " + statuses.size + " statuses: " + statuses.mkString(",")) - statuses.foreach { case (_, status) => - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store") - assert(status.diskSize === 0, "Block should not be in disk store") - } - } + blockId = BroadcastBlockId(broadcastId, "piece0") + statuses = bmm.getBlockStatus(blockId, askSlaves = true) + if (distributed) { + assert(statuses.size === numSlaves + 1) + } else { + assert(statuses.size === 0) } } // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. - def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { - val expectedNumBlocks = if (removeFromDriver) 0 else 1 - val possiblyNot = if (removeFromDriver) "" else " not" - blockIds.foreach { blockId => - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === expectedNumBlocks, - "Block should%s be unpersisted on the driver".format(possiblyNot)) - } + def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { + var blockId = BroadcastBlockId(broadcastId) + var expectedNumBlocks = if (removeFromDriver) 0 else 1 + var statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === expectedNumBlocks) + + blockId = BroadcastBlockId(broadcastId, "piece0") + expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1 + statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -256,10 +246,9 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { distributed: Boolean, numSlaves: Int, // used only when distributed = true broadcastConf: SparkConf, - getBlockIds: Long => Seq[BroadcastBlockId], - afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, - afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, - afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterCreation: (Long, BlockManagerMaster) => Unit, + afterUsingBroadcast: (Long, BlockManagerMaster) => Unit, + afterUnpersist: (Long, BlockManagerMaster) => Unit, removeFromDriver: Boolean) { sc = if (distributed) { @@ -272,15 +261,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Create broadcast variable val broadcast = sc.broadcast(list) - val blocks = getBlockIds(broadcast.id) - afterCreation(blocks, blockManagerMaster) + afterCreation(broadcast.id, blockManagerMaster) // Use broadcast variable on all executors val partitions = 10 assert(partitions > numSlaves) val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) - afterUsingBroadcast(blocks, blockManagerMaster) + afterUsingBroadcast(broadcast.id, blockManagerMaster) // Unpersist broadcast if (removeFromDriver) { @@ -288,7 +276,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } else { broadcast.unpersist(blocking = true) } - afterUnpersist(blocks, blockManagerMaster) + afterUnpersist(broadcast.id, blockManagerMaster) // If the broadcast is removed from driver, all subsequent uses of the broadcast variable // should throw SparkExceptions. Otherwise, the result should be the same as before. diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a5cdcfb5de03b..7e1ef80c84561 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -106,6 +106,18 @@ class SparkSubmitSuite extends FunSuite with Matchers { appArgs.childArgs should be (Seq("some", "--weird", "args")) } + test("handles arguments to user program with name collision") { + val clArgs = Seq( + "--name", "myApp", + "--class", "Foo", + "userjar.jar", + "--master", "local", + "some", + "--weird", "args") + val appArgs = new SparkSubmitArguments(clArgs) + appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) + } + test("handles YARN cluster mode") { val clArgs = Seq( "--deploy-mode", "cluster", diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 3f882a724b047..25be7f25c21bb 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -56,15 +56,33 @@ class CompressionCodecSuite extends FunSuite { testCodec(codec) } + test("lz4 compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "lz4") + assert(codec.getClass === classOf[LZ4CompressionCodec]) + testCodec(codec) + } + test("lzf compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) assert(codec.getClass === classOf[LZFCompressionCodec]) testCodec(codec) } + test("lzf compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "lzf") + assert(codec.getClass === classOf[LZFCompressionCodec]) + testCodec(codec) + } + test("snappy compression codec") { val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) testCodec(codec) } + + test("snappy compression codec short form") { + val codec = CompressionCodec.createCodec(conf, "snappy") + assert(codec.getClass === classOf[SnappyCompressionCodec]) + testCodec(codec) + } } diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala index 415ad8c432c12..e2f4d4c57cdb5 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala @@ -17,14 +17,21 @@ package org.apache.spark.network +import java.io.IOException import java.nio._ +import java.util.concurrent.TimeoutException import org.apache.spark.{SecurityManager, SparkConf} import org.scalatest.FunSuite +import org.mockito.Mockito._ +import org.mockito.Matchers._ + +import scala.concurrent.TimeoutException import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.{Failure, Success, Try} /** * Test the ConnectionManager with various security settings. @@ -46,7 +53,7 @@ class ConnectionManagerSuite extends FunSuite { buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) + Await.result(manager.sendMessageReliably(manager.id, bufferMessage), 10 seconds) assert(receivedMessage == true) @@ -79,7 +86,7 @@ class ConnectionManagerSuite extends FunSuite { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(managerServer.id, bufferMessage) + Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) }) assert(numReceivedServerMessages == 10) @@ -118,7 +125,10 @@ class ConnectionManagerSuite extends FunSuite { val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(managerServer.id, bufferMessage) + // Expect managerServer to close connection, which we'll report as an error: + intercept[IOException] { + Await.result(manager.sendMessageReliably(managerServer.id, bufferMessage), 10 seconds) + } assert(numReceivedServerMessages == 0) assert(numReceivedMessages == 0) @@ -163,6 +173,8 @@ class ConnectionManagerSuite extends FunSuite { val g = Await.result(f, 1 second) assert(false) } catch { + case i: IOException => + assert(true) case e: TimeoutException => { // we should timeout here since the client can't do the negotiation assert(true) @@ -209,7 +221,6 @@ class ConnectionManagerSuite extends FunSuite { }).foreach(f => { try { val g = Await.result(f, 1 second) - if (!g.isDefined) assert(false) else assert(true) } catch { case e: Exception => { assert(false) @@ -223,7 +234,68 @@ class ConnectionManagerSuite extends FunSuite { managerServer.stop() } + test("Ack error message") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + val managerServer = new ConnectionManager(0, conf, securityManager) + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + throw new Exception + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer) + + val future = manager.sendMessageReliably(managerServer.id, bufferMessage) + + intercept[IOException] { + Await.result(future, 1 second) + } + + manager.stop() + managerServer.stop() + + } + + test("sendMessageReliably timeout") { + val clientConf = new SparkConf + clientConf.set("spark.authenticate", "false") + val ackTimeout = 30 + clientConf.set("spark.core.connection.ack.wait.timeout", s"${ackTimeout}") + + val clientSecurityManager = new SecurityManager(clientConf) + val manager = new ConnectionManager(0, clientConf, clientSecurityManager) + + val serverConf = new SparkConf + serverConf.set("spark.authenticate", "false") + val serverSecurityManager = new SecurityManager(serverConf) + val managerServer = new ConnectionManager(0, serverConf, serverSecurityManager) + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + // sleep 60 sec > ack timeout for simulating server slow down or hang up + Thread.sleep(ackTimeout * 3 * 1000) + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + + val future = manager.sendMessageReliably(managerServer.id, bufferMessage) + // Future should throw IOException in 30 sec. + // Otherwise TimeoutExcepton is thrown from Await.result. + // We expect TimeoutException is not thrown. + intercept[IOException] { + Await.result(future, (ackTimeout * 2) second) + } + + manager.stop() + managerServer.stop() + } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala new file mode 100644 index 0000000000000..02d0ffc86f58f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.{RandomAccessFile, File} +import java.nio.ByteBuffer +import java.util.{Collections, HashSet} +import java.util.concurrent.{TimeUnit, Semaphore} + +import scala.collection.JavaConversions._ + +import io.netty.buffer.{ByteBufUtil, Unpooled} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.SparkConf +import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory} +import org.apache.spark.network.netty.server.BlockServer +import org.apache.spark.storage.{FileSegment, BlockDataProvider} + + +/** + * Test suite that makes sure the server and the client implementations share the same protocol. + */ +class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { + + val bufSize = 100000 + var buf: ByteBuffer = _ + var testFile: File = _ + var server: BlockServer = _ + var clientFactory: BlockFetchingClientFactory = _ + + val bufferBlockId = "buffer_block" + val fileBlockId = "file_block" + + val fileContent = new Array[Byte](1024) + scala.util.Random.nextBytes(fileContent) + + override def beforeAll() = { + buf = ByteBuffer.allocate(bufSize) + for (i <- 1 to bufSize) { + buf.put(i.toByte) + } + buf.flip() + + testFile = File.createTempFile("netty-test-file", "txt") + val fp = new RandomAccessFile(testFile, "rw") + fp.write(fileContent) + fp.close() + + server = new BlockServer(new SparkConf, new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + if (blockId == bufferBlockId) { + Right(buf) + } else if (blockId == fileBlockId) { + Left(new FileSegment(testFile, 10, testFile.length - 25)) + } else { + throw new Exception("Unknown block id " + blockId) + } + } + }) + + clientFactory = new BlockFetchingClientFactory(new SparkConf) + } + + override def afterAll() = { + server.stop() + clientFactory.stop() + } + + /** A ByteBuf for buffer_block */ + lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf) + + /** A ByteBuf for file_block */ + lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) + + def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) = + { + val client = clientFactory.createClient(server.hostName, server.port) + val sem = new Semaphore(0) + val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) + val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) + val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer]) + + client.fetchBlocks( + blockIds, + new BlockClientListener { + override def onFetchFailure(blockId: String, errorMsg: String): Unit = { + errorBlockIds.add(blockId) + sem.release() + } + + override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { + receivedBlockIds.add(blockId) + data.retain() + receivedBuffers.add(data) + sem.release() + } + } + ) + if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server") + } + client.close() + (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet) + } + + test("fetch a ByteBuffer block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) + assert(blockIds === Set(bufferBlockId)) + assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) + } + + test("fetch a FileSegment block via zero-copy send") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) + assert(blockIds === Set(fileBlockId)) + assert(buffers.map(_.underlying) === Set(fileBlockReference)) + assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) + } + + test("fetch a non-existent block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) + assert(blockIds.isEmpty) + assert(buffers.isEmpty) + assert(failBlockIds === Set("random-block")) + } + + test("fetch both ByteBuffer block and FileSegment block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) + assert(blockIds === Set(bufferBlockId, fileBlockId)) + assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference)) + assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) + } + + test("fetch both ByteBuffer block and a non-existent block") { + val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) + assert(blockIds === Set(bufferBlockId)) + assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(failBlockIds === Set("random-block")) + buffers.foreach(_.release()) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala new file mode 100644 index 0000000000000..903ab09ae4322 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.client + +import java.nio.ByteBuffer + +import io.netty.buffer.Unpooled +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.{PrivateMethodTester, FunSuite} + + +class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester { + + test("handling block data (successful fetch)") { + val blockId = "test_block" + val blockData = "blahblahblahblahblah" + val totalLength = 4 + blockId.length + blockData.length + + var parsedBlockId: String = "" + var parsedBlockData: String = "" + val handler = new BlockFetchingClientHandler + handler.addRequest(blockId, + new BlockClientListener { + override def onFetchFailure(blockId: String, errorMsg: String): Unit = ??? + override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = { + parsedBlockId = bid + val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining) + refCntBuf.byteBuffer().get(bytes) + parsedBlockData = new String(bytes) + } + } + ) + + val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) + assert(handler.invokePrivate(outstandingRequests()).size === 1) + + val channel = new EmbeddedChannel(handler) + val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself + buf.putInt(totalLength) + buf.putInt(blockId.length) + buf.put(blockId.getBytes) + buf.put(blockData.getBytes) + buf.flip() + + channel.writeInbound(Unpooled.wrappedBuffer(buf)) + assert(parsedBlockId === blockId) + assert(parsedBlockData === blockData) + + assert(handler.invokePrivate(outstandingRequests()).size === 0) + + channel.close() + } + + test("handling error message (failed fetch)") { + val blockId = "test_block" + val errorMsg = "error erro5r error err4or error3 error6 error erro1r" + val totalLength = 4 + blockId.length + errorMsg.length + + var parsedBlockId: String = "" + var parsedErrorMsg: String = "" + val handler = new BlockFetchingClientHandler + handler.addRequest(blockId, new BlockClientListener { + override def onFetchFailure(bid: String, msg: String) ={ + parsedBlockId = bid + parsedErrorMsg = msg + } + override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ??? + }) + + val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) + assert(handler.invokePrivate(outstandingRequests()).size === 1) + + val channel = new EmbeddedChannel(handler) + val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself + buf.putInt(totalLength) + buf.putInt(-blockId.length) + buf.put(blockId.getBytes) + buf.put(errorMsg.getBytes) + buf.flip() + + channel.writeInbound(Unpooled.wrappedBuffer(buf)) + assert(parsedBlockId === blockId) + assert(parsedErrorMsg === errorMsg) + + assert(handler.invokePrivate(outstandingRequests()).size === 0) + + channel.close() + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala new file mode 100644 index 0000000000000..3ee281cb1350b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import io.netty.buffer.ByteBuf +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + + +class BlockHeaderEncoderSuite extends FunSuite { + + test("encode normal block data") { + val blockId = "test_block" + val channel = new EmbeddedChannel(new BlockHeaderEncoder) + channel.writeOutbound(new BlockHeader(17, blockId, None)) + val out = channel.readOutbound().asInstanceOf[ByteBuf] + assert(out.readInt() === 4 + blockId.length + 17) + assert(out.readInt() === blockId.length) + + val blockIdBytes = new Array[Byte](blockId.length) + out.readBytes(blockIdBytes) + assert(new String(blockIdBytes) === blockId) + assert(out.readableBytes() === 0) + + channel.close() + } + + test("encode error message") { + val blockId = "error_block" + val errorMsg = "error encountered" + val channel = new EmbeddedChannel(new BlockHeaderEncoder) + channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg))) + val out = channel.readOutbound().asInstanceOf[ByteBuf] + assert(out.readInt() === 4 + blockId.length + errorMsg.length) + assert(out.readInt() === -blockId.length) + + val blockIdBytes = new Array[Byte](blockId.length) + out.readBytes(blockIdBytes) + assert(new String(blockIdBytes) === blockId) + + val errorMsgBytes = new Array[Byte](errorMsg.length) + out.readBytes(errorMsgBytes) + assert(new String(errorMsgBytes) === errorMsg) + assert(out.readableBytes() === 0) + + channel.close() + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala new file mode 100644 index 0000000000000..3239c710f1639 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty.server + +import java.io.{RandomAccessFile, File} +import java.nio.ByteBuffer + +import io.netty.buffer.{Unpooled, ByteBuf} +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion} +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + +import org.apache.spark.storage.{BlockDataProvider, FileSegment} + + +class BlockServerHandlerSuite extends FunSuite { + + test("ByteBuffer block") { + val expectedBlockId = "test_bytebuffer_block" + val buf = ByteBuffer.allocate(10000) + for (i <- 1 to 10000) { + buf.put(i.toByte) + } + buf.flip() + + val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf) + })) + + channel.writeInbound(expectedBlockId) + assert(channel.outboundMessages().size === 2) + + val out1 = channel.readOutbound().asInstanceOf[BlockHeader] + val out2 = channel.readOutbound().asInstanceOf[ByteBuf] + + assert(out1.blockId === expectedBlockId) + assert(out1.blockSize === buf.remaining) + assert(out1.error === None) + + assert(out2.equals(Unpooled.wrappedBuffer(buf))) + + channel.close() + } + + test("FileSegment block via zero-copy") { + val expectedBlockId = "test_file_block" + + // Create random file data + val fileContent = new Array[Byte](1024) + scala.util.Random.nextBytes(fileContent) + val testFile = File.createTempFile("netty-test-file", "txt") + val fp = new RandomAccessFile(testFile, "rw") + fp.write(fileContent) + fp.close() + + val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + Left(new FileSegment(testFile, 15, testFile.length - 25)) + } + })) + + channel.writeInbound(expectedBlockId) + assert(channel.outboundMessages().size === 2) + + val out1 = channel.readOutbound().asInstanceOf[BlockHeader] + val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion] + + assert(out1.blockId === expectedBlockId) + assert(out1.blockSize === testFile.length - 25) + assert(out1.error === None) + + assert(out2.count === testFile.length - 25) + assert(out2.position === 15) + } + + test("pipeline exception propagation") { + val blockServerHandler = new BlockServerHandler(new BlockDataProvider { + override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ??? + }) + val exceptionHandler = new SimpleChannelInboundHandler[String]() { + override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = { + throw new Exception("this is an error") + } + } + + val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler) + assert(channel.isOpen) + channel.writeInbound("a message to trigger the error") + assert(!channel.isOpen) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index a822bd18bfdbd..f89bdb6e07dea 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -245,6 +245,29 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithoutBucketsForLargerDatasets") { + // Verify the case of slighly larger datasets + val rdd = sc.parallelize(6 to 99) + val (histogramBuckets, histogramResults) = rdd.histogram(8) + val expectedHistogramResults = + Array(12, 12, 11, 12, 12, 11, 12, 12) + val expectedHistogramBuckets = + Array(6.0, 17.625, 29.25, 40.875, 52.5, 64.125, 75.75, 87.375, 99.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsWithIrrationalBucketEdges") { + // Verify the case of buckets with irrational edges. See #SPARK-2862. + val rdd = sc.parallelize(6 to 99) + val (histogramBuckets, histogramResults) = rdd.histogram(9) + val expectedHistogramResults = + Array(11, 10, 11, 10, 10, 11, 10, 10, 11) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets(0) === 6.0) + assert(histogramBuckets(9) === 99.0) + } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 4f49d4a1d4d34..63d3ddb4af98a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -84,118 +84,81 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } test("sampleByKey") { - def stratifier (fractionPositive: Double) = { - (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" - } - def checkSize(exact: Boolean, - withReplacement: Boolean, - expected: Long, - actual: Long, - p: Double): Boolean = { - if (exact) { - return expected == actual - } - val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev + val defaultSeed = 1L + + // vary RDD size + for (n <- List(100, 1000, 1000000)) { + val data = sc.parallelize(1 to n, 2) + val fractionPositive = 0.3 + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } - // Without replacement validation - def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)], - exact: Boolean, - samplingRate: Double, - seed: Long, - n: Long) = { - val expectedSampleSize = stratifiedData.countByKey() - .mapValues(count => math.ceil(count * samplingRate).toInt) - val fractions = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(false, fractions, exact, seed) - val sampleCounts = sample.countByKey() - val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } - assert(takeSample.size === takeSample.toSet.size) - takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + // vary fractionPositive + for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } - // With replacement validation - def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)], - exact: Boolean, - samplingRate: Double, - seed: Long, - n: Long) = { - val expectedSampleSize = stratifiedData.countByKey().mapValues(count => - math.ceil(count * samplingRate).toInt) - val fractions = Map("1" -> samplingRate, "0" -> samplingRate) - val sample = stratifiedData.sampleByKey(true, fractions, exact, seed) - val sampleCounts = sample.countByKey() - val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) } - val groupedByKey = takeSample.groupBy(_._1) - for ((key, v) <- groupedByKey) { - if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { - // sample large enough for there to be repeats with high likelihood - assert(v.toSet.size < expectedSampleSize(key)) - } else { - if (exact) { - assert(v.toSet.size <= expectedSampleSize(key)) - } else { - assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) - } - } - } - takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + // Use the same data for the rest of the tests + val fractionPositive = 0.3 + val n = 100 + val data = sc.parallelize(1 to n, 2) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) + + // vary seed + for (seed <- defaultSeed to defaultSeed + 5L) { + val samplingRate = 0.1 + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, seed, n) } - def checkAllCombos(stratifiedData: RDD[(String, Int)], - samplingRate: Double, - seed: Long, - n: Long) = { - takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, n) - takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, seed, n) - takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n) - takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, n) + // vary sampling rate + for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { + StratifiedAuxiliary.testSample(stratifiedData, samplingRate, defaultSeed, n) } + } + test("sampleByKeyExact") { val defaultSeed = 1L // vary RDD size for (n <- List(100, 1000, 1000000)) { val data = sc.parallelize(1 to n, 2) val fractionPositive = 0.3 - val stratifiedData = data.keyBy(stratifier(fractionPositive)) - + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } // vary fractionPositive for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) { val n = 100 val data = sc.parallelize(1 to n, 2) - val stratifiedData = data.keyBy(stratifier(fractionPositive)) - + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } // Use the same data for the rest of the tests val fractionPositive = 0.3 val n = 100 val data = sc.parallelize(1 to n, 2) - val stratifiedData = data.keyBy(stratifier(fractionPositive)) + val stratifiedData = data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive)) // vary seed for (seed <- defaultSeed to defaultSeed + 5L) { val samplingRate = 0.1 - checkAllCombos(stratifiedData, samplingRate, seed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, seed, n) } // vary sampling rate for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) { - checkAllCombos(stratifiedData, samplingRate, defaultSeed, n) + StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, defaultSeed, n) } } @@ -556,6 +519,98 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { intercept[IllegalArgumentException] {shuffled.lookup(-1)} } + private object StratifiedAuxiliary { + def stratifier (fractionPositive: Double) = { + (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" + } + + def checkSize(exact: Boolean, + withReplacement: Boolean, + expected: Long, + actual: Long, + p: Double): Boolean = { + if (exact) { + return expected == actual + } + val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) + // Very forgiving margin since we're dealing with very small sample sizes most of the time + math.abs(actual - expected) <= 6 * stdev + } + + def testSampleExact(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + testBernoulli(stratifiedData, true, samplingRate, seed, n) + testPoisson(stratifiedData, true, samplingRate, seed, n) + } + + def testSample(stratifiedData: RDD[(String, Int)], + samplingRate: Double, + seed: Long, + n: Long) = { + testBernoulli(stratifiedData, false, samplingRate, seed, n) + testPoisson(stratifiedData, false, samplingRate, seed, n) + } + + // Without replacement validation + def testBernoulli(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey() + .mapValues(count => math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = if (exact) { + stratifiedData.sampleByKeyExact(false, fractions, seed) + } else { + stratifiedData.sampleByKey(false, fractions, seed) + } + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case(k, v) => + assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + assert(takeSample.size === takeSample.toSet.size) + takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } + } + + // With replacement validation + def testPoisson(stratifiedData: RDD[(String, Int)], + exact: Boolean, + samplingRate: Double, + seed: Long, + n: Long) = { + val expectedSampleSize = stratifiedData.countByKey().mapValues(count => + math.ceil(count * samplingRate).toInt) + val fractions = Map("1" -> samplingRate, "0" -> samplingRate) + val sample = if (exact) { + stratifiedData.sampleByKeyExact(true, fractions, seed) + } else { + stratifiedData.sampleByKey(true, fractions, seed) + } + val sampleCounts = sample.countByKey() + val takeSample = sample.collect() + sampleCounts.foreach { case (k, v) => + assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) + } + val groupedByKey = takeSample.groupBy(_._1) + for ((key, v) <- groupedByKey) { + if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) { + // sample large enough for there to be repeats with high likelihood + assert(v.toSet.size < expectedSampleSize(key)) + } else { + if (exact) { + assert(v.toSet.size <= expectedSampleSize(key)) + } else { + assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + } + } + } + takeSample.foreach(x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]")) + } + } + } /* diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index b31e3a09e5b9c..926d4fecb5b91 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -81,11 +81,11 @@ class RDDSuite extends FunSuite with SharedSparkContext { def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble - val size = 100 - val uniformDistro = for (i <- 1 to 100000) yield i % size - val simpleRdd = sc.makeRDD(uniformDistro) - assert(error(simpleRdd.countApproxDistinct(4, 0), size) < 0.4) - assert(error(simpleRdd.countApproxDistinct(8, 0), size) < 0.1) + val size = 1000 + val uniformDistro = for (i <- 1 to 5000) yield i % size + val simpleRdd = sc.makeRDD(uniformDistro, 10) + assert(error(simpleRdd.countApproxDistinct(8, 0), size) < 0.2) + assert(error(simpleRdd.countApproxDistinct(12, 0), size) < 0.1) } test("SparkContext.union") { @@ -726,6 +726,16 @@ class RDDSuite extends FunSuite with SharedSparkContext { jrdd.rdd.retag.collect() } + test("parent method") { + val rdd1 = sc.parallelize(1 to 10, 2) + val rdd2 = rdd1.filter(_ % 2 == 0) + val rdd3 = rdd2.map(_ + 1) + val rdd4 = new UnionRDD(sc, List(rdd1, rdd2, rdd3)) + assert(rdd4.parent(0).isInstanceOf[ParallelCollectionRDD[_]]) + assert(rdd4.parent(1).isInstanceOf[FilteredRDD[_]]) + assert(rdd4.parent(2).isInstanceOf[MappedRDD[_, _]]) + } + test("getNarrowAncestors") { val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.filter(_ % 2 == 0).map(_ + 1) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8c1b0fed11f72..bd829752eb401 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -141,7 +141,9 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } before { - sc = new SparkContext("local", "DAGSchedulerSuite") + // Enable local execution for this test + val conf = new SparkConf().set("spark.localExecution.enabled", "true") + sc = new SparkContext("local", "DAGSchedulerSuite", conf) sparkListener.successfulStages.clear() sparkListener.failedStages.clear() failure = null diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 270f7e661045a..db2ad829a48f9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -32,7 +32,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte val rdd = new RDD[String](sc, List()) { override def getPartitions = Array[Partition](StubPartition(0)) override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => TaskContextSuite.completed = true) + context.addTaskCompletionListener(context => TaskContextSuite.completed = true) sys.error("failed") } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ffd23380a886f..93e8ddacf8865 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -154,6 +154,11 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000) val MAX_TASK_FAILURES = 4 + override def beforeEach() { + super.beforeEach() + FakeRackUtil.cleanUp() + } + test("TaskSet with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) @@ -471,7 +476,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { test("new executors get added and lost") { // Assign host2 to rack2 - FakeRackUtil.cleanUp() FakeRackUtil.assignHostToRack("host2", "rack2") sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc) @@ -504,7 +508,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { } test("test RACK_LOCAL tasks") { - FakeRackUtil.cleanUp() // Assign host1 to rack1 FakeRackUtil.assignHostToRack("host1", "rack1") // Assign host2 to rack1 @@ -607,6 +610,39 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2) } + test("Ensure TaskSetManager is usable after addition of levels") { + // Regression test for SPARK-2931 + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(2, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host2", "execB.1"))) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // Only ANY is valid + assert(manager.myLocalityLevels.sameElements(Array(ANY))) + // Add a new executor + sched.addExecutor("execA", "host1") + sched.addExecutor("execB.2", "host2") + manager.executorAdded() + assert(manager.pendingTasksWithNoPrefs.size === 0) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + assert(manager.resourceOffer("execA", "host1", ANY) !== None) + clock.advance(LOCALITY_WAIT * 4) + assert(manager.resourceOffer("execB.2", "host2", ANY) !== None) + sched.removeExecutor("execA") + sched.removeExecutor("execB.2") + manager.executorLost("execA", "host1") + manager.executorLost("execB.2", "host2") + clock.advance(LOCALITY_WAIT * 4) + sched.addExecutor("execC", "host3") + manager.executorAdded() + // Prior to the fix, this line resulted in an ArrayIndexOutOfBoundsException: + assert(manager.resourceOffer("execC", "host3", ANY) !== None) + } + + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala new file mode 100644 index 0000000000000..11e8c9c4cb37f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.apache.spark.util.Utils + +import com.esotericsoftware.kryo.Kryo +import org.scalatest.FunSuite + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils} +import org.apache.spark.SparkContext._ +import org.apache.spark.serializer.KryoDistributedTest._ + +class KryoSerializerDistributedSuite extends FunSuite { + + test("kryo objects are serialised consistently in different processes") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) + conf.set("spark.task.maxFailures", "1") + + val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) + conf.setJars(List(jar.getPath)) + + val sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + val original = Thread.currentThread.getContextClassLoader + val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) + SparkEnv.get.serializer.setDefaultClassLoader(loader) + + val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache() + + // Randomly mix the keys so that the join below will require a shuffle with each partition + // sending data to multiple other partitions. + val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)} + + // Join the two RDDs, and force evaluation + assert(shuffledRDD.join(cachedRDD).collect().size == 1) + + LocalSparkContext.stop(sc) + } +} + +object KryoDistributedTest { + class MyCustomClass + + class AppJarRegistrator extends KryoRegistrator { + override def registerClasses(k: Kryo) { + val classLoader = Thread.currentThread.getContextClassLoader + k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader)) + } + } + + object AppJarRegistrator { + val customClassName = "KryoSerializerDistributedSuiteCustomClass" + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala new file mode 100644 index 0000000000000..967c9e9899c9d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.SparkContext +import org.apache.spark.LocalSparkContext +import org.apache.spark.SparkException + + +class KryoSerializerResizableOutputSuite extends FunSuite { + + // trial and error showed this will not serialize with 1mb buffer + val x = (1 to 400000).toArray + + test("kryo without resizable output buffer should fail on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "1") + val sc = new SparkContext("local", "test", conf) + intercept[SparkException](sc.parallelize(x).collect()) + LocalSparkContext.stop(sc) + } + + test("kryo with resizable output buffer should succeed on large array") { + val conf = new SparkConf(false) + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer.max.mb", "2") + val sc = new SparkContext("local", "test", conf) + assert(sc.parallelize(x).collect() === x) + LocalSparkContext.stop(sc) + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 789b773bae316..e1e35b688d581 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,9 +23,10 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SparkConf, SharedSparkContext} import org.apache.spark.serializer.KryoTest._ + class KryoSerializerSuite extends FunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", classOf[MyRegistrator].getName) @@ -207,39 +208,43 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x assert(10 + control.sum === result) } -} - -class KryoSerializerResizableOutputSuite extends FunSuite { - import org.apache.spark.SparkConf - import org.apache.spark.SparkContext - import org.apache.spark.LocalSparkContext - import org.apache.spark.SparkException - // trial and error showed this will not serialize with 1mb buffer - val x = (1 to 400000).toArray + test("kryo with nonexistent custom registrator should fail") { + import org.apache.spark.{SparkConf, SparkException} - test("kryo without resizable output buffer should fail on large array") { val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") - val sc = new SparkContext("local", "test", conf) - intercept[SparkException](sc.parallelize(x).collect) - LocalSparkContext.stop(sc) + conf.set("spark.kryo.registrator", "this.class.does.not.exist") + + val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance()) + assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist")) } - test("kryo with resizable output buffer should succeed on large array") { - val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") - val sc = new SparkContext("local", "test", conf) - assert(sc.parallelize(x).collect === x) - LocalSparkContext.stop(sc) + test("default class loader can be set by a different thread") { + val ser = new KryoSerializer(new SparkConf) + + // First serialize the object + val serInstance = ser.newInstance() + val bytes = serInstance.serialize(new ClassLoaderTestingObject) + + // Deserialize the object to make sure normal deserialization works + serInstance.deserialize[ClassLoaderTestingObject](bytes) + + // Set a special, broken ClassLoader and make sure we get an exception on deserialization + ser.setDefaultClassLoader(new ClassLoader() { + override def loadClass(name: String) = throw new UnsupportedOperationException + }) + intercept[UnsupportedOperationException] { + ser.newInstance().deserialize[ClassLoaderTestingObject](bytes) + } } } + +class ClassLoaderTestingObject + + object KryoTest { + case class CaseClass(i: Int, s: String) {} class ClassWithNoArgConstructor { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 8dca2ebb312f5..bcbfe8baf36ad 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -17,18 +17,23 @@ package org.apache.spark.storage +import java.io.IOException +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global + import org.scalatest.{FunSuite, Matchers} -import org.scalatest.PrivateMethodTester._ import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} import org.mockito.stubbing.Answer import org.mockito.invocation.InvocationOnMock -import org.apache.spark._ import org.apache.spark.storage.BlockFetcherIterator._ -import org.apache.spark.network.{ConnectionManager, ConnectionManagerId, - Message} +import org.apache.spark.network.{ConnectionManager, Message} +import org.apache.spark.executor.ShuffleReadMetrics class BlockFetcherIteratorSuite extends FunSuite with Matchers { @@ -66,8 +71,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) ) - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, + new ShuffleReadMetrics()) iterator.initialize() @@ -117,8 +122,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) ) - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, + new ShuffleReadMetrics()) iterator.initialize() @@ -137,4 +142,90 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined") } + test("block fetch from remote fails using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + when(blockManager.connectionManager).thenReturn(connManager) + + val f = future { + throw new IOException("Send failed or we received an error ACK") + } + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1L), (blId2, 1L))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null, new ShuffleReadMetrics()) + + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (!r.isDefined) should be(true) + } + } + } + + test("block fetch from remote succeed using BasicBlockFetcherIterator") { + val blockManager = mock(classOf[BlockManager]) + val connManager = mock(classOf[ConnectionManager]) + when(blockManager.connectionManager).thenReturn(connManager) + + val blId1 = ShuffleBlockId(0,0,0) + val blId2 = ShuffleBlockId(0,1,0) + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1)) + val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2)) + val blockMessageArray = new BlockMessageArray( + Seq(blockMessage1, blockMessage2)) + + val bufferMessage = blockMessageArray.toBufferMessage + val buffer = ByteBuffer.allocate(bufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + bufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip() + arrayBuffer += buffer + + val f = future { + Message.createBufferMessage(arrayBuffer) + } + when(connManager.sendMessageReliably(any(), + any())).thenReturn(f) + when(blockManager.futureExecContext).thenReturn(global) + + when(blockManager.blockManagerId).thenReturn( + BlockManagerId("test-client", "test-client", 1, 0)) + when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) + + val bmId = BlockManagerId("test-server", "test-server",1 , 0) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1L), (blId2, 1L))) + ) + + val iterator = new BasicBlockFetcherIterator(blockManager, + blocksByAddress, null, new ShuffleReadMetrics()) + iterator.initialize() + iterator.foreach{ + case (_, r) => { + (r.isDefined) should be(true) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 0ac0269d7cfc1..f32ce6f9fcc7f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -24,8 +24,13 @@ import java.util.concurrent.TimeUnit import akka.actor._ import akka.pattern.ask import akka.util.Timeout +import org.apache.spark.shuffle.hash.HashShuffleManager + +import org.mockito.invocation.InvocationOnMock +import org.mockito.Matchers.any +import org.mockito.Mockito.{doAnswer, mock, spy, when} +import org.mockito.stubbing.Answer -import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ @@ -33,6 +38,7 @@ import org.scalatest.Matchers import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod +import org.apache.spark.network.{Message, ConnectionManagerId} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -56,6 +62,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) + val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -66,8 +73,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { - new BlockManager( - name, actorSystem, master, serializer, maxMem, conf, securityMgr, mapOutputTracker) + new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr, + mapOutputTracker, shuffleManager) } before { @@ -786,7 +793,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block store put failure") { // Use Java serializer so we can create an unserializable error. store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr, mapOutputTracker) + securityMgr, mapOutputTracker, shuffleManager) // The put should fail since a1 is not serializable. class UnserializableClass @@ -818,8 +825,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter val blockManager = mock(classOf[BlockManager]) val shuffleBlockManager = mock(classOf[ShuffleBlockManager]) when(shuffleBlockManager.conf).thenReturn(conf) - val diskBlockManager = new DiskBlockManager(shuffleBlockManager, - System.getProperty("java.io.tmpdir")) + val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) @@ -1000,6 +1006,109 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } + test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker, shuffleManager) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + val reqBufferMessage = reqBlockMessages.toBufferMessage + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + throw new Exception + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + // Test when exception was thrown during processing block messages + var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + + assert(ackMessage.isDefined, "When Exception was thrown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When Exception was thown in " + + "BlockManagerWorker#processBlockMessage, " + + "ackMessage should have error") + + val notBufferMessage = mock(classOf[Message]) + + // Test when not BufferMessage was received + ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should be defined") + assert(ackMessage.get.hasError, "When not BufferMessage was passed to " + + "BlockManagerWorker#onBlockMessageReceive, " + + "ackMessage should have error") + } + + test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker, shuffleManager) + + val worker = spy(new BlockManagerWorker(store)) + val connManagerId = mock(classOf[ConnectionManagerId]) + + // setup request block messages + val reqBlId1 = ShuffleBlockId(0,0,0) + val reqBlId2 = ShuffleBlockId(0,1,0) + val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) + val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) + val reqBlockMessages = new BlockMessageArray( + Seq(reqBlockMessage1, reqBlockMessage2)) + + val tmpBufferMessage = reqBlockMessages.toBufferMessage + val buffer = ByteBuffer.allocate(tmpBufferMessage.size) + val arrayBuffer = new ArrayBuffer[ByteBuffer] + tmpBufferMessage.buffers.foreach{ b => + buffer.put(b) + } + buffer.flip() + arrayBuffer += buffer + val reqBufferMessage = Message.createBufferMessage(arrayBuffer) + + // setup ack block messages + val buf1 = ByteBuffer.allocate(4) + val buf2 = ByteBuffer.allocate(4) + buf1.putInt(1) + buf1.flip() + buf2.putInt(1) + buf2.flip() + val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1)) + val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2)) + + val answer = new Answer[Option[BlockMessage]] { + override def answer(invocation: InvocationOnMock) + :Option[BlockMessage]= { + if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq( + reqBlockMessage1)) { + return Some(ackBlockMessage1) + } else { + return Some(ackBlockMessage2) + } + } + } + + doAnswer(answer).when(worker).processBlockMessage(any()) + + val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) + assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should be defined") + assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " + + "was executed successfully, ackMessage should not have error") + } + test("reserve/release unroll memory") { store = makeBlockManager(12000) val memoryStore = store.memoryStore diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala new file mode 100644 index 0000000000000..bbc7e1357b90d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import org.scalatest.FunSuite +import java.io.File +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.SparkConf + +class BlockObjectWriterSuite extends FunSuite { + test("verify write metrics") { + val file = new File("somefile") + file.deleteOnExit() + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20)) + // Metrics don't update on every write + assert(writeMetrics.shuffleBytesWritten == 0) + // After 32 writes, metrics should update + for (i <- 0 until 32) { + writer.flush() + writer.write(Long.box(i)) + } + assert(writeMetrics.shuffleBytesWritten > 0) + writer.commitAndClose() + assert(file.length() == writeMetrics.shuffleBytesWritten) + } + + test("verify write metrics on revert") { + val file = new File("somefile") + file.deleteOnExit() + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file, + new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20)) + // Metrics don't update on every write + assert(writeMetrics.shuffleBytesWritten == 0) + // After 32 writes, metrics should update + for (i <- 0 until 32) { + writer.flush() + writer.write(Long.box(i)) + } + assert(writeMetrics.shuffleBytesWritten > 0) + writer.revertPartialWritesAndClose() + assert(writeMetrics.shuffleBytesWritten == 0) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 985ac9394738c..aabaeadd7a071 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import org.apache.spark.shuffle.hash.HashShuffleManager + import scala.collection.mutable import scala.language.reflectiveCalls @@ -30,6 +32,7 @@ import org.apache.spark.SparkConf import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.executor.ShuffleWriteMetrics class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { private val testConf = new SparkConf(false) @@ -41,7 +44,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before // so we coerce consolidation if not already enabled. testConf.set("spark.shuffle.consolidateFiles", "true") - val shuffleBlockManager = new ShuffleBlockManager(null) { + private val shuffleManager = new HashShuffleManager(testConf.clone) + + val shuffleBlockManager = new ShuffleBlockManager(null, shuffleManager) { override def conf = testConf.clone var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) @@ -66,7 +71,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before } override def beforeEach() { - diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs) + val conf = testConf.clone + conf.set("spark.local.dir", rootDirs) + diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) shuffleBlockManager.idToSegmentMap.clear() } @@ -147,13 +154,13 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))), confCopy) val store = new BlockManager("", actorSystem, master , serializer, confCopy, - securityManager, null) + securityManager, null, shuffleManager) try { val shuffleManager = store.shuffleBlockManager - val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer) + val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer, new ShuffleWriteMetrics) for (writer <- shuffle1.writers) { writer.write("test1") writer.write("test2") @@ -165,7 +172,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before val shuffle1Segment = shuffle1.writers(0).fileSegment() shuffle1.releaseWriters(success = true) - val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf)) + val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf), + new ShuffleWriteMetrics) for (writer <- shuffle2.writers) { writer.write("test3") @@ -183,7 +191,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before // of block based on remaining data in file : which could mess things up when there is concurrent read // and writes happening to the same shuffle group. - val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf)) + val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf), + new ShuffleWriteMetrics) for (writer <- shuffle3.writers) { writer.write("test3") writer.write("test4") diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala new file mode 100644 index 0000000000000..dae7bf0e336de --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.File + +import org.apache.spark.util.Utils +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf + + +/** + * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. + */ +class LocalDirsSuite extends FunSuite { + + test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { + // Regression test for SPARK-2974 + assert(!new File("/NONEXISTENT_DIR").exists()) + val conf = new SparkConf(false) + .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") + assert(new File(Utils.getLocalDir(conf)).exists()) + } + + test("SPARK_LOCAL_DIRS override also affects driver") { + // Regression test for SPARK-2975 + assert(!new File("/NONEXISTENT_DIR").exists()) + // SPARK_LOCAL_DIRS is a valid directory: + class MySparkConf extends SparkConf(false) { + override def getenv(name: String) = { + if (name == "SPARK_LOCAL_DIRS") System.getProperty("java.io.tmpdir") + else super.getenv(name) + } + + override def clone: SparkConf = { + new MySparkConf().setAll(settings) + } + } + // spark.local.dir only contains invalid directories, but that's not a problem since + // SPARK_LOCAL_DIRS will override it on both the driver and workers: + val conf = new MySparkConf().set("spark.local.dir", "/NONEXISTENT_PATH") + assert(new File(Utils.getLocalDir(conf)).exists()) + } + +} diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index cb8252515238e..147ec0bc52e39 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.{LocalSparkContext, SparkConf, Success} -import org.apache.spark.executor.{ShuffleWriteMetrics, ShuffleReadMetrics, TaskMetrics} +import org.apache.spark.executor._ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -65,7 +65,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated shuffleRead shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 var task = new ShuffleMapTask(0) @@ -142,7 +142,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() val shuffleWriteMetrics = new ShuffleWriteMetrics() - taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) shuffleReadMetrics.remoteBytesRead = base + 1 shuffleReadMetrics.remoteBlocksFetched = base + 2 @@ -150,6 +150,9 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.executorRunTime = base + 4 taskMetrics.diskBytesSpilled = base + 5 taskMetrics.memoryBytesSpilled = base + 6 + val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + taskMetrics.inputMetrics = Some(inputMetrics) + inputMetrics.bytesRead = base + 7 taskMetrics } @@ -182,6 +185,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(stage1Data.diskBytesSpilled == 205) assert(stage0Data.memoryBytesSpilled == 112) assert(stage1Data.memoryBytesSpilled == 206) + assert(stage0Data.inputBytes == 114) + assert(stage1Data.inputBytes == 207) assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get .totalBlocksFetched == 2) assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get @@ -208,6 +213,8 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc assert(stage1Data.diskBytesSpilled == 610) assert(stage0Data.memoryBytesSpilled == 412) assert(stage1Data.memoryBytesSpilled == 612) + assert(stage0Data.inputBytes == 414) + assert(stage1Data.inputBytes == 614) assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get .totalBlocksFetched == 302) assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 2002a817d9168..97ffb07662482 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -539,7 +539,7 @@ class JsonProtocolSuite extends FunSuite { sr.localBlocksFetched = e sr.fetchWaitTime = a + d sr.remoteBlocksFetched = f - t.updateShuffleReadMetrics(sr) + t.setShuffleReadMetrics(Some(sr)) } sw.shuffleBytesWritten = a + b + c sw.shuffleWriteTime = b + c + d diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 57dcb4ffabac1..706faed980f31 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer -import org.scalatest.FunSuite +import org.scalatest.{PrivateMethodTester, FunSuite} import org.apache.spark._ import org.apache.spark.SparkContext._ -class ExternalSorterSuite extends FunSuite with LocalSparkContext { +class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { private def createSparkConf(loadDefaults: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) // Make the Java serializer write a reset instruction (TC_RESET) after each object to test @@ -36,6 +36,16 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { conf } + private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { + val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) + assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort") + } + + private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = { + val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort) + assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort") + } + test("empty data stream") { val conf = new SparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") @@ -86,28 +96,28 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { // Both aggregator and ordering val sorter = new ExternalSorter[Int, Int, Int]( Some(agg), Some(new HashPartitioner(7)), Some(ord), None) - sorter.write(elements.iterator) + sorter.insertAll(elements.iterator) assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter.stop() // Only aggregator val sorter2 = new ExternalSorter[Int, Int, Int]( Some(agg), Some(new HashPartitioner(7)), None, None) - sorter2.write(elements.iterator) + sorter2.insertAll(elements.iterator) assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter2.stop() // Only ordering val sorter3 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), Some(ord), None) - sorter3.write(elements.iterator) + sorter3.insertAll(elements.iterator) assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter3.stop() // Neither aggregator nor ordering val sorter4 = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), None, None) - sorter4.write(elements.iterator) + sorter4.insertAll(elements.iterator) assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter4.stop() } @@ -118,13 +128,37 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) - val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(7)), Some(ord), None) + assertDidNotBypassMergeSort(sorter) + sorter.insertAll(elements) + assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled + val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) + assert(iter.next() === (0, Nil)) + assert(iter.next() === (1, List((1, 1)))) + assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList)) + assert(iter.next() === (3, Nil)) + assert(iter.next() === (4, Nil)) + assert(iter.next() === (5, List((5, 5)))) + assert(iter.next() === (6, Nil)) + sorter.stop() + } + + test("empty partitions with spilling, bypass merge-sort") { + val conf = createSparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(7)), None, None) - sorter.write(elements) + assertBypassedMergeSort(sorter) + sorter.insertAll(elements) assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) assert(iter.next() === (0, Nil)) @@ -286,14 +320,43 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val ord = implicitly[Ordering[Int]] + + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assertDidNotBypassMergeSort(sorter) + sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + + val sorter2 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assertDidNotBypassMergeSort(sorter2) + sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) + assert(diskBlockManager.getAllFiles().length > 0) + assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) + sorter2.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in sorter, bypass merge-sort") { + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100000).iterator.map(i => (i, i))) + assertBypassedMergeSort(sorter) + sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) sorter.stop() assert(diskBlockManager.getAllBlocks().length === 0) val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter2.write((0 until 100000).iterator.map(i => (i, i))) + assertBypassedMergeSort(sorter2) + sorter2.insertAll((0 until 100000).iterator.map(i => (i, i))) assert(diskBlockManager.getAllFiles().length > 0) assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet) sorter2.stop() @@ -307,9 +370,35 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val ord = implicitly[Ordering[Int]] + + val sorter = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(3)), Some(ord), None) + assertDidNotBypassMergeSort(sorter) + intercept[SparkException] { + sorter.insertAll((0 until 100000).iterator.map(i => { + if (i == 99990) { + throw new SparkException("Intentional failure") + } + (i, i) + })) + } + assert(diskBlockManager.getAllFiles().length > 0) + sorter.stop() + assert(diskBlockManager.getAllBlocks().length === 0) + } + + test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") { + val conf = createSparkConf(true) // Load defaults, otherwise SPARK_HOME is not found + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager + val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) + assertBypassedMergeSort(sorter) intercept[SparkException] { - sorter.write((0 until 100000).iterator.map(i => { + sorter.insertAll((0 until 100000).iterator.map(i => { if (i == 99990) { throw new SparkException("Intentional failure") } @@ -365,7 +454,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100000).iterator.map(i => (i / 4, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i / 4, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet) @@ -381,7 +470,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100).iterator.map(i => (i / 2, i))) + sorter.insertAll((0 until 100).iterator.map(i => (i / 2, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) @@ -397,7 +486,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None) - sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) @@ -414,7 +503,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None) - sorter.write((0 until 100000).iterator.map(i => (i / 2, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i / 2, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet val expected = (0 until 3).map(p => { (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet) @@ -431,7 +520,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.write((0 until 100).iterator.map(i => (i, i))) + sorter.insertAll((0 until 100).iterator.map(i => (i, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq val expected = (0 until 3).map(p => { (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) @@ -448,7 +537,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val ord = implicitly[Ordering[Int]] val sorter = new ExternalSorter[Int, Int, Int]( None, Some(new HashPartitioner(3)), Some(ord), None) - sorter.write((0 until 100000).iterator.map(i => (i, i))) + sorter.insertAll((0 until 100000).iterator.map(i => (i, i))) val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq val expected = (0 until 3).map(p => { (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq) @@ -495,7 +584,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++ collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap) - sorter.write(toInsert) + sorter.insertAll(toInsert) // A map of collision pairs in both directions val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap @@ -524,7 +613,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1) - sorter.write(toInsert.iterator) + sorter.insertAll(toInsert.iterator) val it = sorter.iterator var count = 0 @@ -548,7 +637,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) - sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) + sorter.insertAll((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) val it = sorter.iterator while (it.hasNext) { @@ -572,7 +661,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( Some(agg), None, None, None) - sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( + sorter.insertAll((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator( (null.asInstanceOf[String], "1"), ("1", null.asInstanceOf[String]), (null.asInstanceOf[String], null.asInstanceOf[String]) @@ -584,4 +673,38 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext { it.next() } } + + test("conditions for bypassing merge-sort") { + val conf = createSparkConf(false) + conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") + sc = new SparkContext("local", "test", conf) + + val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) + val ord = implicitly[Ordering[Int]] + + // Numbers of partitions that are above and below the default bypassMergeThreshold + val FEW_PARTITIONS = 50 + val MANY_PARTITIONS = 10000 + + // Sorters with no ordering or aggregator: should bypass unless # of partitions is high + + val sorter1 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None) + assertBypassedMergeSort(sorter1) + + val sorter2 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None) + assertDidNotBypassMergeSort(sorter2) + + // Sorters with an ordering or aggregator: should not bypass even if they have few partitions + + val sorter3 = new ExternalSorter[Int, Int, Int]( + None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None) + assertDidNotBypassMergeSort(sorter3) + + val sorter4 = new ExternalSorter[Int, Int, Int]( + Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) + assertDidNotBypassMergeSort(sorter4) + } } diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 42473629d4f15..28f26d2368254 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -35,6 +35,12 @@ RELEASE_VERSION=${RELEASE_VERSION:-1.0.0} RC_NAME=${RC_NAME:-rc2} USER_NAME=${USER_NAME:-pwendell} +if [ -z "$JAVA_HOME" ]; then + echo "Error: JAVA_HOME is not set, cannot proceed." + exit -1 +fi +JAVA_7_HOME=${JAVA_7_HOME:-$JAVA_HOME} + set -e GIT_TAG=v$RELEASE_VERSION-$RC_NAME @@ -111,12 +117,13 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" -make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & make_binary_release "hadoop2" \ - "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Phive -Phive-thriftserver -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & make_binary_release "hadoop2-without-hive" \ - "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" + "-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -Pyarn.version=2.2.0" & +wait # Copy data echo "Copying release tarballs" @@ -130,7 +137,8 @@ scp spark-* \ cd spark sbt/sbt clean cd docs -PRODUCTION=1 jekyll build +# Compile docs with Java 7 to use nicer format +JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build echo "Copying release documentation" rc_docs_folder=${rc_folder}-docs ssh $USER_NAME@people.apache.org \ diff --git a/dev/lint-python b/dev/lint-python new file mode 100755 index 0000000000000..4efddad839387 --- /dev/null +++ b/dev/lint-python @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" +PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" + +cd $SPARK_ROOT_DIR + +# Get pep8 at runtime so that we don't rely on it being installed on the build server. +#+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 +#+ TODOs: +#+ - Dynamically determine latest release version of pep8 and use that. +#+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?)) +PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py" +PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py" + +curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" +curl_status=$? + +if [ $curl_status -ne 0 ]; then + echo "Failed to download pep8.py from \"$PEP8_SCRIPT_REMOTE_PATH\"." + exit $curl_status +fi + + +# There is no need to write this output to a file +#+ first, but we do so so that the check status can +#+ be output before the report, like with the +#+ scalastyle and RAT checks. +python $PEP8_SCRIPT_PATH ./python > "$PEP8_REPORT_PATH" +pep8_status=${PIPESTATUS[0]} #$? + +if [ $pep8_status -ne 0 ]; then + echo "PEP 8 checks failed." + cat "$PEP8_REPORT_PATH" +else + echo "PEP 8 checks passed." +fi + +rm -f "$PEP8_REPORT_PATH" +rm "$PEP8_SCRIPT_PATH" + +exit $pep8_status diff --git a/.travis.yml b/dev/lint-scala old mode 100644 new mode 100755 similarity index 68% rename from .travis.yml rename to dev/lint-scala index 8ebd0d68429fc..c676dfdf4f44e --- a/.travis.yml +++ b/dev/lint-scala @@ -1,3 +1,6 @@ +#!/usr/bin/env bash + +# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -12,21 +15,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - language: scala - scala: - - "2.10.3" - jdk: - - oraclejdk7 - env: - matrix: - - TEST="scalastyle assembly/assembly" - - TEST="catalyst/test sql/test streaming/test mllib/test graphx/test bagel/test" - - TEST=hive/test - cache: - directories: - - $HOME/.m2 - - $HOME/.ivy2 - - $HOME/.sbt - script: - - "sbt ++$TRAVIS_SCALA_VERSION $TEST" +# + +SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" +SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" + +"$SCRIPT_DIR/scalastyle" diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 53df9b5a3f1d5..d48c8bde12905 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -74,8 +74,10 @@ def fail(msg): def run_cmd(cmd): if isinstance(cmd, list): + print " ".join(cmd) return subprocess.check_output(cmd) else: + print cmd return subprocess.check_output(cmd.split(" ")) diff --git a/dev/mima b/dev/mima index 4c3e65039b160..09e4482af5f3d 100755 --- a/dev/mima +++ b/dev/mima @@ -26,7 +26,9 @@ cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update -export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) -printf "%p:" ` +export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"` +echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" + ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" ret_val=$? diff --git a/dev/run-tests b/dev/run-tests index d401c90f41d7b..132f696d6447a 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -58,7 +58,7 @@ if [ -n "$AMPLAB_JENKINS" ]; then diffs=`git diff --name-only master | grep "^sql/"` if [ -n "$diffs" ]; then echo "Detected changes in SQL. Will run Hive test suite." - export _RUN_SQL_TESTS=true # exported for PySpark tests + _RUN_SQL_TESTS=true fi fi @@ -66,34 +66,54 @@ fi set -e set -o pipefail +echo "" echo "=========================================================================" echo "Running Apache RAT checks" echo "=========================================================================" dev/check-license +echo "" echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" -dev/scalastyle +dev/lint-scala +echo "" +echo "=========================================================================" +echo "Running Python style checks" +echo "=========================================================================" +dev/lint-python + +echo "" echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" +# Build Spark; we always build with Hive because the PySpark SparkSQL tests need it. +# echo "q" is needed because sbt on encountering a build file with failure +# (either resolution or compilation) prompts the user for input either q, r, +# etc to quit or retry. This echo is there to make it not block. +BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver " +echo -e "q\n" | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly | \ + grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + +# If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled: if [ -n "$_RUN_SQL_TESTS" ]; then SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" fi -# echo "q" is needed because sbt on encountering a build file with failure -# (either resolution or compilation) prompts the user for input either q, r, +# echo "q" is needed because sbt on encountering a build file with failure +# (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. -echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS clean package assembly/assembly test | \ +echo -e "q\n" | sbt/sbt $SBT_MAVEN_PROFILES_ARGS test | \ grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" +echo "" echo "=========================================================================" echo "Running PySpark tests" echo "=========================================================================" ./python/run-tests +echo "" echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 3076eb847b420..31506e28e05af 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -19,67 +19,148 @@ # Wrapper script that runs the Spark tests then reports QA results # to github via its API. +# Environment variables are populated by the code here: +#+ https://github.com/jenkinsci/ghprb-plugin/blob/master/src/main/java/org/jenkinsci/plugins/ghprb/GhprbTrigger.java#L139 # Go to the Spark project root directory FWDIR="$(cd `dirname $0`/..; pwd)" cd "$FWDIR" COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" +PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" -function post_message { - message=$1 - data="{\"body\": \"$message\"}" - echo "Attempting to post to Github:" - echo "$data" +COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" +# GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( +SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" - curl -D- -u x-oauth-basic:$GITHUB_OAUTH_KEY -X POST --data "$data" -H \ - "Content-Type: application/json" \ - $COMMENTS_URL | head -n 8 +# NOTE: Jenkins will kill the whole build after 120 minutes. +# Tests are a large part of that, but not all of it. +TESTS_TIMEOUT="120m" + +function post_message () { + local message=$1 + local data="{\"body\": \"$message\"}" + local HTTP_CODE_HEADER="HTTP Response Code: " + + echo "Attempting to post to Github..." + + local curl_output=$( + curl `#--dump-header -` \ + --silent \ + --user x-oauth-basic:$GITHUB_OAUTH_KEY \ + --request POST \ + --data "$data" \ + --write-out "${HTTP_CODE_HEADER}%{http_code}\n" \ + --header "Content-Type: application/json" \ + "$COMMENTS_URL" #> /dev/null #| "$FWDIR/dev/jq" .id #| head -n 8 + ) + local curl_status=${PIPESTATUS[0]} + + if [ "$curl_status" -ne 0 ]; then + echo "Failed to post message to GitHub." >&2 + echo " > curl_status: ${curl_status}" >&2 + echo " > curl_output: ${curl_output}" >&2 + echo " > data: ${data}" >&2 + # exit $curl_status + fi + + local api_response=$( + echo "${curl_output}" \ + | grep -v -e "^${HTTP_CODE_HEADER}" + ) + + local http_code=$( + echo "${curl_output}" \ + | grep -e "^${HTTP_CODE_HEADER}" \ + | sed -r -e "s/^${HTTP_CODE_HEADER}//g" + ) + + if [ -n "$http_code" ] && [ "$http_code" -ne "201" ]; then + echo " > http_code: ${http_code}." >&2 + echo " > api_response: ${api_response}" >&2 + echo " > data: ${data}" >&2 + fi + + if [ "$curl_status" -eq 0 ] && [ "$http_code" -eq "201" ]; then + echo " > Post successful." + fi } -start_message="QA tests have started for PR $ghprbPullId." -if [ "$sha1" == "$ghprbActualCommit" ]; then - start_message="$start_message This patch DID NOT merge cleanly! " -else - start_message="$start_message This patch merges cleanly. " -fi -start_message="$start_message
    View progress: " -start_message="$start_message${BUILD_URL}consoleFull" - -post_message "$start_message" - -./dev/run-tests -test_result="$?" - -result_message="QA results for PR $ghprbPullId:
    " - -if [ "$test_result" -eq "0" ]; then - result_message="$result_message- This patch PASSES unit tests.
    " -else - result_message="$result_message- This patch FAILED unit tests.
    " -fi - -if [ "$sha1" != "$ghprbActualCommit" ]; then - result_message="$result_message- This patch merges cleanly
    " - non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ") - new_public_classes=$(git diff master $non_test_files \ - | grep -e "trait " -e "class " \ - | grep -e "{" -e "(" \ - | grep -v -e \@\@ -e private \ - | grep \+ \ - | sed "s/\+ *//" \ - | tr "\n" "~" \ - | sed "s/~/
    /g") - if [ "$new_public_classes" == "" ]; then - result_message="$result_message- This patch adds no public classes
    " +# check PR merge-ability and check for new public classes +{ + if [ "$sha1" == "$ghprbActualCommit" ]; then + merge_note=" * This patch **does not** merge cleanly!" else - result_message="$result_message- This patch adds the following public classes (experimental):
    " - result_message="$result_message$new_public_classes" + merge_note=" * This patch merges cleanly." + + non_test_files=$(git diff master --name-only | grep -v "\/test" | tr "\n" " ") + new_public_classes=$( + git diff master ${non_test_files} `# diff this patch against master and...` \ + | grep "^\+" `# filter in only added lines` \ + | sed -r -e "s/^\+//g" `# remove the leading +` \ + | grep -e "trait " -e "class " `# filter in lines with these key words` \ + | grep -e "{" -e "(" `# filter in lines with these key words, too` \ + | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ + | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ + | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ + | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | tr -d "\n" `# remove actual LF characters` + ) + + if [ "$new_public_classes" == "" ]; then + public_classes_note=" * This patch adds no public classes." + else + public_classes_note=" * This patch adds the following public classes _(experimental)_:" + public_classes_note="${public_classes_note}\n${new_public_classes}" + fi fi -fi -result_message="${result_message}
    For more information see test ouptut:" -result_message="${result_message}
    ${BUILD_URL}consoleFull" +} -post_message "$result_message" +# post start message +{ + start_message="\ + [QA tests have started](${BUILD_URL}consoleFull) for \ + PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." + + start_message="${start_message}\n${merge_note}" + # start_message="${start_message}\n${public_classes_note}" + + post_message "$start_message" +} + +# run tests +{ + timeout "${TESTS_TIMEOUT}" ./dev/run-tests + test_result="$?" + + if [ "$test_result" -eq "124" ]; then + fail_message="**Tests timed out** after a configured wait of \`${TESTS_TIMEOUT}\`." + post_message "$fail_message" + exit $test_result + else + if [ "$test_result" -eq "0" ]; then + test_result_note=" * This patch **passes** unit tests." + else + test_result_note=" * This patch **fails** unit tests." + fi + fi +} + +# post end message +{ + result_message="\ + [QA tests have finished](${BUILD_URL}consoleFull) for \ + PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." + + result_message="${result_message}\n${test_result_note}" + result_message="${result_message}\n${merge_note}" + result_message="${result_message}\n${public_classes_note}" + + post_message "$result_message" +} exit $test_result diff --git a/dev/scalastyle b/dev/scalastyle index d9f2b91a3a091..b53053a04ff42 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -30,5 +30,5 @@ if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" exit 1 else - echo -e "Scalastyle checks passed.\n" + echo -e "Scalastyle checks passed." fi diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 672d0ef114f6d..4d87ab92cec5b 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -96,6 +96,15 @@ mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package {% endhighlight %} +# Building Thrift JDBC server and CLI for Spark SQL + +Spark SQL supports Thrift JDBC server and CLI. +See sql-programming-guide.md for more information about those features. +You can use those features by setting `-Phive-thriftserver` when building Spark as follows. +{% highlight bash %} +mvn -Phive-thriftserver assembly +{% endhighlight %} + # Spark Tests in Maven Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). diff --git a/docs/configuration.md b/docs/configuration.md index 5e3eb0f0871af..981170d8b49b7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -281,6 +281,24 @@ Apart from these, the following properties are also available, and may be useful overhead per reduce task, so keep it small unless you have a large amount of memory. + + spark.shuffle.manager + HASH + + Implementation to use for shuffling data. A hash-based shuffle manager is the default, but + starting in Spark 1.1 there is an experimental sort-based shuffle manager that is more + memory-efficient in environments with small executors, such as YARN. To use that, change + this value to SORT. + + + + spark.shuffle.sort.bypassMergeThreshold + 200 + + (Advanced) In the sort-based shuffle manager, avoid merge-sorting data if there is no + map-side aggregation and there are at most this many reduce partitions. + + #### Spark UI @@ -355,10 +373,12 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec - org.apache.spark.io.
    SnappyCompressionCodec + snappy - The codec used to compress internal data such as RDD partitions and shuffle outputs. - By default, Spark provides three codecs: org.apache.spark.io.LZ4CompressionCodec, + The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, + Spark provides three codecs: lz4, lzf, and snappy. You + can also use fully qualified class names to specify the codec, e.g. + org.apache.spark.io.LZ4CompressionCodec, org.apache.spark.io.LZFCompressionCodec, and org.apache.spark.io.SnappyCompressionCodec. @@ -542,7 +562,7 @@ Apart from these, the following properties are also available, and may be useful - spark.hadoop.validateOutputSpecs + spark.hadoop.validateOutputSpecs true If set to true, validates the output specification (e.g. checking if the output directory already exists) used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing @@ -550,7 +570,7 @@ Apart from these, the following properties are also available, and may be useful previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. - spark.executor.heartbeatInterval + spark.executor.heartbeatInterval 10000 Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let the driver know that the executor is still alive and update it with metrics for in-progress @@ -807,24 +827,34 @@ Apart from these, the following properties are also available, and may be useful - spark.scheduler.minRegisteredExecutorsRatio + spark.scheduler.minRegisteredResourcesRatio 0 - The minimum ratio of registered executors (registered executors / total expected executors) + The minimum ratio of registered resources (registered resources / total expected resources) + (resources are executors in yarn mode, CPU cores in standalone mode) to wait for before scheduling begins. Specified as a double between 0 and 1. - Regardless of whether the minimum ratio of executors has been reached, + Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config - spark.scheduler.maxRegisteredExecutorsWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime - spark.scheduler.maxRegisteredExecutorsWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime 30000 - Maximum amount of time to wait for executors to register before scheduling begins + Maximum amount of time to wait for resources to register before scheduling begins (in milliseconds). + + spark.localExecution.enabled + false + + Enables Spark to run certain jobs, such as first() or take() on the driver, without sending + tasks to the cluster. This can make certain jobs execute very quickly, but may require + shipping a whole partition of data to the driver. + + #### Security @@ -854,6 +884,15 @@ Apart from these, the following properties are also available, and may be useful out and giving up. + + spark.core.connection.ack.wait.timeout + 60 + + Number of seconds for the connection to wait for ack to occur before timing + out and giving up. To avoid unwilling timeout caused by long pause like GC, + you can set larger value. + + spark.ui.filters None diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 156a727026790..f5ac6d894e1eb 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -12,14 +12,16 @@ on the [Amazon Web Services site](http://aws.amazon.com/). `spark-ec2` is designed to manage multiple named clusters. You can launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster is -identified by placing its machines into EC2 security groups whose names -are derived from the name of the cluster. For example, a cluster named +shutdown an existing cluster, or log into a cluster. Each cluster +launches a set of instances, which are tagged with the cluster name, +and placed into EC2 security groups. If you don't specify a security +group, the `spark-ec2` script will create security groups based on the +cluster name you request. For example, a cluster named `test` will contain a master node in a security group called `test-master`, and a number of slave nodes in a security group called -`test-slaves`. The `spark-ec2` script will create these security groups -for you based on the cluster name you request. You can also use them to -identify machines belonging to each cluster in the Amazon EC2 Console. +`test-slaves`. You can also specify a security group prefix to be used +in place of the cluster name. Machines in a cluster can be identified +by looking for the "Name" tag of the instance in the Amazon EC2 Console. # Before You Start diff --git a/docs/mllib-basics.md b/docs/mllib-basics.md index f9585251fafac..8752df412950a 100644 --- a/docs/mllib-basics.md +++ b/docs/mllib-basics.md @@ -9,17 +9,17 @@ displayTitle: MLlib - Basics MLlib supports local vectors and matrices stored on a single machine, as well as distributed matrices backed by one or more RDDs. -In the current implementation, local vectors and matrices are simple data models -to serve public interfaces. The underlying linear algebra operations are provided by +Local vectors and local matrices are simple data models +that serve as public interfaces. The underlying linear algebra operations are provided by [Breeze](http://www.scalanlp.org/) and [jblas](http://jblas.org/). -A training example used in supervised learning is called "labeled point" in MLlib. +A training example used in supervised learning is called a "labeled point" in MLlib. ## Local vector A local vector has integer-typed and 0-based indices and double-typed values, stored on a single machine. MLlib supports two types of local vectors: dense and sparse. A dense vector is backed by a double array representing its entry values, while a sparse vector is backed by two parallel -arrays: indices and values. For example, a vector $(1.0, 0.0, 3.0)$ can be represented in dense +arrays: indices and values. For example, a vector `(1.0, 0.0, 3.0)` can be represented in dense format as `[1.0, 0.0, 3.0]` or in sparse format as `(3, [0, 2], [1.0, 3.0])`, where `3` is the size of the vector. @@ -44,8 +44,7 @@ val sv1: Vector = Vectors.sparse(3, Array(0, 2), Array(1.0, 3.0)) val sv2: Vector = Vectors.sparse(3, Seq((0, 1.0), (2, 3.0))) {% endhighlight %} -***Note*** - +***Note:*** Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`. @@ -110,8 +109,8 @@ sv2 = sps.csc_matrix((np.array([1.0, 3.0]), np.array([0, 2]), np.array([0, 2])), A labeled point is a local vector, either dense or sparse, associated with a label/response. In MLlib, labeled points are used in supervised learning algorithms. We use a double to store a label, so we can use labeled points in both regression and classification. -For binary classification, label should be either $0$ (negative) or $1$ (positive). -For multiclass classification, labels should be class indices staring from zero: $0, 1, 2, \ldots$. +For binary classification, a label should be either `0` (negative) or `1` (positive). +For multiclass classification, labels should be class indices starting from zero: `0, 1, 2, ...`.
    @@ -172,7 +171,7 @@ neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0])) It is very common in practice to have sparse training data. MLlib supports reading training examples stored in `LIBSVM` format, which is the default format used by [`LIBSVM`](http://www.csie.ntu.edu.tw/~cjlin/libsvm/) and -[`LIBLINEAR`](http://www.csie.ntu.edu.tw/~cjlin/liblinear/). It is a text format. Each line +[`LIBLINEAR`](http://www.csie.ntu.edu.tw/~cjlin/liblinear/). It is a text format in which each line represents a labeled sparse feature vector using the following format: ~~~ @@ -226,7 +225,7 @@ examples = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") ## Local matrix A local matrix has integer-typed row and column indices and double-typed values, stored on a single -machine. MLlib supports dense matrix, whose entry values are stored in a single double array in +machine. MLlib supports dense matrices, whose entry values are stored in a single double array in column major. For example, the following matrix `\[ \begin{pmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ @@ -234,7 +233,6 @@ column major. For example, the following matrix `\[ \begin{pmatrix} \end{pmatrix} \]` is stored in a one-dimensional array `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]` with the matrix size `(3, 2)`. -We are going to add sparse matrix in the next release.
    @@ -242,7 +240,7 @@ We are going to add sparse matrix in the next release. The base class of local matrices is [`Matrix`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrix), and we provide one implementation: [`DenseMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.DenseMatrix). -Sparse matrix will be added in the next release. We recommend using the factory methods implemented +We recommend using the factory methods implemented in [`Matrices`](api/scala/index.html#org.apache.spark.mllib.linalg.Matrices) to create local matrices. @@ -259,7 +257,7 @@ val dm: Matrix = Matrices.dense(3, 2, Array(1.0, 3.0, 5.0, 2.0, 4.0, 6.0)) The base class of local matrices is [`Matrix`](api/java/org/apache/spark/mllib/linalg/Matrix.html), and we provide one implementation: [`DenseMatrix`](api/java/org/apache/spark/mllib/linalg/DenseMatrix.html). -Sparse matrix will be added in the next release. We recommend using the factory methods implemented +We recommend using the factory methods implemented in [`Matrices`](api/java/org/apache/spark/mllib/linalg/Matrices.html) to create local matrices. @@ -279,28 +277,30 @@ Matrix dm = Matrices.dense(3, 2, new double[] {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); A distributed matrix has long-typed row and column indices and double-typed values, stored distributively in one or more RDDs. It is very important to choose the right format to store large and distributed matrices. Converting a distributed matrix to a different format may require a -global shuffle, which is quite expensive. We implemented three types of distributed matrices in -this release and will add more types in the future. +global shuffle, which is quite expensive. Three types of distributed matrices have been implemented +so far. The basic type is called `RowMatrix`. A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, e.g., a collection of feature vectors. It is backed by an RDD of its rows, where each row is a local vector. -We assume that the number of columns is not huge for a `RowMatrix`. +We assume that the number of columns is not huge for a `RowMatrix` so that a single +local vector can be reasonably communicated to the driver and can also be stored / +operated on using a single node. An `IndexedRowMatrix` is similar to a `RowMatrix` but with row indices, -which can be used for identifying rows and joins. -A `CoordinateMatrix` is a distributed matrix stored in [coordinate list (COO)](https://en.wikipedia.org/wiki/Sparse_matrix) format, +which can be used for identifying rows and executing joins. +A `CoordinateMatrix` is a distributed matrix stored in [coordinate list (COO)](https://en.wikipedia.org/wiki/Sparse_matrix#Coordinate_list_.28COO.29) format, backed by an RDD of its entries. ***Note*** The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size. -It is always error-prone to have non-deterministic RDDs. +In general the use of non-deterministic RDDs can lead to errors. ### RowMatrix A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD -of its rows, where each row is a local vector. This is similar to `data matrix` in the context of -multivariate statistics. Since each row is represented by a local vector, the number of columns is +of its rows, where each row is a local vector. +Since each row is represented by a local vector, the number of columns is limited by the integer range but it should be much smaller in practice.
    @@ -344,70 +344,10 @@ long n = mat.numCols();
    -#### Multivariate summary statistics - -We provide column summary statistics for `RowMatrix`. -If the number of columns is not large, say, smaller than 3000, you can also compute -the covariance matrix as a local matrix, which requires $\mathcal{O}(n^2)$ storage where $n$ is the -number of columns. The total CPU time is $\mathcal{O}(m n^2)$, where $m$ is the number of rows, -which could be faster if the rows are sparse. - -
    -
    - -[`RowMatrix#computeColumnSummaryStatistics`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) returns an instance of -[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight scala %} -import org.apache.spark.mllib.linalg.Matrix -import org.apache.spark.mllib.linalg.distributed.RowMatrix -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary - -val mat: RowMatrix = ... // a RowMatrix - -// Compute column summary statistics. -val summary: MultivariateStatisticalSummary = mat.computeColumnSummaryStatistics() -println(summary.mean) // a dense vector containing the mean value for each column -println(summary.variance) // column-wise variance -println(summary.numNonzeros) // number of nonzeros in each column - -// Compute the covariance matrix. -val cov: Matrix = mat.computeCovariance() -{% endhighlight %} -
    - -
    - -[`RowMatrix#computeColumnSummaryStatistics`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html#computeColumnSummaryStatistics()) returns an instance of -[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), -which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the -total count. - -{% highlight java %} -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; -import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; - -RowMatrix mat = ... // a RowMatrix - -// Compute column summary statistics. -MultivariateStatisticalSummary summary = mat.computeColumnSummaryStatistics(); -System.out.println(summary.mean()); // a dense vector containing the mean value for each column -System.out.println(summary.variance()); // column-wise variance -System.out.println(summary.numNonzeros()); // number of nonzeros in each column - -// Compute the covariance matrix. -Matrix cov = mat.computeCovariance(); -{% endhighlight %} -
    -
    - ### IndexedRowMatrix An `IndexedRowMatrix` is similar to a `RowMatrix` but with meaningful row indices. It is backed by -an RDD of indexed rows, which each row is represented by its index (long-typed) and a local vector. +an RDD of indexed rows, so that each row is represented by its index (long-typed) and a local vector.
    @@ -467,7 +407,7 @@ RowMatrix rowMat = mat.toRowMatrix(); A `CoordinateMatrix` is a distributed matrix backed by an RDD of its entries. Each entry is a tuple of `(i: Long, j: Long, value: Double)`, where `i` is the row index, `j` is the column index, and -`value` is the entry value. A `CoordinateMatrix` should be used only in the case when both +`value` is the entry value. A `CoordinateMatrix` should be used only when both dimensions of the matrix are huge and the matrix is very sparse.
    @@ -477,9 +417,9 @@ A [`CoordinateMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.CoordinateMatrix) can be created from an `RDD[MatrixEntry]` instance, where [`MatrixEntry`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.MatrixEntry) is a -wrapper over `(Long, Long, Double)`. A `CoordinateMatrix` can be converted to a `IndexedRowMatrix` -with sparse rows by calling `toIndexedRowMatrix`. In this release, we do not provide other -computation for `CoordinateMatrix`. +wrapper over `(Long, Long, Double)`. A `CoordinateMatrix` can be converted to an `IndexedRowMatrix` +with sparse rows by calling `toIndexedRowMatrix`. Other computations for +`CoordinateMatrix` are not currently supported. {% highlight scala %} import org.apache.spark.mllib.linalg.distributed.{CoordinateMatrix, MatrixEntry} @@ -503,8 +443,9 @@ A [`CoordinateMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.html) can be created from a `JavaRDD` instance, where [`MatrixEntry`](api/java/org/apache/spark/mllib/linalg/distributed/MatrixEntry.html) is a -wrapper over `(long, long, double)`. A `CoordinateMatrix` can be converted to a `IndexedRowMatrix` -with sparse rows by calling `toIndexedRowMatrix`. +wrapper over `(long, long, double)`. A `CoordinateMatrix` can be converted to an `IndexedRowMatrix` +with sparse rows by calling `toIndexedRowMatrix`. Other computations for +`CoordinateMatrix` are not currently supported. {% highlight java %} import org.apache.spark.api.java.JavaRDD; diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md new file mode 100644 index 0000000000000..719cc95767b00 --- /dev/null +++ b/docs/mllib-classification-regression.md @@ -0,0 +1,37 @@ +--- +layout: global +title: Classification and Regression - MLlib +displayTitle: MLlib - Classification and Regression +--- + +MLlib supports various methods for +[binary classification](http://en.wikipedia.org/wiki/Binary_classification), +[multiclass +classification](http://en.wikipedia.org/wiki/Multiclass_classification), and +[regression analysis](http://en.wikipedia.org/wiki/Regression_analysis). The table below outlines +the supported algorithms for each type of problem. + + + + + + + + + + + + + + + + +
    Problem TypeSupported Methods
    Binary Classificationlinear SVMs, logistic regression, decision trees, naive Bayes
    Multiclass Classificationdecision trees, naive Bayes
    Regressionlinear least squares, Lasso, ridge regression, decision trees
    + +More details for these methods can be found here: + +* [Linear models](mllib-linear-methods.html) + * [binary classification (SVMs, logistic regression)](mllib-linear-methods.html#binary-classification) + * [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) +* [Decision trees](mllib-decision-tree.html) +* [Naive Bayes](mllib-naive-bayes.html) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 561de48910132..dfd9cd572888c 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -38,7 +38,7 @@ a given dataset, the algorithm returns the best clustering result).
    -Following code snippets can be executed in `spark-shell`. +The following code snippets can be executed in `spark-shell`. In the following example after loading and parsing data, we use the [`KMeans`](api/scala/index.html#org.apache.spark.mllib.clustering.KMeans) object to cluster the data @@ -70,7 +70,7 @@ All of MLlib's methods use Java-friendly types, so you can import and call them way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by calling `.rdd()` on your `JavaRDD` object. A standalone application example -that is equivalent to the provided example in Scala is given bellow: +that is equivalent to the provided example in Scala is given below: {% highlight java %} import org.apache.spark.api.java.*; @@ -113,14 +113,15 @@ public class KMeansExample { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    -Following examples can be tested in the PySpark shell. +The following examples can be tested in the PySpark shell. In the following example after loading and parsing data, we use the KMeans object to cluster the data into two clusters. The number of desired clusters is passed to the algorithm. We then compute diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 0d28b5f7c89b3..ab10b2f01f87b 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -14,13 +14,13 @@ is commonly used for recommender systems. These techniques aim to fill in the missing entries of a user-item association matrix. MLlib currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. -In particular, we implement the [alternating least squares +MLlib uses the [alternating least squares (ALS)](http://dl.acm.org/citation.cfm?id=1608614) algorithm to learn these latent factors. The implementation in MLlib has the following parameters: * *numBlocks* is the number of blocks used to parallelize computation (set to -1 to auto-configure). -* *rank* is the number of latent factors in our model. +* *rank* is the number of latent factors in the model. * *iterations* is the number of iterations to run. * *lambda* specifies the regularization parameter in ALS. * *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for @@ -86,8 +86,8 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => println("Mean Squared Error = " + MSE) {% endhighlight %} -If the rating matrix is derived from other source of information (i.e., it is inferred from -other signals), you can use the trainImplicit method to get better results. +If the rating matrix is derived from another source of information (e.g., it is inferred from +other signals), you can use the `trainImplicit` method to get better results. {% highlight scala %} val alpha = 0.01 @@ -174,10 +174,11 @@ public class CollaborativeFiltering { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    @@ -219,5 +220,5 @@ model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) ## Tutorial -[AMP Camp](http://ampcamp.berkeley.edu/) provides a hands-on tutorial for -[personalized movie recommendation with MLlib](http://ampcamp.berkeley.edu/big-data-mini-course/movie-recommendation-with-mllib.html). +The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for +[personalized movie recommendation with MLlib](https://databricks-training.s3.amazonaws.com/movie-recommendation-with-mllib.html). diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 9cbd880897578..c01a92a9a1b26 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -84,8 +84,8 @@ Section 9.2.4 in [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for details). For example, for a binary classification problem with one categorical feature with three categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical -features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B -and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification +features are ordered as A followed by C followed B or A, C, B. The two split candidates are A \| C, B +and A , C \| B where \| denotes the split. A similar heuristic is used for multiclass classification when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value is used for ordering. diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 8e434998c15ea..065d646496131 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -9,9 +9,9 @@ displayTitle: MLlib - Dimensionality Reduction [Dimensionality reduction](http://en.wikipedia.org/wiki/Dimensionality_reduction) is the process of reducing the number of variables under consideration. -It is used to extract latent features from raw and noisy features, +It can be used to extract latent features from raw and noisy features or compress data while maintaining the structure. -In this release, we provide preliminary support for dimensionality reduction on tall-and-skinny matrices. +MLlib provides support for dimensionality reduction on tall-and-skinny matrices. ## Singular value decomposition (SVD) @@ -30,17 +30,17 @@ where * $V$ is an orthonormal matrix, whose columns are called right singular vectors. For large matrices, usually we don't need the complete factorization but only the top singular -values and its associated singular vectors. This can save storage, and more importantly, de-noise +values and its associated singular vectors. This can save storage, de-noise and recover the low-rank structure of the matrix. -If we keep the top $k$ singular values, then the dimensions of the return will be: +If we keep the top $k$ singular values, then the dimensions of the resulting low-rank matrix will be: * `$U$`: `$m \times k$`, * `$\Sigma$`: `$k \times k$`, * `$V$`: `$n \times k$`. -In this release, we provide SVD computation to row-oriented matrices that have only a few columns, -say, less than $1000$, but many rows, which we call *tall-and-skinny*. +MLlib provides SVD functionality to row-oriented matrices that have only a few columns, +say, less than $1000$, but many rows, i.e., *tall-and-skinny* matrices.
    @@ -58,15 +58,10 @@ val s: Vector = svd.s // The singular values are stored in a local dense vector. val V: Matrix = svd.V // The V factor is a local dense matrix. {% endhighlight %} -Same code applies to `IndexedRowMatrix`. -The only difference that the `U` matrix becomes an `IndexedRowMatrix`. +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`.
    -In order to run the following standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. - {% highlight java %} import java.util.LinkedList; @@ -104,8 +99,16 @@ public class SVD { } } {% endhighlight %} -Same code applies to `IndexedRowMatrix`. -The only difference that the `U` matrix becomes an `IndexedRowMatrix`. + +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`. + +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. +
    @@ -116,7 +119,7 @@ statistical method to find a rotation such that the first coordinate has the lar possible, and each succeeding coordinate in turn has the largest variance possible. The columns of the rotation matrix are called principal components. PCA is used widely in dimensionality reduction. -In this release, we implement PCA for tall-and-skinny matrices stored in row-oriented format. +MLlib supports PCA for tall-and-skinny matrices stored in row-oriented format.
    @@ -180,9 +183,10 @@ public class PCA { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md new file mode 100644 index 0000000000000..4b3cb715c58c7 --- /dev/null +++ b/docs/mllib-feature-extraction.md @@ -0,0 +1,73 @@ +--- +layout: global +title: Feature Extraction - MLlib +displayTitle: MLlib - Feature Extraction +--- + +* Table of contents +{:toc} + +## Word2Vec + +Word2Vec computes distributed vector representation of words. The main advantage of the distributed +representations is that similar words are close in the vector space, which makes generalization to +novel patterns easier and model estimation more robust. Distributed vector representation is +showed to be useful in many natural language processing applications such as named entity +recognition, disambiguation, parsing, tagging and machine translation. + +### Model + +In our implementation of Word2Vec, we used skip-gram model. The training objective of skip-gram is +to learn word vector representations that are good at predicting its context in the same sentence. +Mathematically, given a sequence of training words `$w_1, w_2, \dots, w_T$`, the objective of the +skip-gram model is to maximize the average log-likelihood +`\[ +\frac{1}{T} \sum_{t = 1}^{T}\sum_{j=-k}^{j=k} \log p(w_{t+j} | w_t) +\]` +where $k$ is the size of the training window. + +In the skip-gram model, every word $w$ is associated with two vectors $u_w$ and $v_w$ which are +vector representations of $w$ as word and context respectively. The probability of correctly +predicting word $w_i$ given word $w_j$ is determined by the softmax model, which is +`\[ +p(w_i | w_j ) = \frac{\exp(u_{w_i}^{\top}v_{w_j})}{\sum_{l=1}^{V} \exp(u_l^{\top}v_{w_j})} +\]` +where $V$ is the vocabulary size. + +The skip-gram model with softmax is expensive because the cost of computing $\log p(w_i | w_j)$ +is proportional to $V$, which can be easily in order of millions. To speed up training of Word2Vec, +we used hierarchical softmax, which reduced the complexity of computing of $\log p(w_i | w_j)$ to +$O(\log(V))$ + +### Example + +The example below demonstrates how to load a text file, parse it as an RDD of `Seq[String]`, +construct a `Word2Vec` instance and then fit a `Word2VecModel` with the input data. Finally, +we display the top 40 synonyms of the specified word. To run the example, first download +the [text8](http://mattmahoney.net/dc/text8.zip) data and extract it to your preferred directory. +Here we assume the extracted file is `text8` and in same directory as you run the spark shell. + +
    +
    +{% highlight scala %} +import org.apache.spark._ +import org.apache.spark.rdd._ +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.feature.Word2Vec + +val input = sc.textFile("text8").map(line => line.split(" ").toSeq) + +val word2vec = new Word2Vec() + +val model = word2vec.fit(input) + +val synonyms = model.findSynonyms("china", 40) + +for((synonym, cosineSimilarity) <- synonyms) { + println(s"$synonym $cosineSimilarity") +} +{% endhighlight %} +
    +
    + +## TFIDF \ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 95ee6bc96801f..ca0a84a8c53fd 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -3,18 +3,19 @@ layout: global title: Machine Learning Library (MLlib) --- -MLlib is a Spark implementation of some common machine learning algorithms and utilities, +MLlib is Spark's scalable machine learning library consisting of common learning algorithms and utilities, including classification, regression, clustering, collaborative -filtering, dimensionality reduction, as well as underlying optimization primitives: +filtering, dimensionality reduction, as well as underlying optimization primitives, as outlined below: -* [Basics](mllib-basics.html) - * data types +* [Data types](mllib-basics.html) +* [Basic statistics](mllib-stats.html) + * random data generation + * stratified sampling * summary statistics -* Classification and regression - * [linear support vector machine (SVM)](mllib-linear-methods.html#linear-support-vector-machine-svm) - * [logistic regression](mllib-linear-methods.html#logistic-regression) - * [linear least squares, Lasso, and ridge regression](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression) - * [decision tree](mllib-decision-tree.html) + * hypothesis testing +* [Classification and regression](mllib-classification-regression.html) + * [linear models (SVMs, logistic regression, linear regression)](mllib-linear-methods.html) + * [decision trees](mllib-decision-tree.html) * [naive Bayes](mllib-naive-bayes.html) * [Collaborative filtering](mllib-collaborative-filtering.html) * alternating least squares (ALS) @@ -23,17 +24,18 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv * [Dimensionality reduction](mllib-dimensionality-reduction.html) * singular value decomposition (SVD) * principal component analysis (PCA) -* [Optimization](mllib-optimization.html) +* [Feature extraction and transformation](mllib-feature-extraction.html) +* [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) -MLlib is a new component under active development. +MLlib is under active development. The APIs marked `Experimental`/`DeveloperApi` may change in future releases, -and we will provide migration guide between releases. +and the migration guide below will explain all changes between releases. # Dependencies -MLlib uses linear algebra packages [Breeze](http://www.scalanlp.org/), which depends on +MLlib uses the linear algebra package [Breeze](http://www.scalanlp.org/), which depends on [netlib-java](https://github.com/fommil/netlib-java), and [jblas](https://github.com/mikiobraun/jblas). `netlib-java` and `jblas` depend on native Fortran routines. @@ -56,7 +58,7 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few breaking changes. If your data is sparse, please store it in a sparse format instead of dense to -take advantage of sparsity in both storage and computation. +take advantage of sparsity in both storage and computation. Details are described below.
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 254201147edc1..9137f9dc1b692 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -33,24 +33,24 @@ the task of finding a minimizer of a convex function `$f$` that depends on a var Formally, we can write this as the optimization problem `$\min_{\wv \in\R^d} \; f(\wv)$`, where the objective function is of the form `\begin{equation} - f(\wv) := - \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i) + - \lambda\, R(\wv_i) + f(\wv) := \lambda\, R(\wv) + + \frac1n \sum_{i=1}^n L(\wv;\x_i,y_i) \label{eq:regPrimal} \ . \end{equation}` Here the vectors `$\x_i\in\R^d$` are the training data examples, for `$1\le i\le n$`, and `$y_i\in\R$` are their corresponding labels, which we want to predict. We call the method *linear* if $L(\wv; \x, y)$ can be expressed as a function of $\wv^T x$ and $y$. -Several MLlib's classification and regression algorithms fall into this category, +Several of MLlib's classification and regression algorithms fall into this category, and are discussed here. The objective function `$f$` has two parts: -the loss that measures the error of the model on the training data, -and the regularizer that measures the complexity of the model. -The loss function `$L(\wv;.)$` must be a convex function in `$\wv$`. -The fixed regularization parameter `$\lambda \ge 0$` (`regParam` in the code) defines the trade-off -between the two goals of small loss and small model complexity. +the regularizer that controls the complexity of the model, +and the loss that measures the error of the model on the training data. +The loss function `$L(\wv;.)$` is typically a convex function in `$\wv$`. The +fixed regularization parameter `$\lambda \ge 0$` (`regParam` in the code) +defines the trade-off between the two goals of minimizing the loss (i.e., +training error) and minimizing model complexity (i.e., to avoid overfitting). ### Loss functions @@ -80,10 +80,10 @@ methods MLlib supports: ### Regularizers -The purpose of the [regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to -encourage simple models, by punishing the complexity of the model `$\wv$`, in order to e.g. avoid -over-fitting. -We support the following regularizers in MLlib: +The purpose of the +[regularizer](http://en.wikipedia.org/wiki/Regularization_(mathematics)) is to +encourage simple models and avoid overfitting. We support the following +regularizers in MLlib: @@ -106,27 +106,28 @@ Here `$\mathrm{sign}(\wv)$` is the vector consisting of the signs (`$\pm1$`) of of `$\wv$`. L2-regularized problems are generally easier to solve than L1-regularized due to smoothness. -However, L1 regularization can help promote sparsity in weights, leading to simpler models, which is -also used for feature selection. It is not recommended to train models without any regularization, +However, L1 regularization can help promote sparsity in weights leading to smaller and more interpretable models, the latter of which can be useful for feature selection. +It is not recommended to train models without any regularization, especially when the number of training examples is small. ## Binary classification -[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) is to divide items into -two categories: positive and negative. MLlib supports two linear methods for binary classification: -linear support vector machine (SVM) and logistic regression. The training data set is represented -by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the mathematical -formulation, a training label $y$ is either $+1$ (positive) or $-1$ (negative), which is convenient -for the formulation. *However*, the negative label is represented by $0$ in MLlib instead of $-1$, -to be consistent with multiclass labeling. +[Binary classification](http://en.wikipedia.org/wiki/Binary_classification) +aims to divide items into two categories: positive and negative. MLlib +supports two linear methods for binary classification: linear support vector +machines (SVMs) and logistic regression. For both methods, MLlib supports +L1 and L2 regularized variants. The training data set is represented by an RDD +of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the +mathematical formulation in this guide, a training label $y$ is denoted as +either $+1$ (positive) or $-1$ (negative), which is convenient for the +formulation. *However*, the negative label is represented by $0$ in MLlib +instead of $-1$, to be consistent with multiclass labeling. -### Linear support vector machine (SVM) +### Linear support vector machines (SVMs) The [linear SVM](http://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM) -has become a standard choice for large-scale classification tasks. -The name "linear SVM" is actually ambiguous. -By "linear SVM", we mean specifically the linear method with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the hinge loss +is a standard method for large-scale classification tasks. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss function in the formulation given by the hinge loss: + `\[ L(\wv;\x,y) := \max \{0, 1-y \wv^T \x \}. \]` @@ -134,39 +135,44 @@ By default, linear SVMs are trained with an L2 regularization. We also support alternative L1 regularization. In this case, the problem becomes a [linear program](http://en.wikipedia.org/wiki/Linear_programming). -Linear SVM algorithm outputs a SVM model, which makes predictions based on the value of $\wv^T \x$. -By the default, if $\wv^T \x \geq 0$, the outcome is positive, or negative otherwise. -However, quite often in practice, the default threshold $0$ is not a good choice. -The threshold should be determined via model evaluation. +The linear SVMs algorithm outputs an SVM model. Given a new data point, +denoted by $\x$, the model makes predictions based on the value of $\wv^T \x$. +By the default, if $\wv^T \x \geq 0$ then the outcome is positive, and negative +otherwise. ### Logistic regression [Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a -binary response. It is a linear method with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the logistic loss +binary response. +It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss +function in the formulation given by the logistic loss: `\[ L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)). \]` -Logistic regression algorithm outputs a logistic regression model, which makes predictions by +The logistic regression algorithm outputs a logistic regression model. Given a +new data point, denoted by $\x$, the model makes predictions by applying the logistic function `\[ \mathrm{f}(z) = \frac{1}{1 + e^{-z}} \]` where $z = \wv^T \x$. -By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or negative otherwise. -For the same reason mentioned above, quite often in practice, this default threshold is not a good choice. -The threshold should be determined via model evaluation. +By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or +negative otherwise, though unlike linear SVMs, the raw output of the logistic regression +model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability +that $\x$ is positive). ### Evaluation metrics -MLlib supports common evaluation metrics for binary classification (not available in Python). This +MLlib supports common evaluation metrics for binary classification (not available in PySpark). +This includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score), [receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic), precision-recall curve, and [area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve). -Among the metrics, area under ROC is commonly used to compare models and precision/recall/F-measure -can help determine the threshold to use. +AUC is commonly used to compare the performance of various models while +precision/recall/F-measure can help determine the appropriate threshold to use +for prediction purposes. ### Examples @@ -233,8 +239,7 @@ svmAlg.optimizer. val modelL1 = svmAlg.run(training) {% endhighlight %} -Similarly, you can use replace `SVMWithSGD` by -[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD). +[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD) can be used in a similar fashion as `SVMWithSGD`. @@ -318,10 +323,11 @@ svmAlg.optimizer() final SVMModel modelL1 = svmAlg.run(training.rdd()); {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    @@ -354,24 +360,22 @@ print("Training Error = " + str(trainErr)) ## Linear least squares, Lasso, and ridge regression -Linear least squares is a family of linear methods with the loss function in formulation -`$\eqref{eq:regPrimal}$` given by the squared loss +Linear least squares is the most common formulation for regression problems. +It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss +function in the formulation given by the squared loss: `\[ L(\wv;\x,y) := \frac{1}{2} (\wv^T \x - y)^2. \]` -Depending on the regularization type, we call the method -[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or simply -[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) if there -is no regularization, [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) if L2 -regularization is used, and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) if L1 -regularization is used. This average loss $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$ is also +Various related regression methods are derived by using different types of regularization: +[*ordinary least squares*](http://en.wikipedia.org/wiki/Ordinary_least_squares) or +[*linear least squares*](http://en.wikipedia.org/wiki/Linear_least_squares_(mathematics)) uses + no regularization; [*ridge regression*](http://en.wikipedia.org/wiki/Ridge_regression) uses L2 +regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) uses L1 +regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error). -Note that the squared loss is sensitive to outliers. -Regularization or a robust alternative (e.g., $\ell_1$ regression) is usually necessary in practice. - ### Examples
    @@ -379,7 +383,7 @@ Regularization or a robust alternative (e.g., $\ell_1$ regression) is usually ne
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The example then uses LinearRegressionWithSGD to build a simple linear model to predict label -values. We compute the Mean Squared Error at the end to evaluate +values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight scala %} @@ -407,9 +411,8 @@ val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean() println("training Mean Squared Error = " + MSE) {% endhighlight %} -Similarly you can use [`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) -and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD). +and [`LassoWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) can be used in a similar fashion as `LinearRegressionWithSGD`.
    @@ -479,16 +482,17 @@ public class LinearRegression { } {% endhighlight %} -In order to run the above standalone application using Spark framework make -sure that you follow the instructions provided at section [Standalone -Applications](quick-start.html) of the quick-start guide. What is more, you -should include to your build file *spark-mllib* as a dependency. +In order to run the above standalone application, follow the instructions +provided in the [Standalone +Applications](quick-start.html#standalone-applications) section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency.
    The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint. The example then uses LinearRegressionWithSGD to build a simple linear model to predict label -values. We compute the Mean Squared Error at the end to evaluate +values. We compute the mean squared error at the end to evaluate [goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). {% highlight python %} @@ -514,6 +518,81 @@ print("Mean Squared Error = " + str(MSE))
    +## Streaming linear regression + +When data arrive in a streaming fashion, it is useful to fit regression models online, +updating the parameters of the model as new data arrives. MLlib currently supports +streaming linear regression using ordinary least squares. The fitting is similar +to that performed offline, except fitting occurs on each batch of data, so that +the model continually updates to reflect the data from the stream. + +### Examples + +The following example demonstrates how to load training and testing data from two different +input streams of text files, parse the streams as labeled points, fit a linear regression model +online to the first stream, and make predictions on the second stream. + +
    + +
    + +First, we import the necessary classes for parsing our input data and creating the model. + +{% highlight scala %} + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD + +{% endhighlight %} + +Then we make input streams for training and testing data. We assume a StreamingContext `ssc` +has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) +for more info. For this example, we use labeled points in training and testing streams, +but in practice you will likely want to use unlabeled vectors for test data. + +{% highlight scala %} + +val trainingData = ssc.textFileStream('/training/data/dir').map(LabeledPoint.parse) +val testData = ssc.textFileStream('/testing/data/dir').map(LabeledPoint.parse) + +{% endhighlight %} + +We create our model by initializing the weights to 0 + +{% highlight scala %} + +val numFeatures = 3 +val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.zeros(numFeatures)) + +{% endhighlight %} + +Now we register the streams for training and testing and start the job. +Printing predictions alongside true labels lets us easily see the result. + +{% highlight scala %} + +model.trainOn(trainingData) +model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + +ssc.start() +ssc.awaitTermination() + +{% endhighlight %} + +We can now save text files with data to the training or testing folders. +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions +will get better! + +
    + +
    + + ## Implementation (developer) Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index b1650c83c98b9..7f9d4c6563944 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -4,23 +4,24 @@ title: Naive Bayes - MLlib displayTitle: MLlib - Naive Bayes --- -Naive Bayes is a simple multiclass classification algorithm with the assumption of independence -between every pair of features. Naive Bayes can be trained very efficiently. Within a single pass to -the training data, it computes the conditional probability distribution of each feature given label, -and then it applies Bayes' theorem to compute the conditional probability distribution of label -given an observation and use it for prediction. For more details, please visit the Wikipedia page -[Naive Bayes classifier](http://en.wikipedia.org/wiki/Naive_Bayes_classifier). - -In MLlib, we implemented multinomial naive Bayes, which is typically used for document -classification. Within that context, each observation is a document, each feature represents a term, -whose value is the frequency of the term. For its formulation, please visit the Wikipedia page -[Multinomial Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes) -or the section -[Naive Bayes text classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) -from the book Introduction to Information -Retrieval. [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by +[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) is a simple +multiclass classification algorithm with the assumption of independence between +every pair of features. Naive Bayes can be trained very efficiently. Within a +single pass to the training data, it computes the conditional probability +distribution of each feature given label, and then it applies Bayes' theorem to +compute the conditional probability distribution of label given an observation +and use it for prediction. + +MLlib supports [multinomial naive +Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes), +which is typically used for [document +classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html). +Within that context, each observation is a document and each +feature represents a term whose value is the frequency of the term. +Feature values must be nonnegative to represent term frequencies. +[Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by setting the parameter $\lambda$ (default to $1.0$). For document classification, the input feature -vectors are usually sparse. Please supply sparse vectors as input to take advantage of +vectors are usually sparse, and sparse vectors should be supplied as input to take advantage of sparsity. Since the training data is only used once, it is not necessary to cache it. ## Examples diff --git a/docs/mllib-stats.md b/docs/mllib-stats.md new file mode 100644 index 0000000000000..f25dca746ba3a --- /dev/null +++ b/docs/mllib-stats.md @@ -0,0 +1,167 @@ +--- +layout: global +title: Statistics Functionality - MLlib +displayTitle: MLlib - Statistics Functionality +--- + +* Table of contents +{:toc} + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + +## Random data generation + +Random data generation is useful for randomized algorithms, prototyping, and performance testing. +MLlib supports generating random RDDs with i.i.d. values drawn from a given distribution: +uniform, standard normal, or Poisson. + +
    +
    +[`RandomRDDs`](api/scala/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.random.RandomRDDs._ + +val sc: SparkContext = ... + +// Generate a random double RDD that contains 1 million i.i.d. values drawn from the +// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +val u = normalRDD(sc, 1000000L, 10) +// Apply a transform to get a random double RDD following `N(1, 4)`. +val v = u.map(x => 1.0 + 2.0 * x) +{% endhighlight %} +
    + +
    +[`RandomRDDs`](api/java/index.html#org.apache.spark.mllib.random.RandomRDDs) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight java %} +import org.apache.spark.SparkContext; +import org.apache.spark.api.JavaDoubleRDD; +import static org.apache.spark.mllib.random.RandomRDDs.*; + +JavaSparkContext jsc = ... + +// Generate a random double RDD that contains 1 million i.i.d. values drawn from the +// standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +JavaDoubleRDD u = normalJavaRDD(jsc, 1000000L, 10); +// Apply a transform to get a random double RDD following `N(1, 4)`. +JavaDoubleRDD v = u.map( + new Function() { + public Double call(Double x) { + return 1.0 + 2.0 * x; + } + }); +{% endhighlight %} +
    + +
    +[`RandomRDDs`](api/python/pyspark.mllib.random.RandomRDDs-class.html) provides factory +methods to generate random double RDDs or vector RDDs. +The following example generates a random double RDD, whose values follows the standard normal +distribution `N(0, 1)`, and then map it to `N(1, 4)`. + +{% highlight python %} +from pyspark.mllib.random import RandomRDDs + +sc = ... # SparkContext + +# Generate a random double RDD that contains 1 million i.i.d. values drawn from the +# standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. +u = RandomRDDs.uniformRDD(sc, 1000000L, 10) +# Apply a transform to get a random double RDD following `N(1, 4)`. +v = u.map(lambda x: 1.0 + 2.0 * x) +{% endhighlight %} +
    + +
    + +## Stratified Sampling + +## Summary Statistics + +### Multivariate summary statistics + +We provide column summary statistics for `RowMatrix` (note: this functionality is not currently supported in `IndexedRowMatrix` or `CoordinateMatrix`). +If the number of columns is not large, e.g., on the order of thousands, then the +covariance matrix can also be computed as a local matrix, which requires $\mathcal{O}(n^2)$ storage where $n$ is the +number of columns. The total CPU time is $\mathcal{O}(m n^2)$, where $m$ is the number of rows, +and is faster if the rows are sparse. + +
    +
    + +[`computeColumnSummaryStatistics()`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.RowMatrix) returns an instance of +[`MultivariateStatisticalSummary`](api/scala/index.html#org.apache.spark.mllib.stat.MultivariateStatisticalSummary), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight scala %} +import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary + +val mat: RowMatrix = ... // a RowMatrix + +// Compute column summary statistics. +val summary: MultivariateStatisticalSummary = mat.computeColumnSummaryStatistics() +println(summary.mean) // a dense vector containing the mean value for each column +println(summary.variance) // column-wise variance +println(summary.numNonzeros) // number of nonzeros in each column + +// Compute the covariance matrix. +val cov: Matrix = mat.computeCovariance() +{% endhighlight %} +
    + +
    + +[`RowMatrix#computeColumnSummaryStatistics`](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html#computeColumnSummaryStatistics()) returns an instance of +[`MultivariateStatisticalSummary`](api/java/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.html), +which contains the column-wise max, min, mean, variance, and number of nonzeros, as well as the +total count. + +{% highlight java %} +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary; + +RowMatrix mat = ... // a RowMatrix + +// Compute column summary statistics. +MultivariateStatisticalSummary summary = mat.computeColumnSummaryStatistics(); +System.out.println(summary.mean()); // a dense vector containing the mean value for each column +System.out.println(summary.variance()); // column-wise variance +System.out.println(summary.numNonzeros()); // number of nonzeros in each column + +// Compute the covariance matrix. +Matrix cov = mat.computeCovariance(); +{% endhighlight %} +
    +
    + + +## Hypothesis Testing diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index cd6543945c385..34accade36ea9 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -605,6 +605,11 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may also use the beeline script comes with Hive. +To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, +users can set the `spark.sql.thriftserver.scheduler.pool` variable: + + SET spark.sql.thriftserver.scheduler.pool=accounting; + ### Migration Guide for Shark Users #### Reducer number diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index 1e045a3dd0ca9..27cd085782f66 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -186,7 +186,7 @@ JavaDStream words = lines.flatMap(new FlatMapFunction() ... {% endhighlight %} -The full source code is in the example [JavaCustomReceiver.java](https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/streaming/examples/JavaCustomReceiver.java). +The full source code is in the example [JavaCustomReceiver.java](https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java). @@ -215,7 +215,7 @@ And a new input stream can be created with this custom actor as val lines = ssc.actorStream[String](Props(new CustomActor()), "CustomReceiver") {% endhighlight %} -See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala) +See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) for an end-to-end example. diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md index 801c905c88df8..16ad3222105a2 100644 --- a/docs/streaming-kinesis.md +++ b/docs/streaming-kinesis.md @@ -3,56 +3,57 @@ layout: global title: Spark Streaming Kinesis Receiver --- -### Kinesis -Build notes: -
  • Spark supports a Kinesis Streaming Receiver which is not included in the default build due to licensing restrictions.
  • -
  • _**Note that by embedding this library you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • -
  • The Spark Kinesis Streaming Receiver source code, examples, tests, and artifacts live in $SPARK_HOME/extras/kinesis-asl.
  • -
  • To build with Kinesis, you must run the maven or sbt builds with -Pkinesis-asl`.
  • -
  • Applications will need to link to the 'spark-streaming-kinesis-asl` artifact.
  • +## Kinesis +###Design +
  • The KinesisReceiver uses the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License.
  • +
  • The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concept of Workers, Checkpoints, and Shard Leases.
  • +
  • The KCL uses DynamoDB to maintain all state. A DynamoDB table is created in the us-east-1 region (regardless of Kinesis stream region) during KCL initialization for each Kinesis application name.
  • +
  • A single KinesisReceiver can process many shards of a stream by spinning up multiple KinesisRecordProcessor threads.
  • +
  • You never need more KinesisReceivers than the number of shards in your stream as each will spin up at least one KinesisRecordProcessor thread.
  • +
  • Horizontal scaling is achieved by autoscaling additional KinesisReceiver (separate processes) or spinning up new KinesisRecordProcessor threads within each KinesisReceiver - up to the number of current shards for a given stream, of course. Don't forget to autoscale back down!
  • -Kinesis examples notes: -
  • To build the Kinesis examples, you must run the maven or sbt builds with -Pkinesis-asl`.
  • -
  • These examples automatically determine the number of local threads and KinesisReceivers to spin up based on the number of shards for the stream.
  • -
  • KinesisWordCountProducerASL will generate random data to put onto the Kinesis stream for testing.
  • -
  • Checkpointing is disabled (no checkpoint dir is set). The examples as written will not recover from a driver failure.
  • +### Build +
  • Spark supports a Streaming KinesisReceiver, but it is not included in the default build due to Amazon Software Licensing (ASL) restrictions.
  • +
  • To build with the Kinesis Streaming Receiver and supporting ASL-licensed code, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • +
  • All KinesisReceiver-related code, examples, tests, and artifacts live in **$SPARK_HOME/extras/kinesis-asl/**.
  • +
  • Kinesis-based Spark Applications will need to link to the **spark-streaming-kinesis-asl** artifact that is built when **-Pkinesis-asl** is specified.
  • +
  • _**Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • -Deployment and runtime notes: -
  • A single KinesisReceiver can process many shards of a stream.
  • -
  • Each shard of a stream is processed by one or more KinesisReceiver's managed by the Kinesis Client Library (KCL) Worker.
  • -
  • You never need more KinesisReceivers than the number of shards in your stream.
  • -
  • You can horizontally scale the receiving by creating more KinesisReceiver/DStreams (up to the number of shards for a given stream)
  • -
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the Kinesis Client Library.
  • -
  • This code uses the DefaultAWSCredentialsProviderChain and searches for credentials in the following order of precedence:
    - 1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    - 2) Java System Properties - aws.accessKeyId and aws.secretKey
    - 3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    - 4) Instance profile credentials - delivered through the Amazon EC2 metadata service
    -
  • -
  • You need to setup a Kinesis stream with 1 or more shards per the following:
    - http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • -
  • Valid Kinesis endpoint urls can be found here: Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • -
  • When you first start up the KinesisReceiver, the Kinesis Client Library (KCL) needs ~30s to establish connectivity with the AWS Kinesis service, -retrieve any checkpoint data, and negotiate with other KCL's reading from the same stream.
  • -
  • Be careful when changing the app name. Kinesis maintains a mapping table in DynamoDB based on this app name (http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization). -Changing the app name could lead to Kinesis errors as only 1 logical application can process a stream. In order to start fresh, -it's always best to delete the DynamoDB table that matches your app name. This DynamoDB table lives in us-east-1 regardless of the Kinesis endpoint URL.
  • +###Example +
  • To build the Kinesis example, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • +
  • You need to setup a Kinesis stream at one of the valid Kinesis endpoints with 1 or more shards per the following: http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • +
  • Valid Kinesis endpoints can be found here: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • +
  • When running **locally**, the example automatically determines the number of threads and KinesisReceivers to spin up based on the number of shards configured for the stream. Therefore, **local[n]** is not needed when starting the example as with other streaming examples.
  • +
  • While this example could use a single KinesisReceiver which spins up multiple KinesisRecordProcessor threads to process multiple shards, I wanted to demonstrate unioning multiple KinesisReceivers as a single DStream. (It's a bit confusing in local mode.)
  • +
  • **KinesisWordCountProducerASL** is provided to generate random records into the Kinesis stream for testing.
  • +
  • The example has been configured to immediately replicate incoming stream data to another node by using (StorageLevel.MEMORY_AND_DISK_2) +
  • Spark checkpointing is disabled because the example does not use any stateful or window-based DStream operations such as updateStateByKey and reduceByWindow. If those operations are introduced, you would need to enable checkpointing or risk losing data in the case of a failure.
  • +
  • Kinesis checkpointing is enabled. This means that the example will recover from a Kinesis failure.
  • +
  • The example uses InitialPositionInStream.LATEST strategy to pull from the latest tip of the stream if no Kinesis checkpoint info exists.
  • +
  • In our example, **KinesisWordCount** is the Kinesis application name for both the Scala and Java versions. The use of this application name is described next.
  • -Failure recovery notes: -
  • The combination of Spark Streaming and Kinesis creates 3 different checkpoints as follows:
    - 1) RDD data checkpoint (Spark Streaming) - frequency is configurable with DStream.checkpoint(Duration)
    - 2) RDD metadata checkpoint (Spark Streaming) - frequency is every DStream batch
    - 3) Kinesis checkpointing (Kinesis) - frequency is controlled by the developer calling ICheckpointer.checkpoint() directly
    +###Deployment and Runtime +
  • A Kinesis application name must be unique for a given account and region.
  • +
  • A DynamoDB table and CloudWatch namespace are created during KCL initialization using this Kinesis application name. http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization
  • +
  • This DynamoDB table lives in the us-east-1 region regardless of the Kinesis endpoint URL.
  • +
  • Changing the app name or stream name could lead to Kinesis errors as only a single logical application can process a single stream.
  • +
  • If you are seeing errors after changing the app name or stream name, it may be necessary to manually delete the DynamoDB table and start from scratch.
  • +
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the KCL.
  • +
  • The KinesisReceiver uses the DefaultAWSCredentialsProviderChain for AWS credentials which searches for credentials in the following order of precedence:
    +1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    +2) Java System Properties - aws.accessKeyId and aws.secretKey
    +3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    +4) Instance profile credentials - delivered through the Amazon EC2 metadata service
  • -
  • Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling
  • -
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last checkpoint sequence number recorded per shard.
  • -
  • If no checkpoint info exists, the worker will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) -or from the tip/latest (InitialPostitionInStream.LATEST). This is configurable.
  • -
  • When pulling from the stream tip (InitialPositionInStream.LATEST), only new stream data will be picked up after the KinesisReceiver starts.
  • -
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running.
  • -
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data -depending on the checkpoint frequency.
  • -
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records depending on the checkpoint frequency.
  • + +###Fault-Tolerance +
  • The combination of Spark Streaming and Kinesis creates 2 different checkpoints that may occur at different intervals.
  • +
  • Checkpointing too frequently against Kinesis will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random backoff retry strategy.
  • +
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last Kinesis checkpoint sequence number recorded per shard (stored in the DynamoDB table).
  • +
  • If no Kinesis checkpoint info exists, the KinesisReceiver will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable.
  • +
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running (and no checkpoint info is being stored.)
  • +
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data.
  • +
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency.
  • Record processing should be idempotent when possible.
  • -
  • Failed or latent KinesisReceivers will be detected and automatically shutdown/load-balanced by the KCL.
  • -
  • If possible, explicitly shutdown the worker if a failure occurs in order to trigger the final checkpoint.
  • +
  • A failed or latent KinesisRecordProcessor within the KinesisReceiver will be detected and automatically restarted by the KCL.
  • +
  • If possible, the KinesisReceiver should be shutdown cleanly in order to trigger a final checkpoint of all KinesisRecordProcessors to avoid duplicate record processing.
  • \ No newline at end of file diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 0c2f85a3868f4..3a8c816cfffa1 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -124,7 +124,7 @@ def parse_args(): help="The SSH user you want to connect as (default: root)") parser.add_option( "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created") + help="When destroying a cluster, delete the security groups that were created.") parser.add_option( "--use-existing-master", action="store_true", default=False, help="Launch fresh slaves, but use an existing stopped master if possible") @@ -138,7 +138,9 @@ def parse_args(): parser.add_option( "--user-data", type="string", default="", help="Path to a user-data file (most AMI's interpret this as an initialization script)") - + parser.add_option( + "--security-group-prefix", type="string", default=None, + help="Use this prefix for the security group rather than the cluster name.") (opts, args) = parser.parse_args() if len(args) != 2: @@ -285,8 +287,12 @@ def launch_cluster(conn, opts, cluster_name): user_data_content = user_data_file.read() print "Setting up security groups..." - master_group = get_or_make_group(conn, cluster_name + "-master") - slave_group = get_or_make_group(conn, cluster_name + "-slaves") + if opts.security_group_prefix is None: + master_group = get_or_make_group(conn, cluster_name + "-master") + slave_group = get_or_make_group(conn, cluster_name + "-slaves") + else: + master_group = get_or_make_group(conn, opts.security_group_prefix + "-master") + slave_group = get_or_make_group(conn, opts.security_group_prefix + "-slaves") if master_group.rules == []: # Group was just now created master_group.authorize(src_group=master_group) master_group.authorize(src_group=slave_group) @@ -310,12 +316,11 @@ def launch_cluster(conn, opts, cluster_name): slave_group.authorize('tcp', 60060, 60060, '0.0.0.0/0') slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0') - # Check if instances are already running in our groups + # Check if instances are already running with the cluster name existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances running in " + - "group %s or %s" % (master_group.name, slave_group.name)) + print >> stderr, ("ERROR: There are already instances for name: %s " % cluster_name) sys.exit(1) # Figure out Spark AMI @@ -371,9 +376,13 @@ def launch_cluster(conn, opts, cluster_name): for r in reqs: id_to_req[r.id] = r active_instance_ids = [] + outstanding_request_ids = [] for i in my_req_ids: - if i in id_to_req and id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) + if i in id_to_req: + if id_to_req[i].state == "active": + active_instance_ids.append(id_to_req[i].instance_id) + else: + outstanding_request_ids.append(i) if len(active_instance_ids) == opts.slaves: print "All %d slaves granted" % opts.slaves reservations = conn.get_all_instances(active_instance_ids) @@ -382,8 +391,8 @@ def launch_cluster(conn, opts, cluster_name): slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves) + print "%d of %d slaves granted, waiting longer for request ids including %s" % ( + len(active_instance_ids), opts.slaves, outstanding_request_ids[0:10]) except: print "Canceling spot instance requests" conn.cancel_spot_instance_requests(my_req_ids) @@ -440,14 +449,29 @@ def launch_cluster(conn, opts, cluster_name): print "Launched master in %s, regid = %s" % (zone, master_res.id) # Give the instances descriptive names + # TODO: Add retry logic for tagging with name since it's used to identify a cluster. for master in master_nodes: - master.add_tag( - key='Name', - value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) + name = '{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id) + for i in range(0, 5): + try: + master.add_tag(key='Name', value=name) + except: + print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) + if (i == 5): + raise "Error - failed max attempts to add name tag" + time.sleep(5) + + for slave in slave_nodes: - slave.add_tag( - key='Name', - value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) + name = '{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id) + for i in range(0, 5): + try: + slave.add_tag(key='Name', value=name) + except: + print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) + if (i == 5): + raise "Error - failed max attempts to add name tag" + time.sleep(5) # Return all the instances return (master_nodes, slave_nodes) @@ -463,10 +487,10 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): for res in reservations: active = [i for i in res.instances if is_active(i)] for inst in active: - group_names = [g.name for g in inst.groups] - if group_names == [cluster_name + "-master"]: + name = inst.tags.get(u'Name', "") + if name.startswith(cluster_name + "-master"): master_nodes.append(inst) - elif group_names == [cluster_name + "-slaves"]: + elif name.startswith(cluster_name + "-slave"): slave_nodes.append(inst) if any((master_nodes, slave_nodes)): print ("Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes))) @@ -474,7 +498,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" + print >> sys.stderr, "ERROR: Could not find master in with name " + cluster_name + "-master" else: print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) @@ -816,7 +840,10 @@ def real_main(): # Delete security groups as well if opts.delete_groups: print "Deleting security groups (this will take some time)..." - group_names = [cluster_name + "-master", cluster_name + "-slaves"] + if opts.security_group_prefix is None: + group_names = [cluster_name + "-master", cluster_name + "-slaves"] + else: + group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"] attempt = 1 while attempt <= 3: diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index c862650b0aa1d..5b1fa4d997eeb 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -97,3 +97,5 @@ def update(i, vec, mat, ratings): error = rmse(R, ms, us) print "Iteration %d:" % i print "\nRMSE: %5.4f\n" % error + + sc.stop() diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py new file mode 100644 index 0000000000000..e902ae29753c0 --- /dev/null +++ b/examples/src/main/python/avro_inputformat.py @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Read data file users.avro in local Spark distro: + +$ cd $SPARK_HOME +$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +> examples/src/main/resources/users.avro +{u'favorite_color': None, u'name': u'Alyssa', u'favorite_numbers': [3, 9, 15, 20]} +{u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} + +To read name and favorite_color fields only, specify the following reader schema: + +$ cat examples/src/main/resources/user.avsc +{"namespace": "example.avro", + "type": "record", + "name": "User", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_color", "type": ["string", "null"]} + ] +} + +$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +> examples/src/main/resources/users.avro examples/src/main/resources/user.avsc +{u'favorite_color': None, u'name': u'Alyssa'} +{u'favorite_color': u'red', u'name': u'Ben'} +""" +if __name__ == "__main__": + if len(sys.argv) != 2 and len(sys.argv) != 3: + print >> sys.stderr, """ + Usage: avro_inputformat [reader_schema_file] + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/avro_inputformat.py [reader_schema_file] + Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. + """ + exit(-1) + + path = sys.argv[1] + sc = SparkContext(appName="AvroKeyInputFormat") + + conf = None + if len(sys.argv) == 3: + schema_rdd = sc.textFile(sys.argv[2], 1).collect() + conf = {"avro.schema.input.key" : reduce(lambda x, y: x+y, schema_rdd)} + + avro_rdd = sc.newAPIHadoopFile(path, + "org.apache.avro.mapreduce.AvroKeyInputFormat", + "org.apache.avro.mapred.AvroKey", + "org.apache.hadoop.io.NullWritable", + keyConverter="org.apache.spark.examples.pythonconverters.AvroWrapperToJavaConverter", + conf=conf) + output = avro_rdd.map(lambda x: x[0]).collect() + for k in output: + print k diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py index 39fa6b0d22ef5..e4a897f61e39d 100644 --- a/examples/src/main/python/cassandra_inputformat.py +++ b/examples/src/main/python/cassandra_inputformat.py @@ -77,3 +77,5 @@ output = cass_rdd.collect() for (k, v) in output: print (k, v) + + sc.stop() diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py index 1dfbf98604425..836c35b5c6794 100644 --- a/examples/src/main/python/cassandra_outputformat.py +++ b/examples/src/main/python/cassandra_outputformat.py @@ -81,3 +81,5 @@ conf=conf, keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter", valueConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLValueConverter") + + sc.stop() diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index c9fa8e171c2a1..befacee0dea56 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -71,3 +71,5 @@ output = hbase_rdd.collect() for (k, v) in output: print (k, v) + + sc.stop() diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py index 5e11548fd13f7..49bbc5aebdb0b 100644 --- a/examples/src/main/python/hbase_outputformat.py +++ b/examples/src/main/python/hbase_outputformat.py @@ -63,3 +63,5 @@ conf=conf, keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter", valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter") + + sc.stop() diff --git a/examples/src/main/python/kmeans.py b/examples/src/main/python/kmeans.py index 036bdf4c4f999..86ef6f32c84e8 100755 --- a/examples/src/main/python/kmeans.py +++ b/examples/src/main/python/kmeans.py @@ -77,3 +77,5 @@ def closestPoint(p, centers): kPoints[x] = y print "Final centers: " + str(kPoints) + + sc.stop() diff --git a/examples/src/main/python/logistic_regression.py b/examples/src/main/python/logistic_regression.py index 8456b272f9c05..3aa56b0528168 100755 --- a/examples/src/main/python/logistic_regression.py +++ b/examples/src/main/python/logistic_regression.py @@ -80,3 +80,5 @@ def add(x, y): w -= points.map(lambda m: gradient(m, w)).reduce(add) print "Final w: " + str(w) + + sc.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py new file mode 100755 index 0000000000000..6b16a56e44af7 --- /dev/null +++ b/examples/src/main/python/mllib/correlations.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Correlations using MLlib. +""" + +import sys + +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat import Statistics +from pyspark.mllib.util import MLUtils + + +if __name__ == "__main__": + if len(sys.argv) not in [1,2]: + print >> sys.stderr, "Usage: correlations ()" + exit(-1) + sc = SparkContext(appName="PythonCorrelations") + if len(sys.argv) == 2: + filepath = sys.argv[1] + else: + filepath = 'data/mllib/sample_linear_regression_data.txt' + corrType = 'pearson' + + points = MLUtils.loadLibSVMFile(sc, filepath)\ + .map(lambda lp: LabeledPoint(lp.label, lp.features.toArray())) + + print + print 'Summary of data file: ' + filepath + print '%d data points' % points.count() + + # Statistics (correlations) + print + print 'Correlation (%s) between label and each feature' % corrType + print 'Feature\tCorrelation' + numFeatures = points.take(1)[0].features.size + labelRDD = points.map(lambda lp: lp.label) + for i in range(numFeatures): + featureRDD = points.map(lambda lp: lp.features[i]) + corr = Statistics.corr(labelRDD, featureRDD, corrType) + print '%d\t%g' % (i, corr) + print + + sc.stop() diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index 8efadb5223f56..6e4a4a0cb6be0 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -17,6 +17,8 @@ """ Decision tree classification and regression using MLlib. + +This example requires NumPy (http://www.numpy.org/). """ import numpy, os, sys @@ -117,6 +119,7 @@ def usage(): if len(sys.argv) == 2: dataPath = sys.argv[1] if not os.path.isfile(dataPath): + sc.stop() usage() points = MLUtils.loadLibSVMFile(sc, dataPath) @@ -124,10 +127,14 @@ def usage(): (reindexedData, origToNewLabels) = reindexClassLabels(points) # Train a classifier. - model = DecisionTree.trainClassifier(reindexedData, numClasses=2) + categoricalFeaturesInfo={} # no categorical features + model = DecisionTree.trainClassifier(reindexedData, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. print "Trained DecisionTree for classification:" print " Model numNodes: %d\n" % model.numNodes() print " Model depth: %d\n" % model.depth() print " Training accuracy: %g\n" % getAccuracy(model, reindexedData) print model + + sc.stop() diff --git a/examples/src/main/python/mllib/kmeans.py b/examples/src/main/python/mllib/kmeans.py index b308132c9aeeb..2eeb1abeeb12b 100755 --- a/examples/src/main/python/mllib/kmeans.py +++ b/examples/src/main/python/mllib/kmeans.py @@ -42,3 +42,4 @@ def parseVector(line): k = int(sys.argv[2]) model = KMeans.train(data, k) print "Final centers: " + str(model.clusterCenters) + sc.stop() diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 9d547ff77c984..8cae27fc4a52d 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -50,3 +50,4 @@ def parsePoint(line): model = LogisticRegressionWithSGD.train(points, iterations) print "Final weights: " + str(model.weights) print "Final intercept: " + str(model.intercept) + sc.stop() diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py new file mode 100755 index 0000000000000..b388d8d83fb86 --- /dev/null +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Randomly generated RDDs. +""" + +import sys + +from pyspark import SparkContext +from pyspark.mllib.random import RandomRDDs + + +if __name__ == "__main__": + if len(sys.argv) not in [1, 2]: + print >> sys.stderr, "Usage: random_rdd_generation" + exit(-1) + + sc = SparkContext(appName="PythonRandomRDDGeneration") + + numExamples = 10000 # number of examples to generate + fraction = 0.1 # fraction of data to sample + + # Example: RandomRDDs.normalRDD + normalRDD = RandomRDDs.normalRDD(sc, numExamples) + print 'Generated RDD of %d examples sampled from the standard normal distribution'\ + % normalRDD.count() + print ' First 5 samples:' + for sample in normalRDD.take(5): + print ' ' + str(sample) + print + + # Example: RandomRDDs.normalVectorRDD + normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows = numExamples, numCols = 2) + print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count() + print ' First 5 samples:' + for sample in normalVectorRDD.take(5): + print ' ' + str(sample) + print + + sc.stop() diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py new file mode 100755 index 0000000000000..ec64a5978c672 --- /dev/null +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Randomly sampled RDDs. +""" + +import sys + +from pyspark import SparkContext +from pyspark.mllib.util import MLUtils + + +if __name__ == "__main__": + if len(sys.argv) not in [1, 2]: + print >> sys.stderr, "Usage: sampled_rdds " + exit(-1) + if len(sys.argv) == 2: + datapath = sys.argv[1] + else: + datapath = 'data/mllib/sample_binary_classification_data.txt' + + sc = SparkContext(appName="PythonSampledRDDs") + + fraction = 0.1 # fraction of data to sample + + examples = MLUtils.loadLibSVMFile(sc, datapath) + numExamples = examples.count() + if numExamples == 0: + print >> sys.stderr, "Error: Data file had no samples to load." + exit(1) + print 'Loaded data with %d examples from file: %s' % (numExamples, datapath) + + # Example: RDD.sample() and RDD.takeSample() + expectedSampleSize = int(numExamples * fraction) + print 'Sampling RDD using fraction %g. Expected sample size = %d.' \ + % (fraction, expectedSampleSize) + sampledRDD = examples.sample(withReplacement = True, fraction = fraction) + print ' RDD.sample(): sample has %d examples' % sampledRDD.count() + sampledArray = examples.takeSample(withReplacement = True, num = expectedSampleSize) + print ' RDD.takeSample(): sample has %d examples' % len(sampledArray) + + print + + # Example: RDD.sampleByKey() + keyedRDD = examples.map(lambda lp: (int(lp.label), lp.features)) + print ' Keyed data using label (Int) as key ==> Orig' + # Count examples per label in original data. + keyCountsA = keyedRDD.countByKey() + + # Subsample, and count examples per label in sampled data. + fractions = {} + for k in keyCountsA.keys(): + fractions[k] = fraction + sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement = True, fractions = fractions) + keyCountsB = sampledByKeyRDD.countByKey() + sizeB = sum(keyCountsB.values()) + print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \ + % sizeB + + # Compare samples + print ' \tFractions of examples with key' + print 'Key\tOrig\tSample' + for k in sorted(keyCountsA.keys()): + fracA = keyCountsA[k] / float(numExamples) + if sizeB != 0: + fracB = keyCountsB.get(k, 0) / float(sizeB) + else: + fracB = 0 + print '%d\t%g\t%g' % (k, fracA, fracB) + + sc.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index 0b96343158d44..b539c4128cdcc 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -68,3 +68,5 @@ def parseNeighbors(urls): # Collects all URL ranks and dump them to console. for (link, rank) in ranks.collect(): print "%s has rank: %s." % (link, rank) + + sc.stop() diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index 21d94a2cd4b64..fc37459dc74aa 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -37,3 +37,5 @@ def f(_): count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) print "Pi is roughly %f" % (4.0 * count / n) + + sc.stop() diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index 41d00c1b79133..bb686f17518a0 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -34,3 +34,5 @@ output = sortedCount.collect() for (num, unitcount) in output: print num + + sc.stop() diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 8698369b13d84..bf331b542c438 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -64,3 +64,5 @@ def generateGraph(): break print "TC has %i edges" % tc.count() + + sc.stop() diff --git a/examples/src/main/python/wordcount.py b/examples/src/main/python/wordcount.py index dcc095fdd0ed9..ae6cd13b83d92 100755 --- a/examples/src/main/python/wordcount.py +++ b/examples/src/main/python/wordcount.py @@ -33,3 +33,5 @@ output = counts.collect() for (word, count) in output: print "%s: %i" % (word, count) + + sc.stop() diff --git a/examples/src/main/resources/user.avsc b/examples/src/main/resources/user.avsc new file mode 100644 index 0000000000000..4995357ab3736 --- /dev/null +++ b/examples/src/main/resources/user.avsc @@ -0,0 +1,8 @@ +{"namespace": "example.avro", + "type": "record", + "name": "User", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "favorite_color", "type": ["string", "null"]} + ] +} diff --git a/examples/src/main/resources/users.avro b/examples/src/main/resources/users.avro new file mode 100644 index 0000000000000..27c526ab114b2 Binary files /dev/null and b/examples/src/main/resources/users.avro differ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 56b02b65d8724..a6f78d2441db1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -21,7 +21,7 @@ import org.apache.log4j.{Level, Logger} import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.classification.{LogisticRegressionWithSGD, SVMWithSGD} +import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater} @@ -66,7 +66,8 @@ object BinaryClassification { .text("number of iterations") .action((x, c) => c.copy(numIterations = x)) opt[Double]("stepSize") - .text(s"initial step size, default: ${defaultParams.stepSize}") + .text("initial step size (ignored by logistic regression), " + + s"default: ${defaultParams.stepSize}") .action((x, c) => c.copy(stepSize = x)) opt[String]("algorithm") .text(s"algorithm (${Algorithm.values.mkString(",")}), " + @@ -125,10 +126,9 @@ object BinaryClassification { val model = params.algorithm match { case LR => - val algorithm = new LogisticRegressionWithSGD() + val algorithm = new LogisticRegressionWithLBFGS() algorithm.optimizer .setNumIterations(params.numIterations) - .setStepSize(params.stepSize) .setUpdater(updater) .setRegParam(params.regParam) algorithm.run(training).clearThreshold() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala new file mode 100644 index 0000000000000..d6b2fe430e5a4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.{SparkConf, SparkContext} + + +/** + * An example app for summarizing multivariate data from a file. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.Correlations + * }}} + * By default, this loads a synthetic dataset from `data/mllib/sample_linear_regression_data.txt`. + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object Correlations { + + case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + + def main(args: Array[String]) { + + val defaultParams = Params() + + val parser = new OptionParser[Params]("Correlations") { + head("Correlations: an example app for computing correlations") + opt[String]("input") + .text(s"Input path to labeled examples in LIBSVM format, default: ${defaultParams.input}") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app on a synthetic dataset: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.Correlations \ + | examples/target/scala-*/spark-examples-*.jar \ + | --input data/mllib/sample_linear_regression_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"Correlations with $params") + val sc = new SparkContext(conf) + + val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() + + println(s"Summary of data file: ${params.input}") + println(s"${examples.count()} data points") + + // Calculate label -- feature correlations + val labelRDD = examples.map(_.label) + val numFeatures = examples.take(1)(0).features.size + val corrType = "pearson" + println() + println(s"Correlation ($corrType) between label and each feature") + println(s"Feature\tCorrelation") + var feature = 0 + while (feature < numFeatures) { + val featureRDD = examples.map(_.features(feature)) + val corr = Statistics.corr(labelRDD, featureRDD) + println(s"$feature\t$corr") + feature += 1 + } + println() + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala new file mode 100644 index 0000000000000..4532512c01f84 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.{SparkConf, SparkContext} + + +/** + * An example app for summarizing multivariate data from a file. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.MultivariateSummarizer + * }}} + * By default, this loads a synthetic dataset from `data/mllib/sample_linear_regression_data.txt`. + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object MultivariateSummarizer { + + case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + + def main(args: Array[String]) { + + val defaultParams = Params() + + val parser = new OptionParser[Params]("MultivariateSummarizer") { + head("MultivariateSummarizer: an example app for MultivariateOnlineSummarizer") + opt[String]("input") + .text(s"Input path to labeled examples in LIBSVM format, default: ${defaultParams.input}") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app on a synthetic dataset: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.MultivariateSummarizer \ + | examples/target/scala-*/spark-examples-*.jar \ + | --input data/mllib/sample_linear_regression_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"MultivariateSummarizer with $params") + val sc = new SparkContext(conf) + + val examples = MLUtils.loadLibSVMFile(sc, params.input).cache() + + println(s"Summary of data file: ${params.input}") + println(s"${examples.count()} data points") + + // Summarize labels + val labelSummary = examples.aggregate(new MultivariateOnlineSummarizer())( + (summary, lp) => summary.add(Vectors.dense(lp.label)), + (sum1, sum2) => sum1.merge(sum2)) + + // Summarize features + val featureSummary = examples.aggregate(new MultivariateOnlineSummarizer())( + (summary, lp) => summary.add(lp.features), + (sum1, sum2) => sum1.merge(sum2)) + + println() + println(s"Summary statistics") + println(s"\tLabel\tFeatures") + println(s"mean\t${labelSummary.mean(0)}\t${featureSummary.mean.toArray.mkString("\t")}") + println(s"var\t${labelSummary.variance(0)}\t${featureSummary.variance.toArray.mkString("\t")}") + println( + s"nnz\t${labelSummary.numNonzeros(0)}\t${featureSummary.numNonzeros.toArray.mkString("\t")}") + println(s"max\t${labelSummary.max(0)}\t${featureSummary.max.toArray.mkString("\t")}") + println(s"min\t${labelSummary.min(0)}\t${featureSummary.min.toArray.mkString("\t")}") + println() + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala new file mode 100644 index 0000000000000..924b586e3af99 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.random.RandomRDDs +import org.apache.spark.rdd.RDD + +import org.apache.spark.{SparkConf, SparkContext} + +/** + * An example app for randomly generated RDDs. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.RandomRDDGeneration + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object RandomRDDGeneration { + + def main(args: Array[String]) { + + val conf = new SparkConf().setAppName(s"RandomRDDGeneration") + val sc = new SparkContext(conf) + + val numExamples = 10000 // number of examples to generate + val fraction = 0.1 // fraction of data to sample + + // Example: RandomRDDs.normalRDD + val normalRDD: RDD[Double] = RandomRDDs.normalRDD(sc, numExamples) + println(s"Generated RDD of ${normalRDD.count()}" + + " examples sampled from the standard normal distribution") + println(" First 5 samples:") + normalRDD.take(5).foreach( x => println(s" $x") ) + + // Example: RandomRDDs.normalVectorRDD + val normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows = numExamples, numCols = 2) + println(s"Generated RDD of ${normalVectorRDD.count()} examples of length-2 vectors.") + println(" First 5 samples:") + normalVectorRDD.take(5).foreach( x => println(s" $x") ) + + println() + + sc.stop() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala new file mode 100644 index 0000000000000..f01b8266e3fe3 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.util.MLUtils +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ + +/** + * An example app for randomly generated and sampled RDDs. Run with + * {{{ + * bin/run-example org.apache.spark.examples.mllib.SampledRDDs + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object SampledRDDs { + + case class Params(input: String = "data/mllib/sample_binary_classification_data.txt") + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("SampledRDDs") { + head("SampledRDDs: an example app for randomly generated and sampled RDDs.") + opt[String]("input") + .text(s"Input path to labeled examples in LIBSVM format, default: ${defaultParams.input}") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.SampledRDDs \ + | examples/target/scala-*/spark-examples-*.jar + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName(s"SampledRDDs with $params") + val sc = new SparkContext(conf) + + val fraction = 0.1 // fraction of data to sample + + val examples = MLUtils.loadLibSVMFile(sc, params.input) + val numExamples = examples.count() + if (numExamples == 0) { + throw new RuntimeException("Error: Data file had no samples to load.") + } + println(s"Loaded data with $numExamples examples from file: ${params.input}") + + // Example: RDD.sample() and RDD.takeSample() + val expectedSampleSize = (numExamples * fraction).toInt + println(s"Sampling RDD using fraction $fraction. Expected sample size = $expectedSampleSize.") + val sampledRDD = examples.sample(withReplacement = true, fraction = fraction) + println(s" RDD.sample(): sample has ${sampledRDD.count()} examples") + val sampledArray = examples.takeSample(withReplacement = true, num = expectedSampleSize) + println(s" RDD.takeSample(): sample has ${sampledArray.size} examples") + + println() + + // Example: RDD.sampleByKey() and RDD.sampleByKeyExact() + val keyedRDD = examples.map { lp => (lp.label.toInt, lp.features) } + println(s" Keyed data using label (Int) as key ==> Orig") + // Count examples per label in original data. + val keyCounts = keyedRDD.countByKey() + + // Subsample, and count examples per label in sampled data. (approximate) + val fractions = keyCounts.keys.map((_, fraction)).toMap + val sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement = true, fractions = fractions) + val keyCountsB = sampledByKeyRDD.countByKey() + val sizeB = keyCountsB.values.sum + println(s" Sampled $sizeB examples using approximate stratified sampling (by label)." + + " ==> Approx Sample") + + // Subsample, and count examples per label in sampled data. (approximate) + val sampledByKeyRDDExact = + keyedRDD.sampleByKeyExact(withReplacement = true, fractions = fractions) + val keyCountsBExact = sampledByKeyRDDExact.countByKey() + val sizeBExact = keyCountsBExact.values.sum + println(s" Sampled $sizeBExact examples using exact stratified sampling (by label)." + + " ==> Exact Sample") + + // Compare samples + println(s" \tFractions of examples with key") + println(s"Key\tOrig\tApprox Sample\tExact Sample") + keyCounts.keys.toSeq.sorted.foreach { key => + val origFrac = keyCounts(key) / numExamples.toDouble + val approxFrac = if (sizeB != 0) { + keyCountsB.getOrElse(key, 0L) / sizeB.toDouble + } else { + 0 + } + val exactFrac = if (sizeBExact != 0) { + keyCountsBExact.getOrElse(key, 0L) / sizeBExact.toDouble + } else { + 0 + } + println(s"$key\t$origFrac\t$approxFrac\t$exactFrac") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index 1fd37edfa7427..c5bd5b0b178d9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -18,8 +18,7 @@ package org.apache.spark.examples.mllib import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD +import org.apache.spark.mllib.regression.{LabeledPoint, StreamingLinearRegressionWithSGD} import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -56,14 +55,14 @@ object StreamingLinearRegression { val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) - val trainingData = MLUtils.loadStreamingLabeledPoints(ssc, args(0)) - val testData = MLUtils.loadStreamingLabeledPoints(ssc, args(1)) + val trainingData = ssc.textFileStream(args(0)).map(LabeledPoint.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) val model = new StreamingLinearRegressionWithSGD() - .setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0))) + .setInitialWeights(Vectors.zeros(args(3).toInt)) model.trainOn(trainingData) - model.predictOn(testData).print() + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala new file mode 100644 index 0000000000000..1b25983a38453 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.pythonconverters + +import java.util.{Collection => JCollection, Map => JMap} + +import scala.collection.JavaConversions._ + +import org.apache.avro.generic.{GenericFixed, IndexedRecord} +import org.apache.avro.mapred.AvroWrapper +import org.apache.avro.Schema +import org.apache.avro.Schema.Type._ + +import org.apache.spark.api.python.Converter +import org.apache.spark.SparkException + + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts + * an Avro Record wrapped in an AvroKey (or AvroValue) to a Java Map. It tries + * to work with all 3 Avro data mappings (Generic, Specific and Reflect). + */ +class AvroWrapperToJavaConverter extends Converter[Any, Any] { + override def convert(obj: Any): Any = { + if (obj == null) { + return null + } + obj.asInstanceOf[AvroWrapper[_]].datum() match { + case null => null + case record: IndexedRecord => unpackRecord(record) + case other => throw new SparkException( + s"Unsupported top-level Avro data type ${other.getClass.getName}") + } + } + + def unpackRecord(obj: Any): JMap[String, Any] = { + val map = new java.util.HashMap[String, Any] + obj match { + case record: IndexedRecord => + record.getSchema.getFields.zipWithIndex.foreach { case (f, i) => + map.put(f.name, fromAvro(record.get(i), f.schema)) + } + case other => throw new SparkException( + s"Unsupported RECORD type ${other.getClass.getName}") + } + map + } + + def unpackMap(obj: Any, schema: Schema): JMap[String, Any] = { + obj.asInstanceOf[JMap[_, _]].map { case (key, value) => + (key.toString, fromAvro(value, schema.getValueType)) + } + } + + def unpackFixed(obj: Any, schema: Schema): Array[Byte] = { + unpackBytes(obj.asInstanceOf[GenericFixed].bytes()) + } + + def unpackBytes(obj: Any): Array[Byte] = { + val bytes: Array[Byte] = obj match { + case buf: java.nio.ByteBuffer => buf.array() + case arr: Array[Byte] => arr + case other => throw new SparkException( + s"Unknown BYTES type ${other.getClass.getName}") + } + val bytearray = new Array[Byte](bytes.length) + System.arraycopy(bytes, 0, bytearray, 0, bytes.length) + bytearray + } + + def unpackArray(obj: Any, schema: Schema): JCollection[Any] = obj match { + case c: JCollection[_] => + c.map(fromAvro(_, schema.getElementType)) + case arr: Array[_] if arr.getClass.getComponentType.isPrimitive => + arr.toSeq + case arr: Array[_] => + arr.map(fromAvro(_, schema.getElementType)).toSeq + case other => throw new SparkException( + s"Unknown ARRAY type ${other.getClass.getName}") + } + + def unpackUnion(obj: Any, schema: Schema): Any = { + schema.getTypes.toList match { + case List(s) => fromAvro(obj, s) + case List(n, s) if n.getType == NULL => fromAvro(obj, s) + case List(s, n) if n.getType == NULL => fromAvro(obj, s) + case _ => throw new SparkException( + "Unions may only consist of a concrete type and null") + } + } + + def fromAvro(obj: Any, schema: Schema): Any = { + if (obj == null) { + return null + } + schema.getType match { + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj + case BOOLEAN => obj + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException( + s"Unknown Avro schema type ${other.getName}") + } + } +} diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index d0bf1cf1ea796..0c68defa5e101 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -72,6 +72,13 @@ org.scalatest scalatest_${scala.binary.version} + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + target/scala-${scala.binary.version}/classes diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 7b735133e3d14..98ae7d783aec8 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -53,7 +53,6 @@ import org.apache.flume.sink.AbstractSink * */ -private[flume] class SparkSink extends AbstractSink with Logging with Configurable { // Size of the pool to use for holding transaction processors. @@ -131,6 +130,14 @@ class SparkSink extends AbstractSink with Logging with Configurable { blockingLatch.await() Status.BACKOFF } + + private[flume] def getPort(): Int = { + serverOpt + .map(_.getPort) + .getOrElse( + throw new RuntimeException("Server was not started!") + ) + } } /** diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala new file mode 100644 index 0000000000000..44b27edf85ce8 --- /dev/null +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.flume.sink + +import java.net.InetSocketAddress +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} + +import scala.collection.JavaConversions._ +import scala.concurrent.{ExecutionContext, Future} +import scala.util.{Failure, Success} + +import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.event.EventBuilder +import org.apache.spark.streaming.TestSuiteBase +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory + +class SparkSinkSuite extends TestSuiteBase { + val eventsPerBatch = 1000 + val channelCapacity = 5000 + + test("Success") { + val (channel, sink) = initializeChannelAndSink() + channel.start() + sink.start() + + putEvents(channel, eventsPerBatch) + + val port = sink.getPort + val address = new InetSocketAddress("0.0.0.0", port) + + val (transceiver, client) = getTransceiverAndClient(address, 1)(0) + val events = client.getEventBatch(1000) + client.ack(events.getSequenceNumber) + assert(events.getEvents.size() === 1000) + assertChannelIsEmpty(channel) + sink.stop() + channel.stop() + transceiver.close() + } + + test("Nack") { + val (channel, sink) = initializeChannelAndSink() + channel.start() + sink.start() + putEvents(channel, eventsPerBatch) + + val port = sink.getPort + val address = new InetSocketAddress("0.0.0.0", port) + + val (transceiver, client) = getTransceiverAndClient(address, 1)(0) + val events = client.getEventBatch(1000) + assert(events.getEvents.size() === 1000) + client.nack(events.getSequenceNumber) + assert(availableChannelSlots(channel) === 4000) + sink.stop() + channel.stop() + transceiver.close() + } + + test("Timeout") { + val (channel, sink) = initializeChannelAndSink(Map(SparkSinkConfig + .CONF_TRANSACTION_TIMEOUT -> 1.toString)) + channel.start() + sink.start() + putEvents(channel, eventsPerBatch) + val port = sink.getPort + val address = new InetSocketAddress("0.0.0.0", port) + + val (transceiver, client) = getTransceiverAndClient(address, 1)(0) + val events = client.getEventBatch(1000) + assert(events.getEvents.size() === 1000) + Thread.sleep(1000) + assert(availableChannelSlots(channel) === 4000) + sink.stop() + channel.stop() + transceiver.close() + } + + test("Multiple consumers") { + testMultipleConsumers(failSome = false) + } + + test("Multiple consumers with some failures") { + testMultipleConsumers(failSome = true) + } + + def testMultipleConsumers(failSome: Boolean): Unit = { + implicit val executorContext = ExecutionContext + .fromExecutorService(Executors.newFixedThreadPool(5)) + val (channel, sink) = initializeChannelAndSink() + channel.start() + sink.start() + (1 to 5).foreach(_ => putEvents(channel, eventsPerBatch)) + val port = sink.getPort + val address = new InetSocketAddress("0.0.0.0", port) + val transceiversAndClients = getTransceiverAndClient(address, 5) + val batchCounter = new CountDownLatch(5) + val counter = new AtomicInteger(0) + transceiversAndClients.foreach(x => { + Future { + val client = x._2 + val events = client.getEventBatch(1000) + if (!failSome || counter.getAndIncrement() % 2 == 0) { + client.ack(events.getSequenceNumber) + } else { + client.nack(events.getSequenceNumber) + throw new RuntimeException("Sending NACK for failure!") + } + events + }.onComplete { + case Success(events) => + assert(events.getEvents.size() === 1000) + batchCounter.countDown() + case Failure(t) => + // Don't re-throw the exception, causes a nasty unnecessary stack trace on stdout + batchCounter.countDown() + } + }) + batchCounter.await() + TimeUnit.SECONDS.sleep(1) // Allow the sink to commit the transactions. + executorContext.shutdown() + if(failSome) { + assert(availableChannelSlots(channel) === 3000) + } else { + assertChannelIsEmpty(channel) + } + sink.stop() + channel.stop() + transceiversAndClients.foreach(x => x._1.close()) + } + + private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty): (MemoryChannel, + SparkSink) = { + val channel = new MemoryChannel() + val channelContext = new Context() + + channelContext.put("capacity", channelCapacity.toString) + channelContext.put("transactionCapacity", 1000.toString) + channelContext.put("keep-alive", 0.toString) + channelContext.putAll(overrides) + channel.configure(channelContext) + + val sink = new SparkSink() + val sinkContext = new Context() + sinkContext.put(SparkSinkConfig.CONF_HOSTNAME, "0.0.0.0") + sinkContext.put(SparkSinkConfig.CONF_PORT, 0.toString) + sink.configure(sinkContext) + sink.setChannel(channel) + (channel, sink) + } + + private def putEvents(ch: MemoryChannel, count: Int): Unit = { + val tx = ch.getTransaction + tx.begin() + (1 to count).foreach(x => ch.put(EventBuilder.withBody(x.toString.getBytes))) + tx.commit() + tx.close() + } + + private def getTransceiverAndClient(address: InetSocketAddress, + count: Int): Seq[(NettyTransceiver, SparkFlumeProtocol.Callback)] = { + + (1 to count).map(_ => { + lazy val channelFactoryExecutor = + Executors.newCachedThreadPool(new ThreadFactoryBuilder().setDaemon(true). + setNameFormat("Flume Receiver Channel Thread - %d").build()) + lazy val channelFactory = + new NioClientSocketChannelFactory(channelFactoryExecutor, channelFactoryExecutor) + val transceiver = new NettyTransceiver(address, channelFactory) + val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) + (transceiver, client) + }) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + assert(availableChannelSlots(channel) === channelCapacity) + } + + private def availableChannelSlots(channel: MemoryChannel): Int = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 27bf2ac962721..32a19787a28e1 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -22,6 +22,8 @@ import java.net.InetSocketAddress import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} import java.util.Random +import org.apache.spark.TestUtils + import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -35,29 +37,45 @@ import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.util.ManualClock import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ +import org.apache.spark.util.Utils class FlumePollingStreamSuite extends TestSuiteBase { - val random = new Random() - /** Return a port in the ephemeral range. */ - def getTestPort = random.nextInt(16382) + 49152 val batchCount = 5 val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch val channelCapacity = 5000 + val maxAttempts = 5 test("flume polling test") { - val testPort = getTestPort - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", testPort)), - StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) - val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] - with SynchronizedBuffer[Seq[SparkFlumeEvent]] - val outputStream = new TestOutputStream(flumeStream, outputBuffer) - outputStream.register() + testMultipleTimes(testFlumePolling) + } + + test("flume polling test multiple hosts") { + testMultipleTimes(testFlumePollingMultipleHost) + } + + /** + * Run the given test until no more java.net.BindException's are thrown. + * Do this only up to a certain attempt limit. + */ + private def testMultipleTimes(test: () => Unit): Unit = { + var testPassed = false + var attempt = 0 + while (!testPassed && attempt < maxAttempts) { + try { + test() + testPassed = true + } catch { + case e: Exception if Utils.isBindCollision(e) => + logWarning("Exception when running flume polling test: " + e) + attempt += 1 + } + } + assert(testPassed, s"Test failed after $attempt attempts!") + } + private def testFlumePolling(): Unit = { // Start the channel and sink. val context = new Context() context.put("capacity", channelCapacity.toString) @@ -68,31 +86,28 @@ class FlumePollingStreamSuite extends TestSuiteBase { val sink = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink, context) sink.setChannel(channel) sink.start() - ssc.start() - - writeAndVerify(Seq(channel), ssc, outputBuffer) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() - } - - test("flume polling test multiple hosts") { - val testPort = getTestPort // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = Seq(testPort, testPort + 1).map(new InetSocketAddress("localhost", _)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = - FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) + FlumeUtils.createPollingStream(ssc, Seq(new InetSocketAddress("localhost", sink.getPort())), + StorageLevel.MEMORY_AND_DISK, eventsPerBatch, 1) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() + ssc.start() + + writeAndVerify(Seq(channel), ssc, outputBuffer) + assertChannelIsEmpty(channel) + sink.stop() + channel.stop() + } + private def testFlumePollingMultipleHost(): Unit = { // Start the channel and sink. val context = new Context() context.put("capacity", channelCapacity.toString) @@ -106,17 +121,29 @@ class FlumePollingStreamSuite extends TestSuiteBase { val sink = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink, context) sink.setChannel(channel) sink.start() val sink2 = new SparkSink() context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(testPort + 1)) + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) Configurables.configure(sink2, context) sink2.setChannel(channel2) sink2.start() + + // Set up the streaming context and input streams + val ssc = new StreamingContext(conf, batchDuration) + val addresses = Seq(sink.getPort(), sink2.getPort()).map(new InetSocketAddress("localhost", _)) + val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = + FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, + eventsPerBatch, 5) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + outputStream.register() + ssc.start() writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) assertChannelIsEmpty(channel) @@ -171,7 +198,7 @@ class FlumePollingStreamSuite extends TestSuiteBase { } def assertChannelIsEmpty(channel: MemoryChannel) = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining"); + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") queueRemaining.setAccessible(true) val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) diff --git a/mllib/pom.xml b/mllib/pom.xml index 9a33bd1cf6ad1..c7a1e2ae75c84 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -57,7 +57,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.7 + 0.9 @@ -91,6 +91,13 @@ junit-interface test + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fd0b9556c7d54..4343124f102a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -25,18 +25,16 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} +import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -50,182 +48,7 @@ import org.apache.spark.util.Utils */ @DeveloperApi class PythonMLLibAPI extends Serializable { - private val DENSE_VECTOR_MAGIC: Byte = 1 - private val SPARSE_VECTOR_MAGIC: Byte = 2 - private val DENSE_MATRIX_MAGIC: Byte = 3 - private val LABELED_POINT_MAGIC: Byte = 4 - - private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { - require(bytes.length - offset >= 5, "Byte array too short") - val magic = bytes(offset) - if (magic == DENSE_VECTOR_MAGIC) { - deserializeDenseVector(bytes, offset) - } else if (magic == SPARSE_VECTOR_MAGIC) { - deserializeSparseVector(bytes, offset) - } else { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - } - - private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { - require(bytes.length - offset == 8, "Wrong size byte array for Double") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - bb.getDouble - } - - private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 5, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val length = bb.getInt() - require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) - val db = bb.asDoubleBuffer() - val ans = new Array[Double](length.toInt) - db.get(ans) - Vectors.dense(ans) - } - - private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { - val packetLength = bytes.length - offset - require(packetLength >= 9, "Byte array too short") - val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) - val size = bb.getInt() - val nonZeros = bb.getInt() - require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) - val ib = bb.asIntBuffer() - val indices = new Array[Int](nonZeros) - ib.get(indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - val values = new Array[Double](nonZeros) - db.get(values) - Vectors.sparse(size, indices, values) - } - - /** - * Returns an 8-byte array for the input Double. - * - * Note: we currently do not use a magic byte for double for storage efficiency. - * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. - * The corresponding deserializer, deserializeDouble, needs to be modified as well if the - * serialization scheme changes. - */ - private[python] def serializeDouble(double: Double): Array[Byte] = { - val bytes = new Array[Byte](8) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putDouble(double) - bytes - } - - private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { - val len = doubles.length - val bytes = new Array[Byte](5 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_VECTOR_MAGIC) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(doubles) - bytes - } - - private def serializeSparseVector(vector: SparseVector): Array[Byte] = { - val nonZeros = vector.indices.length - val bytes = new Array[Byte](9 + 12 * nonZeros) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(SPARSE_VECTOR_MAGIC) - bb.putInt(vector.size) - bb.putInt(nonZeros) - val ib = bb.asIntBuffer() - ib.put(vector.indices) - bb.position(bb.position() + 4 * nonZeros) - val db = bb.asDoubleBuffer() - db.put(vector.values) - bytes - } - - private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { - case s: SparseVector => - serializeSparseVector(s) - case _ => - serializeDenseVector(vector.toArray) - } - - private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { - val packetLength = bytes.length - if (packetLength < 9) { - throw new IllegalArgumentException("Byte array too short.") - } - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.get() - if (magic != DENSE_MATRIX_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val rows = bb.getInt() - val cols = bb.getInt() - if (packetLength != 9 + 8 * rows * cols) { - throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") - } - val db = bb.asDoubleBuffer() - val ans = new Array[Array[Double]](rows.toInt) - for (i <- 0 until rows.toInt) { - ans(i) = new Array[Double](cols.toInt) - db.get(ans(i)) - } - ans - } - private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { - val rows = doubles.length - var cols = 0 - if (rows > 0) { - cols = doubles(0).length - } - val bytes = new Array[Byte](9 + 8 * rows * cols) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(DENSE_MATRIX_MAGIC) - bb.putInt(rows) - bb.putInt(cols) - val db = bb.asDoubleBuffer() - for (i <- 0 until rows) { - db.put(doubles(i)) - } - bytes - } - - private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { - val fb = serializeDoubleVector(p.features) - val bytes = new Array[Byte](1 + 8 + fb.length) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.put(LABELED_POINT_MAGIC) - bb.putDouble(p.label) - bb.put(fb) - bytes - } - - private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { - require(bytes.length >= 9, "Byte array too short") - val magic = bytes(0) - if (magic != LABELED_POINT_MAGIC) { - throw new IllegalArgumentException("Magic " + magic + " is wrong.") - } - val labelBytes = ByteBuffer.wrap(bytes, 1, 8) - labelBytes.order(ByteOrder.nativeOrder()) - val label = labelBytes.asDoubleBuffer().get(0) - LabeledPoint(label, deserializeDoubleVector(bytes, 9)) - } /** * Loads and serializes labeled points saved with `RDD#saveAsTextFile`. @@ -238,17 +61,17 @@ class PythonMLLibAPI extends Serializable { jsc: JavaSparkContext, path: String, minPartitions: Int): JavaRDD[Array[Byte]] = - MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint) + MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(SerDe.serializeLabeledPoint) private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) - val initialWeights = deserializeDoubleVector(initialWeightsBA) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) + val initialWeights = SerDe.deserializeDoubleVector(initialWeightsBA) val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(model.weights)) + ret.add(SerDe.serializeDoubleVector(model.weights)) ret.add(model.intercept: java.lang.Double) ret } @@ -407,12 +230,12 @@ class PythonMLLibAPI extends Serializable { def trainNaiveBayes( dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) val model = NaiveBayes.train(data, lambda) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(Vectors.dense(model.labels))) - ret.add(serializeDoubleVector(Vectors.dense(model.pi))) - ret.add(serializeDoubleMatrix(model.theta)) + ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.labels))) + ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.pi))) + ret.add(SerDe.serializeDoubleMatrix(model.theta)) ret } @@ -425,52 +248,13 @@ class PythonMLLibAPI extends Serializable { maxIterations: Int, runs: Int, initializationMode: String): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes)) + val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes)) val model = KMeans.train(data, k, maxIterations, runs, initializationMode) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) + ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) ret } - /** Unpack a Rating object from an array of bytes */ - private def unpackRating(ratingBytes: Array[Byte]): Rating = { - val bb = ByteBuffer.wrap(ratingBytes) - bb.order(ByteOrder.nativeOrder()) - val user = bb.getInt() - val product = bb.getInt() - val rating = bb.getDouble() - new Rating(user, product, rating) - } - - /** Unpack a tuple of Ints from an array of bytes */ - private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { - val bb = ByteBuffer.wrap(tupleBytes) - bb.order(ByteOrder.nativeOrder()) - val v1 = bb.getInt() - val v2 = bb.getInt() - (v1, v2) - } - - /** - * Serialize a Rating object into an array of bytes. - * It can be deserialized using RatingDeserializer(). - * - * @param rate the Rating object to serialize - * @return - */ - private[spark] def serializeRating(rate: Rating): Array[Byte] = { - val len = 3 - val bytes = new Array[Byte](4 + 8 * len) - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - bb.putInt(len) - val db = bb.asDoubleBuffer() - db.put(rate.user.toDouble) - db.put(rate.product.toDouble) - db.put(rate.rating) - bytes - } - /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care @@ -483,7 +267,7 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) ALS.train(ratings, rank, iterations, lambda, blocks) } @@ -500,7 +284,7 @@ class PythonMLLibAPI extends Serializable { lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { - val ratings = ratingsBytesJRDD.rdd.map(unpackRating) + val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } @@ -521,19 +305,10 @@ class PythonMLLibAPI extends Serializable { maxDepth: Int, maxBins: Int): DecisionTreeModel = { - val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint) - val algo: Algo = algoStr match { - case "classification" => Classification - case "regression" => Regression - case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") - } - val impurity: Impurity = impurityStr match { - case "gini" => Gini - case "entropy" => Entropy - case "variance" => Variance - case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") - } + val algo = Algo.fromString(algoStr) + val impurity = Impurities.fromString(impurityStr) val strategy = new Strategy( algo = algo, @@ -556,7 +331,7 @@ class PythonMLLibAPI extends Serializable { def predictDecisionTreeModel( model: DecisionTreeModel, featuresBytes: Array[Byte]): Double = { - val features: Vector = deserializeDoubleVector(featuresBytes) + val features: Vector = SerDe.deserializeDoubleVector(featuresBytes) model.predict(features) } @@ -570,8 +345,17 @@ class PythonMLLibAPI extends Serializable { def predictDecisionTreeModel( model: DecisionTreeModel, dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) - model.predict(data).map(serializeDouble) + val data = dataJRDD.rdd.map(xBytes => SerDe.deserializeDoubleVector(xBytes)) + model.predict(data).map(SerDe.serializeDouble) + } + + /** + * Java stub for mllib Statistics.colStats(X: RDD[Vector]). + * TODO figure out return type. + */ + def colStats(X: JavaRDD[Array[Byte]]): MultivariateStatisticalSummarySerialized = { + val cStats = Statistics.colStats(X.rdd.map(SerDe.deserializeDoubleVector(_))) + new MultivariateStatisticalSummarySerialized(cStats) } /** @@ -580,17 +364,17 @@ class PythonMLLibAPI extends Serializable { * pyspark. */ def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = { - val inputMatrix = X.rdd.map(deserializeDoubleVector(_)) + val inputMatrix = X.rdd.map(SerDe.deserializeDoubleVector(_)) val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method)) - serializeDoubleMatrix(to2dArray(result)) + SerDe.serializeDoubleMatrix(SerDe.to2dArray(result)) } /** * Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String). */ def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = { - val xDeser = x.rdd.map(deserializeDouble(_)) - val yDeser = y.rdd.map(deserializeDouble(_)) + val xDeser = x.rdd.map(SerDe.deserializeDouble(_)) + val yDeser = y.rdd.map(SerDe.deserializeDouble(_)) Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method)) } @@ -599,12 +383,6 @@ class PythonMLLibAPI extends Serializable { if (method == null) CorrelationNames.defaultCorrName else method } - // Reformat a Matrix into Array[Array[Double]] for serialization - private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { - val values = matrix.toArray - Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) - } - // Used by the *RDD methods to get default seed if not passed in from pyspark private def getSeedOrDefault(seed: java.lang.Long): Long = { if (seed == null) Utils.random.nextLong else seed @@ -632,7 +410,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble) + RG.uniformRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) } /** @@ -644,7 +422,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble) + RG.normalRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble) } /** @@ -657,7 +435,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble) + RG.poissonRDD(jsc.sc, mean, size, parts, s).map(SerDe.serializeDouble) } /** @@ -670,7 +448,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) } /** @@ -683,7 +461,7 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) } /** @@ -697,7 +475,256 @@ class PythonMLLibAPI extends Serializable { seed: java.lang.Long): JavaRDD[Array[Byte]] = { val parts = getNumPartitionsOrDefault(numPartitions, jsc) val s = getSeedOrDefault(seed) - RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) + RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector) + } + +} + +/** + * :: DeveloperApi :: + * MultivariateStatisticalSummary with Vector fields serialized. + */ +@DeveloperApi +class MultivariateStatisticalSummarySerialized(val summary: MultivariateStatisticalSummary) + extends Serializable { + + def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean) + + def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance) + + def count: Long = summary.count + + def numNonzeros: Array[Byte] = SerDe.serializeDoubleVector(summary.numNonzeros) + + def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max) + + def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min) +} + +/** + * SerDe utility functions for PythonMLLibAPI. + */ +private[spark] object SerDe extends Serializable { + private val DENSE_VECTOR_MAGIC: Byte = 1 + private val SPARSE_VECTOR_MAGIC: Byte = 2 + private val DENSE_MATRIX_MAGIC: Byte = 3 + private val LABELED_POINT_MAGIC: Byte = 4 + + private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { + require(bytes.length - offset >= 5, "Byte array too short") + val magic = bytes(offset) + if (magic == DENSE_VECTOR_MAGIC) { + deserializeDenseVector(bytes, offset) + } else if (magic == SPARSE_VECTOR_MAGIC) { + deserializeSparseVector(bytes, offset) + } else { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + } + + private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = { + require(bytes.length - offset == 8, "Wrong size byte array for Double") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + bb.getDouble + } + + private[python] def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 5, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val length = bb.getInt() + require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) + val db = bb.asDoubleBuffer() + val ans = new Array[Double](length.toInt) + db.get(ans) + Vectors.dense(ans) + } + + private[python] def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 9, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val size = bb.getInt() + val nonZeros = bb.getInt() + require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) + val ib = bb.asIntBuffer() + val indices = new Array[Int](nonZeros) + ib.get(indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + val values = new Array[Double](nonZeros) + db.get(values) + Vectors.sparse(size, indices, values) + } + + /** + * Returns an 8-byte array for the input Double. + * + * Note: we currently do not use a magic byte for double for storage efficiency. + * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety. + * The corresponding deserializer, deserializeDouble, needs to be modified as well if the + * serialization scheme changes. + */ + private[python] def serializeDouble(double: Double): Array[Byte] = { + val bytes = new Array[Byte](8) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putDouble(double) + bytes + } + + private[python] def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length + val bytes = new Array[Byte](5 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(DENSE_VECTOR_MAGIC) + bb.putInt(len) + val db = bb.asDoubleBuffer() + db.put(doubles) + bytes + } + + private[python] def serializeSparseVector(vector: SparseVector): Array[Byte] = { + val nonZeros = vector.indices.length + val bytes = new Array[Byte](9 + 12 * nonZeros) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(SPARSE_VECTOR_MAGIC) + bb.putInt(vector.size) + bb.putInt(nonZeros) + val ib = bb.asIntBuffer() + ib.put(vector.indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + db.put(vector.values) + bytes + } + + private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { + case s: SparseVector => + serializeSparseVector(s) + case _ => + serializeDenseVector(vector.toArray) } + private[python] def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { + val packetLength = bytes.length + if (packetLength < 9) { + throw new IllegalArgumentException("Byte array too short.") + } + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + if (magic != DENSE_MATRIX_MAGIC) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val rows = bb.getInt() + val cols = bb.getInt() + if (packetLength != 9 + 8 * rows * cols) { + throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") + } + val db = bb.asDoubleBuffer() + val ans = new Array[Array[Double]](rows.toInt) + for (i <- 0 until rows.toInt) { + ans(i) = new Array[Double](cols.toInt) + db.get(ans(i)) + } + ans + } + + private[python] def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = { + val rows = doubles.length + var cols = 0 + if (rows > 0) { + cols = doubles(0).length + } + val bytes = new Array[Byte](9 + 8 * rows * cols) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(DENSE_MATRIX_MAGIC) + bb.putInt(rows) + bb.putInt(cols) + val db = bb.asDoubleBuffer() + for (i <- 0 until rows) { + db.put(doubles(i)) + } + bytes + } + + private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = { + val fb = serializeDoubleVector(p.features) + val bytes = new Array[Byte](1 + 8 + fb.length) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(LABELED_POINT_MAGIC) + bb.putDouble(p.label) + bb.put(fb) + bytes + } + + private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { + require(bytes.length >= 9, "Byte array too short") + val magic = bytes(0) + if (magic != LABELED_POINT_MAGIC) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val labelBytes = ByteBuffer.wrap(bytes, 1, 8) + labelBytes.order(ByteOrder.nativeOrder()) + val label = labelBytes.asDoubleBuffer().get(0) + LabeledPoint(label, deserializeDoubleVector(bytes, 9)) + } + + // Reformat a Matrix into Array[Array[Double]] for serialization + private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = { + val values = matrix.toArray + Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows)) + } + + + /** Unpack a Rating object from an array of bytes */ + private[python] def unpackRating(ratingBytes: Array[Byte]): Rating = { + val bb = ByteBuffer.wrap(ratingBytes) + bb.order(ByteOrder.nativeOrder()) + val user = bb.getInt() + val product = bb.getInt() + val rating = bb.getDouble() + new Rating(user, product, rating) + } + + /** Unpack a tuple of Ints from an array of bytes */ + def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = { + val bb = ByteBuffer.wrap(tupleBytes) + bb.order(ByteOrder.nativeOrder()) + val v1 = bb.getInt() + val v2 = bb.getInt() + (v1, v2) + } + + /** + * Serialize a Rating object into an array of bytes. + * It can be deserialized using RatingDeserializer(). + * + * @param rate the Rating object to serialize + * @return + */ + def serializeRating(rate: Rating): Array[Byte] = { + val len = 3 + val bytes = new Array[Byte](4 + 8 * len) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.putInt(len) + val db = bb.asDoubleBuffer() + db.put(rate.user.toDouble) + db.put(rate.product.toDouble) + db.put(rate.rating) + bytes + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2242329b7918e..486bdbfa9cb47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -62,7 +62,7 @@ class LogisticRegressionModel ( override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept - val score = 1.0/ (1.0 + math.exp(-margin)) + val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { case Some(t) => if (score < t) 0.0 else 1.0 case None => score @@ -73,6 +73,8 @@ class LogisticRegressionModel ( /** * Train a classification model for Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( private var stepSize: Double, @@ -101,7 +103,7 @@ class LogisticRegressionWithSGD private ( } /** - * Top-level methods for calling Logistic Regression. + * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. * NOTE: Labels used in Logistic Regression should be {0, 1} */ object LogisticRegressionWithSGD { @@ -188,3 +190,22 @@ object LogisticRegressionWithSGD { train(input, numIterations, 1.0, 1.0) } } + +/** + * Train a classification model for Logistic Regression using Limited-memory BFGS. + * Standard feature scaling and L2 regularization are used by default. + * NOTE: Labels used in Logistic Regression should be {0, 1} + */ +class LogisticRegressionWithLBFGS + extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { + + this.setFeatureScaling(true) + + override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater) + + override protected val validators = List(DataValidators.binaryLabelValidator) + + override protected def createModel(weights: Vector, intercept: Double) = { + new LogisticRegressionModel(weights, intercept) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 6c7be0a4f1dcb..8c8e4a161aa5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -73,7 +73,7 @@ class NaiveBayesModel private[mllib] ( * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for * document classification. By making every vector a 0-1 vector, it can also be used as - * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). + * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative. */ class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { @@ -91,12 +91,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ def run(data: RDD[LabeledPoint]) = { + val requireNonnegativeValues: Vector => Unit = (v: Vector) => { + val values = v match { + case sv: SparseVector => + sv.values + case dv: DenseVector => + dv.values + } + if (!values.forall(_ >= 0.0)) { + throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") + } + } + // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( - createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector), - mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze), + createCombiner = (v: Vector) => { + requireNonnegativeValues(v) + (1L, v.toBreeze.toDenseVector) + }, + mergeValue = (c: (Long, BDV[Double]), v: Vector) => { + requireNonnegativeValues(v) + (c._1 + 1L, c._2 += v.toBreeze) + }, mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => (c1._1 + c2._1, c1._2 += c2._2) ).collect() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 0f6d5809e098f..c53475818395f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -32,12 +32,12 @@ import org.apache.spark.util.Utils * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. * - * @param numFeatures number of features (default: 1000000) + * @param numFeatures number of features (default: 2^20^) */ @Experimental class HashingTF(val numFeatures: Int) extends Serializable { - def this() = this(1000000) + def this() = this(1 << 20) /** * Returns the index of the input term. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 7ed611a857acc..d40d5553c1d21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -36,87 +36,25 @@ class IDF { // TODO: Allow different IDF formulations. - private var brzIdf: BDV[Double] = _ - /** * Computes the inverse document frequency. * @param dataset an RDD of term frequency vectors */ - def fit(dataset: RDD[Vector]): this.type = { - brzIdf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + def fit(dataset: RDD[Vector]): IDFModel = { + val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( seqOp = (df, v) => df.add(v), combOp = (df1, df2) => df1.merge(df2) ).idf() - this + new IDFModel(idf) } /** * Computes the inverse document frequency. * @param dataset a JavaRDD of term frequency vectors */ - def fit(dataset: JavaRDD[Vector]): this.type = { + def fit(dataset: JavaRDD[Vector]): IDFModel = { fit(dataset.rdd) } - - /** - * Transforms term frequency (TF) vectors to TF-IDF vectors. - * @param dataset an RDD of term frequency vectors - * @return an RDD of TF-IDF vectors - */ - def transform(dataset: RDD[Vector]): RDD[Vector] = { - if (!initialized) { - throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") - } - val theIdf = brzIdf - val bcIdf = dataset.context.broadcast(theIdf) - dataset.mapPartitions { iter => - val thisIdf = bcIdf.value - iter.map { v => - val n = v.size - v match { - case sv: SparseVector => - val nnz = sv.indices.size - val newValues = new Array[Double](nnz) - var k = 0 - while (k < nnz) { - newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) - k += 1 - } - Vectors.sparse(n, sv.indices, newValues) - case dv: DenseVector => - val newValues = new Array[Double](n) - var j = 0 - while (j < n) { - newValues(j) = dv.values(j) * thisIdf(j) - j += 1 - } - Vectors.dense(newValues) - case other => - throw new UnsupportedOperationException( - s"Only sparse and dense vectors are supported but got ${other.getClass}.") - } - } - } - } - - /** - * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). - * @param dataset a JavaRDD of term frequency vectors - * @return a JavaRDD of TF-IDF vectors - */ - def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { - transform(dataset.rdd).toJavaRDD() - } - - /** Returns the IDF vector. */ - def idf(): Vector = { - if (!initialized) { - throw new IllegalStateException("Haven't learned IDF yet. Call fit first.") - } - Vectors.fromBreeze(brzIdf) - } - - private def initialized: Boolean = brzIdf != null } private object IDF { @@ -177,18 +115,72 @@ private object IDF { private def isEmpty: Boolean = m == 0L /** Returns the current IDF vector. */ - def idf(): BDV[Double] = { + def idf(): Vector = { if (isEmpty) { throw new IllegalStateException("Haven't seen any document yet.") } val n = df.length - val inv = BDV.zeros[Double](n) + val inv = new Array[Double](n) var j = 0 while (j < n) { inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) j += 1 } - inv + Vectors.dense(inv) } } } + +/** + * :: Experimental :: + * Represents an IDF model that can transform term frequency vectors. + */ +@Experimental +class IDFModel private[mllib] (val idf: Vector) extends Serializable { + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors. + * @param dataset an RDD of term frequency vectors + * @return an RDD of TF-IDF vectors + */ + def transform(dataset: RDD[Vector]): RDD[Vector] = { + val bcIdf = dataset.context.broadcast(idf) + dataset.mapPartitions { iter => + val thisIdf = bcIdf.value + iter.map { v => + val n = v.size + v match { + case sv: SparseVector => + val nnz = sv.indices.size + val newValues = new Array[Double](nnz) + var k = 0 + while (k < nnz) { + newValues(k) = sv.values(k) * thisIdf(sv.indices(k)) + k += 1 + } + Vectors.sparse(n, sv.indices, newValues) + case dv: DenseVector => + val newValues = new Array[Double](n) + var j = 0 + while (j < n) { + newValues(j) = dv.values(j) * thisIdf(j) + j += 1 + } + Vectors.dense(newValues) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } + } + } + + /** + * Transforms term frequency (TF) vectors to TF-IDF vectors (Java version). + * @param dataset a JavaRDD of term frequency vectors + * @return a JavaRDD of TF-IDF vectors + */ + def transform(dataset: JavaRDD[Vector]): JavaRDD[Vector] = { + transform(dataset.rdd).toJavaRDD() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index ea9fd0a80d8e0..3afb47767281c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -19,11 +19,11 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} /** - * :: DeveloperApi :: + * :: Experimental :: * Normalizes samples individually to unit L^p^ norm * * For any 1 <= p < Double.PositiveInfinity, normalizes samples using @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} * * @param p Normalization in L^p^ space, p = 2 by default. */ -@DeveloperApi +@Experimental class Normalizer(p: Double) extends VectorTransformer { def this() = this(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index cc2d7579c2901..4dfd1f0ab8134 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -17,16 +17,17 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD /** - * :: DeveloperApi :: + * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. * @@ -34,38 +35,56 @@ import org.apache.spark.rdd.RDD * dense output, so this does not work on sparse input and will raise an exception. * @param withStd True by default. Scales the data to unit standard deviation. */ -@DeveloperApi -class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { +@Experimental +class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { def this() = this(false, true) - require(withMean || withStd, s"withMean and withStd both equal to false. Doing nothing.") - - private var mean: BV[Double] = _ - private var factor: BV[Double] = _ + if (!(withMean || withStd)) { + logWarning("Both withMean and withStd are false. The model does nothing.") + } /** * Computes the mean and variance and stores as a model to be used for later scaling. * * @param data The data used to compute the mean and variance to build the transformation model. - * @return This StandardScalar object. + * @return a StandardScalarModel */ - def fit(data: RDD[Vector]): this.type = { + def fit(data: RDD[Vector]): StandardScalerModel = { + // TODO: skip computation if both withMean and withStd are false val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) + new StandardScalerModel(withMean, withStd, summary.mean, summary.variance) + } +} - mean = summary.mean.toBreeze - factor = summary.variance.toBreeze - require(mean.length == factor.length) +/** + * :: Experimental :: + * Represents a StandardScaler model that can transform vectors. + * + * @param withMean whether to center the data before scaling + * @param withStd whether to scale the data to have unit standard deviation + * @param mean column mean values + * @param variance column variance values + */ +@Experimental +class StandardScalerModel private[mllib] ( + val withMean: Boolean, + val withStd: Boolean, + val mean: Vector, + val variance: Vector) extends VectorTransformer { + require(mean.size == variance.size) + + private lazy val factor: BDV[Double] = { + val f = BDV.zeros[Double](variance.size) var i = 0 - while (i < factor.length) { - factor(i) = if (factor(i) != 0.0) 1.0 / math.sqrt(factor(i)) else 0.0 + while (i < f.size) { + f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 i += 1 } - - this + f } /** @@ -76,13 +95,7 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor * for the column with zero variance. */ override def transform(vector: Vector): Vector = { - if (mean == null || factor == null) { - throw new IllegalStateException( - "Haven't learned column summary statistics yet. Call fit first.") - } - - require(vector.size == mean.length) - + require(mean.size == vector.size) if (withMean) { vector.toBreeze match { case dv: BDV[Double] => @@ -115,5 +128,4 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransfor vector } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 3bf44ad7c44e3..fc1444705364a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -17,6 +17,9 @@ package org.apache.spark.mllib.feature +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -25,8 +28,8 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -115,7 +118,6 @@ class Word2Vec extends Serializable with Logging { private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 - private val layer1Size = vectorSize /** context words from [-window, window] */ private val window = 5 @@ -127,7 +129,6 @@ class Word2Vec extends Serializable with Logging { private var vocabSize = 0 private var vocab: Array[VocabWord] = null private var vocabHash = mutable.HashMap.empty[String, Int] - private var alpha = startingAlpha private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) @@ -239,7 +240,7 @@ class Word2Vec extends Serializable with Logging { a += 1 } } - + /** * Computes the vector representation of each word in vocabulary. * @param dataset an RDD of words @@ -282,13 +283,15 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - var syn0Global = - Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) - var syn1Global = new Array[Float](vocabSize * layer1Size) - + val syn0Global = + Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) + val syn1Global = new Array[Float](vocabSize * vectorSize) + var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val syn0Modify = new Array[Int](vocabSize) + val syn1Modify = new Array[Int](vocabSize) val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -313,24 +316,27 @@ class Word2Vec extends Serializable with Logging { val c = pos - window + a if (c >= 0 && c < sentence.size) { val lastWord = sentence(c) - val l1 = lastWord * layer1Size - val neu1e = new Array[Float](layer1Size) + val l1 = lastWord * vectorSize + val neu1e = new Array[Float](vectorSize) // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { - val l2 = bcVocab.value(word).point(d) * layer1Size + val inner = bcVocab.value(word).point(d) + val l2 = inner * vectorSize // Propagate hidden -> output - var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) + var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat - blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) - blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) + blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + syn1Modify(inner) += 1 } d += 1 } - blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) + blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) + syn0Modify(lastWord) += 1 } } a += 1 @@ -339,21 +345,37 @@ class Word2Vec extends Serializable with Logging { } (syn0, syn1, lwc, wc) } - Iterator(model) + val syn0Local = model._1 + val syn1Local = model._2 + // Only output modified vectors. + Iterator.tabulate(vocabSize) { index => + if (syn0Modify(index) > 0) { + Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) + } else { + None + } + }.flatten ++ Iterator.tabulate(vocabSize) { index => + if (syn1Modify(index) > 0) { + Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) + } else { + None + } + }.flatten } - val (aggSyn0, aggSyn1, _, _) = - partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + val synAgg = partial.reduceByKey { case (v1, v2) => + blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) + v1 + }.collect() + var i = 0 + while (i < synAgg.length) { + val index = synAgg(i)._1 + if (index < vocabSize) { + Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize) + } else { + Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize) } - syn0Global = aggSyn0 - syn1Global = aggSyn1 + i += 1 + } } newSentences.unpersist() @@ -361,19 +383,30 @@ class Word2Vec extends Serializable with Logging { var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word - val vector = new Array[Float](layer1Size) - Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) + val vector = new Array[Float](vectorSize) + Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) word2VecMap += word -> vector i += 1 } new Word2VecModel(word2VecMap.toMap) } + + /** + * Computes the vector representation of each word in vocabulary (Java version). + * @param dataset a JavaRDD of words + * @return a Word2VecModel + */ + def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { + fit(dataset.rdd.map(_.asScala)) + } } /** -* Word2Vec model + * :: Experimental :: + * Word2Vec model */ +@Experimental class Word2VecModel private[mllib] ( private val model: Map[String, Array[Float]]) extends Serializable { @@ -400,15 +433,6 @@ class Word2VecModel private[mllib] ( } } - /** - * Transforms an RDD to its vector representation - * @param dataset a an RDD of words - * @return RDD of vector representation - */ - def transform(dataset: RDD[String]): RDD[Vector] = { - dataset.map(word => transform(word)) - } - /** * Find synonyms of a word * @param word a word diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala new file mode 100644 index 0000000000000..70e23033c8754 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} + +/** + * BLAS routines for MLlib's vectors and matrices. + */ +private[mllib] object BLAS extends Serializable { + + @transient private var _f2jBLAS: NetlibBLAS = _ + + // For level-1 routines, we use Java implementation. + private def f2jBLAS: NetlibBLAS = { + if (_f2jBLAS == null) { + _f2jBLAS = new F2jBLAS + } + _f2jBLAS + } + + /** + * y += a * x + */ + def axpy(a: Double, x: Vector, y: Vector): Unit = { + require(x.size == y.size) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + axpy(a, sx, dy) + case dx: DenseVector => + axpy(a, dx, dy) + case _ => + throw new UnsupportedOperationException( + s"axpy doesn't support x type ${x.getClass}.") + } + case _ => + throw new IllegalArgumentException( + s"axpy only supports adding to a dense vector but got type ${y.getClass}.") + } + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = { + val n = x.size + f2jBLAS.daxpy(n, a, x.values, 1, y.values, 1) + } + + /** + * y += a * x + */ + private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = { + val nnz = x.indices.size + if (a == 1.0) { + var k = 0 + while (k < nnz) { + y.values(x.indices(k)) += x.values(k) + k += 1 + } + } else { + var k = 0 + while (k < nnz) { + y.values(x.indices(k)) += a * x.values(k) + k += 1 + } + } + } + + /** + * dot(x, y) + */ + def dot(x: Vector, y: Vector): Double = { + require(x.size == y.size) + (x, y) match { + case (dx: DenseVector, dy: DenseVector) => + dot(dx, dy) + case (sx: SparseVector, dy: DenseVector) => + dot(sx, dy) + case (dx: DenseVector, sy: SparseVector) => + dot(sy, dx) + case (sx: SparseVector, sy: SparseVector) => + dot(sx, sy) + case _ => + throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).") + } + } + + /** + * dot(x, y) + */ + private def dot(x: DenseVector, y: DenseVector): Double = { + val n = x.size + f2jBLAS.ddot(n, x.values, 1, y.values, 1) + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: DenseVector): Double = { + val nnz = x.indices.size + var sum = 0.0 + var k = 0 + while (k < nnz) { + sum += x.values(k) * y.values(x.indices(k)) + k += 1 + } + sum + } + + /** + * dot(x, y) + */ + private def dot(x: SparseVector, y: SparseVector): Double = { + var kx = 0 + val nnzx = x.indices.size + var ky = 0 + val nnzy = y.indices.size + var sum = 0.0 + // y catching x + while (kx < nnzx && ky < nnzy) { + val ix = x.indices(kx) + while (ky < nnzy && y.indices(ky) < ix) { + ky += 1 + } + if (ky < nnzy && y.indices(ky) == ix) { + sum += x.values(kx) * y.values(ky) + ky += 1 + } + kx += 1 + } + sum + } + + /** + * y = x + */ + def copy(x: Vector, y: Vector): Unit = { + val n = y.size + require(x.size == n) + y match { + case dy: DenseVector => + x match { + case sx: SparseVector => + var i = 0 + var k = 0 + val nnz = sx.indices.size + while (k < nnz) { + val j = sx.indices(k) + while (i < j) { + dy.values(i) = 0.0 + i += 1 + } + dy.values(i) = sx.values(k) + i += 1 + k += 1 + } + while (i < n) { + dy.values(i) = 0.0 + i += 1 + } + case dx: DenseVector => + Array.copy(dx.values, 0, dy.values, 0, n) + } + case _ => + throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}") + } + } + + /** + * x = a * x + */ + def scal(a: Double, x: Vector): Unit = { + x match { + case sx: SparseVector => + f2jBLAS.dscal(sx.values.size, a, sx.values, 1) + case dx: DenseVector => + f2jBLAS.dscal(dx.values.size, a, dx.values, 1) + case _ => + throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 77b3e8c714997..a45781d12e41e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.linalg import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} -import java.util.Arrays +import java.util import scala.annotation.varargs import scala.collection.JavaConverters._ @@ -30,6 +30,8 @@ import org.apache.spark.SparkException /** * Represents a numeric vector, whose index type is Int and value type is Double. + * + * Note: Users should not implement this interface. */ trait Vector extends Serializable { @@ -46,12 +48,12 @@ trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { case v: Vector => - Arrays.equals(this.toArray, v.toArray) + util.Arrays.equals(this.toArray, v.toArray) case _ => false } } - override def hashCode(): Int = Arrays.hashCode(this.toArray) + override def hashCode(): Int = util.Arrays.hashCode(this.toArray) /** * Converts the instance to a breeze vector. @@ -63,6 +65,13 @@ trait Vector extends Serializable { * @param i index */ def apply(i: Int): Double = toBreeze(i) + + /** + * Makes a deep copy of this vector. + */ + def copy: Vector = { + throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") + } } /** @@ -127,6 +136,16 @@ object Vectors { }.toSeq) } + /** + * Creates a dense vector of all zeros. + * + * @param size vector size + * @return a zero vector + */ + def zeros(size: Int): Vector = { + new DenseVector(new Array[Double](size)) + } + /** * Parses a string resulted from `Vector#toString` into * an [[org.apache.spark.mllib.linalg.Vector]]. @@ -142,7 +161,7 @@ object Vectors { case Seq(size: Double, indices: Array[Double], values: Array[Double]) => Vectors.sparse(size.toInt, indices.map(_.toInt), values) case other => - throw new SparkException(s"Cannot parse $other.") + throw new SparkException(s"Cannot parse $other.") } } @@ -183,6 +202,10 @@ class DenseVector(val values: Array[Double]) extends Vector { private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int) = values(i) + + override def copy: DenseVector = { + new DenseVector(values.clone()) + } } /** @@ -213,5 +236,9 @@ class SparseVector( data } + override def copy: SparseVector = { + new SparseVector(size, indices.clone(), values.clone()) + } + private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 45486b2c7d82d..2e414a73be8e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -53,8 +53,14 @@ class RowMatrix( /** Gets or computes the number of columns. */ override def numCols(): Long = { if (nCols <= 0) { - // Calling `first` will throw an exception if `rows` is empty. - nCols = rows.first().size + try { + // Calling `first` will throw an exception if `rows` is empty. + nCols = rows.first().size + } catch { + case err: UnsupportedOperationException => + sys.error("Cannot determine the number of cols because it is not specified in the " + + "constructor and the rows RDD is empty.") + } } nCols } @@ -222,7 +228,7 @@ class RowMatrix( EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter) case SVDMode.LocalLAPACK => val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] - val (uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) + val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) case SVDMode.DistARPACK => require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.") @@ -293,6 +299,10 @@ class RowMatrix( (s1._1 + s2._1, s1._2 += s2._2) ) + if (m <= 1) { + sys.error(s"RowMatrix.computeCovariance called on matrix with only $m rows." + + " Cannot compute the covariance of a RowMatrix with <= 1 row.") + } updateNumRows(m) mean :/= m.toDouble @@ -338,7 +348,7 @@ class RowMatrix( val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]] - val (u: BDM[Double], _, _) = brzSvd(Cov) + val brzSvd.SVD(u: BDM[Double], _, _) = brzSvd(Cov) if (k == n) { Matrices.dense(n, k, u.data) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 9d82f011e674a..fdd67160114ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.optimization -import breeze.linalg.{axpy => brzAxpy} - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} /** * :: DeveloperApi :: @@ -61,11 +60,10 @@ abstract class Gradient extends Serializable { @DeveloperApi class LogisticGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val margin: Double = -1.0 * brzWeights.dot(brzData) + val margin = -1.0 * dot(data, weights) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - val gradient = brzData * gradientMultiplier + val gradient = data.copy + scal(gradientMultiplier, gradient) val loss = if (label > 0) { math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p @@ -73,7 +71,7 @@ class LogisticGradient extends Gradient { math.log1p(math.exp(margin)) - margin } - (Vectors.fromBreeze(gradient), loss) + (gradient, loss) } override def compute( @@ -81,13 +79,9 @@ class LogisticGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val margin: Double = -1.0 * brzWeights.dot(brzData) + val margin = -1.0 * dot(data, weights) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - - brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze) - + axpy(gradientMultiplier, data, cumGradient) if (label > 0) { math.log1p(math.exp(margin)) } else { @@ -106,13 +100,11 @@ class LogisticGradient extends Gradient { @DeveloperApi class LeastSquaresGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val diff = brzWeights.dot(brzData) - label + val diff = dot(data, weights) - label val loss = diff * diff - val gradient = brzData * (2.0 * diff) - - (Vectors.fromBreeze(gradient), loss) + val gradient = data.copy + scal(2.0 * diff, gradient) + (gradient, loss) } override def compute( @@ -120,12 +112,8 @@ class LeastSquaresGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val diff = brzWeights.dot(brzData) - label - - brzAxpy(2.0 * diff, brzData, cumGradient.toBreeze) - + val diff = dot(data, weights) - label + axpy(2.0 * diff, data, cumGradient) diff * diff } } @@ -139,18 +127,16 @@ class LeastSquaresGradient extends Gradient { @DeveloperApi class HingeGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val dotProduct = brzWeights.dot(brzData) - + val dotProduct = dot(data, weights) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 - if (1.0 > labelScaled * dotProduct) { - (Vectors.fromBreeze(brzData * (-labelScaled)), 1.0 - labelScaled * dotProduct) + val gradient = data.copy + scal(-labelScaled, gradient) + (gradient, 1.0 - labelScaled * dotProduct) } else { - (Vectors.dense(new Array[Double](weights.size)), 0.0) + (Vectors.sparse(weights.size, Array.empty, Array.empty), 0.0) } } @@ -159,16 +145,12 @@ class HingeGradient extends Gradient { label: Double, weights: Vector, cumGradient: Vector): Double = { - val brzData = data.toBreeze - val brzWeights = weights.toBreeze - val dotProduct = brzWeights.dot(brzData) - + val dotProduct = dot(data, weights) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 - if (1.0 > labelScaled * dotProduct) { - brzAxpy(-labelScaled, brzData, cumGradient.toBreeze) + axpy(-labelScaled, data, cumGradient) 1.0 - labelScaled * dotProduct } else { 0.0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 26a2b62e76ed0..d16d0daf08565 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -19,14 +19,15 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, axpy} +import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS.axpy import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: @@ -68,8 +69,17 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) /** * Set the maximal number of iterations for L-BFGS. Default 100. + * @deprecated use [[LBFGS#setNumIterations]] instead */ + @deprecated("use setNumIterations instead", "1.1.0") def setMaxNumIterations(iters: Int): this.type = { + this.setNumIterations(iters) + } + + /** + * Set the maximal number of iterations for L-BFGS. Default 100. + */ + def setNumIterations(iters: Int): this.type = { this.maxNumIterations = iters this } @@ -192,31 +202,29 @@ object LBFGS extends Logging { regParam: Double, numExamples: Long) extends DiffFunction[BDV[Double]] { - private var i = 0 - - override def calculate(weights: BDV[Double]) = { + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { // Have a local copy to avoid the serialization of CostFun object which is not serializable. + val w = Vectors.fromBreeze(weights) + val n = w.size + val bcW = data.context.broadcast(w) val localGradient = gradient - val n = weights.length - val bcWeights = data.context.broadcast(weights) - val (gradientSum, lossSum) = data.treeAggregate((BDV.zeros[Double](n), 0.0))( + val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))( seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( - features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad)) + features, label, bcW.value, grad) (grad, loss + l) }, combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - (grad1 += grad2, loss1 + loss2) + axpy(1.0, grad2, grad1) + (grad1, loss1 + loss2) }) /** * regVal is sum of weight squares if it's L2 updater; * for other updater, the same logic is followed. */ - val regVal = updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 + val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2 val loss = lossSum / numExamples + regVal /** @@ -236,17 +244,13 @@ object LBFGS extends Logging { */ // The following gradientTotal is actually the regularization part of gradient. // Will add the gradientSum computed from the data with weights in the next step. - val gradientTotal = weights - updater.compute( - Vectors.fromBreeze(weights), - Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze + val gradientTotal = w.copy + axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal) // gradientTotal = gradientSum / numExamples + gradientTotal axpy(1.0 / numExamples, gradientSum, gradientTotal) - i += 1 - - (loss, gradientTotal) + (loss, gradientTotal.toBreeze.asInstanceOf[BDV[Double]]) } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9cab49f6ed1f0..28179fbc450c0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -20,14 +20,14 @@ package org.apache.spark.mllib.random import cern.jet.random.Poisson import cern.jet.random.engine.DRand -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} /** - * :: Experimental :: + * :: DeveloperApi :: * Trait for random data generators that generate i.i.d. data. */ -@Experimental +@DeveloperApi trait RandomDataGenerator[T] extends Pseudorandom with Serializable { /** @@ -43,10 +43,10 @@ trait RandomDataGenerator[T] extends Pseudorandom with Serializable { } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates i.i.d. samples from U[0.0, 1.0] */ -@Experimental +@DeveloperApi class UniformGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. @@ -62,10 +62,10 @@ class UniformGenerator extends RandomDataGenerator[Double] { } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates i.i.d. samples from the standard normal distribution. */ -@Experimental +@DeveloperApi class StandardNormalGenerator extends RandomDataGenerator[Double] { // XORShiftRandom for better performance. Thread safety isn't necessary here. @@ -81,12 +81,12 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { } /** - * :: Experimental :: + * :: DeveloperApi :: * Generates i.i.d. samples from the Poisson distribution with the given mean. * * @param mean mean for the Poisson distribution. */ -@Experimental +@DeveloperApi class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { private var rng = new Poisson(mean, new DRand) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala deleted file mode 100644 index b0a0593223910..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDGenerators.scala +++ /dev/null @@ -1,493 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.random - -import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.rdd.{RandomVectorRDD, RandomRDD} -import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils - -import scala.reflect.ClassTag - -/** - * :: Experimental :: - * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. - */ -@Experimental -object RandomRDDGenerators { - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. - * - * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - * `RandomRDDGenerators.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. - */ - @Experimental - def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = { - val uniform = new UniformGenerator() - randomRDD(sc, uniform, size, numPartitions, seed) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. - * - * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - * `RandomRDDGenerators.uniformRDD(sc, n, p).map(v => a + (b - a) * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. - */ - @Experimental - def uniformRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = { - uniformRDD(sc, size, numPartitions, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the uniform distribution on [0.0, 1.0]. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * To transform the distribution in the generated RDD from U[0.0, 1.0] to U[a, b], use - * `RandomRDDGenerators.uniformRDD(sc, n).map(v => a + (b - a) * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ U[0.0, 1.0]. - */ - @Experimental - def uniformRDD(sc: SparkContext, size: Long): RDD[Double] = { - uniformRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. - * - * To transform the distribution in the generated RDD from standard normal to some other normal - * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). - */ - @Experimental - def normalRDD(sc: SparkContext, size: Long, numPartitions: Int, seed: Long): RDD[Double] = { - val normal = new StandardNormalGenerator() - randomRDD(sc, normal, size, numPartitions, seed) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. - * - * To transform the distribution in the generated RDD from standard normal to some other normal - * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n, p).map(v => mean + sigma * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). - */ - @Experimental - def normalRDD(sc: SparkContext, size: Long, numPartitions: Int): RDD[Double] = { - normalRDD(sc, size, numPartitions, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * To transform the distribution in the generated RDD from standard normal to some other normal - * N(mean, sigma), use `RandomRDDGenerators.normalRDD(sc, n).map(v => mean + sigma * v)`. - * - * @param sc SparkContext used to create the RDD. - * @param size Size of the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). - */ - @Experimental - def normalRDD(sc: SparkContext, size: Long): RDD[Double] = { - normalRDD(sc, size, sc.defaultParallelism, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). - */ - @Experimental - def poissonRDD(sc: SparkContext, - mean: Double, - size: Long, - numPartitions: Int, - seed: Long): RDD[Double] = { - val poisson = new PoissonGenerator(mean) - randomRDD(sc, poisson, size, numPartitions, seed) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). - */ - @Experimental - def poissonRDD(sc: SparkContext, mean: Double, size: Long, numPartitions: Int): RDD[Double] = { - poissonRDD(sc, mean, size, numPartitions, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param size Size of the RDD. - * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). - */ - @Experimental - def poissonRDD(sc: SparkContext, mean: Double, size: Long): RDD[Double] = { - poissonRDD(sc, mean, size, sc.defaultParallelism, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Double] comprised of i.i.d. samples produced by generator. - */ - @Experimental - def randomRDD[T: ClassTag](sc: SparkContext, - generator: RandomDataGenerator[T], - size: Long, - numPartitions: Int, - seed: Long): RDD[T] = { - new RandomRDD[T](sc, size, numPartitions, generator, seed) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param size Size of the RDD. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Double] comprised of i.i.d. samples produced by generator. - */ - @Experimental - def randomRDD[T: ClassTag](sc: SparkContext, - generator: RandomDataGenerator[T], - size: Long, - numPartitions: Int): RDD[T] = { - randomRDD[T](sc, generator, size, numPartitions, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD comprised of i.i.d. samples produced by the input DistributionGenerator. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param size Size of the RDD. - * @return RDD[Double] comprised of i.i.d. samples produced by generator. - */ - @Experimental - def randomRDD[T: ClassTag](sc: SparkContext, - generator: RandomDataGenerator[T], - size: Long): RDD[T] = { - randomRDD[T](sc, generator, size, sc.defaultParallelism, Utils.random.nextLong) - } - - // TODO Generate RDD[Vector] from multivariate distributions. - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * uniform distribution on [0.0 1.0]. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d samples ~ U[0.0, 1.0]. - */ - @Experimental - def uniformVectorRDD(sc: SparkContext, - numRows: Long, - numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { - val uniform = new UniformGenerator() - randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * uniform distribution on [0.0 1.0]. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. - */ - @Experimental - def uniformVectorRDD(sc: SparkContext, - numRows: Long, - numCols: Int, - numPartitions: Int): RDD[Vector] = { - uniformVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * uniform distribution on [0.0 1.0]. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ U[0.0, 1.0]. - */ - @Experimental - def uniformVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { - uniformVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * standard normal distribution. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). - */ - @Experimental - def normalVectorRDD(sc: SparkContext, - numRows: Long, - numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { - val uniform = new StandardNormalGenerator() - randomVectorRDD(sc, uniform, numRows, numCols, numPartitions, seed) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * standard normal distribution. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). - */ - @Experimental - def normalVectorRDD(sc: SparkContext, - numRows: Long, - numCols: Int, - numPartitions: Int): RDD[Vector] = { - normalVectorRDD(sc, numRows, numCols, numPartitions, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * standard normal distribution. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ N(0.0, 1.0). - */ - @Experimental - def normalVectorRDD(sc: SparkContext, numRows: Long, numCols: Int): RDD[Vector] = { - normalVectorRDD(sc, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). - */ - @Experimental - def poissonVectorRDD(sc: SparkContext, - mean: Double, - numRows: Long, - numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { - val poisson = new PoissonGenerator(mean) - randomVectorRDD(sc, poisson, numRows, numCols, numPartitions, seed) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * Poisson distribution with the input mean. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). - */ - @Experimental - def poissonVectorRDD(sc: SparkContext, - mean: Double, - numRows: Long, - numCols: Int, - numPartitions: Int): RDD[Vector] = { - poissonVectorRDD(sc, mean, numRows, numCols, numPartitions, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the - * Poisson distribution with the input mean. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param mean Mean, or lambda, for the Poisson distribution. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). - */ - @Experimental - def poissonVectorRDD(sc: SparkContext, - mean: Double, - numRows: Long, - numCols: Int): RDD[Vector] = { - poissonVectorRDD(sc, mean, numRows, numCols, sc.defaultParallelism, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the - * input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @param seed Seed for the RNG that generates the seed for the generator in each partition. - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. - */ - @Experimental - def randomVectorRDD(sc: SparkContext, - generator: RandomDataGenerator[Double], - numRows: Long, - numCols: Int, - numPartitions: Int, - seed: Long): RDD[Vector] = { - new RandomVectorRDD(sc, numRows, numCols, numPartitions, generator, seed) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the - * input DistributionGenerator. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @param numPartitions Number of partitions in the RDD. - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. - */ - @Experimental - def randomVectorRDD(sc: SparkContext, - generator: RandomDataGenerator[Double], - numRows: Long, - numCols: Int, - numPartitions: Int): RDD[Vector] = { - randomVectorRDD(sc, generator, numRows, numCols, numPartitions, Utils.random.nextLong) - } - - /** - * :: Experimental :: - * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the - * input DistributionGenerator. - * sc.defaultParallelism used for the number of partitions in the RDD. - * - * @param sc SparkContext used to create the RDD. - * @param generator DistributionGenerator used to populate the RDD. - * @param numRows Number of Vectors in the RDD. - * @param numCols Number of elements in each Vector. - * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. - */ - @Experimental - def randomVectorRDD(sc: SparkContext, - generator: RandomDataGenerator[Double], - numRows: Long, - numCols: Int): RDD[Vector] = { - randomVectorRDD(sc, generator, numRows, numCols, - sc.defaultParallelism, Utils.random.nextLong) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala new file mode 100644 index 0000000000000..c5f4b084321f7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomRDDs.scala @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random + +import scala.reflect.ClassTag + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.rdd.{RandomRDD, RandomVectorRDD} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * Generator methods for creating RDDs comprised of i.i.d. samples from some distribution. + */ +@Experimental +object RandomRDDs { + + /** + * Generates an RDD comprised of i.i.d. samples from the uniform distribution `U(0.0, 1.0)`. + * + * To transform the distribution in the generated RDD from `U(0.0, 1.0)` to `U(a, b)`, use + * `RandomRDDs.uniformRDD(sc, n, p, seed).map(v => a + (b - a) * v)`. + * + * @param sc SparkContext used to create the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Double] comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + */ + def uniformRDD( + sc: SparkContext, + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Double] = { + val uniform = new UniformGenerator() + randomRDD(sc, uniform, size, numPartitionsOrDefault(sc, numPartitions), seed) + } + + /** + * Java-friendly version of [[RandomRDDs#uniformRDD]]. + */ + def uniformJavaRDD( + jsc: JavaSparkContext, + size: Long, + numPartitions: Int, + seed: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions, seed)) + } + + /** + * [[RandomRDDs#uniformJavaRDD]] with the default seed. + */ + def uniformJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size, numPartitions)) + } + + /** + * [[RandomRDDs#uniformJavaRDD]] with the default number of partitions and the default seed. + */ + def uniformJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(uniformRDD(jsc.sc, size)) + } + + /** + * Generates an RDD comprised of i.i.d. samples from the standard normal distribution. + * + * To transform the distribution in the generated RDD from standard normal to some other normal + * `N(mean, sigma^2^)`, use `RandomRDDs.normalRDD(sc, n, p, seed).map(v => mean + sigma * v)`. + * + * @param sc SparkContext used to create the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Double] comprised of i.i.d. samples ~ N(0.0, 1.0). + */ + def normalRDD( + sc: SparkContext, + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Double] = { + val normal = new StandardNormalGenerator() + randomRDD(sc, normal, size, numPartitionsOrDefault(sc, numPartitions), seed) + } + + /** + * Java-friendly version of [[RandomRDDs#normalRDD]]. + */ + def normalJavaRDD( + jsc: JavaSparkContext, + size: Long, + numPartitions: Int, + seed: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions, seed)) + } + + /** + * [[RandomRDDs#normalJavaRDD]] with the default seed. + */ + def normalJavaRDD(jsc: JavaSparkContext, size: Long, numPartitions: Int): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size, numPartitions)) + } + + /** + * [[RandomRDDs#normalJavaRDD]] with the default number of partitions and the default seed. + */ + def normalJavaRDD(jsc: JavaSparkContext, size: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(normalRDD(jsc.sc, size)) + } + + /** + * Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. + * + * @param sc SparkContext used to create the RDD. + * @param mean Mean, or lambda, for the Poisson distribution. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Double] comprised of i.i.d. samples ~ Pois(mean). + */ + def poissonRDD( + sc: SparkContext, + mean: Double, + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Double] = { + val poisson = new PoissonGenerator(mean) + randomRDD(sc, poisson, size, numPartitionsOrDefault(sc, numPartitions), seed) + } + + /** + * Java-friendly version of [[RandomRDDs#poissonRDD]]. + */ + def poissonJavaRDD( + jsc: JavaSparkContext, + mean: Double, + size: Long, + numPartitions: Int, + seed: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size, numPartitions, seed)) + } + + /** + * [[RandomRDDs#poissonJavaRDD]] with the default seed. + */ + def poissonJavaRDD( + jsc: JavaSparkContext, + mean: Double, + size: Long, + numPartitions: Int): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size, numPartitions)) + } + + /** + * [[RandomRDDs#poissonJavaRDD]] with the default number of partitions and the default seed. + */ + def poissonJavaRDD(jsc: JavaSparkContext, mean: Double, size: Long): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(poissonRDD(jsc.sc, mean, size)) + } + + /** + * :: DeveloperApi :: + * Generates an RDD comprised of i.i.d. samples produced by the input RandomDataGenerator. + * + * @param sc SparkContext used to create the RDD. + * @param generator RandomDataGenerator used to populate the RDD. + * @param size Size of the RDD. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Double] comprised of i.i.d. samples produced by generator. + */ + @DeveloperApi + def randomRDD[T: ClassTag]( + sc: SparkContext, + generator: RandomDataGenerator[T], + size: Long, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[T] = { + new RandomRDD[T](sc, size, numPartitionsOrDefault(sc, numPartitions), generator, seed) + } + + // TODO Generate RDD[Vector] from multivariate distributions. + + /** + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * uniform distribution on `U(0.0, 1.0)`. + * + * @param sc SparkContext used to create the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD. + * @param seed Seed for the RNG that generates the seed for the generator in each partition. + * @return RDD[Vector] with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. + */ + def uniformVectorRDD( + sc: SparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { + val uniform = new UniformGenerator() + randomVectorRDD(sc, uniform, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), seed) + } + + /** + * Java-friendly version of [[RandomRDDs#uniformVectorRDD]]. + */ + def uniformJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + uniformVectorRDD(jsc.sc, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#uniformJavaVectorRDD]] with the default seed. + */ + def uniformJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + uniformVectorRDD(jsc.sc, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#uniformJavaVectorRDD]] with the default number of partitions and the default seed. + */ + def uniformJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + uniformVectorRDD(jsc.sc, numRows, numCols).toJavaRDD() + } + + /** + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * standard normal distribution. + * + * @param sc SparkContext used to create the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + */ + def normalVectorRDD( + sc: SparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { + val normal = new StandardNormalGenerator() + randomVectorRDD(sc, normal, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), seed) + } + + /** + * Java-friendly version of [[RandomRDDs#normalVectorRDD]]. + */ + def normalJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + normalVectorRDD(jsc.sc, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#normalJavaVectorRDD]] with the default seed. + */ + def normalJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + normalVectorRDD(jsc.sc, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#normalJavaVectorRDD]] with the default number of partitions and the default seed. + */ + def normalJavaVectorRDD( + jsc: JavaSparkContext, + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + normalVectorRDD(jsc.sc, numRows, numCols).toJavaRDD() + } + + /** + * Generates an RDD[Vector] with vectors containing i.i.d. samples drawn from the + * Poisson distribution with the input mean. + * + * @param sc SparkContext used to create the RDD. + * @param mean Mean, or lambda, for the Poisson distribution. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`) + * @param seed Random seed (default: a random long integer). + * @return RDD[Vector] with vectors containing i.i.d. samples ~ Pois(mean). + */ + def poissonVectorRDD( + sc: SparkContext, + mean: Double, + numRows: Long, + numCols: Int, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { + val poisson = new PoissonGenerator(mean) + randomVectorRDD(sc, poisson, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), seed) + } + + /** + * Java-friendly version of [[RandomRDDs#poissonVectorRDD]]. + */ + def poissonJavaVectorRDD( + jsc: JavaSparkContext, + mean: Double, + numRows: Long, + numCols: Int, + numPartitions: Int, + seed: Long): JavaRDD[Vector] = { + poissonVectorRDD(jsc.sc, mean, numRows, numCols, numPartitions, seed).toJavaRDD() + } + + /** + * [[RandomRDDs#poissonJavaVectorRDD]] with the default seed. + */ + def poissonJavaVectorRDD( + jsc: JavaSparkContext, + mean: Double, + numRows: Long, + numCols: Int, + numPartitions: Int): JavaRDD[Vector] = { + poissonVectorRDD(jsc.sc, mean, numRows, numCols, numPartitions).toJavaRDD() + } + + /** + * [[RandomRDDs#poissonJavaVectorRDD]] with the default number of partitions and the default seed. + */ + def poissonJavaVectorRDD( + jsc: JavaSparkContext, + mean: Double, + numRows: Long, + numCols: Int): JavaRDD[Vector] = { + poissonVectorRDD(jsc.sc, mean, numRows, numCols).toJavaRDD() + } + + /** + * :: DeveloperApi :: + * Generates an RDD[Vector] with vectors containing i.i.d. samples produced by the + * input RandomDataGenerator. + * + * @param sc SparkContext used to create the RDD. + * @param generator RandomDataGenerator used to populate the RDD. + * @param numRows Number of Vectors in the RDD. + * @param numCols Number of elements in each Vector. + * @param numPartitions Number of partitions in the RDD (default: `sc.defaultParallelism`). + * @param seed Random seed (default: a random long integer). + * @return RDD[Vector] with vectors containing i.i.d. samples produced by generator. + */ + @DeveloperApi + def randomVectorRDD(sc: SparkContext, + generator: RandomDataGenerator[Double], + numRows: Long, + numCols: Int, + numPartitions: Int = 0, + seed: Long = Utils.random.nextLong()): RDD[Vector] = { + new RandomVectorRDD( + sc, numRows, numCols, numPartitionsOrDefault(sc, numPartitions), generator, seed) + } + + /** + * Returns `numPartitions` if it is positive, or `sc.defaultParallelism` otherwise. + */ + private def numPartitionsOrDefault(sc: SparkContext, numPartitions: Int): Int = { + if (numPartitions > 0) numPartitions else sc.defaultMinPartitions + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index c8db3910c6eab..910eff9540a47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -105,16 +105,16 @@ private[mllib] object RandomRDD { def getPointIterator[T: ClassTag](partition: RandomRDDPartition[T]): Iterator[T] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) - Array.fill(partition.size)(generator.nextValue()).toIterator + Iterator.fill(partition.size)(generator.nextValue()) } // The RNG has to be reset every time the iterator is requested to guarantee same data // every time the content of the RDD is examined. - def getVectorIterator(partition: RandomRDDPartition[Double], - vectorSize: Int): Iterator[Vector] = { + def getVectorIterator( + partition: RandomRDDPartition[Double], + vectorSize: Int): Iterator[Vector] = { val generator = partition.generator.copy() generator.setSeed(partition.seed) - Array.fill(partition.size)(new DenseVector( - (0 until vectorSize).map { _ => generator.nextValue() }.toArray)).toIterator + Iterator.fill(partition.size)(new DenseVector(Array.fill(vectorSize)(generator.nextValue()))) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 8ebc7e27ed4dd..84d192db53e26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -111,11 +111,17 @@ class ALS private ( */ def this() = this(-1, -1, 10, 10, 0.01, false, 1.0) + /** If true, do alternating nonnegative least squares. */ + private var nonnegative = false + + /** storage level for user/product in/out links */ + private var intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK + /** * Set the number of blocks for both user blocks and product blocks to parallelize the computation * into; pass -1 for an auto-configured number of blocks. Default: -1. */ - def setBlocks(numBlocks: Int): ALS = { + def setBlocks(numBlocks: Int): this.type = { this.numUserBlocks = numBlocks this.numProductBlocks = numBlocks this @@ -124,7 +130,7 @@ class ALS private ( /** * Set the number of user blocks to parallelize the computation. */ - def setUserBlocks(numUserBlocks: Int): ALS = { + def setUserBlocks(numUserBlocks: Int): this.type = { this.numUserBlocks = numUserBlocks this } @@ -132,31 +138,31 @@ class ALS private ( /** * Set the number of product blocks to parallelize the computation. */ - def setProductBlocks(numProductBlocks: Int): ALS = { + def setProductBlocks(numProductBlocks: Int): this.type = { this.numProductBlocks = numProductBlocks this } /** Set the rank of the feature matrices computed (number of features). Default: 10. */ - def setRank(rank: Int): ALS = { + def setRank(rank: Int): this.type = { this.rank = rank this } /** Set the number of iterations to run. Default: 10. */ - def setIterations(iterations: Int): ALS = { + def setIterations(iterations: Int): this.type = { this.iterations = iterations this } /** Set the regularization parameter, lambda. Default: 0.01. */ - def setLambda(lambda: Double): ALS = { + def setLambda(lambda: Double): this.type = { this.lambda = lambda this } /** Sets whether to use implicit preference. Default: false. */ - def setImplicitPrefs(implicitPrefs: Boolean): ALS = { + def setImplicitPrefs(implicitPrefs: Boolean): this.type = { this.implicitPrefs = implicitPrefs this } @@ -166,29 +172,38 @@ class ALS private ( * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. */ @Experimental - def setAlpha(alpha: Double): ALS = { + def setAlpha(alpha: Double): this.type = { this.alpha = alpha this } /** Sets a random seed to have deterministic results. */ - def setSeed(seed: Long): ALS = { + def setSeed(seed: Long): this.type = { this.seed = seed this } - /** If true, do alternating nonnegative least squares. */ - private var nonnegative = false - /** * Set whether the least-squares problems solved at each iteration should have * nonnegativity constraints. */ - def setNonnegative(b: Boolean): ALS = { + def setNonnegative(b: Boolean): this.type = { this.nonnegative = b this } + /** + * :: DeveloperApi :: + * Sets storage level for intermediate RDDs (user/product in/out links). The default value is + * `MEMORY_AND_DISK`. Users can change it to a serialized storage, e.g., `MEMORY_AND_DISK_SER` and + * set `spark.rdd.compress` to `true` to reduce the space requirement, at the cost of speed. + */ + @DeveloperApi + def setIntermediateRDDStorageLevel(storageLevel: StorageLevel): this.type = { + this.intermediateRDDStorageLevel = storageLevel + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. @@ -441,8 +456,8 @@ class ALS private ( }, preservesPartitioning = true) val inLinks = links.mapValues(_._1) val outLinks = links.mapValues(_._2) - inLinks.persist(StorageLevel.MEMORY_AND_DISK) - outLinks.persist(StorageLevel.MEMORY_AND_DISK) + inLinks.persist(intermediateRDDStorageLevel) + outLinks.persist(intermediateRDDStorageLevel) (inLinks, outLinks) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index a1a76fcbe9f9c..478c6485052b6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.api.python.PythonMLLibAPI +import org.apache.spark.mllib.api.python.SerDe /** * Model representing the result of matrix factorization. @@ -117,9 +117,8 @@ class MatrixFactorizationModel private[mllib] ( */ @DeveloperApi def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { - val pythonAPI = new PythonMLLibAPI() - val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes)) - predict(usersProducts).map(rate => pythonAPI.serializeRating(rate)) + val usersProducts = usersProductsJRDD.rdd.map(xBytes => SerDe.unpackTuple(xBytes)) + predict(usersProducts).map(rate => SerDe.serializeRating(rate)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 54854252d7477..20c1fdd2269ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ @@ -94,6 +95,22 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] protected var validateData: Boolean = true + /** + * Whether to perform feature scaling before model training to reduce the condition numbers + * which can significantly help the optimizer converging faster. The scaling correction will be + * translated back to resulting model weights, so it's transparent to users. + * Note: This technique is used in both libsvm and glmnet packages. Default false. + */ + private var useFeatureScaling = false + + /** + * Set if the algorithm should use feature scaling to improve the convergence during optimization. + */ + private[mllib] def setFeatureScaling(useFeatureScaling: Boolean): this.type = { + this.useFeatureScaling = useFeatureScaling + this + } + /** * Create a model given the weights and intercept */ @@ -137,11 +154,45 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] throw new SparkException("Input validation failed.") } + /** + * Scaling columns to unit variance as a heuristic to reduce the condition number: + * + * During the optimization process, the convergence (rate) depends on the condition number of + * the training dataset. Scaling the variables often reduces this condition number + * heuristically, thus improving the convergence rate. Without reducing the condition number, + * some training datasets mixing the columns with different scales may not be able to converge. + * + * GLMNET and LIBSVM packages perform the scaling to reduce the condition number, and return + * the weights in the original scale. + * See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf + * + * Here, if useFeatureScaling is enabled, we will standardize the training features by dividing + * the variance of each column (without subtracting the mean), and train the model in the + * scaled space. Then we transform the coefficients from the scaled space to the original scale + * as GLMNET and LIBSVM do. + * + * Currently, it's only enabled in LogisticRegressionWithLBFGS + */ + val scaler = if (useFeatureScaling) { + (new StandardScaler).fit(input.map(x => x.features)) + } else { + null + } + // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) + if(useFeatureScaling) { + input.map(labeledPoint => + (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features)))) + } else { + input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features))) + } } else { - input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + if (useFeatureScaling) { + input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features))) + } else { + input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) + } } val initialWeightsWithIntercept = if (addIntercept) { @@ -153,13 +204,25 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0 - val weights = + var weights = if (addIntercept) { Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)) } else { weightsWithIntercept } + /** + * The weights and intercept are trained in the scaled space; we're converting them back to + * the original scale. + * + * Math shows that if we only perform standardization without subtracting means, the intercept + * will not be changed. w_i = w_i' / v_i where w_i' is the coefficient in the scaled space, w_i + * is the coefficient in the original space, and v_i is the variance of the column i. + */ + if (useFeatureScaling) { + weights = scaler.transform(weights) + } + createModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 62a03af4a9964..17c753c56681f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -36,7 +36,7 @@ case class LabeledPoint(label: Double, features: Vector) { /** * Parser for [[org.apache.spark.mllib.regression.LabeledPoint]]. */ -private[mllib] object LabeledPointParser { +object LabeledPoint { /** * Parses a string resulted from `LabeledPoint#toString` into * an [[org.apache.spark.mllib.regression.LabeledPoint]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index b8b0b42611775..8db0442a7a569 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -17,8 +17,12 @@ package org.apache.spark.mllib.regression -import org.apache.spark.annotation.DeveloperApi +import scala.reflect.ClassTag + import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.DStream /** @@ -92,15 +96,30 @@ abstract class StreamingLinearAlgorithm[ /** * Use the model to make predictions on batches of data from a DStream * - * @param data DStream containing labeled data + * @param data DStream containing feature vectors * @return DStream containing predictions */ - def predictOn(data: DStream[LabeledPoint]): DStream[Double] = { + def predictOn(data: DStream[Vector]): DStream[Double] = { if (Option(model.weights) == None) { - logError("Initial weights must be set before starting prediction") - throw new IllegalArgumentException + val msg = "Initial weights must be set before starting prediction" + logError(msg) + throw new IllegalArgumentException(msg) } - data.map(x => model.predict(x.features)) + data.map(model.predict) } + /** + * Use the model to make predictions on the values of a DStream and carry over its keys. + * @param data DStream containing feature vectors + * @tparam K key type + * @return DStream containing the input keys and the predictions as values + */ + def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Double)] = { + if (Option(model.weights) == None) { + val msg = "Initial weights must be set before starting prediction" + logError(msg) + throw new IllegalArgumentException(msg) + } + data.mapValues(model.predict) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index 8851097050318..1d11fde24712c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.Vector /** * Train or predict a linear regression model on streaming data. Training uses diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 5105b5c37aaaa..7d845c44365dd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -55,8 +55,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ def add(sample: Vector): this.type = { if (n == 0) { - require(sample.toBreeze.length > 0, s"Vector should have dimension larger than zero.") - n = sample.toBreeze.length + require(sample.size > 0, s"Vector should have dimension larger than zero.") + n = sample.size currMean = BDV.zeros[Double](n) currM2n = BDV.zeros[Double](n) @@ -65,8 +65,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMin = BDV.fill(n)(Double.MaxValue) } - require(n == sample.toBreeze.length, s"Dimensions mismatch when adding new sample." + - s" Expecting $n but got ${sample.toBreeze.length}.") + require(n == sample.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $n but got ${sample.size}.") sample.toBreeze.activeIterator.foreach { case (_, 0.0) => // Skip explicit zero elements. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f416a9fbb323d..3cf4e807b4cf7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -18,8 +18,11 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations +import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult} import org.apache.spark.rdd.RDD /** @@ -28,6 +31,18 @@ import org.apache.spark.rdd.RDD @Experimental object Statistics { + /** + * :: Experimental :: + * Computes column-wise summary statistics for the input RDD[Vector]. + * + * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. + * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. + */ + @Experimental + def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { + new RowMatrix(X).computeColumnSummaryStatistics() + } + /** * :: Experimental :: * Compute the Pearson correlation matrix for the input RDD of Vectors. @@ -89,4 +104,66 @@ object Statistics { */ @Experimental def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) + + /** + * :: Experimental :: + * Conduct Pearson's chi-squared goodness of fit test of the observed data against the + * expected distribution. + * + * Note: the two input Vectors need to have the same size. + * `observed` cannot contain negative values. + * `expected` cannot contain nonpositive values. + * + * @param observed Vector containing the observed categorical counts/relative frequencies. + * @param expected Vector containing the expected categorical counts/relative frequencies. + * `expected` is rescaled if the `expected` sum differs from the `observed` sum. + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + ChiSqTest.chiSquared(observed, expected) + } + + /** + * :: Experimental :: + * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform + * distribution, with each category having an expected frequency of `1 / observed.size`. + * + * Note: `observed` cannot contain negative values. + * + * @param observed Vector containing the observed categorical counts/relative frequencies. + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) + + /** + * :: Experimental :: + * Conduct Pearson's independence test on the input contingency matrix, which cannot contain + * negative entries or columns or rows that sum up to 0. + * + * @param observed The contingency matrix (containing either counts or relative frequencies). + * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, + * the method used, and the null hypothesis. + */ + @Experimental + def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) + + /** + * :: Experimental :: + * Conduct Pearson's independence test for every feature against the label across the input RDD. + * For each feature, the (feature, label) pairs are converted into a contingency matrix for which + * the chi-squared statistic is computed. All label and feature values must be categorical. + * + * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * @return an array containing the ChiSquaredTestResult for every feature against the label. + * The order of the elements in the returned array reflects the order of input features. + */ + @Experimental + def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { + ChiSqTest.chiSquaredFeatures(data) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala index 9bd0c2cd05de4..4a6c677f06d28 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.stat.correlation import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, HashPartitioner} +import org.apache.spark.Logging import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.{DenseVector, Matrix, Vector} -import org.apache.spark.rdd.{CoGroupedRDD, RDD} +import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors} +import org.apache.spark.rdd.RDD /** * Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix @@ -43,87 +43,51 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging { /** * Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the * correlation between column i and j. - * - * Input RDD[Vector] should be cached or checkpointed if possible since it would be split into - * numCol RDD[Double]s, each of which sorted, and the joined back into a single RDD[Vector]. */ override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = { - val indexed = X.zipWithUniqueId() - - val numCols = X.first.size - if (numCols > 50) { - logWarning("Computing the Spearman correlation matrix can be slow for large RDDs with more" - + " than 50 columns.") - } - val ranks = new Array[RDD[(Long, Double)]](numCols) - - // Note: we use a for loop here instead of a while loop with a single index variable - // to avoid race condition caused by closure serialization - for (k <- 0 until numCols) { - val column = indexed.map { case (vector, index) => (vector(k), index) } - ranks(k) = getRanks(column) + // ((columnIndex, value), rowUid) + val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) => + vec.toArray.view.zipWithIndex.map { case (v, j) => + ((j, v), uid) + } } - - val ranksMat: RDD[Vector] = makeRankMatrix(ranks, X) - PearsonCorrelation.computeCorrelationMatrix(ranksMat) - } - - /** - * Compute the ranks for elements in the input RDD, using the average method for ties. - * - * With the average method, elements with the same value receive the same rank that's computed - * by taking the average of their positions in the sorted list. - * e.g. ranks([2, 1, 0, 2]) = [2.5, 1.0, 0.0, 2.5] - * Note that positions here are 0-indexed, instead of the 1-indexed as in the definition for - * ranks in the standard definition for Spearman's correlation. This does not affect the final - * results and is slightly more performant. - * - * @param indexed RDD[(Double, Long)] containing pairs of the format (originalValue, uniqueId) - * @return RDD[(Long, Double)] containing pairs of the format (uniqueId, rank), where uniqueId is - * copied from the input RDD. - */ - private def getRanks(indexed: RDD[(Double, Long)]): RDD[(Long, Double)] = { - // Get elements' positions in the sorted list for computing average rank for duplicate values - val sorted = indexed.sortByKey().zipWithIndex() - - val ranks: RDD[(Long, Double)] = sorted.mapPartitions { iter => - // add an extra element to signify the end of the list so that flatMap can flush the last - // batch of duplicates - val end = -1L - val padded = iter ++ Iterator[((Double, Long), Long)](((Double.NaN, end), end)) - val firstEntry = padded.next() - var lastVal = firstEntry._1._1 - var firstRank = firstEntry._2.toDouble - val idBuffer = ArrayBuffer(firstEntry._1._2) - padded.flatMap { case ((v, id), rank) => - if (v == lastVal && id != end) { - idBuffer += id - Iterator.empty - } else { - val entries = if (idBuffer.size == 1) { - Iterator((idBuffer(0), firstRank)) - } else { - val averageRank = firstRank + (idBuffer.size - 1.0) / 2.0 - idBuffer.map(id => (id, averageRank)) - } - lastVal = v - firstRank = rank - idBuffer.clear() - idBuffer += id - entries + // global sort by (columnIndex, value) + val sorted = colBased.sortByKey() + // assign global ranks (using average ranks for tied values) + val globalRanks = sorted.zipWithIndex().mapPartitions { iter => + var preCol = -1 + var preVal = Double.NaN + var startRank = -1.0 + var cachedUids = ArrayBuffer.empty[Long] + val flush: () => Iterable[(Long, (Int, Double))] = () => { + val averageRank = startRank + (cachedUids.size - 1) / 2.0 + val output = cachedUids.map { uid => + (uid, (preCol, averageRank)) } + cachedUids.clear() + output } + iter.flatMap { case (((j, v), uid), rank) => + // If we see a new value or cachedUids is too big, we flush ids with their average rank. + if (j != preCol || v != preVal || cachedUids.size >= 10000000) { + val output = flush() + preCol = j + preVal = v + startRank = rank + cachedUids += uid + output + } else { + cachedUids += uid + Iterator.empty + } + } ++ flush() } - ranks - } - - private def makeRankMatrix(ranks: Array[RDD[(Long, Double)]], input: RDD[Vector]): RDD[Vector] = { - val partitioner = new HashPartitioner(input.partitions.size) - val cogrouped = new CoGroupedRDD[Long](ranks, partitioner) - cogrouped.map { - case (_, values: Array[Iterable[_]]) => - val doubles = values.asInstanceOf[Array[Iterable[Double]]] - new DenseVector(doubles.flatten.toArray) + // Replace values in the input matrix by their ranks compared with values in the same column. + // Note that shifting all ranks in a column by a constant value doesn't affect result. + val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) => + // sort by column index and then convert values to a vector + Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray) } + PearsonCorrelation.computeCorrelationMatrix(groupedRanks) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala new file mode 100644 index 0000000000000..0089419c2c5d4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import breeze.linalg.{DenseMatrix => BDM} +import cern.jet.stat.Probability.chiSquareComplemented + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD + +import scala.collection.mutable + +/** + * Conduct the chi-squared test for the input RDDs using the specified method. + * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted + * on an input of type `Matrix` in which independence between columns is assessed. + * We also provide a method for computing the chi-squared statistic between each feature and the + * label for an input `RDD[LabeledPoint]`, return an `Array[ChiSquaredTestResult]` of size = + * number of features in the inpuy RDD. + * + * Supported methods for goodness of fit: `pearson` (default) + * Supported methods for independence: `pearson` (default) + * + * More information on Chi-squared test: http://en.wikipedia.org/wiki/Chi-squared_test + */ +private[stat] object ChiSqTest extends Logging { + + /** + * @param name String name for the method. + * @param chiSqFunc Function for computing the statistic given the observed and expected counts. + */ + case class Method(name: String, chiSqFunc: (Double, Double) => Double) + + // Pearson's chi-squared test: http://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test + val PEARSON = new Method("pearson", (observed: Double, expected: Double) => { + val dev = observed - expected + dev * dev / expected + }) + + // Null hypothesis for the two different types of chi-squared tests to be included in the result. + object NullHypothesis extends Enumeration { + type NullHypothesis = Value + val goodnessOfFit = Value("observed follows the same distribution as expected.") + val independence = Value("the occurrence of the outcomes is statistically independent.") + } + + // Method identification based on input methodName string + private def methodFromString(methodName: String): Method = { + methodName match { + case PEARSON.name => PEARSON + case _ => throw new IllegalArgumentException("Unrecognized method for Chi squared test.") + } + } + + /** + * Conduct Pearson's independence test for each feature against the label across the input RDD. + * The contingency table is constructed from the raw (feature, label) pairs and used to conduct + * the independence test. + * Returns an array containing the ChiSquaredTestResult for every feature against the label. + */ + def chiSquaredFeatures(data: RDD[LabeledPoint], + methodName: String = PEARSON.name): Array[ChiSqTestResult] = { + val maxCategories = 10000 + val numCols = data.first().features.size + val results = new Array[ChiSqTestResult](numCols) + var labels: Map[Double, Int] = null + // at most 1000 columns at a time + val batchSize = 1000 + var batch = 0 + while (batch * batchSize < numCols) { + // The following block of code can be cleaned up and made public as + // chiSquared(data: RDD[(V1, V2)]) + val startCol = batch * batchSize + val endCol = startCol + math.min(batchSize, numCols - startCol) + val pairCounts = data.mapPartitions { iter => + val distinctLabels = mutable.HashSet.empty[Double] + val allDistinctFeatures: Map[Int, mutable.HashSet[Double]] = + Map((startCol until endCol).map(col => (col, mutable.HashSet.empty[Double])): _*) + var i = 1 + iter.flatMap { case LabeledPoint(label, features) => + if (i % 1000 == 0) { + if (distinctLabels.size > maxCategories) { + throw new SparkException(s"Chi-square test expect factors (categorical values) but " + + s"found more than $maxCategories distinct label values.") + } + allDistinctFeatures.foreach { case (col, distinctFeatures) => + if (distinctFeatures.size > maxCategories) { + throw new SparkException(s"Chi-square test expect factors (categorical values) but " + + s"found more than $maxCategories distinct values in column $col.") + } + } + } + i += 1 + distinctLabels += label + features.toArray.view.zipWithIndex.slice(startCol, endCol).map { case (feature, col) => + allDistinctFeatures(col) += feature + (col, feature, label) + } + } + }.countByValue() + + if (labels == null) { + // Do this only once for the first column since labels are invariant across features. + labels = + pairCounts.keys.filter(_._1 == startCol).map(_._3).toArray.distinct.zipWithIndex.toMap + } + val numLabels = labels.size + pairCounts.keys.groupBy(_._1).map { case (col, keys) => + val features = keys.map(_._2).toArray.distinct.zipWithIndex.toMap + val numRows = features.size + val contingency = new BDM(numRows, numLabels, new Array[Double](numRows * numLabels)) + keys.foreach { case (_, feature, label) => + val i = features(feature) + val j = labels(label) + contingency(i, j) += pairCounts((col, feature, label)) + } + results(col) = chiSquaredMatrix(Matrices.fromBreeze(contingency), methodName) + } + batch += 1 + } + results + } + + /* + * Pearon's goodness of fit test on the input observed and expected counts/relative frequencies. + * Uniform distribution is assumed when `expected` is not passed in. + */ + def chiSquared(observed: Vector, + expected: Vector = Vectors.dense(Array[Double]()), + methodName: String = PEARSON.name): ChiSqTestResult = { + + // Validate input arguments + val method = methodFromString(methodName) + if (expected.size != 0 && observed.size != expected.size) { + throw new IllegalArgumentException("observed and expected must be of the same size.") + } + val size = observed.size + if (size > 1000) { + logWarning("Chi-squared approximation may not be accurate due to low expected frequencies " + + s" as a result of a large number of categories: $size.") + } + val obsArr = observed.toArray + val expArr = if (expected.size == 0) Array.tabulate(size)(_ => 1.0 / size) else expected.toArray + if (!obsArr.forall(_ >= 0.0)) { + throw new IllegalArgumentException("Negative entries disallowed in the observed vector.") + } + if (expected.size != 0 && ! expArr.forall(_ >= 0.0)) { + throw new IllegalArgumentException("Negative entries disallowed in the expected vector.") + } + + // Determine the scaling factor for expected + val obsSum = obsArr.sum + val expSum = if (expected.size == 0.0) 1.0 else expArr.sum + val scale = if (math.abs(obsSum - expSum) < 1e-7) 1.0 else obsSum / expSum + + // compute chi-squared statistic + val statistic = obsArr.zip(expArr).foldLeft(0.0) { case (stat, (obs, exp)) => + if (exp == 0.0) { + if (obs == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input vectors due" + + " to 0.0 values in both observed and expected.") + } else { + return new ChiSqTestResult(0.0, size - 1, Double.PositiveInfinity, PEARSON.name, + NullHypothesis.goodnessOfFit.toString) + } + } + if (scale == 1.0) { + stat + method.chiSqFunc(obs, exp) + } else { + stat + method.chiSqFunc(obs, exp * scale) + } + } + val df = size - 1 + val pValue = chiSquareComplemented(df, statistic) + new ChiSqTestResult(pValue, df, statistic, PEARSON.name, NullHypothesis.goodnessOfFit.toString) + } + + /* + * Pearon's independence test on the input contingency matrix. + * TODO: optimize for SparseMatrix when it becomes supported. + */ + def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { + val method = methodFromString(methodName) + val numRows = counts.numRows + val numCols = counts.numCols + + // get row and column sums + val colSums = new Array[Double](numCols) + val rowSums = new Array[Double](numRows) + val colMajorArr = counts.toArray + var i = 0 + while (i < colMajorArr.size) { + val elem = colMajorArr(i) + if (elem < 0.0) { + throw new IllegalArgumentException("Contingency table cannot contain negative entries.") + } + colSums(i / numRows) += elem + rowSums(i % numRows) += elem + i += 1 + } + val total = colSums.sum + + // second pass to collect statistic + var statistic = 0.0 + var j = 0 + while (j < colMajorArr.size) { + val col = j / numRows + val colSum = colSums(col) + if (colSum == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" + + s"0 sum in column [$col].") + } + val row = j % numRows + val rowSum = rowSums(row) + if (rowSum == 0.0) { + throw new IllegalArgumentException("Chi-squared statistic undefined for input matrix due to" + + s"0 sum in row [$row].") + } + val expected = colSum * rowSum / total + statistic += method.chiSqFunc(colMajorArr(j), expected) + j += 1 + } + val df = (numCols - 1) * (numRows - 1) + val pValue = chiSquareComplemented(df, statistic) + new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala new file mode 100644 index 0000000000000..4784f9e947908 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/TestResult.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat.test + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Trait for hypothesis test results. + * @tparam DF Return type of `degreesOfFreedom`. + */ +@Experimental +trait TestResult[DF] { + + /** + * The probability of obtaining a test statistic result at least as extreme as the one that was + * actually observed, assuming that the null hypothesis is true. + */ + def pValue: Double + + /** + * Returns the degree(s) of freedom of the hypothesis test. + * Return type should be Number(e.g. Int, Double) or tuples of Numbers for toString compatibility. + */ + def degreesOfFreedom: DF + + /** + * Test statistic. + */ + def statistic: Double + + /** + * Null hypothesis of the test. + */ + def nullHypothesis: String + + /** + * String explaining the hypothesis test result. + * Specific classes implementing this trait should override this method to output test-specific + * information. + */ + override def toString: String = { + + // String explaining what the p-value indicates. + val pValueExplain = if (pValue <= 0.01) { + s"Very strong presumption against null hypothesis: $nullHypothesis." + } else if (0.01 < pValue && pValue <= 0.05) { + s"Strong presumption against null hypothesis: $nullHypothesis." + } else if (0.05 < pValue && pValue <= 0.1) { + s"Low presumption against null hypothesis: $nullHypothesis." + } else { + s"No presumption against null hypothesis: $nullHypothesis." + } + + s"degrees of freedom = ${degreesOfFreedom.toString} \n" + + s"statistic = $statistic \n" + + s"pValue = $pValue \n" + pValueExplain + } +} + +/** + * :: Experimental :: + * Object containing the test results for the chi-squared hypothesis test. + */ +@Experimental +class ChiSqTestResult private[stat] (override val pValue: Double, + override val degreesOfFreedom: Int, + override val statistic: Double, + val method: String, + override val nullHypothesis: String) extends TestResult[Int] { + + override def toString: String = { + "Chi squared test summary:\n" + + s"method: $method\n" + + super.toString + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1d03e6e3b36cf..5cdd258f6c20b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,18 +17,25 @@ package org.apache.spark.mllib.tree +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging +import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom + /** * :: Experimental :: * A class which implements a decision tree learning algorithm for classification and regression. @@ -40,6 +47,8 @@ import org.apache.spark.util.random.XORShiftRandom @Experimental class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { + strategy.assertValid() + /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] @@ -47,39 +56,45 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - // Cache input RDD for speedup during multiple passes. - val retaggedInput = input.retag(classOf[LabeledPoint]).cache() + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy) + timer.start("findSplitsBins") + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) val numBins = bins(0).length + timer.stop("findSplitsBins") logDebug("numBins = " + numBins) + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) + .persist(StorageLevel.MEMORY_AND_DISK) + + val numFeatures = metadata.numFeatures // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1 - // Initialize an array to hold filters applied to points for each node. - val filters = new Array[List[Filter]](maxNumNodes) - // The filter at the top node is an empty list. - filters(0) = List() + val maxNumNodes = (2 << maxDepth) - 1 // Initialize an array to hold parent impurity calculations for each node. val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) - // num features - val numFeatures = retaggedInput.take(1)(0).features.size // Calculate level for single group construction // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, - strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, - strategy.algo) + val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -90,12 +105,13 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) + timer.stop("init") + /* * The main idea here is to perform level-wise training of the decision tree nodes thus * reducing the passes over the data from l to log2(l) where l is the total number of nodes. - * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., - * the sample is only used for the split calculation at the node if the sampled would have - * still survived the filters of the parent nodes. + * Each data sample is handled by a particular node at that level (or it reaches a leaf + * beforehand and is not used in later levels. */ var level = 0 @@ -107,18 +123,39 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities, - strategy, level, filters, splits, bins, maxLevelForSingleGroup) + timer.start("findBestSplits") + val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, + metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) + timer.stop("findBestSplits") + val levelNodeIndexOffset = (1 << level) - 1 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - // Extract info for nodes at the current level. + val nodeIndex = levelNodeIndexOffset + index + val isLeftChild = level != 0 && nodeIndex % 2 == 1 + val parentNodeIndex = if (isLeftChild) { // -1 for root node + (nodeIndex - 1) / 2 + } else { + (nodeIndex - 2) / 2 + } + // Extract info for this node (index) at the current level. + timer.start("extractNodeInfo") extractNodeInfo(nodeSplitStats, level, index, nodes) + timer.stop("extractNodeInfo") + if (level != 0) { + // Set parent. + if (isLeftChild) { + nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) + } else { + nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) + } + } // Extract info for nodes at the next lower level. - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, - filters) + timer.start("extractInfoForLowerLevels") + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities) + timer.stop("extractInfoForLowerLevels") logDebug("final best split = " + nodeSplitStats._1) } - require(math.pow(2, level) == splitsStatsForLevel.length) + require((1 << level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -138,6 +175,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + new DecisionTreeModel(topNode, strategy.algo) } @@ -151,7 +193,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val nodeIndex = math.pow(2, level).toInt - 1 + index + val nodeIndex = (1 << level) - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) @@ -166,31 +208,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo index: Int, maxDepth: Int, nodeSplitStats: (Split, InformationGainStats), - parentImpurities: Array[Double], - filters: Array[List[Filter]]): Unit = { - // 0 corresponds to the left child node and 1 corresponds to the right child node. - var i = 0 - while (i <= 1) { - // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i - if (level < maxDepth) { - val impurity = if (i == 0) { - nodeSplitStats._2.leftImpurity - } else { - nodeSplitStats._2.rightImpurity - } - logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) - // noting the parent impurities - parentImpurities(nodeIndex) = impurity - // noting the parents filters for the child nodes - val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) - filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) - for (filter <- filters(nodeIndex)) { - logDebug("Filter = " + filter) - } - } - i += 1 + parentImpurities: Array[Double]): Unit = { + + if (level >= maxDepth) { + return } + + val leftNodeIndex = (2 << level) - 1 + 2 * index + val leftImpurity = nodeSplitStats._2.leftImpurity + logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) + parentImpurities(leftNodeIndex) = leftImpurity + + val rightNodeIndex = leftNodeIndex + 1 + val rightImpurity = nodeSplitStats._2.rightImpurity + logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity) + parentImpurities(rightNodeIndex) = rightImpurity } } @@ -200,6 +232,10 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. + * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -213,10 +249,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -237,10 +275,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -263,11 +303,12 @@ object DecisionTree extends Serializable with Logging { } /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The decision tree method supports binary classification and - * regression. For the binary classification, the label for each instance should either be 0 or - * 1 to denote the two classes. The method also supports categorical features inputs where the - * number of categories can specified using the categoricalFeaturesInfo option. + * Method to train a decision tree model. + * The method supports binary and multiclass classification and regression. + * + * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. @@ -279,11 +320,9 @@ object DecisionTree extends Serializable with Logging { * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, - * an entry (n -> k) implies the feature n is categorical with k - * categories 0, 1, 2, ... , k-1. It's important to note that - * features are zero-indexed. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction */ def train( @@ -300,78 +339,163 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClassesForClassification number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + val impurityType = Impurities.fromString(impurity) + train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, + categoricalFeaturesInfo) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + */ + def trainClassifier( + input: JavaRDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + trainClassifier(input.rdd, numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param impurity Criterion used for information gain calculation. + * Supported values: "variance". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @return DecisionTreeModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + val impurityType = Impurities.fromString(impurity) + train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + */ + def trainRegressor( + input: JavaRDD[LabeledPoint], + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + impurity: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + trainRegressor(input.rdd, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + impurity, maxDepth, maxBins) + } + + private val InvalidBinIndex = -1 /** * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for constructing the DecisionTree + * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return array of splits with best splits for all nodes at a given level. + * @return array (over nodes) of splits with best split for each node at a given level. */ protected[tree] def findBestSplits( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: DecisionTreeMetadata, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], - maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + maxLevelForSingleGroup: Int, + timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { // When information for all nodes at a given level cannot be stored in memory, // the nodes are divided into multiple groups at each level with the number of groups // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = 1 << level - maxLevelForSingleGroup logDebug("numGroups = " + numGroups) var bestSplits = new Array[(Split, InformationGainStats)](0) // Iterate over each group of nodes at a level. var groupIndex = 0 while (groupIndex < numGroups) { - val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, - filters, splits, bins, numGroups, groupIndex) + val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level, + nodes, splits, bins, timer, numGroups, groupIndex) bestSplits = Array.concat(bestSplits, bestSplitsForGroup) groupIndex += 1 } bestSplits } else { - findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer) } } - /** + /** * Returns an array of optimal splits for a group of nodes at a given level * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param parentImpurities Impurities for all parent nodes for the current level - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for constructing the DecisionTree + * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param filters Filters for all nodes at a given level * @param splits possible splits for all features - * @param bins possible bins for all features + * @param bins possible bins for all features, indexed as (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[LabeledPoint], + input: RDD[TreePoint], parentImpurities: Array[Double], - strategy: Strategy, + metadata: DecisionTreeMetadata, level: Int, - filters: Array[List[Filter]], + nodes: Array[Node], splits: Array[Array[Split]], bins: Array[Array[Bin]], + timer: TimeTracker, numGroups: Int = 1, groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { @@ -387,7 +511,7 @@ object DecisionTree extends Serializable with Logging { * We use a bin-wise best split computation strategy instead of a straightforward best split * computation strategy. Instead of analyzing each sample for contribution to the left/right * child node impurity of every split, we first categorize each feature of a sample into a - * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * bin. Each bin is an interval between a low and high split. Since each split, and thus bin, * is ordered (read ordering for categorical variables in the findSplitsBins method), * we exploit this structure to calculate aggregates for bins and then use these aggregates * to calculate information gain for each split. @@ -403,258 +527,124 @@ object DecisionTree extends Serializable with Logging { // numNodes: Number of nodes in this (level of tree, group), // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = math.pow(2, level).toInt / numGroups + val numNodes = (1 << level) / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. - val numFeatures = input.first().features.size + val numFeatures = metadata.numFeatures logDebug("numFeatures = " + numFeatures) // numBins: Number of bins = 1 + number of possible splits val numBins = bins(0).length logDebug("numBins = " + numBins) - val numClasses = strategy.numClassesForClassification + val numClasses = metadata.numClasses logDebug("numClasses = " + numClasses) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) + val isMulticlass = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlass) - val isMulticlassClassificationWithCategoricalFeatures - = strategy.isMulticlassWithCategoricalFeatures - logDebug("isMultiClassWithCategoricalFeatures = " + - isMulticlassClassificationWithCategoricalFeatures) + val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures + logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - /** Find the filters used before reaching the current code. */ - def findParentFilters(nodeIndex: Int): List[Filter] = { - if (level == 0) { - List[Filter]() - } else { - val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift - filters(nodeFilterIndex) - } - } - /** - * Find whether the sample is valid input for the current node, i.e., whether it passes through - * all the filters for the current node. + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a node + * at the current level being trained; that node's index is returned. + * + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { - // leaf - if ((level > 0) && (parentFilters.length == 0)) { - return false - } - - // Apply each filter and check sample validity. Return false when invalid condition found. - for (filter <- parentFilters) { - val features = labeledPoint.features - val featureIndex = filter.split.feature - val threshold = filter.split.threshold - val comparison = filter.comparison - val categories = filter.split.categories - val isFeatureContinuous = filter.split.featureType == Continuous - val feature = features(featureIndex) - if (isFeatureContinuous) { - comparison match { - case -1 => if (feature > threshold) return false - case 1 => if (feature <= threshold) return false + def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { + if (node.isLeaf) { + node.id + } else { + val featureIndex = node.split.get.feature + val splitLeft = node.split.get.featureType match { + case Continuous => { + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] + // We do not need to check lowSplit since bins are separated by splits. + featureValueUpperBound <= node.split.get.threshold } - } else { - val containsFeature = categories.contains(feature) - comparison match { - case -1 => if (!containsFeature) return false - case 1 => if (containsFeature) return false + case Categorical => { + val featureValue = if (metadata.isUnordered(featureIndex)) { + binnedFeatures(featureIndex) + } else { + val binIndex = binnedFeatures(featureIndex) + bins(featureIndex)(binIndex).category + } + node.split.get.categories.contains(featureValue) } - + case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") } - } - - // Return true when the sample is valid for all filters. - true - } - - /** - * Find bin for one (labeledPoint, feature). - */ - def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, - isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } - else if (lowThreshold >= feature) { - right = mid - 1 - } - else { - left = mid + 1 + if (node.leftNode.isEmpty || node.rightNode.isEmpty) { + // Return index from next layer of nodes to train + if (splitLeft) { + node.id * 2 + 1 // left + } else { + node.id * 2 + 2 // right } - } - -1 - } - - /** - * Sequential search helper method to find bin for categorical feature in multiclass - * classification. The category is returned since each category can belong to multiple - * splits. The actual left/right child allocation per split is performed in the - * sequential phase of the bin aggregate operation. - */ - def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { - labeledPoint.features(featureIndex).toInt - } - - /** - * Sequential search helper method to find bin for categorical feature - * (for classification and regression). - */ - def sequentialBinSearchForOrderedCategoricalFeature(): Int = { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val featureValue = labeledPoint.features(featureIndex) - var binIndex = 0 - while (binIndex < featureCategories) { - val bin = bins(featureIndex)(binIndex) - val categories = bin.highSplit.categories - if (categories.contains(featureValue)) { - return binIndex + } else { + if (splitLeft) { + predictNodeIndex(node.leftNode.get, binnedFeatures) + } else { + predictNodeIndex(node.rightNode.get, binnedFeatures) } - binIndex += 1 - } - if (featureValue < 0 || featureValue >= featureCategories) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) } - -1 } + } - if (isFeatureContinuous) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new UnknownError("no bin was found for continuous variable.") - } - binIndex + def nodeIndexToLevel(idx: Int): Int = { + if (idx == 0) { + 0 } else { - // Perform sequential search to find bin for categorical features. - val binIndex = { - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { - sequentialBinSearchForUnorderedCategoricalFeatureInClassification() - } else { - sequentialBinSearchForOrderedCategoricalFeature() - } - } - if (binIndex == -1) { - throw new UnknownError("no bin was found for categorical variable.") - } - binIndex + math.floor(math.log(idx) / math.log(2)).toInt } } + // Used for treePointToNodeIndex + val levelOffset = (1 << level) - 1 + /** - * Finds bins for all nodes (and all features) at a given level. - * For l nodes, k features the storage is as follows: - * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1 for regressions and binary - * classification and the categorical feature value in multiclass classification. - * Invalid sample is denoted by noting bin for feature 1 as -1. - * - * For unordered features, the "bin index" returned is actually the feature value (category). - * - * @return Array of size 1 + numFeatures * numNodes, where - * arr(0) = label for labeledPoint, and - * arr(1 + numFeatures * nodeIndex + featureIndex) = - * bin index for this labeledPoint - * (or InvalidBinIndex if labeledPoint is not handled by this node) + * Find the node index for the given example. + * Nodes are indexed from 0 at the start of this (level, group). + * If the example does not reach this level, returns a value < 0. */ - def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { - // Calculate bin index and label per feature per node. - val arr = new Array[Double](1 + (numFeatures * numNodes)) - // First element of the array is the label of the instance. - arr(0) = labeledPoint.label - // Iterate over nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - val parentFilters = findParentFilters(nodeIndex) - // Find out whether the sample qualifies for the particular node. - val sampleValid = isSampleValid(parentFilters, labeledPoint) - val shift = 1 + numFeatures * nodeIndex - if (!sampleValid) { - // Mark one bin as -1 is sufficient. - arr(shift) = InvalidBinIndex - } else { - var featureIndex = 0 - while (featureIndex < numFeatures) { - val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) - val isFeatureContinuous = featureInfo.isEmpty - if (isFeatureContinuous) { - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) - } else { - val featureCategories = featureInfo.get - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - arr(shift + featureIndex) - = findBin(featureIndex, labeledPoint, isFeatureContinuous, - isSpaceSufficientForAllCategoricalSplits) - } - featureIndex += 1 - } - } - nodeIndex += 1 + def treePointToNodeIndex(treePoint: TreePoint): Int = { + if (level == 0) { + 0 + } else { + val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures) + // Get index for this (level, group). + globalNodeIndex - levelOffset - groupShift } - arr } - // Find feature bins for all nodes at a level. - val binMappedRDD = input.map(x => findBinsForLevel(x)) - /** * Increment aggregate in location for (node, feature, bin, label). * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ def updateBinForOrderedFeature( - arr: Array[Double], + treePoint: TreePoint, agg: Array[Double], nodeIndex: Int, - label: Double, featureIndex: Int): Unit = { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. val aggIndex = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - numClasses * arr(arrIndex).toInt + - label.toInt + numClasses * treePoint.binnedFeatures(featureIndex) + + treePoint.label.toInt agg(aggIndex) += 1 } @@ -663,8 +653,8 @@ object DecisionTree extends Serializable with Logging { * where [bins] ranges over all bins. * Updates left or right side of aggregate depending on split. * - * @param arr arr(0) = label. - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category) + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param treePoint Data point being aggregated. * @param agg Indexed by (left/right, node, feature, bin, label) * where label is the least significant bit. * The left/right specifier is a 0/1 index indicating left/right child info. @@ -673,21 +663,18 @@ object DecisionTree extends Serializable with Logging { def updateBinForUnorderedFeature( nodeIndex: Int, featureIndex: Int, - arr: Array[Double], - label: Double, + treePoint: TreePoint, agg: Array[Double], rightChildShift: Int): Unit = { - // Find the bin index for this feature. - val arrIndex = 1 + numFeatures * nodeIndex + featureIndex - val featureValue = arr(arrIndex).toInt + val featureValue = treePoint.binnedFeatures(featureIndex) // Update the left or right count for one bin. val aggShift = numClasses * numBins * numFeatures * nodeIndex + numClasses * numBins * featureIndex + - label.toInt + treePoint.label.toInt // Find all matching bins and increment their values - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + val featureCategories = metadata.featureArity(featureIndex) + val numCategoricalBins = (1 << featureCategories - 1) - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { val aggIndex = aggShift + binIndex * numClasses @@ -703,80 +690,51 @@ object DecisionTree extends Serializable with Logging { /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). * @param agg Array storing aggregate calculation, of size: * numClasses * numBins * numFeatures * numNodes. * Indexed by (node, feature, bin, label) where label is the least significant bit. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - featureIndex += 1 - } - } - nodeIndex += 1 + def binaryOrNotCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) + featureIndex += 1 } } + val rightChildShift = numClasses * numBins * numFeatures * numNodes + /** * Helper for binSeqOp. * - * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label. - * Array of size 1 + (numFeatures * numNodes). - * For ordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index. - * For unordered features, - * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category). * @param agg Array storing aggregate calculation. * For ordered features, this is of size: * numClasses * numBins * numFeatures * numNodes. * For unordered features, this is of size: * 2 * numClasses * numBins * numFeatures * numNodes. + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). */ - def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - val rightChildShift = numClasses * numBins * numFeatures * numNodes - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, - rightChildShift) - } else { - updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) - } - } - featureIndex += 1 - } + def multiclassWithCategoricalBinSeqOp( + agg: Array[Double], + treePoint: TreePoint, + nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isUnordered(featureIndex)) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift) + } else { + updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) } - nodeIndex += 1 + featureIndex += 1 } } @@ -787,36 +745,25 @@ object DecisionTree extends Serializable with Logging { * * @param agg Array storing aggregate calculation, updated by this function. * Size: 3 * numBins * numFeatures * numNodes - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). * @return agg */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update count, sum, and sum^2 for one bin. - val aggShift = 3 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 - agg(aggIndex) = agg(aggIndex) + 1 - agg(aggIndex + 1) = agg(aggIndex + 1) + label - agg(aggIndex + 2) = agg(aggIndex + 2) + label*label - featureIndex += 1 - } - } - nodeIndex += 1 + def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { + val label = treePoint.label + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Update count, sum, and sum^2 for one bin. + val binIndex = treePoint.binnedFeatures(featureIndex) + val aggIndex = + 3 * numBins * numFeatures * nodeIndex + + 3 * numBins * featureIndex + + 3 * binIndex + agg(aggIndex) += 1 + agg(aggIndex + 1) += label + agg(aggIndex + 2) += label * label + featureIndex += 1 } } @@ -835,26 +782,30 @@ object DecisionTree extends Serializable with Logging { * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. * Size for regression: * 3 * numBins * numFeatures * numNodes. - * @param arr Bin mapping from findBinsForLevel. - * Array of size 1 + (numFeatures * numNodes). + * @param treePoint Data point being aggregated. * @return agg */ - def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { - strategy.algo match { - case Classification => - if(isMulticlassClassificationWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(arr, agg) + def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = { + val nodeIndex = treePointToNodeIndex(treePoint) + // If the example does not reach this level, then nodeIndex < 0. + // If the example reaches this level but is handled in a different group, + // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). + if (nodeIndex >= 0 && nodeIndex < numNodes) { + if (metadata.isClassification) { + if (isMulticlassWithCategoricalFeatures) { + multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) } else { - binaryOrNotCategoricalBinSeqOp(arr, agg) + binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex) } - case Regression => regressionBinSeqOp(arr, agg) + } else { + regressionBinSeqOp(agg, treePoint, nodeIndex) + } } agg } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses, - isMulticlassClassificationWithCategoricalFeatures, strategy.algo) + val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -874,135 +825,134 @@ object DecisionTree extends Serializable with Logging { } // Calculate bin aggregates. + timer.start("aggregation") val binAggregates = { - binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) } + timer.stop("aggregation") logDebug("binAggregates.length = " + binAggregates.length) /** - * Calculates the information gain for all splits based upon left/right split aggregates. - * @param leftNodeAgg left node aggregates - * @param featureIndex feature index - * @param splitIndex split index - * @param rightNodeAgg right node aggregate + * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * @param leftNodeAgg left node aggregates for this (feature, split) + * @param rightNodeAgg right node aggregate for this (feature, split) * @param topImpurity impurity of the parent node * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int, - splitIndex: Int, - rightNodeAgg: Array[Array[Array[Double]]], + leftNodeAgg: Array[Double], + rightNodeAgg: Array[Double], topImpurity: Double): InformationGainStats = { - strategy.algo match { - case Classification => - val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex) - val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex) - val leftTotalCount = leftCounts.sum - val rightTotalCount = rightCounts.sum - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val rootNodeCounts = new Array[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) - classIndex += 1 - } - strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) + if (metadata.isClassification) { + val leftTotalCount = leftNodeAgg.sum + val rightTotalCount = rightNodeAgg.sum + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val rootNodeCounts = new Array[Double](numClasses) + var classIndex = 0 + while (classIndex < numClasses) { + rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex) + classIndex += 1 } + metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) } + } - val totalCount = leftTotalCount + rightTotalCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) - } + val totalCount = leftTotalCount + rightTotalCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } - // Sum of count for each label - val leftRightCounts: Array[Double] = - leftCounts.zip(rightCounts).map { case (leftCount, rightCount) => - leftCount + rightCount - } + // Sum of count for each label + val leftrightNodeAgg: Array[Double] = + leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => + leftCount + rightCount + } - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) - else (maxIndex, maxValue, currentIndex + 1) - } - if (result._1 < 0) 0 else result._1 + def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } + } + if (result._1 < 0) { + throw new RuntimeException("DecisionTree internal error:" + + " calculateGainForSplit failed in indexOfLargestArrayElement") } + result._1 + } - val predict = indexOfLargestArrayElement(leftRightCounts) - val prob = leftRightCounts(predict) / totalCount + val predict = indexOfLargestArrayElement(leftrightNodeAgg) + val prob = leftrightNodeAgg(predict) / totalCount - val leftImpurity = if (leftTotalCount == 0) { - topImpurity - } else { - strategy.impurity.calculate(leftCounts, leftTotalCount) - } - val rightImpurity = if (rightTotalCount == 0) { + val leftImpurity = if (leftTotalCount == 0) { + topImpurity + } else { + metadata.impurity.calculate(leftNodeAgg, leftTotalCount) + } + val rightImpurity = if (rightTotalCount == 0) { + topImpurity + } else { + metadata.impurity.calculate(rightNodeAgg, rightTotalCount) + } + + val leftWeight = leftTotalCount / totalCount + val rightWeight = rightTotalCount / totalCount + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + + } else { + // Regression + + val leftCount = leftNodeAgg(0) + val leftSum = leftNodeAgg(1) + val leftSumSquares = leftNodeAgg(2) + + val rightCount = rightNodeAgg(0) + val rightSum = rightNodeAgg(1) + val rightSumSquares = rightNodeAgg(2) + + val impurity = { + if (level > 0) { topImpurity } else { - strategy.impurity.calculate(rightCounts, rightTotalCount) + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + metadata.impurity.calculate(count, sum, sumSquares) } + } - val leftWeight = leftTotalCount / totalCount - val rightWeight = rightTotalCount / totalCount - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) - case Regression => - val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) - val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) - val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) - - val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) - val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) - val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - strategy.impurity.calculate(count, sum, sumSquares) - } - } + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, + Double.MinValue, leftSum / leftCount) + } - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity ,topImpurity, - Double.MinValue, leftSum / leftCount) - } + val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares) - val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - val gain = { - if (level > 0) { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } else { - impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - } - } - - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) } } @@ -1025,6 +975,19 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + /** + * The input binData is indexed as (feature, bin, class). + * This computes cumulative sums over splits. + * Each (feature, class) pair is handled separately. + * Note: numSplits = numBins - 1. + * @param leftNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 0, ..., numSplits - 2) is set to be + * the cumulative sum (from left) over binData for bins 0, ..., i. + * @param rightNodeAgg Each (feature, class) slice is an array over splits. + * Element i (i = 1, ..., numSplits - 1) is set to be + * the cumulative sum (from right) over binData for bins + * numBins - 1, ..., numBins - 1 - i. + */ def findAggForOrderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], @@ -1129,45 +1092,32 @@ object DecisionTree extends Serializable with Logging { } } - strategy.algo match { - case Classification => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (isMulticlassClassificationWithCategoricalFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isSpaceSufficientForAllCategoricalSplits) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - } - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - featureIndex += 1 - } - - (leftNodeAgg, rightNodeAgg) - case Regression => - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 + if (metadata.isClassification) { + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (metadata.isUnordered(featureIndex)) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } - (leftNodeAgg, rightNodeAgg) + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } else { + // Regression + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) } } @@ -1180,15 +1130,38 @@ object DecisionTree extends Serializable with Logging { nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) - for (featureIndex <- 0 until numFeatures) { - for (splitIndex <- 0 until numBins - 1) { - gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, - splitIndex, rightNodeAgg, nodeImpurity) + var featureIndex = 0 + while (featureIndex < numFeatures) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplitsForFeature) { + gains(featureIndex)(splitIndex) = + calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), + rightNodeAgg(featureIndex)(splitIndex), nodeImpurity) + splitIndex += 1 } + featureIndex += 1 } gains } + /** + * Get the number of splits for a feature. + */ + def getNumSplitsForFeature(featureIndex: Int): Int = { + if (metadata.isContinuous(featureIndex)) { + numBins - 1 + } else { + // Categorical feature + val featureCategories = metadata.featureArity(featureIndex) + if (metadata.isUnordered(featureIndex)) { + (1 << featureCategories - 1) - 1 + } else { + featureCategories + } + } + } + /** * Find the best split for a node. * @param binData Bin data slice for this node, given by getBinDataForNode. @@ -1207,7 +1180,7 @@ object DecisionTree extends Serializable with Logging { // Calculate gains for all splits. val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - val (bestFeatureIndex,bestSplitIndex, gainStats) = { + val (bestFeatureIndex, bestSplitIndex, gainStats) = { // Initialize with infeasible values. var bestFeatureIndex = Int.MinValue var bestSplitIndex = Int.MinValue @@ -1217,22 +1190,8 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - val maxSplitIndex: Double = { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous) { - numBins - 1 - } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 - if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { - math.pow(2.0, featureCategories - 1).toInt - 1 - } else { // Binary classification - featureCategories - } - } - } - while (splitIndex < maxSplitIndex) { + val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + while (splitIndex < numSplitsForFeature) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -1256,38 +1215,39 @@ object DecisionTree extends Serializable with Logging { * Get bin data for one node. */ def getBinDataForNode(node: Int): Array[Double] = { - strategy.algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - val shift = numClasses * node * numBins * numFeatures - val rightChildShift = numClasses * numBins * numFeatures * numNodes - val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + numClasses * numBins * numFeatures) - leftChildData ++ rightChildData - } - binsForNode - } else { - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode + if (metadata.isClassification) { + if (isMulticlassWithCategoricalFeatures) { + val shift = numClasses * node * numBins * numFeatures + val rightChildShift = numClasses * numBins * numFeatures * numNodes + val binsForNode = { + val leftChildData + = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + val rightChildData + = binAggregates.slice(rightChildShift + shift, + rightChildShift + shift + numClasses * numBins * numFeatures) + leftChildData ++ rightChildData } - case Regression => - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) binsForNode + } else { + val shift = numClasses * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + binsForNode + } + } else { + // Regression + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) + binsForNode } } // Calculate best splits for all nodes at a given level + timer.start("chooseSplits") val bestSplits = new Array[(Split, InformationGainStats)](numNodes) // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift + val nodeImpurityIndex = (1 << level) - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) @@ -1295,6 +1255,8 @@ object DecisionTree extends Serializable with Logging { bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } + timer.stop("chooseSplits") + bestSplits } @@ -1303,20 +1265,15 @@ object DecisionTree extends Serializable with Logging { * * @param numBins Number of bins = 1 + number of possible splits. */ - private def getElementsPerNode( - numFeatures: Int, - numBins: Int, - numClasses: Int, - isMulticlassClassificationWithCategoricalFeatures: Boolean, - algo: Algo): Int = { - algo match { - case Classification => - if (isMulticlassClassificationWithCategoricalFeatures) { - 2 * numClasses * numBins * numFeatures - } else { - numClasses * numBins * numFeatures - } - case Regression => 3 * numBins * numFeatures + private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = { + if (metadata.isClassification) { + if (metadata.isMulticlassWithCategoricalFeatures) { + 2 * metadata.numClasses * numBins * metadata.numFeatures + } else { + metadata.numClasses * numBins * metadata.numFeatures + } + } else { + 3 * numBins * metadata.numFeatures } } @@ -1331,21 +1288,19 @@ object DecisionTree extends Serializable with Logging { * Categorical features: * For each feature, there is 1 bin per split. * Splits and bins are handled in 2 ways: - * (a) For multiclass classification with a low-arity feature + * (a) "unordered features" + * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are 2^(maxFeatureValue - 1) - 1 splits. - * (b) For regression and binary classification, + * There are (1 << maxFeatureValue - 1) - 1 splits. + * (b) "ordered features" + * For regression and binary classification, * and for multiclass classification with a high-arity feature, - * there is one split per category. - - * Categorical case (a) features are called unordered features. - * Other cases are called ordered features. + * there is one bin per category. * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing - * parameters for construction the DecisionTree - * @return A tuple of (splits,bins). + * @param metadata Learning and dataset metadata + * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numBins - 1). * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] @@ -1353,33 +1308,35 @@ object DecisionTree extends Serializable with Logging { */ protected[tree] def findSplitsBins( input: RDD[LabeledPoint], - strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() // Find the number of features by looking at the first sample val numFeatures = input.take(1)(0).features.size - val maxBins = strategy.maxBins + val maxBins = metadata.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - val isMulticlassClassification = strategy.isMulticlassClassification - logDebug("isMulticlassClassification = " + isMulticlassClassification) - + val isMulticlass = metadata.isMulticlass + logDebug("isMulticlass = " + isMulticlass) /* - * Ensure #bins is always greater than the categories. For multiclass classification, - * #bins should be greater than 2^(maxCategories - 1) - 1. + * Ensure numBins is always greater than the categories. For multiclass classification, + * numBins should be greater than 2^(maxCategories - 1) - 1. * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. + * + * This needs to be checked here instead of in Strategy since numBins can be determined + * by the number of training examples. + * TODO: Allow this case, where we simply will know nothing about some categories. */ - if (strategy.categoricalFeaturesInfo.size > 0) { - val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + if (metadata.featureArity.size > 0) { + val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2 require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + "in categorical features") } - // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 @@ -1393,7 +1350,7 @@ object DecisionTree extends Serializable with Logging { val stride: Double = numSamples.toDouble / numBins logDebug("stride = " + stride) - strategy.quantileCalculationStrategy match { + metadata.quantileStrategy match { case Sort => val splits = Array.ofDim[Split](numFeatures, numBins - 1) val bins = Array.ofDim[Bin](numFeatures, numBins) @@ -1404,7 +1361,7 @@ object DecisionTree extends Serializable with Logging { var featureIndex = 0 while (featureIndex < numFeatures) { // Check whether the feature is continuous. - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = metadata.isContinuous(featureIndex) if (isFeatureContinuous) { val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted val stride: Double = numSamples.toDouble / numBins @@ -1417,18 +1374,14 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex)(index) = split } } else { // Categorical feature - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val isSpaceSufficientForAllCategoricalSplits - = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + val featureCategories = metadata.featureArity(featureIndex) // Use different bin/split calculation strategy for categorical features in multiclass // classification that satisfy the space constraint. - val isUnorderedFeature = - isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits - if (isUnorderedFeature) { + if (metadata.isUnordered(featureIndex)) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 - while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { + while (index < (1 << featureCategories - 1) - 1) { val categories: List[Double] = extractMultiClassCategories(index + 1, featureCategories) splits(featureIndex)(index) @@ -1458,7 +1411,7 @@ object DecisionTree extends Serializable with Logging { * centroidForCategories is a mapping: category (for the given feature) --> centroid */ val centroidForCategories = { - if (isMulticlassClassification) { + if (isMulticlass) { // For categorical variables in multiclass classification, // each bin is a category. The bins are sorted and they // are ordered by calculating the impurity of their corresponding labels. @@ -1466,7 +1419,7 @@ object DecisionTree extends Serializable with Logging { .groupBy(_._1) .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum))) + .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum))) } else { // regression or binary classification // For categorical variables in regression and binary classification, // each bin is a category. The bins are sorted and they @@ -1518,7 +1471,7 @@ object DecisionTree extends Serializable with Logging { // Find all bins. featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + val isFeatureContinuous = metadata.isContinuous(featureIndex) if (isFeatureContinuous) { // Bins for categorical variables are already assigned. bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) @@ -1532,7 +1485,7 @@ object DecisionTree extends Serializable with Logging { } featureIndex += 1 } - (splits,bins) + (splits, bins) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") case ApproxHist => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 79a01f58319e8..0ef9c6181a0a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental object Algo extends Enumeration { type Algo = Value val Classification, Regression = Value + + private[mllib] def fromString(name: String): Algo = name match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 4ee4bcd0bcbc7..cfc8192a85abd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -20,29 +20,37 @@ package org.apache.spark.mllib.tree.configuration import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ /** * :: Experimental :: * Stores all the configuration options for tree construction - * @param algo classification or regression - * @param impurity criterion used for information gain calculation + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @param impurity Criterion used for information gain calculation. + * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], + * [[org.apache.spark.mllib.tree.impurity.Entropy]]. + * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. * @param maxDepth Maximum depth of the tree. * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param numClassesForClassification Number of classes for classification. + * (Ignored for regression.) + * Default value is 2 (binary classification). + * @param maxBins Maximum number of bins used for discretizing continuous features and + * for choosing how to split on features at each node. + * More bins give higher granularity. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: + * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the * number of discrete values they take. For example, an entry (n -> * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is + * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * */ @Experimental class Strategy ( @@ -64,20 +72,7 @@ class Strategy ( = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** - * Java-friendly constructor. - * - * @param algo classification or regression - * @param impurity criterion used for information gain calculation - * @param maxDepth Maximum depth of the tree. - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification - * @param maxBins maximum number of bins used for splitting features - * @param categoricalFeaturesInfo A map storing information about the categorical variables and - * the number of discrete values they take. For example, an entry - * (n -> k) implies the feature n is categorical with k categories - * 0, 1, 2, ... , k-1. It's important to note that features are - * zero-indexed. + * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ def this( algo: Algo, @@ -90,4 +85,37 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Check validity of parameters. + * Throws exception if invalid. + */ + private[tree] def assertValid(): Unit = { + algo match { + case Classification => + require(numClassesForClassification >= 2, + s"DecisionTree Strategy for Classification must have numClassesForClassification >= 2," + + s" but numClassesForClassification = $numClassesForClassification.") + require(Set(Gini, Entropy).contains(impurity), + s"DecisionTree Strategy given invalid impurity for Classification: $impurity." + + s" Valid settings: Gini, Entropy") + case Regression => + require(impurity == Variance, + s"DecisionTree Strategy given invalid impurity for Regression: $impurity." + + s" Valid settings: Variance") + case _ => + throw new IllegalArgumentException( + s"DecisionTree Strategy given invalid algo parameter: $algo." + + s" Valid settings are: Classification, Regression.") + } + require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." + + s" Valid values are integers >= 0.") + require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." + + s" Valid values are integers >= 2.") + categoricalFeaturesInfo.foreach { case (feature, arity) => + require(arity >= 2, + s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + + s" feature $feature has $arity categories. The number of categories should be >= 2.") + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala new file mode 100644 index 0000000000000..d9eda354dc986 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.rdd.RDD + + +/** + * Learning and dataset metadata for DecisionTree. + * + * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. + * For regression: fixed at 0 (no meaning). + * @param featureArity Map: categorical feature index --> arity. + * I.e., the feature takes values in {0, ..., arity - 1}. + */ +private[tree] class DecisionTreeMetadata( + val numFeatures: Int, + val numExamples: Long, + val numClasses: Int, + val maxBins: Int, + val featureArity: Map[Int, Int], + val unorderedFeatures: Set[Int], + val impurity: Impurity, + val quantileStrategy: QuantileStrategy) extends Serializable { + + def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) + + def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) + + def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + +} + +private[tree] object DecisionTreeMetadata { + + def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { + + val numFeatures = input.take(1)(0).features.size + val numExamples = input.count() + val numClasses = strategy.algo match { + case Classification => strategy.numClassesForClassification + case Regression => 0 + } + + val maxBins = math.min(strategy.maxBins, numExamples).toInt + val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0) + + val unorderedFeatures = new mutable.HashSet[Int]() + if (numClasses > 2) { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + if (k - 1 < log2MaxBinsp1) { + // Note: The above check is equivalent to checking: + // numUnorderedBins = (1 << k - 1) - 1 < maxBins + unorderedFeatures.add(f) + } else { + // TODO: Allow this case, where we simply will know nothing about some categories? + require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + + s"in categorical features (>= $k)") + } + } + } else { + strategy.categoricalFeaturesInfo.foreach { case (f, k) => + require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + + s"in categorical features (>= $k)") + } + } + + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, + strategy.impurity, strategy.quantileCalculationStrategy) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala new file mode 100644 index 0000000000000..d215d68c4279e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable.{HashMap => MutableHashMap} + +import org.apache.spark.annotation.Experimental + +/** + * Time tracker implementation which holds labeled timers. + */ +@Experimental +private[tree] class TimeTracker extends Serializable { + + private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + /** + * Starts a new timer, or re-starts a stopped timer. + */ + def start(timerLabel: String): Unit = { + val currentTime = System.nanoTime() + if (starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + + s" timerLabel = $timerLabel before that timer was stopped.") + } + starts(timerLabel) = currentTime + } + + /** + * Stops a timer and returns the elapsed time in seconds. + */ + def stop(timerLabel: String): Double = { + val currentTime = System.nanoTime() + if (!starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + + s" timerLabel = $timerLabel, but that timer was not started.") + } + val elapsed = currentTime - starts(timerLabel) + starts.remove(timerLabel) + if (totals.contains(timerLabel)) { + totals(timerLabel) += elapsed + } else { + totals(timerLabel) = elapsed + } + elapsed / 1e9 + } + + /** + * Print all timing results in seconds. + */ + override def toString: String = { + totals.map { case (label, elapsed) => + s" $label: ${elapsed / 1e9}" + }.mkString("\n") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala new file mode 100644 index 0000000000000..170e43e222083 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.Bin +import org.apache.spark.rdd.RDD + + +/** + * Internal representation of LabeledPoint for DecisionTree. + * This bins feature values based on a subsampled of data as follows: + * (a) Continuous features are binned into ranges. + * (b) Unordered categorical features are binned based on subsets of feature values. + * "Unordered categorical features" are categorical features with low arity used in + * multiclass classification. + * (c) Ordered categorical features are binned based on feature values. + * "Ordered categorical features" are categorical features with high arity, + * or any categorical feature used in regression or binary classification. + * + * @param label Label from LabeledPoint + * @param binnedFeatures Binned feature values. + * Same length as LabeledPoint.features, but values are bin indices. + */ +private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) + extends Serializable { +} + +private[tree] object TreePoint { + + /** + * Convert an input dataset into its TreePoint representation, + * binning feature values in preparation for DecisionTree training. + * @param input Input dataset. + * @param bins Bins for features, of size (numFeatures, numBins). + * @param metadata Learning and dataset metadata + * @return TreePoint dataset representation + */ + def convertToTreeRDD( + input: RDD[LabeledPoint], + bins: Array[Array[Bin]], + metadata: DecisionTreeMetadata): RDD[TreePoint] = { + input.map { x => + TreePoint.labeledPointToTreePoint(x, bins, metadata) + } + } + + /** + * Convert one LabeledPoint into its TreePoint representation. + * @param bins Bins for features, of size (numFeatures, numBins). + */ + private def labeledPointToTreePoint( + labeledPoint: LabeledPoint, + bins: Array[Array[Bin]], + metadata: DecisionTreeMetadata): TreePoint = { + + val numFeatures = labeledPoint.features.size + val numBins = bins(0).size + val arr = new Array[Int](numFeatures) + var featureIndex = 0 + while (featureIndex < numFeatures) { + arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex), + metadata.isUnordered(featureIndex), bins, metadata.featureArity) + featureIndex += 1 + } + + new TreePoint(labeledPoint.label, arr) + } + + /** + * Find bin for one (labeledPoint, feature). + * + * @param isUnorderedFeature (only applies if feature is categorical) + * @param bins Bins for features, of size (numFeatures, numBins). + * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity + */ + private def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean, + isUnorderedFeature: Boolean, + bins: Array[Array[Bin]], + categoricalFeaturesInfo: Map[Int, Int]): Int = { + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) && (highThreshold >= feature)) { + return mid + } else if (lowThreshold >= feature) { + right = mid - 1 + } else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. + */ + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } + + /** + * Sequential search helper method to find bin for categorical feature + * (for classification and regression). + */ + def sequentialBinSearchForOrderedCategoricalFeature(): Int = { + val featureCategories = categoricalFeaturesInfo(featureIndex) + val featureValue = labeledPoint.features(featureIndex) + var binIndex = 0 + while (binIndex < featureCategories) { + val bin = bins(featureIndex)(binIndex) + val categories = bin.highSplit.categories + if (categories.contains(featureValue)) { + return binIndex + } + binIndex += 1 + } + if (featureValue < 0 || featureValue >= featureCategories) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureCategories - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1) { + throw new RuntimeException("No bin was found for continuous feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = if (isUnorderedFeature) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() + } else { + sequentialBinSearchForOrderedCategoricalFeature() + } + if (binIndex == -1) { + throw new RuntimeException("No bin was found for categorical feature." + + " This error can occur when given invalid data values (such as NaN)." + + s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + } + binIndex + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala new file mode 100644 index 0000000000000..9a6452aa13a61 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +/** + * Factory for Impurity instances. + */ +private[mllib] object Impurities { + + def fromString(name: String): Impurity = name match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name") + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index c89c1e371a40e..af35d88f713e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -20,15 +20,25 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ /** - * Used for "binning" the features bins for faster best split calculation. For a continuous - * feature, a bin is determined by a low and a high "split". For a categorical feature, - * the a bin is determined using a single label value (category). + * Used for "binning" the features bins for faster best split calculation. + * + * For a continuous feature, the bin is determined by a low and a high split, + * where an example with featureValue falls into the bin s.t. + * lowSplit.threshold < featureValue <= highSplit.threshold. + * + * For ordered categorical features, there is a 1-1-1 correspondence between + * bins, splits, and feature values. The bin is determined by category/feature value. + * However, the bins are not necessarily ordered by feature value; + * they are ordered using impurity. + * For unordered categorical features, there is a 1-1 correspondence between bins, splits, + * where bins and splits correspond to subsets of feature values (in highSplit.categories). + * * @param lowSplit signifying the lower threshold for the continuous feature to be * accepted in the bin * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin for binary classification + * @param category categorical label value accepted in the bin for ordered features */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index 3d3406b5d5f22..0594fd0749d21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -39,7 +39,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable * @return Double prediction from the trained model */ def predict(features: Vector): Double = { - topNode.predictIfLeaf(features) + topNode.predict(features) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 944f11c2c2e4f..0eee6262781c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -69,24 +69,24 @@ class Node ( /** * predict value if node is not leaf - * @param feature feature value + * @param features feature value * @return predicted value */ - def predictIfLeaf(feature: Vector) : Double = { + def predict(features: Vector) : Double = { if (isLeaf) { predict } else{ if (split.get.featureType == Continuous) { - if (feature(split.get.feature) <= split.get.threshold) { - leftNode.get.predictIfLeaf(feature) + if (features(split.get.feature) <= split.get.threshold) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } else { - if (split.get.categories.contains(feature(split.get.feature))) { - leftNode.get.predictIfLeaf(feature) + if (split.get.categories.contains(features(split.get.feature))) { + leftNode.get.predict(features) } else { - rightNode.get.predictIfLeaf(feature) + rightNode.get.predict(features) } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index d7ffd386c05ee..50fb48b40de3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -24,9 +24,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType * :: DeveloperApi :: * Split applied to a feature * @param feature feature index - * @param threshold threshold for continuous feature + * @param threshold Threshold for continuous feature. + * Split left if feature <= threshold, else right. * @param featureType type of feature -- categorical or continuous - * @param categories accepted values for categorical variables + * @param categories Split left if categorical feature value is in this set, else right. */ @DeveloperApi case class Split( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index f4cce86a65ba7..ca35100aa99c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD import org.apache.spark.util.random.BernoulliSampler -import org.apache.spark.mllib.regression.{LabeledPointParser, LabeledPoint} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext @@ -185,7 +185,7 @@ object MLUtils { * @return labeled points stored as an RDD[LabeledPoint] */ def loadLabeledPoints(sc: SparkContext, path: String, minPartitions: Int): RDD[LabeledPoint] = - sc.textFile(path, minPartitions).map(LabeledPointParser.parse) + sc.textFile(path, minPartitions).map(LabeledPoint.parse) /** * Loads labeled points saved using `RDD[LabeledPoint].saveAsTextFile` with the default number of @@ -194,19 +194,6 @@ object MLUtils { def loadLabeledPoints(sc: SparkContext, dir: String): RDD[LabeledPoint] = loadLabeledPoints(sc, dir, sc.defaultMinPartitions) - /** - * Loads streaming labeled points from a stream of text files - * where points are in the same format as used in `RDD[LabeledPoint].saveAsTextFile`. - * See `StreamingContext.textFileStream` for more details on how to - * generate a stream from files - * - * @param ssc Streaming context - * @param dir Directory path in any Hadoop-supported file system URI - * @return Labeled points stored as a DStream[LabeledPoint] - */ - def loadStreamingLabeledPoints(ssc: StreamingContext, dir: String): DStream[LabeledPoint] = - ssc.textFileStream(dir).map(LabeledPointParser.parse) - /** * Load labeled data from a file. The data format used here is * , ... diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java new file mode 100644 index 0000000000000..fb7afe8c6434b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import com.google.common.base.Strings; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaWord2VecSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaWord2VecSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void word2Vec() { + // The tests are to check Java compatibility. + String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); + List words = Lists.newArrayList(sentence.split(" ")); + List> localDoc = Lists.newArrayList(words, words); + JavaRDD> doc = sc.parallelize(localDoc); + Word2Vec word2vec = new Word2Vec() + .setVectorSize(10) + .setSeed(42L); + Word2VecModel model = word2vec.fit(doc); + Tuple2[] syms = model.findSynonyms("a", 2); + Assert.assertEquals(2, syms.length); + Assert.assertEquals("b", syms[0]._1()); + Assert.assertEquals("c", syms[1]._1()); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java new file mode 100644 index 0000000000000..a725736ca1a58 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/random/JavaRandomRDDsSuite.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.random; + +import com.google.common.collect.Lists; +import org.apache.spark.api.java.JavaRDD; +import org.junit.Assert; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import static org.apache.spark.mllib.random.RandomRDDs.*; + +public class JavaRandomRDDsSuite { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRandomRDDsSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testUniformRDD() { + long m = 1000L; + int p = 2; + long seed = 1L; + JavaDoubleRDD rdd1 = uniformJavaRDD(sc, m); + JavaDoubleRDD rdd2 = uniformJavaRDD(sc, m, p); + JavaDoubleRDD rdd3 = uniformJavaRDD(sc, m, p, seed); + for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + } + } + + @Test + public void testNormalRDD() { + long m = 1000L; + int p = 2; + long seed = 1L; + JavaDoubleRDD rdd1 = normalJavaRDD(sc, m); + JavaDoubleRDD rdd2 = normalJavaRDD(sc, m, p); + JavaDoubleRDD rdd3 = normalJavaRDD(sc, m, p, seed); + for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + } + } + + @Test + public void testPoissonRDD() { + double mean = 2.0; + long m = 1000L; + int p = 2; + long seed = 1L; + JavaDoubleRDD rdd1 = poissonJavaRDD(sc, mean, m); + JavaDoubleRDD rdd2 = poissonJavaRDD(sc, mean, m, p); + JavaDoubleRDD rdd3 = poissonJavaRDD(sc, mean, m, p, seed); + for (JavaDoubleRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testUniformVectorRDD() { + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = uniformJavaVectorRDD(sc, m, n); + JavaRDD rdd2 = uniformJavaVectorRDD(sc, m, n, p); + JavaRDD rdd3 = uniformJavaVectorRDD(sc, m, n, p, seed); + for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testNormalVectorRDD() { + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = normalJavaVectorRDD(sc, m, n); + JavaRDD rdd2 = normalJavaVectorRDD(sc, m, n, p); + JavaRDD rdd3 = normalJavaVectorRDD(sc, m, n, p, seed); + for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testPoissonVectorRDD() { + double mean = 2.0; + long m = 100L; + int n = 10; + int p = 2; + long seed = 1L; + JavaRDD rdd1 = poissonJavaVectorRDD(sc, mean, m, n); + JavaRDD rdd2 = poissonJavaVectorRDD(sc, mean, m, n, p); + JavaRDD rdd3 = poissonJavaVectorRDD(sc, mean, m, n, p, seed); + for (JavaRDD rdd: Lists.newArrayList(rdd1, rdd2, rdd3)) { + Assert.assertEquals(m, rdd.count()); + Assert.assertEquals(n, rdd.first().size()); + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index bd413a80f5107..092d67bbc5238 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends FunSuite { - val py = new PythonMLLibAPI test("vector serialization") { val vectors = Seq( @@ -34,8 +33,8 @@ class PythonMLLibAPISuite extends FunSuite { Vectors.sparse(1, Array.empty[Int], Array.empty[Double]), Vectors.sparse(2, Array(1), Array(-2.0))) vectors.foreach { v => - val bytes = py.serializeDoubleVector(v) - val u = py.deserializeDoubleVector(bytes) + val bytes = SerDe.serializeDoubleVector(v) + val u = SerDe.deserializeDoubleVector(bytes) assert(u.getClass === v.getClass) assert(u === v) } @@ -50,8 +49,8 @@ class PythonMLLibAPISuite extends FunSuite { LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])), LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0)))) points.foreach { p => - val bytes = py.serializeLabeledPoint(p) - val q = py.deserializeLabeledPoint(bytes) + val bytes = SerDe.serializeLabeledPoint(p) + val q = SerDe.deserializeLabeledPoint(bytes) assert(q.label === p.label) assert(q.features.getClass === p.features.getClass) assert(q.features === p.features) @@ -60,8 +59,8 @@ class PythonMLLibAPISuite extends FunSuite { test("double serialization") { for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) { - val bytes = py.serializeDouble(x) - val deser = py.deserializeDouble(bytes) + val bytes = SerDe.serializeDouble(x) + val deser = SerDe.deserializeDouble(bytes) // We use `equals` here for comparison because we cannot use `==` for NaN assert(x.equals(deser)) } @@ -70,14 +69,14 @@ class PythonMLLibAPISuite extends FunSuite { test("matrix to 2D array") { val values = Array[Double](0, 1.2, 3, 4.56, 7, 8) val matrix = Matrices.dense(2, 3, values) - val arr = py.to2dArray(matrix) + val arr = SerDe.to2dArray(matrix) val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8)) assert(arr === expected) // Test conversion for empty matrix val empty = Array[Double]() val emptyMatrix = Matrices.dense(0, 0, empty) - val empty2D = py.to2dArray(emptyMatrix) + val empty2D = SerDe.to2dArray(emptyMatrix) assert(empty2D === Array[Array[Double]]()) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index da7c633bbd2af..862178694a50e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -67,7 +67,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match } // Test if we can correctly learn A, B where Y = logistic(A + B*X) - test("logistic regression") { + test("logistic regression with SGD") { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -94,7 +94,36 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } - test("logistic regression with initial weights") { + // Test if we can correctly learn A, B where Y = logistic(A + B*X) + test("logistic regression with LBFGS") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + val model = lr.run(testRDD) + + // Test the weights + assert(model.weights(0) ~== -1.52 relTol 0.01) + assert(model.intercept ~== 2.00 relTol 0.01) + assert(model.weights(0) ~== model.weights(0) relTol 0.01) + assert(model.intercept ~== model.intercept relTol 0.01) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("logistic regression with initial weights with SGD") { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -125,11 +154,99 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("logistic regression with initial weights with LBFGS") { + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialB = -1.0 + val initialWeights = Vectors.dense(initialB) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + // Use half as many iterations as the previous test. + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + val model = lr.run(testRDD, initialWeights) + + // Test the weights + assert(model.weights(0) ~== -1.50 relTol 0.02) + assert(model.intercept ~== 1.97 relTol 0.02) + + val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } + + test("numerical stability of scaling features using logistic regression with LBFGS") { + /** + * If we rescale the features, the condition number will be changed so the convergence rate + * and the solution will not equal to the original solution multiple by the scaling factor + * which it should be. + * + * However, since in the LogisticRegressionWithLBFGS, we standardize the training dataset first, + * no matter how we multiple a scaling factor into the dataset, the convergence rate should be + * the same, and the solution should equal to the original solution multiple by the scaling + * factor. + */ + + val nPoints = 10000 + val A = 2.0 + val B = -1.5 + + val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) + + val initialWeights = Vectors.dense(0.0) + + val testRDD1 = sc.parallelize(testData, 2) + + val testRDD2 = sc.parallelize( + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E3))), 2) + + val testRDD3 = sc.parallelize( + testData.map(x => LabeledPoint(x.label, Vectors.fromBreeze(x.features.toBreeze * 1.0E6))), 2) + + testRDD1.cache() + testRDD2.cache() + testRDD3.cache() + + val lrA = new LogisticRegressionWithLBFGS().setIntercept(true) + val lrB = new LogisticRegressionWithLBFGS().setIntercept(true).setFeatureScaling(false) + + val modelA1 = lrA.run(testRDD1, initialWeights) + val modelA2 = lrA.run(testRDD2, initialWeights) + val modelA3 = lrA.run(testRDD3, initialWeights) + + val modelB1 = lrB.run(testRDD1, initialWeights) + val modelB2 = lrB.run(testRDD2, initialWeights) + val modelB3 = lrB.run(testRDD3, initialWeights) + + // For model trained with feature standardization, the weights should + // be the same in the scaled space. Note that the weights here are already + // in the original space, we transform back to scaled space to compare. + assert(modelA1.weights(0) ~== modelA2.weights(0) * 1.0E3 absTol 0.01) + assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01) + + // Training data with different scales without feature standardization + // will not yield the same result in the scaled space due to poor + // convergence rate. + assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1) + assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) + } + } class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext { - test("task size should be small in both training and prediction") { + test("task size should be small in both training and prediction using SGD optimizer") { val m = 4 val n = 200000 val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => @@ -139,6 +256,30 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont // If we serialize data directly in the task closure, the size of the serialized task would be // greater than 1MB and hence Spark would throw an error. val model = LogisticRegressionWithSGD.train(points, 2) + val predictions = model.predict(points.map(_.features)) + + // Materialize the RDDs + predictions.count() } + + test("task size should be small in both training and prediction using LBFGS optimizer") { + val m = 4 + val n = 200000 + val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) => + val random = new Random(idx) + iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble())))) + }.cache() + // If we serialize data directly in the task closure, the size of the serialized task would be + // greater than 1MB and hence Spark would throw an error. + val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + lr.optimizer.setNumIterations(2) + val model = lr.run(points) + + val predictions = model.predict(points.map(_.features)) + + // Materialize the RDDs + predictions.count() + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 06cdd04f5fdae..80989bc074e84 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.FunSuite +import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} @@ -95,6 +96,33 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("detect negative values") { + val dense = Seq( + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(-1.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(0.0))) + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(dense, 2)) + } + val sparse = Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))) + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(sparse, 2)) + } + val nan = Seq( + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))), + LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))), + LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))) + intercept[SparkException] { + NaiveBayes.train(sc.makeRDD(nan, 2)) + } + } } class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 78a2804ff204b..53d9c0c640b98 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -36,18 +36,12 @@ class IDFSuite extends FunSuite with LocalSparkContext { val m = localTermFrequencies.size val termFrequencies = sc.parallelize(localTermFrequencies, 2) val idf = new IDF - intercept[IllegalStateException] { - idf.idf() - } - intercept[IllegalStateException] { - idf.transform(termFrequencies) - } - idf.fit(termFrequencies) + val model = idf.fit(termFrequencies) val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => math.log((m.toDouble + 1.0) / (x + 1.0)) }) - assert(idf.idf() ~== expected absTol 1e-12) - val tfidf = idf.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(model.idf ~== expected absTol 1e-12) + val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() assert(tfidf.size === 3) val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] assert(tfidf0.indices === Array(1, 3)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 5a9be923a8625..e217b93cebbdb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -50,23 +50,17 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext { val standardizer2 = new StandardScaler() val standardizer3 = new StandardScaler(withMean = true, withStd = false) - withClue("Using a standardizer before fitting the model should throw exception.") { - intercept[IllegalStateException] { - data.map(standardizer1.transform) - } - } - - standardizer1.fit(dataRDD) - standardizer2.fit(dataRDD) - standardizer3.fit(dataRDD) + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(standardizer1.transform) - val data2 = data.map(standardizer2.transform) - val data3 = data.map(standardizer3.transform) + val data1 = data.map(model1.transform) + val data2 = data.map(model2.transform) + val data3 = data.map(model3.transform) - val data1RDD = standardizer1.transform(dataRDD) - val data2RDD = standardizer2.transform(dataRDD) - val data3RDD = standardizer3.transform(dataRDD) + val data1RDD = model1.transform(dataRDD) + val data2RDD = model2.transform(dataRDD) + val data3RDD = model3.transform(dataRDD) val summary = computeSummary(dataRDD) val summary1 = computeSummary(data1RDD) @@ -129,25 +123,25 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext { val standardizer2 = new StandardScaler() val standardizer3 = new StandardScaler(withMean = true, withStd = false) - standardizer1.fit(dataRDD) - standardizer2.fit(dataRDD) - standardizer3.fit(dataRDD) + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) - val data2 = data.map(standardizer2.transform) + val data2 = data.map(model2.transform) withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(standardizer1.transform) + data.map(model1.transform) } } withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(standardizer3.transform) + data.map(model3.transform) } } - val data2RDD = standardizer2.transform(dataRDD) + val data2RDD = model2.transform(dataRDD) val summary2 = computeSummary(data2RDD) @@ -181,13 +175,13 @@ class StandardScalerSuite extends FunSuite with LocalSparkContext { val standardizer2 = new StandardScaler(withMean = true, withStd = false) val standardizer3 = new StandardScaler(withMean = false, withStd = true) - standardizer1.fit(dataRDD) - standardizer2.fit(dataRDD) - standardizer3.fit(dataRDD) + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(standardizer1.transform) - val data2 = data.map(standardizer2.transform) - val data3 = data.map(standardizer3.transform) + val data1 = data.map(model1.transform) + val data2 = data.map(model2.transform) + val data3 = data.map(model3.transform) assert(data1.forall(_.toArray.forall(_ == 0.0)), "The variance is zero, so the transformed result should be 0.0") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala new file mode 100644 index 0000000000000..1952e6734ecf7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.linalg.BLAS._ + +class BLASSuite extends FunSuite { + + test("copy") { + val sx = Vectors.sparse(4, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0, 0.0) + val sy = Vectors.sparse(4, Array(0, 1, 3), Array(2.0, 1.0, 1.0)) + val dy = Array(2.0, 1.0, 0.0, 1.0) + + val dy1 = Vectors.dense(dy.clone()) + copy(sx, dy1) + assert(dy1 ~== dx absTol 1e-15) + + val dy2 = Vectors.dense(dy.clone()) + copy(dx, dy2) + assert(dy2 ~== dx absTol 1e-15) + + intercept[IllegalArgumentException] { + copy(sx, sy) + } + + intercept[IllegalArgumentException] { + copy(dx, sy) + } + + withClue("vector sizes must match") { + intercept[Exception] { + copy(sx, Vectors.dense(0.0, 1.0, 2.0)) + } + } + } + + test("scal") { + val a = 0.1 + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + + scal(a, sx) + assert(sx ~== Vectors.sparse(3, Array(0, 2), Array(0.1, -0.2)) absTol 1e-15) + + scal(a, dx) + assert(dx ~== Vectors.dense(0.1, 0.0, -0.2) absTol 1e-15) + } + + test("axpy") { + val alpha = 0.1 + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + val dy = Array(2.0, 1.0, 0.0) + val expected = Vectors.dense(2.1, 1.0, -0.2) + + val dy1 = Vectors.dense(dy.clone()) + axpy(alpha, sx, dy1) + assert(dy1 ~== expected absTol 1e-15) + + val dy2 = Vectors.dense(dy.clone()) + axpy(alpha, dx, dy2) + assert(dy2 ~== expected absTol 1e-15) + + val sy = Vectors.sparse(4, Array(0, 1), Array(2.0, 1.0)) + + intercept[IllegalArgumentException] { + axpy(alpha, sx, sy) + } + + intercept[IllegalArgumentException] { + axpy(alpha, dx, sy) + } + + withClue("vector sizes must match") { + intercept[Exception] { + axpy(alpha, sx, Vectors.dense(1.0, 2.0)) + } + } + } + + test("dot") { + val sx = Vectors.sparse(3, Array(0, 2), Array(1.0, -2.0)) + val dx = Vectors.dense(1.0, 0.0, -2.0) + val sy = Vectors.sparse(3, Array(0, 1), Array(2.0, 1.0)) + val dy = Vectors.dense(2.0, 1.0, 0.0) + + assert(dot(sx, sy) ~== 2.0 absTol 1e-15) + assert(dot(sy, sx) ~== 2.0 absTol 1e-15) + assert(dot(sx, dy) ~== 2.0 absTol 1e-15) + assert(dot(dy, sx) ~== 2.0 absTol 1e-15) + assert(dot(dx, dy) ~== 2.0 absTol 1e-15) + assert(dot(dy, dx) ~== 2.0 absTol 1e-15) + + assert(dot(sx, sx) ~== 5.0 absTol 1e-15) + assert(dot(dx, dx) ~== 5.0 absTol 1e-15) + assert(dot(sx, dx) ~== 5.0 absTol 1e-15) + assert(dot(dx, sx) ~== 5.0 absTol 1e-15) + + val sx1 = Vectors.sparse(10, Array(0, 3, 5, 7, 8), Array(1.0, 2.0, 3.0, 4.0, 5.0)) + val sx2 = Vectors.sparse(10, Array(1, 3, 6, 7, 9), Array(1.0, 2.0, 3.0, 4.0, 5.0)) + assert(dot(sx1, sx2) ~== 20.0 absTol 1e-15) + assert(dot(sx2, sx1) ~== 20.0 absTol 1e-15) + + withClue("vector sizes must match") { + intercept[Exception] { + dot(sx, Vectors.dense(2.0, 1.0)) + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 7972ceea1fe8a..cd651fe2d2ddf 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -125,4 +125,34 @@ class VectorsSuite extends FunSuite { } } } + + test("zeros") { + assert(Vectors.zeros(3) === Vectors.dense(0.0, 0.0, 0.0)) + } + + test("Vector.copy") { + val sv = Vectors.sparse(4, Array(0, 2), Array(1.0, 2.0)) + val svCopy = sv.copy + (sv, svCopy) match { + case (sv: SparseVector, svCopy: SparseVector) => + assert(sv.size === svCopy.size) + assert(sv.indices === svCopy.indices) + assert(sv.values === svCopy.values) + assert(!sv.indices.eq(svCopy.indices)) + assert(!sv.values.eq(svCopy.values)) + case _ => + throw new RuntimeException(s"copy returned ${svCopy.getClass} on ${sv.getClass}.") + } + + val dv = Vectors.dense(1.0, 0.0, 2.0) + val dvCopy = dv.copy + (dv, dvCopy) match { + case (dv: DenseVector, dvCopy: DenseVector) => + assert(dv.size === dvCopy.size) + assert(dv.values === dvCopy.values) + assert(!dv.values.eq(dvCopy.values)) + case _ => + throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 325b817980f68..1d3a3221365cc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -99,7 +99,7 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { for (mat <- Seq(denseMat, sparseMat)) { for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { val localMat = mat.toBreeze() - val (localU, localSigma, localVt) = brzSvd(localMat) + val brzSvd.SVD(localU, localSigma, localVt) = brzSvd(localMat) val localV: BDM[Double] = localVt.t.toDenseMatrix for (k <- 1 to n) { val skip = (mode == "local-eigs" || mode == "dist-eigs") && k == n diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 5f4c24115ac80..ccba004baa007 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -55,7 +55,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { val initialWeightsWithIntercept = Vectors.dense(1.0 +: initialWeights.toArray) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val (_, loss) = LBFGS.runLBFGS( dataRDD, @@ -63,7 +63,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { simpleUpdater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -99,7 +99,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { // Prepare another non-zero weights to compare the loss in the first iteration. val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val (weightLBFGS, lossLBFGS) = LBFGS.runLBFGS( dataRDD, @@ -107,7 +107,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -140,10 +140,10 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { /** * For the first run, we set the convergenceTol to 0.0, so that the algorithm will - * run up to the maxNumIterations which is 8 here. + * run up to the numIterations which is 8 here. */ val initialWeightsWithIntercept = Vectors.dense(0.0, 0.0) - val maxNumIterations = 8 + val numIterations = 8 var convergenceTol = 0.0 val (_, lossLBFGS1) = LBFGS.runLBFGS( @@ -152,7 +152,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -167,7 +167,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -182,7 +182,7 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { squaredL2Updater, numCorrections, convergenceTol, - maxNumIterations, + numIterations, regParam, initialWeightsWithIntercept) @@ -200,12 +200,12 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { // Prepare another non-zero weights to compare the loss in the first iteration. val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) val convergenceTol = 1e-12 - val maxNumIterations = 10 + val numIterations = 10 val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) .setNumCorrections(numCorrections) .setConvergenceTol(convergenceTol) - .setMaxNumIterations(maxNumIterations) + .setNumIterations(numIterations) .setRegParam(regParam) val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) @@ -241,7 +241,7 @@ class LBFGSClusterSuite extends FunSuite with LocalClusterSparkContext { val lbfgs = new LBFGS(new LogisticGradient, new SquaredL2Updater) .setNumCorrections(1) .setConvergenceTol(1e-12) - .setMaxNumIterations(1) + .setNumIterations(1) .setRegParam(1.0) val random = new Random(0) // If we serialize data directly in the task closure, the size of the serialized task would be diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala similarity index 88% rename from mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 96e0bc63b0fa4..c50b78bcbcc61 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDGeneratorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Serializable { +class RandomRDDsSuite extends FunSuite with LocalSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, @@ -113,18 +113,18 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val poissonMean = 100.0 for (seed <- 0 until 5) { - val uniform = RandomRDDGenerators.uniformRDD(sc, size, numPartitions, seed) + val uniform = RandomRDDs.uniformRDD(sc, size, numPartitions, seed) testGeneratedRDD(uniform, size, numPartitions, 0.5, 1 / math.sqrt(12)) - val normal = RandomRDDGenerators.normalRDD(sc, size, numPartitions, seed) + val normal = RandomRDDs.normalRDD(sc, size, numPartitions, seed) testGeneratedRDD(normal, size, numPartitions, 0.0, 1.0) - val poisson = RandomRDDGenerators.poissonRDD(sc, poissonMean, size, numPartitions, seed) + val poisson = RandomRDDs.poissonRDD(sc, poissonMean, size, numPartitions, seed) testGeneratedRDD(poisson, size, numPartitions, poissonMean, math.sqrt(poissonMean), 0.1) } // mock distribution to check that partitions have unique seeds - val random = RandomRDDGenerators.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L) + val random = RandomRDDs.randomRDD(sc, new MockDistro(), 1000L, 1000, 0L) assert(random.collect.size === random.collect.distinct.size) } @@ -135,13 +135,13 @@ class RandomRDDGeneratorsSuite extends FunSuite with LocalSparkContext with Seri val poissonMean = 100.0 for (seed <- 0 until 5) { - val uniform = RandomRDDGenerators.uniformVectorRDD(sc, rows, cols, parts, seed) + val uniform = RandomRDDs.uniformVectorRDD(sc, rows, cols, parts, seed) testGeneratedVectorRDD(uniform, rows, cols, parts, 0.5, 1 / math.sqrt(12)) - val normal = RandomRDDGenerators.normalVectorRDD(sc, rows, cols, parts, seed) + val normal = RandomRDDs.normalVectorRDD(sc, rows, cols, parts, seed) testGeneratedVectorRDD(normal, rows, cols, parts, 0.0, 1.0) - val poisson = RandomRDDGenerators.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) + val poisson = RandomRDDs.poissonVectorRDD(sc, poissonMean, rows, cols, parts, seed) testGeneratedVectorRDD(poisson, rows, cols, parts, poissonMean, math.sqrt(poissonMean), 0.1) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala index d9308aaba6ee1..110c44a7193fd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LabeledPointSuite.scala @@ -28,12 +28,12 @@ class LabeledPointSuite extends FunSuite { LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0)))) points.foreach { p => - assert(p === LabeledPointParser.parse(p.toString)) + assert(p === LabeledPoint.parse(p.toString)) } } test("parse labeled points with v0.9 format") { - val point = LabeledPointParser.parse("1.0,1.0 0.0 -2.0") + val point = LabeledPoint.parse("1.0,1.0 0.0 -2.0") assert(point === LabeledPoint(1.0, Vectors.dense(1.0, 0.0, -2.0))) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index ed21f84472c9a..03b71301e9ab1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.mllib.regression -import java.io.File -import java.nio.charset.Charset - import scala.collection.mutable.ArrayBuffer -import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext, MLUtils} -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.util.Utils +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.TestSuiteBase + +class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase { -class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { + // use longer wait time to ensure job completion + override def maxWaitTimeMillis = 20000 // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { @@ -49,35 +48,25 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { } // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data - test("streaming linear regression parameter accuracy") { - - val testDir = Files.createTempDir() - val numBatches = 10 - val batchDuration = Milliseconds(1000) - val ssc = new StreamingContext(sc, batchDuration) - val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + test("parameter accuracy") { + // create model val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0, 0.0)) .setStepSize(0.1) - .setNumIterations(50) - - model.trainOn(data) - - ssc.start() + .setNumIterations(25) - // write data to a file stream - for (i <- 0 until numBatches) { - val samples = LinearDataGenerator.generateLinearInput( - 0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) - val file = new File(testDir, i.toString) - Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) - Thread.sleep(batchDuration.milliseconds) + // generate sequence of simulated data + val numBatches = 10 + val input = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42 * (i + 1)) } - ssc.stop(stopSparkContext=false) - - System.clearProperty("spark.driver.port") - Utils.deleteRecursively(testDir) + // apply model training to input stream + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) // check accuracy of final parameter estimates assertEqual(model.latestModel().intercept, 0.0, 0.1) @@ -91,45 +80,63 @@ class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext { } // Test that parameter estimates improve when learning Y = 10*X1 on streaming data - test("streaming linear regression parameter convergence") { - - val testDir = Files.createTempDir() - val batchDuration = Milliseconds(2000) - val ssc = new StreamingContext(sc, batchDuration) - val numBatches = 5 - val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString) + test("parameter convergence") { + // create model val model = new StreamingLinearRegressionWithSGD() .setInitialWeights(Vectors.dense(0.0)) .setStepSize(0.1) - .setNumIterations(50) + .setNumIterations(25) - model.trainOn(data) - - ssc.start() - - // write data to a file stream - val history = new ArrayBuffer[Double](numBatches) - for (i <- 0 until numBatches) { - val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) - val file = new File(testDir, i.toString) - Files.write(samples.map(x => x.toString).mkString("\n"), file, Charset.forName("UTF-8")) - Thread.sleep(batchDuration.milliseconds) - // wait an extra few seconds to make sure the update finishes before new data arrive - Thread.sleep(4000) - history.append(math.abs(model.latestModel().weights(0) - 10.0)) + // generate sequence of simulated data + val numBatches = 10 + val input = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 100, 42 * (i + 1)) } - ssc.stop(stopSparkContext=false) + // create buffer to store intermediate fits + val history = new ArrayBuffer[Double](numBatches) - System.clearProperty("spark.driver.port") - Utils.deleteRecursively(testDir) + // apply model training to input stream, storing the intermediate results + // (we add a count to ensure the result is a DStream) + val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + // compute change in error val deltas = history.drop(1).zip(history.dropRight(1)) // check error stability (it always either shrinks, or increases with small tol) assert(deltas.forall(x => (x._1 - x._2) <= 0.1)) // check that error shrunk on at least 2 batches assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1) - } + // Test predictions on a stream + test("predictions") { + // create model initialized with true weights + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(10.0, 10.0)) + .setStepSize(0.1) + .setNumIterations(25) + + // generate sequence of simulated data for testing + val numBatches = 10 + val nPoints = 100 + val testInput = (0 until numBatches).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), nPoints, 42 * (i + 1)) + } + + // apply model predictions to test stream + val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + }) + // collect the output as (true, estimated) tuples + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + + // compute the mean absolute error and check that it's always less than 0.1 + val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints) + assert(errors.forall(x => x <= 0.1)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index a3f76f77a5dcc..34548c86ebc14 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -39,6 +39,17 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { Vectors.dense(9.0, 0.0, 0.0, 1.0) ) + test("corr(x, y) pearson, 1 value in data") { + val x = sc.parallelize(Array(1.0)) + val y = sc.parallelize(Array(4.0)) + intercept[RuntimeException] { + Statistics.corr(x, y, "pearson") + } + intercept[RuntimeException] { + Statistics.corr(x, y, "spearman") + } + } + test("corr(x, y) default, pearson") { val x = sc.parallelize(xData) val y = sc.parallelize(yData) @@ -58,7 +69,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // RDD of zero variance val z = sc.parallelize(zeros) - assert(Statistics.corr(x, z).isNaN()) + assert(Statistics.corr(x, z).isNaN) } test("corr(x, y) spearman") { @@ -78,7 +89,7 @@ class CorrelationSuite extends FunSuite with LocalSparkContext { // RDD of zero variance => zero variance in ranks val z = sc.parallelize(zeros) - assert(Statistics.corr(x, z, "spearman").isNaN()) + assert(Statistics.corr(x, z, "spearman").isNaN) } test("corr(X) default, pearson") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala new file mode 100644 index 0000000000000..6de3840b3f198 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import java.util.Random + +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.test.ChiSqTest +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class HypothesisTestSuite extends FunSuite with LocalSparkContext { + + test("chi squared pearson goodness of fit") { + + val observed = new DenseVector(Array[Double](4, 6, 5)) + val pearson = Statistics.chiSqTest(observed) + + // Results validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))` + assert(pearson.statistic === 0.4) + assert(pearson.degreesOfFreedom === 2) + assert(pearson.pValue ~== 0.8187 relTol 1e-4) + assert(pearson.method === ChiSqTest.PEARSON.name) + assert(pearson.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // different expected and observed sum + val observed1 = new DenseVector(Array[Double](21, 38, 43, 80)) + val expected1 = new DenseVector(Array[Double](3, 5, 7, 20)) + val pearson1 = Statistics.chiSqTest(observed1, expected1) + + // Results validated against the R command + // `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))` + assert(pearson1.statistic ~== 14.1429 relTol 1e-4) + assert(pearson1.degreesOfFreedom === 3) + assert(pearson1.pValue ~== 0.002717 relTol 1e-4) + assert(pearson1.method === ChiSqTest.PEARSON.name) + assert(pearson1.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // Vectors with different sizes + val observed3 = new DenseVector(Array(1.0, 2.0, 3.0)) + val expected3 = new DenseVector(Array(1.0, 2.0, 3.0, 4.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(observed3, expected3)) + + // negative counts in observed + val negObs = new DenseVector(Array(1.0, 2.0, 3.0, -4.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(negObs, expected1)) + + // count = 0.0 in expected but not observed + val zeroExpected = new DenseVector(Array(1.0, 0.0, 3.0)) + val inf = Statistics.chiSqTest(observed, zeroExpected) + assert(inf.statistic === Double.PositiveInfinity) + assert(inf.degreesOfFreedom === 2) + assert(inf.pValue === 0.0) + assert(inf.method === ChiSqTest.PEARSON.name) + assert(inf.nullHypothesis === ChiSqTest.NullHypothesis.goodnessOfFit.toString) + + // 0.0 in expected and observed simultaneously + val zeroObserved = new DenseVector(Array(2.0, 0.0, 1.0)) + intercept[IllegalArgumentException](Statistics.chiSqTest(zeroObserved, zeroExpected)) + } + + test("chi squared pearson matrix independence") { + val data = Array(40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0) + // [[40.0, 56.0, 31.0, 30.0], + // [24.0, 32.0, 10.0, 15.0], + // [29.0, 42.0, 0.0, 12.0]] + val chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + // Results validated against R command + // `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))` + assert(chi.statistic ~== 21.9958 relTol 1e-4) + assert(chi.degreesOfFreedom === 6) + assert(chi.pValue ~== 0.001213 relTol 1e-4) + assert(chi.method === ChiSqTest.PEARSON.name) + assert(chi.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + + // Negative counts + val negCounts = Array(4.0, 5.0, 3.0, -3.0) + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, negCounts))) + + // Row sum = 0.0 + val rowZero = Array(0.0, 1.0, 0.0, 2.0) + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, rowZero))) + + // Column sum = 0.0 + val colZero = Array(0.0, 0.0, 2.0, 2.0) + // IllegalArgumentException thrown here since it's thrown on driver, not inside a task + intercept[IllegalArgumentException](Statistics.chiSqTest(Matrices.dense(2, 2, colZero))) + } + + test("chi squared pearson RDD[LabeledPoint]") { + // labels: 1.0 (2 / 6), 0.0 (4 / 6) + // feature1: 0.5 (1 / 6), 1.5 (2 / 6), 3.5 (3 / 6) + // feature2: 10.0 (1 / 6), 20.0 (1 / 6), 30.0 (2 / 6), 40.0 (2 / 6) + val data = Seq( + LabeledPoint(0.0, Vectors.dense(0.5, 10.0)), + LabeledPoint(0.0, Vectors.dense(1.5, 20.0)), + LabeledPoint(1.0, Vectors.dense(1.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 30.0)), + LabeledPoint(0.0, Vectors.dense(3.5, 40.0)), + LabeledPoint(1.0, Vectors.dense(3.5, 40.0))) + for (numParts <- List(2, 4, 6, 8)) { + val chi = Statistics.chiSqTest(sc.parallelize(data, numParts)) + val feature1 = chi(0) + assert(feature1.statistic === 0.75) + assert(feature1.degreesOfFreedom === 2) + assert(feature1.pValue ~== 0.6873 relTol 1e-4) + assert(feature1.method === ChiSqTest.PEARSON.name) + assert(feature1.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + val feature2 = chi(1) + assert(feature2.statistic === 1.5) + assert(feature2.degreesOfFreedom === 3) + assert(feature2.pValue ~== 0.6823 relTol 1e-4) + assert(feature2.method === ChiSqTest.PEARSON.name) + assert(feature2.nullHypothesis === ChiSqTest.NullHypothesis.independence.toString) + } + + // Test that the right number of results is returned + val numCols = 1001 + val sparseData = Array( + new LabeledPoint(0.0, Vectors.sparse(numCols, Seq((100, 2.0)))), + new LabeledPoint(0.1, Vectors.sparse(numCols, Seq((200, 1.0))))) + val chi = Statistics.chiSqTest(sc.parallelize(sparseData)) + assert(chi.size === numCols) + assert(chi(1000) != null) // SPARK-3087 + + // Detect continous features or labels + val random = new Random(11L) + val continuousLabel = + Seq.fill(100000)(LabeledPoint(random.nextDouble(), Vectors.dense(random.nextInt(2)))) + intercept[SparkException] { + Statistics.chiSqTest(sc.parallelize(continuousLabel, 2)) + } + val continuousFeature = + Seq.fill(100000)(LabeledPoint(random.nextInt(2), Vectors.dense(random.nextDouble()))) + intercept[SparkException] { + Statistics.chiSqTest(sc.parallelize(continuousFeature, 2)) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index db13f142df517..1e9415249104b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -139,7 +139,8 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") assert(summarizer.variance ~== - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, + "variance mismatch") assert(summarizer.count === 6) } @@ -167,7 +168,8 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer.numNonzeros ~== Vectors.dense(3, 5, 2) absTol 1E-5, "numNonzeros mismatch") assert(summarizer.variance ~== - Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, "variance mismatch") + Vectors.dense(3.857666666666, 7.0456666666666, 2.48166666666666) absTol 1E-5, + "variance mismatch") assert(summarizer.count === 6) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 70ca7c8a266f2..2f36fd907772c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -21,11 +21,12 @@ import scala.collection.JavaConverters._ import org.scalatest.FunSuite -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split} -import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -41,7 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { prediction != expected.label } val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy) + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") } def validateRegressor( @@ -54,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { err * err }.sum val mse = squaredError / input.length - assert(mse <= requiredMSE) + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") } test("split and bin calculation") { @@ -62,7 +64,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -80,7 +83,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) assert(splits(0).length === 99) @@ -160,7 +164,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Check splits. @@ -277,7 +282,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -371,7 +377,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) // 2^10 - 1 > 100, so categorical variables will be ordered @@ -426,9 +433,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -453,9 +462,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -491,7 +502,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -499,8 +511,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -513,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -521,8 +535,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -536,7 +551,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -544,8 +560,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -559,7 +576,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -567,8 +585,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._2.gain === 0) @@ -582,7 +601,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, 3, 2, 100) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) assert(bins.length === 2) @@ -590,13 +610,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) - val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) - val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) + // Train a 1-node model + val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) + val modelOneNode = DecisionTree.train(rdd, strategyOneNode) + val nodes: Array[Node] = new Array[Node](7) + nodes(0) = modelOneNode.topNode + nodes(0).leftNode = None + nodes(0).rightNode = None + val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. - val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes, splits, bins, 10) assert(bestSplits.length === 2) assert(bestSplits(0)._2.gain > 0) @@ -604,8 +630,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second // level tree construction. - val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, - filters, splits, bins, 0) + val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, + nodes, splits, bins, 0) assert(bestSplitsWithGroups.length === 2) assert(bestSplitsWithGroups(0)._2.gain > 0) assert(bestSplitsWithGroups(1)._2.gain > 0) @@ -620,18 +646,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) } - } test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -647,11 +674,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0)) arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0)) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) @@ -678,19 +705,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for multiclass classification, with just enough bins") { val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, - numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + numClassesForClassification = 3, maxBins = maxBins, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -705,17 +735,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -729,17 +761,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 @@ -752,13 +786,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with categorical variables for ordered multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, - Array[List[Filter]](), splits, bins, 10) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 diff --git a/pom.xml b/pom.xml index 4ab027bad55c0..0d44cf4ea5f92 100644 --- a/pom.xml +++ b/pom.xml @@ -143,11 +143,10 @@ - maven-repo + central Maven Repository - - http://repo.maven.apache.org/maven2 + https://repo1.maven.org/maven2 true @@ -213,7 +212,7 @@ spring-releases Spring Release Repository - http://repo.spring.io/libs-release + https://repo.spring.io/libs-release true @@ -222,6 +221,18 @@ + + + central + https://repo1.maven.org/maven2 + + true + + + false + + + @@ -305,7 +316,7 @@ org.xerial.snappy snappy-java - 1.0.5 + 1.1.1.3 net.jpountz.lz4 @@ -409,7 +420,7 @@ io.netty netty-all - 4.0.17.Final + 4.0.23.Final org.apache.derby diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 537ca0dcf267d..300589394b96f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -61,6 +61,17 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.storage.MemoryStore.Entry") ) ++ + Seq( + // Serializer interface change. See SPARK-3045. + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.DeserializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.Serializer"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializationStream"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.serializer.SerializerInstance") + )++ Seq( // Renamed putValues -> putArray + putIterator ProblemFilters.exclude[MissingMethodProblem]( @@ -110,6 +121,21 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") + ) ++ + Seq( // package-private classes removed in MLlib + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") + ) ++ + Seq( // new Vector methods in MLlib (binary compatible assuming users do not implement Vector) + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.copy") + ) ++ + Seq( // synthetic methods generated in LabeledPoint + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.regression.LabeledPoint$"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.regression.LabeledPoint.apply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LabeledPoint.toString") + ) ++ + Seq ( // Scala 2.11 compatibility fix + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.$default$2") ) case v if v.startsWith("1.0") => Seq( diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 40b588512ff08..49d52aefca17a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -30,11 +30,11 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, spark, + val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", + "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = @@ -44,8 +44,9 @@ object BuildCommons { val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") .map(ProjectRef(buildLocation, _)) - val tools = "tools" - + val tools = ProjectRef(buildLocation, "tools") + // Root project. + val spark = ProjectRef(buildLocation, "spark") val sparkHome = buildLocation } @@ -115,7 +116,8 @@ object SparkBuild extends PomBuild { retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, - + + resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) @@ -125,26 +127,6 @@ object SparkBuild extends PomBuild { publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn ) - /** Following project only exists to pull previous artifacts of Spark for generating - Mima ignores. For more information see: SPARK 2071 */ - lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings) - - def versionArtifact(id: String): Option[sbt.ModuleID] = { - val fullId = id + "_2.10" - Some("org.apache.spark" % fullId % "1.0.0") - } - - def oldDepsSettings() = Defaults.defaultSettings ++ Seq( - name := "old-deps", - scalaVersion := "2.10.4", - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", - libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", - "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", - "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", - "spark-core").map(versionArtifact(_).get intransitive()) - ) - def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { val existingSettings = projectsMap.getOrElse(projectRef.project, Seq[Setting[_]]()) projectsMap += (projectRef.project -> (existingSettings ++ settings)) @@ -183,7 +165,7 @@ object SparkBuild extends PomBuild { super.projectDefinitions(baseDirectory).map { x => if (projectsMap.exists(_._1 == x.id)) x.settings(projectsMap(x.id): _*) else x.settings(Seq[Setting[_]](): _*) - } ++ Seq[Project](oldDeps) + } ++ Seq[Project](OldDeps.project) } } @@ -192,6 +174,31 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +/** + * Following project only exists to pull previous artifacts of Spark for generating + * Mima ignores. For more information see: SPARK 2071 + */ +object OldDeps { + + lazy val project = Project("oldDeps", file("dev"), settings = oldDepsSettings) + + def versionArtifact(id: String): Option[sbt.ModuleID] = { + val fullId = id + "_2.10" + Some("org.apache.spark" % fullId % "1.0.0") + } + + def oldDepsSettings() = Defaults.defaultSettings ++ Seq( + name := "old-deps", + scalaVersion := "2.10.4", + retrieveManaged := true, + retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", + libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", + "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", + "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", + "spark-core").map(versionArtifact(_).get intransitive()) + ) +} + object Catalyst { lazy val settings = Seq( addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full), @@ -221,7 +228,6 @@ object SQL { object Hive { lazy val settings = Seq( - javaOptions += "-XX:MaxPermSize=1g", // Multiple queries rely on the TestHive singleton. See comments there for more details. parallelExecution in Test := false, @@ -284,9 +290,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(repl, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, yarn, yarnAlpha), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha), // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { diff --git a/project/plugins.sbt b/project/plugins.sbt index 06d18e193076e..2a61f56c2ea60 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -23,6 +23,6 @@ addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") addSbtPlugin("com.alpinenow" % "junit_xml_listener" % "0.5.1") -addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.0") +addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.3.1") addSbtPlugin("com.cavorite" % "sbt-avro" % "0.3.2") diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 45d36e5d0e764..f133cf6f7befc 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -110,6 +110,7 @@ def _deserialize_accumulator(aid, zero_value, accum_param): class Accumulator(object): + """ A shared variable that can be accumulated, i.e., has a commutative and associative "add" operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} @@ -166,6 +167,7 @@ def __repr__(self): class AccumulatorParam(object): + """ Helper object that defines how to accumulate values of a given type. """ @@ -186,6 +188,7 @@ def addInPlace(self, value1, value2): class AddingAccumulatorParam(AccumulatorParam): + """ An AccumulatorParam that uses the + operators to add values. Designed for simple types such as integers, floats, and lists. Requires the zero value for the underlying type @@ -210,6 +213,7 @@ def addInPlace(self, value1, value2): class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + """ This handler will keep polling updates from the same socket until the server is shutdown. @@ -228,7 +232,9 @@ def handle(self): # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) + class AccumulatorServer(SocketServer.TCPServer): + """ A simple TCP server that intercepts shutdown() in order to interrupt our continuous polling on the handler. @@ -239,6 +245,7 @@ def shutdown(self): self.server_shutdown = True SocketServer.TCPServer.shutdown(self) + def _start_update_server(): """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 43f40f8783bfd..675a2fcd2ff4e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -21,18 +21,16 @@ >>> b = sc.broadcast([1, 2, 3, 4, 5]) >>> b.value [1, 2, 3, 4, 5] - ->>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.bid] = b ->>> from cPickle import dumps, loads ->>> loads(dumps(b)).value -[1, 2, 3, 4, 5] - >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +>>> b.unpersist() >>> large_broadcast = sc.broadcast(list(range(10000))) """ +import os + +from pyspark.serializers import CompressedSerializer, PickleSerializer + # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} @@ -45,23 +43,45 @@ def _from_id(bid): class Broadcast(object): + """ A broadcast variable created with L{SparkContext.broadcast()}. Access its value through C{.value}. """ - def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): + def __init__(self, bid, value, java_broadcast=None, + pickle_registry=None, path=None): """ Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ - self.value = value self.bid = bid + if path is None: + self.value = value self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry + self.path = path + + def unpersist(self, blocking=False): + self._jbroadcast.unpersist(blocking) + os.unlink(self.path) def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) + + def __getattr__(self, item): + if item == 'value' and self.path is not None: + ser = CompressedSerializer(PickleSerializer()) + value = ser.load_stream(open(self.path)).next() + self.value = value + return value + + raise AttributeError(item) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b4c82f519bd53..fb716f6753a45 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -56,6 +56,7 @@ class SparkConf(object): + """ Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2e80eb50f2207..a90870ed3a353 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer + PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD @@ -47,6 +47,7 @@ class SparkContext(object): + """ Main entry point for Spark functionality. A SparkContext represents the connection to a Spark cluster, and can be used to create L{RDD}s and @@ -213,7 +214,7 @@ def _ensure_initialized(cls, instance=None, gateway=None): if instance: if (SparkContext._active_spark_context and - SparkContext._active_spark_context != instance): + SparkContext._active_spark_context != instance): currentMaster = SparkContext._active_spark_context.master currentAppName = SparkContext._active_spark_context.appName callsite = SparkContext._active_spark_context._callsite @@ -406,7 +407,7 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, - keyConverter, valueConverter, minSplits, batchSize) + keyConverter, valueConverter, minSplits, batchSize) return RDD(jrdd, self, ser) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -437,7 +438,8 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf, batchSize) + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -465,7 +467,8 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf, batchSize) + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -496,7 +499,8 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, - valueClass, keyConverter, valueConverter, jconf, batchSize) + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, @@ -523,8 +527,9 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, jconf = self._dictToJavaMap(conf) batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() - jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, - keyConverter, valueConverter, jconf, batchSize) + jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, + valueClass, keyConverter, valueConverter, + jconf, batchSize) return RDD(jrdd, self, ser) def _checkpointFile(self, name, input_deserializer): @@ -555,21 +560,25 @@ def union(self, rdds): first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] rest = ListConverter().convert(rest, self._gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self, - rdds[0]._jrdd_deserializer) + return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer) def broadcast(self, value): """ Broadcast a read-only variable to the cluster, returning a L{Broadcast} - object for reading it in distributed functions. The variable will be - sent to each cluster only once. + object for reading it in distributed functions. The variable will + be sent to each cluster only once. + + :keep: Keep the `value` in driver or not. """ - pickleSer = PickleSerializer() - pickled = pickleSer.dumps(value) - jbroadcast = self._jsc.broadcast(bytearray(pickled)) - return Broadcast(jbroadcast.id(), value, jbroadcast, - self._pickled_broadcast_vars) + ser = CompressedSerializer(PickleSerializer()) + # pass large object by py4j is very slow and need much memory + tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) + ser.dump_stream([value], tempFile) + tempFile.close() + jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) + return Broadcast(jbroadcast.id(), None, jbroadcast, + self._pickled_broadcast_vars, tempFile.name) def accumulator(self, value, accum_param=None): """ @@ -610,7 +619,7 @@ def addFile(self, path): >>> def func(iterator): ... with open(SparkFiles.get("test.txt")) as testFile: ... fileVal = int(testFile.readline()) - ... return [x * 100 for x in iterator] + ... return [x * fileVal for x in iterator] >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() [100, 200, 300, 400] """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index b00da833d06f1..22ab8d30c0ae3 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -22,7 +22,8 @@ import socket import sys import traceback -from errno import EINTR, ECHILD +import time +from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main @@ -43,7 +44,7 @@ def worker(sock): """ # Redirect stdout to stderr os.dup2(2, 1) - sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 + sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 signal.signal(SIGHUP, SIG_DFL) signal.signal(SIGCHLD, SIG_DFL) @@ -80,6 +81,17 @@ def waitSocketClose(sock): os._exit(compute_real_exit_code(exit_code)) +# Cleanup zombie children +def cleanup_dead_children(): + try: + while True: + pid, _ = os.waitpid(0, os.WNOHANG) + if not pid: + break + except: + pass + + def manager(): # Create a new process group to corral our children os.setpgid(0, 0) @@ -102,29 +114,21 @@ def handle_sigterm(*args): signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP - # Cleanup zombie children - def handle_sigchld(*args): - try: - pid, status = os.waitpid(0, os.WNOHANG) - if status != 0: - msg = "worker %s crashed abruptly with exit status %s" % (pid, status) - print >> sys.stderr, msg - except EnvironmentError as err: - if err.errno not in (ECHILD, EINTR): - raise - signal.signal(SIGCHLD, handle_sigchld) - # Initialization complete sys.stdout.close() try: while True: try: - ready_fds = select.select([0, listen_sock], [], [])[0] + ready_fds = select.select([0, listen_sock], [], [], 1)[0] except select.error as ex: if ex[0] == EINTR: continue else: raise + + # cleanup in signal handler will cause deadlock + cleanup_dead_children() + if 0 in ready_fds: try: worker_pid = read_int(sys.stdin) @@ -134,33 +138,44 @@ def handle_sigchld(*args): try: os.kill(worker_pid, signal.SIGKILL) except OSError: - pass # process already died - + pass # process already died if listen_sock in ready_fds: - sock, addr = listen_sock.accept() + try: + sock, _ = listen_sock.accept() + except OSError as e: + if e.errno == EINTR: + continue + raise + # Launch a worker process try: pid = os.fork() - if pid == 0: - listen_sock.close() - try: - worker(sock) - except: - traceback.print_exc() - os._exit(1) - else: - os._exit(0) + except OSError as e: + if e.errno in (EAGAIN, EINTR): + time.sleep(1) + pid = os.fork() # error here will shutdown daemon else: + outfile = sock.makefile('w') + write_int(e.errno, outfile) # Signal that the fork failed + outfile.flush() + outfile.close() sock.close() - - except OSError as e: - print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e - outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) - write_int(-1, outfile) # Signal that the fork failed - outfile.flush() - outfile.close() + continue + + if pid == 0: + # in child process + listen_sock.close() + try: + worker(sock) + except: + traceback.print_exc() + os._exit(1) + else: + os._exit(0) + else: sock.close() + finally: shutdown(1) diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 57ee14eeb7776..331de9a9b2212 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -19,6 +19,7 @@ class SparkFiles(object): + """ Resolves paths to files added through L{SparkContext.addFile()}. diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 2c129679f47f3..c7f7c1fe591b0 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -39,7 +39,7 @@ def launch_gateway(): submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS") submit_args = submit_args if submit_args is not None else "" submit_args = shlex.split(submit_args) - command = [os.path.join(SPARK_HOME, script), "pyspark-shell"] + submit_args + command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"] if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): @@ -65,6 +65,7 @@ def preexec_func(): # Create a thread to echo output from the GatewayServer, which is required # for Java log output to show up: class EchoOutputThread(Thread): + def __init__(self, stream): Thread.__init__(self) self.daemon = True diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 9c1565affbdac..bb60d3d0c8463 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -16,6 +16,7 @@ # import struct +import sys import numpy from numpy import ndarray, float64, int64, int32, array_equal, array from pyspark import SparkContext, RDD @@ -72,12 +73,20 @@ # Python interpreter must agree on what endian the machine is. -DENSE_VECTOR_MAGIC = 1 +DENSE_VECTOR_MAGIC = 1 SPARSE_VECTOR_MAGIC = 2 -DENSE_MATRIX_MAGIC = 3 +DENSE_MATRIX_MAGIC = 3 LABELED_POINT_MAGIC = 4 +# Workaround for SPARK-2954: before Python 2.7, struct.unpack couldn't unpack bytearray()s. +if sys.version_info[:2] <= (2, 6): + def _unpack(fmt, string): + return struct.unpack(fmt, buffer(string)) +else: + _unpack = struct.unpack + + def _deserialize_numpy_array(shape, ba, offset, dtype=float64): """ Deserialize a numpy array of the given type from an offset in @@ -191,7 +200,7 @@ def _deserialize_double(ba, offset=0): raise TypeError("_deserialize_double called on a %s; wanted bytearray" % type(ba)) if len(ba) - offset != 8: raise TypeError("_deserialize_double called on a %d-byte array; wanted 8 bytes." % nb) - return struct.unpack("d", ba[offset:])[0] + return _unpack("d", ba[offset:])[0] def _deserialize_double_vector(ba, offset=0): @@ -443,6 +452,7 @@ def _serialize_rating(r): class RatingDeserializer(Serializer): + def loads(self, stream): length = struct.unpack("!i", stream.read(4))[0] ba = stream.read(length) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 5ec1a8084d269..ffdda7ee19302 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -31,6 +31,7 @@ class LogisticRegressionModel(LinearModel): + """A linear binary classification model derived from logistic regression. >>> data = [ @@ -60,6 +61,7 @@ class LogisticRegressionModel(LinearModel): >>> lrm.predict(SparseVector(2, {1: 0.0})) <= 0 True """ + def predict(self, x): _linear_predictor_typecheck(x, self._coeff) margin = _dot(x, self._coeff) + self._intercept @@ -72,6 +74,7 @@ def predict(self, x): class LogisticRegressionWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=1.0, regType=None, intercept=False): @@ -108,6 +111,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, class SVMModel(LinearModel): + """A support vector machine. >>> data = [ @@ -131,6 +135,7 @@ class SVMModel(LinearModel): >>> svm.predict(SparseVector(2, {0: -1.0})) <= 0 True """ + def predict(self, x): _linear_predictor_typecheck(x, self._coeff) margin = _dot(x, self._coeff) + self._intercept @@ -138,6 +143,7 @@ def predict(self, x): class SVMWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): @@ -173,6 +179,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, class NaiveBayesModel(object): + """ Model for Naive Bayes classifiers. @@ -213,6 +220,7 @@ def predict(self, x): class NaiveBayes(object): + @classmethod def train(cls, data, lambda_=1.0): """ diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index b380e8f6c8725..a0630d1d5c58b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -27,6 +27,7 @@ class KMeansModel(object): + """A clustering model derived from the k-means method. >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) @@ -55,6 +56,7 @@ class KMeansModel(object): >>> type(model.clusterCenters) """ + def __init__(self, centers): self.centers = centers @@ -76,6 +78,7 @@ def predict(self, x): class KMeans(object): + @classmethod def train(cls, data, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 54720c2324ca6..f485a69db1fa2 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -23,10 +23,12 @@ SciPy is available in their environment. """ +import numpy from numpy import array, array_equal, ndarray, float64, int32 class SparseVector(object): + """ A simple sparse vector class for passing data to MLlib. Users may alternatively pass SciPy's {scipy.sparse} data types. @@ -159,6 +161,15 @@ def squared_distance(self, other): j += 1 return result + def toArray(self): + """ + Returns a copy of this SparseVector as a 1-dimensional NumPy array. + """ + arr = numpy.zeros(self.size) + for i in xrange(self.indices.size): + arr[self.indices[i]] = self.values[i] + return arr + def __str__(self): inds = "[" + ",".join([str(i) for i in self.indices]) + "]" vals = "[" + ",".join([str(v) for v in self.values]) + "]" @@ -192,6 +203,7 @@ def __ne__(self, other): class Vectors(object): + """ Factory methods for working with vectors. Note that dense vectors are simply represented as NumPy array objects, so there is no need diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 36e710dbae7a8..4dc1a4a912421 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -24,7 +24,8 @@ from pyspark.mllib._common import _deserialize_double, _deserialize_double_vector from pyspark.serializers import NoOpSerializer -class RandomRDDGenerators: + +class RandomRDDs: """ Generator methods for creating RDDs comprised of i.i.d samples from some distribution. @@ -34,40 +35,40 @@ class RandomRDDGenerators: def uniformRDD(sc, size, numPartitions=None, seed=None): """ Generates an RDD comprised of i.i.d. samples from the - uniform distribution on [0.0, 1.0]. + uniform distribution U(0.0, 1.0). - To transform the distribution in the generated RDD from U[0.0, 1.0] - to U[a, b], use - C{RandomRDDGenerators.uniformRDD(sc, n, p, seed)\ + To transform the distribution in the generated RDD from U(0.0, 1.0) + to U(a, b), use + C{RandomRDDs.uniformRDD(sc, n, p, seed)\ .map(lambda v: a + (b - a) * v)} - >>> x = RandomRDDGenerators.uniformRDD(sc, 100).collect() + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() >>> len(x) 100 >>> max(x) <= 1.0 and min(x) >= 0.0 True - >>> RandomRDDGenerators.uniformRDD(sc, 100, 4).getNumPartitions() + >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() 4 - >>> parts = RandomRDDGenerators.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() >>> parts == sc.defaultParallelism True """ jrdd = sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) - uniform = RDD(jrdd, sc, NoOpSerializer()) + uniform = RDD(jrdd, sc, NoOpSerializer()) return uniform.map(lambda bytes: _deserialize_double(bytearray(bytes))) @staticmethod def normalRDD(sc, size, numPartitions=None, seed=None): """ - Generates an RDD comprised of i.i.d samples from the standard normal + Generates an RDD comprised of i.i.d. samples from the standard normal distribution. To transform the distribution in the generated RDD from standard normal - to some other normal N(mean, sigma), use - C{RandomRDDGenerators.normal(sc, n, p, seed)\ + to some other normal N(mean, sigma^2), use + C{RandomRDDs.normal(sc, n, p, seed)\ .map(lambda v: mean + sigma * v)} - >>> x = RandomRDDGenerators.normalRDD(sc, 1000, seed=1L) + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() 1000L @@ -77,17 +78,17 @@ def normalRDD(sc, size, numPartitions=None, seed=None): True """ jrdd = sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) - normal = RDD(jrdd, sc, NoOpSerializer()) + normal = RDD(jrdd, sc, NoOpSerializer()) return normal.map(lambda bytes: _deserialize_double(bytearray(bytes))) @staticmethod def poissonRDD(sc, mean, size, numPartitions=None, seed=None): """ - Generates an RDD comprised of i.i.d samples from the Poisson + Generates an RDD comprised of i.i.d. samples from the Poisson distribution with the input mean. >>> mean = 100.0 - >>> x = RandomRDDGenerators.poissonRDD(sc, mean, 1000, seed=1L) + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=1L) >>> stats = x.stats() >>> stats.count() 1000L @@ -98,37 +99,37 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None): True """ jrdd = sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) - poisson = RDD(jrdd, sc, NoOpSerializer()) + poisson = RDD(jrdd, sc, NoOpSerializer()) return poisson.map(lambda bytes: _deserialize_double(bytearray(bytes))) @staticmethod def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ - Generates an RDD comprised of vectors containing i.i.d samples drawn - from the uniform distribution on [0.0 1.0]. + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the uniform distribution U(0.0, 1.0). >>> import numpy as np - >>> mat = np.matrix(RandomRDDGenerators.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape (10, 10) >>> mat.max() <= 1.0 and mat.min() >= 0.0 True - >>> RandomRDDGenerators.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() 4 """ jrdd = sc._jvm.PythonMLLibAPI() \ .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - uniform = RDD(jrdd, sc, NoOpSerializer()) + uniform = RDD(jrdd, sc, NoOpSerializer()) return uniform.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) @staticmethod def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ - Generates an RDD comprised of vectors containing i.i.d samples drawn + Generates an RDD comprised of vectors containing i.i.d. samples drawn from the standard normal distribution. >>> import numpy as np - >>> mat = np.matrix(RandomRDDGenerators.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) >>> mat.shape (100, 100) >>> abs(mat.mean() - 0.0) < 0.1 @@ -138,18 +139,18 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): """ jrdd = sc._jvm.PythonMLLibAPI() \ .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - normal = RDD(jrdd, sc, NoOpSerializer()) + normal = RDD(jrdd, sc, NoOpSerializer()) return normal.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) @staticmethod def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ - Generates an RDD comprised of vectors containing i.i.d samples drawn + Generates an RDD comprised of vectors containing i.i.d. samples drawn from the Poisson distribution with the input mean. >>> import numpy as np >>> mean = 100.0 - >>> rdd = RandomRDDGenerators.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) >>> mat = np.mat(rdd.collect()) >>> mat.shape (100, 100) @@ -161,7 +162,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): """ jrdd = sc._jvm.PythonMLLibAPI() \ .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) - poisson = RDD(jrdd, sc, NoOpSerializer()) + poisson = RDD(jrdd, sc, NoOpSerializer()) return poisson.map(lambda bytes: _deserialize_double_vector(bytearray(bytes))) diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 6c385042ffa5f..e863fc249ec36 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -26,6 +26,7 @@ class MatrixFactorizationModel(object): + """A matrix factorisation model trained by regularized alternating least-squares. @@ -58,6 +59,7 @@ def predictAll(self, usersProducts): class ALS(object): + @classmethod def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): sc = ratings.context diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 041b119269427..d8792cf44872f 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -27,6 +27,7 @@ class LabeledPoint(object): + """ The features and labels of a data point. @@ -34,6 +35,7 @@ class LabeledPoint(object): @param features: Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) """ + def __init__(self, label, features): self.label = label if (type(features) == ndarray or type(features) == SparseVector @@ -49,7 +51,9 @@ def __str__(self): class LinearModel(object): + """A linear model that has a vector of coefficients and an intercept.""" + def __init__(self, weights, intercept): self._coeff = weights self._intercept = intercept @@ -64,6 +68,7 @@ def intercept(self): class LinearRegressionModelBase(LinearModel): + """A linear regression model. >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) @@ -72,6 +77,7 @@ class LinearRegressionModelBase(LinearModel): >>> abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6 True """ + def predict(self, x): """Predict the value of the dependent variable given a vector x""" """containing values for the independent variables.""" @@ -80,6 +86,7 @@ def predict(self, x): class LinearRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit. >>> from pyspark.mllib.regression import LabeledPoint @@ -111,6 +118,7 @@ class LinearRegressionModel(LinearRegressionModelBase): class LinearRegressionWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=1.0, regType=None, intercept=False): @@ -146,6 +154,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, class LassoModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an l_1 penalty term. @@ -178,6 +187,7 @@ class LassoModel(LinearRegressionModelBase): class LassoWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): @@ -189,6 +199,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0, class RidgeRegressionModel(LinearRegressionModelBase): + """A linear regression model derived from a least-squares fit with an l_2 penalty term. @@ -221,6 +232,7 @@ class RidgeRegressionModel(LinearRegressionModelBase): class RidgeRegressionWithSGD(object): + @classmethod def train(cls, data, iterations=100, step=1.0, regParam=1.0, miniBatchFraction=1.0, initialWeights=None): diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index 0a08a562d1f1f..feef0d16cd644 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -22,10 +22,75 @@ from pyspark.mllib._common import \ _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ _serialize_double, _serialize_double_vector, \ - _deserialize_double, _deserialize_double_matrix + _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector + + +class MultivariateStatisticalSummary(object): + + """ + Trait for multivariate statistical summary of a data matrix. + """ + + def __init__(self, sc, java_summary): + """ + :param sc: Spark context + :param java_summary: Handle to Java summary object + """ + self._sc = sc + self._java_summary = java_summary + + def __del__(self): + self._sc._gateway.detach(self._java_summary) + + def mean(self): + return _deserialize_double_vector(self._java_summary.mean()) + + def variance(self): + return _deserialize_double_vector(self._java_summary.variance()) + + def count(self): + return self._java_summary.count() + + def numNonzeros(self): + return _deserialize_double_vector(self._java_summary.numNonzeros()) + + def max(self): + return _deserialize_double_vector(self._java_summary.max()) + + def min(self): + return _deserialize_double_vector(self._java_summary.min()) + class Statistics(object): + @staticmethod + def colStats(X): + """ + Computes column-wise summary statistics for the input RDD[Vector]. + + >>> from linalg import Vectors + >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), + ... Vectors.dense([4, 5, 0, 3]), + ... Vectors.dense([6, 7, 0, 8])]) + >>> cStats = Statistics.colStats(rdd) + >>> cStats.mean() + array([ 4., 4., 0., 3.]) + >>> cStats.variance() + array([ 4., 13., 0., 25.]) + >>> cStats.count() + 3L + >>> cStats.numNonzeros() + array([ 3., 2., 0., 3.]) + >>> cStats.max() + array([ 6., 7., 0., 8.]) + >>> cStats.min() + array([ 2., 0., 0., -2.]) + """ + sc = X.ctx + Xser = _get_unmangled_double_vector_rdd(X) + cStats = sc._jvm.PythonMLLibAPI().colStats(Xser._jrdd) + return MultivariateStatisticalSummary(sc, cStats) + @staticmethod def corr(x, y=None, method=None): """ @@ -53,16 +118,18 @@ def corr(x, y=None, method=None): >>> from linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]), ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])]) - >>> Statistics.corr(rdd) - array([[ 1. , 0.05564149, nan, 0.40047142], - [ 0.05564149, 1. , nan, 0.91359586], - [ nan, nan, 1. , nan], - [ 0.40047142, 0.91359586, nan, 1. ]]) - >>> Statistics.corr(rdd, method="spearman") - array([[ 1. , 0.10540926, nan, 0.4 ], - [ 0.10540926, 1. , nan, 0.9486833 ], - [ nan, nan, 1. , nan], - [ 0.4 , 0.9486833 , nan, 1. ]]) + >>> pearsonCorr = Statistics.corr(rdd) + >>> print str(pearsonCorr).replace('nan', 'NaN') + [[ 1. 0.05564149 NaN 0.40047142] + [ 0.05564149 1. NaN 0.91359586] + [ NaN NaN 1. NaN] + [ 0.40047142 0.91359586 NaN 1. ]] + >>> spearmanCorr = Statistics.corr(rdd, method="spearman") + >>> print str(spearmanCorr).replace('nan', 'NaN') + [[ 1. 0.10540926 NaN 0.4 ] + [ 0.10540926 1. NaN 0.9486833 ] + [ NaN NaN 1. NaN] + [ 0.4 0.9486833 NaN 1. ]] >>> try: ... Statistics.corr(rdd, "spearman") ... print "Method name as second argument without 'method=' shouldn't be allowed." diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 9d1e5be637a9a..8a851bd35c0e8 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -19,8 +19,13 @@ Fuller unit tests for Python MLlib. """ +import sys from numpy import array, array_equal -import unittest + +if sys.version_info[:2] <= (2, 6): + import unittest2 as unittest +else: + import unittest from pyspark.mllib._common import _convert_vector, _serialize_double_vector, \ _deserialize_double_vector, _dot, _squared_distance @@ -39,6 +44,7 @@ class VectorTests(unittest.TestCase): + def test_serialize(self): sv = SparseVector(4, {1: 1, 3: 2}) dv = array([1., 2., 3., 4.]) @@ -81,6 +87,7 @@ def test_squared_distance(self): class ListTests(PySparkTestCase): + """ Test MLlib algorithms on plain lists, to make sure they're passed through as NumPy arrays. @@ -128,7 +135,7 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories dt_model = \ DecisionTree.trainClassifier(rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) @@ -168,7 +175,7 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories dt_model = \ DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) @@ -179,6 +186,7 @@ def test_regression(self): @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): + """ Test both vector operations and MLlib algorithms with SciPy sparse matrices, if SciPy is available. @@ -276,7 +284,7 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) @@ -315,7 +323,7 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) - categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 1e0006df75ac6..e9d778df5a24b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -25,7 +25,9 @@ from pyspark.mllib.regression import LabeledPoint from pyspark.serializers import NoOpSerializer + class DecisionTreeModel(object): + """ A decision tree model for classification or regression. @@ -77,6 +79,7 @@ def __str__(self): class DecisionTree(object): + """ Learning algorithm for a decision tree model for classification or regression. @@ -85,7 +88,8 @@ class DecisionTree(object): It will probably be modified for Spark v1.2. Example usage: - >>> from numpy import array, ndarray + >>> from numpy import array + >>> import sys >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree >>> from pyspark.mllib.linalg import SparseVector @@ -96,15 +100,15 @@ class DecisionTree(object): ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] - >>> - >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2) - >>> print(model) + >>> categoricalFeaturesInfo = {} # no categorical features + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2, + ... categoricalFeaturesInfo=categoricalFeaturesInfo) + >>> sys.stdout.write(model) DecisionTreeModel classifier If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 True >>> model.predict(array([0.0])) == 0 @@ -116,7 +120,8 @@ class DecisionTree(object): ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> - >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data)) + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), + ... categoricalFeaturesInfo=categoricalFeaturesInfo) >>> model.predict(array([0.0, 1.0])) == 1 True >>> model.predict(array([0.0, 0.0])) == 0 @@ -128,7 +133,7 @@ class DecisionTree(object): """ @staticmethod - def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + def trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for classification. @@ -147,12 +152,20 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "classification", numClasses, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, "classification", + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) @staticmethod - def trainRegressor(data, categoricalFeaturesInfo={}, + def trainRegressor(data, categoricalFeaturesInfo, impurity="variance", maxDepth=4, maxBins=100): """ Train a DecisionTreeModel for regression. @@ -170,43 +183,14 @@ def trainRegressor(data, categoricalFeaturesInfo={}, :param maxBins: Number of bins used for finding splits at each node. :return: DecisionTreeModel """ - return DecisionTree.train(data, "regression", 0, - categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - - - @staticmethod - def train(data, algo, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins=100): - """ - Train a DecisionTreeModel for classification or regression. - - :param data: Training data: RDD of LabeledPoint. - For classification, labels are integers - {0,1,...,numClasses}. - For regression, labels are real numbers. - :param algo: "classification" or "regression" - :param numClasses: Number of classes for classification. - :param categoricalFeaturesInfo: Map from categorical feature index - to number of categories. - Any feature not in this map - is treated as continuous. - :param impurity: For classification: "entropy" or "gini". - For regression: "variance". - :param maxDepth: Max depth of tree. - E.g., depth 0 means 1 leaf node. - Depth 1 means 1 internal node + 2 leaf nodes. - :param maxBins: Number of bins used for finding splits at each node. - :return: DecisionTreeModel - """ sc = data.context dataBytes = _get_unmangled_labeled_point_rdd(data) categoricalFeaturesInfoJMap = \ MapConverter().convert(categoricalFeaturesInfo, sc._gateway._gateway_client) model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - dataBytes._jrdd, algo, - numClasses, categoricalFeaturesInfoJMap, + dataBytes._jrdd, "regression", + 0, categoricalFeaturesInfoJMap, impurity, maxDepth, maxBins) dataBytes.unpersist() return DecisionTreeModel(sc, model) diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 639cda6350229..4962d05491c03 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -26,6 +26,7 @@ class MLUtils: + """ Helper methods to load, save and pre-process data used in MLlib. """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 309f5a9b6038d..3eefc878d274e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -30,12 +30,13 @@ from threading import Thread import warnings import heapq +import bisect from random import Random from math import sqrt, log from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long + PickleSerializer, pack_long, CompressedSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -134,6 +135,7 @@ class MaxHeapQ(object): """ An implementation of MaxHeap. + >>> import pyspark.rdd >>> heap = pyspark.rdd.MaxHeapQ(5) >>> [heap.insert(i) for i in range(10)] @@ -233,7 +235,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): def _toPickleSerialization(self): if (self._jrdd_deserializer == PickleSerializer() or - self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): + self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): return self else: return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) @@ -381,6 +383,7 @@ def mapPartitionsWithSplit(self, f, preservesPartitioning=False): def getNumPartitions(self): """ Returns the number of partitions in RDD + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) >>> rdd.getNumPartitions() 2 @@ -570,7 +573,12 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. # noqa + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey().first() + ('1', 3) + >>> sc.parallelize(tmp).sortByKey(True, 1).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> sc.parallelize(tmp).sortByKey(True, 2).collect() [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)] @@ -581,42 +589,36 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - bounds = list() + def sortPartition(iterator): + return iter(sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=not ascending)) + + if numPartitions == 1: + if self.getNumPartitions() > 1: + self = self.coalesce(1) + return self.mapPartitions(sortPartition) # first compute the boundary of each part via sampling: we want to partition # the key-space into bins such that the bins have roughly the same # number of (key, value) pairs falling into them - if numPartitions > 1: - rddSize = self.count() - # constant from Spark's RangePartitioner - maxSampleSize = numPartitions * 20.0 - fraction = min(maxSampleSize / max(rddSize, 1), 1.0) - - samples = self.sample(False, fraction, 1).map( - lambda (k, v): k).collect() - samples = sorted(samples, reverse=(not ascending), key=keyfunc) - - # we have numPartitions many parts but one of the them has - # an implicit boundary - for i in range(0, numPartitions - 1): - index = (len(samples) - 1) * (i + 1) / numPartitions - bounds.append(samples[index]) - - def rangePartitionFunc(k): - p = 0 - while p < len(bounds) and keyfunc(k) > bounds[p]: - p += 1 + rddSize = self.count() + maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner + fraction = min(maxSampleSize / max(rddSize, 1), 1.0) + samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = sorted(samples, reverse=(not ascending), key=keyfunc) + + # we have numPartitions many parts but one of the them has + # an implicit boundary + bounds = [samples[len(samples) * (i + 1) / numPartitions] + for i in range(0, numPartitions - 1)] + + def rangePartitioner(k): + p = bisect.bisect_left(bounds, keyfunc(k)) if ascending: return p else: return numPartitions - 1 - p - def mapFunc(iterator): - yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) - - return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc) - .mapPartitions(mapFunc, preservesPartitioning=True) - .flatMap(lambda x: x, preservesPartitioning=True)) + return self.partitionBy(numPartitions, rangePartitioner).mapPartitions(sortPartition, True) def sortBy(self, keyfunc, ascending=True, numPartitions=None): """ @@ -1079,7 +1081,9 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, - outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) + outputFormatClass, + keyClass, valueClass, + keyConverter, valueConverter, jconf) def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): """ @@ -1125,8 +1129,10 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No pickledRDD = self._toPickleSerialization() batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, - outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, - jconf, compressionCodecClass) + outputFormatClass, + keyClass, valueClass, + keyConverter, valueConverter, + jconf, compressionCodecClass) def saveAsSequenceFile(self, path, compressionCodecClass=None): """ @@ -1183,7 +1189,9 @@ def func(split, iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) - yield x.encode("utf-8") + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x keyed = self.mapPartitionsWithIndex(func) keyed._bypass_serializer = True keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path) @@ -1205,6 +1213,7 @@ def collectAsMap(self): def keys(self): """ Return an RDD with the keys of each tuple. + >>> m = sc.parallelize([(1, 2), (3, 4)]).keys() >>> m.collect() [1, 3] @@ -1214,6 +1223,7 @@ def keys(self): def values(self): """ Return an RDD with the values of each tuple. + >>> m = sc.parallelize([(1, 2), (3, 4)]).values() >>> m.collect() [2, 4] @@ -1348,7 +1358,7 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): outputSerializer = self.ctx._unbatched_serializer limit = (_parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) / 2) + "spark.python.worker.memory", "512m")) / 2) def add_shuffle_key(split, iterator): @@ -1430,12 +1440,12 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + "spark.python.worker.memory", "512m")) agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) + if spill else InMemoryMerger(agg) merger.mergeValues(iterator) return merger.iteritems() @@ -1444,7 +1454,7 @@ def combineLocally(iterator): def _mergeCombiners(iterator): merger = ExternalMerger(agg, memory, serializer) \ - if spill else InMemoryMerger(agg) + if spill else InMemoryMerger(agg) merger.mergeCombiners(iterator) return merger.iteritems() @@ -1588,7 +1598,7 @@ def sampleByKey(self, withReplacement, fractions, seed=None): """ for fraction in fractions.values(): assert fraction >= 0.0, "Negative fraction value: %s" % fraction - return self.mapPartitionsWithIndex( \ + return self.mapPartitionsWithIndex( RDDStratifiedSampler(withReplacement, fractions, seed).func, True) def subtractByKey(self, other, numPartitions=None): @@ -1638,6 +1648,7 @@ def repartition(self, numPartitions): Internally, this uses a shuffle to redistribute data. If you are decreasing the number of partitions in this RDD, consider using `coalesce`, which can avoid performing a shuffle. + >>> rdd = sc.parallelize([1,2,3,4,5,6,7], 4) >>> sorted(rdd.glom().collect()) [[1], [2, 3], [4, 5], [6, 7]] @@ -1652,6 +1663,7 @@ def repartition(self, numPartitions): def coalesce(self, numPartitions, shuffle=False): """ Return a new RDD that is reduced into `numPartitions` partitions. + >>> sc.parallelize([1, 2, 3, 4, 5], 3).glom().collect() [[1], [2, 3], [4, 5]] >>> sc.parallelize([1, 2, 3, 4, 5], 3).coalesce(1).glom().collect() @@ -1673,6 +1685,31 @@ def zip(self, other): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") + + def get_batch_size(ser): + if isinstance(ser, BatchedSerializer): + return ser.batchSize + return 0 + + def batch_as(rdd, batchSize): + ser = rdd._jrdd_deserializer + if isinstance(ser, BatchedSerializer): + ser = ser.serializer + return rdd._reserialize(BatchedSerializer(ser, batchSize)) + + my_batch = get_batch_size(self._jrdd_deserializer) + other_batch = get_batch_size(other._jrdd_deserializer) + if my_batch != other_batch: + # use the greatest batchSize to batch the other one. + if my_batch > other_batch: + other = batch_as(other, my_batch) + else: + self = batch_as(self, other_batch) + + # There will be an Exception in JVM if there are different number + # of items in each partitions. pairRDD = self._jrdd.zip(other._jrdd) deserializer = PairDeserializer(self._jrdd_deserializer, other._jrdd_deserializer) @@ -1690,6 +1727,7 @@ def name(self): def setName(self, name): """ Assign a name to this RDD. + >>> rdd1 = sc.parallelize([1,2]) >>> rdd1.setName('RDD1') >>> rdd1.name() @@ -1749,6 +1787,7 @@ class PipelinedRDD(RDD): """ Pipelined maps: + >>> rdd = sc.parallelize([1, 2, 3, 4]) >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect() [4, 8, 12, 16] @@ -1796,7 +1835,8 @@ def _jrdd(self): self._jrdd_deserializer = NoOpSerializer() command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) - pickled_command = CloudPickleSerializer().dumps(command) + ser = CloudPickleSerializer() + pickled_command = ser.dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 2df000fdb08ca..55e247da0e4dc 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -20,6 +20,7 @@ class RDDSamplerBase(object): + def __init__(self, withReplacement, seed=None): try: import numpy @@ -95,6 +96,7 @@ def shuffle(self, vals): class RDDSampler(RDDSamplerBase): + def __init__(self, withReplacement, fraction, seed=None): RDDSamplerBase.__init__(self, withReplacement, seed) self._fraction = fraction @@ -113,7 +115,9 @@ def func(self, split, iterator): if self.getUniformSample(split) <= self._fraction: yield obj + class RDDStratifiedSampler(RDDSamplerBase): + def __init__(self, withReplacement, fractions, seed=None): RDDSamplerBase.__init__(self, withReplacement, seed) self._fractions = fractions diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index df34740fc8176..ef04c82866e6c 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -21,9 +21,11 @@ class ResultIterable(collections.Iterable): + """ A special result iterable. This is used because the standard iterator can not be pickled """ + def __init__(self, data): self.data = data self.index = 0 diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a10f85b55ad30..fc49aa42dbaf9 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -67,6 +67,7 @@ import sys import types import collections +import zlib from pyspark import cloudpickle @@ -111,6 +112,7 @@ def __ne__(self, other): class FramedSerializer(Serializer): + """ Serializer that writes objects as a stream of (length, data) pairs, where C{length} is a 32-bit integer and data is C{length} bytes. @@ -162,6 +164,7 @@ def loads(self, obj): class BatchedSerializer(Serializer): + """ Serializes a stream of objects in batches by calling its wrapped Serializer with streams of objects. @@ -207,6 +210,7 @@ def __str__(self): class CartesianDeserializer(FramedSerializer): + """ Deserializes the JavaRDD cartesian() of two PythonRDDs. """ @@ -240,6 +244,7 @@ def __str__(self): class PairDeserializer(CartesianDeserializer): + """ Deserializes the JavaRDD zip() of two PythonRDDs. """ @@ -250,6 +255,9 @@ def __init__(self, key_ser, val_ser): def load_stream(self, stream): for (keys, vals) in self.prepare_keys_values(stream): + if len(keys) != len(vals): + raise ValueError("Can not deserialize RDD with different number of items" + " in pair: (%d, %d)" % (len(keys), len(vals))) for pair in izip(keys, vals): yield pair @@ -289,6 +297,7 @@ def _hack_namedtuple(cls): """ Make class generated by namedtuple picklable """ name = cls.__name__ fields = cls._fields + def __reduce__(self): return (_restore, (name, fields, tuple(self))) cls.__reduce__ = __reduce__ @@ -301,15 +310,16 @@ def _hijack_namedtuple(): if hasattr(collections.namedtuple, "__hijack"): return - global _old_namedtuple # or it will put in closure + global _old_namedtuple # or it will put in closure + def _copy_func(f): return types.FunctionType(f.func_code, f.func_globals, f.func_name, - f.func_defaults, f.func_closure) + f.func_defaults, f.func_closure) _old_namedtuple = _copy_func(collections.namedtuple) - def namedtuple(name, fields, verbose=False, rename=False): - cls = _old_namedtuple(name, fields, verbose, rename) + def namedtuple(*args, **kwargs): + cls = _old_namedtuple(*args, **kwargs) return _hack_namedtuple(cls) # replace namedtuple with new one @@ -323,15 +333,16 @@ def namedtuple(name, fields, verbose=False, rename=False): # so only hack those in __main__ module for n, o in sys.modules["__main__"].__dict__.iteritems(): if (type(o) is type and o.__base__ is tuple - and hasattr(o, "_fields") - and "__reduce__" not in o.__dict__): - _hack_namedtuple(o) # hack inplace + and hasattr(o, "_fields") + and "__reduce__" not in o.__dict__): + _hack_namedtuple(o) # hack inplace _hijack_namedtuple() class PickleSerializer(FramedSerializer): + """ Serializes objects using Python's cPickle serializer: @@ -354,6 +365,7 @@ def dumps(self, obj): class MarshalSerializer(FramedSerializer): + """ Serializes objects using Python's Marshal serializer: @@ -367,9 +379,11 @@ class MarshalSerializer(FramedSerializer): class AutoSerializer(FramedSerializer): + """ Choose marshal or cPickle as serialization protocol autumatically """ + def __init__(self): FramedSerializer.__init__(self) self._type = None @@ -393,7 +407,24 @@ def loads(self, obj): raise ValueError("invalid sevialization type: %s" % _type) +class CompressedSerializer(FramedSerializer): + """ + compress the serialized data + """ + + def __init__(self, serializer): + FramedSerializer.__init__(self) + self.serializer = serializer + + def dumps(self, obj): + return zlib.compress(self.serializer.dumps(obj), 1) + + def loads(self, obj): + return self.serializer.loads(zlib.decompress(obj)) + + class UTF8Deserializer(Serializer): + """ Deserializes streams written by String.getBytes. """ diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index e3923d1c36c57..1ebe7df418327 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -45,7 +45,7 @@ def get_used_memory(): return int(line.split()[1]) >> 10 else: warnings.warn("Please install psutil to have better " - "support with spilling") + "support with spilling") if platform.system() == "Darwin": import resource rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss @@ -141,7 +141,7 @@ class ExternalMerger(Merger): This class works as follows: - - It repeatedly combine the items and save them in one dict in + - It repeatedly combine the items and save them in one dict in memory. - When the used memory goes above memory limit, it will split @@ -190,12 +190,12 @@ class ExternalMerger(Merger): MAX_TOTAL_PARTITIONS = 4096 def __init__(self, aggregator, memory_limit=512, serializer=None, - localdirs=None, scale=1, partitions=59, batch=1000): + localdirs=None, scale=1, partitions=59, batch=1000): Merger.__init__(self, aggregator) self.memory_limit = memory_limit # default serializer is only used for tests self.serializer = serializer or \ - BatchedSerializer(PickleSerializer(), 1024) + BatchedSerializer(PickleSerializer(), 1024) self.localdirs = localdirs or self._get_dirs() # number of partitions when spill data into disks self.partitions = partitions @@ -214,7 +214,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, def _get_dirs(self): """ Get all the directories """ - path = os.environ.get("SPARK_LOCAL_DIR", "/tmp") + path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp") dirs = path.split(",") return [os.path.join(d, "python", str(os.getpid()), str(id(self))) for d in dirs] @@ -341,7 +341,7 @@ def _spill(self): self.pdata[i].clear() self.spills += 1 - gc.collect() # release the memory as much as possible + gc.collect() # release the memory as much as possible def iteritems(self): """ Return all merged items as iterator """ @@ -370,8 +370,8 @@ def _external_items(self): if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS and j < self.spills - 1 and get_used_memory() > hard_limit): - self.data.clear() # will read from disk again - gc.collect() # release the memory as much as possible + self.data.clear() # will read from disk again + gc.collect() # release the memory as much as possible for v in self._recursive_merged_items(i): yield v return @@ -409,9 +409,9 @@ def _recursive_merged_items(self, start): for i in range(start, self.partitions): subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] + for d in self.localdirs] m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions) + subdirs, self.scale * self.partitions) m.pdata = [{} for _ in range(self.partitions)] limit = self._next_limit() @@ -419,7 +419,7 @@ def _recursive_merged_items(self, start): path = self._get_spill_dir(j) p = os.path.join(path, str(i)) m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) + self.serializer.load_stream(open(p))) if get_used_memory() > limit: m._spill() diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f1093701ddc89..d4ca0cc8f336e 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -45,6 +45,7 @@ class DataType(object): + """Spark SQL DataType""" def __repr__(self): @@ -62,6 +63,7 @@ def __ne__(self, other): class PrimitiveTypeSingleton(type): + """Metaclass for PrimitiveType""" _instances = {} @@ -73,6 +75,7 @@ def __call__(cls): class PrimitiveType(DataType): + """Spark SQL PrimitiveType""" __metaclass__ = PrimitiveTypeSingleton @@ -83,6 +86,7 @@ def __eq__(self, other): class StringType(PrimitiveType): + """Spark SQL StringType The data type representing string values. @@ -90,6 +94,7 @@ class StringType(PrimitiveType): class BinaryType(PrimitiveType): + """Spark SQL BinaryType The data type representing bytearray values. @@ -97,6 +102,7 @@ class BinaryType(PrimitiveType): class BooleanType(PrimitiveType): + """Spark SQL BooleanType The data type representing bool values. @@ -104,6 +110,7 @@ class BooleanType(PrimitiveType): class TimestampType(PrimitiveType): + """Spark SQL TimestampType The data type representing datetime.datetime values. @@ -111,6 +118,7 @@ class TimestampType(PrimitiveType): class DecimalType(PrimitiveType): + """Spark SQL DecimalType The data type representing decimal.Decimal values. @@ -118,6 +126,7 @@ class DecimalType(PrimitiveType): class DoubleType(PrimitiveType): + """Spark SQL DoubleType The data type representing float values. @@ -125,6 +134,7 @@ class DoubleType(PrimitiveType): class FloatType(PrimitiveType): + """Spark SQL FloatType The data type representing single precision floating-point values. @@ -132,6 +142,7 @@ class FloatType(PrimitiveType): class ByteType(PrimitiveType): + """Spark SQL ByteType The data type representing int values with 1 singed byte. @@ -139,6 +150,7 @@ class ByteType(PrimitiveType): class IntegerType(PrimitiveType): + """Spark SQL IntegerType The data type representing int values. @@ -146,6 +158,7 @@ class IntegerType(PrimitiveType): class LongType(PrimitiveType): + """Spark SQL LongType The data type representing long values. If the any value is @@ -155,6 +168,7 @@ class LongType(PrimitiveType): class ShortType(PrimitiveType): + """Spark SQL ShortType The data type representing int values with 2 signed bytes. @@ -162,6 +176,7 @@ class ShortType(PrimitiveType): class ArrayType(DataType): + """Spark SQL ArrayType The data type representing list values. An ArrayType object @@ -187,10 +202,11 @@ def __init__(self, elementType, containsNull=False): def __str__(self): return "ArrayType(%s,%s)" % (self.elementType, - str(self.containsNull).lower()) + str(self.containsNull).lower()) class MapType(DataType): + """Spark SQL MapType The data type representing dict values. A MapType object comprises @@ -226,10 +242,11 @@ def __init__(self, keyType, valueType, valueContainsNull=True): def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, - str(self.valueContainsNull).lower()) + str(self.valueContainsNull).lower()) class StructField(DataType): + """Spark SQL StructField Represents a field in a StructType. @@ -263,10 +280,11 @@ def __init__(self, name, dataType, nullable): def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, - str(self.nullable).lower()) + str(self.nullable).lower()) class StructType(DataType): + """Spark SQL StructType The data type representing rows. @@ -291,7 +309,7 @@ def __init__(self, fields): def __repr__(self): return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) + ",".join(str(field) for field in self.fields)) def _parse_datatype_list(datatype_list_string): @@ -319,7 +337,7 @@ def _parse_datatype_list(datatype_list_string): _all_primitive_types = dict((k, v) for k, v in globals().iteritems() - if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) + if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) def _parse_datatype_string(datatype_string): @@ -459,16 +477,16 @@ def _infer_schema(row): items = sorted(row.items()) elif isinstance(row, tuple): - if hasattr(row, "_fields"): # namedtuple + if hasattr(row, "_fields"): # namedtuple items = zip(row._fields, tuple(row)) - elif hasattr(row, "__FIELDS__"): # Row + elif hasattr(row, "__FIELDS__"): # Row items = zip(row.__FIELDS__, tuple(row)) elif all(isinstance(x, tuple) and len(x) == 2 for x in row): items = row else: raise ValueError("Can't infer schema from tuple") - elif hasattr(row, "__dict__"): # object + elif hasattr(row, "__dict__"): # object items = sorted(row.__dict__.items()) else: @@ -480,10 +498,7 @@ def _infer_schema(row): def _create_converter(obj, dataType): """Create an converter to drop the names of fields in obj """ - if not _has_struct(dataType): - return lambda x: x - - elif isinstance(dataType, ArrayType): + if isinstance(dataType, ArrayType): conv = _create_converter(obj[0], dataType.elementType) return lambda row: map(conv, row) @@ -492,6 +507,9 @@ def _create_converter(obj, dataType): conv = _create_converter(value, dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif not isinstance(dataType, StructType): + return lambda x: x + # dataType must be StructType names = [f.name for f in dataType.fields] @@ -499,7 +517,7 @@ def _create_converter(obj, dataType): conv = lambda o: tuple(o.get(n) for n in names) elif isinstance(obj, tuple): - if hasattr(obj, "_fields"): # namedtuple + if hasattr(obj, "_fields"): # namedtuple conv = tuple elif hasattr(obj, "__FIELDS__"): conv = tuple @@ -508,11 +526,10 @@ def _create_converter(obj, dataType): else: raise ValueError("unexpected tuple") - elif hasattr(obj, "__dict__"): # object + elif hasattr(obj, "__dict__"): # object conv = lambda o: [o.__dict__.get(n, None) for n in names] - nested = any(_has_struct(f.dataType) for f in dataType.fields) - if not nested: + if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): return conv row = conv(obj) @@ -660,7 +677,7 @@ def _infer_schema_type(obj, dataType): assert len(fs) == len(obj), \ "Obj(%s) have different length with fields(%s)" % (obj, fs) fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True) - for o, f in zip(obj, fs)] + for o, f in zip(obj, fs)] return StructType(fields) else: @@ -683,6 +700,7 @@ def _infer_schema_type(obj, dataType): StructType: (tuple, list), } + def _verify_type(obj, dataType): """ Verify the type of obj against dataType, raise an exception if @@ -728,7 +746,7 @@ def _verify_type(obj, dataType): elif isinstance(dataType, StructType): if len(obj) != len(dataType.fields): raise ValueError("Length of object (%d) does not match with" - "length of fields (%d)" % (len(obj), len(dataType.fields))) + "length of fields (%d)" % (len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) @@ -861,6 +879,7 @@ def __reduce__(self): raise Exception("unexpected data type: %s" % dataType) class Row(tuple): + """ Row in SchemaRDD """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) @@ -872,7 +891,7 @@ class Row(tuple): def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__FIELDS__)) + for n in self.__FIELDS__)) def __reduce__(self): return (_restore_object, (self.__DATATYPE__, tuple(self))) @@ -881,6 +900,7 @@ def __reduce__(self): class SQLContext: + """Main entry point for SparkSQL functionality. A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as @@ -891,6 +911,8 @@ def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. @param sparkContext: The SparkContext to wrap. + @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + SQLContext in the JVM, instead we make all calls to this object. >>> srdd = sqlCtx.inferSchema(rdd) >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL @@ -960,7 +982,7 @@ def registerFunction(self, name, f, returnType=StringType()): env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, - self._sc._gateway._gateway_client) + self._sc._gateway._gateway_client) self._ssql_ctx.registerPython(name, bytearray(CloudPickleSerializer().dumps(command)), env, @@ -1012,9 +1034,10 @@ def inferSchema(self, rdd): first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " - "can not infer schema") + "can not infer schema") if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated") + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.Row instead") schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) @@ -1070,8 +1093,8 @@ def applySchema(self, rdd, schema): >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.1 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)] + ... "float + 1.5 as float FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), @@ -1231,13 +1254,22 @@ def jsonRDD(self, rdd, schema=None): ... "field3.field5[0] as f3 from table3") >>> srdd6.collect() [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] + + >>> sqlCtx.jsonRDD(sc.parallelize(['{}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] + >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', + ... '{"key0": {"key1": "value1"}}'])).collect() + [Row(key0=None), Row(key0=Row(key1=u'value1'))] """ def func(iterator): for x in iterator: if not isinstance(x, basestring): x = unicode(x) - yield x.encode("utf-8") + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x keyed = rdd.mapPartitions(func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) @@ -1280,12 +1312,25 @@ def uncacheTable(self, tableName): class HiveContext(SQLContext): + """A variant of Spark SQL that integrates with data stored in Hive. Configuration for Hive is read from hive-site.xml on the classpath. It supports running both SQL and HiveQL commands. """ + def __init__(self, sparkContext, hiveContext=None): + """Create a new HiveContext. + + @param sparkContext: The SparkContext to wrap. + @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + HiveContext in the JVM, instead we make all calls to this object. + """ + SQLContext.__init__(self, sparkContext) + + if hiveContext: + self._scala_HiveContext = hiveContext + @property def _ssql_ctx(self): try: @@ -1320,6 +1365,7 @@ def hql(self, hqlQuery): class LocalHiveContext(HiveContext): + """Starts up an instance of hive where metadata is stored locally. An in-process metadata data is created with data stored in ./metadata. @@ -1350,7 +1396,7 @@ class LocalHiveContext(HiveContext): def __init__(self, sparkContext, sqlContext=None): HiveContext.__init__(self, sparkContext, sqlContext) warnings.warn("LocalHiveContext is deprecated. " - "Use HiveContext instead.", DeprecationWarning) + "Use HiveContext instead.", DeprecationWarning) def _get_hive_ctx(self): return self._jvm.LocalHiveContext(self._jsc.sc()) @@ -1369,6 +1415,7 @@ def _create_row(fields, values): class Row(tuple): + """ A row in L{SchemaRDD}. The fields in it can be accessed like attributes. @@ -1410,7 +1457,6 @@ def __new__(self, *args, **kwargs): else: raise ValueError("No args or kwargs") - # let obect acs like class def __call__(self, *args): """create new Row object""" @@ -1436,12 +1482,13 @@ def __reduce__(self): def __repr__(self): if hasattr(self, "__FIELDS__"): return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) + for k, v in zip(self.__FIELDS__, self)) else: return "" % ", ".join(self) class SchemaRDD(RDD): + """An RDD of L{Row} objects that has an associated schema. The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can @@ -1652,7 +1699,7 @@ def subtract(self, other, numPartitions=None): rdd = self._jschema_rdd.subtract(other._jschema_rdd) else: rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) + numPartitions) return SchemaRDD(rdd, self.sql_ctx) else: raise ValueError("Can only subtract another SchemaRDD") @@ -1679,9 +1726,9 @@ def _test(): jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' - '"field6":[{"field7": "row2"}]}', + '"field6":[{"field7": "row2"}]}', '{"field1" : null, "field2": "row3", ' - '"field3":{"field4":33, "field5": []}}' + '"field3":{"field4":33, "field5": []}}' ] globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 5d77a131f2856..2aa0fb9d2c1ed 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -19,6 +19,7 @@ class StorageLevel: + """ Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 4ac94ba729d35..51bfbb47e53c2 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -29,12 +29,17 @@ import sys import tempfile import time -import unittest import zipfile +if sys.version_info[:2] <= (2, 6): + import unittest2 as unittest +else: + import unittest + + from pyspark.context import SparkContext from pyspark.files import SparkFiles -from pyspark.serializers import read_int +from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger _have_scipy = False @@ -62,53 +67,53 @@ def setUp(self): self.N = 1 << 16 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) - self.agg = Aggregator(lambda x: [x], - lambda x, y: x.append(y) or x, - lambda x, y: x.extend(y) or x) + self.agg = Aggregator(lambda x: [x], + lambda x, y: x.append(y) or x, + lambda x, y: x.extend(y) or x) def test_in_memory(self): m = InMemoryMerger(self.agg) m.mergeValues(self.data) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) m = InMemoryMerger(self.agg) m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) def test_small_dataset(self): m = ExternalMerger(self.agg, 1000) m.mergeValues(self.data) self.assertEqual(m.spills, 0) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) m = ExternalMerger(self.agg, 1000) m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data)) self.assertEqual(m.spills, 0) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) def test_medium_dataset(self): m = ExternalMerger(self.agg, 10) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N))) + sum(xrange(self.N))) m = ExternalMerger(self.agg, 10) m.mergeCombiners(map(lambda (x, y): (x, [y]), self.data * 3)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), - sum(xrange(self.N)) * 3) + sum(xrange(self.N)) * 3) def test_huge_dataset(self): m = ExternalMerger(self.agg, 10) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), - self.N * 10) + self.N * 10) m._cleanup() @@ -188,6 +193,7 @@ def test_add_py_file(self): log4j = self.sc._jvm.org.apache.log4j old_level = log4j.LogManager.getRootLogger().getLevel() log4j.LogManager.getRootLogger().setLevel(log4j.Level.FATAL) + def func(x): from userlibrary import UserClass return UserClass().hello() @@ -250,6 +256,15 @@ def test_save_as_textfile_with_unicode(self): raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + def test_save_as_textfile_with_utf8(self): + x = u"\u00A1Hola, mundo!" + data = self.sc.parallelize([x.encode("utf-8")]) + tempFile = tempfile.NamedTemporaryFile(delete=True) + tempFile.close() + data.saveAsTextFile(tempFile.name) + raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) + self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + def test_transforming_cartesian_result(self): # Regression test for SPARK-1034 rdd1 = self.sc.parallelize([1, 2]) @@ -317,6 +332,38 @@ def test_namedtuple_in_rdd(self): theDoes = self.sc.parallelize([jon, jane]) self.assertEquals([jon, jane], theDoes.collect()) + def test_large_broadcast(self): + N = 100000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 270MB + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEquals(N, m) + + def test_zip_with_different_serializers(self): + a = self.sc.parallelize(range(5)) + b = self.sc.parallelize(range(100, 105)) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + b = b._reserialize(MarshalSerializer()) + self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + + def test_zip_with_different_number_of_items(self): + a = self.sc.parallelize(range(5), 2) + # different number of partitions + b = self.sc.parallelize(range(100, 106), 3) + self.assertRaises(ValueError, lambda: a.zip(b)) + # different number of batched items in JVM + b = self.sc.parallelize(range(100, 104), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # different number of items in one pair + b = self.sc.parallelize(range(100, 106), 2) + self.assertRaises(Exception, lambda: a.zip(b).count()) + # same total number of items, but different distributions + a = self.sc.parallelize([2, 3], 2).flatMap(range) + b = self.sc.parallelize([3, 2], 2).flatMap(range) + self.assertEquals(a.count(), b.count()) + self.assertRaises(Exception, lambda: a.zip(b).count()) + class TestIO(PySparkTestCase): @@ -355,8 +402,8 @@ def test_sequencefiles(self): self.assertEqual(doubles, ed) bytes = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbytes/", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.BytesWritable").collect()) + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BytesWritable").collect()) ebs = [(1, bytearray('aa', 'utf-8')), (1, bytearray('aa', 'utf-8')), (2, bytearray('aa', 'utf-8')), @@ -428,9 +475,9 @@ def test_sequencefiles(self): self.assertEqual(clazz[0], ec) unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", - "org.apache.hadoop.io.Text", - "org.apache.spark.api.python.TestWritable", - batchSize=1).collect()) + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable", + batchSize=1).collect()) self.assertEqual(unbatched_clazz[0], ec) def test_oldhadoop(self): @@ -443,7 +490,7 @@ def test_oldhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - oldconf = {"mapred.input.dir" : hellopath} + oldconf = {"mapred.input.dir": hellopath} hello = self.sc.hadoopRDD("org.apache.hadoop.mapred.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -462,7 +509,7 @@ def test_newhadoop(self): self.assertEqual(ints, ei) hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") - newconf = {"mapred.input.dir" : hellopath} + newconf = {"mapred.input.dir": hellopath} hello = self.sc.newAPIHadoopRDD("org.apache.hadoop.mapreduce.lib.input.TextInputFormat", "org.apache.hadoop.io.LongWritable", "org.apache.hadoop.io.Text", @@ -517,6 +564,7 @@ def test_converters(self): (u'\x03', [2.0])] self.assertEqual(maps, em) + class TestOutputFormat(PySparkTestCase): def setUp(self): @@ -574,8 +622,8 @@ def test_sequencefiles(self): def test_oldhadoop(self): basepath = self.tempdir.name dict_data = [(1, {}), - (1, {"row1" : 1.0}), - (2, {"row2" : 2.0})] + (1, {"row1": 1.0}), + (2, {"row2": 2.0})] self.sc.parallelize(dict_data).saveAsHadoopFile( basepath + "/oldhadoop/", "org.apache.hadoop.mapred.SequenceFileOutputFormat", @@ -589,12 +637,13 @@ def test_oldhadoop(self): self.assertEqual(result, dict_data) conf = { - "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.hadoop.io.MapWritable", - "mapred.output.dir" : basepath + "/olddataset/"} + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.MapWritable", + "mapred.output.dir": basepath + "/olddataset/" + } self.sc.parallelize(dict_data).saveAsHadoopDataset(conf) - input_conf = {"mapred.input.dir" : basepath + "/olddataset/"} + input_conf = {"mapred.input.dir": basepath + "/olddataset/"} old_dataset = sorted(self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -602,6 +651,7 @@ def test_oldhadoop(self): conf=input_conf).collect()) self.assertEqual(old_dataset, dict_data) + @unittest.skipIf(sys.version_info[:2] <= (2, 6), "Skipped on 2.6 until SPARK-2951 is fixed") def test_newhadoop(self): basepath = self.tempdir.name # use custom ArrayWritable types and converters to handle arrays @@ -622,14 +672,17 @@ def test_newhadoop(self): valueConverter="org.apache.spark.api.python.WritableToDoubleArrayConverter").collect()) self.assertEqual(result, array_data) - conf = {"mapreduce.outputformat.class" : - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.spark.api.python.DoubleArrayWritable", - "mapred.output.dir" : basepath + "/newdataset/"} - self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset(conf, + conf = { + "mapreduce.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.spark.api.python.DoubleArrayWritable", + "mapred.output.dir": basepath + "/newdataset/" + } + self.sc.parallelize(array_data).saveAsNewAPIHadoopDataset( + conf, valueConverter="org.apache.spark.api.python.DoubleArrayToWritableConverter") - input_conf = {"mapred.input.dir" : basepath + "/newdataset/"} + input_conf = {"mapred.input.dir": basepath + "/newdataset/"} new_dataset = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -640,7 +693,7 @@ def test_newhadoop(self): def test_newolderror(self): basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( basepath + "/newolderror/saveAsHadoopFile/", "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat")) @@ -650,7 +703,7 @@ def test_newolderror(self): def test_bad_inputs(self): basepath = self.tempdir.name - rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x )) + rdd = self.sc.parallelize(range(1, 4)).map(lambda x: (x, "a" * x)) self.assertRaises(Exception, lambda: rdd.saveAsHadoopFile( basepath + "/badinputs/saveAsHadoopFile/", "org.apache.hadoop.mapred.NotValidOutputFormat")) @@ -685,30 +738,32 @@ def test_reserialization(self): result1 = sorted(self.sc.sequenceFile(basepath + "/reserialize/sequence").collect()) self.assertEqual(result1, data) - rdd.saveAsHadoopFile(basepath + "/reserialize/hadoop", - "org.apache.hadoop.mapred.SequenceFileOutputFormat") + rdd.saveAsHadoopFile( + basepath + "/reserialize/hadoop", + "org.apache.hadoop.mapred.SequenceFileOutputFormat") result2 = sorted(self.sc.sequenceFile(basepath + "/reserialize/hadoop").collect()) self.assertEqual(result2, data) - rdd.saveAsNewAPIHadoopFile(basepath + "/reserialize/newhadoop", - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") + rdd.saveAsNewAPIHadoopFile( + basepath + "/reserialize/newhadoop", + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") result3 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newhadoop").collect()) self.assertEqual(result3, data) conf4 = { - "mapred.output.format.class" : "org.apache.hadoop.mapred.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.dir" : basepath + "/reserialize/dataset"} + "mapred.output.format.class": "org.apache.hadoop.mapred.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.dir": basepath + "/reserialize/dataset"} rdd.saveAsHadoopDataset(conf4) result4 = sorted(self.sc.sequenceFile(basepath + "/reserialize/dataset").collect()) self.assertEqual(result4, data) - conf5 = {"mapreduce.outputformat.class" : - "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", - "mapred.output.key.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.value.class" : "org.apache.hadoop.io.IntWritable", - "mapred.output.dir" : basepath + "/reserialize/newdataset"} + conf5 = {"mapreduce.outputformat.class": + "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat", + "mapred.output.key.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.value.class": "org.apache.hadoop.io.IntWritable", + "mapred.output.dir": basepath + "/reserialize/newdataset"} rdd.saveAsNewAPIHadoopDataset(conf5) result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) @@ -719,25 +774,28 @@ def test_unbatched_save_and_read(self): self.sc.parallelize(ei, numSlices=len(ei)).saveAsSequenceFile( basepath + "/unbatched/") - unbatched_sequence = sorted(self.sc.sequenceFile(basepath + "/unbatched/", + unbatched_sequence = sorted(self.sc.sequenceFile( + basepath + "/unbatched/", batchSize=1).collect()) self.assertEqual(unbatched_sequence, ei) - unbatched_hadoopFile = sorted(self.sc.hadoopFile(basepath + "/unbatched/", + unbatched_hadoopFile = sorted(self.sc.hadoopFile( + basepath + "/unbatched/", "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text", batchSize=1).collect()) self.assertEqual(unbatched_hadoopFile, ei) - unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile(basepath + "/unbatched/", + unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile( + basepath + "/unbatched/", "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", "org.apache.hadoop.io.Text", batchSize=1).collect()) self.assertEqual(unbatched_newAPIHadoopFile, ei) - oldconf = {"mapred.input.dir" : basepath + "/unbatched/"} + oldconf = {"mapred.input.dir": basepath + "/unbatched/"} unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( "org.apache.hadoop.mapred.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -746,7 +804,7 @@ def test_unbatched_save_and_read(self): batchSize=1).collect()) self.assertEqual(unbatched_hadoopRDD, ei) - newconf = {"mapred.input.dir" : basepath + "/unbatched/"} + newconf = {"mapred.input.dir": basepath + "/unbatched/"} unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", "org.apache.hadoop.io.IntWritable", @@ -763,7 +821,9 @@ def test_malformed_RDD(self): self.assertRaises(Exception, lambda: rdd.saveAsSequenceFile( basepath + "/malformed/sequence")) + class TestDaemon(unittest.TestCase): + def connect(self, port): from socket import socket, AF_INET, SOCK_STREAM sock = socket(AF_INET, SOCK_STREAM) @@ -810,12 +870,15 @@ def test_termination_sigterm(self): class TestWorker(PySparkTestCase): + def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) temp.close() path = temp.name + def sleep(x): - import os, time + import os + import time with open(path, 'w') as f: f.write("%d %d" % (os.getppid(), os.getpid())) time.sleep(100) @@ -845,7 +908,7 @@ def run(): os.kill(worker_pid, 0) time.sleep(0.1) except OSError: - break # worker was killed + break # worker was killed else: self.fail("worker has not been killed after 5 seconds") @@ -855,12 +918,13 @@ def run(): self.fail("daemon had been killed") def test_fd_leak(self): - N = 1100 # fd limit is 1024 by default + N = 1100 # fd limit is 1024 by default rdd = self.sc.parallelize(range(N), N) self.assertEquals(N, rdd.count()) class TestSparkSubmit(unittest.TestCase): + def setUp(self): self.programDir = tempfile.mkdtemp() self.sparkSubmit = os.path.join(os.environ.get("SPARK_HOME"), "bin", "spark-submit") @@ -888,8 +952,9 @@ def createFileInZip(self, name, content): pattern = re.compile(r'^ *\|', re.MULTILINE) content = re.sub(pattern, '', content.strip()) path = os.path.join(self.programDir, name + ".zip") - with zipfile.ZipFile(path, 'w') as zip: - zip.writestr(name, content) + zip = zipfile.ZipFile(path, 'w') + zip.writestr(name, content) + zip.close() return path def test_single_script(self): @@ -953,9 +1018,9 @@ def test_module_dependency_on_cluster(self): |def myfunc(x): | return x + 1 """) - proc = subprocess.Popen( - [self.sparkSubmit, "--py-files", zip, "--master", "local-cluster[1,1,512]", script], - stdout=subprocess.PIPE) + proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", + "local-cluster[1,1,512]", script], + stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out) @@ -981,6 +1046,7 @@ def test_single_script_on_cluster(self): @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): + """General PySpark tests that depend on scipy """ def test_serialize(self): @@ -993,15 +1059,16 @@ def test_serialize(self): @unittest.skipIf(not _have_numpy, "NumPy not installed") class NumPyTests(PySparkTestCase): + """General PySpark tests that depend on numpy """ def test_statcounter_array(self): - x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])]) + x = self.sc.parallelize([np.array([1.0, 1.0]), np.array([2.0, 2.0]), np.array([3.0, 3.0])]) s = x.stats() - self.assertSequenceEqual([2.0,2.0], s.mean().tolist()) - self.assertSequenceEqual([1.0,1.0], s.min().tolist()) - self.assertSequenceEqual([3.0,3.0], s.max().tolist()) - self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist()) + self.assertSequenceEqual([2.0, 2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.min().tolist()) + self.assertSequenceEqual([3.0, 3.0], s.max().tolist()) + self.assertSequenceEqual([1.0, 1.0], s.sampleStdev().tolist()) if __name__ == "__main__": diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2770f63059853..6805063e06798 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,7 +30,8 @@ from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ + CompressedSerializer pickleSer = PickleSerializer() @@ -65,9 +66,10 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) + ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = pickleSer._read_with_length(infile) + value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) command = pickleSer._read_with_length(infile) diff --git a/python/run-tests b/python/run-tests index 48feba2f5bd63..7b1ee3e1cddba 100755 --- a/python/run-tests +++ b/python/run-tests @@ -48,12 +48,18 @@ function run_test() { echo "Running PySpark tests. Output is in python/unit-tests.log." +# Try to test with Python 2.6, since that's the minimum version that we support: +if [ $(which python2.6) ]; then + export PYSPARK_PYTHON="python2.6" +fi + +echo "Testing with Python version:" +$PYSPARK_PYTHON --version + run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "pyspark/conf.py" -if [ -n "$_RUN_SQL_TESTS" ]; then - run_test "pyspark/sql.py" -fi +run_test "pyspark/sql.py" # These tests are included in the module-level docs, and so must # be handled on a higher level rather than within the python file. export PYSPARK_DOC_TEST=1 @@ -70,7 +76,9 @@ run_test "pyspark/mllib/linalg.py" run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" +run_test "pyspark/mllib/stat.py" run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" if [[ $FAILED == 0 ]]; then diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py index 8e4a6292bc17c..73fd26e71f10d 100755 --- a/python/test_support/userlibrary.py +++ b/python/test_support/userlibrary.py @@ -19,6 +19,8 @@ Used to test shipping of code depenencies with SparkContext.addPyFile(). """ + class UserClass(object): + def hello(self): return "Hello World!" diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 8398e6f19b511..2c4452473ccbc 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -26,11 +26,53 @@ set -o posix # Figure out where Spark is installed FWDIR="$(cd `dirname $0`/..; pwd)" -if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then - echo "Usage: ./sbin/start-thriftserver [options]" +CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" + +function usage { + echo "Usage: ./sbin/start-thriftserver [options] [thrift server options]" + pattern="usage" + pattern+="\|Spark assembly has been built with Hive" + pattern+="\|NOTE: SPARK_PREPEND_CLASSES is set" + pattern+="\|Spark Command: " + pattern+="\|=======" + pattern+="\|--help" + $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + echo + echo "Thrift server options:" + $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 +} + +function ensure_arg_number { + arg_number=$1 + at_least=$2 + + if [[ $arg_number -lt $at_least ]]; then + usage + exit 1 + fi +} + +if [[ "$@" = --help ]] || [[ "$@" = -h ]]; then + usage exit 0 fi -CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" -exec "$FWDIR"/bin/spark-submit --class $CLASS spark-internal $@ +THRIFT_SERVER_ARGS=() +SUBMISSION_ARGS=() + +while (($#)); do + case $1 in + --hiveconf) + ensure_arg_number $# 2 + THRIFT_SERVER_ARGS+=("$1"); shift + THRIFT_SERVER_ARGS+=("$1"); shift + ;; + + *) + SUBMISSION_ARGS+=("$1"); shift + ;; + esac +done + +exec "$FWDIR"/bin/spark-submit --class $CLASS "${SUBMISSION_ARGS[@]}" spark-internal "${THRIFT_SERVER_ARGS[@]}" diff --git a/sql/README.md b/sql/README.md index 14d5555f0c713..31f9152344086 100644 --- a/sql/README.md +++ b/sql/README.md @@ -3,10 +3,11 @@ Spark SQL This module provides support for executing relational queries expressed in either SQL or a LINQ-like Scala DSL. -Spark SQL is broken up into three subprojects: +Spark SQL is broken up into four subprojects: - Catalyst (sql/catalyst) - An implementation-agnostic framework for manipulating trees of relational operators and expressions. - Execution (sql/core) - A query planner / execution engine for translating Catalyst’s logical query plans into Spark RDDs. This component also includes a new public interface, SQLContext, that allows users to execute SQL or LINQ statements against existing RDDs and Parquet files. - Hive Support (sql/hive) - Includes an extension of SQLContext called HiveContext that allows users to write queries using a subset of HiveQL and access data from a Hive Metastore using Hive SerDes. There are also wrappers that allows users to run queries that include Hive UDFs, UDAFs, and UDTFs. + - HiveServer and CLI support (sql/hive-thriftserver) - Includes support for the SQL CLI (bin/spark-sql) and a HiveServer2 (for JDBC/ODBC) compatible server. Other dependencies for developers diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 58d44e7923bee..830711a46a35b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -77,28 +77,28 @@ org.apache.maven.plugins maven-jar-plugin - - - test-jar - - - - test-jar-on-compile - compile - - test-jar - - + + + test-jar + + + + test-jar-on-test-compile + test-compile + + test-jar + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5b398695bf560..de2d67ce82ff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -78,7 +78,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin .build( new CacheLoader[InType, OutType]() { override def load(in: InType): OutType = globalLock.synchronized { - create(in) + val startTime = System.nanoTime() + val result = create(in) + val endTime = System.nanoTime() + def timeMs = (endTime - startTime).toDouble / 1000000 + logInfo(s"Code generated expression $in in $timeMs ms") + result } }) @@ -413,7 +418,19 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin """.children } - EvaluatedExpression(code, nullTerm, primitiveTerm, objectTerm) + // Only inject debugging code if debugging is turned on. + val debugCode = + if (log.isDebugEnabled) { + val localLogger = log + val localLoggerTree = reify { localLogger } + q""" + $localLoggerTree.debug(${e.toString} + ": " + (if($nullTerm) "null" else $primitiveTerm)) + """ :: Nil + } else { + Nil + } + + EvaluatedExpression(code ++ debugCode, nullTerm, primitiveTerm, objectTerm) } protected def getColumn(inputRow: TermName, dataType: DataType, ordinal: Int) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3d41acb79e5fd..e99c5b452d183 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -86,19 +86,19 @@ case class Explode(attributeNames: Seq[String], child: Expression) (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) private lazy val elementTypes = child.dataType match { - case ArrayType(et, _) => et :: Nil - case MapType(kt,vt, _) => kt :: vt :: Nil + case ArrayType(et, containsNull) => (et, containsNull) :: Nil + case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } // TODO: Move this pattern into Generator. protected def makeOutput() = if (attributeNames.size == elementTypes.size) { attributeNames.zip(elementTypes).map { - case (n, t) => AttributeReference(n, t, nullable = true)() + case (n, (t, nullable)) => AttributeReference(n, t, nullable)() } } else { elementTypes.zipWithIndex.map { - case (t, i) => AttributeReference(s"c_$i", t, nullable = true)() + case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index ce6d99c911ab3..e88c5d4fa178a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -60,6 +60,8 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def eval(input: Row): Any = { child.eval(input) == null } + + override def toString = s"IS NULL $child" } case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[Expression] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 0fd7aaaa36eb8..5cc41a83cc792 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -25,11 +25,17 @@ import java.util.Properties private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" + val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" + val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" + val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" + + // This is only used for the thriftserver + val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -48,6 +54,7 @@ private[spark] object SQLConf { trait SQLConf { import SQLConf._ + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) @@ -71,6 +78,9 @@ trait SQLConf { /** When true tables cached using the in-memory columnar caching will be compressed. */ private[spark] def useCompression: Boolean = getConf(COMPRESS_CACHED, "false").toBoolean + /** The number of rows that will be */ + private[spark] def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE, "1000").toInt + /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt @@ -83,8 +93,7 @@ trait SQLConf { * * Defaults to false as this feature is currently experimental. */ - private[spark] def codegenEnabled: Boolean = - if (getConf(CODEGEN_ENABLED, "false") == "true") true else false + private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to @@ -104,6 +113,12 @@ trait SQLConf { private[spark] def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, (autoBroadcastJoinThreshold + 1).toString).toLong + /** + * When set to true, we always treat byte arrays in Parquet files as strings. + */ + private[spark] def isParquetBinaryAsString: Boolean = + getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 71d338d21d0f2..af9f7c62a1d25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -273,7 +273,7 @@ class SQLContext(@transient val sparkContext: SparkContext) currentTable.logicalPlan case _ => - InMemoryRelation(useCompression, executePlan(currentTable).executedPlan) + InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan) } catalog.registerTable(None, tableName, asInMemoryRelation) @@ -284,7 +284,7 @@ class SQLContext(@transient val sparkContext: SparkContext) table(tableName).queryExecution.analyzed match { // This is kind of a hack to make sure that if this was just an RDD registered as a table, // we reregister the RDD as a table. - case inMem @ InMemoryRelation(_, _, e: ExistingRdd) => + case inMem @ InMemoryRelation(_, _, _, e: ExistingRdd) => inMem.cachedColumnBuffers.unpersist() catalog.unregisterTable(None, tableName) catalog.registerTable(None, tableName, SparkLogicalPlan(e)(self)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 57df79321b35d..33b2ed1b3a399 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -382,21 +382,26 @@ class SchemaRDD( private[sql] def javaToPython: JavaRDD[Array[Byte]] = { import scala.collection.Map - def toJava(obj: Any, dataType: DataType): Any = dataType match { - case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct) - case array: ArrayType => obj match { - case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava - case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava - case arr if arr != null && arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) - case other => other - } - case mt: MapType => obj.asInstanceOf[Map[_, _]].map { + def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (obj: Row, struct: StructType) => rowToArray(obj, struct) + + case (seq: Seq[Any], array: ArrayType) => + seq.map(x => toJava(x, array.elementType)).asJava + case (list: JList[_], array: ArrayType) => + list.map(x => toJava(x, array.elementType)).asJava + case (arr, array: ArrayType) if arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + + case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + // Pyrolite can handle Timestamp - case other => obj + case (other, _) => other } + def rowToArray(row: Row, structType: StructType): Array[Any] = { val fields = structType.fields.map(field => field.dataType) row.zip(fields).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 88901debbb4e9..e63b4903041f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -20,21 +20,21 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{SparkPlan, LeafNode} -import org.apache.spark.sql.Row -import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.{LeafNode, SparkPlan} object InMemoryRelation { - def apply(useCompression: Boolean, child: SparkPlan): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, child)() + def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = + new InMemoryRelation(child.output, useCompression, batchSize, child)() } private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, + batchSize: Int, child: SparkPlan) (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null) extends LogicalPlan with MultiInstanceRelation { @@ -43,22 +43,33 @@ private[sql] case class InMemoryRelation( // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output - val cached = child.execute().mapPartitions { iterator => - val columnBuilders = output.map { attribute => - ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name, useCompression) - }.toArray - - var row: Row = null - while (iterator.hasNext) { - row = iterator.next() - var i = 0 - while (i < row.length) { - columnBuilders(i).appendFrom(row, i) - i += 1 + val cached = child.execute().mapPartitions { baseIterator => + new Iterator[Array[ByteBuffer]] { + def next() = { + val columnBuilders = output.map { attribute => + val columnType = ColumnType(attribute.dataType) + val initialBufferSize = columnType.defaultSize * batchSize + ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression) + }.toArray + + var row: Row = null + var rowCount = 0 + + while (baseIterator.hasNext && rowCount < batchSize) { + row = baseIterator.next() + var i = 0 + while (i < row.length) { + columnBuilders(i).appendFrom(row, i) + i += 1 + } + rowCount += 1 + } + + columnBuilders.map(_.build()) } - } - Iterator.single(columnBuilders.map(_.build())) + def hasNext = baseIterator.hasNext + } }.cache() cached.setName(child.toString) @@ -74,6 +85,7 @@ private[sql] case class InMemoryRelation( new InMemoryRelation( output.map(_.newInstance), useCompression, + batchSize, child)( _cachedColumnBuffers).asInstanceOf[this.type] } @@ -90,22 +102,31 @@ private[sql] case class InMemoryColumnarTableScan( override def execute() = { relation.cachedColumnBuffers.mapPartitions { iterator => - val columnBuffers = iterator.next() - assert(!iterator.hasNext) + // Find the ordinals of the requested columns. If none are requested, use the first. + val requestedColumns = + if (attributes.isEmpty) { + Seq(0) + } else { + attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) + } new Iterator[Row] { - // Find the ordinals of the requested columns. If none are requested, use the first. - val requestedColumns = - if (attributes.isEmpty) { - Seq(0) - } else { - attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) - } + private[this] var columnBuffers: Array[ByteBuffer] = null + private[this] var columnAccessors: Seq[ColumnAccessor] = null + nextBatch() + + private[this] val nextRow = new GenericMutableRow(columnAccessors.length) - val columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) - val nextRow = new GenericMutableRow(columnAccessors.length) + def nextBatch() = { + columnBuffers = iterator.next() + columnAccessors = requestedColumns.map(columnBuffers(_)).map(ColumnAccessor(_)) + } override def next() = { + if (!columnAccessors.head.hasNext) { + nextBatch() + } + var i = 0 while (i < nextRow.length) { columnAccessors(i).extractTo(nextRow, i) @@ -114,7 +135,7 @@ private[sql] case class InMemoryColumnarTableScan( nextRow } - override def hasNext = columnAccessors.head.hasNext + override def hasNext = columnAccessors.head.hasNext || iterator.hasNext } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 0027f3cf1fc79..f9dfa3c92f1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -303,3 +303,15 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } } + +/** + * :: DeveloperApi :: + * A plan node that does nothing but lie about the output of its child. Used to spice a + * (hopefully structurally equivalent) tree from a different optimization sequence into an already + * resolved tree. + */ +@DeveloperApi +case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan { + def children = child :: Nil + def execute() = child.execute() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 51bb61530744c..b08f9aacc1fcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.{ArrayBuffer, BitSet} +import java.util.{HashMap => JavaHashMap} + import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent._ import scala.concurrent.duration._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.util.collection.CompactBuffer @DeveloperApi sealed abstract class BuildSide @@ -65,7 +66,7 @@ trait HashJoin { def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { // TODO: Use Spark's HashMap implementation. - val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() + val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() var currentRow: Row = null // Create a mapping of buildKeys -> rows @@ -75,7 +76,7 @@ trait HashJoin { if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { - val newMatchList = new ArrayBuffer[Row]() + val newMatchList = new CompactBuffer[Row]() hashTable.put(rowKey, newMatchList) newMatchList } else { @@ -87,7 +88,7 @@ trait HashJoin { new Iterator[Row] { private[this] var currentStreamedRow: Row = _ - private[this] var currentHashMatches: ArrayBuffer[Row] = _ + private[this] var currentHashMatches: CompactBuffer[Row] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. @@ -136,17 +137,9 @@ trait HashJoin { } } -/** - * Constant Value for Binary Join Node - */ -object HashOuterJoin { - val DUMMY_LIST = Seq[Row](null) - val EMPTY_LIST = Seq[Row]() -} - /** * :: DeveloperApi :: - * Performs a hash based outer join for two child relations by shuffling the data using + * Performs a hash based outer join for two child relations by shuffling the data using * the join keys. This operator requires loading the associated partition in both side into memory. */ @DeveloperApi @@ -168,29 +161,43 @@ case class HashOuterJoin( override def requiredChildDistribution = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - def output = left.output ++ right.output + override def output = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case x => + throw new Exception(s"HashOuterJoin should not take $x as the JoinType") + } + } + + @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) + @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala - // iterator for performance purpose. + // iterator for performance purpose. private[this] def leftOuterIterator( key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { val joinedRow = new JoinedRow() val rightNullRow = new GenericRow(right.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - leftIter.iterator.flatMap { l => + leftIter.iterator.flatMap { l => joinedRow.withLeft(l) var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => + (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => matched = true joinedRow.copy } else { Nil - }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the // records in right side. // If we didn't get any proper row, then append a single row with empty right joinedRow.withRight(rightNullRow).copy @@ -202,20 +209,20 @@ case class HashOuterJoin( key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { val joinedRow = new JoinedRow() val leftNullRow = new GenericRow(left.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - rightIter.iterator.flatMap { r => + rightIter.iterator.flatMap { r => joinedRow.withRight(r) var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => + (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => matched = true joinedRow.copy } else { Nil - }) ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the + }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all of the // records in left side. // If we didn't get any proper row, then append a single row with empty left. joinedRow.withLeft(leftNullRow).copy @@ -228,7 +235,7 @@ case class HashOuterJoin( val joinedRow = new JoinedRow() val leftNullRow = new GenericRow(left.output.length) val rightNullRow = new GenericRow(right.output.length) - val boundCondition = + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) if (!key.anyNull) { @@ -238,8 +245,8 @@ case class HashOuterJoin( leftIter.iterator.flatMap[Row] { l => joinedRow.withLeft(l) var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, + rightIter.zipWithIndex.collect { + // 1. For those matched (satisfy the join condition) records with both sides filled, // append them directly case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> { @@ -248,11 +255,11 @@ case class HashOuterJoin( rightMatchedSet.add(idx) joinedRow.copy } - } ++ HashOuterJoin.DUMMY_LIST.filter(_ => !matched).map( _ => { + } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { // 2. For those unmatched records in left, append additional records with empty right. - // HashOuterJoin.DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all + // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, + // as we don't know whether we need to append it until finish iterating all // of the records in right side. // If we didn't get any proper row, then append a single row with empty right. joinedRow.withRight(rightNullRow).copy @@ -260,8 +267,8 @@ case class HashOuterJoin( } ++ rightIter.zipWithIndex.collect { // 3. For those unmatched records in right, append additional records with empty left. - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. + // Re-visiting the records in right, and append additional row with empty left, if its not + // in the matched set. case (r, idx) if (!rightMatchedSet.contains(idx)) => { joinedRow(leftNullRow, r).copy } @@ -276,18 +283,22 @@ case class HashOuterJoin( } private[this] def buildHashTable( - iter: Iterator[Row], keyGenerator: Projection): Map[Row, ArrayBuffer[Row]] = { - // TODO: Use Spark's HashMap implementation. - val hashTable = scala.collection.mutable.Map[Row, ArrayBuffer[Row]]() + iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = { + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]() while (iter.hasNext) { val currentRow = iter.next() val rowKey = keyGenerator(currentRow) - val existingMatchList = hashTable.getOrElseUpdate(rowKey, {new ArrayBuffer[Row]()}) + var existingMatchList = hashTable.get(rowKey) + if (existingMatchList == null) { + existingMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, existingMatchList) + } + existingMatchList += currentRow.copy() } - - hashTable.toMap[Row, ArrayBuffer[Row]] + + hashTable } def execute() = { @@ -298,21 +309,22 @@ case class HashOuterJoin( // Build HashMap for current partition in right relation val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val boundCondition = + import scala.collection.JavaConversions._ + val boundCondition = condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) joinType match { case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST), - rightHashTable.getOrElse(key, HashOuterJoin.EMPTY_LIST)) + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST)) } case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } @@ -411,7 +423,7 @@ case class BroadcastHashJoin( UnspecifiedDistribution :: UnspecifiedDistribution :: Nil @transient - lazy val broadcastFuture = future { + val broadcastFuture = future { sparkContext.broadcast(buildPlan.executeCollect()) } @@ -537,7 +549,7 @@ case class BroadcastNestedLoopJoin( /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => - val matchedRows = new ArrayBuffer[Row] + val matchedRows = new CompactBuffer[Row] // TODO: Use Spark's BitSet. val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) @@ -589,20 +601,20 @@ case class BroadcastNestedLoopJoin( val rightNulls = new GenericMutableRow(right.output.size) /** Rows from broadcasted joined with nulls. */ val broadcastRowsWithNulls: Seq[Row] = { - val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer() + val buf: CompactBuffer[Row] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) + case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) case _ => } } i += 1 } - arrBuf.toSeq + buf.toSeq } // TODO: Breaks lineage. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index a3d2a1c7a51f8..1c0b03c684f10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -109,7 +109,9 @@ private[sql] object JsonRDD extends Logging { val newType = dataType match { case NullType => StringType case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) - case struct: StructType => nullTypeToStringType(struct) + case ArrayType(struct: StructType, containsNull) => + ArrayType(nullTypeToStringType(struct), containsNull) + case struct: StructType =>nullTypeToStringType(struct) case other: DataType => other } StructField(fieldName, newType, nullable) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index cc575bedd8fcb..2298a9b933df5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -201,8 +201,9 @@ object ParquetFilters { (leftFilter, rightFilter) match { case (None, Some(filter)) => Some(filter) case (Some(filter), None) => Some(filter) - case (_, _) => - Some(new AndFilter(leftFilter.get, rightFilter.get)) + case (Some(leftF), Some(rightF)) => + Some(new AndFilter(leftF, rightF)) + case _ => None } } case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b3bae5db0edbc..1713ae6fb5d93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -47,7 +47,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} private[sql] case class ParquetRelation( path: String, @transient conf: Option[Configuration], - @transient sqlContext: SQLContext) + @transient sqlContext: SQLContext, + partitioningAttributes: Seq[Attribute] = Nil) extends LeafNode with MultiInstanceRelation { self: Product => @@ -60,9 +61,14 @@ private[sql] case class ParquetRelation( .getSchema /** Attributes */ - override val output = ParquetTypesConverter.readSchemaFromFile(new Path(path), conf) - - override def newInstance = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] + override val output = + partitioningAttributes ++ + ParquetTypesConverter.readSchemaFromFile( + new Path(path.split(",").head), + conf, + sqlContext.isParquetBinaryAsString) + + override def newInstance() = ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 759a2a586b926..f6cfab736d98a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -17,17 +17,19 @@ package org.apache.spark.sql.parquet -import scala.collection.JavaConversions._ -import scala.collection.mutable -import scala.util.Try - import java.io.IOException import java.lang.{Long => JLong} import java.text.SimpleDateFormat -import java.util.{Date, List => JList} +import java.util.concurrent.{Callable, TimeUnit} +import java.util.{ArrayList, Collections, Date, List => JList} + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.Try +import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} @@ -41,7 +43,8 @@ import parquet.io.ParquetDecodingException import parquet.schema.MessageType import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.{Logging, SerializableWritable, TaskContext} @@ -59,11 +62,18 @@ case class ParquetTableScan( // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes // by exprId. note: output cannot be transient, see // https://issues.apache.org/jira/browse/SPARK-1367 - val output = attributes.map { a => - relation.output - .find(o => o.exprId == a.exprId) - .getOrElse(sys.error(s"Invalid parquet attribute $a in ${relation.output.mkString(",")}")) - } + val normalOutput = + attributes + .filterNot(a => relation.partitioningAttributes.map(_.exprId).contains(a.exprId)) + .flatMap(a => relation.output.find(o => o.exprId == a.exprId)) + + val partOutput = + attributes.flatMap(a => relation.partitioningAttributes.find(o => o.exprId == a.exprId)) + + def output = partOutput ++ normalOutput + + assert(normalOutput.size + partOutput.size == attributes.size, + s"$normalOutput + $partOutput != $attributes, ${relation.output}") override def execute(): RDD[Row] = { val sc = sqlContext.sparkContext @@ -71,16 +81,19 @@ case class ParquetTableScan( ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) val conf: Configuration = ContextUtil.getConfiguration(job) - val qualifiedPath = { - val path = new Path(relation.path) - path.getFileSystem(conf).makeQualified(path) + + relation.path.split(",").foreach { curPath => + val qualifiedPath = { + val path = new Path(curPath) + path.getFileSystem(conf).makeQualified(path) + } + NewFileInputFormat.addInputPath(job, qualifiedPath) } - NewFileInputFormat.addInputPath(job, qualifiedPath) // Store both requested and original schema in `Configuration` conf.set( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(output)) + ParquetTypesConverter.convertToString(normalOutput)) conf.set( RowWriteSupport.SPARK_ROW_SCHEMA, ParquetTypesConverter.convertToString(relation.output)) @@ -96,13 +109,46 @@ case class ParquetTableScan( ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) } - sc.newAPIHadoopRDD( - conf, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[Row]) - .map(_._2) - .filter(_ != null) // Parquet's record filters may produce null values + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + conf.set( + SQLConf.PARQUET_CACHE_METADATA, + sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, "false")) + + val baseRDD = + new org.apache.spark.rdd.NewHadoopRDD( + sc, + classOf[FilteringParquetRowInputFormat], + classOf[Void], + classOf[Row], + conf) + + if (partOutput.nonEmpty) { + baseRDD.mapPartitionsWithInputSplit { case (split, iter) => + val partValue = "([^=]+)=([^=]+)".r + val partValues = + split.asInstanceOf[parquet.hadoop.ParquetInputSplit] + .getPath + .toString + .split("/") + .flatMap { + case partValue(key, value) => Some(key -> value) + case _ => None + }.toMap + + val partitionRowValues = + partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) + + new Iterator[Row] { + private[this] val joinedRow = new JoinedRow(Row(partitionRowValues:_*), null) + + def hasNext = iter.hasNext + + def next() = joinedRow.withRight(iter.next()._2) + } + } + } else { + baseRDD.map(_._2) + }.filter(_ != null) // Parquet's record filters may produce null values } /** @@ -323,10 +369,40 @@ private[parquet] class FilteringParquetRowInputFormat } override def getFooters(jobContext: JobContext): JList[Footer] = { + import FilteringParquetRowInputFormat.footerCache + if (footers eq null) { + val conf = ContextUtil.getConfiguration(jobContext) + val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) val statuses = listStatus(jobContext) fileStatuses = statuses.map(file => file.getPath -> file).toMap - footers = getFooters(ContextUtil.getConfiguration(jobContext), statuses) + if (statuses.isEmpty) { + footers = Collections.emptyList[Footer] + } else if (!cacheMetadata) { + // Read the footers from HDFS + footers = getFooters(conf, statuses) + } else { + // Read only the footers that are not in the footerCache + val foundFooters = footerCache.getAllPresent(statuses) + val toFetch = new ArrayList[FileStatus] + for (s <- statuses) { + if (!foundFooters.containsKey(s)) { + toFetch.add(s) + } + } + val newFooters = new mutable.HashMap[FileStatus, Footer] + if (toFetch.size > 0) { + val fetched = getFooters(conf, toFetch) + for ((status, i) <- toFetch.zipWithIndex) { + newFooters(status) = fetched.get(i) + } + footerCache.putAll(newFooters) + } + footers = new ArrayList[Footer](statuses.size) + for (status <- statuses) { + footers.add(newFooters.getOrElse(status, foundFooters.get(status))) + } + } } footers @@ -339,6 +415,10 @@ private[parquet] class FilteringParquetRowInputFormat configuration: Configuration, footers: JList[Footer]): JList[ParquetInputSplit] = { + import FilteringParquetRowInputFormat.blockLocationCache + + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) val minSplitSize: JLong = Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) @@ -366,16 +446,23 @@ private[parquet] class FilteringParquetRowInputFormat for (footer <- footers) { val fs = footer.getFile.getFileSystem(configuration) val file = footer.getFile - val fileStatus = fileStatuses.getOrElse(file, fs.getFileStatus(file)) + val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) val parquetMetaData = footer.getParquetMetadata val blocks = parquetMetaData.getBlocks - val fileBlockLocations = fs.getFileBlockLocations(fileStatus, 0, fileStatus.getLen) + var blockLocations: Array[BlockLocation] = null + if (!cacheMetadata) { + blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) + } else { + blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { + def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) + }) + } splits.addAll( generateSplits.invoke( null, blocks, - fileBlockLocations, - fileStatus, + blockLocations, + status, parquetMetaData.getFileMetaData, readContext.getRequestedSchema.toString, readContext.getReadSupportMetadata, @@ -387,6 +474,17 @@ private[parquet] class FilteringParquetRowInputFormat } } +private[parquet] object FilteringParquetRowInputFormat { + private val footerCache = CacheBuilder.newBuilder() + .maximumSize(20000) + .build[FileStatus, Footer]() + + private val blockLocationCache = CacheBuilder.newBuilder() + .maximumSize(20000) + .expireAfterWrite(15, TimeUnit.MINUTES) // Expire locations since HDFS files might move + .build[FileStatus, Array[BlockLocation]]() +} + private[parquet] object FileSystemHelper { def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 6d4ce32ac5bfa..6a657c20fe46c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -80,9 +80,10 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { } } // if both unavailable, fall back to deducing the schema from the given Parquet schema + // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes(parquetSchema) + schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index aaef1a1d474fe..c79a9ac2dad81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -43,10 +43,13 @@ private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass - def toPrimitiveDataType(parquetType: ParquetPrimitiveType): DataType = + def toPrimitiveDataType( + parquetType: ParquetPrimitiveType, + binayAsString: Boolean): DataType = parquetType.getPrimitiveTypeName match { case ParquetPrimitiveTypeName.BINARY - if parquetType.getOriginalType == ParquetOriginalType.UTF8 => StringType + if (parquetType.getOriginalType == ParquetOriginalType.UTF8 || + binayAsString) => StringType case ParquetPrimitiveTypeName.BINARY => BinaryType case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType @@ -85,7 +88,7 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param parquetType The type to convert. * @return The corresponding Catalyst type. */ - def toDataType(parquetType: ParquetType): DataType = { + def toDataType(parquetType: ParquetType, isBinaryAsString: Boolean): DataType = { def correspondsToMap(groupType: ParquetGroupType): Boolean = { if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { false @@ -107,7 +110,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType) + toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString) } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -116,7 +119,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - ArrayType(toDataType(field), containsNull = false) + ArrayType(toDataType(field, isBinaryAsString), containsNull = false) } case ParquetOriginalType.MAP => { assert( @@ -126,9 +129,9 @@ private[parquet] object ParquetTypesConverter extends Logging { assert( keyValueGroup.getFieldCount == 2, "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - val keyType = toDataType(keyValueGroup.getFields.apply(0)) + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val valueType = toDataType(keyValueGroup.getFields.apply(1)) + val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true // at here. @@ -138,22 +141,22 @@ private[parquet] object ParquetTypesConverter extends Logging { // Note: the order of these checks is important! if (correspondsToMap(groupType)) { // MapType val keyValueGroup = groupType.getFields.apply(0).asGroupType() - val keyType = toDataType(keyValueGroup.getFields.apply(0)) + val keyType = toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString) assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - val valueType = toDataType(keyValueGroup.getFields.apply(1)) + val valueType = toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true // at here. MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType - val elementType = toDataType(groupType.getFields.apply(0)) + val elementType = toDataType(groupType.getFields.apply(0), isBinaryAsString) ArrayType(elementType, containsNull = false) } else { // everything else: StructType val fields = groupType .getFields .map(ptype => new StructField( ptype.getName, - toDataType(ptype), + toDataType(ptype, isBinaryAsString), ptype.getRepetition != Repetition.REQUIRED)) StructType(fields) } @@ -276,7 +279,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } } - def convertToAttributes(parquetSchema: ParquetType): Seq[Attribute] = { + def convertToAttributes(parquetSchema: ParquetType, isBinaryAsString: Boolean): Seq[Attribute] = { parquetSchema .asGroupType() .getFields @@ -284,7 +287,7 @@ private[parquet] object ParquetTypesConverter extends Logging { field => new AttributeReference( field.getName, - toDataType(field), + toDataType(field, isBinaryAsString), field.getRepetition != Repetition.REQUIRED)()) } @@ -373,8 +376,9 @@ private[parquet] object ParquetTypesConverter extends Logging { } ParquetRelation.enableLogForwarding() - val children = fs.listStatus(path).filterNot { - _.getPath.getName == FileOutputCommitter.SUCCEEDED_FILE_NAME + val children = fs.listStatus(path).filterNot { status => + val name = status.getPath.getName + name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME } // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row @@ -402,7 +406,10 @@ private[parquet] object ParquetTypesConverter extends Logging { * @param conf The Hadoop configuration to use. * @return A list of attributes that make up the schema. */ - def readSchemaFromFile(origPath: Path, conf: Option[Configuration]): Seq[Attribute] = { + def readSchemaFromFile( + origPath: Path, + conf: Option[Configuration], + isBinaryAsString: Boolean): Seq[Attribute] = { val keyValueMetadata: java.util.Map[String, String] = readMetaData(origPath, conf) .getFileMetaData @@ -411,7 +418,7 @@ private[parquet] object ParquetTypesConverter extends Logging { convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) } else { val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema) + readMetaData(origPath, conf).getFileMetaData.getSchema, isBinaryAsString) log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") attributes } diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index dffd15a61838b..c7e0ff1cf6494 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -36,6 +36,9 @@ log4j.appender.FA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c{1}: %m%n log4j.appender.FA.Threshold = INFO # Some packages are noisy for no good reason. +log4j.additivity.parquet.hadoop.ParquetRecordReader=false +log4j.logger.parquet.hadoop.ParquetRecordReader=OFF + log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index fbf9bd9dbcdea..befef46d93973 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -22,9 +22,19 @@ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableSca import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ +case class BigData(s: String) + class CachedTableSuite extends QueryTest { TestData // Load test tables. + test("too big for memory") { + val data = "*" * 10000 + sparkContext.parallelize(1 to 1000000, 1).map(_ => BigData(data)).registerTempTable("bigData") + cacheTable("bigData") + assert(table("bigData").count() === 1000000L) + uncacheTable("bigData") + } + test("SPARK-1669: cacheTable should be idempotent") { assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) @@ -37,7 +47,7 @@ class CachedTableSuite extends QueryTest { cacheTable("testData") table("testData").queryExecution.analyzed match { - case InMemoryRelation(_, _, _: InMemoryColumnarTableScan) => + case InMemoryRelation(_, _, _, _: InMemoryColumnarTableScan) => fail("cacheTable is not idempotent") case _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index b561b44ad7ee2..736c0f8571e9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -28,14 +28,14 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("simple columnar query") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().toSeq) } test("projection") { val plan = TestSQLContext.executePlan(testData.select('value, 'key).logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -44,7 +44,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan - val scan = InMemoryRelation(useCompression = true, plan) + val scan = InMemoryRelation(useCompression = true, 5, plan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 75c0589eb208e..58b1e23891a3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -213,7 +213,8 @@ class JsonSuite extends QueryTest { StructField("arrayOfStruct", ArrayType( StructType( StructField("field1", BooleanType, true) :: - StructField("field2", StringType, true) :: Nil)), true) :: + StructField("field2", StringType, true) :: + StructField("field3", StringType, true) :: Nil)), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: StructField("field2", DecimalType, true) :: Nil), true) :: @@ -263,8 +264,12 @@ class JsonSuite extends QueryTest { // Access elements of an array of structs. checkAnswer( - sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2] from jsonTable"), - (true :: "str1" :: Nil, false :: null :: Nil, null) :: Nil + sql("select arrayOfStruct[0], arrayOfStruct[1], arrayOfStruct[2], arrayOfStruct[3] " + + "from jsonTable"), + (true :: "str1" :: null :: Nil, + false :: null :: null :: Nil, + null :: null :: null :: Nil, + null) :: Nil ) // Access a struct and fields inside of it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index d0180f3754f22..a88310b5f1b46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -43,7 +43,7 @@ object TestJsonData { "arrayOfDouble":[1.2, 1.7976931348623157E308, 4.9E-324, 2.2250738585072014E-308], "arrayOfBoolean":[true, false, true], "arrayOfNull":[null, null, null, null], - "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}], + "arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "arrayOfArray1":[[1, 2, 3], ["str1", "str2"]], "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] }""" :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 9933575038bd3..172dcd6aa0ee3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -21,8 +21,6 @@ import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil -import parquet.schema.MessageTypeParser - import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job @@ -33,7 +31,6 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.util.Utils @@ -138,6 +135,57 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } } + test("Treat binary as string") { + val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString + + // Create the test file. + val file = getTempFilePath("parquet") + val path = file.toString + val range = (0 to 255) + val rowRDD = TestSQLContext.sparkContext.parallelize(range) + .map(i => org.apache.spark.sql.Row(i, s"val_$i".getBytes)) + // We need to ask Parquet to store the String column as a Binary column. + val schema = StructType( + StructField("c1", IntegerType, false) :: + StructField("c2", BinaryType, false) :: Nil) + val schemaRDD1 = applySchema(rowRDD, schema) + schemaRDD1.saveAsParquetFile(path) + val resultWithBinary = parquetFile(path).collect + range.foreach { + i => + assert(resultWithBinary(i).getInt(0) === i) + assert(resultWithBinary(i)(1) === s"val_$i".getBytes) + } + + TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true") + // This ParquetRelation always use Parquet types to derive output. + val parquetRelation = new ParquetRelation( + path.toString, + Some(TestSQLContext.sparkContext.hadoopConfiguration), + TestSQLContext) { + override val output = + ParquetTypesConverter.convertToAttributes( + ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema, + TestSQLContext.isParquetBinaryAsString) + } + val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation) + val resultWithString = schemaRDD.collect + range.foreach { + i => + assert(resultWithString(i).getInt(0) === i) + assert(resultWithString(i)(1) === s"val_$i") + } + + schemaRDD.registerTempTable("tmp") + checkAnswer( + sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"), + (5, "val_5") :: + (7, "val_7") :: Nil) + + // Set it back. + TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString) + } + test("Read/Write All Types with non-primitive type") { val tempDir = getTempFilePath("parquetTest").getCanonicalPath val range = (0 to 255) @@ -381,11 +429,14 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val predicate5 = new GreaterThan(attribute1, attribute2) val badfilter = ParquetFilters.createFilter(predicate5) assert(badfilter.isDefined === false) + + val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2)) + val badfilter2 = ParquetFilters.createFilter(predicate6) + assert(badfilter2.isDefined === false) } test("test filter by predicate pushdown") { for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) { - println(s"testing field $myval") val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") assert( query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index 08d3f983d9e71..cadf7aaf42157 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -40,7 +40,6 @@ private[hive] object HiveThriftServer2 extends Logging { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { - logWarning("Error starting HiveThriftServer2 with given arguments") System.exit(-1) } @@ -61,7 +60,7 @@ private[hive] object HiveThriftServer2 extends Logging { Runtime.getRuntime.addShutdownHook( new Thread() { override def run() { - SparkSQLEnv.sparkContext.stop() + SparkSQLEnv.stop() } } ) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 4d0c506c5a397..b092f42372171 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -32,7 +32,7 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{SetProcessor, CommandProcessor, CommandProcessorFactory} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket @@ -278,7 +278,7 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { val proc: CommandProcessor = CommandProcessorFactory.get(tokens(0), hconf) if (proc != null) { - if (proc.isInstanceOf[Driver]) { + if (proc.isInstanceOf[Driver] || proc.isInstanceOf[SetProcessor]) { val driver = new SparkSQLDriver driver.init() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index dee092159dd4c..699a1103f3248 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -17,24 +17,24 @@ package org.apache.spark.sql.hive.thriftserver.server -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.math.{random, round} - import java.sql.Timestamp import java.util.{Map => JMap} +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, Map} +import scala.math.{random, round} + import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession - import org.apache.spark.Logging +import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.catalyst.plans.logical.SetCommand import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.hive.thriftserver.ReflectionUtils import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.hive.thriftserver.ReflectionUtils /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -43,6 +43,9 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") + // TODO: Currenlty this will grow infinitely, even as sessions expire + val sessionToActivePool = Map[HiveSession, String]() + override def newExecuteStatementOperation( parentSession: HiveSession, statement: String, @@ -73,35 +76,10 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage var curCol = 0 while (curCol < sparkRow.length) { - dataTypes(curCol) match { - case StringType => - row.addString(sparkRow(curCol).asInstanceOf[String]) - case IntegerType => - row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol))) - case BooleanType => - row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol))) - case DoubleType => - row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol))) - case FloatType => - row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol))) - case DecimalType => - val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal - row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) - case LongType => - row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol))) - case ByteType => - row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol))) - case ShortType => - row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol))) - case TimestampType => - row.addColumnValue( - ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp])) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = result - .queryExecution - .asInstanceOf[HiveContext#QueryExecution] - .toHiveString((sparkRow.get(curCol), dataTypes(curCol))) - row.addColumnValue(ColumnValue.stringValue(hiveString)) + if (sparkRow.isNullAt(curCol)) { + addNullColumnValue(sparkRow, row, curCol) + } else { + addNonNullColumnValue(sparkRow, row, curCol) } curCol += 1 } @@ -112,6 +90,66 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } } + def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(from(ordinal).asInstanceOf[String]) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal))) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal))) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) + case DecimalType => + val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal + to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) + case LongType => + to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal))) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) + case ShortType => + to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal))) + case TimestampType => + to.addColumnValue( + ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((from.get(ordinal), dataTypes(ordinal))) + to.addColumnValue(ColumnValue.stringValue(hiveString)) + } + } + + def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) { + dataTypes(ordinal) match { + case StringType => + to.addString(null) + case IntegerType => + to.addColumnValue(ColumnValue.intValue(null)) + case BooleanType => + to.addColumnValue(ColumnValue.booleanValue(null)) + case DoubleType => + to.addColumnValue(ColumnValue.doubleValue(null)) + case FloatType => + to.addColumnValue(ColumnValue.floatValue(null)) + case DecimalType => + to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) + case LongType => + to.addColumnValue(ColumnValue.longValue(null)) + case ByteType => + to.addColumnValue(ColumnValue.byteValue(null)) + case ShortType => + to.addColumnValue(ColumnValue.intValue(null)) + case TimestampType => + to.addColumnValue(ColumnValue.timestampValue(null)) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + to.addColumnValue(ColumnValue.stringValue(null: String)) + } + } + def getResultSetSchema: TableSchema = { logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { @@ -130,9 +168,28 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage try { result = hiveContext.sql(statement) logDebug(result.queryExecution.toString()) + result.queryExecution.logical match { + case SetCommand(Some(key), Some(value)) if (key == SQLConf.THRIFTSERVER_POOL) => + sessionToActivePool(parentSession) = value + logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.") + case _ => + } + val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) - iter = result.queryExecution.toRdd.toLocalIterator + sessionToActivePool.get(parentSession).foreach { pool => + hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool) + } + iter = { + val resultRdd = result.queryExecution.toRdd + val useIncrementalCollect = + hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean + if (useIncrementalCollect) { + resultRdd.toLocalIterator + } else { + resultRdd.collect().iterator + } + } dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray setHasResultSet(true) } catch { diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt new file mode 100644 index 0000000000000..ae08c640e6c13 --- /dev/null +++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt @@ -0,0 +1,10 @@ +238val_238 + +311val_311 +val_27 +val_165 +val_409 +255val_255 +278val_278 +98val_98 +val_484 diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 69f19f826a802..2bf8cfdcacd22 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.thriftserver import java.io.{BufferedReader, InputStreamReader, PrintWriter} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { @@ -27,15 +28,15 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with TestUtils { val METASTORE_PATH = TestUtils.getMetastorePath("cli") override def beforeAll() { - val pb = new ProcessBuilder( - "../../bin/spark-sql", - "--master", - "local", - "--hiveconf", - s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", - "--hiveconf", - "hive.metastore.warehouse.dir=" + WAREHOUSE_PATH) - + val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true" + val commands = + s"""../../bin/spark-sql + | --master local + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$WAREHOUSE_PATH + """.stripMargin.split("\\s+") + + val pb = new ProcessBuilder(commands: _*) process = pb.start() outputWriter = new PrintWriter(process.getOutputStream, true) inputReader = new BufferedReader(new InputStreamReader(process.getInputStream)) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index b7b7c9957ac34..aedef6ce1f5f2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -25,6 +25,7 @@ import java.io.{BufferedReader, InputStreamReader} import java.net.ServerSocket import java.sql.{Connection, DriverManager, Statement} +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.Logging @@ -63,16 +64,18 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt // Forking a new process to start the Hive Thrift server. The reason to do this is it is // hard to clean up Hive resources entirely, so we just start a new process and kill // that process for cleanup. - val defaultArgs = Seq( - "../../sbin/start-thriftserver.sh", - "--master local", - "--hiveconf", - "hive.root.logger=INFO,console", - "--hiveconf", - s"javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$METASTORE_PATH;create=true", - "--hiveconf", - s"hive.metastore.warehouse.dir=$WAREHOUSE_PATH") - val pb = new ProcessBuilder(defaultArgs ++ args) + val jdbcUrl = s"jdbc:derby:;databaseName=$METASTORE_PATH;create=true" + val command = + s"""../../sbin/start-thriftserver.sh + | --master local + | --hiveconf hive.root.logger=INFO,console + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}="$jdbcUrl" + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$METASTORE_PATH + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=$HOST + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$PORT + """.stripMargin.split("\\s+") + + val pb = new ProcessBuilder(command ++ args: _*) val environment = pb.environment() environment.put("HIVE_SERVER2_THRIFT_PORT", PORT.toString) environment.put("HIVE_SERVER2_THRIFT_BIND_HOST", HOST) @@ -110,22 +113,40 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt val stmt = createStatement() stmt.execute("DROP TABLE IF EXISTS test") stmt.execute("DROP TABLE IF EXISTS test_cached") - stmt.execute("CREATE TABLE test(key int, val string)") + stmt.execute("CREATE TABLE test(key INT, val STRING)") stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test") - stmt.execute("CREATE TABLE test_cached as select * from test limit 4") + stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4") stmt.execute("CACHE TABLE test_cached") - var rs = stmt.executeQuery("select count(*) from test") + var rs = stmt.executeQuery("SELECT COUNT(*) FROM test") rs.next() assert(rs.getInt(1) === 5) - rs = stmt.executeQuery("select count(*) from test_cached") + rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached") rs.next() assert(rs.getInt(1) === 4) stmt.close() } + test("SPARK-3004 regression: result set containing NULL") { + Thread.sleep(5 * 1000) + val dataFilePath = getDataFile("data/files/small_kv_with_null.txt") + val stmt = createStatement() + stmt.execute("DROP TABLE IF EXISTS test_null") + stmt.execute("CREATE TABLE test_null(key INT, val STRING)") + stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null") + + val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL") + var count = 0 + while (rs.next()) { + count += 1 + } + assert(count === 5) + + stmt.close() + } + def getConnection: Connection = { val connectURI = s"jdbc:hive2://localhost:$PORT/" DriverManager.getConnection(connectURI, System.getProperty("user.name"), "") diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 4fef071161719..210753efe7678 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -635,6 +635,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "serde_regex", "serde_reported_schema", "set_variable_sub", + "show_create_table_partitioned", + "show_create_table_delimited", + "show_create_table_alter", + "show_create_table_view", + "show_create_table_serde", + "show_create_table_db_table", + "show_create_table_does_not_exist", + "show_create_table_index", "show_describe_func_quotes", "show_functions", "show_partitions", diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 93d00f7c37c9b..30ff277e67c88 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -36,6 +36,11 @@ + + com.twitter + parquet-hive-bundle + 1.5.0 + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 53f3dc11dbb9f..ff32c7c90a0d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -39,7 +39,8 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, Analyzer, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators} +import org.apache.spark.sql.catalyst.analysis.{OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException @@ -78,6 +79,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // Change the default SQL dialect to HiveQL override private[spark] def dialect: String = getConf(SQLConf.DIALECT, "hiveql") + /** + * When true, enables an experimental feature where metastore tables that use the parquet SerDe + * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive + * SerDe. + */ + private[spark] def convertMetastoreParquet: Boolean = + getConf("spark.sql.hive.convertMetastoreParquet", "false") == "true" + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } @@ -119,10 +128,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { * in the Hive metastore. */ def analyze(tableName: String) { - val relation = catalog.lookupRelation(None, tableName) match { - case LowerCaseSchema(r) => r - case o => o - } + val relation = EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) relation match { case relation: MetastoreRelation => { @@ -328,6 +334,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { TakeOrdered, ParquetOperations, InMemoryScans, + ParquetConversion, // Must be before HiveTableScans HiveTableScans, DataSinks, Scripts, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 354fcd53f303b..943bbaa8ce25e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -71,6 +71,9 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType)) + + // Hive seems to return this for struct types? + case c: Class[_] if c == classOf[java.lang.Object] => NullType } /** Converts hive types to native catalyst types. */ @@ -147,7 +150,10 @@ private[hive] trait HiveInspectors { case t: java.sql.Timestamp => t case s: Seq[_] => seqAsJavaList(s.map(wrap)) case m: Map[_,_] => - mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) }) + // Some UDFs seem to assume we pass in a HashMap. + val hashMap = new java.util.HashMap[AnyRef, AnyRef]() + hashMap.putAll(m.map { case (k, v) => wrap(k) -> wrap(v) }) + hashMap case null => null } @@ -214,6 +220,12 @@ private[hive] trait HiveInspectors { import TypeInfoFactory._ def toTypeInfo: TypeInfo = dt match { + case ArrayType(elemType, _) => + getListTypeInfo(elemType.toTypeInfo) + case StructType(fields) => + getStructTypeInfo(fields.map(_.name), fields.map(_.dataType.toTypeInfo)) + case MapType(keyType, valueType, _) => + getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo case BooleanType => booleanTypeInfo case ByteType => byteTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 301cf51c00e2b..3b371211e14cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.hive import scala.util.parsing.combinator.RegexParsers -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.{FieldSchema, StorageDescriptor, SerDeInfo} import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition} import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table} @@ -39,6 +37,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.hive.execution.HiveTableScan +import org.apache.spark.util.Utils /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -138,7 +137,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with castChildOutput(p, table, child) case p @ logical.InsertIntoTable( - InMemoryRelation(_, _, + InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), _, child, _) => castChildOutput(p, table, child) } @@ -288,7 +287,10 @@ private[hive] case class MetastoreRelation ) val tableDesc = new TableDesc( - Class.forName(hiveQlTable.getSerializationLib).asInstanceOf[Class[Deserializer]], + Class.forName( + hiveQlTable.getSerializationLib, + true, + Utils.getContextOrSparkClassLoader).asInstanceOf[Class[Deserializer]], hiveQlTable.getInputFormatClass, // The class of table should be org.apache.hadoop.hive.ql.metadata.Table because // getOutputFormatClass will use HiveFileFormatUtils.getOutputFormatSubstitute to diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index bc2fefafd58c8..1d9ba1b24a7a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -46,11 +46,14 @@ private[hive] case class AddFile(filePath: String) extends Command private[hive] case class DropTable(tableName: String, ifExists: Boolean) extends Command +private[hive] case class AnalyzeTable(tableName: String) extends Command + /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ private[hive] object HiveQl { protected val nativeCommands = Seq( "TOK_DESCFUNCTION", "TOK_DESCDATABASE", + "TOK_SHOW_CREATETABLE", "TOK_SHOW_TABLESTATUS", "TOK_SHOWDATABASES", "TOK_SHOWFUNCTIONS", @@ -74,7 +77,6 @@ private[hive] object HiveQl { "TOK_CREATEFUNCTION", "TOK_DROPFUNCTION", - "TOK_ANALYZE", "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", @@ -92,7 +94,6 @@ private[hive] object HiveQl { "TOK_ALTERTABLE_SKEWED", "TOK_ALTERTABLE_TOUCH", "TOK_ALTERTABLE_UNARCHIVE", - "TOK_ANALYZE", "TOK_CREATEDATABASE", "TOK_CREATEFUNCTION", "TOK_CREATEINDEX", @@ -239,7 +240,6 @@ private[hive] object HiveQl { ShellCommand(sql.drop(1)) } else { val tree = getAst(sql) - if (nativeCommands contains tree.getText) { NativeCommand(sql) } else { @@ -387,6 +387,22 @@ private[hive] object HiveQl { ifExists) => val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") DropTable(tableName, ifExists.nonEmpty) + // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" + case Token("TOK_ANALYZE", + Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: + isNoscan) => + // Reference: + // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables + if (partitionSpec.nonEmpty) { + // Analyze partitions will be treated as a Hive native command. + NativePlaceholder + } else if (isNoscan.isEmpty) { + // If users do not specify "noscan", it will be treated as a Hive native command. + NativePlaceholder + } else { + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + AnalyzeTable(tableName) + } // Just fake explain for any of the native commands. case Token("TOK_EXPLAIN", explainArgs) if noExplainCommands.contains(explainArgs.head.getText) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 2175c5f3835a6..389ace726d205 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,14 +17,20 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.SQLContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.parquet.{ParquetRelation, ParquetTableScan} + +import scala.collection.JavaConversions._ private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. @@ -32,6 +38,115 @@ private[hive] trait HiveStrategies { val hiveContext: HiveContext + /** + * :: Experimental :: + * Finds table scans that would use the Hive SerDe and replaces them with our own native parquet + * table scan operator. + * + * TODO: Much of this logic is duplicated in HiveTableScan. Ideally we would do some refactoring + * but since this is after the code freeze for 1.1 all logic is here to minimize disruption. + * + * Other issues: + * - Much of this logic assumes case insensitive resolution. + */ + @Experimental + object ParquetConversion extends Strategy { + implicit class LogicalPlanHacks(s: SchemaRDD) { + def lowerCase = + new SchemaRDD(s.sqlContext, LowerCaseSchema(s.logicalPlan)) + + def addPartitioningAttributes(attrs: Seq[Attribute]) = + new SchemaRDD( + s.sqlContext, + s.logicalPlan transform { + case p: ParquetRelation => p.copy(partitioningAttributes = attrs) + }) + } + + implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { + def fakeOutput(newOutput: Seq[Attribute]) = + OutputFaker( + originalPlan.output.map(a => + newOutput.find(a.name.toLowerCase == _.name.toLowerCase) + .getOrElse( + sys.error(s"Can't find attribute $a to fake in set ${newOutput.mkString(",")}"))), + originalPlan) + } + + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) + if relation.tableDesc.getSerdeClassName.contains("Parquet") && + hiveContext.convertMetastoreParquet => + + // Filter out all predicates that only deal with partition keys + val partitionKeyIds = relation.partitionKeys.map(_.exprId).toSet + val (pruningPredicates, otherPredicates) = predicates.partition { + _.references.map(_.exprId).subsetOf(partitionKeyIds) + } + + // We are going to throw the predicates and projection back at the whole optimization + // sequence so lets unresolve all the attributes, allowing them to be rebound to the + // matching parquet attributes. + val unresolvedOtherPredicates = otherPredicates.map(_ transform { + case a: AttributeReference => UnresolvedAttribute(a.name) + }).reduceOption(And).getOrElse(Literal(true)) + + val unresolvedProjection = projectList.map(_ transform { + case a: AttributeReference => UnresolvedAttribute(a.name) + }) + + if (relation.hiveQlTable.isPartitioned) { + val rawPredicate = pruningPredicates.reduceOption(And).getOrElse(Literal(true)) + // Translate the predicate so that it automatically casts the input values to the correct + // data types during evaluation + val castedPredicate = rawPredicate transform { + case a: AttributeReference => + val idx = relation.partitionKeys.indexWhere(a.exprId == _.exprId) + val key = relation.partitionKeys(idx) + Cast(BoundReference(idx, StringType, nullable = true), key.dataType) + } + + val inputData = new GenericMutableRow(relation.partitionKeys.size) + val pruningCondition = + if(codegenEnabled) { + GeneratePredicate(castedPredicate) + } else { + InterpretedPredicate(castedPredicate) + } + + val partitions = relation.hiveQlPartitions.filter { part => + val partitionValues = part.getValues + var i = 0 + while (i < partitionValues.size()) { + inputData(i) = partitionValues(i) + i += 1 + } + pruningCondition(inputData) + } + + hiveContext + .parquetFile(partitions.map(_.getLocation).mkString(",")) + .addPartitioningAttributes(relation.partitionKeys) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection:_*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)):: Nil + } else { + hiveContext + .parquetFile(relation.hiveQlTable.getDataLocation.getPath) + .lowerCase + .where(unresolvedOtherPredicates) + .select(unresolvedProjection:_*) + .queryExecution + .executedPlan + .fakeOutput(projectList.map(_.toAttribute)) :: Nil + } + case _ => Nil + } + } + object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ScriptTransformation(input, script, output, child) => @@ -45,7 +160,7 @@ private[hive] trait HiveStrategies { case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case logical.InsertIntoTable( - InMemoryRelation(_, _, + InMemoryRelation(_, _, _, HiveTableScan(_, table, _)), partition, child, overwrite) => InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil case _ => Nil @@ -83,6 +198,8 @@ private[hive] trait HiveStrategies { case DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil + case AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil + case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed resolvedTable match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index d890df866fbe5..a013f3f7a805f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -70,6 +70,13 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { setConf("hive.metastore.warehouse.dir", warehousePath) } + val testTempDir = File.createTempFile("testTempFiles", "spark.hive.tmp") + testTempDir.delete() + testTempDir.mkdir() + + // For some hive test case which contain ${system:test.tmp.dir} + System.setProperty("test.tmp.dir", testTempDir.getCanonicalPath) + configure() // Must be called before initializing the catalog below. /** The location of the compiled hive distribution */ @@ -109,6 +116,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { hiveFilesTemp.mkdir() hiveFilesTemp.deleteOnExit() + val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) { new File("src" + File.separator + "test" + File.separator + "resources" + File.separator) } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 8920e2a76a27f..577ca928b43b6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -72,17 +72,12 @@ case class HiveTableScan( } private def addColumnMetadataToConf(hiveConf: HiveConf) { - // Specifies IDs and internal names of columns to be scanned. - val neededColumnIDs = attributes.map(a => relation.output.indexWhere(_.name == a.name): Integer) - val columnInternalNames = neededColumnIDs.map(HiveConf.getColumnInternalName(_)).mkString(",") - - if (attributes.size == relation.output.size) { - // SQLContext#pruneFilterProject guarantees no duplicated value in `attributes` - ColumnProjectionUtils.setFullyReadColumns(hiveConf) - } else { - ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) - } + // Specifies needed column IDs for those non-partitioning columns. + val neededColumnIDs = + attributes.map(a => + relation.attributes.indexWhere(_.name == a.name): Integer).filter(index => index >= 0) + ColumnProjectionUtils.appendReadColumnIDs(hiveConf, neededColumnIDs) ColumnProjectionUtils.appendReadColumnNames(hiveConf, attributes.map(_.name)) // Specifies types and object inspectors of columns to be scanned. @@ -99,7 +94,7 @@ case class HiveTableScan( .mkString(",") hiveConf.set(serdeConstants.LIST_COLUMN_TYPES, columnTypeNames) - hiveConf.set(serdeConstants.LIST_COLUMNS, columnInternalNames) + hiveConf.set(serdeConstants.LIST_COLUMNS, relation.attributes.map(_.name).mkString(",")) } addColumnMetadataToConf(context.hiveconf) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala similarity index 72% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 9cd0c86c6c796..2985169da033c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DropTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -23,6 +23,32 @@ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.execution.{Command, LeafNode} import org.apache.spark.sql.hive.HiveContext +/** + * :: DeveloperApi :: + * Analyzes the given table in the current database to generate statistics, which will be + * used in query optimizations. + * + * Right now, it only supports Hive tables and it only updates the size of a Hive table + * in the Hive metastore. + */ +@DeveloperApi +case class AnalyzeTable(tableName: String) extends LeafNode with Command { + + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + def output = Seq.empty + + override protected[sql] lazy val sideEffectResult = { + hiveContext.analyze(tableName) + Seq.empty[Any] + } + + override def execute(): RDD[Row] = { + sideEffectResult + sparkContext.emptyRDD[Row] + } +} + /** * :: DeveloperApi :: * Drops a table from the metastore and removes it if it is cached. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 179aac5cbd5cd..c6497a15efa0c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -55,7 +55,10 @@ private[hive] abstract class HiveFunctionRegistry HiveSimpleUdf( functionClassName, - children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) } + children.zip(expectedDataTypes).map { + case (e, NullType) => e + case (e, t) => Cast(e, t) + } ) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUdf(functionClassName, children) @@ -115,22 +118,26 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ c.getParameterTypes.size == 1 && primitiveClasses.contains(c.getParameterTypes.head) } - val constructor = matchingConstructor.getOrElse( - sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) - - (a: Any) => { - logDebug( - s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") - // We must make sure that primitives get boxed java style. - if (a == null) { - null - } else { - constructor.newInstance(a match { - case i: Int => i: java.lang.Integer - case bd: BigDecimal => new HiveDecimal(bd.underlying()) - case other: AnyRef => other - }).asInstanceOf[AnyRef] - } + matchingConstructor match { + case Some(constructor) => + (a: Any) => { + logDebug( + s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} $constructor.") + // We must make sure that primitives get boxed java style. + if (a == null) { + null + } else { + constructor.newInstance(a match { + case i: Int => i: java.lang.Integer + case bd: BigDecimal => new HiveDecimal(bd.underlying()) + case other: AnyRef => other + }).asInstanceOf[AnyRef] + } + } + case None => + (a: Any) => a match { + case wrapper => wrap(wrapper) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala new file mode 100644 index 0000000000000..544abfc32423c --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.parquet + +import java.util.Properties + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category +import org.apache.hadoop.hive.serde2.{SerDeStats, SerDe} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.io.Writable + +/** + * A placeholder that allows SparkSQL users to create metastore tables that are stored as + * parquet files. It is only intended to pass the checks that the serde is valid and exists + * when a CREATE TABLE is run. The actual work of decoding will be done by ParquetTableScan + * when "spark.sql.hive.convertMetastoreParquet" is set to true. + */ +@deprecated("No code should depend on FakeParquetHiveSerDe as it is only intended as a " + + "placeholder in the Hive MetaStore") +class FakeParquetSerDe extends SerDe { + override def getObjectInspector: ObjectInspector = new ObjectInspector { + override def getCategory: Category = Category.PRIMITIVE + + override def getTypeName: String = "string" + } + + override def deserialize(p1: Writable): AnyRef = throwError + + override def initialize(p1: Configuration, p2: Properties): Unit = {} + + override def getSerializedClass: Class[_ <: Writable] = throwError + + override def getSerDeStats: SerDeStats = throwError + + override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = throwError + + private def throwError = + sys.error( + "spark.sql.hive.convertMetastoreParquet must be set to true to use FakeParquetSerDe") +} diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 b/sql/hive/src/test/resources/golden/show_create_table_alter-0-813886d6cf0875c62e89cd1d06b8b0b4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..3c1fc128bedce --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,18 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_alter-10-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 b/sql/hive/src/test/resources/golden/show_create_table_alter-2-928cc85c025440b731e5ee33e437e404 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..2ece813dd7d56 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-3-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,22 @@ +CREATE TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'temporary table' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'EXTERNAL'='FALSE', + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132100', + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 b/sql/hive/src/test/resources/golden/show_create_table_alter-4-c2cb6a7d942d4dddd1aababccb1239f9 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..2af657bd29506 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-5-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132100', + 'transient_lastDdlTime'='1407132100') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb b/sql/hive/src/test/resources/golden/show_create_table_alter-6-fdd1bd7f9acf0b2c8c9b7503d4046cb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..f793ffb7a0bfd --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-7-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132101', + 'transient_lastDdlTime'='1407132101') diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e b/sql/hive/src/test/resources/golden/show_create_table_alter-8-22ab6ed5b15a018756f454dd2294847e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..c65aff26a7fc1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_alter-9-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,21 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key smallint, + value float) +COMMENT 'changed comment' +CLUSTERED BY ( + key) +SORTED BY ( + value DESC) +INTO 5 BUCKETS +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED BY + 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' +WITH SERDEPROPERTIES ( + 'serialization.format'='1') +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'last_modified_by'='tianyi', + 'last_modified_time'='1407132101', + 'transient_lastDdlTime'='1407132101') diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 b/sql/hive/src/test/resources/golden/show_create_table_db_table-0-67509558a4b2d39b25787cca33f52635 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 b/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 new file mode 100644 index 0000000000000..707b2ae3ed1df --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-1-549981e00a3d95f03dd5a9ef6044aa20 @@ -0,0 +1,2 @@ +default +tmp_feng diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee b/sql/hive/src/test/resources/golden/show_create_table_db_table-2-34ae7e611d0aedbc62b6e420347abee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 b/sql/hive/src/test/resources/golden/show_create_table_db_table-3-7a9e67189d3d4151f23b12c22bde06b5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 new file mode 100644 index 0000000000000..b5a18368ed85e --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_db_table-4-b585371b624cbab2616a49f553a870a0 @@ -0,0 +1,13 @@ +CREATE TABLE tmp_feng.tmp_showcrt( + key string, + value int) +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_feng.db/tmp_showcrt' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132107') diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 b/sql/hive/src/test/resources/golden/show_create_table_db_table-5-964757b7e7f2a69fe36132c1a5712199 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e b/sql/hive/src/test/resources/golden/show_create_table_db_table-6-ac09cf81e7e734cf10406f30b9fa566e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 b/sql/hive/src/test/resources/golden/show_create_table_delimited-0-97228478b9925f06726ceebb6571bf34 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..d36ad25dc8273 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_delimited-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,17 @@ +CREATE TABLE tmp_showcrt1( + key int, + value string, + newvalue bigint) +ROW FORMAT DELIMITED + FIELDS TERMINATED BY ',' + COLLECTION ITEMS TERMINATED BY '|' + MAP KEYS TERMINATED BY '%' + LINES TERMINATED BY '\n' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132730') diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_delimited-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c b/sql/hive/src/test/resources/golden/show_create_table_partitioned-0-4be9a3b1ff0840786a1f001cba170a0c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..9e572c0d7df6a --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_partitioned-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,16 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key string, + newvalue boolean COMMENT 'a new value') +COMMENT 'temporary table' +PARTITIONED BY ( + value bigint COMMENT 'some value') +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.mapred.TextInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132112') diff --git a/sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_partitioned-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c b/sql/hive/src/test/resources/golden/show_create_table_serde-0-33f15d91810b75ee05c7b9dea0abb01c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..69a38e1a7b20a --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-1-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,15 @@ +CREATE TABLE tmp_showcrt1( + key int, + value string, + newvalue bigint) +COMMENT 'temporary table' +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' +STORED AS INPUTFORMAT + 'org.apache.hadoop.hive.ql.io.RCFileInputFormat' +OUTPUTFORMAT + 'org.apache.hadoop.hive.ql.io.RCFileOutputFormat' +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132115') diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_serde-2-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b b/sql/hive/src/test/resources/golden/show_create_table_serde-3-fd12b3e0fe30f5d71c67676791b4a33b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 b/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 new file mode 100644 index 0000000000000..b4e693dc622fb --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_serde-4-2a91d52719cf4552ebeb867204552a26 @@ -0,0 +1,14 @@ +CREATE EXTERNAL TABLE tmp_showcrt1( + key string, + value boolean) +ROW FORMAT SERDE + 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' +STORED BY + 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' +WITH SERDEPROPERTIES ( + 'serialization.format'='$', + 'field.delim'=',') +LOCATION + 'file:/tmp/sparkHiveWarehouse1280221975983654134/tmp_showcrt1' +TBLPROPERTIES ( + 'transient_lastDdlTime'='1407132115') diff --git a/sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de b/sql/hive/src/test/resources/golden/show_create_table_serde-5-259d978ed9543204c8b9c25b6e25b0de new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 b/sql/hive/src/test/resources/golden/show_create_table_view-0-ecef6821e4e9212e553ca38142fd0250 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 b/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 new file mode 100644 index 0000000000000..be3fb3ce30960 --- /dev/null +++ b/sql/hive/src/test/resources/golden/show_create_table_view-1-1e931ea3fa6065107859ffbb29bb0ed7 @@ -0,0 +1 @@ +CREATE VIEW tmp_copy_src AS SELECT `src`.`key`, `src`.`value` FROM `default`.`src` diff --git a/sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f b/sql/hive/src/test/resources/golden/show_create_table_view-2-ed97e9e56d95c5b3db57485cba5ad17f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index bf5931bbf97ee..7c82964b5ecdc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -19,13 +19,54 @@ package org.apache.spark.sql.hive import scala.reflect.ClassTag + import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.catalyst.plans.logical.NativeCommand import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ class StatisticsSuite extends QueryTest { + test("parse analyze commands") { + def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { + val parsed = HiveQl.parseSql(analyzeCommand) + val operators = parsed.collect { + case a: AnalyzeTable => a + case o => o + } + + assert(operators.size === 1) + if (operators(0).getClass() != c) { + fail( + s"""$analyzeCommand expected command: $c, but got ${operators(0)} + |parsed command: + |$parsed + """.stripMargin) + } + } + + assertAnalyzeCommand( + "ANALYZE TABLE Table1 COMPUTE STATISTICS", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS", + classOf[NativeCommand]) + assertAnalyzeCommand( + "ANALYZE TABLE Table1 PARTITION(ds, hr) COMPUTE STATISTICS noscan", + classOf[NativeCommand]) + + assertAnalyzeCommand( + "ANALYZE TABLE Table1 COMPUTE STATISTICS nOscAn", + classOf[AnalyzeTable]) + } + test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = catalog.lookupRelation(None, tableName).statistics.sizeInBytes @@ -37,7 +78,7 @@ class StatisticsSuite extends QueryTest { assert(queryTotalSize("analyzeTable") === defaultSizeInBytes) - analyze("analyzeTable") + sql("ANALYZE TABLE analyzeTable COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable") === BigInt(11624)) @@ -66,7 +107,7 @@ class StatisticsSuite extends QueryTest { assert(queryTotalSize("analyzeTable_part") === defaultSizeInBytes) - analyze("analyzeTable_part") + sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") assert(queryTotalSize("analyzeTable_part") === BigInt(17436)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 0ebaf6ffd5458..502ce8fb297e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -161,6 +161,7 @@ abstract class HiveComparisonTest "transient_lastDdlTime", "grantTime", "lastUpdateTime", + "last_modified_by", "last_modified_time", "Owner:", // The following are hive specific schema parameters which we do not need to match exactly. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala new file mode 100644 index 0000000000000..0723be7298e15 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala @@ -0,0 +1,171 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.io.File + +import org.apache.spark.sql.hive.execution.HiveTableScan +import org.scalatest.BeforeAndAfterAll + +import scala.reflect.ClassTag + +import org.apache.spark.sql.{SQLConf, QueryTest} +import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ + +case class ParquetData(intField: Int, stringField: String) + +/** + * Tests for our SerDe -> Native parquet scan conversion. + */ +class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + setConf("spark.sql.hive.convertMetastoreParquet", "true") + } + + override def afterAll(): Unit = { + setConf("spark.sql.hive.convertMetastoreParquet", "false") + } + + val partitionedTableDir = File.createTempFile("parquettests", "sparksql") + partitionedTableDir.delete() + partitionedTableDir.mkdir() + + (1 to 10).foreach { p => + val partDir = new File(partitionedTableDir, s"p=$p") + sparkContext.makeRDD(1 to 10) + .map(i => ParquetData(i, s"part-$p")) + .saveAsParquetFile(partDir.getCanonicalPath) + } + + sql(s""" + create external table partitioned_parquet + ( + intField INT, + stringField STRING + ) + PARTITIONED BY (p int) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${partitionedTableDir.getCanonicalPath}' + """) + + sql(s""" + create external table normal_parquet + ( + intField INT, + stringField STRING + ) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS + INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + location '${new File(partitionedTableDir, "p=1").getCanonicalPath}' + """) + + (1 to 10).foreach { p => + sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") + } + + test("project the partitioning column") { + checkAnswer( + sql("SELECT p, count(*) FROM partitioned_parquet group by p"), + (1, 10) :: + (2, 10) :: + (3, 10) :: + (4, 10) :: + (5, 10) :: + (6, 10) :: + (7, 10) :: + (8, 10) :: + (9, 10) :: + (10, 10) :: Nil + ) + } + + test("project partitioning and non-partitioning columns") { + checkAnswer( + sql("SELECT stringField, p, count(intField) " + + "FROM partitioned_parquet GROUP BY p, stringField"), + ("part-1", 1, 10) :: + ("part-2", 2, 10) :: + ("part-3", 3, 10) :: + ("part-4", 4, 10) :: + ("part-5", 5, 10) :: + ("part-6", 6, 10) :: + ("part-7", 7, 10) :: + ("part-8", 8, 10) :: + ("part-9", 9, 10) :: + ("part-10", 10, 10) :: Nil + ) + } + + test("simple count") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet"), + 100) + } + + test("pruned count") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet WHERE p = 1"), + 10) + } + + test("multi-partition pruned count") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet WHERE p IN (1,2,3)"), + 30) + } + + test("non-partition predicates") { + checkAnswer( + sql("SELECT COUNT(*) FROM partitioned_parquet WHERE intField IN (1,2,3)"), + 30) + } + + test("sum") { + checkAnswer( + sql("SELECT SUM(intField) FROM partitioned_parquet WHERE intField IN (1,2,3) AND p = 1"), + 1 + 2 + 3 + ) + } + + test("non-part select(*)") { + checkAnswer( + sql("SELECT COUNT(*) FROM normal_parquet"), + 10 + ) + } + + test("conversion is working") { + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: HiveTableScan => true + }.isEmpty) + assert( + sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { + case _: ParquetTableScan => true + }.nonEmpty) + } +} diff --git a/streaming/pom.xml b/streaming/pom.xml index 1072f74aea0d9..ce35520a28609 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -81,11 +81,11 @@ org.apache.maven.plugins @@ -97,8 +97,8 @@ - test-jar-on-compile - compile + test-jar-on-test-compile + test-compile test-jar diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index e0677b795cb94..101cec1c7a7c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -98,9 +98,15 @@ class StreamingContext private[streaming] ( * @param hadoopConf Optional, configuration object if necessary for reading from * HDFS compatible filesystems */ - def this(path: String, hadoopConf: Configuration = new Configuration) = + def this(path: String, hadoopConf: Configuration) = this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null) + /** + * Recreate a StreamingContext from a checkpoint file. + * @param path Path to the directory that was specified as the checkpoint directory + */ + def this(path: String) = this(path, new Configuration) + if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + "both SparkContext and checkpoint as null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala index 774adc3c23c21..75f0e8716dc7e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingSource.scala @@ -23,10 +23,10 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.streaming.ui.StreamingJobProgressListener private[streaming] class StreamingSource(ssc: StreamingContext) extends Source { - val metricRegistry = new MetricRegistry - val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) + override val metricRegistry = new MetricRegistry + override val sourceName = "%s.StreamingMetrics".format(ssc.sparkContext.appName) - val streamingListener = ssc.uiTab.listener + private val streamingListener = ssc.uiTab.listener private def registerGauge[T](name: String, f: StreamingJobProgressListener => T, defaultValue: T) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 451b23e01c995..1353e487c72cf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -42,8 +42,7 @@ private[ui] class StreamingPage(parent: StreamingTab)

    Statistics over last {listener.retainedCompletedBatches.size} processed batches

    ++ generateReceiverStats() ++ generateBatchStatsTable() - UIUtils.headerSparkPage( - content, parent.basePath, parent.appName, "Streaming", parent.headerTabs, parent, Some(5000)) + UIUtils.headerSparkPage("Streaming", content, parent, Some(5000)) } /** Generate basic stats of the streaming program */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index 51448d15c6516..34ac254f337eb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -19,15 +19,13 @@ package org.apache.spark.streaming.ui import org.apache.spark.Logging import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.WebUITab +import org.apache.spark.ui.SparkUITab /** Spark Web UI tab that shows statistics of a streaming job */ private[spark] class StreamingTab(ssc: StreamingContext) - extends WebUITab(ssc.sc.ui, "streaming") with Logging { + extends SparkUITab(ssc.sc.ui, "streaming") with Logging { val parent = ssc.sc.ui - val appName = parent.appName - val basePath = parent.basePath val listener = new StreamingJobProgressListener(ssc) ssc.addStreamingListener(listener) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index cc178fba12c9d..759baacaa4308 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -17,18 +17,18 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} -import org.apache.spark.streaming.util.ManualClock +import java.io.{ObjectInputStream, IOException} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.SynchronizedBuffer import scala.reflect.ClassTag -import java.io.{ObjectInputStream, IOException} - import org.scalatest.{BeforeAndAfter, FunSuite} +import com.google.common.io.Files -import org.apache.spark.{SparkContext, SparkConf, Logging} +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.rdd.RDD /** @@ -119,7 +119,12 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { def batchDuration = Seconds(1) // Directory where the checkpoint data will be saved - def checkpointDir = "checkpoint" + lazy val checkpointDir = { + val dir = Files.createTempDir() + logDebug(s"checkpointDir: $dir") + dir.deleteOnExit() + dir.toString + } // Number of partitions of the input parallel collections created for testing def numInputPartitions = 2 @@ -242,7 +247,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + val outputStream = ssc.graph.getOutputStreams. + filter(_.isInstanceOf[TestOutputStreamWithPartitions[_]]). + head.asInstanceOf[TestOutputStreamWithPartitions[V]] val output = outputStream.output try { diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 8a05fcb449aa6..17bf7c2541d13 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils +import org.apache.spark.executor.ShuffleWriteMetrics /** * Internal utility for micro-benchmarking shuffle write performance. @@ -56,7 +57,7 @@ object StoragePerfTester { def writeOutputBytes(mapId: Int, total: AtomicLong) = { val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, - new KryoSerializer(sc.conf)) + new KryoSerializer(sc.conf), new ShuffleWriteMetrics()) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { writers(i % numOutputSplits).write(writeData) diff --git a/tox.ini b/tox.ini index 44766e529bf7f..a1fefdd0e176f 100644 --- a/tox.ini +++ b/tox.ini @@ -15,3 +15,4 @@ [pep8] max-line-length=100 +exclude=cloudpickle.py diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 62b5c3bc5f0f3..4d4848b1bd8f8 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -72,10 +72,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var registered = false def run() { - // Setup the directories so things go to yarn approved directories rather - // then user specified and /tmp. - System.setProperty("spark.local.dir", getLocalDirs()) - // set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") @@ -138,20 +134,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, params) } - /** Get the Yarn approved local directories. */ - private def getLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .orElse(Option(System.getenv("LOCAL_DIRS"))) - - localDirs match { - case None => throw new Exception("Yarn Local dirs can't be empty") - case Some(l) => l - } - } - private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) @@ -267,12 +249,10 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - // Exists the loop if the user thread exits. - while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } + // Exits the loop if the user thread exits. + while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive + && !isFinished) { + checkNumExecutorsFailed() yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) @@ -303,11 +283,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val t = new Thread { override def run() { - while (userThread.isAlive) { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } + while (userThread.isAlive && !isFinished) { + checkNumExecutorsFailed() val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning if (missingExecutorCount > 0) { logInfo("Allocating %d containers to make up for (potentially) lost containers". @@ -327,6 +304,22 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, t } + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + logInfo("max number of executor failures reached") + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + // make sure to stop the user thread + val sparkContext = ApplicationMaster.sparkContextRef.get() + if (sparkContext != null) { + logInfo("Invoking sc stop from checkNumExecutorsFailed") + sparkContext.stop() + } else { + logError("sparkContext is null when should shutdown") + } + } + } + private def sendProgress() { logDebug("Sending progress") // Simulated with an allocate request with no nodes requested ... diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 184e2ad6c82cd..c3310fbc24a98 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -95,11 +95,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } def run() { - - // Setup the directories so things go to yarn approved directories rather - // then user specified and /tmp. - System.setProperty("spark.local.dir", getLocalDirs()) - appAttemptId = getApplicationAttemptId() resourceManager = registerWithResourceManager() @@ -152,20 +147,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp System.exit(0) } - /** Get the Yarn approved local directories. */ - private def getLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .orElse(Option(System.getenv("LOCAL_DIRS"))) - - localDirs match { - case None => throw new Exception("Yarn Local dirs can't be empty") - case Some(l) => l - } - } - private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) @@ -249,7 +230,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp // Wait until all containers have finished // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed) && + !isFinished) { yarnAllocator.allocateContainers( math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0)) checkNumExecutorsFailed() @@ -271,7 +253,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { - while (!driverClosed) { + while (!driverClosed && !isFinished) { checkNumExecutorsFailed() val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning if (missingExecutorCount > 0) { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 4c383ab574abe..424b0fb0936f2 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -29,7 +29,7 @@ class ApplicationMasterArguments(val args: Array[String]) { var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS parseArgs(args.toList) - + private def parseArgs(inputArgs: List[String]): Unit = { val userArgsBuffer = new ArrayBuffer[String]() @@ -47,7 +47,7 @@ class ApplicationMasterArguments(val args: Array[String]) { userClass = value args = tail - case ("--args") :: value :: tail => + case ("--args" | "--arg") :: value :: tail => userArgsBuffer += value args = tail @@ -75,7 +75,7 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgs = userArgsBuffer.readOnly } - + def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { if (unknownParam != null) { System.err.println("Unknown/unsupported param " + unknownParam) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 1da0a1b675554..3897b3a373a8c 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -300,11 +300,11 @@ trait ClientBase extends Logging { } def userArgsToString(clientArgs: ClientArguments): String = { - val prefix = " --args " + val prefix = " --arg " val args = clientArgs.userArgs val retval = new StringBuilder() for (arg <- args) { - retval.append(prefix).append(" '").append(arg).append("' ") + retval.append(prefix).append(" ").append(YarnSparkHadoopUtil.escapeForShell(arg)) } retval.toString } @@ -386,7 +386,7 @@ trait ClientBase extends Logging { // TODO: it might be nicer to pass these as an internal environment variable rather than // as Java options, due to complications with string parsing of nested quotes. for ((k, v) <- sparkConf.getAll) { - javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" + javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } if (args.amClass == classOf[ApplicationMaster].getName) { @@ -400,7 +400,8 @@ trait ClientBase extends Logging { // Command for the ApplicationMaster val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ javaOpts ++ - Seq(args.amClass, "--class", args.userClass, "--jar ", args.userJar, + Seq(args.amClass, "--class", YarnSparkHadoopUtil.escapeForShell(args.userClass), + "--jar ", YarnSparkHadoopUtil.escapeForShell(args.userJar), userArgsToString(args), "--executor-memory", args.executorMemory.toString, "--executor-cores", args.executorCores.toString, diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index 71a9e42846b2b..312d82a649792 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -68,10 +68,10 @@ trait ExecutorRunnableUtil extends Logging { // authentication settings. sparkConf.getAll. filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }. - foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } sparkConf.getAkkaConf. - foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" } + foreach { case (k, v) => javaOpts += YarnSparkHadoopUtil.escapeForShell(s"-D$k=$v") } // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index e98308cdbd74e..10aef5eb2486f 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -148,4 +148,29 @@ object YarnSparkHadoopUtil { } } + /** + * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands + * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The + * argument is enclosed in single quotes and some key characters are escaped. + * + * @param arg A single argument. + * @return Argument quoted for execution via Yarn's generated shell script. + */ + def escapeForShell(arg: String): String = { + if (arg != null) { + val escaped = new StringBuilder("'") + for (i <- 0 to arg.length() - 1) { + arg.charAt(i) match { + case '$' => escaped.append("\\$") + case '"' => escaped.append("\\\"") + case '\'' => escaped.append("'\\''") + case c => escaped.append(c) + } + } + escaped.append("'").toString() + } else { + arg + } + } + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index f8fb96b312f23..833e249f9f612 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -30,15 +30,15 @@ private[spark] class YarnClientSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with Logging { - if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 - ready = false } var client: Client = null var appId: ApplicationId = null var checkerThread: Thread = null var stopping: Boolean = false + var totalExpectedExecutors = 0 private[spark] def addArg(optionName: String, envVar: String, sysProp: String, arrayBuf: ArrayBuffer[String]) { @@ -84,7 +84,7 @@ private[spark] class YarnClientSchedulerBackend( logDebug("ClientArguments called with: " + argsArrayBuf) val args = new ClientArguments(argsArrayBuf.toArray, conf) - totalExpectedExecutors.set(args.numExecutors) + totalExpectedExecutors = args.numExecutors client = new Client(args, conf) appId = client.runApp() waitForApp() @@ -150,4 +150,7 @@ private[spark] class YarnClientSchedulerBackend( logInfo("Stopped") } + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio + } } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 0ad1794d19538..55665220a6f96 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -27,19 +27,24 @@ private[spark] class YarnClusterSchedulerBackend( sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { - if (conf.getOption("spark.scheduler.minRegisteredExecutorsRatio").isEmpty) { + var totalExpectedExecutors = 0 + + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 - ready = false } override def start() { super.start() - var numExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS + totalExpectedExecutors = ApplicationMasterArguments.DEFAULT_NUMBER_EXECUTORS if (System.getenv("SPARK_EXECUTOR_INSTANCES") != null) { - numExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")).getOrElse(numExecutors) + totalExpectedExecutors = IntParam.unapply(System.getenv("SPARK_EXECUTOR_INSTANCES")) + .getOrElse(totalExpectedExecutors) } // System property can override environment variable. - numExecutors = sc.getConf.getInt("spark.executor.instances", numExecutors) - totalExpectedExecutors.set(numExecutors) + totalExpectedExecutors = sc.getConf.getInt("spark.executor.instances", totalExpectedExecutors) + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala new file mode 100644 index 0000000000000..7650bd4396c12 --- /dev/null +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{File, IOException} + +import com.google.common.io.{ByteStreams, Files} +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.Logging + +class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { + + val hasBash = + try { + val exitCode = Runtime.getRuntime().exec(Array("bash", "--version")).waitFor() + exitCode == 0 + } catch { + case e: IOException => + false + } + + if (!hasBash) { + logWarning("Cannot execute bash, skipping bash tests.") + } + + def bashTest(name: String)(fn: => Unit) = + if (hasBash) test(name)(fn) else ignore(name)(fn) + + bashTest("shell script escaping") { + val scriptFile = File.createTempFile("script.", ".sh") + val args = Array("arg1", "${arg.2}", "\"arg3\"", "'arg4'", "$arg5", "\\arg6") + try { + val argLine = args.map(a => YarnSparkHadoopUtil.escapeForShell(a)).mkString(" ") + Files.write(("bash -c \"echo " + argLine + "\"").getBytes(), scriptFile) + scriptFile.setExecutable(true) + + val proc = Runtime.getRuntime().exec(Array(scriptFile.getAbsolutePath())) + val out = new String(ByteStreams.toByteArray(proc.getInputStream())).trim() + val err = new String(ByteStreams.toByteArray(proc.getErrorStream())) + val exitCode = proc.waitFor() + exitCode should be (0) + out should be (args.mkString(" ")) + } finally { + scriptFile.delete() + } + } + +} diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 035356d390c80..1c4005fd8e78e 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -72,10 +72,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private var registered = false def run() { - // Setup the directories so things go to YARN approved directories rather - // than user specified and /tmp. - System.setProperty("spark.local.dir", getLocalDirs()) - // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box System.setProperty("spark.ui.port", "0") @@ -144,20 +140,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, "spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) } - // Get the Yarn approved local directories. - private def getLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .orElse(Option(System.getenv("LOCAL_DIRS"))) - - localDirs match { - case None => throw new Exception("Yarn local dirs can't be empty") - case Some(l) => l - } - } - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { logInfo("Registering the ApplicationMaster") amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) @@ -247,13 +229,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, yarnAllocator.allocateResources() // Exits the loop if the user thread exits. - var iters = 0 - while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { + while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive + && !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL) - iters += 1 } } logInfo("All executors have launched.") @@ -271,8 +252,17 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, private def checkNumExecutorsFailed() { if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + logInfo("max number of executor failures reached") finishApplicationMaster(FinalApplicationStatus.FAILED, "max number of executor failures reached") + // make sure to stop the user thread + val sparkContext = ApplicationMaster.sparkContextRef.get() + if (sparkContext != null) { + logInfo("Invoking sc stop from checkNumExecutorsFailed") + sparkContext.stop() + } else { + logError("sparkContext is null when should shutdown") + } } } @@ -289,7 +279,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val t = new Thread { override def run() { - while (userThread.isAlive) { + while (userThread.isAlive && !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() logDebug("Sending progress") diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index fc7b8320d734d..45925f1fea005 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -94,11 +94,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp } def run() { - - // Setup the directories so things go to yarn approved directories rather - // then user specified and /tmp. - System.setProperty("spark.local.dir", getLocalDirs()) - amClient = AMRMClient.createAMRMClient() amClient.init(yarnConf) amClient.start() @@ -141,20 +136,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp System.exit(0) } - /** Get the Yarn approved local directories. */ - private def getLocalDirs(): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .orElse(Option(System.getenv("LOCAL_DIRS"))) - - localDirs match { - case None => throw new Exception("Yarn Local dirs can't be empty") - case Some(l) => l - } - } - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { val appUIAddress = sparkConf.get("spark.driver.appUIAddress", "") logInfo(s"Registering the ApplicationMaster with appUIAddress: $appUIAddress") @@ -217,7 +198,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp // Wait until all containers have launched yarnAllocator.addResourceRequests(args.numExecutors) yarnAllocator.allocateResources() - while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed) && + !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() yarnAllocator.allocateResources() @@ -249,7 +231,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { - while (!driverClosed) { + while (!driverClosed && !isFinished) { checkNumExecutorsFailed() allocateMissingExecutor() logDebug("Sending progress")