Skip to content

Commit

Permalink
Add NativeReaderOptions to abstract indexed read options
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisvittal committed Jun 30, 2019
1 parent 9412c0c commit 6249db3
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 133 deletions.
9 changes: 5 additions & 4 deletions hail/python/hail/ir/matrix_reader.py
Expand Up @@ -51,10 +51,11 @@ def render(self, r):
'path': self.path}
if self.intervals is not None:
assert self._interval_type is not None
reader['intervals'] = {
"value": self._interval_type._convert_to_json(self.intervals),
"pointType": self._interval_type.element_type.point_type._parsable_string(),
"filter": self.filter_intervals,
reader['options'] = {
'name': 'NativeReaderOptions',
'intervals': self._interval_type._convert_to_json(self.intervals),
'intervalPointType': self._interval_type.element_type.point_type._parsable_string(),
'filterIntervals': self.filter_intervals,
}
return escape_str(json.dumps(reader))

Expand Down
9 changes: 5 additions & 4 deletions hail/python/hail/ir/table_reader.py
Expand Up @@ -48,10 +48,11 @@ def render(self):
'path': self.path}
if self.intervals is not None:
assert self._interval_type is not None
reader['intervals'] = {
"value": self._interval_type._convert_to_json(self.intervals),
"pointType": self._interval_type.element_type.point_type._parsable_string(),
"filter": self.filter_intervals,
reader['options'] = {
'name': 'NativeReaderOptions',
'intervals': self._interval_type._convert_to_json(self.intervals),
'intervalPointType': self._interval_type.element_type.point_type._parsable_string(),
'filterIntervals': self.filter_intervals,
}
return escape_str(json.dumps(reader))

Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/HailContext.scala
Expand Up @@ -760,7 +760,7 @@ class HailContext private(
val nPartitions = partFiles.length
val localFS = bcFS
val (keyType, annotationType) = indexSpec.types
indexSpec.offsetField.map { f =>
indexSpec.offsetField.foreach { f =>
require(annotationType.asInstanceOf[TStruct].hasField(f))
require(annotationType.asInstanceOf[TStruct].fieldType(f) == TInt64())
}
Expand Down Expand Up @@ -857,11 +857,11 @@ class HailContext private(
val mkIndexReader = indexSpecRows.map { indexSpec =>
val idxPath = indexSpec.relPath
val (keyType, annotationType) = indexSpec.types
indexSpec.offsetField.map { f =>
indexSpec.offsetField.foreach { f =>
require(annotationType.asInstanceOf[TStruct].hasField(f))
require(annotationType.asInstanceOf[TStruct].fieldType(f) == TInt64())
}
indexSpecEntries.get.offsetField.map { f =>
indexSpecEntries.get.offsetField.foreach { f =>
require(annotationType.asInstanceOf[TStruct].hasField(f))
require(annotationType.asInstanceOf[TStruct].fieldType(f) == TInt64())
}
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/backend/LowerTableIR.scala
Expand Up @@ -99,7 +99,7 @@ object LowerTableIR {
val globalRef = genUID()

reader match {
case r@TableNativeReader(path, None, _, _, _) =>
case r@TableNativeReader(path, None, _) =>
val globalsPath = r.spec.globalsComponent.absolutePath(path)
val globalsSpec = AbstractRVDSpec.read(HailContext.get, globalsPath)
val gPath = AbstractRVDSpec.partPath(globalsPath, globalsSpec.partFiles.head)
Expand Down
50 changes: 7 additions & 43 deletions hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala
Expand Up @@ -91,7 +91,7 @@ object MatrixReader {
implicit val formats: Formats = RelationalSpec.formats + ShortTypeHints(
List(classOf[MatrixNativeReader], classOf[MatrixRangeReader], classOf[MatrixVCFReader],
classOf[MatrixBGENReader], classOf[MatrixPLINKReader], classOf[MatrixGENReader],
classOf[TextInputFilterAndReplace])) + new MatrixNativeReaderSerializer()
classOf[TextInputFilterAndReplace])) + new NativeReaderOptionsSerializer()
}

trait MatrixReader {
Expand Down Expand Up @@ -147,45 +147,9 @@ abstract class MatrixHybridReader extends TableReader with MatrixReader {
}
}

class MatrixNativeReaderSerializer() extends CustomSerializer[MatrixNativeReader](
format =>
({ case jObj: JObject =>
implicit val fmt = format
val path = (jObj \ "path").extract[String]
val intervalPointType = (jObj \ "intervals" \ "pointType").extractOpt[String].map { tstring =>
IRParser.parseType(tstring)
}
val jIntervals = (jObj \ "intervals" \ "value").toOption
if (intervalPointType.isDefined) require(jIntervals.isDefined)
val filterIntervals = (jObj \ "intervals" \ "filter").extractOpt[Boolean].getOrElse(false)
val intervals = jIntervals.map { jv =>
val intType = TArray(TInterval(intervalPointType.get))
JSONAnnotationImpex.importAnnotation(jv, intType).asInstanceOf[IndexedSeq[Interval]]
}
MatrixNativeReader(path, intervals, intervalPointType, filterIntervals)
}, { case reader: MatrixNativeReader =>
implicit val fmt = format
val intType = reader.intervalPointType.map { pt => TArray(TInterval(pt)) }
val obj = JObject(
JField("name", JString(reader.getClass.getSimpleName)),
JField("path", JString(reader.path)))
if (reader.intervalPointType.isEmpty)
obj
else {
val intervalsJson: JObject = ("intervals" ->
("pointType" -> reader.intervalPointType.map { t => t.parsableString() }) ~
("value" -> reader.intervals.map(JSONAnnotationImpex.exportAnnotation(_, intType.get))) ~
("filter" -> reader.filterIntervals))
obj.merge(intervalsJson)
}
})
)

case class MatrixNativeReader(
path: String,
intervals: Option[IndexedSeq[Interval]] = None,
intervalPointType: Option[Type] = None,
filterIntervals: Boolean = false,
options: Option[NativeReaderOptions] = None,
_spec: AbstractMatrixTableSpec = null
) extends MatrixReader {
lazy val spec: AbstractMatrixTableSpec = Option(_spec).getOrElse(
Expand All @@ -204,10 +168,12 @@ case class MatrixNativeReader(

def fullMatrixType: MatrixType = spec.matrix_type

private def intervals = options.map(_.intervals)

if (intervals.nonEmpty && !spec.indexed(path))
fatal("""`intervals` specified on an unindexed matrix table.
|This matrix table was written using an older version of hail
|rewrite the matrix in order to create an index to proceed""" )
|rewrite the matrix in order to create an index to proceed""".stripMargin)

override def lower(mr: MatrixRead): TableIR = {
val rowsPath = path + "/rows"
Expand All @@ -216,7 +182,7 @@ case class MatrixNativeReader(

if (mr.dropCols) {
val tt = TableType(mr.typ.rowType, mr.typ.rowKey, mr.typ.globalType)
val trdr: TableReader = TableNativeReader(rowsPath, intervals, intervalPointType, _spec = spec.rowsTableSpec(rowsPath))
val trdr: TableReader = TableNativeReader(rowsPath, options, _spec = spec.rowsTableSpec(rowsPath))
var tr: TableIR = TableRead(tt, mr.dropRows, trdr)
tr = TableMapGlobals(
tr,
Expand All @@ -236,9 +202,7 @@ case class MatrixNativeReader(
val trdr = TableNativeZippedReader(
rowsPath,
entriesPath,
intervals,
intervalPointType,
filterIntervals,
options,
spec.rowsTableSpec(rowsPath),
spec.entriesTableSpec(entriesPath))
var tr: TableIR = TableRead(tt, mr.dropRows, trdr)
Expand Down
38 changes: 38 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/NativeReaderOptions.scala
@@ -0,0 +1,38 @@
package is.hail.expr.ir

import is.hail.annotations._
import is.hail.expr.types.virtual._
import is.hail.expr.JSONAnnotationImpex
import is.hail.utils._
import org.json4s.{Formats, ShortTypeHints, CustomSerializer, JObject}
import org.json4s.JsonAST.{JArray, JInt, JNull, JString, JField, JNothing}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods

class NativeReaderOptionsSerializer() extends CustomSerializer[NativeReaderOptions](
format =>
({ case jObj: JObject =>
implicit val fmt = format
val filterIntervals = (jObj \ "filterIntervals").extract[Boolean]
val intervalPointType = IRParser.parseType((jObj \ "intervalPointType").extract[String])
val intervals = {
val jv = jObj \ "intervals"
val ty = TArray(TInterval(intervalPointType))
JSONAnnotationImpex.importAnnotation(jv, ty).asInstanceOf[IndexedSeq[Interval]]
}
NativeReaderOptions(intervals, intervalPointType, filterIntervals)
}, { case opts: NativeReaderOptions =>
implicit val fmt = format
val ty = TArray(TInterval(opts.intervalPointType))
(("name" -> opts.getClass.getSimpleName) ~
("intervals" -> JSONAnnotationImpex.exportAnnotation(opts.intervals, ty)) ~
("intervalPointType" -> opts.intervalPointType.parsableString()) ~
("filterIntervals" -> opts.filterIntervals))
})
)

case class NativeReaderOptions(
intervals: IndexedSeq[Interval],
intervalPointType: Type,
filterIntervals: Boolean = false
)
88 changes: 10 additions & 78 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Expand Up @@ -91,7 +91,7 @@ object TableReader {
classOf[TextTableReader],
classOf[TextInputFilterAndReplace],
classOf[TableFromBlockMatrixNativeReader])
) + new TableNativeReaderSerializer() + new TableNativeZippedReaderSerializer()
) + new NativeReaderOptionsSerializer()
}

abstract class TableReader {
Expand All @@ -102,45 +102,9 @@ abstract class TableReader {
def fullType: TableType
}

class TableNativeReaderSerializer() extends CustomSerializer[TableNativeReader](
format =>
({ case jObj: JObject =>
implicit val fmt = format
val path = (jObj \ "path").extract[String]
val intervalPointType = (jObj \ "intervals" \ "pointType").extractOpt[String].map { tstring =>
IRParser.parseType(tstring)
}
val jIntervals = (jObj \ "intervals" \ "value").toOption
if (intervalPointType.isDefined) require(jIntervals.isDefined)
val filterIntervals = (jObj \ "intervals" \ "filter").extractOpt[Boolean].getOrElse(false)
val intervals = jIntervals.map { jv =>
val intType = TArray(TInterval(intervalPointType.get))
JSONAnnotationImpex.importAnnotation(jv, intType).asInstanceOf[IndexedSeq[Interval]]
}
TableNativeReader(path, intervals, intervalPointType, filterIntervals)
}, { case reader: TableNativeReader =>
implicit val fmt = format
val intType = reader.intervalPointType.map { pt => TArray(TInterval(pt)) }
val obj = JObject(
JField("name", JString(reader.getClass.getSimpleName)),
JField("path", JString(reader.path)))
if (reader.intervalPointType.isEmpty)
obj
else {
val intervalsJson: JObject = ("intervals" ->
("pointType" -> reader.intervalPointType.map { t => t.parsableString() }) ~
("value" -> reader.intervals.map(JSONAnnotationImpex.exportAnnotation(_, intType.get))) ~
("filter" -> reader.filterIntervals))
obj.merge(intervalsJson)
}
})
)

case class TableNativeReader(
path: String,
intervals: Option[IndexedSeq[Interval]] = None,
intervalPointType: Option[Type] = None,
filterIntervals: Boolean = false,
options: Option[NativeReaderOptions] = None,
var _spec: AbstractTableSpec = null
) extends TableReader {
lazy val spec = if (_spec != null)
Expand All @@ -155,10 +119,13 @@ case class TableNativeReader(

def fullType: TableType = spec.table_type

private lazy val filterIntervals = options.map(_.filterIntervals).getOrElse(false)
private def intervals = options.map(_.intervals)

if (intervals.nonEmpty && !spec.indexed(path))
fatal("""`intervals` specified on an unindexed table.
|This table was written using an older version of hail
|rewrite the table in order to create an index to proceed""" )
|rewrite the table in order to create an index to proceed""".stripMargin)

def apply(tr: TableRead): TableValue = {
val hc = HailContext.get
Expand All @@ -183,48 +150,10 @@ case class TableNativeReader(
}
}

class TableNativeZippedReaderSerializer() extends CustomSerializer[TableNativeZippedReader](
format =>
({ case jObj: JObject =>
implicit val fmt = format
val pathLeft = (jObj \ "pathLeft").extract[String]
val pathRight = (jObj \ "pathLeft").extract[String]
val intervalPointType = (jObj \ "intervals" \ "pointType").extractOpt[String].map { tstring =>
IRParser.parseType(tstring)
}
val jIntervals = (jObj \ "intervals" \ "value").toOption
if (intervalPointType.isDefined) require(jIntervals.isDefined)
val filterIntervals = (jObj \ "intervals" \ "filter").extractOpt[Boolean].getOrElse(false)
val intervals = jIntervals.map { jv =>
val intType = TArray(TInterval(intervalPointType.get))
JSONAnnotationImpex.importAnnotation(jv, intType).asInstanceOf[IndexedSeq[Interval]]
}
TableNativeZippedReader(pathLeft, pathRight, intervals, intervalPointType, filterIntervals)
}, { case reader: TableNativeZippedReader =>
implicit val fmt = format
val intType = reader.intervalPointType.map { pt => TArray(TInterval(pt)) }
val obj = JObject(
JField("name", JString(reader.getClass.getSimpleName)),
JField("pathLeft", JString(reader.pathLeft)),
JField("pathRight", JString(reader.pathRight)))
if (reader.intervalPointType.isEmpty)
obj
else {
val intervalsJson: JObject = ("intervals" ->
("pointType" -> reader.intervalPointType.map { t => t.parsableString() }) ~
("value" -> reader.intervals.map(JSONAnnotationImpex.exportAnnotation(_, intType.get))) ~
("filter" -> reader.filterIntervals))
obj.merge(intervalsJson)
}
})
)

case class TableNativeZippedReader(
pathLeft: String,
pathRight: String,
intervals: Option[IndexedSeq[Interval]] = None,
intervalPointType: Option[Type] = None,
filterIntervals: Boolean = false,
options: Option[NativeReaderOptions] = None,
var _specLeft: AbstractTableSpec = null,
var _specRight: AbstractTableSpec = null
) extends TableReader {
Expand All @@ -236,6 +165,9 @@ case class TableNativeZippedReader(
lazy val specLeft = if (_specLeft != null) _specLeft else getSpec(pathLeft)
lazy val specRight = if (_specRight != null) _specRight else getSpec(pathRight)

private lazy val filterIntervals = options.map(_.filterIntervals).getOrElse(false)
private def intervals = options.map(_.intervals)

require((specLeft.table_type.rowType.fieldNames ++ specRight.table_type.rowType.fieldNames).areDistinct())
require(specRight.table_type.key.isEmpty)
require(specLeft.partitionCounts sameElements specRight.partitionCounts)
Expand Down

0 comments on commit 6249db3

Please sign in to comment.