diff --git a/src/main/java/com/google/cloud/genomics/dataflow/pipelines/ShardedBAMWriting.java b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/ShardedBAMWriting.java index c5a1d74..7c207a3 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/pipelines/ShardedBAMWriting.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/pipelines/ShardedBAMWriting.java @@ -16,13 +16,13 @@ import com.google.api.services.genomics.model.Read; import com.google.api.services.storage.Storage; import com.google.api.services.storage.Storage.Objects.Compose; +import com.google.api.services.storage.model.Bucket; import com.google.api.services.storage.model.ComposeRequest; import com.google.api.services.storage.model.ComposeRequest.SourceObjects; import com.google.api.services.storage.model.StorageObject; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.DelegateCoder; -import com.google.cloud.dataflow.sdk.coders.SerializableCoder; import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; import com.google.cloud.dataflow.sdk.io.TextIO; import com.google.cloud.dataflow.sdk.options.Default; @@ -66,19 +66,16 @@ import htsjdk.samtools.util.BlockCompressedStreamConstants; import htsjdk.samtools.util.StringLineReader; -import java.io.BufferedWriter; import java.io.IOException; import java.io.OutputStream; -import java.io.OutputStreamWriter; -import java.io.Serializable; import java.io.StringWriter; -import java.io.Writer; import java.nio.channels.Channels; import java.security.GeneralSecurityException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.List; import java.util.logging.Logger; /** @@ -86,6 +83,7 @@ */ public class ShardedBAMWriting { private static final Logger LOG = Logger.getLogger(ShardedBAMWriting.class.getName()); + private static final int MAX_RETRIES_FOR_WRITING_A_SHARD = 4; private static ShardedBAMWritingOptions options; private static Pipeline p; private static GenomicsFactory.OfflineAuth auth; @@ -97,6 +95,12 @@ public static interface ShardedBAMWritingOptions extends GenomicsDatasetOptions, String getBAMFilePath(); void setBAMFilePath(String filePath); + + @Description("Loci per writing shard") + @Default.Long(10000) + long getLociPerWritingShard(); + + void setLociPerWritingShard(long lociPerShard); } public static void main(String[] args) throws GeneralSecurityException, IOException { @@ -110,7 +114,7 @@ public static void main(String[] args) throws GeneralSecurityException, IOExcept // Register coders DataflowWorkarounds.registerGenomicsCoders(p); DataflowWorkarounds.registerCoder(p, Contig.class, CONTIG_CODER); - // Get contigs + // Process options contigs = Contig.parseContigsFromCommandLine(options.getReferences()); // Get header info final HeaderInfo headerInfo = getHeader(); @@ -137,7 +141,7 @@ public HeaderInfo(SAMFileHeader header, Contig firstShard) { this.firstShard = firstShard; } } - + private static HeaderInfo getHeader() throws IOException { HeaderInfo result = null; @@ -177,7 +181,7 @@ public int compare(Contig o1, Contig o2) { SAMRecord record = recordIterator.next(); final int alignmentStart = record.getAlignmentStart(); if (firstShard == null && alignmentStart > firstContig.start && alignmentStart < firstContig.end) { - firstShard = shardFromAlignmentStart(firstContig.referenceName, alignmentStart); + firstShard = shardFromAlignmentStart(firstContig.referenceName, alignmentStart, options.getLociPerWritingShard()); LOG.info("Determined first shard to be " + firstShard); result = new HeaderInfo(header, firstShard); } @@ -226,7 +230,9 @@ public static class KeyReadsFn extends DoFn> { public void processElement(DoFn>.ProcessContext c) throws Exception { final Read read = c.element(); - c.output(KV.of(shardKeyForRead(read), read)); + c.output(KV.of(shardKeyForRead(read, + c.getPipelineOptions().as(ShardedBAMWritingOptions.class).getLociPerWritingShard()), + read)); } } @@ -273,9 +279,7 @@ public HeaderInfo apply(String str) throws Exception { } }); - static final long LOCI_PER_SHARD = 10000; - - static Contig shardKeyForRead(Read read) { + static Contig shardKeyForRead(Read read, long lociPerShard) { String referenceName = null; Long alignmentStart = null; if (read.getAlignment() != null) { @@ -295,12 +299,12 @@ static Contig shardKeyForRead(Read read) { referenceName = "*"; alignmentStart = new Long(0); } - return shardFromAlignmentStart(referenceName, alignmentStart); + return shardFromAlignmentStart(referenceName, alignmentStart, lociPerShard); } - static Contig shardFromAlignmentStart(String referenceName, long alignmentStart) { - final long shardStart = (alignmentStart / LOCI_PER_SHARD) * LOCI_PER_SHARD; - return new Contig(referenceName, shardStart, shardStart + LOCI_PER_SHARD); + static Contig shardFromAlignmentStart(String referenceName, long alignmentStart, long lociPerShard) { + final long shardStart = (alignmentStart / lociPerShard) * lociPerShard; + return new Contig(referenceName, shardStart, shardStart + lociPerShard); } public static TupleTag>> SHARDED_READS_TAG = new TupleTag<>(); @@ -372,18 +376,33 @@ public void processElement(DoFn>, String>.ProcessConte LOG.info("Writing non-first shard " + shardContig); } - final String writeResult = writeShard(headerInfo.header, - shardContig, reads, - c.getPipelineOptions().as(ShardedBAMWritingOptions.class), - isFirstShard); - c.output(writeResult); + int numRetriesLeft = MAX_RETRIES_FOR_WRITING_A_SHARD; + boolean done = false; + do { + try { + final String writeResult = writeShard(headerInfo.header, + shardContig, reads, + c.getPipelineOptions().as(ShardedBAMWritingOptions.class), + isFirstShard); + c.output(writeResult); + done = true; + } catch (IOException iox) { + LOG.warning("Write shard failed for " + shardContig + ": " + iox.getMessage()); + if (--numRetriesLeft <= 0) { + LOG.warning("No more retries - failing the task for " + shardContig); + throw iox; + } + } + } while (!done); LOG.info("Finished writing " + shardContig); } String writeShard(SAMFileHeader header, Contig shardContig, Iterable reads, ShardedBAMWritingOptions options, boolean isFirstShard) throws IOException { final String outputFileName = options.getOutput(); - final String shardName = outputFileName + "-" + shardContig; + final String shardName = outputFileName + "-" + shardContig.referenceName + + ":" + String.format("%012d", shardContig.start) + "-" + + String.format("%012d", shardContig.end); LOG.info("Writing shard file " + shardName); final OutputStream outputStream = Channels.newOutputStream( @@ -396,9 +415,14 @@ String writeShard(SAMFileHeader header, Contig shardContig, Iterable reads final BAMBlockWriter bw = new BAMBlockWriter(new TruncatedOutputStream( outputStream, BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK.length), null /*file*/); - bw.setSortOrder(header.getSortOrder(), true /*presorted*/); + // If reads are unsorted then we do not care about their order + // otherwise we need to sort them as we write. + final boolean treatReadsAsPresorted = + header.getSortOrder() == SAMFileHeader.SortOrder.unsorted; + bw.setSortOrder(header.getSortOrder(), treatReadsAsPresorted); bw.setHeader(header); if (isFirstShard) { + LOG.info("First shard - writing header to " + shardName); bw.writeHeader(header); } for (Read read : reads) { @@ -438,11 +462,6 @@ static String combineShards(ShardedBAMWritingOptions options, String dest, .build() .objects(); - final GcsPath destPath = GcsPath.fromUri(dest); - - StorageObject destination = new StorageObject() - .setContentType("application/octet-stream"); - ArrayList sortedShardsNames = Lists.newArrayList(shards); Collections.sort(sortedShardsNames); @@ -455,22 +474,65 @@ static String combineShards(ShardedBAMWritingOptions options, String dest, os.write(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK); os.close(); sortedShardsNames.add(eofFileName); - // list of files to concatenate - ArrayList sourceObjects = new ArrayList(); - for (String shard : sortedShardsNames) { - final GcsPath shardPath = GcsPath.fromUri(shard); - LOG.info("Adding object " + shardPath); - sourceObjects.add( new SourceObjects().setName(shardPath.getObject()) ); + + int stageNumber = 0; + while (sortedShardsNames.size() > 32) { + LOG.info("Have " + sortedShardsNames.size() + + " shards: must combine in groups 32"); + final ArrayList combinedShards = Lists.newArrayList(); + for (int idx = 0; idx < sortedShardsNames.size(); idx += 32) { + final int endIdx = Math.min(idx + 32, sortedShardsNames.size()); + final List combinableShards = sortedShardsNames.subList( + idx, endIdx); + final String intermediateCombineResultName = dest + "-" + + String.format("%02d",stageNumber) + "-" + + String.format("%02d",idx); + final String combineResult = composeAndCleanupShards(storage, + combinableShards, intermediateCombineResultName); + combinedShards.add(combineResult); + } + sortedShardsNames = combinedShards; + stageNumber++; } - - final ComposeRequest composeRequest = new ComposeRequest() - .setDestination(destination) - .setSourceObjects(sourceObjects); - final Compose compose = storage.compose( - destPath.getBucket(), destPath.getObject(), composeRequest); - final String combineResult = compose.execute().toString(); - LOG.info("Combine result is " + combineResult); + + LOG.info("Combining a final group of " + sortedShardsNames.size() + " shards"); + final String combineResult = composeAndCleanupShards(storage, + sortedShardsNames, dest); + return combineResult; } } + + static String composeAndCleanupShards(Storage.Objects storage, + List shardNames, String dest) throws IOException { + LOG.info("Combining shards into " + dest); + + final GcsPath destPath = GcsPath.fromUri(dest); + + StorageObject destination = new StorageObject() + .setContentType("application/octet-stream"); + + ArrayList sourceObjects = new ArrayList(); + for (String shard : shardNames) { + final GcsPath shardPath = GcsPath.fromUri(shard); + LOG.info("Adding shard " + shardPath); + sourceObjects.add( new SourceObjects().setName(shardPath.getObject()) ); + } + + final ComposeRequest composeRequest = new ComposeRequest() + .setDestination(destination) + .setSourceObjects(sourceObjects); + final Compose compose = storage.compose( + destPath.getBucket(), destPath.getObject(), composeRequest); + final StorageObject result = compose.execute(); + final String combineResult = GcsPath.fromObject(result).toString(); + LOG.info("Combine result is " + combineResult); + for (SourceObjects sourceObject : sourceObjects) { + final String shardToDelete = sourceObject.getName(); + LOG.info("Cleaning up shard " + shardToDelete); + storage.delete(destPath.getBucket(), shardToDelete).execute(); + } + + return combineResult; + } } diff --git a/src/main/java/com/google/cloud/genomics/dataflow/utils/TruncatedOutputStream.java b/src/main/java/com/google/cloud/genomics/dataflow/utils/TruncatedOutputStream.java index 1a4f6dc..70e01cb 100644 --- a/src/main/java/com/google/cloud/genomics/dataflow/utils/TruncatedOutputStream.java +++ b/src/main/java/com/google/cloud/genomics/dataflow/utils/TruncatedOutputStream.java @@ -46,16 +46,18 @@ public void write(byte[] data, int offset, int length) throws IOException { // We have more than bytesToTruncate to write, so clear the buffer // completely, and write all but bytesToTruncate directly to the stream. os.write(buf, 0, count); - os.write(data, offset, length - bytesToTruncate); - System.arraycopy(data, offset + length - bytesToTruncate, buf, 0, bytesToTruncate); + final int bytesToWriteDrirectly = length - bytesToTruncate; + os.write(data, offset, bytesToWriteDrirectly); + System.arraycopy(data, offset + bytesToWriteDrirectly, buf, 0, bytesToTruncate); count = bytesToTruncate; } else { // Need this many of the current bytes to stay in the buffer to ensure we // have at least bytesToTruncate. - int keepInBuffer = bytesToTruncate - length; + final int keepInBuffer = bytesToTruncate - length; // Write the rest to the stream. - os.write(buf, 0, count - keepInBuffer); - System.arraycopy(buf, count - keepInBuffer, buf, 0, keepInBuffer); + final int bytesToDumpFromBuffer = count - keepInBuffer; + os.write(buf, 0, bytesToDumpFromBuffer); + System.arraycopy(buf, bytesToDumpFromBuffer, buf, 0, keepInBuffer); System.arraycopy(data, offset, buf, keepInBuffer, length); count = bytesToTruncate; } @@ -74,9 +76,10 @@ public void close() throws IOException { } private void flushBuffer() throws IOException { + final int bytesWeCanSafelyWrite = count - bytesToTruncate; if (count > bytesToTruncate) { - os.write(buf, 0, count - bytesToTruncate); - System.arraycopy(buf, count - bytesToTruncate, buf, 0, bytesToTruncate); + os.write(buf, 0, bytesWeCanSafelyWrite); + System.arraycopy(buf, bytesWeCanSafelyWrite, buf, 0, bytesToTruncate); count = bytesToTruncate; } }