Skip to content

Commit

Permalink
More complete docs, adding edge test case
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-petersohn committed Jun 13, 2017
1 parent a4d196e commit 564bc7d
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,6 @@ object ReferenceRegion {
implicit def orderingForPositions = RegionOrdering
implicit def orderingForOptionalPositions = OptionalRegionOrdering

/**
* Creates an empty ReferenceRegion.
*
* @return An empty ReferenceRegion.
*/
private[adam] val empty: ReferenceRegion = ReferenceRegion("", 0L, 0L)

/**
* Creates a reference region that starts at the beginning of a contig.
*
Expand Down
164 changes: 102 additions & 62 deletions adam-core/src/main/scala/org/bdgenomics/adam/rdd/settheory/Closest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,96 +32,132 @@ import scala.reflect.ClassTag
* @tparam RX The resulting type of the right after the operation.
*/
sealed trait Closest[T, U <: GenomicRDD[T, U], X, Y <: GenomicRDD[X, Y], RT, RX]
extends SetTheoryBetweenCollections[T, U, X, Y, RT, RX]
with SetTheoryPrimitive {

var currentClosest: ReferenceRegion = ReferenceRegion.empty
extends SetTheoryBetweenCollections[T, U, X, Y, RT, RX] {

override protected def condition(firstRegion: ReferenceRegion,
secondRegion: ReferenceRegion,
cache: SetTheoryCache[X, RT, RX],
threshold: Long = 0L): Boolean = {
firstRegion.unstrandedDistance(currentClosest)

// we must maintain this invariant throughout the computation
cache.closest.isDefined &&
// we want to identify all the regions that share the same distance as our
// current closest.
firstRegion.unstrandedDistance(cache.closest.get)
.exists(_ == firstRegion.unstrandedDistance(secondRegion).getOrElse(Long.MaxValue))
}

override protected def pruneCacheCondition(cachedRegion: ReferenceRegion,
to: ReferenceRegion): Boolean = {
if (cachedRegion.referenceName != to.referenceName) {
true
} else {
to: ReferenceRegion,
cache: SetTheoryCache[X, RT, RX]): Boolean = {

// we must maintain this invariant throughout the computation
cache.closest.isDefined &&
// we want to prune in the case that the unstranded distance between the
// current query region is greater than our current closest
cachedRegion.referenceName == to.referenceName &&
to.unstrandedDistance(cachedRegion).get >
to.unstrandedDistance(currentClosest).getOrElse(Long.MaxValue)
}
to.unstrandedDistance(cache.closest.get).getOrElse(Long.MaxValue)
}

override protected def advanceCacheCondition(candidateRegion: ReferenceRegion,
until: ReferenceRegion): Boolean = {
until: ReferenceRegion,
cache: SetTheoryCache[X, RT, RX]): Boolean = {

// if our current closest isn't on the same reference name, we don't
// consider it the closest, thus we have no current closest
if (cache.closest.isDefined &&
cache.closest.get.referenceName != candidateRegion.referenceName) {

cache.closest = None
}

// if the reference names don't match, we don't consider them the closest,
// unless we have no current closest
if (candidateRegion.referenceName != until.referenceName &&
cache.closest.isDefined) {

if (candidateRegion.referenceName != until.referenceName) {
false
} else if (until.referenceName != currentClosest.referenceName ||
// current closest must be set if there is no current closest. This
// prevents us from dropping results when we don't have any records of that
// reference name in the dataset. otherwise, we set current closest if it
// is closer than our current
} else if (cache.closest.isEmpty ||
until.referenceName != cache.closest.get.referenceName ||
until.unstrandedDistance(candidateRegion).get <=
until.unstrandedDistance(currentClosest).getOrElse(Long.MaxValue)) {

currentClosest = candidateRegion
until.unstrandedDistance(cache.closest.get).getOrElse(Long.MaxValue)) {
// this object can be short lived, but the overhead should be low for
// options
cache.closest = Some(candidateRegion)
true
} else {
// we reach this on the region immediately after we have passed the
// closest region
false
}
}

override protected def prepare()(implicit tTag: ClassTag[T], xtag: ClassTag[X]): (RDD[(ReferenceRegion, T)], RDD[(ReferenceRegion, X)]) = {
override protected def prepare()(
implicit tTag: ClassTag[T], xtag: ClassTag[X]): (RDD[(ReferenceRegion, T)], RDD[(ReferenceRegion, X)]) = {

val (preparedLeftRdd, partitionMap) = {
if (leftRdd.optPartitionMap.isDefined) {
(leftRdd.flattenRddByRegions, leftRdd.optPartitionMap.get)
(leftRdd.flattenRddByRegions(), leftRdd.optPartitionMap.get)
} else {
val sortedLeft = leftRdd.sortLexicographically(storePartitionMap = true)
(sortedLeft.flattenRddByRegions, sortedLeft.optPartitionMap.get)
(sortedLeft.flattenRddByRegions(), sortedLeft.optPartitionMap.get)
}
}

val adjustedPartitionMapWithIndex = partitionMap
// the zipWithIndex gives us the destination partition ID
.zipWithIndex
.filter(_._1.nonEmpty)
.map(f => (f._1.get, f._2))
.map(g => {
.map(f => (f._1.get, f._2)).map(g => {
// first region for the bound
val rr = g._1._1
// second region for the bound
val secondrr = g._1._2
// in the case where we span multiple referenceNames
if (g._1._1.referenceName != g._1._2.referenceName) {
if (rr.referenceName != g._1._2.referenceName) {
// create a ReferenceRegion that goes to the end of the chromosome
(ReferenceRegion(
g._1._1.referenceName,
g._1._1.start,
g._1._1.end),
g._2)
(ReferenceRegion(rr.referenceName, rr.start, rr.end), g._2)
} else {
// otherwise we just have the ReferenceRegion span from partition
// start to end
(ReferenceRegion(
g._1._1.referenceName,
g._1._1.start,
g._1._2.end),
g._2)
(ReferenceRegion(rr.referenceName, rr.start, secondrr.end), g._2)
}
})

// we use an interval array to quickly look up the destination partitions
val partitionMapIntervals = IntervalArray(
adjustedPartitionMapWithIndex,
adjustedPartitionMapWithIndex.maxBy(_._1.width)._1.width,
sorted = true)

val assignedRightRdd = {
val firstPass = rightRdd.flattenRddByRegions.mapPartitions(iter => {
iter.flatMap(f => {
val rangeOfHits = partitionMapIntervals.get(f._1, requireOverlap = false)
rangeOfHits.map(g => ((f._1, g._2), f._2))
})
}, preservesPartitioning = true)

val assignedRightRdd: RDD[((ReferenceRegion, Int), X)] = {
// copartitioning for the closest is tricky, and requires that we handle
// unique edge cases, described below.
// the first pass gives us the initial destination partitions.
val firstPass = rightRdd.flattenRddByRegions()
.mapPartitions(iter => {
iter.flatMap(f => {
val rangeOfHits = partitionMapIntervals.get(f._1, requireOverlap = false)
rangeOfHits.map(g => ((f._1, g._2), f._2))
})
}, preservesPartitioning = true)

// we have to find the partitions that don't have right data going there
// so we can send the flanking partitions' data there
val partitionsWithoutData =
partitionMap.indices.filterNot(firstPass.map(_._1._2).distinct().collect.contains)

// this gives us a list of partitions that are sending copies of their
// data and the number of nodes to send to. a negative number of nodes
// indicates that the data needs to be sent to lower numbered nodes, a
// positive number indicates that the data needs to be sent to higher
// numbered nodes. the way this is written, it will handle an arbitrary
// run of empty partitions.
val partitionsToSend = partitionsWithoutData.foldLeft(List.empty[List[Int]])((b, a) => {
if (b.isEmpty) {
List(List(a))
Expand All @@ -130,29 +166,35 @@ sealed trait Closest[T, U <: GenomicRDD[T, U], X, Y <: GenomicRDD[X, Y], RT, RX]
} else {
b.:+(List(a))
}
// we end up getting all the data from both flanking nodes. we use the
// length here so we know how many destinations we have resulting from
// runs of empty partitions.
}).flatMap(f => List((f.head - 1, f.length), (f.last + 1, -1 * f.length)))

firstPass.flatMap(f => {
val index = partitionsToSend.indexWhere(_._1 == f._1._2)
if (index < 0) {
List(f)
} else {
if (partitionsToSend(index)._2 < 0) {
(partitionsToSend(index)._2 to 0)
.map(g => ((f._1._1, f._1._2 + g), f._2))
// extract the destinations for this data record
val destinations = partitionsToSend.filter(g => g._1 == f._1._2)
// we use an inclusive range to specify all destinations
val duplicatedRecords = {
if (destinations.length == 1) {
// the data is only going to lower numbered nodes
if (destinations.head._2 < 0) {
destinations.head._2 to 0
// the data is only going to higher numbered nodes
} else {
0 to destinations.head._2
}
// the data is going to higher and lower numbered nodes
} else if (destinations.length == 2) {
destinations.last._2 to destinations.head._2
// the data is only going to its original destination
} else {
(0 to partitionsToSend(index)._2)
.map(g => ((f._1._1, f._1._2 + g), f._2)) ++ {
if (index == partitionsToSend.lastIndexWhere(_._1 == f._1._2)) {
List()
} else {
val endIndex = partitionsToSend.lastIndexWhere(_._1 == f._1._2)
(partitionsToSend(endIndex)._2 to -1)
.map(g => ((f._1._1, f._1._2 + g), f._2))
}
}
0 to 0
}
}
// add the destination
}.map(g => ((f._1._1, f._1._2 + g), f._2))

duplicatedRecords
})
}

Expand All @@ -172,16 +214,14 @@ sealed trait Closest[T, U <: GenomicRDD[T, U], X, Y <: GenomicRDD[X, Y], RT, RX]
*
* @param leftRdd The left RDD.
* @param rightRdd The right RDD.
* @param optPartitionMap An optional partition map defining the left RDD
* partition bounds.
* @param threshold The maximum distance allowed for the closest.
* @param optPartitions Optionally sets the number of partitions for the join.
* @tparam T The type of the left records.
* @tparam U The type of the right records.
*/
case class ShuffleClosestRegion[T, U <: GenomicRDD[T, U], X, Y <: GenomicRDD[X, Y]](
protected val leftRdd: GenomicRDD[T, U],
protected val rightRdd: GenomicRDD[X, Y],
protected val optPartitionMap: Option[Array[Option[(ReferenceRegion, ReferenceRegion)]]],
protected val threshold: Long = Long.MaxValue,
protected val optPartitions: Option[Int] = None)
extends Closest[T, U, X, Y, T, Iterable[X]]
Expand Down
Loading

0 comments on commit 564bc7d

Please sign in to comment.