Skip to content

Commit

Permalink
[ADAM-1358] Refactor BQSR to improve performance and legibility.
Browse files Browse the repository at this point in the history
Resolves bigdatagenomics#1358.

* Adds instrumentation to BQSR.
* Changed SnpTable to remove RichVariant conversion, use VariantRDD API.
* Refactoring SnpTable to eliminate per-residue costly masked site lookup.
* Restructuring core of SnpTable around an array to improve GC performance.
  Additionally, wrote custom serializer to improve serialization performance.
* Added test suite for SnpTable, to test table creation.
* Refactored SnpTable to use an IntervalArray-like approach. This approach
  improves masked site lookup performance by 50%.
* Added tests to SnpTableSuite to cover lookup case, and reenabled tests in
  BaseQualityRecalibrationSuite.
* Adding unit test coverage to covariates
* Revert "[ADAM-775] Allow all IUPAC codes in BQSR"
  This reverts commit 207eeba.
* Pulled Seq allocation for base check out into an immutable set.
* Rewrote dinuc covariate. 50% improvement in performance.
* Rewrite main BQSR aggregate as reduce by key
* Added tests to recalibrator, recalibration table.
* Majorly refactors of BQSR tables.
* Starting to factor out the QualityScore class
* Refactoring CovariateKey to reduce size in memory
* Eliminated `org.bdgenomics.adam.rich.DecadentRead` (partially resolves bigdatagenomics#577)
* Refactor CovariateKey to store record group ID instead of record group name.
* Removed `org.bdgenomics.adam.models.QualityScore`.
* Split multi-class files into one class per file (excepting private classes) to improve navigability.
* Scaladoc all the recalibrators! You get a scaladoc! And you get a scaladoc!
  • Loading branch information
fnothaft committed Mar 20, 2017
1 parent cf39e6c commit 51a8ce2
Show file tree
Hide file tree
Showing 30 changed files with 1,870 additions and 1,204 deletions.
28 changes: 14 additions & 14 deletions adam-cli/src/main/scala/org/bdgenomics/adam/cli/Transform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ class TransformArgs extends Args4jBase with ADAMSaveAnyArgs with ParquetArgs {
var markDuplicates: Boolean = false
@Args4jOption(required = false, name = "-recalibrate_base_qualities", usage = "Recalibrate the base quality scores (ILLUMINA only)")
var recalibrateBaseQualities: Boolean = false
@Args4jOption(required = false, name = "-min_acceptable_quality", usage = "Minimum acceptable quality for recalibrating a base in a read. Default is 5.")
var minAcceptableQuality: Int = 5
@Args4jOption(required = false, name = "-stringency", usage = "Stringency level for various checks; can be SILENT, LENIENT, or STRICT. Defaults to LENIENT")
var stringency: String = "LENIENT"
@Args4jOption(required = false, name = "-dump_observations", usage = "Local path to dump BQSR observations to. Outputs CSV format.")
var observationsPath: String = null
@Args4jOption(required = false, name = "-known_snps", usage = "Sites-only VCF giving location of known SNPs")
var knownSnpsFile: String = null
@Args4jOption(required = false, name = "-realign_indels", usage = "Locally realign indels present in reads.")
Expand Down Expand Up @@ -218,25 +218,25 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans
log.info("Recalibrating base qualities")

// bqsr is a two pass algorithm, so cache the rdd if requested
if (args.cache) {
rdd.rdd.persist(sl)
val optSl = if (args.cache) {
Some(sl)
} else {
None
}

// create the known sites file, if one is available
val knownSnps: SnpTable = createKnownSnpsTable(rdd.rdd.context)
val broadcastedSnps = BroadcastingKnownSnps.time {
rdd.rdd.context.broadcast(knownSnps)
}

// run bqsr
val bqsredRdd = rdd.recalibateBaseQualities(
rdd.rdd.context.broadcast(knownSnps),
Option(args.observationsPath),
stringency
broadcastedSnps,
args.minAcceptableQuality,
optSl
)

// if we cached the input, unpersist it, as it is never reused after bqsr
if (args.cache) {
rdd.rdd.unpersist()
}

bqsredRdd
} else {
rdd
Expand Down Expand Up @@ -477,7 +477,7 @@ class Transform(protected val args: TransformArgs) extends BDGSparkCommand[Trans
isSorted = args.sortReads || args.sortLexicographically)
}

private def createKnownSnpsTable(sc: SparkContext): SnpTable = CreateKnownSnpsTable.time {
Option(args.knownSnpsFile).fold(SnpTable())(f => SnpTable(sc.loadVariants(f).rdd.map(new RichVariant(_))))
private def createKnownSnpsTable(sc: SparkContext): SnpTable = {
Option(args.knownSnpsFile).fold(SnpTable())(f => SnpTable(sc.loadVariants(f)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,24 @@ object Timers extends Metrics {

// Recalibrate Base Qualities
val BQSRInDriver = timer("Base Quality Recalibration")
val CreateKnownSnpsTable = timer("Create Known SNPs Table")
val CreatingKnownSnpsTable = timer("Creating Known SNPs Table")
val CollectingSnps = timer("Collecting SNPs")
val BroadcastingKnownSnps = timer("Broadcasting known SNPs")
val ComputeCovariates = timer("Compute Covariates")
val ObservingRead = timer("Observing covariates per read")
val ReadCovariates = timer("Computing covariates per read base")
val ComputingDinucCovariate = timer("Computing dinuc covariate")
val ComputingCycleCovariate = timer("Computing cycle covariate")
val ReadResidues = timer("Splitting read into residues")
val CheckingForMask = timer("Checking if residue is masked")
val ObservationAccumulatorComb = timer("Observation Accumulator: comb")
val ObservationAccumulatorSeq = timer("Observation Accumulator: seq")
val RecalibrateRead = timer("Recalibrate Read")
val ComputeQualityScore = timer("Compute Quality Score")
val GetExtraValues = timer("Get Extra Values")
val CreatingRecalibrationTable = timer("Creating recalibration table")
val InvertingRecalibrationTable = timer("Inverting recalibration table")
val QueryingRecalibrationTable = timer("Querying recalibration table")

// Realign Indels
val RealignIndelsInDriver = timer("Realign Indels")
Expand Down

This file was deleted.

186 changes: 146 additions & 40 deletions adam-core/src/main/scala/org/bdgenomics/adam/models/SnpTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,114 @@
*/
package org.bdgenomics.adam.models

import org.bdgenomics.adam.rich.RichVariant
import org.bdgenomics.adam.rich.DecadentRead._
import org.bdgenomics.utils.misc.Logging
import com.esotericsoftware.kryo.io.{ Input, Output }
import com.esotericsoftware.kryo.{ Kryo, Serializer }
import org.apache.spark.rdd.MetricsContext._
import org.apache.spark.rdd.RDD
import scala.collection.immutable._
import scala.collection.mutable
import org.bdgenomics.adam.instrumentation.Timers._
import org.bdgenomics.adam.rdd.variant.VariantRDD
import org.bdgenomics.utils.misc.Logging
import scala.annotation.tailrec
import scala.math.{ max, min }

/**
* A table containing all of the SNPs in a known variation dataset.
*
* @param table A map between a contig name and a set containing all coordinates
* where a point variant is known to exist.
* @param indices A map of contig names to the (first, last) index in the
* site array that contain data from this contig.
* @param sites An array containing positions that have masked SNPs. Sorted by
* contig name and then position.
*/
class SnpTable(private val table: Map[String, Set[Long]]) extends Serializable with Logging {
log.info("SNP table has %s contigs and %s entries".format(
table.size,
table.values.map(_.size).sum))
class SnpTable private[models] (
private[models] val indices: Map[String, (Int, Int)],
private[models] val sites: Array[Long]) extends Serializable with Logging {

/**
* Is there a known SNP at the reference location of this Residue?
*/
def isMasked(residue: Residue): Boolean =
contains(residue.referencePosition)
private val midpoints: Map[String, Int] = {
@tailrec def pow2ceil(length: Int, i: Int = 1): Int = {
if (2 * i >= length) {
i
} else {
pow2ceil(length, 2 * i)
}
}

/**
* Is there a known SNP at the given reference location?
*/
def contains(location: ReferencePosition): Boolean = {
val bucket = table.get(location.referenceName)
if (bucket.isEmpty) unknownContigWarning(location.referenceName)
bucket.exists(_.contains(location.pos))
indices.mapValues(p => {
val (start, end) = p
pow2ceil(end - start + 1)
})
}

private val unknownContigs = new mutable.HashSet[String]
@tailrec private def binarySearch(rr: ReferenceRegion,
offset: Int,
length: Int,
step: Int,
idx: Int = 0): Option[Int] = {
if (length == 0) {
None
} else if (rr.start <= sites(offset + idx) && rr.end > sites(offset + idx)) {
// if we've satistfied this last condition, then the read is overlapping the
// current index and we have a hit
Some(offset + idx)
} else if (step == 0) {
None
} else {
val stepIdx = idx + step
val nextIdx: Int = if (stepIdx >= length ||
rr.end <= sites(offset + stepIdx)) {
idx
} else {
stepIdx
}
binarySearch(rr, offset, length, step / 2, nextIdx)
}
}

@tailrec private def extendForward(rr: ReferenceRegion,
offset: Int,
idx: Int,
list: List[Long] = List.empty): List[Long] = {
if (idx < offset) {
list
} else {
if (rr.start > sites(idx)) {
list
} else {
extendForward(rr, offset, idx - 1, sites(idx) :: list)
}
}
}

private def unknownContigWarning(contig: String) = {
// This is synchronized to avoid a data race. Multiple threads may
// race to update `unknownContigs`, e.g. when running with a Spark
// master of `local[N]`.
synchronized {
if (!unknownContigs.contains(contig)) {
unknownContigs += contig
log.warn("Contig has no entries in known SNPs table: %s".format(contig))
@tailrec private def extendBackwards(rr: ReferenceRegion,
end: Int,
idx: Int,
list: List[Long]): Set[Long] = {
if (idx > end) {
list.toSet
} else {
if (rr.end <= sites(idx)) {
list.toSet
} else {
extendBackwards(rr, end, idx + 1, sites(idx) :: list)
}
}
}

/**
* Is there a known SNP at the reference location of this Residue?
*/
private[adam] def maskedSites(rr: ReferenceRegion): Set[Long] = CheckingForMask.time {
val optRange = indices.get(rr.referenceName)

optRange.flatMap(range => {
val (offset, end) = range
val optIdx = binarySearch(rr, offset, end - offset + 1, midpoints(rr.referenceName))

optIdx.map(idx => {
extendBackwards(rr, end, idx + 1, extendForward(rr, offset, idx))
.map(_.toLong)
})
}).getOrElse(Set.empty)
}
}

/**
Expand All @@ -76,20 +138,64 @@ object SnpTable {
* @return An empty SNP table.
*/
def apply(): SnpTable = {
new SnpTable(Map[String, Set[Long]]())
new SnpTable(Map.empty,
Array.empty)
}

/**
* Creates a SNP Table from an RDD of RichVariants.
* Creates a SNP Table from a VariantRDD.
*
* @param variants The variants to populate the table from.
* @return Returns a new SNPTable containing the input variants.
*/
def apply(variants: RDD[RichVariant]): SnpTable = {
val positions = variants.map(variant => (variant.variant.getContigName,
variant.variant.getStart)).collect()
val table = new mutable.HashMap[String, mutable.HashSet[Long]]
positions.foreach(tup => table.getOrElseUpdate(tup._1, { new mutable.HashSet[Long] }) += tup._2)
new SnpTable(table.mapValues(_.toSet).toMap)
def apply(variants: VariantRDD): SnpTable = CreatingKnownSnpsTable.time {
val (indices, positions) = CollectingSnps.time {
val sortedVariants = variants.sort
.rdd
.cache()

val contigIndices = sortedVariants.map(_.getContigName)
.zipWithIndex
.mapValues(v => (v.toInt, v.toInt))
.reduceByKeyLocally((p1, p2) => {
(min(p1._1, p2._1), max(p1._2, p2._2))
}).toMap
val sites = sortedVariants.map(_.getStart: Long).collect()

// unpersist the cached variants
sortedVariants.unpersist()

(contigIndices, sites)
}
new SnpTable(indices, positions)
}
}

private[adam] class SnpTableSerializer extends Serializer[SnpTable] {

def write(kryo: Kryo, output: Output, obj: SnpTable) {
output.writeInt(obj.indices.size)
obj.indices.foreach(kv => {
val (contigName, (lowerBound, upperBound)) = kv
output.writeString(contigName)
output.writeInt(lowerBound)
output.writeInt(upperBound)
})
output.writeInt(obj.sites.length)
obj.sites.foreach(output.writeLong(_))
}

def read(kryo: Kryo, input: Input, klazz: Class[SnpTable]): SnpTable = {
val indicesSize = input.readInt()
val indices = new Array[(String, (Int, Int))](indicesSize)
(0 until indicesSize).foreach(i => {
indices(i) = (input.readString(), (input.readInt(), input.readInt()))
})
val sitesSize = input.readInt()
val sites = new Array[Long](sitesSize)
(0 until sitesSize).foreach(i => {
sites(i) = input.readLong()
})
new SnpTable(indices.toMap, sites)
}
}
Loading

0 comments on commit 51a8ce2

Please sign in to comment.