Skip to content

Commit

Permalink
Merge 9713727 into 2c93584
Browse files Browse the repository at this point in the history
  • Loading branch information
laserson committed Oct 7, 2015
2 parents 2c93584 + 9713727 commit ccca14f
Show file tree
Hide file tree
Showing 19 changed files with 392 additions and 120 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import com.google.cloud.dataflow.sdk.options.PipelineOptions;
import com.google.cloud.dataflow.sdk.transforms.SerializableFunction;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.reference.ReferenceSequenceFile;
import htsjdk.samtools.reference.ReferenceSequenceFileFactory;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.engine.spark.datasources.ReferenceTwoBitSource;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.dataflow.BucketUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.reference.ReferenceBases;

import java.io.File;
import java.io.IOException;
import java.io.Serializable;

Expand All @@ -35,7 +35,13 @@ public class ReferenceDataflowSource implements ReferenceSource, Serializable {
public ReferenceDataflowSource(final PipelineOptions pipelineOptions, final String referenceURL,
final SerializableFunction<GATKRead, SimpleInterval> referenceWindowFunction) {
Utils.nonNull(referenceWindowFunction);
if (isFasta(referenceURL)) {
if (referenceURL.endsWith(".2bit")) {
try {
referenceSource = new ReferenceTwoBitSource(pipelineOptions, referenceURL);
} catch (IOException e) {
throw new UserException("Failed to create a ReferenceTwoBitSource object" + e.getMessage());
}
} else if (isFasta(referenceURL)) {
if (BucketUtils.isHadoopUrl(referenceURL)) {
referenceSource = new ReferenceHadoopSource(referenceURL);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
/**
* Internal interface to load a reference sequence.
*/
interface ReferenceSource {
public interface ReferenceSource {
ReferenceBases getReferenceBases(PipelineOptions pipelineOptions, SimpleInterval interval) throws IOException;
SAMSequenceDictionary getReferenceSequenceDictionary(SAMSequenceDictionary optReadSequenceDictionaryToMatch) throws IOException;
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
package org.broadinstitute.hellbender.engine.spark;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.broadinstitute.hellbender.engine.dataflow.datasources.ReadContextData;
import org.broadinstitute.hellbender.engine.dataflow.datasources.ReferenceDataflowSource;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.reference.ReferenceBases;
import org.broadinstitute.hellbender.utils.variant.Variant;
import scala.Tuple2;

import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;

/**
* AddContextDataToRead pairs reference bases and overlapping variants with each GATKRead in the RDD input.
* The variants are obtained from a local file (later a GCS Bucket). The reference bases come from the Google Genomics API.
Expand All @@ -32,81 +26,20 @@
public class AddContextDataToReadSpark {
public static JavaPairRDD<GATKRead, ReadContextData> add(
final JavaRDD<GATKRead> reads, final ReferenceDataflowSource referenceDataflowSource,
final JavaRDD<Variant> variants) {

// This transform can currently only handle mapped reads
final JavaRDD<Variant> variants, final JoinStrategy joinStrategy) {
// TODO: this static method should not be filtering the unmapped reads. To be addressed in another issue.
JavaRDD<GATKRead> mappedReads = reads.filter(read -> ReadFilterLibrary.MAPPED.test(read));

// Join Reads and Variants, Reads and ReferenceBases
JavaPairRDD<GATKRead, Iterable<Variant>> readiVariants = JoinReadsWithVariants.join(mappedReads, variants);
JavaPairRDD<GATKRead, ReferenceBases> readRefBases = JoinReadsWithRefBases.addBases(referenceDataflowSource, mappedReads);

// For testing we want to know that the reads from the KVs coming back from JoinReadsWithVariants.Join
// and JoinReadsWithRefBases.Pair are the same reads from "reads".
boolean assertsEnabled = false;
assert assertsEnabled = true; // Intentional side-effect!!!
// Now assertsEnabled is set to the correct value
if (assertsEnabled) {
assertSameReads(mappedReads, readRefBases, readiVariants);
}

JavaPairRDD<GATKRead, Tuple2<Iterable<Iterable<Variant>>, Iterable<ReferenceBases>>> cogroup = readiVariants.cogroup(readRefBases);
return cogroup.mapToPair(in -> {
ReadContextData readContextData = null;
try {
List<Variant> lVariants = makeListFromIterableIterable(in._2()._1());

ReferenceBases refBases = Iterables.getOnlyElement(in._2()._2());
readContextData = new ReadContextData(refBases, lVariants);
} catch (NoSuchElementException e) {
throw new GATKException.ShouldNeverReachHereException(e);
}
return new Tuple2<>(in._1(), readContextData);
});
}

private static <T> List<T> makeListFromIterableIterable(Iterable<Iterable<T>> iterables) {
List<Iterable<T>> listIterableT = Lists.newArrayList(iterables);
List<T> listT = Lists.newArrayList();
if (!listIterableT.isEmpty()) {
final Iterable<T> iterableT = Iterables.getOnlyElement(iterables);
// It's possible for the iterableT to contain only a null T, we don't
// want to include that.
final Iterator<T> iterator = iterableT.iterator();
if (iterator.hasNext()) {
final T next = iterator.next();
if (next != null) {
listT = Lists.newArrayList(iterableT);
}
}
// Join Reads and Variants
JavaPairRDD<GATKRead, Iterable<Variant>> withVariants = JoinReadsWithVariants.join(mappedReads, variants);
// Join Reads with ReferenceBases
JavaPairRDD<GATKRead, Tuple2<Iterable<Variant>, ReferenceBases>> withVariantsWithRef;
if (joinStrategy.equals(JoinStrategy.BROADCAST)) {
withVariantsWithRef = BroadcastJoinReadsWithRefBases.addBases(referenceDataflowSource, withVariants);
} else if (joinStrategy.equals(JoinStrategy.SHUFFLE)) {
withVariantsWithRef = ShuffleJoinReadsWithRefBases.addBases(referenceDataflowSource, withVariants);
} else {
throw new UserException("Unknown JoinStrategy");
}
return listT;
return withVariantsWithRef.mapToPair(in -> new Tuple2<>(in._1(), new ReadContextData(in._2()._2(), in._2()._1())));
}


private static void assertSameReads(final JavaRDD<GATKRead> reads,
final JavaPairRDD<GATKRead, ReferenceBases> readRefBases,
final JavaPairRDD<GATKRead, Iterable<Variant>> readiVariants) {

// We want to verify that the reads are the same for each collection and that there are no duplicates
// in any collection.

// Collect all reads (with potential duplicates) in allReads. We expect there to be 3x the unique reads.
// Verify that there are 3x the distinct reads and the reads count for each collection match
// the distinct reads count.
// We should also check that the reference bases and variants are correctly paired with the reads. See
// issue (#873).
JavaRDD<GATKRead> refBasesReads = readRefBases.keys();
JavaRDD<GATKRead> variantsReads = readiVariants.keys();
JavaRDD<GATKRead> allReads = reads.union(refBasesReads).union(variantsReads);
long allReadsCount = allReads.count();
long distinctReads = allReads.distinct().count();

assert 3 * distinctReads == allReadsCount;
assert distinctReads == reads.count();
assert distinctReads == refBasesReads.count();
assert distinctReads == variantsReads.count();
}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package org.broadinstitute.hellbender.engine.spark;

import com.google.cloud.dataflow.sdk.transforms.SerializableFunction;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.hellbender.engine.dataflow.datasources.ReferenceDataflowSource;
import org.broadinstitute.hellbender.engine.spark.datasources.ReferenceTwoBitSource;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.reference.ReferenceBases;
import scala.Tuple2;

/**
* Joins an RDD of GATKReads to reference data using a broadcast strategy.
*
* The ReferenceDataflowSource is broadcast using Spark's Broadcast variable mechanism. The reads are then mapped
* over and a reference query is executed on each read. This makes sense for ReferenceDataflowSource implementations
* that contain the reference data in memory (e.g., ReferenceTwoBitSource), but will likely be much slower for
* implementations that have to query other resources for the reference sequences.
*/
public class BroadcastJoinReadsWithRefBases {

/**
* Joins each read of an RDD<GATKRead> with that read's corresponding reference sequence.
*
* @param referenceDataflowSource The source of the reference sequence information
* @param reads The reads for which to extract reference sequence information
* @return The JavaPairRDD that contains each read along with the corresponding ReferenceBases object
*/
public static JavaPairRDD<GATKRead, ReferenceBases> addBases(final ReferenceDataflowSource referenceDataflowSource,
final JavaRDD<GATKRead> reads) {
JavaSparkContext ctx = new JavaSparkContext(reads.context());
Broadcast<ReferenceDataflowSource> bReferenceSource = ctx.broadcast(referenceDataflowSource);
return reads.mapToPair(read -> {
SimpleInterval interval = bReferenceSource.getValue().getReferenceWindowFunction().apply(read);
return new Tuple2<>(read, bReferenceSource.getValue().getReferenceBases(null, interval));
});
}

/**
* Joins each read of an RDD<GATKRead, T> with key's corresponding reference sequence.
*
* @param referenceDataflowSource The source of the reference sequence information
* @param keyedByRead The read-keyed RDD for which to extract reference sequence information
* @return The JavaPairRDD that contains each read along with the corresponding ReferenceBases object and the value
*/
public static <T> JavaPairRDD<GATKRead, Tuple2<T, ReferenceBases>> addBases(final ReferenceDataflowSource referenceDataflowSource,
final JavaPairRDD<GATKRead, T> keyedByRead) {
JavaSparkContext ctx = new JavaSparkContext(keyedByRead.context());
Broadcast<ReferenceDataflowSource> bReferenceSource = ctx.broadcast(referenceDataflowSource);
return keyedByRead.mapToPair(pair -> {
SimpleInterval interval = bReferenceSource.getValue().getReferenceWindowFunction().apply(pair._1());
return new Tuple2<>(pair._1(), new Tuple2<>(pair._2(), bReferenceSource.getValue().getReferenceBases(null, interval)));
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,20 @@ public static JavaPairRDD<GATKRead, Iterable<Variant>> join(

JavaPairRDD<VariantShard, Variant> variantsWShards = pairVariantsWithVariantShards(variants);

// generate read-variant pairs; however, the reads are replicated for each overlapping pair
JavaPairRDD<GATKRead, Variant> allPairs = pairReadsWithVariants(readsWShards, variantsWShards);


final JavaPairRDD<GATKRead, ? extends Iterable<Variant>> gatkReadHashSetJavaPairRDD = allPairs.aggregateByKey(new HashSet<>(), (Set<Variant> vs, Variant v) -> {
// we group together all variants for each unique GATKRead. As we combine through the Variants, they are added
// to a HashSet that get continually merged together
return allPairs.aggregateByKey(new HashSet<>(), (vs, v) -> {
if (v != null) { // pairReadsWithVariants can produce null variant
vs.add(v);
((HashSet<Variant>) vs).add(v);
}
return vs;
}, (Set<Variant> vs1, Set<Variant> vs2) -> {
vs1.addAll(vs2);
}, (vs1, vs2) -> {
((HashSet<Variant>) vs1).addAll((HashSet<Variant>) vs2);
return vs1;
});

@SuppressWarnings("unchecked")
final JavaPairRDD<GATKRead, Iterable<Variant>> gatkReadIterableJavaPairRDD = (JavaPairRDD<GATKRead, Iterable<Variant>>)gatkReadHashSetJavaPairRDD;
return gatkReadIterableJavaPairRDD;

}

private static JavaPairRDD<VariantShard, GATKRead> pairReadsWithVariantShards(final JavaRDD<GATKRead> reads) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.broadinstitute.hellbender.engine.spark;

/**
* Possible join strategies when using Spark
*/
public enum JoinStrategy {
/**
* Use a broadcast join strategy, where one side of the join is collected into memory and broadcast to all workers.
*/
BROADCAST,

/**
* Use a shuffle join strategy, where both sides of join are shuffled across the workers.
*/
SHUFFLE
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,18 @@
* The reference bases paired with each read can be customized by passing in a reference window function
* inside the {@link ReferenceDataflowSource} argument to {@link #addBases}. See {@link org.broadinstitute.hellbender.engine.dataflow.datasources.RefWindowFunctions} for examples.
*/
public class JoinReadsWithRefBases {
public class ShuffleJoinReadsWithRefBases {

/**
* Joins each read of an RDD<GATKRead> with that read's corresponding reference sequence.
*
* @param referenceDataflowSource The source of the reference sequence information
* @param reads The reads for which to extract reference sequence information
* @return The JavaPairRDD that contains each read along with the corresponding ReferenceBases object
*/
public static JavaPairRDD<GATKRead, ReferenceBases> addBases(final ReferenceDataflowSource referenceDataflowSource,
final JavaRDD<GATKRead> reads) {
// TODO: reimpl this method by calling out to the more complex version?
SerializableFunction<GATKRead, SimpleInterval> windowFunction = referenceDataflowSource.getReferenceWindowFunction();

JavaPairRDD<ReferenceShard, GATKRead> shardRead = reads.mapToPair(gatkRead -> {
Expand All @@ -80,4 +89,41 @@ public static JavaPairRDD<GATKRead, ReferenceBases> addBases(final ReferenceData
return out;
});
}
}

/**
* Joins each read of an RDD<GATKRead, T> with key's corresponding reference sequence.
*
* @param referenceDataflowSource The source of the reference sequence information
* @param keyedByRead The read-keyed RDD for which to extract reference sequence information
* @return The JavaPairRDD that contains each read along with the corresponding ReferenceBases object and the value
*/
public static <T> JavaPairRDD<GATKRead, Tuple2<T, ReferenceBases>> addBases(final ReferenceDataflowSource referenceDataflowSource,
final JavaPairRDD<GATKRead, T> keyedByRead) {
SerializableFunction<GATKRead, SimpleInterval> windowFunction = referenceDataflowSource.getReferenceWindowFunction();

JavaPairRDD<ReferenceShard, Tuple2<GATKRead, T>> shardRead = keyedByRead.mapToPair(pair -> {
ReferenceShard shard = ReferenceShard.getShardNumberFromInterval(windowFunction.apply(pair._1()));
return new Tuple2<>(shard, pair);
});

JavaPairRDD<ReferenceShard, Iterable<Tuple2<GATKRead, T>>> shardiRead = shardRead.groupByKey();

return shardiRead.flatMapToPair(in -> {
List<Tuple2<GATKRead, Tuple2<T, ReferenceBases>>> out = Lists.newArrayList();
Iterable<Tuple2<GATKRead, T>> iReads = in._2();

// Apply the reference window function to each read to produce a set of intervals representing
// the desired reference bases for each read.
final List<SimpleInterval> readWindows = StreamSupport.stream(iReads.spliterator(), false).map(pair -> windowFunction.apply(pair._1())).collect(Collectors.toList());

SimpleInterval interval = SimpleInterval.getSpanningInterval(readWindows);
// TODO: don't we need to support GCS PipelineOptions?
ReferenceBases bases = referenceDataflowSource.getReferenceBases(null, interval);
for (Tuple2<GATKRead, T> p : iReads) {
final ReferenceBases subset = bases.getSubset(windowFunction.apply(p._1()));
out.add(new Tuple2<>(p._1(), new Tuple2<>(p._2(), subset)));
}
return out;
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.broadinstitute.hellbender.engine.spark.datasources;

import org.bdgenomics.utils.io.ByteArrayByteAccess;

/**
* A version of org.bdgenomics.utils.io.ByteArrayByteAccess that makes no copies of byte array used for initialization.
* However DirectFullByteArrayByteAccess.readFully can only return a reference to the full underlying byte array.
* Therefore, the user should exercise caution that the underlying data does not get mutated.
*/
class DirectFullByteArrayByteAccess extends ByteArrayByteAccess {
private static final long serialVersionUID = 1L;

DirectFullByteArrayByteAccess(byte[] bytes) {
super(bytes);
}

@Override
public byte[] readFully(long offset, int length) {
if ((offset != 0) || (length != this.length())) {
throw new IllegalArgumentException("readFully can only return a reference to the full underlying byte array");
}
return this.bytes();
}
}
Loading

0 comments on commit ccca14f

Please sign in to comment.