New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BQSR updates #213

Merged
merged 7 commits into from Apr 22, 2014
Jump to file or symbol
Failed to load files and symbols.
+16,414 −554
Diff settings

Always

Just for now

@@ -78,7 +78,7 @@ class Transform(protected val args: TransformArgs) extends ADAMSparkCommand[Tran
log.info("Recalibrating base qualities")
val variants: RDD[RichADAMVariant] = sc.adamVCFLoad(args.knownSnpsFile).map(_.variant)
val knownSnps = SnpTable(variants)
adamRecords = adamRecords.adamBQSR(knownSnps)
adamRecords = adamRecords.adamBQSR(sc.broadcast(knownSnps))
}
if (args.locallyRealign) {
@@ -40,6 +40,7 @@ import java.io.File
import java.util.logging.Level
import org.apache.avro.specific.SpecificRecord
import org.apache.hadoop.mapreduce.Job
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
@@ -181,7 +182,7 @@ class ADAMRecordRDDFunctions(rdd: RDD[ADAMRecord]) extends ADAMSequenceDictionar
MarkDuplicates(rdd)
}
def adamBQSR(knownSnps: SnpTable): RDD[ADAMRecord] = {
def adamBQSR(knownSnps: Broadcast[SnpTable]): RDD[ADAMRecord] = {

This comment has been minimized.

@fnothaft

fnothaft Apr 9, 2014

Member

IMO, it would be preferable to have:

def adamBQSR(knownSnps: SnpTable): RDD[ADAMRecord] = {
     BaseQualityRecalibration(rdd, rdd.context.broadcast(knownSnps))

I think this is a bit cleaner, would be glad to hear comments.

This comment has been minimized.

@jey

jey Apr 11, 2014

Contributor

I did it this way so that the variable would only be broadcast once. In practical terms, this is moot right now since we don't reuse it except in BQSR, but this would make it clear to avoid re-broadcasting the SnpTable in the future. I'm open to either approach.

BaseQualityRecalibration(rdd, knownSnps)
}
@@ -26,6 +26,7 @@ import org.bdgenomics.adam.util.QualityScore
import org.apache.spark.Logging
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
/**
* Base Quality Score Recalibration
@@ -40,55 +41,90 @@ import org.apache.spark.rdd.RDD
* assign adjusted quality scores.
*/
class BaseQualityRecalibration(
val reads: RDD[DecadentRead],
val input: RDD[DecadentRead],
val knownSnps: Broadcast[SnpTable])
extends Serializable with Logging {
// Additional covariates to use when computing the correction
// TODO: parameterize
val covariates = CovariateSpace(
new DinucCovariate)
val covariates = CovariateSpace(new CycleCovariate, new DinucCovariate)
// Bases with quality less than this will be skipped and left alone
// TODO: parameterize
val minAcceptableQuality = QualityScore(5)
val observed: ObservationTable = {
// debug flags
val dumpObservationTable = false
val enableVisitLogging = false
val dataset: RDD[(CovariateKey, Residue)] = {
def shouldIncludeRead(read: DecadentRead) =
read.isCanonicalRecord &&
read.alignmentQuality.exists(_ > QualityScore.zero) &&
read.passedQualityChecks
reads.
filter(shouldIncludeRead).flatMap(observe).
def shouldIncludeResidue(residue: Residue) =
residue.quality > QualityScore.zero &&
residue.isRegularBase &&
!residue.isInsertion &&
!knownSnps.value.isMasked(residue)
def observe(read: DecadentRead): Seq[(CovariateKey, Residue)] =
covariates(read).zip(read.residues).
filter { case (key, residue) => shouldIncludeResidue(residue) }
input.filter(shouldIncludeRead).flatMap(observe)
}
if (enableVisitLogging) {
input.cache
dataset.cache
dumpVisits("bqsr-visits.dump")
}
val observed: ObservationTable = {
dataset.
map { case (key, residue) => (key, Observation(residue.isSNP)) }.
aggregate(ObservationAccumulator(covariates))(_ += _, _ ++= _).result
}
if (dumpObservationTable) {
println(observed.toCSV)
}
val result: RDD[ADAMRecord] = {
val recalibrator = Recalibrator(observed, minAcceptableQuality)
reads.map(recalibrator)
input.map(recalibrator)
}
// Compute observation table for a single read
private def observe(read: DecadentRead): Seq[(CovariateKey, Observation)] = {
def shouldIncludeResidue(residue: Residue) =
residue.quality > QualityScore.zero &&
residue.isRegularBase &&
!residue.isInsertion &&
!knownSnps.value.isMasked(residue)
private def dumpVisits(filename: String) = {

This comment has been minimized.

@fnothaft

fnothaft Apr 9, 2014

Member

What is a visit? Can you add docs?

This comment has been minimized.

@massie

massie Apr 21, 2014

Member

I pretty sure Jey's talking about the visitor pattern here.

def readId(read: DecadentRead): String =
read.name +
(if (read.isNegativeRead) "-" else "+") +
(if (read.record.getFirstOfPair) "1" else "") +
(if (read.record.getSecondOfPair) "2" else "")
// Compute keys and filter out skipped residues
val keys: Seq[(CovariateKey, Residue)] =
covariates(read).zip(read.sequence).filter(x => shouldIncludeResidue(x._2))
val readLengths =
input.map(read => (readId(read), read.residues.length)).collectAsMap
// Construct result
keys.map { case (key, residue) => (key, Observation(residue.isSNP)) }
val visited = dataset.
map { case (key, residue) => (readId(residue.read), Seq(residue.offset)) }.
reduceByKeyLocally((left, right) => left ++ right)
val outf = new java.io.File(filename)
val writer = new java.io.PrintWriter(outf)
visited.foreach {
case (readName, visited) =>
val length = readLengths(readName)
val buf = Array.fill[Char](length)('O')
visited.foreach { idx => buf(idx) = 'X' }
writer.println(readName + "\t" + String.valueOf(buf))
}
writer.close
}
}
object BaseQualityRecalibration {
def apply(rdd: RDD[ADAMRecord], knownSnps: SnpTable): RDD[ADAMRecord] = {
val sc = rdd.context
new BaseQualityRecalibration(cloy(rdd), sc.broadcast(knownSnps)).result
}
def apply(rdd: RDD[ADAMRecord], knownSnps: Broadcast[SnpTable]): RDD[ADAMRecord] =
new BaseQualityRecalibration(cloy(rdd), knownSnps).result
}
@@ -96,12 +96,12 @@ class CovariateSpace(val extras: IndexedSeq[Covariate]) extends Serializable {
val extraVals = extras.map(cov => {
val result = cov(read)
// Each covariate must return a value per Residue
assert(result.size == read.sequence.size)
assert(result.size == read.residues.size)
result
})
// Construct the CovariateKeys
read.sequence.zipWithIndex.map {
read.residues.zipWithIndex.map {
case (residue, residueIdx) =>
val residueExtras = extraVals.map(_(residueIdx))
new CovariateKey(read.readGroup, residue.quality, residueExtras)
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2014 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.bdgenomics.adam.rdd.recalibration
import org.bdgenomics.adam.rich.DecadentRead
import org.bdgenomics.adam.rich.DecadentRead._
// This is based on the CycleCovariate in GATK 1.6.
class CycleCovariate extends AbstractCovariate[Int] {
def compute(read: DecadentRead): Seq[Option[Int]] = {
val (initial, increment) = initialization(read)
Range(0, read.residues.length).map(pos => Some(initial + increment * pos))
}
// Returns (initialValue, increment)
private def initialization(read: DecadentRead): (Int, Int) = {
if (!read.isNegativeRead) {
if (read.isSecondOfPair) {
(-1, -1)
} else {
(1, 1)
}
} else {
if (read.isSecondOfPair) {
(-read.residues.length, 1)
} else {
(read.residues.length, -1)
}
}
}
override def csvFieldName: String = "Cycle"
override def equals(other: Any) = other match {
case that: CycleCovariate => true
case _ => false
}
override def hashCode = 0x83EFAB61
}
@@ -22,7 +22,7 @@ import org.bdgenomics.adam.rich.DecadentRead._
// TODO: should inherit from something like AbstractCovariate[(DNABase, DNABase)]
class DinucCovariate extends AbstractCovariate[(Char, Char)] {
def compute(read: DecadentRead): Seq[Option[(Char, Char)]] = {
val sequence = read.sequence.map(_.base)
val sequence = read.residues.map(_.base)
if (read.isNegativeRead) {
/* Use the reverse-complement of the sequence to get back the original
* sequence as it was read by the sequencing machine. The sequencer
@@ -30,11 +30,6 @@ import scala.collection.mutable
class Observation(val total: Long, val mismatches: Long) extends Serializable {
require(mismatches >= 0 && mismatches <= total)
/**
* Whether to emulate GATK's calculations.

This comment has been minimized.

@fnothaft

fnothaft Apr 9, 2014

Member

What changes if you do/don't emulate GATK? Can you add docs?

This comment has been minimized.

@massie

massie Apr 21, 2014

Member

This was mainly used for debugging more than anything. This change makes us as concordant with GATK as possible without inheriting their bugs.

This comment has been minimized.

@jey

jey Apr 22, 2014

Contributor

Yup, exactly. The old code was trying to use the same logic as GATK 1.6, but the Bayesian approach is both more principled/correct and now gets >99.99% concordance with GATK.

I'm not sure why we get lower concordance when using GATK 1.6's logic, but I'm guessing it's because of two effects: a) their complicated logic basically reduces to the same as ours, and b) I didn't reproduce some aspect of their formula accurately.

*/
val emulateGATK: Boolean = true
def this(that: Observation) = this(that.total, that.mismatches)
def +(that: Observation) =
@@ -44,7 +39,7 @@ class Observation(val total: Long, val mismatches: Long) extends Serializable {
* Empirically estimated probability of a mismatch.
*/
def empiricalErrorProbability: Double =
if (!emulateGATK) bayesianErrorProbability else gatkErrorProbability
bayesianErrorProbability
/**
* Empirically estimated probability of a mismatch, as a QualityScore.
@@ -63,21 +58,8 @@ class Observation(val total: Long, val mismatches: Long) extends Serializable {
def bayesianErrorProbability: Double = bayesianErrorProbability(1, 1)
def bayesianErrorProbability(a: Double, b: Double): Double = (a + mismatches) / (a + b + total)
// GATK 1.6 uses the following, which they describe as "Yates's correction". However,
// it doesn't match the Wikipedia entry which describes a different formula used
// for correction of chi-squared independence tests on contingency tables.
// TODO: Figure out this discrepancy.
def gatkErrorProbability: Double = gatkErrorProbability(1)
def gatkErrorProbability(smoothing: Double): Double = {
val errProb = (mismatches + smoothing) / (total + smoothing)
Seq(QualityScore(50).errorProbability, errProb + 0.0001).max
}
// Format as string compatible with GATK's CSV output
def toCSV: Seq[String] = Seq(total.toString, mismatches.toString, empiricalQualityForCSV.phred.toString)
def empiricalQualityForCSV: QualityScore =
if (!emulateGATK) empiricalQuality else QualityScore.fromErrorProbability(gatkErrorProbability(0))
def toCSV: Seq[String] = Seq(total.toString, mismatches.toString, empiricalQuality.phred.toString)
override def toString: String =
"%s / %s (%s)".format(mismatches, total, empiricalQuality)
@@ -117,7 +99,7 @@ object Aggregate {
val empty: Aggregate = new Aggregate(0, 0, 0)
def apply(key: CovariateKey, value: Observation) =
new Aggregate(value.total, value.mismatches, key.quality.errorProbability * value.mismatches)
new Aggregate(value.total, value.mismatches, key.quality.errorProbability * value.total)
}
/**
@@ -37,7 +37,7 @@ class Recalibrator(val table: RecalibrationTable, val minAcceptableQuality: Qual
}
def computeQual(read: DecadentRead): Seq[QualityScore] = {
val origQuals = read.sequence.map(_.quality)
val origQuals = read.residues.map(_.quality)
val newQuals = table(read)
origQuals.zip(newQuals).map {
case (origQ, newQ) =>
@@ -64,6 +64,11 @@ class RecalibrationTable(
val extraTables: IndexedSeq[Map[(String, QualityScore, Option[Covariate#Value]), Aggregate]])
extends (DecadentRead => Seq[QualityScore]) with Serializable {
// TODO: parameterize?
val maxQualScore = QualityScore(50)
val maxLogP = log(maxQualScore.errorProbability)
def apply(read: DecadentRead): Seq[QualityScore] =
covariates(read).map(lookup)
@@ -73,12 +78,17 @@ class RecalibrationTable(
val qualityDelta = computeQualityDelta(key, residueLogP + globalDelta)
val extrasDelta = computeExtrasDelta(key, residueLogP + globalDelta + qualityDelta)
val correctedLogP = residueLogP + globalDelta + qualityDelta + extrasDelta
QualityScore.fromErrorProbability(exp(correctedLogP))
qualityFromLogP(correctedLogP)
}
def qualityFromLogP(logP: Double): QualityScore = {
val boundedLogP = math.min(0.0, math.max(maxLogP, logP))
QualityScore.fromErrorProbability(exp(boundedLogP))
}
def computeGlobalDelta(key: CovariateKey): Double = {
globalTable.get(key.readGroup).
map(bucket => log(bucket.reportedErrorProbability) - log(bucket.empiricalErrorProbability)).
map(bucket => log(bucket.empiricalErrorProbability) - log(bucket.reportedErrorProbability)).
getOrElse(0.0)
}
@@ -72,19 +72,20 @@ class DecadentRead(val record: RichADAMRecord) extends Logging {
require(record.referencePositions.length == record.getSequence.length)
/**
* A "residue" is an individual monomer of a polymeric chain, such as DNA.
* In biochemistry and molecular biology, a "residue" refers to a specific
* monomer within a polymeric chain, such as DNA.
*/
class Residue private[DecadentRead] (val position: Int) {
class Residue private[DecadentRead] (val offset: Int) {
def read = DecadentRead.this
/**
* Nucleotide at this position.
* Nucleotide at this offset.
*
* TODO: Return values of meaningful type, e.g. `DNABase' or `Deoxyribonucleotide'.
* TODO: Return values of meaningful type, e.g. `DNABase'.
*/
def base: Char = read.baseSequence(position)
def base: Char = read.baseSequence(offset)
def quality = QualityScore(record.qualityScores(position))
def quality = QualityScore(record.qualityScores(offset))
def isRegularBase: Boolean = base match {
case 'A' => true
@@ -96,17 +97,17 @@ class DecadentRead(val record: RichADAMRecord) extends Logging {
}
def isMismatch(includeInsertions: Boolean = true): Boolean =
assumingAligned(record.isMismatchAtReadOffset(position).getOrElse(includeInsertions))
assumingAligned(record.isMismatchAtReadOffset(offset).getOrElse(includeInsertions))
def isSNP: Boolean = isMismatch(false)
def isInsertion: Boolean =
assumingAligned(record.isMismatchAtReadOffset(position).isEmpty)
assumingAligned(record.isMismatchAtReadOffset(offset).isEmpty)
def referenceLocationOption: Option[ReferenceLocation] =
assumingAligned(
record.readOffsetToReferencePosition(position).
map(refOffset => new ReferenceLocation(record.getReferenceName.toString, refOffset)))
def referenceLocationOption: Option[ReferenceLocation] = assumingAligned {
record.readOffsetToReferencePosition(offset).
map(refOffset => new ReferenceLocation(record.getReferenceName.toString, refOffset))
}
def referenceLocation: ReferenceLocation =
referenceLocationOption.getOrElse(
@@ -117,7 +118,9 @@ class DecadentRead(val record: RichADAMRecord) extends Logging {
private lazy val baseSequence: String = record.getSequence.toString
lazy val sequence: IndexedSeq[Residue] = Range(0, baseSequence.length).map(new Residue(_))
lazy val residues: IndexedSeq[Residue] = Range(0, baseSequence.length).map(new Residue(_))
def name: String = record.getReadName
def isAligned: Boolean = record.getReadMapped
@@ -136,6 +139,12 @@ class DecadentRead(val record: RichADAMRecord) extends Logging {
def isDuplicate: Boolean = record.getDuplicateRead
def isPaired: Boolean = record.getReadPaired

This comment has been minimized.

@fnothaft

fnothaft Apr 9, 2014

Member

Why do we keep these values around, since they can be accessed directly from the read?

This comment has been minimized.

@jey

jey Apr 22, 2014

Contributor

The intention here is to provide a unified clean interface to the underlying data, as opposed to having to know when to use a DecadentRead method and when to drop down to the ADAMRecord.

def isFirstOfPair: Boolean = isPaired && !record.getSecondOfPair
def isSecondOfPair: Boolean = isPaired && record.getSecondOfPair
def isNegativeRead: Boolean = record.getReadNegativeStrand
// Is this the most representative record for this read?
Oops, something went wrong.
ProTip! Use n and p to navigate between commits in a pull request.