Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] don't force TableRead to produce a TableValue #13229

Merged
merged 5 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 2 additions & 66 deletions hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,6 @@ class PartitionIteratorLongReader(
}
}

class GenericTableValueRDDPartition(
val index: Int,
val context: Any
) extends Partition

class GenericTableValueRDD(
@transient val contexts: IndexedSeq[Any],
body: (Region, HailClassLoader, Any) => Iterator[Long]
) extends RDD[RVDContext => Iterator[Long]](SparkBackend.sparkContext("GenericTableValueRDD"), Nil) {
def getPartitions: Array[Partition] = contexts.zipWithIndex.map { case (c, i) =>
new GenericTableValueRDDPartition(i, c)
}.toArray

def compute(split: Partition, context: TaskContext): Iterator[RVDContext => Iterator[Long]] = {
Iterator.single { (rvdCtx: RVDContext) =>
body(rvdCtx.region, theHailClassLoaderForSparkWorkers, split.asInstanceOf[GenericTableValueRDDPartition].context)
}
}
}

abstract class LoweredTableReaderCoercer {
def coerce(ctx: ExecuteContext,
globals: IR,
Expand All @@ -164,8 +144,8 @@ class GenericTableValue(
assert(contextType.hasField("partitionIndex"))
assert(contextType.fieldType("partitionIndex") == TInt32)

var ltrCoercer: LoweredTableReaderCoercer = _
def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = {
private var ltrCoercer: LoweredTableReaderCoercer = _
private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = {
if (ltrCoercer == null) {
ltrCoercer = LoweredTableReader.makeCoercer(
ctx,
Expand Down Expand Up @@ -210,48 +190,4 @@ class GenericTableValue(
requestedBody)
}
}

def toContextRDD(fs: FS, requestedRowType: TStruct): ContextRDD[Long] = {
val localBody = body(requestedRowType)
ContextRDD(new GenericTableValueRDD(contexts, localBody(_, _, fs, _)))
}

private[this] var rvdCoercer: RVDCoercer = _

def getRVDCoercer(ctx: ExecuteContext): RVDCoercer = {
if (rvdCoercer == null) {
rvdCoercer = RVD.makeCoercer(
ctx,
RVDType(bodyPType(fullTableType.rowType), fullTableType.key),
1,
toContextRDD(ctx.fs, fullTableType.keyType))
}
rvdCoercer
}

def toTableValue(ctx: ExecuteContext, requestedType: TableType): TableValue = {
val requestedRowType = requestedType.rowType
val requestedRowPType = bodyPType(requestedType.rowType)
val crdd = toContextRDD(ctx.fs, requestedRowType)

val rvd = partitioner match {
case Some(partitioner) =>
RVD(
RVDType(requestedRowPType, fullTableType.key),
partitioner,
crdd)
case None if requestedType.key.isEmpty =>
RVD(
RVDType(requestedRowPType, fullTableType.key),
RVDPartitioner.unkeyed(ctx.stateManager, contexts.length),
crdd)
case None =>
getRVDCoercer(ctx).coerce(RVDType(requestedRowPType, fullTableType.key), crdd)
}

TableValue(ctx,
requestedType,
BroadcastRow(ctx, globals(requestedType.globalType), requestedType.globalType),
rvd)
}
}
8 changes: 1 addition & 7 deletions hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class StringTableReader(

override def pathsUsed: Seq[String] = params.files

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val fs = ctx.fs
val lines = GenericLines.read(fs, fileStatuses, None, None, params.minPartitions, params.forceBGZ, params.forceGZ,
params.filePerPartition)
Expand All @@ -168,12 +168,6 @@ class StringTableReader(
)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
val ts = lower(ctx, requestedType)
val (broadCastRow, rvd) = TableStageToRVD.apply(ctx, ts)
TableValue(ctx, requestedType, broadCastRow, rvd)
}

override def partitionCounts: Option[IndexedSeq[Long]] = None

override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
Expand Down
42 changes: 27 additions & 15 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ object LoweredTableReader {
tableStage,
keyType.fieldNames.map(f => SortField(f, Ascending)),
RTable(rowRType, globRType, FastSeq())
).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]))
)._lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]))
}
}
}
Expand Down Expand Up @@ -481,7 +481,10 @@ trait TableReaderWithExtraUID extends TableReader {
abstract class TableReader {
def pathsUsed: Seq[String]

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue
def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
TableExecuteIntermediate(_lower(ctx, requestedType))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this interface change. One thing I'd suggest is to remove the TableExecuteIntermediate companion object and construct TableStageIntermediate and TableValueIntermediate directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not a bad change, but given that it will all be going away when everything is lowered, I don't think there's much benefit.

}

def partitionCounts: Option[IndexedSeq[Long]]

Expand All @@ -506,8 +509,20 @@ abstract class TableReader {
def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")

def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented")
final def lower(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableStage =
if (dropRows) {
val globals = lowerGlobals(ctx, requestedType.globalType)

TableStage(
globals,
RVDPartitioner.empty(ctx, requestedType.keyType),
TableStageDependency.none,
MakeStream(FastIndexedSeq(), TStream(TStruct.empty)),
(_: Ref) => MakeStream(FastIndexedSeq(), TStream(requestedType.rowType)))
} else
_lower(ctx, requestedType)

def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage
Copy link
Collaborator

@ehigham ehigham Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what this change adds other than a template pattern which I generally avoid whenever possible. I preferred keeping this logic in LowerTableIr.

How is _lower different to lower? Who else uses lower with dropRows other than TableRead? _lower seems like something that shouldn't be public and calling it directly seems like a mistake.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pattern we use occasionally. _lower should really be protected, I'll make that change, and it is only called directly in lower, which is final. So _lower is the customization point for subclasses.

Copy link
Collaborator

@ehigham ehigham Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I was trying to understand is why do we need _lower at all when TableRead seems to be the only place that uses dropRows?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I understand now. That's fair. Addressed.

}

object TableNativeReader {
Expand Down Expand Up @@ -1409,9 +1424,6 @@ class TableNativeReader(
VirtualTypeWithReq(tcoerce[PStruct](spec.globalsComponent.rvdSpec(ctx.fs, params.path)
.typedCodecSpec.encodedType.decodedPType(requestedType.globalType)))

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)

override def toJValue: JValue = {
implicit val formats: Formats = DefaultFormats
decomposeWithName(params, "TableNativeReader")
Expand Down Expand Up @@ -1440,7 +1452,7 @@ class TableNativeReader(
0)
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = lowerGlobals(ctx, requestedType.globalType)
val rowsSpec = spec.rowsSpec
val specPart = rowsSpec.partitioner(ctx.stateManager)
Expand Down Expand Up @@ -1525,9 +1537,6 @@ case class TableNativeZippedReader(
(t, mk)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = {
val globalsSpec = specLeft.globalsSpec
val globalsPath = specLeft.globalsComponent.absolutePath(pathLeft)
Expand All @@ -1539,7 +1548,7 @@ case class TableNativeZippedReader(
0)
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = lowerGlobals(ctx, requestedType.globalType)
val rowsSpec = specLeft.rowsSpec
val specPart = rowsSpec.partitioner(ctx.stateManager)
Expand Down Expand Up @@ -1612,7 +1621,7 @@ case class TableFromBlockMatrixNativeReader(
override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
VirtualTypeWithReq(PCanonicalStruct.empty(required = true))

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
val rowsRDD = new BlockMatrixReadRowBlockedRDD(
ctx.fsBc, params.path, partitionRanges, requestedType.rowType, metadata,
maybeMaximumCacheMemoryInBytes = params.maximumCacheMemoryInBytes)
Expand All @@ -1622,9 +1631,12 @@ case class TableFromBlockMatrixNativeReader(

val rowTyp = PType.canonical(requestedType.rowType, required = true).asInstanceOf[PStruct]
val rvd = RVD(RVDType(rowTyp, fullType.key.filter(rowTyp.hasField)), partitioner, ContextRDD(rowsRDD))
TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd)
TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd))
}

def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented")

override def toJValue: JValue = {
decomposeWithName(params, "TableFromBlockMatrixNativeReader")(TableReader.formats)
}
Expand Down Expand Up @@ -1666,7 +1678,7 @@ case class TableRead(typ: TableType, dropRows: Boolean, tr: TableReader) extends
}

protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate =
new TableValueIntermediate(tr.apply(ctx, typ, dropRows))
tr.toExecuteIntermediate(ctx, typ, dropRows)
}

case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) extends TableIR {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ object LowerDistributedSort {

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
}

override def isDistinctlyKeyed: Boolean = false // FIXME: No default value

def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = {
Expand All @@ -94,7 +89,7 @@ object LowerDistributedSort {
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
PruneDeadFields.upcast(ctx, globals, requestedGlobalsType)

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
TableStage(
globals = globals,
partitioner = partitioner.coarsen(requestedType.key.length),
Expand Down Expand Up @@ -612,11 +607,6 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
}

override def isDistinctlyKeyed: Boolean = false // FIXME: No default value

def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = {
Expand All @@ -636,7 +626,7 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
PruneDeadFields.upcast(ctx, globals, requestedGlobalsType)

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {

val contextData = {
var filesCount: Long = 0
Expand Down
14 changes: 2 additions & 12 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ class TableStage(
rightWithPartNums,
SortField("__partNum", Ascending) +: right.key.map(k => SortField(k, Ascending)),
rightTableRType)
val sorted = sortedReader.lower(ctx, sortedReader.fullType)
val sorted = sortedReader.lower(ctx, sortedReader.fullType, false)
assert(sorted.kType.fieldNames.sameElements("__partNum" +: right.key))
val newRightPartitioner = new RVDPartitioner(
ctx.stateManager,
Expand Down Expand Up @@ -785,17 +785,7 @@ object LowerTableIR {

val lowered: TableStage = tir match {
case TableRead(typ, dropRows, reader) =>
if (dropRows) {
val globals = reader.lowerGlobals(ctx, typ.globalType)

TableStage(
globals,
RVDPartitioner.empty(ctx, typ.keyType),
TableStageDependency.none,
MakeStream(FastIndexedSeq(), TStream(TStruct.empty)),
(_: Ref) => MakeStream(FastIndexedSeq(), TStream(typ.rowType)))
} else
reader.lower(ctx, typ)
reader.lower(ctx, typ, dropRows)

case TableParallelize(rowsAndGlobal, nPartitions) =>
val nPartitionsAdj = nPartitions.getOrElse(16)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
val (Some(PTypeReferenceSingleCodeType(globType: PStruct)), f) = Compile[AsmFunction1RegionLong](
ctx, FastIndexedSeq(), FastIndexedSeq(classInfo[Region]), LongInfo, PruneDeadFields.upcast(ctx, globals, requestedType.globalType))
Expand All @@ -43,10 +43,10 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader
requestedType.rowType))

val fsBc = ctx.fsBc
TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) =>
TableExecuteIntermediate(TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) =>
val partF = rowF(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), ctx.partitionRegion)
it.map { elt => partF(ctx.r, elt) }
})
}))
}

override def isDistinctlyKeyed: Boolean = false
Expand All @@ -68,7 +68,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
PruneDeadFields.upcast(ctx, globals, requestedGlobalsType)

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
RVDToTableStage(rvd, globals)
.upcast(ctx, requestedType)
}
Expand Down
7 changes: 1 addition & 6 deletions hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ class AvroTableReader(

def renderShort(): String = defaultRender()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
val ts = lower(ctx, requestedType)
new TableStageIntermediate(ts).asTableValue(ctx)
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = MakeStruct(FastIndexedSeq())
val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) =>
MakeStruct(Array("partitionPath" -> path, "partitionIndex" -> Cast(idx, TInt64)))
Expand Down
13 changes: 1 addition & 12 deletions hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -479,17 +479,6 @@ class MatrixBGENReader(
override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
VirtualTypeWithReq(PType.canonical(requestedType.globalType, required = true))

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {

val _lc = lower(ctx, requestedType)
val lc = if (dropRows)
_lc.copy(partitioner = _lc.partitioner.copy(rangeBounds = Array[Interval]()),
contexts = StreamTake(_lc.contexts, 0))
else _lc

TableExecuteIntermediate(lc).asTableValue(ctx)
}

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalType: TStruct): IR = {
requestedGlobalType.fieldOption(LowerMatrixIR.colsFieldName) match {
case Some(f) =>
Expand Down Expand Up @@ -520,7 +509,7 @@ class MatrixBGENReader(
case _ => false
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {

val globals = lowerGlobals(ctx, requestedType.globalType)
variants match {
Expand Down
5 changes: 1 addition & 4 deletions hail/src/main/scala/is/hail/io/plink/LoadPlink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,13 @@ class MatrixPLINKReader(
body)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
executeGeneric(ctx).toTableValue(ctx, requestedType)

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = {
val tt = fullMatrixType.toTableType(LowerMatrixIR.entriesFieldName, LowerMatrixIR.colsFieldName)
val subset = tt.globalType.valueSubsetter(requestedGlobalsType)
Literal(requestedGlobalsType, subset(globals).asInstanceOf[Row])
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
executeGeneric(ctx).toTableStage(ctx, requestedType, "PLINK file", params)

override def toJValue: JValue = {
Expand Down
Loading