Skip to content

Commit

Permalink
Merge pull request #213 from jey/bqsr-updates
Browse files Browse the repository at this point in the history
BQSR updates
  • Loading branch information
massie committed Apr 22, 2014
2 parents 17621dc + 6bd83d9 commit a13c758
Show file tree
Hide file tree
Showing 12 changed files with 16,416 additions and 554 deletions.
Expand Up @@ -78,7 +78,7 @@ class Transform(protected val args: TransformArgs) extends ADAMSparkCommand[Tran
log.info("Recalibrating base qualities")
val variants: RDD[RichADAMVariant] = sc.adamVCFLoad(args.knownSnpsFile).map(_.variant)
val knownSnps = SnpTable(variants)
adamRecords = adamRecords.adamBQSR(knownSnps)
adamRecords = adamRecords.adamBQSR(sc.broadcast(knownSnps))
}

if (args.locallyRealign) {
Expand Down
Expand Up @@ -40,6 +40,7 @@ import java.io.File
import java.util.logging.Level
import org.apache.avro.specific.SpecificRecord
import org.apache.hadoop.mapreduce.Job
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -181,7 +182,7 @@ class ADAMRecordRDDFunctions(rdd: RDD[ADAMRecord]) extends ADAMSequenceDictionar
MarkDuplicates(rdd)
}

def adamBQSR(knownSnps: SnpTable): RDD[ADAMRecord] = {
def adamBQSR(knownSnps: Broadcast[SnpTable]): RDD[ADAMRecord] = {
BaseQualityRecalibration(rdd, knownSnps)
}

Expand Down
Expand Up @@ -26,6 +26,7 @@ import org.bdgenomics.adam.util.QualityScore
import org.apache.spark.Logging
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._

/**
* Base Quality Score Recalibration
Expand All @@ -40,55 +41,92 @@ import org.apache.spark.rdd.RDD
* assign adjusted quality scores.
*/
class BaseQualityRecalibration(
val reads: RDD[DecadentRead],
val input: RDD[DecadentRead],
val knownSnps: Broadcast[SnpTable])
extends Serializable with Logging {

// Additional covariates to use when computing the correction
// TODO: parameterize
val covariates = CovariateSpace(
new DinucCovariate)
val covariates = CovariateSpace(new CycleCovariate, new DinucCovariate)

// Bases with quality less than this will be skipped and left alone
// TODO: parameterize
val minAcceptableQuality = QualityScore(5)

val observed: ObservationTable = {
// Debug: Print the ObservationTable to stdout
val dumpObservationTable = false

// Debug: Log the visited/skipped residues to bqsr-visits.dump
val enableVisitLogging = false

val dataset: RDD[(CovariateKey, Residue)] = {
def shouldIncludeRead(read: DecadentRead) =
read.isCanonicalRecord &&
read.alignmentQuality.exists(_ > QualityScore.zero) &&
read.passedQualityChecks

reads.
filter(shouldIncludeRead).flatMap(observe).
def shouldIncludeResidue(residue: Residue) =
residue.quality > QualityScore.zero &&
residue.isRegularBase &&
!residue.isInsertion &&
!knownSnps.value.isMasked(residue)

def observe(read: DecadentRead): Seq[(CovariateKey, Residue)] =
covariates(read).zip(read.residues).
filter { case (key, residue) => shouldIncludeResidue(residue) }

input.filter(shouldIncludeRead).flatMap(observe)
}

if (enableVisitLogging) {
input.cache
dataset.cache
dumpVisits("bqsr-visits.dump")
}

val observed: ObservationTable = {
dataset.
map { case (key, residue) => (key, Observation(residue.isSNP)) }.
aggregate(ObservationAccumulator(covariates))(_ += _, _ ++= _).result
}

if (dumpObservationTable) {
println(observed.toCSV)
}

val result: RDD[ADAMRecord] = {
val recalibrator = Recalibrator(observed, minAcceptableQuality)
reads.map(recalibrator)
input.map(recalibrator)
}

// Compute observation table for a single read
private def observe(read: DecadentRead): Seq[(CovariateKey, Observation)] = {
def shouldIncludeResidue(residue: Residue) =
residue.quality > QualityScore.zero &&
residue.isRegularBase &&
!residue.isInsertion &&
!knownSnps.value.isMasked(residue)
private def dumpVisits(filename: String) = {
def readId(read: DecadentRead): String =
read.name +
(if (read.isNegativeRead) "-" else "+") +
(if (read.record.getFirstOfPair) "1" else "") +
(if (read.record.getSecondOfPair) "2" else "")

val readLengths =
input.map(read => (readId(read), read.residues.length)).collectAsMap

// Compute keys and filter out skipped residues
val keys: Seq[(CovariateKey, Residue)] =
covariates(read).zip(read.sequence).filter(x => shouldIncludeResidue(x._2))
val visited = dataset.
map { case (key, residue) => (readId(residue.read), Seq(residue.offset)) }.
reduceByKeyLocally((left, right) => left ++ right)

// Construct result
keys.map { case (key, residue) => (key, Observation(residue.isSNP)) }
val outf = new java.io.File(filename)
val writer = new java.io.PrintWriter(outf)
visited.foreach {
case (readName, visited) =>
val length = readLengths(readName)
val buf = Array.fill[Char](length)('O')
visited.foreach { idx => buf(idx) = 'X' }
writer.println(readName + "\t" + String.valueOf(buf))
}
writer.close
}
}

object BaseQualityRecalibration {
def apply(rdd: RDD[ADAMRecord], knownSnps: SnpTable): RDD[ADAMRecord] = {
val sc = rdd.context
new BaseQualityRecalibration(cloy(rdd), sc.broadcast(knownSnps)).result
}
def apply(rdd: RDD[ADAMRecord], knownSnps: Broadcast[SnpTable]): RDD[ADAMRecord] =
new BaseQualityRecalibration(cloy(rdd), knownSnps).result
}
Expand Up @@ -96,12 +96,12 @@ class CovariateSpace(val extras: IndexedSeq[Covariate]) extends Serializable {
val extraVals = extras.map(cov => {
val result = cov(read)
// Each covariate must return a value per Residue
assert(result.size == read.sequence.size)
assert(result.size == read.residues.size)
result
})

// Construct the CovariateKeys
read.sequence.zipWithIndex.map {
read.residues.zipWithIndex.map {
case (residue, residueIdx) =>
val residueExtras = extraVals.map(_(residueIdx))
new CovariateKey(read.readGroup, residue.quality, residueExtras)
Expand Down
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2014 The Regents of the University of California
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.bdgenomics.adam.rdd.recalibration

import org.bdgenomics.adam.rich.DecadentRead
import org.bdgenomics.adam.rich.DecadentRead._

// This is based on the CycleCovariate in GATK 1.6.
class CycleCovariate extends AbstractCovariate[Int] {
def compute(read: DecadentRead): Seq[Option[Int]] = {
val (initial, increment) = initialization(read)
Range(0, read.residues.length).map(pos => Some(initial + increment * pos))
}

// Returns (initialValue, increment)
private def initialization(read: DecadentRead): (Int, Int) = {
if (!read.isNegativeRead) {
if (read.isSecondOfPair) {
(-1, -1)
} else {
(1, 1)
}
} else {
if (read.isSecondOfPair) {
(-read.residues.length, 1)
} else {
(read.residues.length, -1)
}
}
}

override def csvFieldName: String = "Cycle"

override def equals(other: Any) = other match {
case that: CycleCovariate => true
case _ => false
}

override def hashCode = 0x83EFAB61
}

Expand Up @@ -22,7 +22,7 @@ import org.bdgenomics.adam.rich.DecadentRead._
// TODO: should inherit from something like AbstractCovariate[(DNABase, DNABase)]
class DinucCovariate extends AbstractCovariate[(Char, Char)] {
def compute(read: DecadentRead): Seq[Option[(Char, Char)]] = {
val sequence = read.sequence.map(_.base)
val sequence = read.residues.map(_.base)
if (read.isNegativeRead) {
/* Use the reverse-complement of the sequence to get back the original
* sequence as it was read by the sequencing machine. The sequencer
Expand Down
Expand Up @@ -30,11 +30,6 @@ import scala.collection.mutable
class Observation(val total: Long, val mismatches: Long) extends Serializable {
require(mismatches >= 0 && mismatches <= total)

/**
* Whether to emulate GATK's calculations.
*/
val emulateGATK: Boolean = true

def this(that: Observation) = this(that.total, that.mismatches)

def +(that: Observation) =
Expand All @@ -44,7 +39,7 @@ class Observation(val total: Long, val mismatches: Long) extends Serializable {
* Empirically estimated probability of a mismatch.
*/
def empiricalErrorProbability: Double =
if (!emulateGATK) bayesianErrorProbability else gatkErrorProbability
bayesianErrorProbability

/**
* Empirically estimated probability of a mismatch, as a QualityScore.
Expand All @@ -63,21 +58,8 @@ class Observation(val total: Long, val mismatches: Long) extends Serializable {
def bayesianErrorProbability: Double = bayesianErrorProbability(1, 1)
def bayesianErrorProbability(a: Double, b: Double): Double = (a + mismatches) / (a + b + total)

// GATK 1.6 uses the following, which they describe as "Yates's correction". However,
// it doesn't match the Wikipedia entry which describes a different formula used
// for correction of chi-squared independence tests on contingency tables.
// TODO: Figure out this discrepancy.
def gatkErrorProbability: Double = gatkErrorProbability(1)
def gatkErrorProbability(smoothing: Double): Double = {
val errProb = (mismatches + smoothing) / (total + smoothing)
Seq(QualityScore(50).errorProbability, errProb + 0.0001).max
}

// Format as string compatible with GATK's CSV output
def toCSV: Seq[String] = Seq(total.toString, mismatches.toString, empiricalQualityForCSV.phred.toString)

def empiricalQualityForCSV: QualityScore =
if (!emulateGATK) empiricalQuality else QualityScore.fromErrorProbability(gatkErrorProbability(0))
def toCSV: Seq[String] = Seq(total.toString, mismatches.toString, empiricalQuality.phred.toString)

override def toString: String =
"%s / %s (%s)".format(mismatches, total, empiricalQuality)
Expand Down Expand Up @@ -117,7 +99,7 @@ object Aggregate {
val empty: Aggregate = new Aggregate(0, 0, 0)

def apply(key: CovariateKey, value: Observation) =
new Aggregate(value.total, value.mismatches, key.quality.errorProbability * value.mismatches)
new Aggregate(value.total, value.mismatches, key.quality.errorProbability * value.total)
}

/**
Expand Down
Expand Up @@ -37,7 +37,7 @@ class Recalibrator(val table: RecalibrationTable, val minAcceptableQuality: Qual
}

def computeQual(read: DecadentRead): Seq[QualityScore] = {
val origQuals = read.sequence.map(_.quality)
val origQuals = read.residues.map(_.quality)
val newQuals = table(read)
origQuals.zip(newQuals).map {
case (origQ, newQ) =>
Expand All @@ -64,6 +64,11 @@ class RecalibrationTable(
val extraTables: IndexedSeq[Map[(String, QualityScore, Option[Covariate#Value]), Aggregate]])
extends (DecadentRead => Seq[QualityScore]) with Serializable {

// TODO: parameterize?
val maxQualScore = QualityScore(50)

val maxLogP = log(maxQualScore.errorProbability)

def apply(read: DecadentRead): Seq[QualityScore] =
covariates(read).map(lookup)

Expand All @@ -73,12 +78,17 @@ class RecalibrationTable(
val qualityDelta = computeQualityDelta(key, residueLogP + globalDelta)
val extrasDelta = computeExtrasDelta(key, residueLogP + globalDelta + qualityDelta)
val correctedLogP = residueLogP + globalDelta + qualityDelta + extrasDelta
QualityScore.fromErrorProbability(exp(correctedLogP))
qualityFromLogP(correctedLogP)
}

def qualityFromLogP(logP: Double): QualityScore = {
val boundedLogP = math.min(0.0, math.max(maxLogP, logP))
QualityScore.fromErrorProbability(exp(boundedLogP))
}

def computeGlobalDelta(key: CovariateKey): Double = {
globalTable.get(key.readGroup).
map(bucket => log(bucket.reportedErrorProbability) - log(bucket.empiricalErrorProbability)).
map(bucket => log(bucket.empiricalErrorProbability) - log(bucket.reportedErrorProbability)).
getOrElse(0.0)
}

Expand Down
Expand Up @@ -72,19 +72,20 @@ class DecadentRead(val record: RichADAMRecord) extends Logging {
require(record.referencePositions.length == record.getSequence.length)

/**
* A "residue" is an individual monomer of a polymeric chain, such as DNA.
* In biochemistry and molecular biology, a "residue" refers to a specific
* monomer within a polymeric chain, such as DNA.
*/
class Residue private[DecadentRead] (val position: Int) {
class Residue private[DecadentRead] (val offset: Int) {
def read = DecadentRead.this

/**
* Nucleotide at this position.
* Nucleotide at this offset.
*
* TODO: Return values of meaningful type, e.g. `DNABase' or `Deoxyribonucleotide'.
* TODO: Return values of meaningful type, e.g. `DNABase'.
*/
def base: Char = read.baseSequence(position)
def base: Char = read.baseSequence(offset)

def quality = QualityScore(record.qualityScores(position))
def quality = QualityScore(record.qualityScores(offset))

def isRegularBase: Boolean = base match {
case 'A' => true
Expand All @@ -96,17 +97,17 @@ class DecadentRead(val record: RichADAMRecord) extends Logging {
}

def isMismatch(includeInsertions: Boolean = true): Boolean =
assumingAligned(record.isMismatchAtReadOffset(position).getOrElse(includeInsertions))
assumingAligned(record.isMismatchAtReadOffset(offset).getOrElse(includeInsertions))

def isSNP: Boolean = isMismatch(false)

def isInsertion: Boolean =
assumingAligned(record.isMismatchAtReadOffset(position).isEmpty)
assumingAligned(record.isMismatchAtReadOffset(offset).isEmpty)

def referenceLocationOption: Option[ReferenceLocation] =
assumingAligned(
record.readOffsetToReferencePosition(position).
map(refOffset => new ReferenceLocation(record.getReferenceName.toString, refOffset)))
def referenceLocationOption: Option[ReferenceLocation] = assumingAligned {
record.readOffsetToReferencePosition(offset).
map(refOffset => new ReferenceLocation(record.getReferenceName.toString, refOffset))
}

def referenceLocation: ReferenceLocation =
referenceLocationOption.getOrElse(
Expand All @@ -117,7 +118,9 @@ class DecadentRead(val record: RichADAMRecord) extends Logging {

private lazy val baseSequence: String = record.getSequence.toString

lazy val sequence: IndexedSeq[Residue] = Range(0, baseSequence.length).map(new Residue(_))
lazy val residues: IndexedSeq[Residue] = Range(0, baseSequence.length).map(new Residue(_))

def name: String = record.getReadName

def isAligned: Boolean = record.getReadMapped

Expand All @@ -136,6 +139,12 @@ class DecadentRead(val record: RichADAMRecord) extends Logging {

def isDuplicate: Boolean = record.getDuplicateRead

def isPaired: Boolean = record.getReadPaired

def isFirstOfPair: Boolean = isPaired && !record.getSecondOfPair

def isSecondOfPair: Boolean = isPaired && record.getSecondOfPair

def isNegativeRead: Boolean = record.getReadNegativeStrand

// Is this the most representative record for this read?
Expand Down

0 comments on commit a13c758

Please sign in to comment.