diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index f7849ebebc573..83f2a312972fb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import java.util.Arrays; import java.util.Iterator; import org.apache.spark.sql.catalyst.InternalRow; @@ -142,14 +141,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); // Make sure that the buffer is large enough to hold the key. If it's not, grow it: if (groupingKeySize > groupingKeyConversionScratchSpace.length) { - // This new array will be initially zero, so there's no need to zero it out here groupingKeyConversionScratchSpace = new byte[groupingKeySize]; - } else { - // Zero out the buffer that's used to hold the current row. This is necessary in order - // to ensure that rows hash properly, since garbage data from the previous row could - // otherwise end up as padding in this row. As a performance optimization, we only zero out - // the portion of the buffer that we'll actually write to. - Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, (byte) 0); } final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( groupingKey, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index ed04d2e50ec84..bb2f2079b40f0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -47,7 +47,8 @@ * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the - * base address of the row) that points to the beginning of the variable-length field. + * base address of the row) that points to the beginning of the variable-length field, and length + * (they are combined into a long). * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ @@ -92,6 +93,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { */ public static final Set readableFieldTypes; + // TODO: support DecimalType static { settableFieldTypes = Collections.unmodifiableSet( new HashSet( @@ -111,7 +113,8 @@ public static int calculateBitSetWidthInBytes(int numFields) { // We support get() on a superset of the types for which we support set(): final Set _readableFieldTypes = new HashSet( Arrays.asList(new DataType[]{ - StringType + StringType, + BinaryType })); _readableFieldTypes.addAll(settableFieldTypes); readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); @@ -221,11 +224,6 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } - @Override - public void setString(int ordinal, String value) { - throw new UnsupportedOperationException(); - } - @Override public int size() { return numFields; @@ -249,6 +247,8 @@ public Object get(int i) { return null; } else if (dataType == StringType) { return getUTF8String(i); + } else if (dataType == BinaryType) { + return getBinary(i); } else { throw new UnsupportedOperationException(); } @@ -311,19 +311,23 @@ public double getDouble(int i) { } public UTF8String getUTF8String(int i) { + return UTF8String.fromBytes(getBinary(i)); + } + + public byte[] getBinary(int i) { assertIndexIsValid(i); - final long offsetToStringSize = getLong(i); - final int stringSizeInBytes = - (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); - final byte[] strBytes = new byte[stringSizeInBytes]; + final long offsetAndSize = getLong(i); + final int offset = (int)(offsetAndSize >> 32); + final int size = (int)(offsetAndSize & ((1L << 32) - 1)); + final byte[] bytes = new byte[size]; PlatformDependent.copyMemory( baseObject, - baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data - strBytes, + baseOffset + offset, + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - stringSizeInBytes + size ); - return UTF8String.fromBytes(strBytes); + return bytes; } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 72f740ecaead3..89adaf053b1a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -72,6 +70,19 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { */ def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = { unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + + if (writers.length > 0) { + // zero-out the bitset + var n = writers.length / 64 + while (n >= 0) { + PlatformDependent.UNSAFE.putLong( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset + n * 8, + 0L) + n -= 1 + } + } + var fieldNumber = 0 var appendCursor: Int = fixedLengthSize while (fieldNumber < writers.length) { @@ -122,6 +133,7 @@ private object UnsafeColumnWriter { case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter + case BinaryType => BinaryUnsafeColumnWriter case DateType => IntUnsafeColumnWriter case TimestampType => LongUnsafeColumnWriter case t => @@ -141,6 +153,7 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter +private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: @@ -235,10 +248,13 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr } } -private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { +private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { + + def getBytes(source: InternalRow, column: Int): Array[Byte] + def getSize(source: InternalRow, column: Int): Int = { - val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + val numBytes = getBytes(source, column).length + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } override def write( @@ -246,19 +262,33 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { target: UnsafeRow, column: Int, appendCursor: Int): Int = { - val value = source.get(column).asInstanceOf[UTF8String] - val baseObject = target.getBaseObject - val baseOffset = target.getBaseOffset - val numBytes = value.getBytes.length - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + val offset = target.getBaseOffset + appendCursor + val bytes = getBytes(source, column) + val numBytes = bytes.length + if ((numBytes & 0x07) > 0) { + // zero-out the padding bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L) + } PlatformDependent.copyMemory( - value.getBytes, + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - baseOffset + appendCursor + 8, + target.getBaseObject, + offset, numBytes ) - target.setLong(column, appendCursor) - 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong) + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } +} + +private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[UTF8String](column).getBytes + } +} + +private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + def getBytes(source: InternalRow, column: Int): Array[Byte] = { + source.getAs[Array[Byte]](column) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 721ef8a22608c..d8f3351d6dff6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -23,8 +23,8 @@ import java.util.Arrays import org.scalatest.Matchers import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.util.DateUtils +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -52,19 +52,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.getInt(2) should be (2) } - test("basic conversion with primitive and string types") { - val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + test("basic conversion with primitive, string and binary types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) val converter = new UnsafeRowConverter(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setString(1, "Hello") - row.setString(2, "World") + row.update(2, "World".getBytes) val sizeRequired: Int = converter.getSizeRequirement(row) sizeRequired should be (8 + (8 * 3) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) numBytesWritten should be (sizeRequired) @@ -73,7 +73,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) unsafeRow.getLong(0) should be (0) unsafeRow.getString(1) should be ("Hello") - unsafeRow.getString(2) should be ("World") + unsafeRow.getBinary(2) should be ("World".getBytes) } test("basic conversion with primitive, string, date and timestamp types") { @@ -88,7 +88,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(row) sizeRequired should be (8 + (8 * 4) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8)) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) numBytesWritten should be (sizeRequired)