Skip to content

Commit

Permalink
[ADAM-1501] Compute coverage using Dataset API.
Browse files Browse the repository at this point in the history
Resolves #1501.
  • Loading branch information
fnothaft committed Jun 24, 2017
1 parent d5ea78a commit 7215c0d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 31 deletions.
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.MetricsContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{ Dataset, SQLContext }
import org.apache.spark.sql.{ Dataset, Row, SQLContext }
import org.apache.spark.storage.StorageLevel
import org.bdgenomics.adam.algorithms.consensus.{
ConsensusGenerator,
Expand All @@ -48,7 +48,11 @@ import org.bdgenomics.adam.rdd.{
JavaSaveArgs,
SAMHeaderWriter
}
import org.bdgenomics.adam.rdd.feature.{ CoverageRDD, RDDBoundCoverageRDD }
import org.bdgenomics.adam.rdd.feature.{
CoverageRDD,
DatasetBoundCoverageRDD,
RDDBoundCoverageRDD
}
import org.bdgenomics.adam.rdd.read.realignment.RealignIndels
import org.bdgenomics.adam.rdd.read.recalibration.BaseQualityRecalibration
import org.bdgenomics.adam.rdd.fragment.FragmentRDD
Expand Down Expand Up @@ -218,6 +222,31 @@ case class RDDBoundAlignmentRecordRDD private[rdd] (
import sqlContext.implicits._
sqlContext.createDataset(rdd.map(AlignmentRecordProduct.fromAvro))
}

override def toCoverage(): CoverageRDD = {
val covCounts =
rdd.rdd
.filter(r => {
val readMapped = r.getReadMapped

// validate alignment fields
if (readMapped) {
require(r.getStart != null && r.getEnd != null && r.getContigName != null,
"Read was mapped but was missing alignment start/end/contig (%s).".format(r))
}

readMapped
}).flatMap(r => {
val t: List[Long] = List.range(r.getStart, r.getEnd)
t.map(n => (ReferenceRegion(r.getContigName, n, n + 1), 1))
}).reduceByKey(_ + _)
.map(r => Coverage(r._1, r._2.toDouble))

RDDBoundCoverageRDD(covCounts, sequences)
}
}

private case class AlignmentWindow(contigName: String, start: Long, end: Long) {
}

sealed abstract class AlignmentRecordRDD extends AvroReadGroupGenomicRDD[AlignmentRecord, AlignmentRecordProduct, AlignmentRecordRDD] {
Expand Down Expand Up @@ -311,25 +340,30 @@ sealed abstract class AlignmentRecordRDD extends AvroReadGroupGenomicRDD[Alignme
* @return CoverageRDD containing mapped RDD of Coverage.
*/
def toCoverage(): CoverageRDD = {
val covCounts =
rdd.rdd
.filter(r => {
val readMapped = r.getReadMapped

// validate alignment fields
if (readMapped) {
require(r.getStart != null && r.getEnd != null && r.getContigName != null,
"Read was mapped but was missing alignment start/end/contig (%s).".format(r))
}

readMapped
}).flatMap(r => {
val t: List[Long] = List.range(r.getStart, r.getEnd)
t.map(n => (ReferenceRegion(r.getContigName, n, n + 1), 1))
}).reduceByKey(_ + _)
.map(r => Coverage(r._1, r._2.toDouble))

RDDBoundCoverageRDD(covCounts, sequences, None)
import dataset.sqlContext.implicits._
val covCounts = dataset.toDF
.where($"readMapped")
.select($"contigName", $"start", $"end")
.as[AlignmentWindow]
.flatMap(w => {
val width = (w.end - w.start).toInt
val buffer = new Array[Coverage](width)
var idx = 0
var pos = w.start
while (idx < width) {
val lastPos = pos
pos += 1L
buffer(idx) = Coverage(w.contigName, lastPos, pos, 1.0)
idx += 1
}
buffer
}).toDF
.groupBy("contigName", "start", "end")
.sum("count")
.withColumnRenamed("sum(count)", "count")
.as[Coverage]

DatasetBoundCoverageRDD(covCounts, sequences)
}

/**
Expand Down
Expand Up @@ -105,18 +105,17 @@ class AlignmentRecordRDDSuite extends ADAMFunSuite {

// get pileup at position 30
val pointCoverage = reads.filterByOverlappingRegion(ReferenceRegion("artificial", 30, 31)).rdd.count
val coverage: CoverageRDD = reads.toCoverage()
assert(coverage.rdd.filter(r => r.start == 30).first.count == pointCoverage)
}
def testCoverage(coverage: CoverageRDD) {
assert(coverage.rdd.filter(r => r.start == 30).first.count == pointCoverage)
}

sparkTest("test filterByOverlappingRegions") {
val inputPath = testFile("artificial.sam")
val reads: AlignmentRecordRDD = sc.loadAlignments(inputPath)
val coverageRdd = reads.toCoverage()
testCoverage(coverageRdd)

// get pileup at position 30
val pointCoverage = reads.filterByOverlappingRegions(Array(ReferenceRegion("artificial", 30, 31)).toList).rdd.count
val coverage: CoverageRDD = reads.toCoverage()
assert(coverage.rdd.filter(r => r.start == 30).first.count == pointCoverage)
// test dataset path
val readsDs = reads.transformDataset(ds => ds)
val coverageDs = readsDs.toCoverage()
testCoverage(coverageDs)
}

sparkTest("merges adjacent records with equal coverage values") {
Expand Down

0 comments on commit 7215c0d

Please sign in to comment.