From 5930f64bf0d2516304b21bd49eac361a54caabdd Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 14 Nov 2014 14:33:37 -0800 Subject: [PATCH 01/22] [SPARK-4062][Streaming]Add ReliableKafkaReceiver in Spark Streaming Kafka connector Add ReliableKafkaReceiver in Kafka connector to prevent data loss if WAL in Spark Streaming is enabled. Details and design doc can be seen in [SPARK-4062](https://issues.apache.org/jira/browse/SPARK-4062). Author: jerryshao Author: Tathagata Das Author: Saisai Shao Closes #2991 from jerryshao/kafka-refactor and squashes the following commits: 5461f1c [Saisai Shao] Merge pull request #8 from tdas/kafka-refactor3 eae4ad6 [Tathagata Das] Refectored KafkaStreamSuiteBased to eliminate KafkaTestUtils and made Java more robust. fab14c7 [Tathagata Das] minor update. 149948b [Tathagata Das] Fixed mistake 14630aa [Tathagata Das] Minor updates. d9a452c [Tathagata Das] Minor updates. ec2e95e [Tathagata Das] Removed the receiver's locks and essentially reverted to Saisai's original design. 2a20a01 [jerryshao] Address some comments 9f636b3 [Saisai Shao] Merge pull request #5 from tdas/kafka-refactor b2b2f84 [Tathagata Das] Refactored Kafka receiver logic and Kafka testsuites e501b3c [jerryshao] Add Mima excludes b798535 [jerryshao] Fix the missed issue e5e21c1 [jerryshao] Change to while loop ea873e4 [jerryshao] Further address the comments 98f3d07 [jerryshao] Fix comment style 4854ee9 [jerryshao] Address all the comments 96c7a1d [jerryshao] Update the ReliableKafkaReceiver unit test 8135d31 [jerryshao] Fix flaky test a949741 [jerryshao] Address the comments 16bfe78 [jerryshao] Change the ordering of imports 0894aef [jerryshao] Add some comments 77c3e50 [jerryshao] Code refactor and add some unit tests dd9aeeb [jerryshao] Initial commit for reliable Kafka receiver --- .../streaming/kafka/KafkaInputDStream.scala | 33 +- .../spark/streaming/kafka/KafkaUtils.scala | 4 +- .../kafka/ReliableKafkaReceiver.scala | 282 ++++++++++++++++++ .../streaming/kafka/JavaKafkaStreamSuite.java | 44 +-- .../streaming/kafka/KafkaStreamSuite.scala | 216 ++++++++------ .../kafka/ReliableKafkaStreamSuite.scala | 140 +++++++++ project/MimaExcludes.scala | 4 + .../streaming/receiver/BlockGenerator.scala | 55 +++- .../receiver/ReceiverSupervisorImpl.scala | 8 +- .../spark/streaming/ReceiverSuite.scala | 8 +- 10 files changed, 651 insertions(+), 143 deletions(-) create mode 100644 external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala create mode 100644 external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 28ac5929df44a..4d26b640e8d74 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -17,13 +17,12 @@ package org.apache.spark.streaming.kafka +import java.util.Properties + import scala.collection.Map import scala.reflect.{classTag, ClassTag} -import java.util.Properties -import java.util.concurrent.Executors - -import kafka.consumer._ +import kafka.consumer.{KafkaStream, Consumer, ConsumerConfig, ConsumerConnector} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties @@ -32,6 +31,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils /** * Input stream that pulls messages from a Kafka Broker. @@ -51,12 +51,16 @@ class KafkaInputDStream[ @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], + useReliableReceiver: Boolean, storageLevel: StorageLevel ) extends ReceiverInputDStream[(K, V)](ssc_) with Logging { def getReceiver(): Receiver[(K, V)] = { - new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) - .asInstanceOf[Receiver[(K, V)]] + if (!useReliableReceiver) { + new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) + } else { + new ReliableKafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) + } } } @@ -69,14 +73,15 @@ class KafkaReceiver[ kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel - ) extends Receiver[Any](storageLevel) with Logging { + ) extends Receiver[(K, V)](storageLevel) with Logging { // Connection to Kafka - var consumerConnector : ConsumerConnector = null + var consumerConnector: ConsumerConnector = null def onStop() { if (consumerConnector != null) { consumerConnector.shutdown() + consumerConnector = null } } @@ -102,11 +107,11 @@ class KafkaReceiver[ .newInstance(consumerConfig.props) .asInstanceOf[Decoder[V]] - // Create Threads for each Topic/Message Stream we are listening + // Create threads for each topic/message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - val executorPool = Executors.newFixedThreadPool(topics.values.sum) + val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") try { // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => @@ -117,13 +122,15 @@ class KafkaReceiver[ } } - // Handles Kafka Messages - private class MessageHandler[K: ClassTag, V: ClassTag](stream: KafkaStream[K, V]) + // Handles Kafka messages + private class MessageHandler(stream: KafkaStream[K, V]) extends Runnable { def run() { logInfo("Starting MessageHandler.") try { - for (msgAndMetadata <- stream) { + val streamIterator = stream.iterator() + while (streamIterator.hasNext()) { + val msgAndMetadata = streamIterator.next() store((msgAndMetadata.key, msgAndMetadata.message)) } } catch { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index ec812e1ef3b04..b4ac929e0c070 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -70,7 +70,8 @@ object KafkaUtils { topics: Map[String, Int], storageLevel: StorageLevel ): ReceiverInputDStream[(K, V)] = { - new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, storageLevel) + val walEnabled = ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false) + new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, walEnabled, storageLevel) } /** @@ -99,7 +100,6 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. - * */ def createStream( jssc: JavaStreamingContext, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala new file mode 100644 index 0000000000000..be734b80272d1 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.util.Properties +import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap} + +import scala.collection.{Map, mutable} +import scala.reflect.{ClassTag, classTag} + +import kafka.common.TopicAndPartition +import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} +import kafka.message.MessageAndMetadata +import kafka.serializer.Decoder +import kafka.utils.{VerifiableProperties, ZKGroupTopicDirs, ZKStringSerializer, ZkUtils} +import org.I0Itec.zkclient.ZkClient + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} +import org.apache.spark.util.Utils + +/** + * ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss. + * It is turned off by default and will be enabled when + * spark.streaming.receiver.writeAheadLog.enable is true. The difference compared to KafkaReceiver + * is that this receiver manages topic-partition/offset itself and updates the offset information + * after data is reliably stored as write-ahead log. Offsets will only be updated when data is + * reliably stored, so the potential data loss problem of KafkaReceiver can be eliminated. + * + * Note: ReliableKafkaReceiver will set auto.commit.enable to false to turn off automatic offset + * commit mechanism in Kafka consumer. So setting this configuration manually within kafkaParams + * will not take effect. + */ +private[streaming] +class ReliableKafkaReceiver[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( + kafkaParams: Map[String, String], + topics: Map[String, Int], + storageLevel: StorageLevel) + extends Receiver[(K, V)](storageLevel) with Logging { + + private val groupId = kafkaParams("group.id") + private val AUTO_OFFSET_COMMIT = "auto.commit.enable" + private def conf = SparkEnv.get.conf + + /** High level consumer to connect to Kafka. */ + private var consumerConnector: ConsumerConnector = null + + /** zkClient to connect to Zookeeper to commit the offsets. */ + private var zkClient: ZkClient = null + + /** + * A HashMap to manage the offset for each topic/partition, this HashMap is called in + * synchronized block, so mutable HashMap will not meet concurrency issue. + */ + private var topicPartitionOffsetMap: mutable.HashMap[TopicAndPartition, Long] = null + + /** A concurrent HashMap to store the stream block id and related offset snapshot. */ + private var blockOffsetMap: ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]] = null + + /** + * Manage the BlockGenerator in receiver itself for better managing block store and offset + * commit. + */ + private var blockGenerator: BlockGenerator = null + + /** Thread pool running the handlers for receiving message from multiple topics and partitions. */ + private var messageHandlerThreadPool: ThreadPoolExecutor = null + + override def onStart(): Unit = { + logInfo(s"Starting Kafka Consumer Stream with group: $groupId") + + // Initialize the topic-partition / offset hash map. + topicPartitionOffsetMap = new mutable.HashMap[TopicAndPartition, Long] + + // Initialize the stream block id / offset snapshot hash map. + blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]() + + // Initialize the block generator for storing Kafka message. + blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf) + + if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") { + logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " + + "otherwise we will manually set it to false to turn off auto offset commit in Kafka") + } + + val props = new Properties() + kafkaParams.foreach(param => props.put(param._1, param._2)) + // Manually set "auto.commit.enable" to "false" no matter user explicitly set it to true, + // we have to make sure this property is set to false to turn off auto commit mechanism in + // Kafka. + props.setProperty(AUTO_OFFSET_COMMIT, "false") + + val consumerConfig = new ConsumerConfig(props) + + assert(!consumerConfig.autoCommitEnable) + + logInfo(s"Connecting to Zookeeper: ${consumerConfig.zkConnect}") + consumerConnector = Consumer.create(consumerConfig) + logInfo(s"Connected to Zookeeper: ${consumerConfig.zkConnect}") + + zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs, + consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer) + + messageHandlerThreadPool = Utils.newDaemonFixedThreadPool( + topics.values.sum, "KafkaMessageHandler") + + blockGenerator.start() + + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(consumerConfig.props) + .asInstanceOf[Decoder[K]] + + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(consumerConfig.props) + .asInstanceOf[Decoder[V]] + + val topicMessageStreams = consumerConnector.createMessageStreams( + topics, keyDecoder, valueDecoder) + + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => + messageHandlerThreadPool.submit(new MessageHandler(stream)) + } + } + } + + override def onStop(): Unit = { + if (messageHandlerThreadPool != null) { + messageHandlerThreadPool.shutdown() + messageHandlerThreadPool = null + } + + if (consumerConnector != null) { + consumerConnector.shutdown() + consumerConnector = null + } + + if (zkClient != null) { + zkClient.close() + zkClient = null + } + + if (blockGenerator != null) { + blockGenerator.stop() + blockGenerator = null + } + + if (topicPartitionOffsetMap != null) { + topicPartitionOffsetMap.clear() + topicPartitionOffsetMap = null + } + + if (blockOffsetMap != null) { + blockOffsetMap.clear() + blockOffsetMap = null + } + } + + /** Store a Kafka message and the associated metadata as a tuple. */ + private def storeMessageAndMetadata( + msgAndMetadata: MessageAndMetadata[K, V]): Unit = { + val topicAndPartition = TopicAndPartition(msgAndMetadata.topic, msgAndMetadata.partition) + val data = (msgAndMetadata.key, msgAndMetadata.message) + val metadata = (topicAndPartition, msgAndMetadata.offset) + blockGenerator.addDataWithCallback(data, metadata) + } + + /** Update stored offset */ + private def updateOffset(topicAndPartition: TopicAndPartition, offset: Long): Unit = { + topicPartitionOffsetMap.put(topicAndPartition, offset) + } + + /** + * Remember the current offsets for each topic and partition. This is called when a block is + * generated. + */ + private def rememberBlockOffsets(blockId: StreamBlockId): Unit = { + // Get a snapshot of current offset map and store with related block id. + val offsetSnapshot = topicPartitionOffsetMap.toMap + blockOffsetMap.put(blockId, offsetSnapshot) + topicPartitionOffsetMap.clear() + } + + /** Store the ready-to-be-stored block and commit the related offsets to zookeeper. */ + private def storeBlockAndCommitOffset( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) + Option(blockOffsetMap.get(blockId)).foreach(commitOffset) + blockOffsetMap.remove(blockId) + } + + /** + * Commit the offset of Kafka's topic/partition, the commit mechanism follow Kafka 0.8.x's + * metadata schema in Zookeeper. + */ + private def commitOffset(offsetMap: Map[TopicAndPartition, Long]): Unit = { + if (zkClient == null) { + val thrown = new IllegalStateException("Zookeeper client is unexpectedly null") + stop("Zookeeper client is not initialized before commit offsets to ZK", thrown) + return + } + + for ((topicAndPart, offset) <- offsetMap) { + try { + val topicDirs = new ZKGroupTopicDirs(groupId, topicAndPart.topic) + val zkPath = s"${topicDirs.consumerOffsetDir}/${topicAndPart.partition}" + + ZkUtils.updatePersistentPath(zkClient, zkPath, offset.toString) + } catch { + case e: Exception => + logWarning(s"Exception during commit offset $offset for topic" + + s"${topicAndPart.topic}, partition ${topicAndPart.partition}", e) + } + + logInfo(s"Committed offset $offset for topic ${topicAndPart.topic}, " + + s"partition ${topicAndPart.partition}") + } + } + + /** Class to handle received Kafka message. */ + private final class MessageHandler(stream: KafkaStream[K, V]) extends Runnable { + override def run(): Unit = { + while (!isStopped) { + try { + val streamIterator = stream.iterator() + while (streamIterator.hasNext) { + storeMessageAndMetadata(streamIterator.next) + } + } catch { + case e: Exception => + logError("Error handling message", e) + } + } + } + } + + /** Class to handle blocks generated by the block generator. */ + private final class GeneratedBlockHandler extends BlockGeneratorListener { + + def onAddData(data: Any, metadata: Any): Unit = { + // Update the offset of the data that was added to the generator + if (metadata != null) { + val (topicAndPartition, offset) = metadata.asInstanceOf[(TopicAndPartition, Long)] + updateOffset(topicAndPartition, offset) + } + } + + def onGenerateBlock(blockId: StreamBlockId): Unit = { + // Remember the offsets of topics/partitions when a block has been generated + rememberBlockOffsets(blockId) + } + + def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + // Store block and commit the blocks offset + storeBlockAndCommitOffset(blockId, arrayBuffer) + } + + def onError(message: String, throwable: Throwable): Unit = { + reportError(message, throwable) + } + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index efb0099c7c850..6e1abf3f385ee 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -20,7 +20,10 @@ import java.io.Serializable; import java.util.HashMap; import java.util.List; +import java.util.Random; +import org.apache.spark.SparkConf; +import org.apache.spark.streaming.Duration; import scala.Predef; import scala.Tuple2; import scala.collection.JavaConverters; @@ -32,8 +35,6 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.LocalJavaStreamingContext; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -42,25 +43,27 @@ import org.junit.After; import org.junit.Before; -public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable { - private transient KafkaStreamSuite testSuite = new KafkaStreamSuite(); +public class JavaKafkaStreamSuite implements Serializable { + private transient JavaStreamingContext ssc = null; + private transient Random random = new Random(); + private transient KafkaStreamSuiteBase suiteBase = null; @Before - @Override public void setUp() { - testSuite.beforeFunction(); + suiteBase = new KafkaStreamSuiteBase() { }; + suiteBase.setupKafka(); System.clearProperty("spark.driver.port"); - //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, new Duration(500)); } @After - @Override public void tearDown() { ssc.stop(); ssc = null; System.clearProperty("spark.driver.port"); - testSuite.afterFunction(); + suiteBase.tearDownKafka(); } @Test @@ -74,15 +77,15 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - testSuite.createTopic(topic); + suiteBase.createTopic(topic); HashMap tmp = new HashMap(sent); - testSuite.produceAndSendMessage(topic, - JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms())); + suiteBase.produceAndSendMessage(topic, + JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( + Predef.>conforms())); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", testSuite.zkHost() + ":" + testSuite.zkPort()); - kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); + kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); + kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); JavaPairDStream stream = KafkaUtils.createStream(ssc, @@ -124,11 +127,16 @@ public Void call(JavaPairRDD rdd) throws Exception { ); ssc.start(); - ssc.awaitTermination(3000); - + long startTime = System.currentTimeMillis(); + boolean sizeMatches = false; + while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { + sizeMatches = sent.size() == result.size(); + Thread.sleep(200); + } Assert.assertEquals(sent.size(), result.size()); for (String k : sent.keySet()) { Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); } + ssc.stop(); } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 6943326eb750e..b19c053ebfc44 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -19,51 +19,57 @@ package org.apache.spark.streaming.kafka import java.io.File import java.net.InetSocketAddress -import java.util.{Properties, Random} +import java.util.Properties import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random import kafka.admin.CreateTopicCommand import kafka.common.{KafkaException, TopicAndPartition} -import kafka.producer.{KeyedMessage, ProducerConfig, Producer} -import kafka.utils.ZKStringSerializer +import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.{StringDecoder, StringEncoder} import kafka.server.{KafkaConfig, KafkaServer} - +import kafka.utils.ZKStringSerializer import org.I0Itec.zkclient.ZkClient +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually -import org.apache.zookeeper.server.ZooKeeperServer -import org.apache.zookeeper.server.NIOServerCnxnFactory - -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class KafkaStreamSuite extends TestSuiteBase { - import KafkaTestUtils._ - - val zkHost = "localhost" - var zkPort: Int = 0 - val zkConnectionTimeout = 6000 - val zkSessionTimeout = 6000 - - protected var brokerPort = 9092 - protected var brokerConf: KafkaConfig = _ - protected var zookeeper: EmbeddedZookeeper = _ - protected var zkClient: ZkClient = _ - protected var server: KafkaServer = _ - protected var producer: Producer[String, String] = _ - - override def useManualClock = false - - override def beforeFunction() { +/** + * This is an abstract base class for Kafka testsuites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + */ +abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { + + var zkAddress: String = _ + var zkClient: ZkClient = _ + + private val zkHost = "localhost" + private val zkConnectionTimeout = 6000 + private val zkSessionTimeout = 6000 + private var zookeeper: EmbeddedZookeeper = _ + private var zkPort: Int = 0 + private var brokerPort = 9092 + private var brokerConf: KafkaConfig = _ + private var server: KafkaServer = _ + private var producer: Producer[String, String] = _ + + def setupKafka() { // Zookeeper server startup zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") // Get the actual zookeeper binding port zkPort = zookeeper.actualPort + zkAddress = s"$zkHost:$zkPort" logInfo("==================== 0 ====================") - zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) logInfo("==================== 1 ====================") @@ -71,7 +77,7 @@ class KafkaStreamSuite extends TestSuiteBase { var bindSuccess: Boolean = false while(!bindSuccess) { try { - val brokerProps = getBrokerConfig(brokerPort, s"$zkHost:$zkPort") + val brokerProps = getBrokerConfig() brokerConf = new KafkaConfig(brokerProps) server = new KafkaServer(brokerConf) logInfo("==================== 2 ====================") @@ -89,53 +95,30 @@ class KafkaStreamSuite extends TestSuiteBase { Thread.sleep(2000) logInfo("==================== 4 ====================") - super.beforeFunction() } - override def afterFunction() { - producer.close() - server.shutdown() - brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - - zkClient.close() - zookeeper.shutdown() - - super.afterFunction() - } - - test("Kafka input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val topic = "topic1" - val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) - produceAndSendMessage(topic, sent) + def tearDownKafka() { + if (producer != null) { + producer.close() + producer = null + } - val kafkaParams = Map("zookeeper.connect" -> s"$zkHost:$zkPort", - "group.id" -> s"test-consumer-${random.nextInt(10000)}", - "auto.offset.reset" -> "smallest") + if (server != null) { + server.shutdown() + server = null + } - val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( - ssc, - kafkaParams, - Map(topic -> 1), - StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() - stream.map { case (k, v) => v } - .countByValue() - .foreachRDD { r => - val ret = r.collect() - ret.toMap.foreach { kv => - val count = result.getOrElseUpdate(kv._1, 0) + kv._2 - result.put(kv._1, count) - } - } - ssc.start() - ssc.awaitTermination(3000) + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - assert(sent.size === result.size) - sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } + if (zkClient != null) { + zkClient.close() + zkClient = null + } - ssc.stop() + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } } private def createTestMessage(topic: String, sent: Map[String, Int]) @@ -150,58 +133,43 @@ class KafkaStreamSuite extends TestSuiteBase { CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") logInfo("==================== 5 ====================") // wait until metadata is propagated - waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000) + waitUntilMetadataIsPropagated(topic, 0) } def produceAndSendMessage(topic: String, sent: Map[String, Int]) { - val brokerAddr = brokerConf.hostName + ":" + brokerConf.port - producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr))) + producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) producer.send(createTestMessage(topic, sent): _*) + producer.close() logInfo("==================== 6 ====================") } -} - -object KafkaTestUtils { - val random = new Random() - def getBrokerConfig(port: Int, zkConnect: String): Properties = { + private def getBrokerConfig(): Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") - props.put("port", port.toString) + props.put("port", brokerPort.toString) props.put("log.dir", Utils.createTempDir().getAbsolutePath) - props.put("zookeeper.connect", zkConnect) + props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props } - def getProducerConfig(brokerList: String): Properties = { + private def getProducerConfig(): Properties = { + val brokerAddr = brokerConf.hostName + ":" + brokerConf.port val props = new Properties() - props.put("metadata.broker.list", brokerList) + props.put("metadata.broker.list", brokerAddr) props.put("serializer.class", classOf[StringEncoder].getName) props } - def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = { - val startTime = System.currentTimeMillis() - while (true) { - if (condition()) - return true - if (System.currentTimeMillis() > startTime + waitTime) - return false - Thread.sleep(waitTime.min(100L)) + private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { + eventually(timeout(1000 milliseconds), interval(100 milliseconds)) { + assert( + server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)), + s"Partition [$topic, $partition] metadata not propagated after timeout" + ) } - // Should never go to here - throw new RuntimeException("unexpected error") - } - - def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int, - timeout: Long) { - assert(waitUntilTrue(() => - servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains( - TopicAndPartition(topic, partition))), timeout), - s"Partition [$topic, $partition] metadata not propagated after timeout") } class EmbeddedZookeeper(val zkConnect: String) { @@ -227,3 +195,53 @@ object KafkaTestUtils { } } } + + +class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { + var ssc: StreamingContext = _ + + before { + setupKafka() + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + tearDownKafka() + } + + test("Kafka input stream") { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + produceAndSendMessage(topic, sent) + + val kafkaParams = Map("zookeeper.connect" -> zkAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}", + "auto.offset.reset" -> "smallest") + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map(_._2).countByValue().foreachRDD { r => + val ret = r.collect() + ret.toMap.foreach { kv => + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } + } + ssc.start() + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + assert(sent.size === result.size) + sent.keys.foreach { k => + assert(sent(k) === result(k).toInt) + } + } + ssc.stop() + } +} + diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala new file mode 100644 index 0000000000000..64ccc92c81fa9 --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + + +import java.io.File + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import com.google.common.io.Files +import kafka.serializer.StringDecoder +import kafka.utils.{ZKGroupTopicDirs, ZkUtils} +import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually + +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} + +class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { + + val sparkConf = new SparkConf() + .setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + + + var groupId: String = _ + var kafkaParams: Map[String, String] = _ + var ssc: StreamingContext = _ + var tempDirectory: File = null + + before { + setupKafka() + groupId = s"test-consumer-${Random.nextInt(10000)}" + kafkaParams = Map( + "zookeeper.connect" -> zkAddress, + "group.id" -> groupId, + "auto.offset.reset" -> "smallest" + ) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + tempDirectory = Files.createTempDir() + ssc.checkpoint(tempDirectory.getAbsolutePath) + } + + after { + if (ssc != null) { + ssc.stop() + } + if (tempDirectory != null && tempDirectory.exists()) { + FileUtils.deleteDirectory(tempDirectory) + tempDirectory = null + } + tearDownKafka() + } + + + test("Reliable Kafka input stream with single topic") { + var topic = "test-topic" + createTopic(topic) + produceAndSendMessage(topic, data) + + // Verify whether the offset of this group/topic/partition is 0 before starting. + assert(getCommitOffset(groupId, topic, 0) === None) + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map { case (k, v) => v }.foreachRDD { r => + val ret = r.collect() + ret.foreach { v => + val count = result.getOrElseUpdate(v, 0) + 1 + result.put(v, count) + } + } + ssc.start() + eventually(timeout(20000 milliseconds), interval(200 milliseconds)) { + // A basic process verification for ReliableKafkaReceiver. + // Verify whether received message number is equal to the sent message number. + assert(data.size === result.size) + // Verify whether each message is the same as the data to be verified. + data.keys.foreach { k => assert(data(k) === result(k).toInt) } + // Verify the offset number whether it is equal to the total message number. + assert(getCommitOffset(groupId, topic, 0) === Some(29L)) + } + ssc.stop() + } + + test("Reliable Kafka input stream with multiple topics") { + val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) + topics.foreach { case (t, _) => + createTopic(t) + produceAndSendMessage(t, data) + } + + // Before started, verify all the group/topic/partition offsets are 0. + topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === None) } + + // Consuming all the data sent to the broker which will potential commit the offsets internally. + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY) + stream.foreachRDD(_ => Unit) + ssc.start() + eventually(timeout(20000 milliseconds), interval(100 milliseconds)) { + // Verify the offset for each group/topic to see whether they are equal to the expected one. + topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) } + } + ssc.stop() + } + + + /** Getting partition offset from Zookeeper. */ + private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = { + assert(zkClient != null, "Zookeeper client is not initialized") + val topicDirs = new ZKGroupTopicDirs(groupId, topic) + val zkPath = s"${topicDirs.consumerOffsetDir}/$partition" + ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong) + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a94d09be3bec6..8a2a865867fc4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -85,6 +85,10 @@ object MimaExcludes { "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), ProblemFilters.exclude[MissingTypesProblem]( "org.apache.spark.rdd.PairRDDFunctions") + ) ++ Seq( + // SPARK-4062 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") ) case v if v.startsWith("1.1") => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 0316b6862f195..55765dc90698b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -27,9 +27,38 @@ import org.apache.spark.streaming.util.{RecurringTimer, SystemClock} /** Listener object for BlockGenerator events */ private[streaming] trait BlockGeneratorListener { - /** Called when a new block needs to be pushed */ + /** + * Called after a data item is added into the BlockGenerator. The data addition and this + * callback are synchronized with the block generation and its associated callback, + * so block generation waits for the active data addition+callback to complete. This is useful + * for updating metadata on successful buffering of a data item, specifically that metadata + * that will be useful when a block is generated. Any long blocking operation in this callback + * will hurt the throughput. + */ + def onAddData(data: Any, metadata: Any) + + /** + * Called when a new block of data is generated by the block generator. The block generation + * and this callback are synchronized with the data addition and its associated callback, so + * the data addition waits for the block generation+callback to complete. This is useful + * for updating metadata when a block has been generated, specifically metadata that will + * be useful when the block has been successfully stored. Any long blocking operation in this + * callback will hurt the throughput. + */ + def onGenerateBlock(blockId: StreamBlockId) + + /** + * Called when a new block is ready to be pushed. Callers are supposed to store the block into + * Spark in this method. Internally this is called from a single + * thread, that is not synchronized with any other callbacks. Hence it is okay to do long + * blocking operation in this callback. + */ def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) - /** Called when an error has occurred in BlockGenerator */ + + /** + * Called when an error has occurred in the BlockGenerator. Can be called form many places + * so better to not do any long block operation in this callback. + */ def onError(message: String, throwable: Throwable) } @@ -80,9 +109,20 @@ private[streaming] class BlockGenerator( * Push a single data item into the buffer. All received data items * will be periodically pushed into BlockManager. */ - def += (data: Any): Unit = synchronized { + def addData (data: Any): Unit = synchronized { + waitToPush() + currentBuffer += data + } + + /** + * Push a single data item into the buffer. After buffering the data, the + * `BlockGeneratorListnere.onAddData` callback will be called. All received data items + * will be periodically pushed into BlockManager. + */ + def addDataWithCallback(data: Any, metadata: Any) = synchronized { waitToPush() currentBuffer += data + listener.onAddData(data, metadata) } /** Change the buffer to which single records are added to. */ @@ -93,14 +133,15 @@ private[streaming] class BlockGenerator( if (newBlockBuffer.size > 0) { val blockId = StreamBlockId(receiverId, time - blockInterval) val newBlock = new Block(blockId, newBlockBuffer) + listener.onGenerateBlock(blockId) blocksForPushing.put(newBlock) // put is blocking when queue is full logDebug("Last element in " + blockId + " is " + newBlockBuffer.last) } } catch { case ie: InterruptedException => logInfo("Block updating timer thread was interrupted") - case t: Throwable => - reportError("Error in block updating thread", t) + case e: Exception => + reportError("Error in block updating thread", e) } } @@ -126,8 +167,8 @@ private[streaming] class BlockGenerator( } catch { case ie: InterruptedException => logInfo("Block pushing thread was interrupted") - case t: Throwable => - reportError("Error in block pushing thread", t) + case e: Exception => + reportError("Error in block pushing thread", e) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 5360412330d37..3b1233e86c210 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -27,10 +27,10 @@ import akka.actor.{Actor, Props} import akka.pattern.ask import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration + import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.util.WriteAheadLogFileSegment import org.apache.spark.util.{AkkaUtils, Utils} /** @@ -99,6 +99,10 @@ private[streaming] class ReceiverSupervisorImpl( /** Divides received data records into data blocks for pushing in BlockManager. */ private val blockGenerator = new BlockGenerator(new BlockGeneratorListener { + def onAddData(data: Any, metadata: Any): Unit = { } + + def onGenerateBlock(blockId: StreamBlockId): Unit = { } + def onError(message: String, throwable: Throwable) { reportError(message, throwable) } @@ -110,7 +114,7 @@ private[streaming] class ReceiverSupervisorImpl( /** Push a single record of received data into block generator. */ def pushSingle(data: Any) { - blockGenerator += (data) + blockGenerator.addData(data) } /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 0f6a9489dbe0d..e26c0c6859e57 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -138,7 +138,7 @@ class ReceiverSuite extends FunSuite with Timeouts { blockGenerator.start() var count = 0 while(System.currentTimeMillis - startTime < waitTime) { - blockGenerator += count + blockGenerator.addData(count) generatedData += count count += 1 Thread.sleep(10) @@ -168,7 +168,7 @@ class ReceiverSuite extends FunSuite with Timeouts { blockGenerator.start() var count = 0 while(System.currentTimeMillis - startTime < waitTime) { - blockGenerator += count + blockGenerator.addData(count) generatedData += count count += 1 Thread.sleep(1) @@ -299,6 +299,10 @@ class ReceiverSuite extends FunSuite with Timeouts { val arrayBuffers = new ArrayBuffer[ArrayBuffer[Int]] val errors = new ArrayBuffer[Throwable] + def onAddData(data: Any, metadata: Any) { } + + def onGenerateBlock(blockId: StreamBlockId) { } + def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) { val bufferOfInts = arrayBuffer.map(_.asInstanceOf[Int]) arrayBuffers += bufferOfInts From a0300ea32a9d92bd51c72930bc3979087b0082b2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 14 Nov 2014 14:56:57 -0800 Subject: [PATCH 02/22] [SPARK-4390][SQL] Handle NaN cast to decimal correctly Author: Michael Armbrust Closes #3256 from marmbrus/NanDecimal and squashes the following commits: 4c3ba46 [Michael Armbrust] fix style d360f83 [Michael Armbrust] Handle NaN cast to decimal --- .../org/apache/spark/sql/catalyst/expressions/Cast.scala | 6 +++++- .../NaN to Decimal-0-6ca781bc343025635d72321ef0a9d425 | 1 + .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 sql/hive/src/test/resources/golden/NaN to Decimal-0-6ca781bc343025635d72321ef0a9d425 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 55319e7a79103..34697a1249644 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -290,7 +290,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case LongType => b => changePrecision(Decimal(b.asInstanceOf[Long]), target) case x: NumericType => // All other numeric types can be represented precisely as Doubles - b => changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) + b => try { + changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) + } catch { + case _: NumberFormatException => null + } } // DoubleConverter diff --git a/sql/hive/src/test/resources/golden/NaN to Decimal-0-6ca781bc343025635d72321ef0a9d425 b/sql/hive/src/test/resources/golden/NaN to Decimal-0-6ca781bc343025635d72321ef0a9d425 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/NaN to Decimal-0-6ca781bc343025635d72321ef0a9d425 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 684d22807c0c6..0dd766f25348d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -56,6 +56,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { Locale.setDefault(originalLocale) } + createQueryTest("NaN to Decimal", + "SELECT CAST(CAST('NaN' AS DOUBLE) AS DECIMAL(1,1)) FROM src LIMIT 1") + createQueryTest("constant null testing", """SELECT |IF(FALSE, CAST(NULL AS STRING), CAST(1 AS STRING)) AS COL1, From e47c38763914aaf89a7a851c5f41b7549a75615b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 14 Nov 2014 14:59:35 -0800 Subject: [PATCH 03/22] [SPARK-4391][SQL] Configure parquet filters using SQLConf This is more uniform with the rest of SQL configuration and allows it to be turned on and off without restarting the SparkContext. In this PR I also turn off filter pushdown by default due to a number of outstanding issues (in particular SPARK-4258). When those are fixed we should turn it back on by default. Author: Michael Armbrust Closes #3258 from marmbrus/parquetFilters and squashes the following commits: 5655bfe [Michael Armbrust] Remove extra line. 15e9a98 [Michael Armbrust] Enable filters for tests 75afd39 [Michael Armbrust] Fix comments 78fa02d [Michael Armbrust] off by default e7f9e16 [Michael Armbrust] First draft of correctly configuring parquet filter pushdown --- .../main/scala/org/apache/spark/sql/SQLConf.scala | 8 +++++++- .../spark/sql/execution/SparkStrategies.scala | 7 +++++-- .../apache/spark/sql/parquet/ParquetFilters.scala | 2 -- .../spark/sql/parquet/ParquetTableOperations.scala | 13 +++++++------ .../spark/sql/parquet/ParquetQuerySuite.scala | 2 ++ 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 279495aa64755..cd7d78e684791 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -22,7 +22,6 @@ import scala.collection.JavaConversions._ import java.util.Properties - private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" @@ -32,9 +31,12 @@ private[spark] object SQLConf { val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" val DIALECT = "spark.sql.dialect" + val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" val PARQUET_CACHE_METADATA = "spark.sql.parquet.cacheMetadata" val PARQUET_COMPRESSION = "spark.sql.parquet.compression.codec" + val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.parquet.filterPushdown" + val COLUMN_NAME_OF_CORRUPT_RECORD = "spark.sql.columnNameOfCorruptRecord" // This is only used for the thriftserver @@ -90,6 +92,10 @@ private[sql] trait SQLConf { /** Number of partitions to use for shuffle operators. */ private[spark] def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS, "200").toInt + /** When true predicates will be passed to the parquet record reader when possible. */ + private[spark] def parquetFilterPushDown = + getConf(PARQUET_FILTER_PUSHDOWN_ENABLED, "false").toBoolean + /** * When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode * that evaluates expressions found in queries. In general this custom code runs much faster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cc7e0c05ffc70..03cd5bd6272bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -208,7 +208,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = - if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { + if (sqlContext.parquetFilterPushDown) { (filters: Seq[Expression]) => { filters.filter { filter => // Note: filters cannot be pushed down to Parquet if they contain more complex @@ -234,7 +234,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, prunePushedDownFilters, - ParquetTableScan(_, relation, filters)) :: Nil + ParquetTableScan( + _, + relation, + if (sqlContext.parquetFilterPushDown) filters else Nil)) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 1e67799e8399a..9a3f6d388d621 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -43,8 +43,6 @@ import org.apache.spark.sql.parquet.ParquetColumns._ private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" - // set this to false if pushdown should be disabled - val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown" def createRecordFilter(filterExpressions: Seq[Expression]): Filter = { val filters: Seq[CatalystFilter] = filterExpressions.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 74c43e053b03c..5f93279a08dd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -23,6 +23,8 @@ import java.text.SimpleDateFormat import java.util.concurrent.{Callable, TimeUnit} import java.util.{ArrayList, Collections, Date, List => JList} +import org.apache.spark.annotation.DeveloperApi + import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Try @@ -52,6 +54,7 @@ import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.{Logging, SerializableWritable, TaskContext} /** + * :: DeveloperApi :: * Parquet table scan operator. Imports the file that backs the given * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. */ @@ -108,15 +111,11 @@ case class ParquetTableScan( // Note 1: the input format ignores all predicates that cannot be expressed // as simple column predicate filters in Parquet. Here we just record // the whole pruning predicate. - // Note 2: you can disable filter predicate pushdown by setting - // "spark.sql.hints.parquetFilterPushdown" to false inside SparkConf. - if (columnPruningPred.length > 0 && - sc.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { - + if (columnPruningPred.length > 0) { // Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering val filter: Filter = ParquetFilters.createRecordFilter(columnPruningPred) if (filter != null){ - val filterPredicate = filter.asInstanceOf[FilterPredicateCompat].getFilterPredicate() + val filterPredicate = filter.asInstanceOf[FilterPredicateCompat].getFilterPredicate ParquetInputFormat.setFilterPredicate(conf, filterPredicate) } } @@ -193,6 +192,7 @@ case class ParquetTableScan( } /** + * :: DeveloperApi :: * Operator that acts as a sink for queries on RDDs and can be used to * store the output inside a directory of Parquet files. This operator * is similar to Hive's INSERT INTO TABLE operation in the sense that @@ -208,6 +208,7 @@ case class ParquetTableScan( * cause unpredicted behaviour and therefore results in a RuntimeException * (only detected via filename pattern so will not catch all cases). */ +@DeveloperApi case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 3cccafe92d4f3..80a3e0b4c91ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -95,6 +95,8 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA testRDD.registerTempTable("testsource") parquetFile(ParquetTestData.testFilterDir.toString) .registerTempTable("testfiltersource") + + setConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED, "true") } override def afterAll() { From f805025e8efe9cd522e8875141ec27df8d16bbe0 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 14 Nov 2014 15:00:42 -0800 Subject: [PATCH 04/22] [SQL] Minor cleanup of comments, errors and override. Author: Michael Armbrust Closes #3257 from marmbrus/minorCleanup and squashes the following commits: d8b5abc [Michael Armbrust] Use interpolation. 2fdf903 [Michael Armbrust] Better error message when coalesce can't be resolved. f9fa6cf [Michael Armbrust] Methods in a final class do not also need to be final, use override. 199fd98 [Michael Armbrust] Fix typo --- .../sql/catalyst/expressions/aggregates.scala | 2 +- .../expressions/codegen/GenerateProjection.scala | 16 ++++++++-------- .../sql/catalyst/expressions/nullFunctions.scala | 4 +++- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 2b364fc1df1d8..3ceb5ecaf66e4 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -304,7 +304,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN child.dataType match { case DecimalType.Fixed(_, _) => - // Turn the results to unlimited decimals for the divsion, before going back to fixed + // Turn the results to unlimited decimals for the division, before going back to fixed val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) SplitEvaluation( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 7871a62620478..2ff61169a17db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -53,8 +53,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val nullFunctions = q""" private[this] var nullBits = new Array[Boolean](${expressions.size}) - final def setNullAt(i: Int) = { nullBits(i) = true } - final def isNullAt(i: Int) = nullBits(i) + override def setNullAt(i: Int) = { nullBits(i) = true } + override def isNullAt(i: Int) = nullBits(i) """.children val tupleElements = expressions.zipWithIndex.flatMap { @@ -82,7 +82,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val iLit = ru.Literal(Constant(i)) q"if(isNullAt($iLit)) { null } else { ${newTermName(s"c$i")} }" } - q"final def iterator = Iterator[Any](..$allColumns)" + q"override def iterator = Iterator[Any](..$allColumns)" } val accessorFailure = q"""scala.sys.error("Invalid ordinal:" + i)""" @@ -94,7 +94,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { q"if(i == $ordinal) { if(isNullAt($i)) return null else return $elementName }" } - q"final def apply(i: Int): Any = { ..$cases; $accessorFailure }" + q"override def apply(i: Int): Any = { ..$cases; $accessorFailure }" } val updateFunction = { @@ -114,7 +114,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return }""" } - q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" + q"override def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" } val specificAccessorFunctions = NativeType.all.map { dataType => @@ -128,7 +128,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } q""" - final def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { + override def ${accessorForType(dataType)}(i: Int):${termForType(dataType)} = { ..$ifStatements; $accessorFailure }""" @@ -145,7 +145,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } q""" - final def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { + override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}): Unit = { ..$ifStatements; $accessorFailure }""" @@ -193,7 +193,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val copyFunction = q""" - final def copy() = new $genericRowType(this.toArray) + override def copy() = new $genericRowType(this.toArray) """ val classBody = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 086d0a3e073e5..84a3567895175 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -37,7 +37,9 @@ case class Coalesce(children: Seq[Expression]) extends Expression { def dataType = if (resolved) { children.head.dataType } else { - throw new UnresolvedException(this, "Coalesce cannot have children of different types.") + val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") + throw new UnresolvedException( + this, s"Coalesce cannot have children of different types. $childTypes") } override def eval(input: Row): Any = { From 4b4b50c9e596673c1534df97effad50d107a8007 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 14 Nov 2014 15:03:23 -0800 Subject: [PATCH 05/22] [SQL] Don't shuffle code generated rows When sort based shuffle and code gen are on we were trying to ship the code generated rows during a shuffle. This doesn't work because the classes don't exist on the other side. Instead we now copy into a generic row before shipping. Author: Michael Armbrust Closes #3263 from marmbrus/aggCodeGen and squashes the following commits: f6ba8cf [Michael Armbrust] fix and test --- .../scala/org/apache/spark/sql/execution/Exchange.scala | 4 ++-- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 927f40063e47e..cff7a012691dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -47,8 +47,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una // TODO: Eliminate redundant expressions in grouping key and value. val rdd = if (sortBasedShuffleOn) { child.execute().mapPartitions { iter => - val hashExpressions = newProjection(expressions, child.output) - iter.map(r => (hashExpressions(r), r.copy())) + val hashExpressions = newMutableProjection(expressions, child.output)() + iter.map(r => (hashExpressions(r).copy(), r.copy())) } } else { child.execute().mapPartitions { iter => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8a80724c08c7c..5dd777f1fb3b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -72,6 +72,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { 2.5) } + test("aggregation with codegen") { + val originalValue = codegenEnabled + setConf(SQLConf.CODEGEN_ENABLED, "true") + sql("SELECT key FROM testData GROUP BY key").collect() + setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } + test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), From 0c7b66bd449093bb5d2dafaf91d54e63e601e320 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 14 Nov 2014 15:09:36 -0800 Subject: [PATCH 06/22] [SPARK-4322][SQL] Enables struct fields as sub expressions of grouping fields While resolving struct fields, the resulted `GetField` expression is wrapped with an `Alias` to make it a named expression. Assume `a` is a struct instance with a field `b`, then `"a.b"` will be resolved as `Alias(GetField(a, "b"), "b")`. Thus, for this following SQL query: ```sql SELECT a.b + 1 FROM t GROUP BY a.b + 1 ``` the grouping expression is ```scala Add(GetField(a, "b"), Literal(1, IntegerType)) ``` while the aggregation expression is ```scala Add(Alias(GetField(a, "b"), "b"), Literal(1, IntegerType)) ``` This mismatch makes the above SQL query fail during the both analysis and execution phases. This PR fixes this issue by removing the alias when substituting aggregation expressions. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/3248) Author: Cheng Lian Closes #3248 from liancheng/spark-4322 and squashes the following commits: 23a46ea [Cheng Lian] Code simplification dd20a79 [Cheng Lian] Should only trim aliases around `GetField`s 7f46532 [Cheng Lian] Enables struct fields as sub expressions of grouping fields --- .../sql/catalyst/analysis/Analyzer.scala | 27 +++++++++---------- .../sql/catalyst/planning/patterns.scala | 15 ++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++++++- 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a448c794213ae..d3b4cf8e34242 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -60,7 +60,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ResolveFunctions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - TrimAliases :: + TrimGroupingAliases :: typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, @@ -93,17 +93,10 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool /** * Removes no-op Alias expressions from the plan. */ - object TrimAliases extends Rule[LogicalPlan] { + object TrimGroupingAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Aggregate(groups, aggs, child) => - Aggregate( - groups.map { - _ transform { - case Alias(c, _) => c - } - }, - aggs, - child) + Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child) } } @@ -122,10 +115,15 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case e => e.children.forall(isValidAggregateExpression) } - aggregateExprs.foreach { e => - if (!isValidAggregateExpression(e)) { - throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") - } + aggregateExprs.find { e => + !isValidAggregateExpression(e.transform { + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + case Alias(g: GetField, _) => g + }) + }.foreach { e => + throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") } aggregatePlan @@ -328,4 +326,3 @@ object EliminateAnalysisOperators extends Rule[LogicalPlan] { case Subquery(_, child) => child } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f0fd9a8b9a46e..310d127506d68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -151,8 +151,15 @@ object PartialAggregation { val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute + + case e: Expression => + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + namedGroupingExpressions + .get(e.transform { case Alias(g: GetField, _) => g }) + .map(_.toAttribute) + .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] val partialComputation = @@ -188,7 +195,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. - val (joinPredicates, otherPredicates) = + val (joinPredicates, otherPredicates) = condition.map(splitConjunctivePredicates).getOrElse(Nil).partition { case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || (canEvaluate(l, right) && canEvaluate(r, left)) => true @@ -203,7 +210,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5dd777f1fb3b7..ce5672c08653a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -551,7 +551,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil) } - test("INTERSECT") { + test("INTERSECT") { checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), (1, "a") :: @@ -949,4 +949,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), (1 to 99).map(i => Seq(i))) } + + test("SPARK-4322 Grouping field with struct field as sub expression") { + jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1) + dropTempTable("data") + + jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2) + dropTempTable("data") + } } From f76b9683706232c3d4e8e6e61627b8188dcb79dc Mon Sep 17 00:00:00 2001 From: Jim Carroll Date: Fri, 14 Nov 2014 15:11:53 -0800 Subject: [PATCH 07/22] [SPARK-4386] Improve performance when writing Parquet files. If you profile the writing of a Parquet file, the single worst time consuming call inside of org.apache.spark.sql.parquet.MutableRowWriteSupport.write is actually in the scala.collection.AbstractSequence.size call. This is because the size call actually ends up COUNTING the elements in a scala.collection.LinearSeqOptimized.length ("optimized?"). This doesn't need to be done. "size" is called repeatedly where needed rather than called once at the top of the method and stored in a 'val'. Author: Jim Carroll Closes #3254 from jimfcarroll/parquet-perf and squashes the following commits: 30cc0b5 [Jim Carroll] Improve performance when writing Parquet files. --- .../spark/sql/parquet/ParquetTableSupport.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 7bc249660053a..ef3687e692964 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -152,14 +152,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } override def write(record: Row): Unit = { - if (attributes.size > record.size) { + val attributesSize = attributes.size + if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributes.size}>${record.size})") + s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") } var index = 0 writer.startMessage() - while(index < attributes.size) { + while(index < attributesSize) { // null values indicate optional fields but we do not check currently if (record(index) != null) { writer.startField(attributes(index).name, index) @@ -312,14 +313,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { // Optimized for non-nested rows private[parquet] class MutableRowWriteSupport extends RowWriteSupport { override def write(record: Row): Unit = { - if (attributes.size > record.size) { + val attributesSize = attributes.size + if (attributesSize > record.size) { throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row (${attributes.size}>${record.size})") + s"Trying to write more fields than contained in row (${attributesSize}>${record.size})") } var index = 0 writer.startMessage() - while(index < attributes.size) { + while(index < attributesSize) { // null values indicate optional fields but we do not check currently if (record(index) != null && record(index) != Nil) { writer.startField(attributes(index).name, index) From 63ca3af66f9680fd12adee82fb4d342caae5cea4 Mon Sep 17 00:00:00 2001 From: Yash Datta Date: Fri, 14 Nov 2014 15:16:36 -0800 Subject: [PATCH 08/22] [SPARK-4365][SQL] Remove unnecessary filter call on records returned from parquet library Since parquet library has been updated , we no longer need to filter the records returned from parquet library for null records , as now the library skips those : from parquet-hadoop/src/main/java/parquet/hadoop/InternalParquetRecordReader.java public boolean nextKeyValue() throws IOException, InterruptedException { boolean recordFound = false; while (!recordFound) { // no more records left if (current >= total) { return false; } try { checkRead(); currentValue = recordReader.read(); current ++; if (recordReader.shouldSkipCurrentRecord()) { // this record is being filtered via the filter2 package if (DEBUG) LOG.debug("skipping record"); continue; } if (currentValue == null) { // only happens with FilteredRecordReader at end of block current = totalCountLoadedSoFar; if (DEBUG) LOG.debug("filtered record reader reached end of block"); continue; } recordFound = true; if (DEBUG) LOG.debug("read value: " + currentValue); } catch (RuntimeException e) { throw new ParquetDecodingException(format("Can not read value at %d in block %d in file %s", current, currentBlock, file), e); } } return true; } Author: Yash Datta Closes #3229 from saucam/remove_filter and squashes the following commits: 8909ae9 [Yash Datta] SPARK-4365: Remove unnecessary filter call on records returned from parquet library --- .../org/apache/spark/sql/parquet/ParquetTableOperations.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 5f93279a08dd8..f6bed5016fbfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -159,7 +159,7 @@ case class ParquetTableScan( } } else { baseRDD.map(_._2) - }.filter(_ != null) // Parquet's record filters may produce null values + } } /** From 37482ce5a7b875f17d32a5e8c561cc8e9772c9b3 Mon Sep 17 00:00:00 2001 From: Jim Carroll Date: Fri, 14 Nov 2014 15:33:21 -0800 Subject: [PATCH 09/22] [SPARK-4412][SQL] Fix Spark's control of Parquet logging. The Spark ParquetRelation.scala code makes the assumption that the parquet.Log class has already been loaded. If ParquetRelation.enableLogForwarding executes prior to the parquet.Log class being loaded then the code in enableLogForwarding has no affect. ParquetRelation.scala attempts to override the parquet logger but, at least currently (and if your application simply reads a parquet file before it does anything else with Parquet), the parquet.Log class hasn't been loaded yet. Therefore the code in ParquetRelation.enableLogForwarding has no affect. If you look at the code in parquet.Log there's a static initializer that needs to be called prior to enableLogForwarding or whatever enableLogForwarding does gets undone by this static initializer. The "fix" would be to force the static initializer to get called in parquet.Log as part of enableForwardLogging. Author: Jim Carroll Closes #3271 from jimfcarroll/parquet-logging and squashes the following commits: 37bdff7 [Jim Carroll] Fix Spark's control of Parquet logging. --- .../spark/sql/parquet/ParquetRelation.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 82130b5459174..b237a07c72d07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -84,6 +84,21 @@ private[sql] case class ParquetRelation( private[sql] object ParquetRelation { def enableLogForwarding() { + // Note: the parquet.Log class has a static initializer that + // sets the java.util.logging Logger for "parquet". This + // checks first to see if there's any handlers already set + // and if not it creates them. If this method executes prior + // to that class being loaded then: + // 1) there's no handlers installed so there's none to + // remove. But when it IS finally loaded the desired affect + // of removing them is circumvented. + // 2) The parquet.Log static initializer calls setUseParentHanders(false) + // undoing the attempt to override the logging here. + // + // Therefore we need to force the class to be loaded. + // This should really be resolved by Parquet. + Class.forName(classOf[parquet.Log].getName()) + // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. val parquetLogger = java.util.logging.Logger.getLogger("parquet") From ad42b283246b93654c5fd731cd618fee74d8c4da Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Fri, 14 Nov 2014 15:51:05 -0800 Subject: [PATCH 10/22] SPARK-4214. With dynamic allocation, avoid outstanding requests for more... ... executors than pending tasks need. WIP. Still need to add and fix tests. Author: Sandy Ryza Closes #3204 from sryza/sandy-spark-4214 and squashes the following commits: 35cf0e0 [Sandy Ryza] Add comment 13b53df [Sandy Ryza] Review feedback 067465f [Sandy Ryza] Whitespace fix 6ae080c [Sandy Ryza] Add tests and get num pending tasks from ExecutorAllocationListener 531e2b6 [Sandy Ryza] SPARK-4214. With dynamic allocation, avoid outstanding requests for more executors than pending tasks need. --- .../spark/ExecutorAllocationManager.scala | 55 ++++++++++++++++--- .../ExecutorAllocationManagerSuite.scala | 48 ++++++++++++++++ 2 files changed, 94 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index ef93009a074e7..88adb892998af 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -28,7 +28,9 @@ import org.apache.spark.scheduler._ * the scheduler queue is not drained in N seconds, then new executors are added. If the queue * persists for another M seconds, then more executors are added and so on. The number added * in each round increases exponentially from the previous round until an upper bound on the - * number of executors has been reached. + * number of executors has been reached. The upper bound is based both on a configured property + * and on the number of tasks pending: the policy will never increase the number of executor + * requests past the number needed to handle all pending tasks. * * The rationale for the exponential increase is twofold: (1) Executors should be added slowly * in the beginning in case the number of extra executors needed turns out to be small. Otherwise, @@ -82,6 +84,12 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + // TODO: The default value of 1 for spark.executor.cores works right now because dynamic + // allocation is only supported for YARN and the default number of cores per executor in YARN is + // 1, but it might need to be attained differently for different cluster managers + private val tasksPerExecutor = + conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + validateSettings() // Number of executors to add in the next round @@ -110,6 +118,9 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging // Clock used to schedule when executors should be added and removed private var clock: Clock = new RealClock + // Listener for Spark events that impact the allocation policy + private val listener = new ExecutorAllocationListener(this) + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -141,6 +152,9 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging throw new SparkException("Dynamic allocation of executors requires the external " + "shuffle service. You may enable this through spark.shuffle.service.enabled.") } + if (tasksPerExecutor == 0) { + throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores") + } } /** @@ -154,7 +168,6 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging * Register for scheduler callbacks to decide when to add and remove executors. */ def start(): Unit = { - val listener = new ExecutorAllocationListener(this) sc.addSparkListener(listener) startPolling() } @@ -218,13 +231,27 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging return 0 } - // Request executors with respect to the upper bound - val actualNumExecutorsToAdd = - if (numExistingExecutors + numExecutorsToAdd <= maxNumExecutors) { - numExecutorsToAdd - } else { - maxNumExecutors - numExistingExecutors - } + // The number of executors needed to satisfy all pending tasks is the number of tasks pending + // divided by the number of tasks each executor can fit, rounded up. + val maxNumExecutorsPending = + (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor + if (numExecutorsPending >= maxNumExecutorsPending) { + logDebug(s"Not adding executors because there are already $numExecutorsPending " + + s"pending and pending tasks could only fill $maxNumExecutorsPending") + numExecutorsToAdd = 1 + return 0 + } + + // It's never useful to request more executors than could satisfy all the pending tasks, so + // cap request at that amount. + // Also cap request with respect to the configured upper bound. + val maxNumExecutorsToAdd = math.min( + maxNumExecutorsPending - numExecutorsPending, + maxNumExecutors - numExistingExecutors) + assert(maxNumExecutorsToAdd > 0) + + val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd) + val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd val addRequestAcknowledged = testing || sc.requestExecutors(actualNumExecutorsToAdd) if (addRequestAcknowledged) { @@ -445,6 +472,16 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId) } + + /** + * An estimate of the total number of pending tasks remaining for currently running stages. Does + * not account for tasks which may have failed and been resubmitted. + */ + def totalPendingTasks(): Int = { + stageIdToNumTasks.map { case (stageId, numTasks) => + numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0) + }.sum + } } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 66cf60d25f6d1..4b27477790212 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -76,6 +76,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test("add executors") { sc = createSparkContext(1, 10) val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Keep adding until the limit is reached assert(numExecutorsPending(manager) === 0) @@ -117,6 +118,51 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { assert(numExecutorsToAdd(manager) === 1) } + test("add executors capped by num pending tasks") { + sc = createSparkContext(1, 10) + val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) + + // Verify that we're capped at number of tasks in the stage + assert(numExecutorsPending(manager) === 0) + assert(numExecutorsToAdd(manager) === 1) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 1) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 3) + assert(numExecutorsToAdd(manager) === 4) + assert(addExecutors(manager) === 2) + assert(numExecutorsPending(manager) === 5) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a task reduces the cap + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) + sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 6) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 7) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that re-running a task doesn't reduce the cap further + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 3))) + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 8) + assert(numExecutorsToAdd(manager) === 2) + assert(addExecutors(manager) === 1) + assert(numExecutorsPending(manager) === 9) + assert(numExecutorsToAdd(manager) === 1) + + // Verify that running a task once we're at our limit doesn't blow things up + sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) + assert(addExecutors(manager) === 0) + assert(numExecutorsPending(manager) === 9) + } + test("remove executors") { sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get @@ -170,6 +216,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { test ("interleaving add and remove") { sc = createSparkContext(5, 10) val manager = sc.executorAllocationManager.get + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Add a few executors assert(addExecutors(manager) === 1) @@ -343,6 +390,7 @@ class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext { val clock = new TestClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Scheduler queue backlogged onSchedulerBacklogged(manager) From 303a4e4d23e5cd93b541480cf88d5badb9cf9622 Mon Sep 17 00:00:00 2001 From: WangTao Date: Fri, 14 Nov 2014 20:11:51 -0800 Subject: [PATCH 11/22] [SPARK-4404]SparkSubmitDriverBootstrapper should stop after its SparkSubmit sub-proc... ...ess ends https://issues.apache.org/jira/browse/SPARK-4404 When we have spark.driver.extra* or spark.driver.memory in SPARK_SUBMIT_PROPERTIES_FILE, spark-class will use SparkSubmitDriverBootstrapper to launch driver. If we get process id of SparkSubmitDriverBootstrapper and wanna kill it during its running, we expect its SparkSubmit sub-process stop also. Author: WangTao Author: WangTaoTheTonic Closes #3266 from WangTaoTheTonic/killsubmit and squashes the following commits: e03eba5 [WangTaoTheTonic] add comments 57b5ca1 [WangTao] SparkSubmitDriverBootstrapper should stop after its SparkSubmit sub-process ends --- .../spark/deploy/SparkSubmitDriverBootstrapper.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index 2b894a796c8c6..7ffff29122d4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -129,6 +129,16 @@ private[spark] object SparkSubmitDriverBootstrapper { val process = builder.start() + // If we kill an app while it's running, its sub-process should be killed too. + Runtime.getRuntime().addShutdownHook(new Thread() { + override def run() = { + if (process != null) { + process.destroy() + sys.exit(process.waitFor()) + } + } + }) + // Redirect stdout and stderr from the child JVM val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout") val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr") From 7fe08b43c78bf9e8515f671e72aa03a83ea782f8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 14 Nov 2014 20:13:46 -0800 Subject: [PATCH 12/22] [SPARK-4415] [PySpark] JVM should exit after Python exit When JVM is started in a Python process, it should exit once the stdin is closed. test: add spark.driver.memory in conf/spark-defaults.conf ``` daviesdm:~/work/spark$ cat conf/spark-defaults.conf spark.driver.memory 8g daviesdm:~/work/spark$ bin/pyspark >>> quit daviesdm:~/work/spark$ jps 4931 Jps 286 daviesdm:~/work/spark$ python wc.py 943738 0.719928026199 daviesdm:~/work/spark$ jps 286 4990 Jps ``` Author: Davies Liu Closes #3274 from davies/exit and squashes the following commits: df0e524 [Davies Liu] address comments ce8599c [Davies Liu] address comments 050651f [Davies Liu] JVM should exit after Python exit --- bin/pyspark | 2 -- bin/pyspark2.cmd | 1 - .../spark/deploy/SparkSubmitDriverBootstrapper.scala | 11 ++++++----- python/pyspark/java_gateway.py | 4 +++- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index 1d8c94d43d285..0b4f695dd06dd 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -132,7 +132,5 @@ if [[ "$1" =~ \.py$ ]]; then gatherSparkSubmitOpts "$@" exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}" else - # PySpark shell requires special handling downstream - export PYSPARK_SHELL=1 exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS fi diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 59415e9bdec2c..a542ec80b49d6 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -59,7 +59,6 @@ for /f %%i in ('echo %1^| findstr /R "\.py"') do ( ) if [%PYTHON_FILE%] == [] ( - set PYSPARK_SHELL=1 if [%IPYTHON%] == [1] ( ipython %IPYTHON_OPTS% ) else ( diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index 7ffff29122d4b..aa3743ca7df63 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -149,14 +149,15 @@ private[spark] object SparkSubmitDriverBootstrapper { // subprocess there already reads directly from our stdin, so we should avoid spawning a // thread that contends with the subprocess in reading from System.in. val isWindows = Utils.isWindows - val isPySparkShell = sys.env.contains("PYSPARK_SHELL") + val isSubprocess = sys.env.contains("IS_SUBPROCESS") if (!isWindows) { val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin") stdinThread.start() - // For the PySpark shell, Spark submit itself runs as a python subprocess, and so this JVM - // should terminate on broken pipe, which signals that the parent process has exited. In - // Windows, the termination logic for the PySpark shell is handled in java_gateway.py - if (isPySparkShell) { + // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on + // broken pipe, signaling that the parent process has exited. This is the case if the + // application is launched directly from python, as in the PySpark shell. In Windows, + // the termination logic is handled in java_gateway.py + if (isSubprocess) { stdinThread.join() process.destroy() } diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 9c70fa5c16d0c..a975dc19cb78e 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -45,7 +45,9 @@ def launch_gateway(): # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + env = dict(os.environ) + env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdout=PIPE, stdin=PIPE) From dba14058230194122a715c219e35ab8eaa786321 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 14 Nov 2014 22:25:41 -0800 Subject: [PATCH 13/22] [SPARK-4379][Core] Change Exception to SparkException in checkpoint It's better to change to SparkException. However, it's a breaking change since it will change the exception type. Author: zsxwing Closes #3241 from zsxwing/SPARK-4379 and squashes the following commits: 409f3af [zsxwing] Change Exception to SparkException in checkpoint --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 716f2dd17733b..cb64d43c6c54a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1202,7 +1202,7 @@ abstract class RDD[T: ClassTag]( */ def checkpoint() { if (context.checkpointDir.isEmpty) { - throw new Exception("Checkpoint directory has not been set in the SparkContext") + throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { checkpointData = Some(new RDDCheckpointData(this)) checkpointData.get.markForCheckpoint() From 861223ee5bea8e434a9ebb0d53f436ce23809f9c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 14 Nov 2014 22:28:48 -0800 Subject: [PATCH 14/22] [SPARK-4363][Doc] Update the Broadcast example Author: zsxwing Closes #3226 from zsxwing/SPARK-4363 and squashes the following commits: 8109914 [zsxwing] Update the Broadcast example --- core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala | 2 +- docs/programming-guide.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 87f5cf944ed85..a5ea478f231d7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -39,7 +39,7 @@ import scala.reflect.ClassTag * * {{{ * scala> val broadcastVar = sc.broadcast(Array(1, 2, 3)) - * broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c) + * broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0) * * scala> broadcastVar.value * res0: Array[Int] = Array(1, 2, 3) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 18420afb27e3c..9de2f914b8b4c 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -1131,7 +1131,7 @@ method. The code below shows this: {% highlight scala %} scala> val broadcastVar = sc.broadcast(Array(1, 2, 3)) -broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c) +broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0) scala> broadcastVar.value res0: Array[Int] = Array(1, 2, 3) From 60969b0336930449a826821a48f83f65337e8856 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 14 Nov 2014 22:36:56 -0800 Subject: [PATCH 15/22] [SPARK-4260] Httpbroadcast should set connection timeout. Httpbroadcast sets read timeout but doesn't set connection timeout. Author: Kousuke Saruta Closes #3122 from sarutak/httpbroadcast-timeout and squashes the following commits: c7f3a56 [Kousuke Saruta] Added Connection timeout for Http Connection to HttpBroadcast.scala --- .../main/scala/org/apache/spark/broadcast/HttpBroadcast.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 7dade04273b08..31f0a462f84d8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -191,10 +191,12 @@ private[broadcast] object HttpBroadcast extends Logging { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) uc = newuri.toURL.openConnection() + uc.setConnectTimeout(httpReadTimeout) uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") uc = new URL(url).openConnection() + uc.setConnectTimeout(httpReadTimeout) } val in = { From cbddac23696d89b672dce380cc7360a873e27b3b Mon Sep 17 00:00:00 2001 From: kai Date: Fri, 14 Nov 2014 23:44:23 -0800 Subject: [PATCH 16/22] Added contains(key) to Metadata Add contains(key) to org.apache.spark.sql.catalyst.util.Metadata to test the existence of a key. Otherwise, Class Metadata's get methods may throw NoSuchElement exception if the key does not exist. Testcases are added to MetadataSuite as well. Author: kai Closes #3273 from kai-zeng/metadata-fix and squashes the following commits: 74b3d03 [kai] Added contains(key) to Metadata --- .../apache/spark/sql/catalyst/util/Metadata.scala | 3 +++ .../spark/sql/catalyst/util/MetadataSuite.scala | 13 +++++++++++++ 2 files changed, 16 insertions(+) mode change 100644 => 100755 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala mode change 100644 => 100755 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala old mode 100644 new mode 100755 index 2f2082fa3c863..8172733e94dd5 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala @@ -34,6 +34,9 @@ import org.json4s.jackson.JsonMethods._ */ sealed class Metadata private[util] (private[util] val map: Map[String, Any]) extends Serializable { + /** Tests whether this Metadata contains a binding for a key. */ + def contains(key: String): Boolean = map.contains(key) + /** Gets a Long. */ def getLong(key: String): Long = get(key) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala old mode 100644 new mode 100755 index 0063d31666c85..f005b7df21043 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -56,17 +56,30 @@ class MetadataSuite extends FunSuite { .build() test("metadata builder and getters") { + assert(age.contains("summary") === false) + assert(age.contains("index") === true) assert(age.getLong("index") === 1L) + assert(age.contains("average") === true) assert(age.getDouble("average") === 45.0) + assert(age.contains("categorical") === true) assert(age.getBoolean("categorical") === false) + assert(age.contains("name") === true) assert(age.getString("name") === "age") + assert(metadata.contains("purpose") === true) assert(metadata.getString("purpose") === "ml") + assert(metadata.contains("isBase") === true) assert(metadata.getBoolean("isBase") === false) + assert(metadata.contains("summary") === true) assert(metadata.getMetadata("summary") === summary) + assert(metadata.contains("long[]") === true) assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L)) + assert(metadata.contains("double[]") === true) assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0)) + assert(metadata.contains("boolean[]") === true) assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false)) + assert(gender.contains("categories") === true) assert(gender.getStringArray("categories").toSeq === Seq("male", "female")) + assert(metadata.contains("features") === true) assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender)) } From 40eb8b6ef3a67e36d0d9492c044981a1da76351d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 14 Nov 2014 23:46:25 -0800 Subject: [PATCH 17/22] [SPARK-2321] Several progress API improvements / refactorings This PR refactors / extends the status API introduced in #2696. - Change StatusAPI from a mixin trait to a class. Before, the new status API methods were directly accessible through SparkContext, whereas now they're accessed through a `sc.statusAPI` field. As long as we were going to add these methods directly to SparkContext, the mixin trait seemed like a good idea, but this might be simpler to reason about and may avoid pitfalls that I've run into while attempting to refactor other parts of SparkContext to use mixins (see #3071, for example). - Change the name from SparkStatusAPI to SparkStatusTracker. - Make `getJobIdsForGroup(null)` return ids for jobs that aren't associated with any job group. - Add `getActiveStageIds()` and `getActiveJobIds()` methods that return the ids of whatever's currently active in this SparkContext. This should simplify davies's progress bar code. Author: Josh Rosen Closes #3197 from JoshRosen/progress-api-improvements and squashes the following commits: 30b0afa [Josh Rosen] Rename SparkStatusAPI to SparkStatusTracker. d1b08d8 [Josh Rosen] Add missing newlines 2cc7353 [Josh Rosen] Add missing file. d5eab1f [Josh Rosen] Add getActive[Stage|Job]Ids() methods. a227984 [Josh Rosen] getJobIdsForGroup(null) should return jobs for default group c47e294 [Josh Rosen] Remove StatusAPI mixin trait. --- .../scala/org/apache/spark/SparkContext.scala | 68 ++++++++- .../org/apache/spark/SparkStatusAPI.scala | 142 ------------------ .../org/apache/spark/SparkStatusTracker.scala | 107 +++++++++++++ .../spark/api/java/JavaSparkContext.scala | 21 +-- .../api/java/JavaSparkStatusTracker.scala | 72 +++++++++ ...PISuite.scala => StatusTrackerSuite.scala} | 25 ++- ...PIDemo.java => JavaStatusTrackerDemo.java} | 6 +- 7 files changed, 269 insertions(+), 172 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/SparkStatusAPI.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkStatusTracker.scala create mode 100644 core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala rename core/src/test/scala/org/apache/spark/{StatusAPISuite.scala => StatusTrackerSuite.scala} (69%) rename examples/src/main/java/org/apache/spark/examples/{JavaStatusAPIDemo.java => JavaStatusTrackerDemo.java} (92%) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 03ea672c813d1..65edeeffb837a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -25,6 +25,7 @@ import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger import java.util.UUID.randomUUID import scala.collection.{Map, Set} +import scala.collection.JavaConversions._ import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.reflect.{ClassTag, classTag} @@ -61,7 +62,7 @@ import org.apache.spark.util._ * this config overrides the default configs as well as system properties. */ -class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { +class SparkContext(config: SparkConf) extends Logging { // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It @@ -228,6 +229,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { private[spark] val jobProgressListener = new JobProgressListener(conf) listenerBus.addListener(jobProgressListener) + val statusTracker = new SparkStatusTracker(this) + // Initialize the Spark UI private[spark] val ui: Option[SparkUI] = if (conf.getBoolean("spark.ui.enabled", true)) { @@ -1001,6 +1004,69 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging { /** The version of Spark on which this application is running. */ def version = SPARK_VERSION + /** + * Return a map from the slave to the max memory available for caching and the remaining + * memory available for caching. + */ + def getExecutorMemoryStatus: Map[String, (Long, Long)] = { + env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => + (blockManagerId.host + ":" + blockManagerId.port, mem) + } + } + + /** + * :: DeveloperApi :: + * Return information about what RDDs are cached, if they are in mem or on disk, how much space + * they take, etc. + */ + @DeveloperApi + def getRDDStorageInfo: Array[RDDInfo] = { + val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray + StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) + rddInfos.filter(_.isCached) + } + + /** + * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. + * Note that this does not necessarily mean the caching or computation was successful. + */ + def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap + + /** + * :: DeveloperApi :: + * Return information about blocks stored in all of the slaves + */ + @DeveloperApi + def getExecutorStorageStatus: Array[StorageStatus] = { + env.blockManager.master.getStorageStatus + } + + /** + * :: DeveloperApi :: + * Return pools for fair scheduler + */ + @DeveloperApi + def getAllPools: Seq[Schedulable] = { + // TODO(xiajunluan): We should take nested pools into account + taskScheduler.rootPool.schedulableQueue.toSeq + } + + /** + * :: DeveloperApi :: + * Return the pool associated with the given name, if one exists + */ + @DeveloperApi + def getPoolForName(pool: String): Option[Schedulable] = { + Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) + } + + /** + * Return current scheduling mode + */ + def getSchedulingMode: SchedulingMode.SchedulingMode = { + taskScheduler.schedulingMode + } + /** * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. diff --git a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala b/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala deleted file mode 100644 index 1982499c5e1d3..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.Map -import scala.collection.JavaConversions._ - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{SchedulingMode, Schedulable} -import org.apache.spark.storage.{StorageStatus, StorageUtils, RDDInfo} - -/** - * Trait that implements Spark's status APIs. This trait is designed to be mixed into - * SparkContext; it allows the status API code to live in its own file. - */ -private[spark] trait SparkStatusAPI { this: SparkContext => - - /** - * Return a map from the slave to the max memory available for caching and the remaining - * memory available for caching. - */ - def getExecutorMemoryStatus: Map[String, (Long, Long)] = { - env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) => - (blockManagerId.host + ":" + blockManagerId.port, mem) - } - } - - /** - * :: DeveloperApi :: - * Return information about what RDDs are cached, if they are in mem or on disk, how much space - * they take, etc. - */ - @DeveloperApi - def getRDDStorageInfo: Array[RDDInfo] = { - val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray - StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus) - rddInfos.filter(_.isCached) - } - - /** - * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. - */ - def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - - /** - * :: DeveloperApi :: - * Return information about blocks stored in all of the slaves - */ - @DeveloperApi - def getExecutorStorageStatus: Array[StorageStatus] = { - env.blockManager.master.getStorageStatus - } - - /** - * :: DeveloperApi :: - * Return pools for fair scheduler - */ - @DeveloperApi - def getAllPools: Seq[Schedulable] = { - // TODO(xiajunluan): We should take nested pools into account - taskScheduler.rootPool.schedulableQueue.toSeq - } - - /** - * :: DeveloperApi :: - * Return the pool associated with the given name, if one exists - */ - @DeveloperApi - def getPoolForName(pool: String): Option[Schedulable] = { - Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)) - } - - /** - * Return current scheduling mode - */ - def getSchedulingMode: SchedulingMode.SchedulingMode = { - taskScheduler.schedulingMode - } - - - /** - * Return a list of all known jobs in a particular job group. The returned list may contain - * running, failed, and completed jobs, and may vary across invocations of this method. This - * method does not guarantee the order of the elements in its result. - */ - def getJobIdsForGroup(jobGroup: String): Array[Int] = { - jobProgressListener.synchronized { - val jobData = jobProgressListener.jobIdToData.valuesIterator - jobData.filter(_.jobGroup.exists(_ == jobGroup)).map(_.jobId).toArray - } - } - - /** - * Returns job information, or `None` if the job info could not be found or was garbage collected. - */ - def getJobInfo(jobId: Int): Option[SparkJobInfo] = { - jobProgressListener.synchronized { - jobProgressListener.jobIdToData.get(jobId).map { data => - new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) - } - } - } - - /** - * Returns stage information, or `None` if the stage info could not be found or was - * garbage collected. - */ - def getStageInfo(stageId: Int): Option[SparkStageInfo] = { - jobProgressListener.synchronized { - for ( - info <- jobProgressListener.stageIdToInfo.get(stageId); - data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) - ) yield { - new SparkStageInfoImpl( - stageId, - info.attemptId, - info.name, - info.numTasks, - data.numActiveTasks, - data.numCompleteTasks, - data.numFailedTasks) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala new file mode 100644 index 0000000000000..c18d763d7ff4d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Low-level status reporting APIs for monitoring job and stage progress. + * + * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should + * be prepared to handle empty / missing information. For example, a job's stage ids may be known + * but the status API may not have any information about the details of those stages, so + * `getStageInfo` could potentially return `None` for a valid stage id. + * + * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs + * will provide information for the last `spark.ui.retainedStages` stages and + * `spark.ui.retainedJobs` jobs. + * + * NOTE: this class's constructor should be considered private and may be subject to change. + */ +class SparkStatusTracker private[spark] (sc: SparkContext) { + + private val jobProgressListener = sc.jobProgressListener + + /** + * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then + * returns all known jobs that are not associated with a job group. + * + * The returned list may contain running, failed, and completed jobs, and may vary across + * invocations of this method. This method does not guarantee the order of the elements in + * its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = { + jobProgressListener.synchronized { + val jobData = jobProgressListener.jobIdToData.valuesIterator + jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + } + } + + /** + * Returns an array containing the ids of all active stages. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveStageIds(): Array[Int] = { + jobProgressListener.synchronized { + jobProgressListener.activeStages.values.map(_.stageId).toArray + } + } + + /** + * Returns an array containing the ids of all active jobs. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveJobIds(): Array[Int] = { + jobProgressListener.synchronized { + jobProgressListener.activeJobs.values.map(_.jobId).toArray + } + } + + /** + * Returns job information, or `None` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): Option[SparkJobInfo] = { + jobProgressListener.synchronized { + jobProgressListener.jobIdToData.get(jobId).map { data => + new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) + } + } + } + + /** + * Returns stage information, or `None` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): Option[SparkStageInfo] = { + jobProgressListener.synchronized { + for ( + info <- jobProgressListener.stageIdToInfo.get(stageId); + data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) + ) yield { + new SparkStageInfoImpl( + stageId, + info.attemptId, + info.name, + info.numTasks, + data.numActiveTasks, + data.numCompleteTasks, + data.numFailedTasks) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 5c6e8d32c5c8a..d50ed32ca085c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -105,6 +105,8 @@ class JavaSparkContext(val sc: SparkContext) private[spark] val env = sc.env + def statusTracker = new JavaSparkStatusTracker(sc) + def isLocal: java.lang.Boolean = sc.isLocal def sparkUser: String = sc.sparkUser @@ -134,25 +136,6 @@ class JavaSparkContext(val sc: SparkContext) /** Default min number of partitions for Hadoop RDDs when not given by user */ def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions - - /** - * Return a list of all known jobs in a particular job group. The returned list may contain - * running, failed, and completed jobs, and may vary across invocations of this method. This - * method does not guarantee the order of the elements in its result. - */ - def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.getJobIdsForGroup(jobGroup) - - /** - * Returns job information, or `null` if the job info could not be found or was garbage collected. - */ - def getJobInfo(jobId: Int): SparkJobInfo = sc.getJobInfo(jobId).orNull - - /** - * Returns stage information, or `null` if the stage info could not be found or was - * garbage collected. - */ - def getStageInfo(stageId: Int): SparkStageInfo = sc.getStageInfo(stageId).orNull - /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val ctag: ClassTag[T] = fakeClassTag diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala new file mode 100644 index 0000000000000..3300cad9efbab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext} + +/** + * Low-level status reporting APIs for monitoring job and stage progress. + * + * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should + * be prepared to handle empty / missing information. For example, a job's stage ids may be known + * but the status API may not have any information about the details of those stages, so + * `getStageInfo` could potentially return `null` for a valid stage id. + * + * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs + * will provide information for the last `spark.ui.retainedStages` stages and + * `spark.ui.retainedJobs` jobs. + * + * NOTE: this class's constructor should be considered private and may be subject to change. + */ +class JavaSparkStatusTracker private[spark] (sc: SparkContext) { + + /** + * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then + * returns all known jobs that are not associated with a job group. + * + * The returned list may contain running, failed, and completed jobs, and may vary across + * invocations of this method. This method does not guarantee the order of the elements in + * its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.statusTracker.getJobIdsForGroup(jobGroup) + + /** + * Returns an array containing the ids of all active stages. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveStageIds(): Array[Int] = sc.statusTracker.getActiveStageIds() + + /** + * Returns an array containing the ids of all active jobs. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveJobIds(): Array[Int] = sc.statusTracker.getActiveJobIds() + + /** + * Returns job information, or `null` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): SparkJobInfo = sc.statusTracker.getJobInfo(jobId).orNull + + /** + * Returns stage information, or `null` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): SparkStageInfo = sc.statusTracker.getStageInfo(stageId).orNull +} diff --git a/core/src/test/scala/org/apache/spark/StatusAPISuite.scala b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala similarity index 69% rename from core/src/test/scala/org/apache/spark/StatusAPISuite.scala rename to core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala index 4468fba8c1dff..8577e4ac7e33e 100644 --- a/core/src/test/scala/org/apache/spark/StatusAPISuite.scala +++ b/core/src/test/scala/org/apache/spark/StatusTrackerSuite.scala @@ -27,9 +27,10 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.JobExecutionStatus._ import org.apache.spark.SparkContext._ -class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { +class StatusTrackerSuite extends FunSuite with Matchers with LocalSparkContext { test("basic status API usage") { + sc = new SparkContext("local", "test", new SparkConf(false)) val jobFuture = sc.parallelize(1 to 10000, 2).map(identity).groupBy(identity).collectAsync() val jobId: Int = eventually(timeout(10 seconds)) { val jobIds = jobFuture.jobIds @@ -37,20 +38,20 @@ class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { jobIds.head } val jobInfo = eventually(timeout(10 seconds)) { - sc.getJobInfo(jobId).get + sc.statusTracker.getJobInfo(jobId).get } jobInfo.status() should not be FAILED val stageIds = jobInfo.stageIds() stageIds.size should be(2) val firstStageInfo = eventually(timeout(10 seconds)) { - sc.getStageInfo(stageIds(0)).get + sc.statusTracker.getStageInfo(stageIds(0)).get } firstStageInfo.stageId() should be(stageIds(0)) firstStageInfo.currentAttemptId() should be(0) firstStageInfo.numTasks() should be(2) eventually(timeout(10 seconds)) { - val updatedFirstStageInfo = sc.getStageInfo(stageIds(0)).get + val updatedFirstStageInfo = sc.statusTracker.getStageInfo(stageIds(0)).get updatedFirstStageInfo.numCompletedTasks() should be(2) updatedFirstStageInfo.numActiveTasks() should be(0) updatedFirstStageInfo.numFailedTasks() should be(0) @@ -58,21 +59,31 @@ class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext { } test("getJobIdsForGroup()") { + sc = new SparkContext("local", "test", new SparkConf(false)) + // Passing `null` should return jobs that were not run in a job group: + val defaultJobGroupFuture = sc.parallelize(1 to 1000).countAsync() + val defaultJobGroupJobId = eventually(timeout(10 seconds)) { + defaultJobGroupFuture.jobIds.head + } + eventually(timeout(10 seconds)) { + sc.statusTracker.getJobIdsForGroup(null).toSet should be (Set(defaultJobGroupJobId)) + } + // Test jobs submitted in job groups: sc.setJobGroup("my-job-group", "description") - sc.getJobIdsForGroup("my-job-group") should be (Seq.empty) + sc.statusTracker.getJobIdsForGroup("my-job-group") should be (Seq.empty) val firstJobFuture = sc.parallelize(1 to 1000).countAsync() val firstJobId = eventually(timeout(10 seconds)) { firstJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.getJobIdsForGroup("my-job-group") should be (Seq(firstJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group") should be (Seq(firstJobId)) } val secondJobFuture = sc.parallelize(1 to 1000).countAsync() val secondJobId = eventually(timeout(10 seconds)) { secondJobFuture.jobIds.head } eventually(timeout(10 seconds)) { - sc.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) + sc.statusTracker.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId)) } } } \ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java similarity index 92% rename from examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java rename to examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java index 430e96ab14d9d..e68ec74c3ed54 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java @@ -31,7 +31,7 @@ /** * Example of using Spark's status APIs from Java. */ -public final class JavaStatusAPIDemo { +public final class JavaStatusTrackerDemo { public static final String APP_NAME = "JavaStatusAPIDemo"; @@ -58,8 +58,8 @@ public static void main(String[] args) throws Exception { continue; } int currentJobId = jobIds.get(jobIds.size() - 1); - SparkJobInfo jobInfo = sc.getJobInfo(currentJobId); - SparkStageInfo stageInfo = sc.getStageInfo(jobInfo.stageIds()[0]); + SparkJobInfo jobInfo = sc.statusTracker().getJobInfo(currentJobId); + SparkStageInfo stageInfo = sc.statusTracker().getStageInfo(jobInfo.stageIds()[0]); System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() + " active, " + stageInfo.numCompletedTasks() + " complete"); } From 7d8e152eecc7e822b7b1e40b791267a8911e01cf Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 15 Nov 2014 22:22:34 -0800 Subject: [PATCH 18/22] [SPARK-4419] Upgrade snappy-java to 1.1.1.6 This upgrades snappy-java to 1.1.1.6, which includes a patch that improves error messages when attempting to deserialize empty inputs using SnappyInputStream (see xerial/snappy-java#89). We previously tried up upgrade to 1.1.1.5 in #2911 but reverted that patch after discovering a memory leak in snappy-java. This should leak have been fixed in 1.1.1.6, though (see xerial/snappy-java#92). Author: Josh Rosen Closes #3287 from JoshRosen/SPARK-4419 and squashes the following commits: 5d6f4cc [Josh Rosen] [SPARK-4419] Upgrade snappy-java to 1.1.1.6. --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 90d3bff76bbbc..639ea22a1fda3 100644 --- a/pom.xml +++ b/pom.xml @@ -413,7 +413,7 @@ org.xerial.snappy snappy-java - 1.1.1.3 + 1.1.1.6 net.jpountz.lz4 From 84468b2e2031d646dcf035cb18947170ba326ccd Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sat, 15 Nov 2014 22:23:47 -0800 Subject: [PATCH 19/22] [SPARK-4426][SQL][Minor] The symbol of BitwiseOr is wrong, should not be '&' The symbol of BitwiseOr is defined as '&' but I think it's wrong. It should be '|'. Author: Kousuke Saruta Closes #3284 from sarutak/bitwise-or-symbol-fix and squashes the following commits: aff4be5 [Kousuke Saruta] Fixed symbol of BitwiseOr --- .../org/apache/spark/sql/catalyst/expressions/arithmetic.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 8574cabc43525..d17c9553ac24e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -142,7 +142,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * A function that calculates bitwise or(|) of two numbers. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { - def symbol = "&" + def symbol = "|" override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match { case ByteType => (evalE1.asInstanceOf[Byte] | evalE2.asInstanceOf[Byte]).toByte From 7850e0c707affd5eafd570fb43716753396cf479 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 16 Nov 2014 00:44:15 -0800 Subject: [PATCH 20/22] [SPARK-4393] Fix memory leak in ConnectionManager ACK timeout TimerTasks; use HashedWheelTimer This patch is intended to fix a subtle memory leak in ConnectionManager's ACK timeout TimerTasks: in the old code, each TimerTask held a reference to the message being sent and a cancelled TimerTask won't necessarily be garbage-collected until it's scheduled to run, so this caused huge buildups of messages that weren't garbage collected until their timeouts expired, leading to OOMs. This patch addresses this problem by capturing only the message ID in the TimerTask instead of the whole message, and by keeping a WeakReference to the promise in the TimerTask. I've also modified this code to use Netty's HashedWheelTimer, whose performance characteristics should be better for this use-case. Thanks to cristianopris for narrowing down this issue! Author: Josh Rosen Closes #3259 from JoshRosen/connection-manager-timeout-bugfix and squashes the following commits: afcc8d6 [Josh Rosen] Address rxin's review feedback. 2a2e92d [Josh Rosen] Keep only WeakReference to promise in TimerTask; 0f0913b [Josh Rosen] Spelling fix: timout => timeout 3200c33 [Josh Rosen] Use Netty HashedWheelTimer f847dd4 [Josh Rosen] Don't capture entire message in ACK timeout task. --- .../spark/network/nio/ConnectionManager.scala | 47 ++++++++++++++----- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index f198aa8564a54..df4b085d2251e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -18,13 +18,13 @@ package org.apache.spark.network.nio import java.io.IOException +import java.lang.ref.WeakReference import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} -import java.util.{Timer, TimerTask} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.concurrent.duration._ @@ -32,6 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 +import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} import org.apache.spark._ import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} @@ -77,7 +78,8 @@ private[nio] class ConnectionManager( } private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) + private val ackTimeoutMonitor = + new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) @@ -139,7 +141,10 @@ private[nio] class ConnectionManager( new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - private val messageStatuses = new HashMap[Int, MessageStatus] + // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this + // map when messages are sent and are removed when acknowledgement messages are received or when + // acknowledgement timeouts expire + private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val registerRequests = new SynchronizedQueue[SendingConnection] @@ -899,22 +904,41 @@ private[nio] class ConnectionManager( : Future[Message] = { val promise = Promise[Message]() - val timeoutTask = new TimerTask { - override def run(): Unit = { + // It's important that the TimerTask doesn't capture a reference to `message`, which can cause + // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time + // at which they would originally be scheduled to run. Therefore, extract the message id + // from outside of the TimerTask closure (see SPARK-4393 for more context). + val messageId = message.id + // Keep a weak reference to the promise so that the completed promise may be garbage-collected + val promiseReference = new WeakReference(promise) + val timeoutTask: TimerTask = new TimerTask { + override def run(timeout: Timeout): Unit = { messageStatuses.synchronized { - messageStatuses.remove(message.id).foreach ( s => { + messageStatuses.remove(messageId).foreach { s => val e = new IOException("sendMessageReliably failed because ack " + s"was not received within $ackTimeout sec") - if (!promise.tryFailure(e)) { - logWarning("Ignore error because promise is completed", e) + val p = promiseReference.get + if (p != null) { + // Attempt to fail the promise with a Timeout exception + if (!p.tryFailure(e)) { + // If we reach here, then someone else has already signalled success or failure + // on this promise, so log a warning: + logError("Ignore error because promise is completed", e) + } + } else { + // The WeakReference was empty, which should never happen because + // sendMessageReliably's caller should have a strong reference to promise.future; + logError("Promise was garbage collected; this should never happen!", e) } - }) + } } } } + val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) + val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTask.cancel() + timeoutTaskHandle.cancel() s match { case scala.util.Failure(e) => // Indicates a failure where we either never sent or never got ACK'd @@ -943,7 +967,6 @@ private[nio] class ConnectionManager( messageStatuses += ((message.id, status)) } - ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } @@ -953,7 +976,7 @@ private[nio] class ConnectionManager( } def stop() { - ackTimeoutMonitor.cancel() + ackTimeoutMonitor.stop() selectorThread.interrupt() selectorThread.join() selector.close() From cb6bd83a91d9b4a227dc6467255231869c1820e2 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 16 Nov 2014 14:26:41 -0800 Subject: [PATCH 21/22] [SPARK-4309][SPARK-4407][SQL] Date type support for Thrift server, and fixes for complex types SPARK-4407 was detected while working on SPARK-4309. Merged these two into a single PR since 1.2.0 RC is approaching. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/3178) Author: Cheng Lian Closes #3178 from liancheng/date-for-thriftserver and squashes the following commits: 6f71d0b [Cheng Lian] Makes toHiveString static 26fa955 [Cheng Lian] Fixes complex type support in Hive 0.13.1 shim a92882a [Cheng Lian] Updates HiveShim for 0.13.1 73f442b [Cheng Lian] Adds Date support for HiveThriftServer2 (Hive 0.12.0) --- .../thriftserver/HiveThriftServer2Suite.scala | 90 +++++++++---- .../spark/sql/hive/thriftserver/Shim12.scala | 11 +- .../spark/sql/hive/thriftserver/Shim13.scala | 29 ++-- .../apache/spark/sql/hive/HiveContext.scala | 127 ++++++++---------- 4 files changed, 142 insertions(+), 115 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index bba29b2bdca4d..23d12cbff3495 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.ServerSocket -import java.sql.{DriverManager, Statement} +import java.sql.{Date, DriverManager, Statement} import java.util.concurrent.TimeoutException +import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} @@ -51,6 +52,15 @@ import org.apache.spark.sql.hive.HiveShim class HiveThriftServer2Suite extends FunSuite with Logging { Class.forName(classOf[HiveDriver].getCanonicalName) + object TestData { + def getTestDataFilePath(name: String) = { + Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") + } + + val smallKv = getTestDataFilePath("small_kv.txt") + val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") + } + def randomListeningPort = { // Let the system to choose a random available port to avoid collision with other parallel // builds. @@ -145,12 +155,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } } - val env = Seq( - // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths - "SPARK_TESTING" -> "0", - // Prevents loading classes out of the assembly jar. Otherwise Utils.sparkVersion can't read - // proper version information from the jar manifest. - "SPARK_PREPEND_CLASSES" -> "") + // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths + val env = Seq("SPARK_TESTING" -> "0") Process(command, None, env: _*).run(ProcessLogger( captureThriftServerOutput("stdout"), @@ -194,15 +200,12 @@ class HiveThriftServer2Suite extends FunSuite with Logging { test("Test JDBC query execution") { withJdbcStatement() { statement => - val dataFilePath = - Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - - val queries = - s"""SET spark.sql.shuffle.partitions=3; - |CREATE TABLE test(key INT, val STRING); - |LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test; - |CACHE TABLE test; - """.stripMargin.split(";").map(_.trim).filter(_.nonEmpty) + val queries = Seq( + "SET spark.sql.shuffle.partitions=3", + "DROP TABLE IF EXISTS test", + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", + "CACHE TABLE test") queries.foreach(statement.execute) @@ -216,14 +219,10 @@ class HiveThriftServer2Suite extends FunSuite with Logging { test("SPARK-3004 regression: result set containing NULL") { withJdbcStatement() { statement => - val dataFilePath = - Thread.currentThread().getContextClassLoader.getResource( - "data/files/small_kv_with_null.txt") - val queries = Seq( "DROP TABLE IF EXISTS test_null", "CREATE TABLE test_null(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null") + s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") queries.foreach(statement.execute) @@ -270,13 +269,10 @@ class HiveThriftServer2Suite extends FunSuite with Logging { test("SPARK-4292 regression: result set iterator issue") { withJdbcStatement() { statement => - val dataFilePath = - Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") - val queries = Seq( "DROP TABLE IF EXISTS test_4292", "CREATE TABLE test_4292(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_4292") + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") queries.foreach(statement.execute) @@ -284,10 +280,52 @@ class HiveThriftServer2Suite extends FunSuite with Logging { Seq(238, 86, 311, 27, 165).foreach { key => resultSet.next() - assert(resultSet.getInt(1) == key) + assert(resultSet.getInt(1) === key) } statement.executeQuery("DROP TABLE IF EXISTS test_4292") } } + + test("SPARK-4309 regression: Date type support") { + withJdbcStatement() { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_date", + "CREATE TABLE test_date(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") + + queries.foreach(statement.execute) + + assertResult(Date.valueOf("2011-01-01")) { + val resultSet = statement.executeQuery( + "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") + resultSet.next() + resultSet.getDate(1) + } + } + } + + test("SPARK-4407 regression: Complex type support") { + withJdbcStatement() { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + assertResult("""{238:"val_238"}""") { + val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + + assertResult("""["238","val_238"]""") { + val resultSet = statement.executeQuery( + "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + } + } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index aa2e3cab72bb9..9258ad0cdf1d0 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.util.{ArrayList => JArrayList, Map => JMap} import scala.collection.JavaConversions._ @@ -131,14 +131,13 @@ private[hive] class SparkExecuteStatementOperation( to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) case ShortType => to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) + case DateType => + to.addColumnValue(ColumnValue.dateValue(from(ordinal).asInstanceOf[Date])) case TimestampType => to.addColumnValue( ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = result - .queryExecution - .asInstanceOf[HiveContext#QueryExecution] - .toHiveString((from.get(ordinal), dataTypes(ordinal))) + val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) to.addColumnValue(ColumnValue.stringValue(hiveString)) } } @@ -163,6 +162,8 @@ private[hive] class SparkExecuteStatementOperation( to.addColumnValue(ColumnValue.byteValue(null)) case ShortType => to.addColumnValue(ColumnValue.shortValue(null)) + case DateType => + to.addColumnValue(ColumnValue.dateValue(null)) case TimestampType => to.addColumnValue(ColumnValue.timestampValue(null)) case BinaryType | _: ArrayType | _: StructType | _: MapType => diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index a642478d08857..3c7f62af450d9 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import java.util.concurrent.Future import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} @@ -113,7 +113,7 @@ private[hive] class SparkExecuteStatementOperation( def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { dataTypes(ordinal) match { case StringType => - to += from.get(ordinal).asInstanceOf[String] + to += from.getString(ordinal) case IntegerType => to += from.getInt(ordinal) case BooleanType => @@ -123,23 +123,20 @@ private[hive] class SparkExecuteStatementOperation( case FloatType => to += from.getFloat(ordinal) case DecimalType() => - to += from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal + to += from.getAs[BigDecimal](ordinal).bigDecimal case LongType => to += from.getLong(ordinal) case ByteType => to += from.getByte(ordinal) case ShortType => to += from.getShort(ordinal) + case DateType => + to += from.getAs[Date](ordinal) case TimestampType => - to += from.get(ordinal).asInstanceOf[Timestamp] - case BinaryType => - to += from.get(ordinal).asInstanceOf[String] - case _: ArrayType => - to += from.get(ordinal).asInstanceOf[String] - case _: StructType => - to += from.get(ordinal).asInstanceOf[String] - case _: MapType => - to += from.get(ordinal).asInstanceOf[String] + to += from.getAs[Timestamp](ordinal) + case BinaryType | _: ArrayType | _: StructType | _: MapType => + val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) + to += hiveString } } @@ -147,9 +144,9 @@ private[hive] class SparkExecuteStatementOperation( validateDefaultFetchOrientation(order) assertState(OperationState.FINISHED) setHasResultSet(true) - val reultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) + val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) if (!iter.hasNext) { - reultRowSet + resultRowSet } else { // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int val maxRows = maxRowsL.toInt @@ -166,10 +163,10 @@ private[hive] class SparkExecuteStatementOperation( } curCol += 1 } - reultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) + resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) curRow += 1 } - reultRowSet + resultRowSet } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index e88afaaf001c0..feed64fe4cd6f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -19,36 +19,27 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.sql.{Date, Timestamp} -import java.util.{ArrayList => JArrayList} - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.spark.sql.catalyst.types.DecimalType -import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.reflect.runtime.universe.TypeTag -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.serde2.io.TimestampWritable -import org.apache.hadoop.hive.serde2.io.DateWritable +import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators} -import org.apache.spark.sql.catalyst.analysis.{OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.ExtractPythonUdfs -import org.apache.spark.sql.execution.QueryExecutionException -import org.apache.spark.sql.execution.{Command => PhysicalCommand} +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.execution.{ExtractPythonUdfs, QueryExecutionException, Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand import org.apache.spark.sql.sources.DataSourceStrategy @@ -136,7 +127,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val relation = EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) relation match { - case relation: MetastoreRelation => { + case relation: MetastoreRelation => // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) // in Hive 0.13 (except that we do not use fs.getContentSummary). @@ -147,7 +138,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // countFileSize to count the table size. def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDir) { + val size = if (fileStatus.isDirectory) { fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum } else { fileStatus.getLen @@ -157,7 +148,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } def getFileSizeForTable(conf: HiveConf, table: Table): Long = { - val path = table.getPath() + val path = table.getPath var size: Long = 0L try { val fs = path.getFileSystem(conf) @@ -187,15 +178,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hiveTTable = relation.hiveQlTable.getTTable hiveTTable.setParameters(tableParameters) val tableFullName = - relation.hiveQlTable.getDbName() + "." + relation.hiveQlTable.getTableName() + relation.hiveQlTable.getDbName + "." + relation.hiveQlTable.getTableName catalog.client.alterTable(tableFullName, new Table(hiveTTable)) } - } case otherRelation => throw new NotImplementedError( s"Analyze has only implemented for Hive tables, " + - s"but ${tableName} is a ${otherRelation.nodeName}") + s"but $tableName is a ${otherRelation.nodeName}") } } @@ -374,50 +364,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { - protected val primitiveTypes = - Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DateType, TimestampType, BinaryType) - - protected[sql] def toHiveString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => - struct.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "NULL" - case (d: Date, DateType) => new DateWritable(d).toString - case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString - case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") - case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString - HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString - case (other, tpe) if primitiveTypes contains tpe => other.toString - } - - /** Hive outputs fields of structs slightly differently than top level attributes. */ - protected def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => - struct.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "null" - case (s: String, StringType) => "\"" + s + "\"" - case (decimal, DecimalType()) => decimal.toString - case (other, tpe) if primitiveTypes contains tpe => other.toString - } - /** * Returns the result as a hive compatible sequence of strings. For native commands, the * execution is simply passed back to Hive. @@ -435,8 +381,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. - val asString = result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq - asString + result.map(_.zip(types).map(HiveContext.toHiveString)).map(_.mkString("\t")).toSeq } override def simpleString: String = @@ -447,3 +392,49 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } } + +object HiveContext { + protected val primitiveTypes = + Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, + ShortType, DateType, TimestampType, BinaryType) + + protected[sql] def toHiveString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_,_], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "NULL" + case (d: Date, DateType) => new DateWritable(d).toString + case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString + case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") + case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString + HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + + /** Hive outputs fields of structs slightly differently than top level attributes. */ + protected def toHiveStructString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_, _], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "null" + case (s: String, StringType) => "\"" + s + "\"" + case (decimal, DecimalType()) => decimal.toString + case (other, tpe) if primitiveTypes contains tpe => other.toString + } +} From 45ce3273cb618d14ec4d20c4c95699634b951086 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 16 Nov 2014 15:05:04 -0800 Subject: [PATCH 22/22] Revert "[SPARK-4309][SPARK-4407][SQL] Date type support for Thrift server, and fixes for complex types" Author: Michael Armbrust Closes #3292 from marmbrus/revert4309 and squashes the following commits: 808e96e [Michael Armbrust] Revert "[SPARK-4309][SPARK-4407][SQL] Date type support for Thrift server, and fixes for complex types" --- .../thriftserver/HiveThriftServer2Suite.scala | 90 ++++--------- .../spark/sql/hive/thriftserver/Shim12.scala | 11 +- .../spark/sql/hive/thriftserver/Shim13.scala | 29 ++-- .../apache/spark/sql/hive/HiveContext.scala | 127 ++++++++++-------- 4 files changed, 115 insertions(+), 142 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index 23d12cbff3495..bba29b2bdca4d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -19,10 +19,9 @@ package org.apache.spark.sql.hive.thriftserver import java.io.File import java.net.ServerSocket -import java.sql.{Date, DriverManager, Statement} +import java.sql.{DriverManager, Statement} import java.util.concurrent.TimeoutException -import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} @@ -52,15 +51,6 @@ import org.apache.spark.sql.hive.HiveShim class HiveThriftServer2Suite extends FunSuite with Logging { Class.forName(classOf[HiveDriver].getCanonicalName) - object TestData { - def getTestDataFilePath(name: String) = { - Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") - } - - val smallKv = getTestDataFilePath("small_kv.txt") - val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") - } - def randomListeningPort = { // Let the system to choose a random available port to avoid collision with other parallel // builds. @@ -155,8 +145,12 @@ class HiveThriftServer2Suite extends FunSuite with Logging { } } - // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths - val env = Seq("SPARK_TESTING" -> "0") + val env = Seq( + // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths + "SPARK_TESTING" -> "0", + // Prevents loading classes out of the assembly jar. Otherwise Utils.sparkVersion can't read + // proper version information from the jar manifest. + "SPARK_PREPEND_CLASSES" -> "") Process(command, None, env: _*).run(ProcessLogger( captureThriftServerOutput("stdout"), @@ -200,12 +194,15 @@ class HiveThriftServer2Suite extends FunSuite with Logging { test("Test JDBC query execution") { withJdbcStatement() { statement => - val queries = Seq( - "SET spark.sql.shuffle.partitions=3", - "DROP TABLE IF EXISTS test", - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", - "CACHE TABLE test") + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + + val queries = + s"""SET spark.sql.shuffle.partitions=3; + |CREATE TABLE test(key INT, val STRING); + |LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test; + |CACHE TABLE test; + """.stripMargin.split(";").map(_.trim).filter(_.nonEmpty) queries.foreach(statement.execute) @@ -219,10 +216,14 @@ class HiveThriftServer2Suite extends FunSuite with Logging { test("SPARK-3004 regression: result set containing NULL") { withJdbcStatement() { statement => + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource( + "data/files/small_kv_with_null.txt") + val queries = Seq( "DROP TABLE IF EXISTS test_null", "CREATE TABLE test_null(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null") queries.foreach(statement.execute) @@ -269,10 +270,13 @@ class HiveThriftServer2Suite extends FunSuite with Logging { test("SPARK-4292 regression: result set iterator issue") { withJdbcStatement() { statement => + val dataFilePath = + Thread.currentThread().getContextClassLoader.getResource("data/files/small_kv.txt") + val queries = Seq( "DROP TABLE IF EXISTS test_4292", "CREATE TABLE test_4292(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") + s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_4292") queries.foreach(statement.execute) @@ -280,52 +284,10 @@ class HiveThriftServer2Suite extends FunSuite with Logging { Seq(238, 86, 311, 27, 165).foreach { key => resultSet.next() - assert(resultSet.getInt(1) === key) + assert(resultSet.getInt(1) == key) } statement.executeQuery("DROP TABLE IF EXISTS test_4292") } } - - test("SPARK-4309 regression: Date type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_date", - "CREATE TABLE test_date(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") - - queries.foreach(statement.execute) - - assertResult(Date.valueOf("2011-01-01")) { - val resultSet = statement.executeQuery( - "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") - resultSet.next() - resultSet.getDate(1) - } - } - } - - test("SPARK-4407 regression: Complex type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_map", - "CREATE TABLE test_map(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") - - queries.foreach(statement.execute) - - assertResult("""{238:"val_238"}""") { - val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - - assertResult("""["238","val_238"]""") { - val resultSet = statement.executeQuery( - "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - } - } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 9258ad0cdf1d0..aa2e3cab72bb9 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import java.util.{ArrayList => JArrayList, Map => JMap} import scala.collection.JavaConversions._ @@ -131,13 +131,14 @@ private[hive] class SparkExecuteStatementOperation( to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal))) case ShortType => to.addColumnValue(ColumnValue.shortValue(from.getShort(ordinal))) - case DateType => - to.addColumnValue(ColumnValue.dateValue(from(ordinal).asInstanceOf[Date])) case TimestampType => to.addColumnValue( ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp])) case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) + val hiveString = result + .queryExecution + .asInstanceOf[HiveContext#QueryExecution] + .toHiveString((from.get(ordinal), dataTypes(ordinal))) to.addColumnValue(ColumnValue.stringValue(hiveString)) } } @@ -162,8 +163,6 @@ private[hive] class SparkExecuteStatementOperation( to.addColumnValue(ColumnValue.byteValue(null)) case ShortType => to.addColumnValue(ColumnValue.shortValue(null)) - case DateType => - to.addColumnValue(ColumnValue.dateValue(null)) case TimestampType => to.addColumnValue(ColumnValue.timestampValue(null)) case BinaryType | _: ArrayType | _: StructType | _: MapType => diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index 3c7f62af450d9..a642478d08857 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction -import java.sql.{Date, Timestamp} +import java.sql.Timestamp import java.util.concurrent.Future import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} @@ -113,7 +113,7 @@ private[hive] class SparkExecuteStatementOperation( def addNonNullColumnValue(from: SparkRow, to: ArrayBuffer[Any], ordinal: Int) { dataTypes(ordinal) match { case StringType => - to += from.getString(ordinal) + to += from.get(ordinal).asInstanceOf[String] case IntegerType => to += from.getInt(ordinal) case BooleanType => @@ -123,20 +123,23 @@ private[hive] class SparkExecuteStatementOperation( case FloatType => to += from.getFloat(ordinal) case DecimalType() => - to += from.getAs[BigDecimal](ordinal).bigDecimal + to += from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal case LongType => to += from.getLong(ordinal) case ByteType => to += from.getByte(ordinal) case ShortType => to += from.getShort(ordinal) - case DateType => - to += from.getAs[Date](ordinal) case TimestampType => - to += from.getAs[Timestamp](ordinal) - case BinaryType | _: ArrayType | _: StructType | _: MapType => - val hiveString = HiveContext.toHiveString((from.get(ordinal), dataTypes(ordinal))) - to += hiveString + to += from.get(ordinal).asInstanceOf[Timestamp] + case BinaryType => + to += from.get(ordinal).asInstanceOf[String] + case _: ArrayType => + to += from.get(ordinal).asInstanceOf[String] + case _: StructType => + to += from.get(ordinal).asInstanceOf[String] + case _: MapType => + to += from.get(ordinal).asInstanceOf[String] } } @@ -144,9 +147,9 @@ private[hive] class SparkExecuteStatementOperation( validateDefaultFetchOrientation(order) assertState(OperationState.FINISHED) setHasResultSet(true) - val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) + val reultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) if (!iter.hasNext) { - resultRowSet + reultRowSet } else { // maxRowsL here typically maps to java.sql.Statement.getFetchSize, which is an int val maxRows = maxRowsL.toInt @@ -163,10 +166,10 @@ private[hive] class SparkExecuteStatementOperation( } curCol += 1 } - resultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) + reultRowSet.addRow(row.toArray.asInstanceOf[Array[Object]]) curRow += 1 } - resultRowSet + reultRowSet } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index feed64fe4cd6f..e88afaaf001c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -19,27 +19,36 @@ package org.apache.spark.sql.hive import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.sql.{Date, Timestamp} +import java.util.{ArrayList => JArrayList} + +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag +import scala.reflect.runtime.universe.{TypeTag, typeTag} -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.Table import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} +import org.apache.hadoop.hive.serde2.io.TimestampWritable +import org.apache.hadoop.hive.serde2.io.DateWritable import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators} +import org.apache.spark.sql.catalyst.analysis.{OverrideCatalog, OverrideFunctionRegistry} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types.DecimalType -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import org.apache.spark.sql.execution.{ExtractPythonUdfs, QueryExecutionException, Command => PhysicalCommand} +import org.apache.spark.sql.execution.ExtractPythonUdfs +import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand import org.apache.spark.sql.sources.DataSourceStrategy @@ -127,7 +136,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val relation = EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) relation match { - case relation: MetastoreRelation => + case relation: MetastoreRelation => { // This method is mainly based on // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) // in Hive 0.13 (except that we do not use fs.getContentSummary). @@ -138,7 +147,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // countFileSize to count the table size. def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDirectory) { + val size = if (fileStatus.isDir) { fs.listStatus(path).map(status => calculateTableSize(fs, status.getPath)).sum } else { fileStatus.getLen @@ -148,7 +157,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } def getFileSizeForTable(conf: HiveConf, table: Table): Long = { - val path = table.getPath + val path = table.getPath() var size: Long = 0L try { val fs = path.getFileSystem(conf) @@ -178,14 +187,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hiveTTable = relation.hiveQlTable.getTTable hiveTTable.setParameters(tableParameters) val tableFullName = - relation.hiveQlTable.getDbName + "." + relation.hiveQlTable.getTableName + relation.hiveQlTable.getDbName() + "." + relation.hiveQlTable.getTableName() catalog.client.alterTable(tableFullName, new Table(hiveTTable)) } + } case otherRelation => throw new NotImplementedError( s"Analyze has only implemented for Hive tables, " + - s"but $tableName is a ${otherRelation.nodeName}") + s"but ${tableName} is a ${otherRelation.nodeName}") } } @@ -364,6 +374,50 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { + protected val primitiveTypes = + Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, + ShortType, DateType, TimestampType, BinaryType) + + protected[sql] def toHiveString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_,_], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "NULL" + case (d: Date, DateType) => new DateWritable(d).toString + case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString + case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") + case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString + HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + + /** Hive outputs fields of structs slightly differently than top level attributes. */ + protected def toHiveStructString(a: (Any, DataType)): String = a match { + case (struct: Row, StructType(fields)) => + struct.zip(fields).map { + case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" + }.mkString("{", ",", "}") + case (seq: Seq[_], ArrayType(typ, _)) => + seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") + case (map: Map[_, _], MapType(kType, vType, _)) => + map.map { + case (key, value) => + toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) + }.toSeq.sorted.mkString("{", ",", "}") + case (null, _) => "null" + case (s: String, StringType) => "\"" + s + "\"" + case (decimal, DecimalType()) => decimal.toString + case (other, tpe) if primitiveTypes contains tpe => other.toString + } + /** * Returns the result as a hive compatible sequence of strings. For native commands, the * execution is simply passed back to Hive. @@ -381,7 +435,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. - result.map(_.zip(types).map(HiveContext.toHiveString)).map(_.mkString("\t")).toSeq + val asString = result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq + asString } override def simpleString: String = @@ -392,49 +447,3 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } } - -object HiveContext { - protected val primitiveTypes = - Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DateType, TimestampType, BinaryType) - - protected[sql] def toHiveString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => - struct.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "NULL" - case (d: Date, DateType) => new DateWritable(d).toString - case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString - case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") - case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString - HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString - case (other, tpe) if primitiveTypes contains tpe => other.toString - } - - /** Hive outputs fields of structs slightly differently than top level attributes. */ - protected def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => - struct.zip(fields).map { - case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" - }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => - seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_, _], MapType(kType, vType, _)) => - map.map { - case (key, value) => - toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) - }.toSeq.sorted.mkString("{", ",", "}") - case (null, _) => "null" - case (s: String, StringType) => "\"" + s + "\"" - case (decimal, DecimalType()) => decimal.toString - case (other, tpe) if primitiveTypes contains tpe => other.toString - } -}