Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADAM-528][ADAM-533] Adds new RegionJoin impl that is shuffle-based #534

Merged
merged 1 commit into from Jan 9, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -27,7 +27,7 @@ import org.bdgenomics.adam.models.{ SequenceDictionary, ReferenceRegion }
import org.bdgenomics.adam.projections.Projection
import org.bdgenomics.adam.projections.AlignmentRecordField._
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.adam.rdd.RegionJoin
import org.bdgenomics.adam.rdd.BroadcastRegionJoin
import org.bdgenomics.adam.rich.ReferenceMappingContext._
import org.bdgenomics.formats.avro.AlignmentRecord
import scala.io._
Expand Down Expand Up @@ -90,8 +90,8 @@ class CalculateDepth(protected val args: CalculateDepthArgs) extends ADAMSparkCo
val variantNames = vcf.collect().toMap

val joinedRDD: RDD[(ReferenceRegion, AlignmentRecord)] =
if (args.cartesian) RegionJoin.cartesianFilter(variantPositions, mappedRDD)
else RegionJoin.partitionAndJoin(sc, variantPositions, mappedRDD)
if (args.cartesian) BroadcastRegionJoin.cartesianFilter(variantPositions, mappedRDD)
else BroadcastRegionJoin.partitionAndJoin(sc, variantPositions, mappedRDD)

val depths: RDD[(ReferenceRegion, Int)] =
joinedRDD.map { case (region, record) => (region, 1) }.reduceByKey(_ + _).sortByKey()
Expand Down
Expand Up @@ -165,7 +165,7 @@ object ReferencePosition {
/**
* The UNMAPPED value is a convenience value, which can be used to indicate a position
* which is not located anywhere along the reference sequences (see, e.g. its use in
* GenomicRegionPartitioner).
* GenomicPositionPartitioner).
*/
val UNMAPPED = new ReferencePosition("", -1)

Expand Down
Expand Up @@ -31,10 +31,10 @@ import scala.reflect.ClassTag
* Different implementations will have different performance characteristics -- and new implementations
* will likely be added in the future, see the notes to each individual method for more details.
*/
object RegionJoin {
object BroadcastRegionJoin {

/**
* Performs a region join between two RDDs.
* Performs a region join between two RDDs (broadcast join).
*
* This implementation first _collects_ the left-side RDD; therefore, if the left-side RDD is large
* or otherwise idiosyncratic in a spatial sense (i.e. contains a set of regions whose unions overlap
Expand Down
Expand Up @@ -17,12 +17,12 @@
*/
package org.bdgenomics.adam.rdd

import org.bdgenomics.adam.models.{ ReferencePosition, SequenceDictionary }
import org.bdgenomics.adam.models.{ ReferenceRegion, ReferenceMapping, ReferencePosition, SequenceDictionary }
import org.apache.spark.{ Logging, Partitioner }
import scala.math._

/**
* GenomicRegionPartitioner partitions ReferencePosition objects into separate, spatially-coherent
* GenomicPositionPartitioner partitions ReferencePosition objects into separate, spatially-coherent
* regions of the genome.
*
* This can be used to organize genomic data for computation that is spatially distributed (e.g. GATK and Queue's
Expand All @@ -35,9 +35,9 @@ import scala.math._
* @param seqLengths a map relating sequence-name to length and indicating the set and length of all extant
* sequences in the genome.
*/
case class GenomicRegionPartitioner(numParts: Int, seqLengths: Map[String, Long]) extends Partitioner with Logging {
case class GenomicPositionPartitioner(numParts: Int, seqLengths: Map[String, Long]) extends Partitioner with Logging {

log.info("Have genomic region partitioner with " + numParts + " partitions, and sequences:")
log.info("Have genomic position partitioner with " + numParts + " partitions, and sequences:")
seqLengths.foreach(kv => log.info("Contig " + kv._1 + " with length " + kv._2))

val names: Seq[String] = seqLengths.keys.toSeq.sortWith(_ < _)
Expand Down Expand Up @@ -80,7 +80,7 @@ case class GenomicRegionPartitioner(numParts: Int, seqLengths: Map[String, Long]
case refpos: ReferencePosition => getPart(refpos.referenceName, refpos.pos)

// only ReferencePosition values are partitioned using this partitioner
case _ => throw new IllegalArgumentException("Only ReferencePosition values can be partitioned by GenomicRegionPartitioner")
case _ => throw new IllegalArgumentException("Only ReferencePosition values can be partitioned by GenomicPositionPartitioner")
}
}

Expand All @@ -90,11 +90,41 @@ case class GenomicRegionPartitioner(numParts: Int, seqLengths: Map[String, Long]

}

object GenomicRegionPartitioner {
object GenomicPositionPartitioner {

def apply(numParts: Int, seqDict: SequenceDictionary): GenomicRegionPartitioner =
GenomicRegionPartitioner(numParts, extractLengthMap(seqDict))
def apply(numParts: Int, seqDict: SequenceDictionary): GenomicPositionPartitioner =
GenomicPositionPartitioner(numParts, extractLengthMap(seqDict))

def extractLengthMap(seqDict: SequenceDictionary): Map[String, Long] =
Map(seqDict.records.toSeq.map(rec => (rec.name.toString, rec.length)): _*)
}

case class GenomicRegionPartitioner[T: ReferenceMapping](partitionSize: Long, seqLengths: Map[String, Long], start: Boolean = true) extends Partitioner with Logging {
private val names: Seq[String] = seqLengths.keys.toSeq.sortWith(_ < _)
private val lengths: Seq[Long] = names.map(seqLengths(_))
private val parts: Seq[Int] = lengths.map(v => round(ceil(v.toDouble / partitionSize)).toInt)
private val cumulParts: Map[String, Int] = Map(names.zip(parts.scan(0)(_ + _)): _*)

private def extractReferenceRegion(k: T)(implicit tMapping: ReferenceMapping[T]): ReferenceRegion = {
tMapping.getReferenceRegion(k)
}

private def computePartition(refReg: ReferenceRegion): Int = {
val pos = if (start) refReg.start else (refReg.end - 1)
(cumulParts(refReg.referenceName) + pos / partitionSize).toInt
}

override def numPartitions: Int = parts.sum

override def getPartition(key: Any): Int = {
key match {
case mappable: T => computePartition(extractReferenceRegion(mappable))
case _ => throw new IllegalArgumentException("Only ReferenceMappable values can be partitioned by GenomicRegionPartitioner")
}
}
}

object GenomicRegionPartitioner {
def apply[T: ReferenceMapping](partitionSize: Long, seqDict: SequenceDictionary): GenomicRegionPartitioner[T] =
GenomicRegionPartitioner(partitionSize, GenomicPositionPartitioner.extractLengthMap(seqDict))
}