Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite of MarkDuplicates which seems to improve performance #380

Merged
merged 4 commits into from Sep 18, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -25,40 +25,42 @@ import org.bdgenomics.formats.avro.AlignmentRecord

private[rdd] object MarkDuplicates extends Serializable {

def markReads(buckets: Seq[SingleReadBucket], areDups: Boolean): Seq[SingleReadBucket] = {
for (bucket <- buckets; read <- bucket.primaryMapped ++ bucket.secondaryMapped) {
read.setDuplicateRead(areDups)
}
for (bucket <- buckets; read <- bucket.unmapped) {
private def markReadsInBucket(bucket: SingleReadBucket, primaryAreDups: Boolean, secondaryAreDups: Boolean) {
bucket.primaryMapped.foreach(read => {
read.setDuplicateRead(primaryAreDups)
})
bucket.secondaryMapped.foreach(read => {
read.setDuplicateRead(secondaryAreDups)
})
bucket.unmapped.foreach(read => {
read.setDuplicateRead(false)
}
buckets
})
}

// Calculates the sum of the phred scores that are greater than or equal to 15
def score(record: AlignmentRecord): Int = {
record.qualityScores.filter(15 <=).sum
}

def scoreAndMarkReads(buckets: Seq[SingleReadBucket]): Seq[SingleReadBucket] = {
val scoredBuckets = buckets.map(p => (p.primaryMapped.map(score).sum, p))
val sortedBuckets = scoredBuckets.sortBy(_._1)(Ordering[Int].reverse)
private def scoreBucket(bucket: SingleReadBucket): Int = {
bucket.primaryMapped.map(score).sum
}

for (((score, bucket), i) <- sortedBuckets.zipWithIndex) {
for (read <- bucket.primaryMapped) {
read.setDuplicateRead(i != 0)
}
for (read <- bucket.secondaryMapped) {
read.setDuplicateRead(true)
}
for (read <- bucket.unmapped) {
read.setDuplicateRead(false)
private def markReads(reads: Iterable[(ReferencePositionPair, SingleReadBucket)], areDups: Boolean) {
markReads(reads, primaryAreDups = areDups, secondaryAreDups = areDups, ignore = None)
}

private def markReads(reads: Iterable[(ReferencePositionPair, SingleReadBucket)], primaryAreDups: Boolean, secondaryAreDups: Boolean,
ignore: Option[(ReferencePositionPair, SingleReadBucket)] = None) {
reads.foreach(read => {
if (ignore.isEmpty || read != ignore.get) {
markReadsInBucket(read._2, primaryAreDups, secondaryAreDups)
}
}
buckets
})
}

def apply(rdd: RDD[AlignmentRecord]): RDD[AlignmentRecord] = {

// Group by library and left position
def leftPositionAndLibrary(p: (ReferencePositionPair, SingleReadBucket)): (Option[ReferencePositionWithOrientation], CharSequence) = {
(p._1.read1refPos, p._2.allReads.head.getRecordGroupLibrary)
Expand All @@ -71,47 +73,62 @@ private[rdd] object MarkDuplicates extends Serializable {

rdd.adamSingleReadBuckets().keyBy(ReferencePositionPair(_)).groupBy(leftPositionAndLibrary)
.flatMap(kv => {
val ((leftPos, library), readsByLeftPos) = kv

val buckets = leftPos match {
val leftPos: Option[ReferencePositionWithOrientation] = kv._1._1
val readsAtLeftPos: Iterable[(ReferencePositionPair, SingleReadBucket)] = kv._2

leftPos match {

// These are all unmapped reads. There is no way to determine if they are duplicates
case None =>
markReads(readsByLeftPos.toSeq.unzip._2, areDups = false)
markReads(readsAtLeftPos, areDups = false)

// These reads have their left position mapped
case Some(leftPosWithOrientation) =>
// Group the reads by their right position
val readsByRightPos = readsByLeftPos.groupBy(rightPosition)
// Find any reads with no right position
val fragments = readsByRightPos.get(None)
// Check if we have any pairs (reads with a right position)
val hasPairs = readsByRightPos.keys.exists(_.isDefined)

if (hasPairs) {
// Since we have pairs, mark all fragments as duplicates
val processedFrags = if (fragments.isDefined) {
markReads(fragments.get.toSeq.unzip._2, areDups = true)

val readsByRightPos = readsAtLeftPos.groupBy(rightPosition)

val groupCount = readsByRightPos.size

readsByRightPos.foreach(e => {

val rightPos = e._1
val reads = e._2

val groupIsFragments = rightPos.isEmpty

// We have no pairs (only fragments) if the current group is a group of fragments
// and there is only one group in total
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nifty! Nice approach.

val onlyFragments = groupIsFragments && groupCount == 1

// If there are only fragments then score the fragments. Otherwise, if there are not only
// fragments (there are pairs as well) mark all fragments as duplicates.
// If the group does not contain fragments (it contains pairs) then always score it.
if (onlyFragments || !groupIsFragments) {
// Find the highest-scoring read and mark it as not a duplicate. Mark all the other reads in this group as duplicates.
val highestScoringRead = reads.max(ScoreOrdering)
markReadsInBucket(highestScoringRead._2, primaryAreDups = false, secondaryAreDups = true)
markReads(reads, primaryAreDups = true, secondaryAreDups = true, ignore = Some(highestScoringRead))
} else {
Seq.empty
markReads(reads, areDups = true)
}

val processedPairs = for (
buckets <- (readsByRightPos - None).values;
processedPair <- scoreAndMarkReads(buckets.toSeq.unzip._2)
) yield processedPair

processedPairs ++ processedFrags
})

} else if (fragments.isDefined) {
// No pairs. Score the fragments.
scoreAndMarkReads(fragments.get.toSeq.unzip._2)
} else {
Seq.empty
}
}

buckets.flatMap(_.allReads)
readsAtLeftPos.flatMap(read => { read._2.allReads })

})

}

private object ScoreOrdering extends Ordering[(ReferencePositionPair, SingleReadBucket)] {
override def compare(x: (ReferencePositionPair, SingleReadBucket), y: (ReferencePositionPair, SingleReadBucket)): Int = {
// This is safe because scores are Ints
scoreBucket(x._2) - scoreBucket(y._2)
}
}

}

Expand Up @@ -180,4 +180,16 @@ class MarkDuplicatesSuite extends SparkFunSuite {
assert(MarkDuplicates.score(record) == 2000)
}

sparkTest("read pairs that cross chromosomes") {
val poorPairs = for (
i <- 0 until 10;
read <- createPair("ref0", 10, "ref1", 210, avgPhredScore = 20, readName = "poor%d".format(i))
) yield read
val bestPair = createPair("ref0", 10, "ref1", 210, avgPhredScore = 30, readName = "best")
val marked = markDuplicates(bestPair ++ poorPairs: _*)
val (dups, nonDups) = marked.partition(_.getDuplicateRead)
assert(nonDups.size == 2 && nonDups.forall(p => p.getReadName.toString == "best"))
assert(dups.forall(p => p.getReadName.startsWith("poor")))
}

}