Skip to content

Commit

Permalink
[compiler] don't force TableRead to produce a TableValue (#13229)
Browse files Browse the repository at this point in the history
Currently `TableRead.execute` always produces a
`TableValueIntermediate`, even though almost all `TableReader`s are
lowerable, so could produce a `TableStageIntermediate`.

This pr refactors `TableReader` to allow producing a
`TableStageIntermediate` in most cases, and to make it clearer which
readers still need to be lowered (only
`TableFromBlockMatrixNativeReader`, `MatrixVCFReader`, and
`MatrixPLINKReader`). It also deletes some now dead code.
  • Loading branch information
patrick-schultz committed Jul 11, 2023
1 parent f48a4d1 commit 479ef56
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 126 deletions.
68 changes: 2 additions & 66 deletions hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,6 @@ class PartitionIteratorLongReader(
}
}

class GenericTableValueRDDPartition(
val index: Int,
val context: Any
) extends Partition

class GenericTableValueRDD(
@transient val contexts: IndexedSeq[Any],
body: (Region, HailClassLoader, Any) => Iterator[Long]
) extends RDD[RVDContext => Iterator[Long]](SparkBackend.sparkContext("GenericTableValueRDD"), Nil) {
def getPartitions: Array[Partition] = contexts.zipWithIndex.map { case (c, i) =>
new GenericTableValueRDDPartition(i, c)
}.toArray

def compute(split: Partition, context: TaskContext): Iterator[RVDContext => Iterator[Long]] = {
Iterator.single { (rvdCtx: RVDContext) =>
body(rvdCtx.region, theHailClassLoaderForSparkWorkers, split.asInstanceOf[GenericTableValueRDDPartition].context)
}
}
}

abstract class LoweredTableReaderCoercer {
def coerce(ctx: ExecuteContext,
globals: IR,
Expand All @@ -164,8 +144,8 @@ class GenericTableValue(
assert(contextType.hasField("partitionIndex"))
assert(contextType.fieldType("partitionIndex") == TInt32)

var ltrCoercer: LoweredTableReaderCoercer = _
def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = {
private var ltrCoercer: LoweredTableReaderCoercer = _
private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = {
if (ltrCoercer == null) {
ltrCoercer = LoweredTableReader.makeCoercer(
ctx,
Expand Down Expand Up @@ -210,48 +190,4 @@ class GenericTableValue(
requestedBody)
}
}

def toContextRDD(fs: FS, requestedRowType: TStruct): ContextRDD[Long] = {
val localBody = body(requestedRowType)
ContextRDD(new GenericTableValueRDD(contexts, localBody(_, _, fs, _)))
}

private[this] var rvdCoercer: RVDCoercer = _

def getRVDCoercer(ctx: ExecuteContext): RVDCoercer = {
if (rvdCoercer == null) {
rvdCoercer = RVD.makeCoercer(
ctx,
RVDType(bodyPType(fullTableType.rowType), fullTableType.key),
1,
toContextRDD(ctx.fs, fullTableType.keyType))
}
rvdCoercer
}

def toTableValue(ctx: ExecuteContext, requestedType: TableType): TableValue = {
val requestedRowType = requestedType.rowType
val requestedRowPType = bodyPType(requestedType.rowType)
val crdd = toContextRDD(ctx.fs, requestedRowType)

val rvd = partitioner match {
case Some(partitioner) =>
RVD(
RVDType(requestedRowPType, fullTableType.key),
partitioner,
crdd)
case None if requestedType.key.isEmpty =>
RVD(
RVDType(requestedRowPType, fullTableType.key),
RVDPartitioner.unkeyed(ctx.stateManager, contexts.length),
crdd)
case None =>
getRVDCoercer(ctx).coerce(RVDType(requestedRowPType, fullTableType.key), crdd)
}

TableValue(ctx,
requestedType,
BroadcastRow(ctx, globals(requestedType.globalType), requestedType.globalType),
rvd)
}
}
9 changes: 3 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.functions.StringFunctions
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency, TableStageToRVD}
import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency, TableStageToRVD}
import is.hail.expr.ir.streams.StreamProducer
import is.hail.io.fs.{FS, FileStatus}
import is.hail.rvd.RVDPartitioner
Expand Down Expand Up @@ -168,11 +168,8 @@ class StringTableReader(
)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
val ts = lower(ctx, requestedType)
val (broadCastRow, rvd) = TableStageToRVD.apply(ctx, ts)
TableValue(ctx, requestedType, broadCastRow, rvd)
}
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")

override def partitionCounts: Option[IndexedSeq[Long]] = None

Expand Down
29 changes: 15 additions & 14 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,10 @@ trait TableReaderWithExtraUID extends TableReader {
abstract class TableReader {
def pathsUsed: Seq[String]

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue
def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType))
}

def partitionCounts: Option[IndexedSeq[Long]]

Expand All @@ -503,11 +506,9 @@ abstract class TableReader {
StringEscapeUtils.escapeString(JsonMethods.compact(toJValue))
}

def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")
def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR

def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented")
def lower(ctx: ExecuteContext, requestedType: TableType): TableStage
}

object TableNativeReader {
Expand Down Expand Up @@ -1409,9 +1410,6 @@ class TableNativeReader(
VirtualTypeWithReq(tcoerce[PStruct](spec.globalsComponent.rvdSpec(ctx.fs, params.path)
.typedCodecSpec.encodedType.decodedPType(requestedType.globalType)))

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)

override def toJValue: JValue = {
implicit val formats: Formats = DefaultFormats
decomposeWithName(params, "TableNativeReader")
Expand Down Expand Up @@ -1525,9 +1523,6 @@ case class TableNativeZippedReader(
(t, mk)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = {
val globalsSpec = specLeft.globalsSpec
val globalsPath = specLeft.globalsComponent.absolutePath(pathLeft)
Expand Down Expand Up @@ -1612,7 +1607,7 @@ case class TableFromBlockMatrixNativeReader(
override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
VirtualTypeWithReq(PCanonicalStruct.empty(required = true))

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
val rowsRDD = new BlockMatrixReadRowBlockedRDD(
ctx.fsBc, params.path, partitionRanges, requestedType.rowType, metadata,
maybeMaximumCacheMemoryInBytes = params.maximumCacheMemoryInBytes)
Expand All @@ -1622,9 +1617,15 @@ case class TableFromBlockMatrixNativeReader(

val rowTyp = PType.canonical(requestedType.rowType, required = true).asInstanceOf[PStruct]
val rvd = RVD(RVDType(rowTyp, fullType.key.filter(rowTyp.hasField)), partitioner, ContextRDD(rowsRDD))
TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd)
TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd))
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented")

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")

override def toJValue: JValue = {
decomposeWithName(params, "TableFromBlockMatrixNativeReader")(TableReader.formats)
}
Expand Down Expand Up @@ -1666,7 +1667,7 @@ case class TableRead(typ: TableType, dropRows: Boolean, tr: TableReader) extends
}

protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate =
new TableValueIntermediate(tr.apply(ctx, typ, dropRows))
tr.toExecuteIntermediate(ctx, typ, dropRows)
}

case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) extends TableIR {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ object LowerDistributedSort {

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
}

override def isDistinctlyKeyed: Boolean = false // FIXME: No default value

def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = {
Expand Down Expand Up @@ -612,11 +607,6 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
}

override def isDistinctlyKeyed: Boolean = false // FIXME: No default value

def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
val (Some(PTypeReferenceSingleCodeType(globType: PStruct)), f) = Compile[AsmFunction1RegionLong](
ctx, FastIndexedSeq(), FastIndexedSeq(classInfo[Region]), LongInfo, PruneDeadFields.upcast(ctx, globals, requestedType.globalType))
Expand All @@ -43,10 +43,10 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader
requestedType.rowType))

val fsBc = ctx.fsBc
TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) =>
TableExecuteIntermediate(TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) =>
val partF = rowF(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), ctx.partitionRegion)
it.map { elt => partF(ctx.r, elt) }
})
}))
}

override def isDistinctlyKeyed: Boolean = false
Expand Down
10 changes: 4 additions & 6 deletions hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package is.hail.io.avro

import is.hail.backend.ExecuteContext
import is.hail.expr.ir._
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency}
import is.hail.rvd.RVDPartitioner
import is.hail.types.physical.{PCanonicalStruct, PCanonicalTuple, PInt64Required}
import is.hail.types.virtual._
Expand Down Expand Up @@ -44,11 +44,6 @@ class AvroTableReader(

def renderShort(): String = defaultRender()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
val ts = lower(ctx, requestedType)
new TableStageIntermediate(ts).asTableValue(ctx)
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = MakeStruct(FastIndexedSeq())
val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) =>
Expand All @@ -64,6 +59,9 @@ class AvroTableReader(
}
)
}

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")
}

object AvroTableReader {
Expand Down
11 changes: 0 additions & 11 deletions hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -479,17 +479,6 @@ class MatrixBGENReader(
override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
VirtualTypeWithReq(PType.canonical(requestedType.globalType, required = true))

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {

val _lc = lower(ctx, requestedType)
val lc = if (dropRows)
_lc.copy(partitioner = _lc.partitioner.copy(rangeBounds = Array[Interval]()),
contexts = StreamTake(_lc.contexts, 0))
else _lc

TableExecuteIntermediate(lc).asTableValue(ctx)
}

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalType: TStruct): IR = {
requestedGlobalType.fieldOption(LowerMatrixIR.colsFieldName) match {
case Some(f) =>
Expand Down
3 changes: 0 additions & 3 deletions hail/src/main/scala/is/hail/io/plink/LoadPlink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,6 @@ class MatrixPLINKReader(
body)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
executeGeneric(ctx).toTableValue(ctx, requestedType)

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = {
val tt = fullMatrixType.toTableType(LowerMatrixIR.entriesFieldName, LowerMatrixIR.colsFieldName)
val subset = tt.globalType.valueSubsetter(requestedGlobalsType)
Expand Down
5 changes: 1 addition & 4 deletions hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import is.hail.backend.{BroadcastValue, ExecuteContext, HailStateManager}
import is.hail.expr.JSONAnnotationImpex
import is.hail.expr.ir.lowering.TableStage
import is.hail.expr.ir.streams.StreamProducer
import is.hail.expr.ir.{CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, MatrixReader, PartitionReader, TableValue}
import is.hail.expr.ir.{CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, MatrixReader, TableExecuteIntermediate, PartitionReader, TableValue}
import is.hail.io.fs.{FS, FileStatus}
import is.hail.io.tabix._
import is.hail.io.vcf.LoadVCF.{getHeaderLines, parseHeader}
Expand Down Expand Up @@ -1965,9 +1965,6 @@ class MatrixVCFReader(
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
executeGeneric(ctx).toTableStage(ctx, requestedType, "VCF", params)

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
executeGeneric(ctx, dropRows).toTableValue(ctx, requestedType)

override def toJValue: JValue = {
implicit val formats: Formats = DefaultFormats
decomposeWithName(params, "MatrixVCFReader")
Expand Down
5 changes: 4 additions & 1 deletion hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package is.hail.expr.ir
import is.hail.HailSuite
import is.hail.backend.ExecuteContext
import is.hail.expr.Nat
import is.hail.expr.ir.lowering.TableStage
import is.hail.methods.{ForceCountMatrixTable, ForceCountTable}
import is.hail.rvd.RVD
import is.hail.types._
Expand Down Expand Up @@ -114,7 +115,9 @@ class PruneSuite extends HailSuite {

def pathsUsed: IndexedSeq[String] = FastSeq()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = ???
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ???

override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ???

def partitionCounts: Option[IndexedSeq[Long]] = ???

Expand Down
6 changes: 4 additions & 2 deletions hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import is.hail.annotations.SafeNDArray
import is.hail.backend.ExecuteContext
import is.hail.expr.Nat
import is.hail.expr.ir.TestUtils._
import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR}
import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR, TableStage}
import is.hail.methods.ForceCountTable
import is.hail.rvd.RVDPartitioner
import is.hail.types._
Expand Down Expand Up @@ -826,7 +826,9 @@ class TableIRSuite extends HailSuite {

def pathsUsed: Seq[String] = FastSeq()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = ???
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ???

override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ???

override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastIndexedSeq(1, 2, 3, 4))

Expand Down

0 comments on commit 479ef56

Please sign in to comment.