Skip to content

Commit

Permalink
reuse buffer for string
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jun 23, 2015
1 parent 39f09ca commit e8a1649
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,21 @@ public final class UnsafeFixedWidthAggregationMap {
* map, we copy this buffer and use it as the value.
*/
private final byte[] emptyBuffer;
private final InternalRow emptyRow;

/**
* Whether can the empty aggregation buffer be reuse without calling `initProj` or not.
* A empty row used by `initProjection`
*/
private static final InternalRow emptyRow = new GenericRow();

/**
* Whether can the empty aggregation buffer be reuse without calling `initProjection` or not.
*/
private boolean reuseEmptyBuffer;

/**
* The projection used to initialize the emptyBuffer
*/
private final Function1<InternalRow, InternalRow> initProj;
private final Function1<InternalRow, InternalRow> initProjection;

/**
* Encodes grouping keys or buffers as UnsafeRows.
Expand Down Expand Up @@ -90,34 +94,33 @@ public final class UnsafeFixedWidthAggregationMap {
/**
* Create a new UnsafeFixedWidthAggregationMap.
*
* @param initProj the default value for new keys (a "zero" of the agg. function)
* @param initProjection the default value for new keys (a "zero" of the agg. function)
* @param keyConverter the converter of the grouping key, used for row conversion.
* @param bufferConverter the converter of the aggregation buffer, used for row conversion.
* @param memoryManager the memory manager used to allocate our Unsafe memory structures.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
*/
public UnsafeFixedWidthAggregationMap(
Function1<InternalRow, InternalRow> initProj,
Function1<InternalRow, InternalRow> initProjection,
UnsafeRowConverter keyConverter,
UnsafeRowConverter bufferConverter,
TaskMemoryManager memoryManager,
int initialCapacity,
boolean enablePerfMetrics) {
this.initProj = initProj;
this.initProjection = initProjection;
this.keyConverter = keyConverter;
this.bufferConverter = bufferConverter;
this.enablePerfMetrics = enablePerfMetrics;

this.emptyRow = new GenericRow();
this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
this.keyPool = new UniqueObjectPool(100);
this.bufferPool = new ObjectPool(initialCapacity);

InternalRow initRow = initProj.apply(emptyRow);
InternalRow initRow = initProjection.apply(emptyRow);
this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
int writtenLength =
bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
int writtenLength = bufferConverter.writeRow(
initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
// re-use the empty buffer only when there is no object saved in pool.
reuseEmptyBuffer = bufferPool.size() == 0;
Expand Down Expand Up @@ -150,9 +153,9 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
// empty aggregation buffer into the map:
if (!reuseEmptyBuffer) {
// There is some objects referenced by emptyBuffer, so generate a new one
InternalRow initRow = initProj.apply(emptyRow);
InternalRow initRow = initProjection.apply(emptyRow);
bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
bufferPool);
bufferPool);
}
loc.putNewKey(
groupingKeyConversionScratchSpace,
Expand All @@ -170,7 +173,7 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
address.getBaseObject(),
address.getBaseOffset(),
bufferConverter.numFields(),
bufferPool
bufferPool
);
return currentBuffer;
}
Expand Down Expand Up @@ -215,7 +218,7 @@ public MapEntry next() {
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
bufferConverter.numFields(),
bufferPool
bufferPool
);
return entry;
}
Expand Down Expand Up @@ -243,7 +246,7 @@ public void printPerfMetrics() {
System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
System.out.println("Number of objects in keys: " + keyPool.size());
System.out.println("Number of unique objects in keys: " + keyPool.size());
System.out.println("Number of objects in buffers: " + bufferPool.size());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
* primitive types, such as long, double, or int, we store the value directly in the word. For
* fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
* base address of the row) that points to the beginning of the variable-length field, and length
* (they are combined into a long). For other objects, we
* (they are combined into a long). For other objects, they are stored in a pool, the indexes of
* them are hold in the the word.
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
Expand Down Expand Up @@ -114,24 +115,51 @@ private void setNotNullAt(int i) {
@Override
public void update(int i, Object value) {
if (value == null) {
// There will be some garbage left in pool
setNullAt(i);
if (!isNullAt(i)) {
long idx = getLong(i);
if (idx <= 0) {
pool.replace((int)-idx, null);
} else {
// there will be some garbage left (UTF8String or byte[])
}
setNullAt(i);
}
return;
}

if (isNullAt(i)) {
int idx = pool.put(value);
PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), (long)-idx);
setLong(i, (long)-idx);
} else {
long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i));
long v = getLong(i);
if (v <= 0) {
int idx = (int)-v;
pool.replace(idx, value);
} else {
// old object are UTF8String or Binary, the space will be wasted
// TODO: try to re-use the buffer
int idx = pool.put(value);
PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), (long)-idx);
// old object are UTF8String or byte[], try to reuse the space
boolean is_string = (v >> 62) > 0;
int offset = (int)((v >> 31) & Integer.MAX_VALUE);
int size = (int)(v & Integer.MAX_VALUE);
byte[] bytes;
if (value instanceof UTF8String) {
bytes = ((UTF8String)value).getBytes();
} else {
bytes = (byte[]) value;
}
if (bytes.length <= size) {
PlatformDependent.copyMemory(
bytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
baseOffset + offset,
bytes.length);
long flag = is_string ? 1L << 62 : 0L;
setLong(i, flag | (((long)offset) << 31) | (long)bytes.length);
} else {
// Can not fit in the buffer
int idx = pool.put(value);
setLong(i, (long)-idx);
}
}
}
setNotNullAt(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.getLong(0) should be (0)
unsafeRow.getLong(1) should be (1)
unsafeRow.getInt(2) should be (2)

unsafeRow.setLong(1, 3)
unsafeRow.getLong(1) should be (3)
unsafeRow.setInt(2, 4)
unsafeRow.getInt(2) should be (4)
}

test("basic conversion with primitive, string and binary types") {
Expand All @@ -71,10 +76,25 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
numBytesWritten should be (sizeRequired)

val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
val pool = new ObjectPool(10)
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
unsafeRow.getLong(0) should be (0)
unsafeRow.getString(1) should be ("Hello")
unsafeRow.get(2) should be ("World".getBytes)

unsafeRow.update(1, UTF8String.fromString("World"))
unsafeRow.getString(1) should be ("World")
assert(pool.size === 0)
unsafeRow.update(1, UTF8String.fromString("Hello World"))
unsafeRow.getString(1) should be ("Hello World")
assert(pool.size === 1)

unsafeRow.update(2, "World".getBytes)
unsafeRow.get(2) should be ("World".getBytes)
assert(pool.size === 1)
unsafeRow.update(2, "Hello World".getBytes)
unsafeRow.get(2) should be ("Hello World".getBytes)
assert(pool.size === 2)
}

test("basic conversion with primitive, decimal and array") {
Expand All @@ -92,12 +112,19 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
numBytesWritten should be (sizeRequired)
assert(pool.size === 2)

val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
unsafeRow.getLong(0) should be (0)
unsafeRow.get(1) should be (Decimal(1))
unsafeRow.get(2) should be (Array(2))

unsafeRow.update(1, Decimal(2))
unsafeRow.get(1) should be (Decimal(2))
unsafeRow.update(2, Array(3, 4))
unsafeRow.get(2) should be (Array(3, 4))
assert(pool.size === 2)
}

test("basic conversion with primitive, string, date and timestamp types") {
Expand Down Expand Up @@ -126,6 +153,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
// Timestamp is represented as Long in unsafeRow
DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be
(Timestamp.valueOf("2015-05-08 08:10:25"))

unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22")))
DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("2015-06-22"))

}

test("null handling") {
Expand Down

0 comments on commit e8a1649

Please sign in to comment.