Skip to content
This repository was archived by the owner on Oct 29, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import com.google.api.services.genomics.model.Read;
import com.google.api.services.storage.Storage;
import com.google.api.services.storage.Storage.Objects.Compose;
import com.google.api.services.storage.model.Bucket;
import com.google.api.services.storage.model.ComposeRequest;
import com.google.api.services.storage.model.ComposeRequest.SourceObjects;
import com.google.api.services.storage.model.StorageObject;
import com.google.cloud.dataflow.sdk.Pipeline;
import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.coders.DelegateCoder;
import com.google.cloud.dataflow.sdk.coders.SerializableCoder;
import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
import com.google.cloud.dataflow.sdk.io.TextIO;
import com.google.cloud.dataflow.sdk.options.Default;
Expand Down Expand Up @@ -66,26 +66,24 @@
import htsjdk.samtools.util.BlockCompressedStreamConstants;
import htsjdk.samtools.util.StringLineReader;

import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.io.StringWriter;
import java.io.Writer;
import java.nio.channels.Channels;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.logging.Logger;

/**
* Demonstrates sharded BAM writing
*/
public class ShardedBAMWriting {
private static final Logger LOG = Logger.getLogger(ShardedBAMWriting.class.getName());
private static final int MAX_RETRIES_FOR_WRITING_A_SHARD = 4;
private static ShardedBAMWritingOptions options;
private static Pipeline p;
private static GenomicsFactory.OfflineAuth auth;
Expand All @@ -97,6 +95,12 @@ public static interface ShardedBAMWritingOptions extends GenomicsDatasetOptions,
String getBAMFilePath();

void setBAMFilePath(String filePath);

@Description("Loci per writing shard")
@Default.Long(10000)
long getLociPerWritingShard();

void setLociPerWritingShard(long lociPerShard);
}

public static void main(String[] args) throws GeneralSecurityException, IOException {
Expand All @@ -110,7 +114,7 @@ public static void main(String[] args) throws GeneralSecurityException, IOExcept
// Register coders
DataflowWorkarounds.registerGenomicsCoders(p);
DataflowWorkarounds.registerCoder(p, Contig.class, CONTIG_CODER);
// Get contigs
// Process options
contigs = Contig.parseContigsFromCommandLine(options.getReferences());
// Get header info
final HeaderInfo headerInfo = getHeader();
Expand All @@ -137,7 +141,7 @@ public HeaderInfo(SAMFileHeader header, Contig firstShard) {
this.firstShard = firstShard;
}
}

private static HeaderInfo getHeader() throws IOException {
HeaderInfo result = null;

Expand Down Expand Up @@ -177,7 +181,7 @@ public int compare(Contig o1, Contig o2) {
SAMRecord record = recordIterator.next();
final int alignmentStart = record.getAlignmentStart();
if (firstShard == null && alignmentStart > firstContig.start && alignmentStart < firstContig.end) {
firstShard = shardFromAlignmentStart(firstContig.referenceName, alignmentStart);
firstShard = shardFromAlignmentStart(firstContig.referenceName, alignmentStart, options.getLociPerWritingShard());
LOG.info("Determined first shard to be " + firstShard);
result = new HeaderInfo(header, firstShard);
}
Expand Down Expand Up @@ -226,7 +230,9 @@ public static class KeyReadsFn extends DoFn<Read, KV<Contig,Read>> {
public void processElement(DoFn<Read, KV<Contig, Read>>.ProcessContext c)
throws Exception {
final Read read = c.element();
c.output(KV.of(shardKeyForRead(read), read));
c.output(KV.of(shardKeyForRead(read,
c.getPipelineOptions().as(ShardedBAMWritingOptions.class).getLociPerWritingShard()),
read));
}
}

Expand Down Expand Up @@ -273,9 +279,7 @@ public HeaderInfo apply(String str) throws Exception {
}
});

static final long LOCI_PER_SHARD = 10000;

static Contig shardKeyForRead(Read read) {
static Contig shardKeyForRead(Read read, long lociPerShard) {
String referenceName = null;
Long alignmentStart = null;
if (read.getAlignment() != null) {
Expand All @@ -295,12 +299,12 @@ static Contig shardKeyForRead(Read read) {
referenceName = "*";
alignmentStart = new Long(0);
}
return shardFromAlignmentStart(referenceName, alignmentStart);
return shardFromAlignmentStart(referenceName, alignmentStart, lociPerShard);
}

static Contig shardFromAlignmentStart(String referenceName, long alignmentStart) {
final long shardStart = (alignmentStart / LOCI_PER_SHARD) * LOCI_PER_SHARD;
return new Contig(referenceName, shardStart, shardStart + LOCI_PER_SHARD);
static Contig shardFromAlignmentStart(String referenceName, long alignmentStart, long lociPerShard) {
final long shardStart = (alignmentStart / lociPerShard) * lociPerShard;
return new Contig(referenceName, shardStart, shardStart + lociPerShard);
}

public static TupleTag<KV<Contig, Iterable<Read>>> SHARDED_READS_TAG = new TupleTag<>();
Expand Down Expand Up @@ -372,18 +376,33 @@ public void processElement(DoFn<KV<Contig, Iterable<Read>>, String>.ProcessConte
LOG.info("Writing non-first shard " + shardContig);
}

final String writeResult = writeShard(headerInfo.header,
shardContig, reads,
c.getPipelineOptions().as(ShardedBAMWritingOptions.class),
isFirstShard);
c.output(writeResult);
int numRetriesLeft = MAX_RETRIES_FOR_WRITING_A_SHARD;
boolean done = false;
do {
try {
final String writeResult = writeShard(headerInfo.header,
shardContig, reads,
c.getPipelineOptions().as(ShardedBAMWritingOptions.class),
isFirstShard);
c.output(writeResult);
done = true;
} catch (IOException iox) {
LOG.warning("Write shard failed for " + shardContig + ": " + iox.getMessage());
if (--numRetriesLeft <= 0) {
LOG.warning("No more retries - failing the task for " + shardContig);
throw iox;
}
}
} while (!done);
LOG.info("Finished writing " + shardContig);
}

String writeShard(SAMFileHeader header, Contig shardContig, Iterable<Read> reads,
ShardedBAMWritingOptions options, boolean isFirstShard) throws IOException {
final String outputFileName = options.getOutput();
final String shardName = outputFileName + "-" + shardContig;
final String shardName = outputFileName + "-" + shardContig.referenceName
+ ":" + String.format("%012d", shardContig.start) + "-" +
String.format("%012d", shardContig.end);
LOG.info("Writing shard file " + shardName);
final OutputStream outputStream =
Channels.newOutputStream(
Expand All @@ -396,9 +415,14 @@ String writeShard(SAMFileHeader header, Contig shardContig, Iterable<Read> reads
final BAMBlockWriter bw = new BAMBlockWriter(new TruncatedOutputStream(
outputStream, BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK.length),
null /*file*/);
bw.setSortOrder(header.getSortOrder(), true /*presorted*/);
// If reads are unsorted then we do not care about their order
// otherwise we need to sort them as we write.
final boolean treatReadsAsPresorted =
header.getSortOrder() == SAMFileHeader.SortOrder.unsorted;
bw.setSortOrder(header.getSortOrder(), treatReadsAsPresorted);
bw.setHeader(header);
if (isFirstShard) {
LOG.info("First shard - writing header to " + shardName);
bw.writeHeader(header);
}
for (Read read : reads) {
Expand Down Expand Up @@ -438,11 +462,6 @@ static String combineShards(ShardedBAMWritingOptions options, String dest,
.build()
.objects();

final GcsPath destPath = GcsPath.fromUri(dest);

StorageObject destination = new StorageObject()
.setContentType("application/octet-stream");

ArrayList<String> sortedShardsNames = Lists.newArrayList(shards);
Collections.sort(sortedShardsNames);

Expand All @@ -455,22 +474,65 @@ static String combineShards(ShardedBAMWritingOptions options, String dest,
os.write(BlockCompressedStreamConstants.EMPTY_GZIP_BLOCK);
os.close();
sortedShardsNames.add(eofFileName);
// list of files to concatenate
ArrayList<SourceObjects> sourceObjects = new ArrayList<SourceObjects>();
for (String shard : sortedShardsNames) {
final GcsPath shardPath = GcsPath.fromUri(shard);
LOG.info("Adding object " + shardPath);
sourceObjects.add( new SourceObjects().setName(shardPath.getObject()) );

int stageNumber = 0;
while (sortedShardsNames.size() > 32) {
LOG.info("Have " + sortedShardsNames.size() +
" shards: must combine in groups 32");
final ArrayList<String> combinedShards = Lists.newArrayList();
for (int idx = 0; idx < sortedShardsNames.size(); idx += 32) {
final int endIdx = Math.min(idx + 32, sortedShardsNames.size());
final List<String> combinableShards = sortedShardsNames.subList(
idx, endIdx);
final String intermediateCombineResultName = dest + "-" +
String.format("%02d",stageNumber) + "-" +
String.format("%02d",idx);
final String combineResult = composeAndCleanupShards(storage,
combinableShards, intermediateCombineResultName);
combinedShards.add(combineResult);
}
sortedShardsNames = combinedShards;
stageNumber++;
}

final ComposeRequest composeRequest = new ComposeRequest()
.setDestination(destination)
.setSourceObjects(sourceObjects);
final Compose compose = storage.compose(
destPath.getBucket(), destPath.getObject(), composeRequest);
final String combineResult = compose.execute().toString();
LOG.info("Combine result is " + combineResult);

LOG.info("Combining a final group of " + sortedShardsNames.size() + " shards");
final String combineResult = composeAndCleanupShards(storage,
sortedShardsNames, dest);

return combineResult;
}
}

static String composeAndCleanupShards(Storage.Objects storage,
List<String> shardNames, String dest) throws IOException {
LOG.info("Combining shards into " + dest);

final GcsPath destPath = GcsPath.fromUri(dest);

StorageObject destination = new StorageObject()
.setContentType("application/octet-stream");

ArrayList<SourceObjects> sourceObjects = new ArrayList<SourceObjects>();
for (String shard : shardNames) {
final GcsPath shardPath = GcsPath.fromUri(shard);
LOG.info("Adding shard " + shardPath);
sourceObjects.add( new SourceObjects().setName(shardPath.getObject()) );
}

final ComposeRequest composeRequest = new ComposeRequest()
.setDestination(destination)
.setSourceObjects(sourceObjects);
final Compose compose = storage.compose(
destPath.getBucket(), destPath.getObject(), composeRequest);
final StorageObject result = compose.execute();
final String combineResult = GcsPath.fromObject(result).toString();
LOG.info("Combine result is " + combineResult);
for (SourceObjects sourceObject : sourceObjects) {
final String shardToDelete = sourceObject.getName();
LOG.info("Cleaning up shard " + shardToDelete);
storage.delete(destPath.getBucket(), shardToDelete).execute();
}

return combineResult;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,18 @@ public void write(byte[] data, int offset, int length) throws IOException {
// We have more than bytesToTruncate to write, so clear the buffer
// completely, and write all but bytesToTruncate directly to the stream.
os.write(buf, 0, count);
os.write(data, offset, length - bytesToTruncate);
System.arraycopy(data, offset + length - bytesToTruncate, buf, 0, bytesToTruncate);
final int bytesToWriteDrirectly = length - bytesToTruncate;
os.write(data, offset, bytesToWriteDrirectly);
System.arraycopy(data, offset + bytesToWriteDrirectly, buf, 0, bytesToTruncate);
count = bytesToTruncate;
} else {
// Need this many of the current bytes to stay in the buffer to ensure we
// have at least bytesToTruncate.
int keepInBuffer = bytesToTruncate - length;
final int keepInBuffer = bytesToTruncate - length;
// Write the rest to the stream.
os.write(buf, 0, count - keepInBuffer);
System.arraycopy(buf, count - keepInBuffer, buf, 0, keepInBuffer);
final int bytesToDumpFromBuffer = count - keepInBuffer;
os.write(buf, 0, bytesToDumpFromBuffer);
System.arraycopy(buf, bytesToDumpFromBuffer, buf, 0, keepInBuffer);
System.arraycopy(data, offset, buf, keepInBuffer, length);
count = bytesToTruncate;
}
Expand All @@ -74,9 +76,10 @@ public void close() throws IOException {
}

private void flushBuffer() throws IOException {
final int bytesWeCanSafelyWrite = count - bytesToTruncate;
if (count > bytesToTruncate) {
os.write(buf, 0, count - bytesToTruncate);
System.arraycopy(buf, count - bytesToTruncate, buf, 0, bytesToTruncate);
os.write(buf, 0, bytesWeCanSafelyWrite);
System.arraycopy(buf, bytesWeCanSafelyWrite, buf, 0, bytesToTruncate);
count = bytesToTruncate;
}
}
Expand Down