Skip to content
This repository was archived by the owner on Oct 29, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
import com.google.cloud.dataflow.sdk.options.Default;
import com.google.cloud.dataflow.sdk.options.Description;
import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.Create;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.GroupByKey;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.transforms.Sum;
import com.google.cloud.dataflow.sdk.transforms.Sum.SumIntegerFn;
import com.google.cloud.dataflow.sdk.transforms.View;
import com.google.cloud.dataflow.sdk.util.GcsUtil;
import com.google.cloud.dataflow.sdk.util.Transport;
Expand Down Expand Up @@ -238,13 +241,33 @@ public static PCollection<KV<Contig, Iterable<Read>>> shard(PCollection<Read> re
}

public static class KeyReadsFn extends DoFn<Read, KV<Contig,Read>> {
private Aggregator<Integer, Integer> readCountAggregator;
private Aggregator<Integer, Integer> unmappedReadCountAggregator;
private long lociPerShard;

public KeyReadsFn() {
readCountAggregator = createAggregator("Keyed reads", new SumIntegerFn());
unmappedReadCountAggregator = createAggregator("Keyed unmapped reads", new SumIntegerFn());
}

@Override
public void startBundle(Context c) {
lociPerShard = c.getPipelineOptions()
.as(ShardedBAMWritingOptions.class)
.getLociPerWritingShard();
}
@Override
public void processElement(DoFn<Read, KV<Contig, Read>>.ProcessContext c)
throws Exception {
final Read read = c.element();
c.output(KV.of(shardKeyForRead(read,
c.getPipelineOptions().as(ShardedBAMWritingOptions.class).getLociPerWritingShard()),
read));
c.output(
KV.of(
shardKeyForRead(read, lociPerShard),
read));
readCountAggregator.addValue(1);
if (isUnmapped(read)) {
unmappedReadCountAggregator.addValue(1);
}
}
}

Expand Down Expand Up @@ -291,6 +314,17 @@ public HeaderInfo apply(String str) throws Exception {
}
});

static boolean isUnmapped(Read read) {
if (read.getAlignment() == null || read.getAlignment().getPosition() == null) {
return true;
}
final String reference = read.getAlignment().getPosition().getReferenceName();
if (reference == null || reference.isEmpty() || reference.equals("*")) {
return true;
}
return false;
}

static Contig shardKeyForRead(Read read, long lociPerShard) {
String referenceName = null;
Long alignmentStart = null;
Expand All @@ -301,14 +335,17 @@ static Contig shardKeyForRead(Read read, long lociPerShard) {
}
}
// If this read is unmapped but its mate is mapped, group them together
if (referenceName == null || alignmentStart == null) {
if (referenceName == null || referenceName.isEmpty() ||
referenceName.equals("*") || alignmentStart == null) {
if (read.getNextMatePosition() != null) {
referenceName = read.getNextMatePosition().getReferenceName();
alignmentStart = read.getNextMatePosition().getPosition();
}
}
if (referenceName == null || alignmentStart == null) {
if (referenceName == null || referenceName.isEmpty()) {
referenceName = "*";
}
if (alignmentStart == null) {
alignmentStart = new Long(0);
}
return shardFromAlignmentStart(referenceName, alignmentStart, lociPerShard);
Expand Down Expand Up @@ -363,9 +400,13 @@ public static PCollection<String> write(PCollection<KV<Contig, Iterable<Read>>>
public static class WriteShardFn extends DoFn<KV<Contig, Iterable<Read>>, String> {
final PCollectionView<HeaderInfo> headerView;
Storage.Objects storage;
Aggregator<Integer, Integer> readCountAggregator;
Aggregator<Integer, Integer> unmappedReadCountAggregator;

public WriteShardFn(final PCollectionView<HeaderInfo> headerView) {
this.headerView = headerView;
readCountAggregator = createAggregator("Written reads", new SumIntegerFn());
unmappedReadCountAggregator = createAggregator("Written unmapped reads", new SumIntegerFn());
}

@Override
Expand Down Expand Up @@ -422,6 +463,7 @@ String writeShard(SAMFileHeader header, Contig shardContig, Iterable<Read> reads
.create(GcsPath.fromUri(shardName),
"application/octet-stream"));
int count = 0;
int countUnmapped = 0;
// Use a TruncatedOutputStream to avoid writing the empty gzip block that
// indicates EOF.
final BAMBlockWriter bw = new BAMBlockWriter(new TruncatedOutputStream(
Expand All @@ -439,11 +481,20 @@ String writeShard(SAMFileHeader header, Contig shardContig, Iterable<Read> reads
}
for (Read read : reads) {
SAMRecord samRecord = ReadUtils.makeSAMRecord(read, header);
if (samRecord.getReadUnmappedFlag()) {
if (!samRecord.getMateUnmappedFlag()) {
samRecord.setReferenceName(samRecord.getMateReferenceName());
samRecord.setAlignmentStart(samRecord.getMateAlignmentStart());
}
countUnmapped++;
}
bw.addAlignment(samRecord);
count++;
}
bw.close();
LOG.info("Wrote " + count + " reads into " + shardName);
LOG.info("Wrote " + count + " reads, " + countUnmapped + " umapped, into " + shardName);
readCountAggregator.addValue(count);
unmappedReadCountAggregator.addValue(countUnmapped);
return shardName;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import com.google.cloud.dataflow.sdk.Pipeline;
import com.google.cloud.dataflow.sdk.coders.SerializableCoder;
import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.Create;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.transforms.ParDo;
import com.google.cloud.dataflow.sdk.transforms.View;
import com.google.cloud.dataflow.sdk.transforms.Sum.SumIntegerFn;
import com.google.cloud.dataflow.sdk.util.Transport;
import com.google.cloud.dataflow.sdk.values.PCollection;
import com.google.cloud.dataflow.sdk.values.PCollectionTuple;
Expand Down Expand Up @@ -52,10 +54,20 @@ public static class ReadFn extends DoFn<BAMShard, Read> {
GenomicsFactory.OfflineAuth auth;
Storage.Objects storage;
ReaderOptions options;
Aggregator<Integer, Integer> recordCountAggregator;
Aggregator<Integer, Integer> readCountAggregator;
Aggregator<Integer, Integer> skippedStartCountAggregator;
Aggregator<Integer, Integer> skippedEndCountAggregator;
Aggregator<Integer, Integer> skippedRefMismatchAggregator;

public ReadFn(GenomicsFactory.OfflineAuth auth, ReaderOptions options) {
this.auth = auth;
this.options = options;
recordCountAggregator = createAggregator("Processed records", new SumIntegerFn());
readCountAggregator = createAggregator("Reads generated", new SumIntegerFn());
skippedStartCountAggregator = createAggregator("Skipped start", new SumIntegerFn());
skippedEndCountAggregator = createAggregator("Skipped end", new SumIntegerFn());
skippedRefMismatchAggregator = createAggregator("Ref mismatch", new SumIntegerFn());
}

@Override
Expand All @@ -65,8 +77,13 @@ public void startBundle(DoFn<BAMShard, Read>.Context c) throws IOException {

@Override
public void processElement(ProcessContext c) throws java.lang.Exception {
(new Reader(storage, options, c.element(), c))
.process();
final Reader reader = new Reader(storage, options, c.element(), c);
reader.process();
recordCountAggregator.addValue(reader.recordsProcessed);
skippedStartCountAggregator.addValue(reader.recordsBeforeStart);
skippedEndCountAggregator.addValue(reader.recordsAfterEnd);
skippedRefMismatchAggregator.addValue(reader.mismatchedSequence);
readCountAggregator.addValue(reader.readsGenerated);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ enum Filter {

Filter filter;

int recordsBeforeStart = 0;
int recordsAfterEnd = 0;
int mismatchedSequence = 0;
int recordsProcessed = 0;
public int recordsBeforeStart = 0;
public int recordsAfterEnd = 0;
public int mismatchedSequence = 0;
public int recordsProcessed = 0;
public int readsGenerated = 0;

public Reader(Objects storageClient, ReaderOptions options, BAMShard shard, DoFn<BAMShard, Read>.ProcessContext c) {
super();
Expand Down Expand Up @@ -160,6 +161,7 @@ boolean passesFilter(SAMRecord record) {
}

void processRecord(SAMRecord record) {
recordsProcessed++;
if (!passesFilter(record)) {
mismatchedSequence++;
return;
Expand All @@ -168,17 +170,17 @@ void processRecord(SAMRecord record) {
recordsBeforeStart++;
return;
}
if (record.getAlignmentStart() >= shard.contig.end) {
if (record.getAlignmentStart() > shard.contig.end) {
recordsAfterEnd++;
return;
}
c.output(ReadUtils.makeRead(record));
recordsProcessed++;
readsGenerated++;
}

void dumpStats() {
timer.stop();
LOG.info("Processed " + recordsProcessed +
LOG.info("Processed " + recordsProcessed + " outputted " + readsGenerated +
" in " + timer +
". Speed: " + (recordsProcessed*1000)/timer.elapsed(TimeUnit.MILLISECONDS) + " reads/sec"
+ ", filtered out by reference and mapping " + mismatchedSequence
Expand Down Expand Up @@ -219,7 +221,7 @@ public static Iterable<Read> readSequentiallyForTesting(Objects storageClient,
recordsBeforeStart++;
continue;
}
if (record.getAlignmentStart() >= contig.end) {
if (record.getAlignmentStart() > contig.end) {
recordsAfterEnd++;
continue;
}
Expand Down