Skip to content

Commit

Permalink
[ADAM-1047] Support generating predicates from ReferenceRegions.
Browse files Browse the repository at this point in the history
  • Loading branch information
fnothaft committed Jul 7, 2017
1 parent 8572fb7 commit bd6df07
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 3 deletions.
Expand Up @@ -19,6 +19,8 @@ package org.bdgenomics.adam.models

import com.esotericsoftware.kryo.io.{ Input, Output }
import com.esotericsoftware.kryo.{ Kryo, Serializer }
import org.apache.parquet.filter2.dsl.Dsl._
import org.apache.parquet.filter2.predicate.FilterPredicate
import org.bdgenomics.formats.avro._
import org.bdgenomics.utils.interval.array.Interval
import scala.math.{ max, min }
Expand Down Expand Up @@ -298,6 +300,21 @@ object ReferenceRegion {
def apply(coverage: Coverage): ReferenceRegion = {
new ReferenceRegion(coverage.contigName, coverage.start, coverage.end)
}

/**
* Creates a predicate that filters records overlapping one or more regions.
*
* @param regions The regions to filter on.
* @return Returns a predicate that can be pushed into Parquet files that
* keeps all records that overlap one or more region.
*/
def createPredicate(regions: ReferenceRegion*): FilterPredicate = {
require(regions.nonEmpty,
"Cannot create a predicate from an empty set of regions.")
regions.toIterable
.map(_.toPredicate)
.reduce(_ || _)
}
}

/**
Expand Down Expand Up @@ -666,6 +683,18 @@ case class ReferenceRegion(
first ++ second
}

/**
* Generates a predicate that can be used with Parquet files.
*
* @return A predicate that selects records that overlap a given genomic
* region.
*/
def toPredicate: FilterPredicate = {
((LongColumn("end") > start) &&
(LongColumn("start") <= end) &&
(BinaryColumn("contigName") === referenceName))
}

override def hashCode: Int = {
val nameHashCode = 37 + referenceName.hashCode
val strandHashCode = strand.ordinal()
Expand Down
Expand Up @@ -33,6 +33,12 @@ import scala.collection.JavaConversions._

class ReferenceRegionSuite extends FunSuite {

test("cannot create an empty predicate") {
intercept[IllegalArgumentException] {
ReferenceRegion.createPredicate()
}
}

test("contains(: ReferenceRegion)") {
assert(region("chr0", 10, 100).contains(region("chr0", 50, 70)))
assert(region("chr0", 10, 100).contains(region("chr0", 10, 100)))
Expand Down
Expand Up @@ -17,7 +17,6 @@
*/
package org.bdgenomics.adam.rdd.feature

import org.apache.parquet.filter2.dsl.Dsl._
import org.bdgenomics.adam.models.{
ReferenceRegion,
Coverage,
Expand Down Expand Up @@ -80,7 +79,7 @@ class CoverageRDDSuite extends ADAMFunSuite {
coverageRDD.save(outputFile, false, false)

val region = ReferenceRegion("chr1", 1, 9)
val predicate = ((LongColumn("end") >= region.start) && (LongColumn("start") <= region.end) && (BinaryColumn("contigName") === region.referenceName))
val predicate = region.toPredicate
val coverage = sc.loadParquetCoverage(outputFile, Some(predicate))
assert(coverage.rdd.count == 1)
}
Expand Down
Expand Up @@ -487,7 +487,12 @@ class AlignmentRecordRDDSuite extends ADAMFunSuite {
val reads: AlignmentRecordRDD = sc.loadAlignments(inputPath)
val outputPath = tmpLocation()
reads.saveAsParquet(TestSaveArgs(outputPath))
assert(new File(outputPath).exists())
val unfilteredReads = sc.loadAlignments(outputPath)
assert(unfilteredReads.rdd.count === 20)
val regionToLoad = ReferenceRegion("1", 100000000L, 200000000L)
val filteredReads = sc.loadParquetAlignments(outputPath,
optPredicate = Some(regionToLoad.toPredicate))
assert(filteredReads.rdd.count === 8)
}

sparkTest("save as SAM format") {
Expand Down
Expand Up @@ -17,6 +17,7 @@
*/
package org.bdgenomics.adam.rdd.variant

import org.bdgenomics.adam.models.ReferenceRegion
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.adam.util.ADAMFunSuite

Expand All @@ -31,6 +32,27 @@ class GenotypeRDDSuite extends ADAMFunSuite {
assert(union.samples.size === 4)
}

sparkTest("round trip to parquet") {
val genotypes = sc.loadGenotypes(testFile("small.vcf"))
val outputPath = tmpLocation()
genotypes.saveAsParquet(outputPath)
val unfilteredGenotypes = sc.loadGenotypes(outputPath)
assert(unfilteredGenotypes.rdd.count === 18)

val predicate = ReferenceRegion.createPredicate(ReferenceRegion("1", 14399L, 14400L),
ReferenceRegion("1", 752720L, 757721L),
ReferenceRegion("1", 752790L, 752793L))
val filteredGenotypes = sc.loadParquetGenotypes(outputPath,
optPredicate = Some(predicate))
filteredGenotypes.rdd.foreach(println)
assert(filteredGenotypes.rdd.count === 9)
val starts = filteredGenotypes.rdd.map(_.getStart).distinct.collect.toSet
assert(starts.size === 3)
assert(starts(14396L))
assert(starts(752720L))
assert(starts(752790L))
}

sparkTest("use broadcast join to pull down genotypes mapped to targets") {
val genotypesPath = testFile("small.vcf")
val targetsPath = testFile("small.1.bed")
Expand Down
Expand Up @@ -17,6 +17,7 @@
*/
package org.bdgenomics.adam.rdd.variant

import org.bdgenomics.adam.models.ReferenceRegion
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.adam.util.ADAMFunSuite

Expand All @@ -30,6 +31,23 @@ class VariantRDDSuite extends ADAMFunSuite {
assert(union.sequences.size === (variant1.sequences.size + variant2.sequences.size))
}

sparkTest("round trip to parquet") {
val variants = sc.loadVariants(testFile("small.vcf"))
val outputPath = tmpLocation()
variants.saveAsParquet(outputPath)
val unfilteredVariants = sc.loadVariants(outputPath)
assert(unfilteredVariants.rdd.count === 6)

val predicate = ReferenceRegion.toEnd("1", 752720L).toPredicate
val filteredVariants = sc.loadParquetVariants(outputPath,
optPredicate = Some(predicate))
assert(filteredVariants.rdd.count === 2)
val starts = filteredVariants.rdd.map(_.getStart).distinct.collect.toSet
assert(starts.size === 2)
assert(starts(752720L))
assert(starts(752790L))
}

sparkTest("use broadcast join to pull down variants mapped to targets") {
val variantsPath = testFile("small.vcf")
val targetsPath = testFile("small.1.bed")
Expand Down

0 comments on commit bd6df07

Please sign in to comment.