Skip to content

Commit

Permalink
Merge pull request #4 from devin-petersohn/intersection
Browse files Browse the repository at this point in the history
Intersection implementation
  • Loading branch information
devin-petersohn committed Apr 14, 2017
2 parents 1a2c3f5 + 33b9b35 commit ef06144
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 54 deletions.
Expand Up @@ -3,6 +3,7 @@ package org.bdgenomics.lime.cli
import org.apache.spark.SparkContext
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.adam.rdd.ADAMSaveAnyArgs
import org.bdgenomics.lime.set_theory.DistributedIntersection
import org.bdgenomics.utils.cli._
import org.kohsuke.args4j.Argument

Expand Down Expand Up @@ -39,8 +40,6 @@ object Intersection extends BDGCommandCompanion {
val leftGenomicRDD = sc.loadBed(args.leftInput)
val rightGenomicRDD = sc.loadBed(args.rightInput)

leftGenomicRDD.shuffleRegionJoin(rightGenomicRDD)
leftGenomicRDD.rdd.collect.foreach(println)
}
}
}
Expand Up @@ -10,7 +10,7 @@ private object LimeMain {
}

private class LimeMain(args: Array[String]) extends Logging {
private def commands: List[BDGCommandCompanion] = List(Intersection, Sort)
private def commands: List[BDGCommandCompanion] = List(Intersection, Sort, Merge)

private def printVersion() {
println("Version 0")
Expand Down
7 changes: 1 addition & 6 deletions lime-cli/src/main/scala/org/bdgenomics/lime/cli/Merge.scala
Expand Up @@ -3,14 +3,10 @@ package org.bdgenomics.lime.cli
import org.apache.spark.SparkContext
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.adam.rdd.ADAMSaveAnyArgs
import org.bdgenomics.adam.rdd.feature.FeatureRDD
import org.bdgenomics.lime.set_theory.DistributedMerge
import org.bdgenomics.utils.cli._
import org.kohsuke.args4j.Argument

/**
* Created by DevinPetersohn on 4/6/17.
*/
object Merge extends BDGCommandCompanion {
val commandName = "merge"
val commandDescription = "Merges the regions in a file."
Expand Down Expand Up @@ -40,8 +36,7 @@ object Merge extends BDGCommandCompanion {
DistributedMerge(genomicRdd.flattenRddByRegions(),
genomicRdd.partitionMap.get)
.compute()
.collect
.foreach(println)
.collect.foreach(println)
}
}

Expand Down

This file was deleted.

@@ -0,0 +1,83 @@
package org.bdgenomics.lime.set_theory

import org.apache.spark.rdd.RDD
import org.bdgenomics.adam.models.ReferenceRegion
import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag

sealed abstract class Intersection[T: ClassTag, U: ClassTag] extends SetTheoryBetweenCollections[T, U, T, U] {

val threshold: Long

def primitive(currRegion: ReferenceRegion,
tempRegion: ReferenceRegion,
minimumOverlap: Long = 0L): ReferenceRegion = {

currRegion.intersection(tempRegion, minimumOverlap)
}

def condition(firstRegion: ReferenceRegion,
secondRegion: ReferenceRegion,
minimumOverlap: Long = 0L): Boolean = {

firstRegion.overlapsBy(secondRegion).exists(_ >= threshold)
}
}

case class DistributedIntersection[T: ClassTag, U: ClassTag](leftRdd: RDD[(ReferenceRegion, T)],
rightRdd: RDD[(ReferenceRegion, U)],
partitionMap: Array[Option[(ReferenceRegion, ReferenceRegion)]],
threshold: Long = 0L) extends Intersection[T, U] {

private val cache: ListBuffer[(ReferenceRegion, U)] = ListBuffer.empty[(ReferenceRegion, U)]

def compute(): RDD[(ReferenceRegion, (T, U))] = {
leftRdd.zipPartitions(rightRdd)(sweep)
}

private def sweep(leftIter: Iterator[(ReferenceRegion, T)],
rightIter: Iterator[(ReferenceRegion, U)]): Iterator[(ReferenceRegion, (T, U))] = {

makeIterator(leftIter.buffered, rightIter.buffered)
}

private def makeIterator(left: BufferedIterator[(ReferenceRegion, T)],
right: BufferedIterator[(ReferenceRegion, U)]): Iterator[(ReferenceRegion, (T, U))] = {

def advanceCache(until: ReferenceRegion) = {
while (right.hasNext && (right.head._1.compareTo(until) <= 0 ||
right.head._1.covers(until))) {

cache += right.next
}
}

def pruneCache(to: ReferenceRegion) = {
cache.trimStart({
val index = cache.indexWhere(f => !(f._1.compareTo(to) < 0 && !f._1.covers(to)))
if (index <= 0) {
0
} else {
index
}
})
}

left.flatMap(f => {
val (currentRegion, _) = f
advanceCache(currentRegion)
pruneCache(currentRegion)
processHits(f)
})
}

private def processHits(current: (ReferenceRegion, T)): Iterator[(ReferenceRegion, (T, U))] = {

val (currentRegion, _) = current
cache.filter(f => f._1.overlapsBy(currentRegion).exists(_ >= threshold))
.map(g => {
(currentRegion.intersection(g._1, threshold), (current._2, g._2))
}).iterator
}

}
Expand Up @@ -5,16 +5,21 @@ import org.bdgenomics.adam.models.ReferenceRegion
import scala.reflect.ClassTag

sealed abstract class Merge[T: ClassTag] extends SetTheoryWithSingleCollection[T] {
def primitive(currRegion: ReferenceRegion,
tempRegion: ReferenceRegion,
def primitive(firstRegion: ReferenceRegion,
secondRegion: ReferenceRegion,
distanceThreshold: Long = 0L): ReferenceRegion = {

currRegion.merge(tempRegion, distanceThreshold)
firstRegion.merge(secondRegion, distanceThreshold)
}

def condition(firstRegion: ReferenceRegion,
secondRegion: ReferenceRegion,
distanceThreshold: Long = 0L): Boolean = {

firstRegion.overlaps(secondRegion, distanceThreshold)
}
}

case class DistributedMerge[T: ClassTag](rddToCompute: RDD[(ReferenceRegion, T)],
partitionMap: Array[Option[(ReferenceRegion, ReferenceRegion)]],
distanceThreshold: Long = 0L) extends Merge[T] {

}
threshold: Long = 0L) extends Merge[T]
Expand Up @@ -3,32 +3,37 @@ package org.bdgenomics.lime.set_theory
import org.apache.spark.rdd.RDD
import org.bdgenomics.adam.models.ReferenceRegion
import org.bdgenomics.lime.util.Partitioners.ReferenceRegionRangePartitioner
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag

abstract class SetTheory[T: ClassTag] extends Serializable {
abstract class SetTheory extends Serializable {
val partitionMap: Array[Option[(ReferenceRegion, ReferenceRegion)]]
val distanceThreshold: Long
val threshold: Long

def primitive(firstRegion: ReferenceRegion,
secondRegion: ReferenceRegion,
distanceThreshold: Long = 0L): ReferenceRegion

def condition(firstRegion: ReferenceRegion,
secondRegion: ReferenceRegion,
distanceThreshold: Long = 0L): Boolean

}

abstract class SetTheoryBetweenCollections[T: ClassTag, U: ClassTag, RT, RU] extends SetTheory[T] {
abstract class SetTheoryBetweenCollections[T: ClassTag, U: ClassTag, RT, RU] extends SetTheory {
val leftRdd: RDD[(ReferenceRegion, T)]
val rightRdd: RDD[(ReferenceRegion, U)]

def compute(): RDD[(ReferenceRegion, (RT, RU))]
}

abstract class SetTheoryWithSingleCollection[T: ClassTag] extends SetTheory[T] {
abstract class SetTheoryWithSingleCollection[T: ClassTag] extends SetTheory {
val rddToCompute: RDD[(ReferenceRegion, T)]

def compute(): RDD[(ReferenceRegion, Iterable[T])] = {
val localComputed = localCompute(rddToCompute.map(f => (f._1, Iterable(f._2))), distanceThreshold)
externalCompute(localComputed, distanceThreshold, 2)
val localComputed = localCompute(rddToCompute.map(f => (f._1, Iterable(f._2))), threshold)
externalCompute(localComputed, partitionMap, threshold, 2)
}

protected def localCompute(rdd: RDD[(ReferenceRegion, Iterable[T])], distanceThreshold: Long): RDD[(ReferenceRegion, Iterable[T])] = {
Expand All @@ -42,7 +47,7 @@ abstract class SetTheoryWithSingleCollection[T: ClassTag] extends SetTheory[T] {
while (iter.hasNext) {
val tempTuple = iter.next()
val tempRegion = tempTuple._1
if (currRegion.isNearby(tempRegion, distanceThreshold)) {
if (condition(currRegion, tempRegion, distanceThreshold)) {
currRegion = primitive(currRegion, tempRegion, distanceThreshold)
tempValueListBuffer ++= currTuple._2
} else {
Expand All @@ -58,36 +63,84 @@ abstract class SetTheoryWithSingleCollection[T: ClassTag] extends SetTheory[T] {
})
}

protected def externalCompute(rdd: RDD[(ReferenceRegion, Iterable[T])],
distanceThreshold: Long, round: Int): RDD[(ReferenceRegion, Iterable[T])] = {
/**
* Computes the primitives between partitions.
*
* @param rdd The rdd to compute on.
* @param partitionMap The current partition map for rdd.
* @param distanceThreshold The distance threshold for the condition and primitive.
* @param round The current round of computation in the recursion tree. Increments by a factor of 2 each round.
* @return The computed rdd for this round.
*/
@tailrec private def externalCompute(rdd: RDD[(ReferenceRegion, Iterable[T])],
partitionMap: Array[Option[(ReferenceRegion, ReferenceRegion)]],
distanceThreshold: Long,
round: Int): RDD[(ReferenceRegion, Iterable[T])] = {

if (round == partitionMap.length) {
if (round > partitionMap.length) {
return rdd
}

val partitionedRdd = rdd.mapPartitionsWithIndex((idx, iter) => {
val indexWithinParent = idx % round
val partnerPartitionBounds =
val partnerPartition = {
var i = 1
if (idx > 0) {
partitionMap(idx - 1).get
while (partitionMap(idx - i).isEmpty) {
i += 1
}
idx - i
} else {
partitionMap(idx).get
idx
}
}

val partnerPartitionBounds = partitionMap(partnerPartition)

iter.map(f => {
val (region, value) = f
if (indexWithinParent == round / 2 &&
(region.covers(partnerPartitionBounds._2, distanceThreshold) ||
region.compareTo(partnerPartitionBounds._2) <= 0)) {
(region.covers(partnerPartitionBounds.get._2, distanceThreshold) ||
region.compareTo(partnerPartitionBounds.get._2) <= 0)) {

((region, idx - 1), value)
((region, partnerPartition), value)
} else {
((region, idx), value)
}
})
}).repartitionAndSortWithinPartitions(new ReferenceRegionRangePartitioner(partitionMap.length))
.map(f => (f._1._1, f._2))

externalCompute(localCompute(partitionedRdd, distanceThreshold), distanceThreshold, round * 2)
val newPartitionMap = partitionedRdd.mapPartitions(iter => {
getRegionBoundsFromPartition(iter)
}).collect

externalCompute(localCompute(partitionedRdd, distanceThreshold), newPartitionMap, distanceThreshold, round * 2)
}

/**
* Gets the partition bounds from a ReferenceRegion keyed Iterator
*
* @param iter The data on a given partition. ReferenceRegion keyed
* @return The bounds of the ReferenceRegions on that partition, in an Iterator
*/
private def getRegionBoundsFromPartition(iter: Iterator[(ReferenceRegion, Iterable[T])]): Iterator[Option[(ReferenceRegion, ReferenceRegion)]] = {
if (iter.isEmpty) {
// This means that there is no data on the partition, so we have no bounds
Iterator(None)
} else {
val firstRegion = iter.next
val lastRegion =
if (iter.hasNext) {
// we have to make sure we get the full bounds of this partition, this
// includes any extremely long regions. we include the firstRegion for
// the case that the first region is extremely long
(iter ++ Iterator(firstRegion)).maxBy(f => (f._1.referenceName, f._1.end, f._1.start))
// only one record on this partition, so this is the extent of the bounds
} else {
firstRegion
}
Iterator(Some((firstRegion._1, lastRegion._1)))
}
}
}
20 changes: 20 additions & 0 deletions lime-core/src/test/resources/cpg_20merge.bed
@@ -0,0 +1,20 @@
chr1 28735 29810 CpG:_116
chr1 29800 29820 CpG:_30
chr1 29815 29830 CpG:_29
chr1 29825 29840 CpG:_84
chr1 29835 29850 CpG:_99
chr1 29845 29860 CpG:_94
chr1 29855 29870 CpG:_171
chr1 29865 29880 CpG:_60
chr1 29875 29890 CpG:_115
chr1 29885 29900 CpG:_28
chr1 29895 29910 CpG:_24
chr1 29905 29920 CpG:_50
chr1 29915 29930 CpG:_83
chr1 29925 29940 CpG:_153
chr1 29935 29950 CpG:_16
chr1 29945 29960 CpG:_257
chr1 29955 29970 CpG:_178
chr1 29965 29980 CpG:_246
chr1 29975 29990 CpG:_18
chr1 29985 30000 CpG:_615
Expand Up @@ -8,10 +8,9 @@ import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.lime.LimeFunSuite

class MergeSuite extends LimeFunSuite {
sparkTest("test local merge") {
val genomicRdd = sc.loadBed(resourcesFile("/cpg.bed")).repartitionAndSort()

sparkTest("test local merge when all data merges to a single region") {
val genomicRdd = sc.loadBed(resourcesFile("/cpg_20merge.bed")).repartitionAndSort()
val x = DistributedMerge(genomicRdd.flattenRddByRegions, genomicRdd.partitionMap.get).compute()
x.collect.foreach(println)
}

}

0 comments on commit ef06144

Please sign in to comment.