Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] don't force TableRead to produce a TableValue #13229

Merged
merged 5 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 2 additions & 66 deletions hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,6 @@ class PartitionIteratorLongReader(
}
}

class GenericTableValueRDDPartition(
val index: Int,
val context: Any
) extends Partition

class GenericTableValueRDD(
@transient val contexts: IndexedSeq[Any],
body: (Region, HailClassLoader, Any) => Iterator[Long]
) extends RDD[RVDContext => Iterator[Long]](SparkBackend.sparkContext("GenericTableValueRDD"), Nil) {
def getPartitions: Array[Partition] = contexts.zipWithIndex.map { case (c, i) =>
new GenericTableValueRDDPartition(i, c)
}.toArray

def compute(split: Partition, context: TaskContext): Iterator[RVDContext => Iterator[Long]] = {
Iterator.single { (rvdCtx: RVDContext) =>
body(rvdCtx.region, theHailClassLoaderForSparkWorkers, split.asInstanceOf[GenericTableValueRDDPartition].context)
}
}
}

abstract class LoweredTableReaderCoercer {
def coerce(ctx: ExecuteContext,
globals: IR,
Expand All @@ -164,8 +144,8 @@ class GenericTableValue(
assert(contextType.hasField("partitionIndex"))
assert(contextType.fieldType("partitionIndex") == TInt32)

var ltrCoercer: LoweredTableReaderCoercer = _
def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = {
private var ltrCoercer: LoweredTableReaderCoercer = _
private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = {
if (ltrCoercer == null) {
ltrCoercer = LoweredTableReader.makeCoercer(
ctx,
Expand Down Expand Up @@ -210,48 +190,4 @@ class GenericTableValue(
requestedBody)
}
}

def toContextRDD(fs: FS, requestedRowType: TStruct): ContextRDD[Long] = {
val localBody = body(requestedRowType)
ContextRDD(new GenericTableValueRDD(contexts, localBody(_, _, fs, _)))
}

private[this] var rvdCoercer: RVDCoercer = _

def getRVDCoercer(ctx: ExecuteContext): RVDCoercer = {
if (rvdCoercer == null) {
rvdCoercer = RVD.makeCoercer(
ctx,
RVDType(bodyPType(fullTableType.rowType), fullTableType.key),
1,
toContextRDD(ctx.fs, fullTableType.keyType))
}
rvdCoercer
}

def toTableValue(ctx: ExecuteContext, requestedType: TableType): TableValue = {
val requestedRowType = requestedType.rowType
val requestedRowPType = bodyPType(requestedType.rowType)
val crdd = toContextRDD(ctx.fs, requestedRowType)

val rvd = partitioner match {
case Some(partitioner) =>
RVD(
RVDType(requestedRowPType, fullTableType.key),
partitioner,
crdd)
case None if requestedType.key.isEmpty =>
RVD(
RVDType(requestedRowPType, fullTableType.key),
RVDPartitioner.unkeyed(ctx.stateManager, contexts.length),
crdd)
case None =>
getRVDCoercer(ctx).coerce(RVDType(requestedRowPType, fullTableType.key), crdd)
}

TableValue(ctx,
requestedType,
BroadcastRow(ctx, globals(requestedType.globalType), requestedType.globalType),
rvd)
}
}
9 changes: 3 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.functions.StringFunctions
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency, TableStageToRVD}
import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency, TableStageToRVD}
import is.hail.expr.ir.streams.StreamProducer
import is.hail.io.fs.{FS, FileStatus}
import is.hail.rvd.RVDPartitioner
Expand Down Expand Up @@ -168,11 +168,8 @@ class StringTableReader(
)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
val ts = lower(ctx, requestedType)
val (broadCastRow, rvd) = TableStageToRVD.apply(ctx, ts)
TableValue(ctx, requestedType, broadCastRow, rvd)
}
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")

override def partitionCounts: Option[IndexedSeq[Long]] = None

Expand Down
29 changes: 15 additions & 14 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,10 @@ trait TableReaderWithExtraUID extends TableReader {
abstract class TableReader {
def pathsUsed: Seq[String]

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue
def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType))
}

def partitionCounts: Option[IndexedSeq[Long]]

Expand All @@ -503,11 +506,9 @@ abstract class TableReader {
StringEscapeUtils.escapeString(JsonMethods.compact(toJValue))
}

def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")
def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR

def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented")
def lower(ctx: ExecuteContext, requestedType: TableType): TableStage
}

object TableNativeReader {
Expand Down Expand Up @@ -1409,9 +1410,6 @@ class TableNativeReader(
VirtualTypeWithReq(tcoerce[PStruct](spec.globalsComponent.rvdSpec(ctx.fs, params.path)
.typedCodecSpec.encodedType.decodedPType(requestedType.globalType)))

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)

override def toJValue: JValue = {
implicit val formats: Formats = DefaultFormats
decomposeWithName(params, "TableNativeReader")
Expand Down Expand Up @@ -1525,9 +1523,6 @@ case class TableNativeZippedReader(
(t, mk)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = {
val globalsSpec = specLeft.globalsSpec
val globalsPath = specLeft.globalsComponent.absolutePath(pathLeft)
Expand Down Expand Up @@ -1612,7 +1607,7 @@ case class TableFromBlockMatrixNativeReader(
override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
VirtualTypeWithReq(PCanonicalStruct.empty(required = true))

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
val rowsRDD = new BlockMatrixReadRowBlockedRDD(
ctx.fsBc, params.path, partitionRanges, requestedType.rowType, metadata,
maybeMaximumCacheMemoryInBytes = params.maximumCacheMemoryInBytes)
Expand All @@ -1622,9 +1617,15 @@ case class TableFromBlockMatrixNativeReader(

val rowTyp = PType.canonical(requestedType.rowType, required = true).asInstanceOf[PStruct]
val rvd = RVD(RVDType(rowTyp, fullType.key.filter(rowTyp.hasField)), partitioner, ContextRDD(rowsRDD))
TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd)
TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd))
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented")

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")

override def toJValue: JValue = {
decomposeWithName(params, "TableFromBlockMatrixNativeReader")(TableReader.formats)
}
Expand Down Expand Up @@ -1666,7 +1667,7 @@ case class TableRead(typ: TableType, dropRows: Boolean, tr: TableReader) extends
}

protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate =
new TableValueIntermediate(tr.apply(ctx, typ, dropRows))
tr.toExecuteIntermediate(ctx, typ, dropRows)
}

case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) extends TableIR {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ object LowerDistributedSort {

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
}

override def isDistinctlyKeyed: Boolean = false // FIXME: No default value

def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = {
Expand Down Expand Up @@ -612,11 +607,6 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
}

override def isDistinctlyKeyed: Boolean = false // FIXME: No default value

def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
val (Some(PTypeReferenceSingleCodeType(globType: PStruct)), f) = Compile[AsmFunction1RegionLong](
ctx, FastIndexedSeq(), FastIndexedSeq(classInfo[Region]), LongInfo, PruneDeadFields.upcast(ctx, globals, requestedType.globalType))
Expand All @@ -43,10 +43,10 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader
requestedType.rowType))

val fsBc = ctx.fsBc
TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) =>
TableExecuteIntermediate(TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) =>
val partF = rowF(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), ctx.partitionRegion)
it.map { elt => partF(ctx.r, elt) }
})
}))
}

override def isDistinctlyKeyed: Boolean = false
Expand Down
10 changes: 4 additions & 6 deletions hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package is.hail.io.avro

import is.hail.backend.ExecuteContext
import is.hail.expr.ir._
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency}
import is.hail.rvd.RVDPartitioner
import is.hail.types.physical.{PCanonicalStruct, PCanonicalTuple, PInt64Required}
import is.hail.types.virtual._
Expand Down Expand Up @@ -44,11 +44,6 @@ class AvroTableReader(

def renderShort(): String = defaultRender()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
val ts = lower(ctx, requestedType)
new TableStageIntermediate(ts).asTableValue(ctx)
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = MakeStruct(FastIndexedSeq())
val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) =>
Expand All @@ -64,6 +59,9 @@ class AvroTableReader(
}
)
}

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")
}

object AvroTableReader {
Expand Down
11 changes: 0 additions & 11 deletions hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -479,17 +479,6 @@ class MatrixBGENReader(
override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
VirtualTypeWithReq(PType.canonical(requestedType.globalType, required = true))

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {

val _lc = lower(ctx, requestedType)
val lc = if (dropRows)
_lc.copy(partitioner = _lc.partitioner.copy(rangeBounds = Array[Interval]()),
contexts = StreamTake(_lc.contexts, 0))
else _lc

TableExecuteIntermediate(lc).asTableValue(ctx)
}

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalType: TStruct): IR = {
requestedGlobalType.fieldOption(LowerMatrixIR.colsFieldName) match {
case Some(f) =>
Expand Down
3 changes: 0 additions & 3 deletions hail/src/main/scala/is/hail/io/plink/LoadPlink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -485,9 +485,6 @@ class MatrixPLINKReader(
body)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
executeGeneric(ctx).toTableValue(ctx, requestedType)

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = {
val tt = fullMatrixType.toTableType(LowerMatrixIR.entriesFieldName, LowerMatrixIR.colsFieldName)
val subset = tt.globalType.valueSubsetter(requestedGlobalsType)
Expand Down
5 changes: 1 addition & 4 deletions hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import is.hail.backend.{BroadcastValue, ExecuteContext, HailStateManager}
import is.hail.expr.JSONAnnotationImpex
import is.hail.expr.ir.lowering.TableStage
import is.hail.expr.ir.streams.StreamProducer
import is.hail.expr.ir.{CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, MatrixReader, PartitionReader, TableValue}
import is.hail.expr.ir.{CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, MatrixReader, TableExecuteIntermediate, PartitionReader, TableValue}
import is.hail.io.fs.{FS, FileStatus}
import is.hail.io.tabix._
import is.hail.io.vcf.LoadVCF.{getHeaderLines, parseHeader}
Expand Down Expand Up @@ -1965,9 +1965,6 @@ class MatrixVCFReader(
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
executeGeneric(ctx).toTableStage(ctx, requestedType, "VCF", params)

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
executeGeneric(ctx, dropRows).toTableValue(ctx, requestedType)

override def toJValue: JValue = {
implicit val formats: Formats = DefaultFormats
decomposeWithName(params, "MatrixVCFReader")
Expand Down
5 changes: 4 additions & 1 deletion hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package is.hail.expr.ir
import is.hail.HailSuite
import is.hail.backend.ExecuteContext
import is.hail.expr.Nat
import is.hail.expr.ir.lowering.TableStage
import is.hail.methods.{ForceCountMatrixTable, ForceCountTable}
import is.hail.rvd.RVD
import is.hail.types._
Expand Down Expand Up @@ -114,7 +115,9 @@ class PruneSuite extends HailSuite {

def pathsUsed: IndexedSeq[String] = FastSeq()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = ???
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ???

override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ???

def partitionCounts: Option[IndexedSeq[Long]] = ???

Expand Down
6 changes: 4 additions & 2 deletions hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import is.hail.annotations.SafeNDArray
import is.hail.backend.ExecuteContext
import is.hail.expr.Nat
import is.hail.expr.ir.TestUtils._
import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR}
import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR, TableStage}
import is.hail.methods.ForceCountTable
import is.hail.rvd.RVDPartitioner
import is.hail.types._
Expand Down Expand Up @@ -826,7 +826,9 @@ class TableIRSuite extends HailSuite {

def pathsUsed: Seq[String] = FastSeq()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = ???
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ???

override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ???

override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastIndexedSeq(1, 2, 3, 4))

Expand Down