From 5efd83f6712f96f43f72f90e3f4a4e5522d001ea Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 20 Apr 2015 18:42:50 -0700 Subject: [PATCH] [SPARK-6368][SQL] Build a specialized serializer for Exchange operator. JIRA: https://issues.apache.org/jira/browse/SPARK-6368 Author: Yin Huai Closes #5497 from yhuai/serializer2 and squashes the following commits: da562c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 50e0c3d [Yin Huai] When no filed is emitted to shuffle, use SparkSqlSerializer for now. 9f1ed92 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 6d07678 [Yin Huai] Address comments. 4273b8c [Yin Huai] Enabled SparkSqlSerializer2. 09e587a [Yin Huai] Remove TODO. 791b96a [Yin Huai] Use UTF8String. 60a1487 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2 3e09655 [Yin Huai] Use getAs for Date column. 43b9fb4 [Yin Huai] Test. 8297732 [Yin Huai] Fix test. c9373c8 [Yin Huai] Support DecimalType. 2379eeb [Yin Huai] ASF header. 39704ab [Yin Huai] Specialized serializer for Exchange. --- .../scala/org/apache/spark/sql/SQLConf.scala | 4 + .../apache/spark/sql/execution/Exchange.scala | 59 ++- .../sql/execution/SparkSqlSerializer2.scala | 421 ++++++++++++++++++ .../execution/SparkSqlSerializer2Suite.scala | 195 ++++++++ 4 files changed, 673 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 5c65f04ee8497..4fc5de7e824fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -64,6 +64,8 @@ private[spark] object SQLConf { // Set to false when debugging requires the ability to look at invalid query plans. val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" + val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -147,6 +149,8 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean + /** * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to * a broadcast value during the physical executions of join operations. Setting this to -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 69a620e1ec929..5b2e46962cd3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf} +import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner} import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.serializer.Serializer import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType import org.apache.spark.util.MutablePair object Exchange { @@ -77,9 +79,48 @@ case class Exchange( } } - override def execute(): RDD[Row] = attachTree(this , "execute") { - lazy val sparkConf = child.sqlContext.sparkContext.getConf + @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf + + def serializer( + keySchema: Array[DataType], + valueSchema: Array[DataType], + numPartitions: Int): Serializer = { + // In ExternalSorter's spillToMergeableFile function, key-value pairs are written out + // through write(key) and then write(value) instead of write((key, value)). Because + // SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use + // it when spillToMergeableFile in ExternalSorter will be used. + // So, we will not use SparkSqlSerializer2 when + // - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater + // then the bypassMergeThreshold; or + // - newOrdering is defined. + val cannotUseSqlSerializer2 = + (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty + + // It is true when there is no field that needs to be write out. + // For now, we will not use SparkSqlSerializer2 when noField is true. + val noField = + (keySchema == null || keySchema.length == 0) && + (valueSchema == null || valueSchema.length == 0) + + val useSqlSerializer2 = + child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. + !cannotUseSqlSerializer2 && // Safe to use Serializer2. + SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. + SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. + !noField + + val serializer = if (useSqlSerializer2) { + logInfo("Using SparkSqlSerializer2.") + new SparkSqlSerializer2(keySchema, valueSchema) + } else { + logInfo("Using SparkSqlSerializer.") + new SparkSqlSerializer(sparkConf) + } + + serializer + } + override def execute(): RDD[Row] = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. @@ -111,7 +152,10 @@ case class Exchange( } else { new ShuffledRDD[Row, Row, Row](rdd, part) } - shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) + val keySchema = expressions.map(_.dataType).toArray + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions)) + shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => @@ -134,7 +178,9 @@ case class Exchange( } else { new ShuffledRDD[Row, Null, Null](rdd, part) } - shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) + val keySchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(keySchema, null, numPartitions)) + shuffled.map(_._1) case SinglePartition => @@ -152,7 +198,8 @@ case class Exchange( } val partitioner = new HashPartitioner(1) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) - shuffled.setSerializer(new SparkSqlSerializer(sparkConf)) + val valueSchema = child.output.map(_.dataType).toArray + shuffled.setSerializer(serializer(null, valueSchema, 1)) shuffled.map(_._2) case _ => sys.error(s"Exchange not implemented for $newPartitioning") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala new file mode 100644 index 0000000000000..cec97de2cd8e4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io._ +import java.math.{BigDecimal, BigInteger} +import java.nio.ByteBuffer +import java.sql.Timestamp + +import scala.reflect.ClassTag + +import org.apache.spark.serializer._ +import org.apache.spark.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types._ + +/** + * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in + * its `writeObject` are [[Product2]]. The serialization functions for the key and value of the + * [[Product2]] are constructed based on their schemata. + * The benefit of this serialization stream is that compared with general-purpose serializers like + * Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower + * allocation cost, which can benefit the shuffle operation. Right now, its main limitations are: + * 1. It does not support complex types, i.e. Map, Array, and Struct. + * 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when + * [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because + * the objects passed in the serializer are not in the type of [[Product2]]. Also also see + * the comment of the `serializer` method in [[Exchange]] for more information on it. + */ +private[sql] class Serializer2SerializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + out: OutputStream) + extends SerializationStream with Logging { + + val rowOut = new DataOutputStream(out) + val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) + val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) + + def writeObject[T: ClassTag](t: T): SerializationStream = { + val kv = t.asInstanceOf[Product2[Row, Row]] + writeKey(kv._1) + writeValue(kv._2) + + this + } + + def flush(): Unit = { + rowOut.flush() + } + + def close(): Unit = { + rowOut.close() + } +} + +/** + * The corresponding deserialization stream for [[Serializer2SerializationStream]]. + */ +private[sql] class Serializer2DeserializationStream( + keySchema: Array[DataType], + valueSchema: Array[DataType], + in: InputStream) + extends DeserializationStream with Logging { + + val rowIn = new DataInputStream(new BufferedInputStream(in)) + + val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null + val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null + val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) + val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) + + def readObject[T: ClassTag](): T = { + readKey() + readValue() + + (key, value).asInstanceOf[T] + } + + def close(): Unit = { + rowIn.close() + } +} + +private[sql] class ShuffleSerializerInstance( + keySchema: Array[DataType], + valueSchema: Array[DataType]) + extends SerializerInstance { + + def serialize[T: ClassTag](t: T): ByteBuffer = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer): T = + throw new UnsupportedOperationException("Not supported.") + + def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = + throw new UnsupportedOperationException("Not supported.") + + def serializeStream(s: OutputStream): SerializationStream = { + new Serializer2SerializationStream(keySchema, valueSchema, s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new Serializer2DeserializationStream(keySchema, valueSchema, s) + } +} + +/** + * SparkSqlSerializer2 is a special serializer that creates serialization function and + * deserialization function based on the schema of data. It assumes that values passed in + * are key/value pairs and values returned from it are also key/value pairs. + * The schema of keys is represented by `keySchema` and that of values is represented by + * `valueSchema`. + */ +private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType]) + extends Serializer + with Logging + with Serializable{ + + def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema) +} + +private[sql] object SparkSqlSerializer2 { + + final val NULL = 0 + final val NOT_NULL = 1 + + /** + * Check if rows with the given schema can be serialized with ShuffleSerializer. + */ + def support(schema: Array[DataType]): Boolean = { + if (schema == null) return true + + var i = 0 + while (i < schema.length) { + schema(i) match { + case udt: UserDefinedType[_] => return false + case array: ArrayType => return false + case map: MapType => return false + case struct: StructType => return false + case _ => + } + i += 1 + } + + return true + } + + /** + * The util function to create the serialization function based on the given schema. + */ + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { + (row: Row) => + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we write values to the underlying stream, we also first write the null byte + // first. Then, if the value is not null, we write the contents out. + + case NullType => // Write nothing. + + case BooleanType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeBoolean(row.getBoolean(i)) + } + + case ByteType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeByte(row.getByte(i)) + } + + case ShortType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeShort(row.getShort(i)) + } + + case IntegerType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getInt(i)) + } + + case LongType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeLong(row.getLong(i)) + } + + case FloatType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeFloat(row.getFloat(i)) + } + + case DoubleType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeDouble(row.getDouble(i)) + } + + case decimal: DecimalType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + out.writeInt(bytes.length) + out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) + } + + case DateType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + out.writeInt(row.getAs[Int](i)) + } + + case TimestampType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val timestamp = row.getAs[java.sql.Timestamp](i) + val time = timestamp.getTime + val nanos = timestamp.getNanos + out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value. + out.writeInt(nanos) // Write the nanoseconds part. + } + + case StringType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[UTF8String](i).getBytes + out.writeInt(bytes.length) + out.write(bytes) + } + + case BinaryType => + if (row.isNullAt(i)) { + out.writeByte(NULL) + } else { + out.writeByte(NOT_NULL) + val bytes = row.getAs[Array[Byte]](i) + out.writeInt(bytes.length) + out.write(bytes) + } + } + i += 1 + } + } + } + + /** + * The util function to create the deserialization function based on the given schema. + */ + def createDeserializationFunction( + schema: Array[DataType], + in: DataInputStream, + mutableRow: SpecificMutableRow): () => Unit = { + () => { + // If the schema is null, the returned function does nothing when it get called. + if (schema != null) { + var i = 0 + while (i < schema.length) { + schema(i) match { + // When we read values from the underlying stream, we also first read the null byte + // first. Then, if the value is not null, we update the field of the mutable row. + + case NullType => mutableRow.setNullAt(i) // Read nothing. + + case BooleanType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setBoolean(i, in.readBoolean()) + } + + case ByteType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setByte(i, in.readByte()) + } + + case ShortType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setShort(i, in.readShort()) + } + + case IntegerType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setInt(i, in.readInt()) + } + + case LongType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setLong(i, in.readLong()) + } + + case FloatType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setFloat(i, in.readFloat()) + } + + case DoubleType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.setDouble(i, in.readDouble()) + } + + case decimal: DecimalType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + // First, read in the unscaled value. + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) + } + + case DateType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + mutableRow.update(i, in.readInt()) + } + + case TimestampType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val time = in.readLong() // Read the milliseconds value. + val nanos = in.readInt() // Read the nanoseconds part. + val timestamp = new Timestamp(time) + timestamp.setNanos(nanos) + mutableRow.update(i, timestamp) + } + + case StringType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, UTF8String(bytes)) + } + + case BinaryType => + if (in.readByte() == NULL) { + mutableRow.setNullAt(i) + } else { + val length = in.readInt() + val bytes = new Array[Byte](length) + in.readFully(bytes) + mutableRow.update(i, bytes) + } + } + i += 1 + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala new file mode 100644 index 0000000000000..27f063d73a9a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.sql.{Timestamp, Date} + +import org.scalatest.{FunSuite, BeforeAndAfterAll} + +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.ShuffleDependency +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} + +class SparkSqlSerializer2DataTypeSuite extends FunSuite { + // Make sure that we will not use serializer2 for unsupported data types. + def checkSupported(dataType: DataType, isSupported: Boolean): Unit = { + val testName = + s"${if (dataType == null) null else dataType.toString} is " + + s"${if (isSupported) "supported" else "unsupported"}" + + test(testName) { + assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported) + } + } + + checkSupported(null, isSupported = true) + checkSupported(NullType, isSupported = true) + checkSupported(BooleanType, isSupported = true) + checkSupported(ByteType, isSupported = true) + checkSupported(ShortType, isSupported = true) + checkSupported(IntegerType, isSupported = true) + checkSupported(LongType, isSupported = true) + checkSupported(FloatType, isSupported = true) + checkSupported(DoubleType, isSupported = true) + checkSupported(DateType, isSupported = true) + checkSupported(TimestampType, isSupported = true) + checkSupported(StringType, isSupported = true) + checkSupported(BinaryType, isSupported = true) + checkSupported(DecimalType(10, 5), isSupported = true) + checkSupported(DecimalType.Unlimited, isSupported = true) + + // For now, ArrayType, MapType, and StructType are not supported. + checkSupported(ArrayType(DoubleType, true), isSupported = false) + checkSupported(ArrayType(StringType, false), isSupported = false) + checkSupported(MapType(IntegerType, StringType, true), isSupported = false) + checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false) + checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false) + // UDTs are not supported right now. + checkSupported(new MyDenseVectorUDT, isSupported = false) +} + +abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll { + var allColumns: String = _ + val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]] + var numShufflePartitions: Int = _ + var useSerializer2: Boolean = _ + + override def beforeAll(): Unit = { + numShufflePartitions = conf.numShufflePartitions + useSerializer2 = conf.useSqlSerializer2 + + sql("set spark.sql.useSerializer2=true") + + val supportedTypes = + Seq(StringType, BinaryType, NullType, BooleanType, + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + DateType, TimestampType) + + val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + allColumns = fields.map(_.name).mkString(",") + val schema = StructType(fields) + + // Create a RDD with all data types supported by SparkSqlSerializer2. + val rdd = + sparkContext.parallelize((1 to 1000), 10).map { i => + Row( + s"str${i}: test serializer2.", + s"binary${i}: test serializer2.".getBytes("UTF-8"), + null, + i % 2 == 0, + i.toByte, + i.toShort, + i, + Long.MaxValue - i.toLong, + (i + 0.25).toFloat, + (i + 0.75), + BigDecimal(Long.MaxValue.toString + ".12345"), + new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"), + new Date(i), + new Timestamp(i)) + } + + createDataFrame(rdd, schema).registerTempTable("shuffle") + + super.beforeAll() + } + + override def afterAll(): Unit = { + dropTempTable("shuffle") + sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") + sql(s"set spark.sql.useSerializer2=$useSerializer2") + super.afterAll() + } + + def checkSerializer[T <: Serializer]( + executedPlan: SparkPlan, + expectedSerializerClass: Class[T]): Unit = { + executedPlan.foreach { + case exchange: Exchange => + val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]] + val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]] + val serializerNotSetMessage = + s"Expected $expectedSerializerClass as the serializer of Exchange. " + + s"However, the serializer was not set." + val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage)) + assert(serializer.getClass === expectedSerializerClass) + case _ => // Ignore other nodes. + } + } + + test("key schema and value schema are not nulls") { + val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + table("shuffle").collect()) + } + + test("value schema is null") { + val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + assert( + df.map(r => r.getString(0)).collect().toSeq === + table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + } + + test("no map output field") { + val df = sql(s"SELECT 1 + 1 FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) + } +} + +/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */ +class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { + override def beforeAll(): Unit = { + super.beforeAll() + // Sort merge will not be triggered. + sql("set spark.sql.shuffle.partitions = 200") + } + + test("key schema is null") { + val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") + val df = sql(s"SELECT $aggregations FROM shuffle") + checkSerializer(df.queryExecution.executedPlan, serializerClass) + checkAnswer( + df, + Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000)) + } +} + +/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ +class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { + + // We are expecting SparkSqlSerializer. + override val serializerClass: Class[Serializer] = + classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]] + + override def beforeAll(): Unit = { + super.beforeAll() + // To trigger the sort merge. + sql("set spark.sql.shuffle.partitions = 201") + } +}