From e8a1649d1fcb142f560423c06a16e24794b471e1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 23 Jun 2015 14:35:43 -0700 Subject: [PATCH] reuse buffer for string --- .../UnsafeFixedWidthAggregationMap.java | 33 +++++++------ .../sql/catalyst/expressions/UnsafeRow.java | 46 +++++++++++++++---- .../expressions/UnsafeRowConverterSuite.scala | 33 ++++++++++++- 3 files changed, 87 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 21dc082507b47..9bfe2e4785981 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -39,17 +39,21 @@ public final class UnsafeFixedWidthAggregationMap { * map, we copy this buffer and use it as the value. */ private final byte[] emptyBuffer; - private final InternalRow emptyRow; /** - * Whether can the empty aggregation buffer be reuse without calling `initProj` or not. + * A empty row used by `initProjection` + */ + private static final InternalRow emptyRow = new GenericRow(); + + /** + * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not. */ private boolean reuseEmptyBuffer; /** * The projection used to initialize the emptyBuffer */ - private final Function1 initProj; + private final Function1 initProjection; /** * Encodes grouping keys or buffers as UnsafeRows. @@ -90,7 +94,7 @@ public final class UnsafeFixedWidthAggregationMap { /** * Create a new UnsafeFixedWidthAggregationMap. * - * @param initProj the default value for new keys (a "zero" of the agg. function) + * @param initProjection the default value for new keys (a "zero" of the agg. function) * @param keyConverter the converter of the grouping key, used for row conversion. * @param bufferConverter the converter of the aggregation buffer, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. @@ -98,26 +102,25 @@ public final class UnsafeFixedWidthAggregationMap { * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( - Function1 initProj, + Function1 initProjection, UnsafeRowConverter keyConverter, UnsafeRowConverter bufferConverter, TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { - this.initProj = initProj; + this.initProjection = initProjection; this.keyConverter = keyConverter; this.bufferConverter = bufferConverter; this.enablePerfMetrics = enablePerfMetrics; - this.emptyRow = new GenericRow(); this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); this.keyPool = new UniqueObjectPool(100); this.bufferPool = new ObjectPool(initialCapacity); - InternalRow initRow = initProj.apply(emptyRow); + InternalRow initRow = initProjection.apply(emptyRow); this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; - int writtenLength = - bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + int writtenLength = bufferConverter.writeRow( + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; // re-use the empty buffer only when there is no object saved in pool. reuseEmptyBuffer = bufferPool.size() == 0; @@ -150,9 +153,9 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { // empty aggregation buffer into the map: if (!reuseEmptyBuffer) { // There is some objects referenced by emptyBuffer, so generate a new one - InternalRow initRow = initProj.apply(emptyRow); + InternalRow initRow = initProjection.apply(emptyRow); bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - bufferPool); + bufferPool); } loc.putNewKey( groupingKeyConversionScratchSpace, @@ -170,7 +173,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { address.getBaseObject(), address.getBaseOffset(), bufferConverter.numFields(), - bufferPool + bufferPool ); return currentBuffer; } @@ -215,7 +218,7 @@ public MapEntry next() { valueAddress.getBaseObject(), valueAddress.getBaseOffset(), bufferConverter.numFields(), - bufferPool + bufferPool ); return entry; } @@ -243,7 +246,7 @@ public void printPerfMetrics() { System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); - System.out.println("Number of objects in keys: " + keyPool.size()); + System.out.println("Number of unique objects in keys: " + keyPool.size()); System.out.println("Number of objects in buffers: " + bufferPool.size()); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 81797428b26d0..3191740bfe4ed 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -38,7 +38,8 @@ * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the * base address of the row) that points to the beginning of the variable-length field, and length - * (they are combined into a long). For other objects, we + * (they are combined into a long). For other objects, they are stored in a pool, the indexes of + * them are hold in the the word. * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ @@ -114,24 +115,51 @@ private void setNotNullAt(int i) { @Override public void update(int i, Object value) { if (value == null) { - // There will be some garbage left in pool - setNullAt(i); + if (!isNullAt(i)) { + long idx = getLong(i); + if (idx <= 0) { + pool.replace((int)-idx, null); + } else { + // there will be some garbage left (UTF8String or byte[]) + } + setNullAt(i); + } return; } if (isNullAt(i)) { int idx = pool.put(value); - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), (long)-idx); + setLong(i, (long)-idx); } else { - long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + long v = getLong(i); if (v <= 0) { int idx = (int)-v; pool.replace(idx, value); } else { - // old object are UTF8String or Binary, the space will be wasted - // TODO: try to re-use the buffer - int idx = pool.put(value); - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), (long)-idx); + // old object are UTF8String or byte[], try to reuse the space + boolean is_string = (v >> 62) > 0; + int offset = (int)((v >> 31) & Integer.MAX_VALUE); + int size = (int)(v & Integer.MAX_VALUE); + byte[] bytes; + if (value instanceof UTF8String) { + bytes = ((UTF8String)value).getBytes(); + } else { + bytes = (byte[]) value; + } + if (bytes.length <= size) { + PlatformDependent.copyMemory( + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + offset, + bytes.length); + long flag = is_string ? 1L << 62 : 0L; + setLong(i, flag | (((long)offset) << 31) | (long)bytes.length); + } else { + // Can not fit in the buffer + int idx = pool.put(value); + setLong(i, (long)-idx); + } } } setNotNullAt(i); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 760c6ae17291b..cd26fa8cf8615 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -51,6 +51,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { unsafeRow.getLong(0) should be (0) unsafeRow.getLong(1) should be (1) unsafeRow.getInt(2) should be (2) + + unsafeRow.setLong(1, 3) + unsafeRow.getLong(1) should be (3) + unsafeRow.setInt(2, 4) + unsafeRow.getInt(2) should be (4) } test("basic conversion with primitive, string and binary types") { @@ -71,10 +76,25 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { numBytesWritten should be (sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + val pool = new ObjectPool(10) + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) unsafeRow.getLong(0) should be (0) unsafeRow.getString(1) should be ("Hello") unsafeRow.get(2) should be ("World".getBytes) + + unsafeRow.update(1, UTF8String.fromString("World")) + unsafeRow.getString(1) should be ("World") + assert(pool.size === 0) + unsafeRow.update(1, UTF8String.fromString("Hello World")) + unsafeRow.getString(1) should be ("Hello World") + assert(pool.size === 1) + + unsafeRow.update(2, "World".getBytes) + unsafeRow.get(2) should be ("World".getBytes) + assert(pool.size === 1) + unsafeRow.update(2, "Hello World".getBytes) + unsafeRow.get(2) should be ("Hello World".getBytes) + assert(pool.size === 2) } test("basic conversion with primitive, decimal and array") { @@ -92,12 +112,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) numBytesWritten should be (sizeRequired) + assert(pool.size === 2) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) unsafeRow.getLong(0) should be (0) unsafeRow.get(1) should be (Decimal(1)) unsafeRow.get(2) should be (Array(2)) + + unsafeRow.update(1, Decimal(2)) + unsafeRow.get(1) should be (Decimal(2)) + unsafeRow.update(2, Array(3, 4)) + unsafeRow.get(2) should be (Array(3, 4)) + assert(pool.size === 2) } test("basic conversion with primitive, string, date and timestamp types") { @@ -126,6 +153,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // Timestamp is represented as Long in unsafeRow DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be (Timestamp.valueOf("2015-05-08 08:10:25")) + + unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22"))) + DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("2015-06-22")) + } test("null handling") {