Skip to content

Commit

Permalink
First passing test for ExternalSorter.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 4, 2015
1 parent 5e100b2 commit 2776aca
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public final class UnsafeExternalSortSpillMerger {
public static abstract class MergeableIterator {
public abstract boolean hasNext();

public abstract void advanceRecord();
public abstract void loadNextRecord();

public abstract long getPrefix();

Expand Down Expand Up @@ -68,6 +68,9 @@ public int compare(MergeableIterator left, MergeableIterator right) {
}

public void addSpill(MergeableIterator spillReader) {
if (spillReader.hasNext()) {
spillReader.loadNextRecord();
}
priorityQueue.add(spillReader);
}

Expand All @@ -79,17 +82,18 @@ public Iterator<RecordAddressAndKeyPrefix> getSortedIterator() {

@Override
public boolean hasNext() {
return spillReader.hasNext() || !priorityQueue.isEmpty();
return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
}

@Override
public RecordAddressAndKeyPrefix next() {
if (spillReader != null) {
if (spillReader.hasNext()) {
spillReader.loadNextRecord();
priorityQueue.add(spillReader);
}
}
spillReader = priorityQueue.poll();
spillReader = priorityQueue.remove();
record.baseObject = spillReader.getBaseObject();
record.baseOffset = spillReader.getBaseOffset();
record.keyPrefix = spillReader.getPrefix();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
public final class UnsafeExternalSorter {

private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this

private final PrefixComparator prefixComparator;
private final RecordComparator recordComparator;
Expand Down Expand Up @@ -92,6 +91,7 @@ private void openSorter() {
public void spill() throws IOException {
final UnsafeSorterSpillWriter spillWriter =
new UnsafeSorterSpillWriter(blockManager, fileBufferSize, writeMetrics);
spillWriters.add(spillWriter);
final Iterator<RecordPointerAndKeyPrefix> sortedRecords = sorter.getSortedIterator();
while (sortedRecords.hasNext()) {
final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next();
Expand All @@ -110,8 +110,11 @@ private void freeMemory() {
final Iterator<MemoryBlock> iter = allocatedPages.iterator();
while (iter.hasNext()) {
memoryManager.freePage(iter.next());
shuffleMemoryManager.release(PAGE_SIZE);
iter.remove();
}
currentPage = null;
currentPagePosition = -1;
}

private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ public void remove() {

public UnsafeExternalSortSpillMerger.MergeableIterator getMergeableIterator() {
sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator);
return new UnsafeExternalSortSpillMerger.MergeableIterator() {
UnsafeExternalSortSpillMerger.MergeableIterator iter =
new UnsafeExternalSortSpillMerger.MergeableIterator() {

private int position = 0;
private Object baseObject;
Expand All @@ -182,12 +183,12 @@ public boolean hasNext() {
}

@Override
public void advanceRecord() {
public void loadNextRecord() {
final long recordPointer = sortBuffer[position];
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer);
keyPrefix = sortBuffer[position + 1];
position += 2;
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer);
}

@Override
Expand All @@ -205,5 +206,6 @@ public long getBaseOffset() {
return baseOffset;
}
};
return iter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,9 @@
package org.apache.spark.unsafe.sort;

import com.google.common.io.ByteStreams;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.serializer.JavaSerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.BlockObjectWriter;
import org.apache.spark.storage.TempLocalBlockId;
import org.apache.spark.unsafe.PlatformDependent;
import scala.Tuple2;

import java.io.*;

Expand All @@ -39,18 +33,19 @@ public final class UnsafeSorterSpillReader extends UnsafeExternalSortSpillMerger
private long keyPrefix;
private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)?
private final Object baseObject = arr;
private int nextRecordLength;
private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;

public UnsafeSorterSpillReader(
BlockManager blockManager,
File file,
BlockId blockId) throws IOException {
this.file = file;
assert (file.length() > 0);
final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
this.in = blockManager.wrapForCompression(blockId, bs);
this.din = new DataInputStream(this.in);
assert (file.length() > 0);
advanceRecord();
nextRecordLength = din.readInt();
}

@Override
Expand All @@ -59,21 +54,19 @@ public boolean hasNext() {
}

@Override
public void advanceRecord() {
public void loadNextRecord() {
try {
final int recordLength = din.readInt();
if (recordLength == UnsafeSorterSpillWriter.EOF_MARKER) {
keyPrefix = din.readLong();
ByteStreams.readFully(in, arr, 0, nextRecordLength);
nextRecordLength = din.readInt();
if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) {
in.close();
in = null;
return;
din = null;
}
keyPrefix = din.readLong();
ByteStreams.readFully(in, arr, 0, recordLength);

} catch (Exception e) {
PlatformDependent.throwException(e);
}
throw new IllegalStateException();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@
package org.apache.spark.unsafe.sort;

import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.serializer.DeserializationStream;
import org.apache.spark.serializer.JavaSerializerInstance;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.BlockObjectWriter;
import org.apache.spark.storage.TempLocalBlockId;
import org.apache.spark.unsafe.PlatformDependent;
import scala.Tuple2;
import scala.reflect.ClassTag;

import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.*;
import java.nio.ByteBuffer;

public final class UnsafeSorterSpillWriter {

Expand All @@ -51,7 +53,47 @@ public UnsafeSorterSpillWriter(
this.file = spilledFileInfo._2();
this.blockId = spilledFileInfo._1();
// Dummy serializer:
final SerializerInstance ser = new JavaSerializerInstance(0, false, null);
final SerializerInstance ser = new SerializerInstance() {
@Override
public SerializationStream serializeStream(OutputStream s) {
return new SerializationStream() {
@Override
public void flush() {

}

@Override
public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
return null;
}

@Override
public void close() {

}
};
}

@Override
public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) {
return null;
}

@Override
public DeserializationStream deserializeStream(InputStream s) {
return null;
}

@Override
public <T> T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag<T> ev1) {
return null;
}

@Override
public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) {
return null;
}
};
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics);
dos = new DataOutputStream(writer);
}
Expand All @@ -61,14 +103,14 @@ public void write(
long baseOffset,
int recordLength,
long keyPrefix) throws IOException {
dos.writeInt(recordLength);
dos.writeLong(keyPrefix);
PlatformDependent.copyMemory(
baseObject,
baseOffset + 4,
arr,
PlatformDependent.BYTE_ARRAY_OFFSET,
recordLength);
dos.writeInt(recordLength);
dos.writeLong(keyPrefix);
writer.write(arr, 0, recordLength);
// TODO: add a test that detects whether we leave this call out:
writer.recordWritten();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,13 @@ private[spark] class DiskBlockObjectWriter(
}
}

override def write(b: Int): Unit = throw new UnsupportedOperationException()
override def write(b: Int): Unit = {
if (!initialized) {
open()
}

bs.write(b)
}

override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
if (!initialized) {
Expand Down
Loading

0 comments on commit 2776aca

Please sign in to comment.