Skip to content

Commit

Permalink
support binaryType in UnsafeRow
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 19, 2015
1 parent 54976e5 commit 447dea0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ public static int calculateBitSetWidthInBytes(int numFields) {
*/
public static final Set<DataType> readableFieldTypes;

// TODO: support DecimalType
static {
settableFieldTypes = Collections.unmodifiableSet(
new HashSet<DataType>(
Expand All @@ -111,7 +112,8 @@ public static int calculateBitSetWidthInBytes(int numFields) {
// We support get() on a superset of the types for which we support set():
final Set<DataType> _readableFieldTypes = new HashSet<DataType>(
Arrays.asList(new DataType[]{
StringType
StringType,
BinaryType
}));
_readableFieldTypes.addAll(settableFieldTypes);
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
Expand Down Expand Up @@ -221,11 +223,6 @@ public void setFloat(int ordinal, float value) {
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}

@Override
public void setString(int ordinal, String value) {
throw new UnsupportedOperationException();
}

@Override
public int size() {
return numFields;
Expand All @@ -249,6 +246,8 @@ public Object get(int i) {
return null;
} else if (dataType == StringType) {
return getUTF8String(i);
} else if (dataType == BinaryType) {
return getBinary(i);
} else {
throw new UnsupportedOperationException();
}
Expand Down Expand Up @@ -311,21 +310,23 @@ public double getDouble(int i) {
}

public UTF8String getUTF8String(int i) {
return UTF8String.fromBytes(getBinary(i));
}

public byte[] getBinary(int i) {
assertIndexIsValid(i);
final UTF8String str = new UTF8String();
final long offsetToStringSize = getLong(i);
final int stringSizeInBytes =
(int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize);
final byte[] strBytes = new byte[stringSizeInBytes];
final long offsetAndSize = getLong(i);
final int offset = (int)(offsetAndSize >> 32);
final int size = (int)(offsetAndSize & ((1L << 32) - 1));
final byte[] bytes = new byte[size];
PlatformDependent.copyMemory(
baseObject,
baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data
strBytes,
baseOffset + offset,
bytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
stringSizeInBytes
size
);
str.set(strBytes);
return str;
return bytes;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
Expand Down Expand Up @@ -122,6 +120,7 @@ private object UnsafeColumnWriter {
case FloatType => FloatUnsafeColumnWriter
case DoubleType => DoubleUnsafeColumnWriter
case StringType => StringUnsafeColumnWriter
case BinaryType => BinaryUnsafeColumnWriter
case DateType => IntUnsafeColumnWriter
case TimestampType => LongUnsafeColumnWriter
case t =>
Expand All @@ -141,6 +140,7 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter

private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
// Primitives don't write to the variable-length region:
Expand Down Expand Up @@ -238,27 +238,53 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr
private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
def getSize(source: InternalRow, column: Int): Int = {
val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}

override def write(
source: InternalRow,
target: UnsafeRow,
column: Int,
appendCursor: Int): Int = {
val value = source.get(column).asInstanceOf[UTF8String]
val value = source.get(column).asInstanceOf[UTF8String].getBytes
val baseObject = target.getBaseObject
val baseOffset = target.getBaseOffset
val numBytes = value.getBytes.length
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
val numBytes = value.length
PlatformDependent.copyMemory(
value.getBytes,
value,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
baseOffset + appendCursor + 8,
baseOffset + appendCursor,
numBytes
)
target.setLong(column, appendCursor)
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
target.setLong(column, (appendCursor.toLong << 32) | numBytes.toLong)
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
}

private class BinaryUnsafeColumnWriter private() extends UnsafeColumnWriter {
def getSize(source: InternalRow, column: Int): Int = {
val numBytes = source.getAs[Array[Byte]](column).length
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}

override def write(
source: InternalRow,
target: UnsafeRow,
column: Int,
appendCursor: Int): Int = {
val value = source.getAs[Array[Byte]](column)
val baseObject = target.getBaseObject
val baseOffset = target.getBaseOffset
val numBytes = value.length
PlatformDependent.copyMemory(
value,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
baseOffset + appendCursor,
numBytes
)
target.setLong(column, (appendCursor.toLong << 32) | numBytes.toLong)
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import java.util.Arrays
import org.scalatest.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods

Expand Down Expand Up @@ -52,19 +52,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.getInt(2) should be (2)
}

test("basic conversion with primitive and string types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
test("basic conversion with primitive, string and binary types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType)
val converter = new UnsafeRowConverter(fieldTypes)

val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setString(1, "Hello")
row.setString(2, "World")
row.update(2, "World".getBytes)

val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 3) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)
Expand All @@ -73,7 +73,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.getLong(0) should be (0)
unsafeRow.getString(1) should be ("Hello")
unsafeRow.getString(2) should be ("World")
unsafeRow.getBinary(2) should be ("World".getBytes)
}

test("basic conversion with primitive, string, date and timestamp types") {
Expand All @@ -88,7 +88,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {

val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 4) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8))
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)
Expand Down

0 comments on commit 447dea0

Please sign in to comment.