Skip to content

Commit

Permalink
Super-messy WIP on external sort
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed May 4, 2015
1 parent 595923a commit 5e100b2
Show file tree
Hide file tree
Showing 8 changed files with 663 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.shuffle.unsafe;

import org.apache.spark.*;
import org.apache.spark.unsafe.sort.UnsafeExternalSortSpillMerger;
import org.apache.spark.unsafe.sort.UnsafeExternalSorter;
import scala.Option;
import scala.Product2;
import scala.reflect.ClassTag;
Expand All @@ -30,10 +33,6 @@

import com.esotericsoftware.kryo.io.ByteBufferOutputStream;

import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
Expand All @@ -54,7 +53,6 @@
// IntelliJ gets confused and claims that this class should be abstract, but this actually compiles
public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {

private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();

Expand All @@ -70,9 +68,6 @@ public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {
private final int fileBufferSize;
private MapStatus mapStatus = null;

private MemoryBlock currentPage = null;
private long currentPagePosition = -1;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
* and then call stop() with success = false if they get an exception, we want to make sure
Expand Down Expand Up @@ -109,39 +104,20 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) {
}
}

private void ensureSpaceInDataPage(long requiredSpace) throws Exception {
final long spaceInCurrentPage;
if (currentPage != null) {
spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset());
} else {
spaceInCurrentPage = 0;
}
if (requiredSpace > PAGE_SIZE) {
// TODO: throw a more specific exception?
throw new Exception("Required space " + requiredSpace + " is greater than page size (" +
PAGE_SIZE + ")");
} else if (requiredSpace > spaceInCurrentPage) {
currentPage = memoryManager.allocatePage(PAGE_SIZE);
currentPagePosition = currentPage.getBaseOffset();
allocatedPages.add(currentPage);
}
}

private void freeMemory() {
final Iterator<MemoryBlock> iter = allocatedPages.iterator();
while (iter.hasNext()) {
memoryManager.freePage(iter.next());
iter.remove();
}
// TODO: free sorter memory
}

private Iterator<RecordPointerAndKeyPrefix> sortRecords(
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
final UnsafeSorter sorter = new UnsafeSorter(
private Iterator<UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix> sortRecords(
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
memoryManager,
SparkEnv$.MODULE$.get().shuffleMemoryManager(),
SparkEnv$.MODULE$.get().blockManager(),
RECORD_COMPARATOR,
PREFIX_COMPARATOR,
4096 // Initial size (TODO: tune this!)
4096, // Initial size (TODO: tune this!)
SparkEnv$.MODULE$.get().conf()
);

final byte[] serArray = new byte[SER_BUFFER_SIZE];
Expand All @@ -161,30 +137,16 @@ private Iterator<RecordPointerAndKeyPrefix> sortRecords(

final int serializedRecordSize = serByteBuffer.position();
assert (serializedRecordSize > 0);
// Need 4 bytes to store the record length.
ensureSpaceInDataPage(serializedRecordSize + 4);

final long recordAddress =
memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
final Object baseObject = currentPage.getBaseObject();
PlatformDependent.UNSAFE.putInt(baseObject, currentPagePosition, serializedRecordSize);
currentPagePosition += 4;
PlatformDependent.copyMemory(
serArray,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
currentPagePosition,
serializedRecordSize);
currentPagePosition += serializedRecordSize;

sorter.insertRecord(recordAddress, partitionId);
sorter.insertRecord(
serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
}

return sorter.getSortedIterator();
}

private long[] writeSortedRecordsToFile(
Iterator<RecordPointerAndKeyPrefix> sortedRecords) throws IOException {
Iterator<UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix> sortedRecords) throws IOException {
final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId);
final ShuffleBlockId blockId =
new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID());
Expand All @@ -195,7 +157,7 @@ private long[] writeSortedRecordsToFile(

final byte[] arr = new byte[SER_BUFFER_SIZE];
while (sortedRecords.hasNext()) {
final RecordPointerAndKeyPrefix recordPointer = sortedRecords.next();
final UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix recordPointer = sortedRecords.next();
final int partition = (int) recordPointer.keyPrefix;
assert (partition >= currentPartition);
if (partition != currentPartition) {
Expand All @@ -209,17 +171,14 @@ private long[] writeSortedRecordsToFile(
blockManager.getDiskWriter(blockId, outputFile, serializer, fileBufferSize, writeMetrics);
}

final Object baseObject = memoryManager.getPage(recordPointer.recordPointer);
final long baseOffset = memoryManager.getOffsetInPage(recordPointer.recordPointer);
final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
PlatformDependent.copyMemory(
baseObject,
baseOffset + 4,
recordPointer.baseObject,
recordPointer.baseOffset + 4,
arr,
PlatformDependent.BYTE_ARRAY_OFFSET,
recordLength);
recordPointer.recordLength);
assert (writer != null); // To suppress an IntelliJ warning
writer.write(arr, 0, recordLength);
writer.write(arr, 0, recordPointer.recordLength);
// TODO: add a test that detects whether we leave this call out:
writer.recordWritten();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.unsafe.sort;

import java.util.Comparator;
import java.util.Iterator;
import java.util.PriorityQueue;

import static org.apache.spark.unsafe.sort.UnsafeSorter.*;

public final class UnsafeExternalSortSpillMerger {

private final PriorityQueue<MergeableIterator> priorityQueue;

public static abstract class MergeableIterator {
public abstract boolean hasNext();

public abstract void advanceRecord();

public abstract long getPrefix();

public abstract Object getBaseObject();

public abstract long getBaseOffset();
}

public static final class RecordAddressAndKeyPrefix {
public Object baseObject;
public long baseOffset;
public int recordLength;
public long keyPrefix;
}

public UnsafeExternalSortSpillMerger(
final RecordComparator recordComparator,
final UnsafeSorter.PrefixComparator prefixComparator) {
final Comparator<MergeableIterator> comparator = new Comparator<MergeableIterator>() {

@Override
public int compare(MergeableIterator left, MergeableIterator right) {
final int prefixComparisonResult =
prefixComparator.compare(left.getPrefix(), right.getPrefix());
if (prefixComparisonResult == 0) {
return recordComparator.compare(
left.getBaseObject(), left.getBaseOffset(),
right.getBaseObject(), right.getBaseOffset());
} else {
return prefixComparisonResult;
}
}
};
priorityQueue = new PriorityQueue<MergeableIterator>(10, comparator);
}

public void addSpill(MergeableIterator spillReader) {
priorityQueue.add(spillReader);
}

public Iterator<RecordAddressAndKeyPrefix> getSortedIterator() {
return new Iterator<RecordAddressAndKeyPrefix>() {

private MergeableIterator spillReader;
private final RecordAddressAndKeyPrefix record = new RecordAddressAndKeyPrefix();

@Override
public boolean hasNext() {
return spillReader.hasNext() || !priorityQueue.isEmpty();
}

@Override
public RecordAddressAndKeyPrefix next() {
if (spillReader != null) {
if (spillReader.hasNext()) {
priorityQueue.add(spillReader);
}
}
spillReader = priorityQueue.poll();
record.baseObject = spillReader.getBaseObject();
record.baseOffset = spillReader.getBaseOffset();
record.keyPrefix = spillReader.getPrefix();
return record;
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}

}
Loading

0 comments on commit 5e100b2

Please sign in to comment.