Skip to content
This repository has been archived by the owner on Oct 29, 2023. It is now read-only.

Commit

Permalink
Merge pull request #129 from iliat/sharded-bam-writer
Browse files Browse the repository at this point in the history
Fixes to sharding
  • Loading branch information
iliat committed Aug 9, 2015
2 parents 7f47146 + 2d55cc9 commit 97fd20b
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import com.google.api.services.genomics.model.Read;
import com.google.api.services.storage.Storage;
import com.google.api.services.storage.Storage.Objects.Compose;
import com.google.api.services.storage.model.Bucket;
import com.google.api.services.storage.model.ComposeRequest;
import com.google.api.services.storage.model.ComposeRequest.SourceObjects;
import com.google.api.services.storage.model.StorageObject;
import com.google.cloud.dataflow.sdk.Pipeline;
import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.coders.DelegateCoder;
import com.google.cloud.dataflow.sdk.coders.SerializableCoder;
import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
import com.google.cloud.dataflow.sdk.io.TextIO;
import com.google.cloud.dataflow.sdk.options.Default;
Expand Down Expand Up @@ -66,26 +66,24 @@
import htsjdk.samtools.util.BlockCompressedStreamConstants;
import htsjdk.samtools.util.StringLineReader;

import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.io.StringWriter;
import java.io.Writer;
import java.nio.channels.Channels;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.logging.Logger;

/**
* Demonstrates sharded BAM writing
*/
public class ShardedBAMWriting {
private static final Logger LOG = Logger.getLogger(ShardedBAMWriting.class.getName());
private static final int MAX_RETRIES_FOR_WRITING_A_SHARD = 4;
private static ShardedBAMWritingOptions options;
private static Pipeline p;
private static GenomicsFactory.OfflineAuth auth;
Expand All @@ -97,6 +95,12 @@ public static interface ShardedBAMWritingOptions extends GenomicsDatasetOptions,
String getBAMFilePath();

void setBAMFilePath(String filePath);

@Description("Loci per writing shard")
@Default.Long(10000)
long getLociPerWritingShard();

void setLociPerWritingShard(long lociPerShard);
}

public static void main(String[] args) throws GeneralSecurityException, IOException {
Expand All @@ -110,7 +114,7 @@ public static void main(String[] args) throws GeneralSecurityException, IOExcept
// Register coders
DataflowWorkarounds.registerGenomicsCoders(p);
DataflowWorkarounds.registerCoder(p, Contig.class, CONTIG_CODER);
// Get contigs
// Process options
contigs = Contig.parseContigsFromCommandLine(options.getReferences());
// Get header info
final HeaderInfo headerInfo = getHeader();
Expand All @@ -137,7 +141,7 @@ public HeaderInfo(SAMFileHeader header, Contig firstShard) {
this.firstShard = firstShard;
}
}

private static HeaderInfo getHeader() throws IOException {
HeaderInfo result = null;

Expand Down Expand Up @@ -177,7 +181,7 @@ public int compare(Contig o1, Contig o2) {
SAMRecord record = recordIterator.next();
final int alignmentStart = record.getAlignmentStart();
if (firstShard == null && alignmentStart > firstContig.start && alignmentStart < firstContig.end) {
firstShard = shardFromAlignmentStart(firstContig.referenceName, alignmentStart);
firstShard = shardFromAlignmentStart(firstContig.referenceName, alignmentStart, options.getLociPerWritingShard());
LOG.info("Determined first shard to be " + firstShard);
result = new HeaderInfo(header, firstShard);
}
Expand Down Expand Up @@ -226,7 +230,9 @@ public static class KeyReadsFn extends DoFn<Read, KV<Contig,Read>> {
public void processElement(DoFn<Read, KV<Contig, Read>>.ProcessContext c)
throws Exception {
final Read read = c.element();
c.output(KV.of(shardKeyForRead(read), read));
c.output(KV.of(shardKeyForRead(read,
c.getPipelineOptions().as(ShardedBAMWritingOptions.class).getLociPerWritingShard()),
read));
}
}

Expand Down Expand Up @@ -273,9 +279,7 @@ public HeaderInfo apply(String str) throws Exception {
}
});

static final long LOCI_PER_SHARD = 10000;

static Contig shardKeyForRead(Read read) {
static Contig shardKeyForRead(Read read, long lociPerShard) {
String referenceName = null;
Long alignmentStart = null;
if (read.getAlignment() != null) {
Expand All @@ -295,12 +299,12 @@ static Contig shardKeyForRead(Read read) {
referenceName = "*";
alignmentStart = new Long(0);
}
return shardFromAlignmentStart(referenceName, alignmentStart);
return shardFromAlignmentStart(referenceName, alignmentStart, lociPerShard);
}

static Contig shardFromAlignmentStart(String referenceName, long alignmentStart) {
final long shardStart = (alignmentStart / LOCI_PER_SHARD) * LOCI_PER_SHARD;
return new Contig(referenceName, shardStart, shardStart + LOCI_PER_SHARD);
static Contig shardFromAlignmentStart(String referenceName, long alignmentStart, long lociPerShard) {
final long shardStart = (alignmentStart / lociPerShard) * lociPerShard;
return new Contig(referenceName, shardStart, shardStart + lociPerShard);
}

public static TupleTag<KV<Contig, Iterable<Read>>> SHARDED_READS_TAG = new TupleTag<>();
Expand Down Expand Up @@ -372,18 +376,33 @@ public void processElement(DoFn<KV<Contig, Iterable<Read>>, String>.ProcessConte
LOG.info("Writing non-first shard " + shardContig);
}

final String writeResult = writeShard(headerInfo.header,
shardContig, reads,
c.getPipelineOptions().as(ShardedBAMWritingOptions.class),
isFirstShard);
c.output(writeResult);
int numRetriesLeft = MAX_RETRIES_FOR_WRITING_A_SHARD;
boolean done = false;
do {
try {
final String writeResult = writeShard(headerInfo.header,
shardContig, reads,
c.getPipelineOptions().as(ShardedBAMWritingOptions.class),
isFirstShard);
c.output(writeResult);
done = true;
} catch (IOException iox) {
LOG.warning("Write shard failed for " + shardContig + ": " + iox.getMessage());
if (--numRetriesLeft <= 0) {
LOG.warning("No more retries - failing the task for " + shardContig);
throw iox;
}
}
} while (!done);
LOG.info("Finished writing " + shardContig);
}

String writeShard(SAMFileHeader header, Contig shardContig, Iterable<Read> reads,
ShardedBAMWritingOptions options, boolean isFirstShard) throws IOException {
final String outputFileName = options.getOutput();
final String shardName = outputFileName + "-" + shardContig;
final String shardName = outputFileName + "-" + shardContig.referenceName
+ ":" + String.format("%012d", shardContig.start) + "-" +
String.format("%012d", shardContig.end);
LOG.info("Writing shard file " + shardName);
final OutputStream outputStream =
Channels.newOutputStream(
Expand All @@ -396,9 +415,14 @@ String writeShard(SAMFileHeader header, Contig shardContig, Iterable<Read> reads
final BAMBlockWriter bw = new BAMBlockWriter(new TruncatedOutputStream(
outputStream, BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK.length),
null /*file*/);
bw.setSortOrder(header.getSortOrder(), true /*presorted*/);
// If reads are unsorted then we do not care about their order
// otherwise we need to sort them as we write.
final boolean treatReadsAsPresorted =
header.getSortOrder() == SAMFileHeader.SortOrder.unsorted;
bw.setSortOrder(header.getSortOrder(), treatReadsAsPresorted);
bw.setHeader(header);
if (isFirstShard) {
LOG.info("First shard - writing header to " + shardName);
bw.writeHeader(header);
}
for (Read read : reads) {
Expand Down Expand Up @@ -438,11 +462,6 @@ static String combineShards(ShardedBAMWritingOptions options, String dest,
.build()
.objects();

final GcsPath destPath = GcsPath.fromUri(dest);

StorageObject destination = new StorageObject()
.setContentType("application/octet-stream");

ArrayList<String> sortedShardsNames = Lists.newArrayList(shards);
Collections.sort(sortedShardsNames);

Expand All @@ -455,22 +474,65 @@ static String combineShards(ShardedBAMWritingOptions options, String dest,
os.write(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK);
os.close();
sortedShardsNames.add(eofFileName);
// list of files to concatenate
ArrayList<SourceObjects> sourceObjects = new ArrayList<SourceObjects>();
for (String shard : sortedShardsNames) {
final GcsPath shardPath = GcsPath.fromUri(shard);
LOG.info("Adding object " + shardPath);
sourceObjects.add( new SourceObjects().setName(shardPath.getObject()) );

int stageNumber = 0;
while (sortedShardsNames.size() > 32) {
LOG.info("Have " + sortedShardsNames.size() +
" shards: must combine in groups 32");
final ArrayList<String> combinedShards = Lists.newArrayList();
for (int idx = 0; idx < sortedShardsNames.size(); idx += 32) {
final int endIdx = Math.min(idx + 32, sortedShardsNames.size());
final List<String> combinableShards = sortedShardsNames.subList(
idx, endIdx);
final String intermediateCombineResultName = dest + "-" +
String.format("%02d",stageNumber) + "-" +
String.format("%02d",idx);
final String combineResult = composeAndCleanupShards(storage,
combinableShards, intermediateCombineResultName);
combinedShards.add(combineResult);
}
sortedShardsNames = combinedShards;
stageNumber++;
}

final ComposeRequest composeRequest = new ComposeRequest()
.setDestination(destination)
.setSourceObjects(sourceObjects);
final Compose compose = storage.compose(
destPath.getBucket(), destPath.getObject(), composeRequest);
final String combineResult = compose.execute().toString();
LOG.info("Combine result is " + combineResult);

LOG.info("Combining a final group of " + sortedShardsNames.size() + " shards");
final String combineResult = composeAndCleanupShards(storage,
sortedShardsNames, dest);

return combineResult;
}
}

static String composeAndCleanupShards(Storage.Objects storage,
List<String> shardNames, String dest) throws IOException {
LOG.info("Combining shards into " + dest);

final GcsPath destPath = GcsPath.fromUri(dest);

StorageObject destination = new StorageObject()
.setContentType("application/octet-stream");

ArrayList<SourceObjects> sourceObjects = new ArrayList<SourceObjects>();
for (String shard : shardNames) {
final GcsPath shardPath = GcsPath.fromUri(shard);
LOG.info("Adding shard " + shardPath);
sourceObjects.add( new SourceObjects().setName(shardPath.getObject()) );
}

final ComposeRequest composeRequest = new ComposeRequest()
.setDestination(destination)
.setSourceObjects(sourceObjects);
final Compose compose = storage.compose(
destPath.getBucket(), destPath.getObject(), composeRequest);
final StorageObject result = compose.execute();
final String combineResult = GcsPath.fromObject(result).toString();
LOG.info("Combine result is " + combineResult);
for (SourceObjects sourceObject : sourceObjects) {
final String shardToDelete = sourceObject.getName();
LOG.info("Cleaning up shard " + shardToDelete);
storage.delete(destPath.getBucket(), shardToDelete).execute();
}

return combineResult;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,18 @@ public void write(byte[] data, int offset, int length) throws IOException {
// We have more than bytesToTruncate to write, so clear the buffer
// completely, and write all but bytesToTruncate directly to the stream.
os.write(buf, 0, count);
os.write(data, offset, length - bytesToTruncate);
System.arraycopy(data, offset + length - bytesToTruncate, buf, 0, bytesToTruncate);
final int bytesToWriteDrirectly = length - bytesToTruncate;
os.write(data, offset, bytesToWriteDrirectly);
System.arraycopy(data, offset + bytesToWriteDrirectly, buf, 0, bytesToTruncate);
count = bytesToTruncate;
} else {
// Need this many of the current bytes to stay in the buffer to ensure we
// have at least bytesToTruncate.
int keepInBuffer = bytesToTruncate - length;
final int keepInBuffer = bytesToTruncate - length;
// Write the rest to the stream.
os.write(buf, 0, count - keepInBuffer);
System.arraycopy(buf, count - keepInBuffer, buf, 0, keepInBuffer);
final int bytesToDumpFromBuffer = count - keepInBuffer;
os.write(buf, 0, bytesToDumpFromBuffer);
System.arraycopy(buf, bytesToDumpFromBuffer, buf, 0, keepInBuffer);
System.arraycopy(data, offset, buf, keepInBuffer, length);
count = bytesToTruncate;
}
Expand All @@ -74,9 +76,10 @@ public void close() throws IOException {
}

private void flushBuffer() throws IOException {
final int bytesWeCanSafelyWrite = count - bytesToTruncate;
if (count > bytesToTruncate) {
os.write(buf, 0, count - bytesToTruncate);
System.arraycopy(buf, count - bytesToTruncate, buf, 0, bytesToTruncate);
os.write(buf, 0, bytesWeCanSafelyWrite);
System.arraycopy(buf, bytesWeCanSafelyWrite, buf, 0, bytesToTruncate);
count = bytesToTruncate;
}
}
Expand Down

0 comments on commit 97fd20b

Please sign in to comment.