Skip to content

Commit

Permalink
don’t force TableRead to use TableValue
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Jul 7, 2023
1 parent b7cc5f3 commit 424b5f3
Show file tree
Hide file tree
Showing 13 changed files with 58 additions and 58 deletions.
4 changes: 2 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ class GenericTableValue(
assert(contextType.hasField("partitionIndex"))
assert(contextType.fieldType("partitionIndex") == TInt32)

var ltrCoercer: LoweredTableReaderCoercer = _
def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = {
private var ltrCoercer: LoweredTableReaderCoercer = _
private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = {
if (ltrCoercer == null) {
ltrCoercer = LoweredTableReader.makeCoercer(
ctx,
Expand Down
8 changes: 3 additions & 5 deletions hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class StringTableReader(

override def pathsUsed: Seq[String] = params.files

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val fs = ctx.fs
val lines = GenericLines.read(fs, fileStatuses, None, None, params.minPartitions, params.forceBGZ, params.forceGZ,
params.filePerPartition)
Expand All @@ -168,10 +168,8 @@ class StringTableReader(
)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
val ts = lower(ctx, requestedType)
val (broadCastRow, rvd) = TableStageToRVD.apply(ctx, ts)
TableValue(ctx, requestedType, broadCastRow, rvd)
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
TableExecuteIntermediate(_lower(ctx, requestedType))
}

override def partitionCounts: Option[IndexedSeq[Long]] = None
Expand Down
37 changes: 25 additions & 12 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ object LoweredTableReader {
tableStage,
keyType.fieldNames.map(f => SortField(f, Ascending)),
RTable(rowRType, globRType, FastSeq())
).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]))
)._lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]))
}
}
}
Expand Down Expand Up @@ -481,7 +481,7 @@ trait TableReaderWithExtraUID extends TableReader {
abstract class TableReader {
def pathsUsed: Seq[String]

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue
def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate

def partitionCounts: Option[IndexedSeq[Long]]

Expand All @@ -506,7 +506,20 @@ abstract class TableReader {
def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")

def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
final def lower(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableStage =
if (dropRows) {
val globals = lowerGlobals(ctx, requestedType.globalType)

TableStage(
globals,
RVDPartitioner.empty(ctx, requestedType.keyType),
TableStageDependency.none,
MakeStream(FastIndexedSeq(), TStream(TStruct.empty)),
(_: Ref) => MakeStream(FastIndexedSeq(), TStream(requestedType.rowType)))
} else
_lower(ctx, requestedType)

def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented")
}

Expand Down Expand Up @@ -1409,8 +1422,8 @@ class TableNativeReader(
VirtualTypeWithReq(tcoerce[PStruct](spec.globalsComponent.rvdSpec(ctx.fs, params.path)
.typedCodecSpec.encodedType.decodedPType(requestedType.globalType)))

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate =
TableExecuteIntermediate(_lower(ctx, requestedType))

override def toJValue: JValue = {
implicit val formats: Formats = DefaultFormats
Expand Down Expand Up @@ -1440,7 +1453,7 @@ class TableNativeReader(
0)
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = lowerGlobals(ctx, requestedType.globalType)
val rowsSpec = spec.rowsSpec
val specPart = rowsSpec.partitioner(ctx.stateManager)
Expand Down Expand Up @@ -1525,8 +1538,8 @@ case class TableNativeZippedReader(
(t, mk)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate =
TableExecuteIntermediate(_lower(ctx, requestedType))

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = {
val globalsSpec = specLeft.globalsSpec
Expand All @@ -1539,7 +1552,7 @@ case class TableNativeZippedReader(
0)
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = lowerGlobals(ctx, requestedType.globalType)
val rowsSpec = specLeft.rowsSpec
val specPart = rowsSpec.partitioner(ctx.stateManager)
Expand Down Expand Up @@ -1612,7 +1625,7 @@ case class TableFromBlockMatrixNativeReader(
override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
VirtualTypeWithReq(PCanonicalStruct.empty(required = true))

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
val rowsRDD = new BlockMatrixReadRowBlockedRDD(
ctx.fsBc, params.path, partitionRanges, requestedType.rowType, metadata,
maybeMaximumCacheMemoryInBytes = params.maximumCacheMemoryInBytes)
Expand All @@ -1622,7 +1635,7 @@ case class TableFromBlockMatrixNativeReader(

val rowTyp = PType.canonical(requestedType.rowType, required = true).asInstanceOf[PStruct]
val rvd = RVD(RVDType(rowTyp, fullType.key.filter(rowTyp.hasField)), partitioner, ContextRDD(rowsRDD))
TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd)
TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd))
}

override def toJValue: JValue = {
Expand Down Expand Up @@ -1666,7 +1679,7 @@ case class TableRead(typ: TableType, dropRows: Boolean, tr: TableReader) extends
}

protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate =
new TableValueIntermediate(tr.apply(ctx, typ, dropRows))
tr.toExecuteIntermediate(ctx, typ, dropRows)
}

case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) extends TableIR {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ object LowerDistributedSort {

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
TableExecuteIntermediate(_lower(ctx, requestedType))
}

override def isDistinctlyKeyed: Boolean = false // FIXME: No default value
Expand All @@ -94,7 +94,7 @@ object LowerDistributedSort {
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
PruneDeadFields.upcast(ctx, globals, requestedGlobalsType)

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
TableStage(
globals = globals,
partitioner = partitioner.coarsen(requestedType.key.length),
Expand Down Expand Up @@ -612,9 +612,9 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx)
TableExecuteIntermediate(_lower(ctx, requestedType))
}

override def isDistinctlyKeyed: Boolean = false // FIXME: No default value
Expand All @@ -636,7 +636,7 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
PruneDeadFields.upcast(ctx, globals, requestedGlobalsType)

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {

val contextData = {
var filesCount: Long = 0
Expand Down
12 changes: 1 addition & 11 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -785,17 +785,7 @@ object LowerTableIR {

val lowered: TableStage = tir match {
case TableRead(typ, dropRows, reader) =>
if (dropRows) {
val globals = reader.lowerGlobals(ctx, typ.globalType)

TableStage(
globals,
RVDPartitioner.empty(ctx, typ.keyType),
TableStageDependency.none,
MakeStream(FastIndexedSeq(), TStream(TStruct.empty)),
(_: Ref) => MakeStream(FastIndexedSeq(), TStream(typ.rowType)))
} else
reader.lower(ctx, typ)
reader.lower(ctx, typ, dropRows)

case TableParallelize(rowsAndGlobal, nPartitions) =>
val nPartitionsAdj = nPartitions.getOrElse(16)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader

override def partitionCounts: Option[IndexedSeq[Long]] = None

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
val (Some(PTypeReferenceSingleCodeType(globType: PStruct)), f) = Compile[AsmFunction1RegionLong](
ctx, FastIndexedSeq(), FastIndexedSeq(classInfo[Region]), LongInfo, PruneDeadFields.upcast(ctx, globals, requestedType.globalType))
Expand All @@ -43,10 +43,10 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader
requestedType.rowType))

val fsBc = ctx.fsBc
TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) =>
TableExecuteIntermediate(TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) =>
val partF = rowF(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), ctx.partitionRegion)
it.map { elt => partF(ctx.r, elt) }
})
}))
}

override def isDistinctlyKeyed: Boolean = false
Expand All @@ -68,7 +68,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
PruneDeadFields.upcast(ctx, globals, requestedGlobalsType)

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
RVDToTableStage(rvd, globals)
.upcast(ctx, requestedType)
}
Expand Down
7 changes: 3 additions & 4 deletions hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@ class AvroTableReader(

def renderShort(): String = defaultRender()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
val ts = lower(ctx, requestedType)
new TableStageIntermediate(ts).asTableValue(ctx)
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
TableExecuteIntermediate(_lower(ctx, requestedType))
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = MakeStruct(FastIndexedSeq())
val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) =>
MakeStruct(Array("partitionPath" -> path, "partitionIndex" -> Cast(idx, TInt64)))
Expand Down
8 changes: 4 additions & 4 deletions hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,15 @@ class MatrixBGENReader(
override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
VirtualTypeWithReq(PType.canonical(requestedType.globalType, required = true))

def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = {
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {

val _lc = lower(ctx, requestedType)
val _lc = _lower(ctx, requestedType)
val lc = if (dropRows)
_lc.copy(partitioner = _lc.partitioner.copy(rangeBounds = Array[Interval]()),
contexts = StreamTake(_lc.contexts, 0))
else _lc

TableExecuteIntermediate(lc).asTableValue(ctx)
TableExecuteIntermediate(lc)
}

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalType: TStruct): IR = {
Expand Down Expand Up @@ -520,7 +520,7 @@ class MatrixBGENReader(
case _ => false
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {

val globals = lowerGlobals(ctx, requestedType.globalType)
variants match {
Expand Down
6 changes: 3 additions & 3 deletions hail/src/main/scala/is/hail/io/plink/LoadPlink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -485,16 +485,16 @@ class MatrixPLINKReader(
body)
}

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
executeGeneric(ctx).toTableValue(ctx, requestedType)
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate =
TableExecuteIntermediate(_lower(ctx, requestedType))

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = {
val tt = fullMatrixType.toTableType(LowerMatrixIR.entriesFieldName, LowerMatrixIR.colsFieldName)
val subset = tt.globalType.valueSubsetter(requestedGlobalsType)
Literal(requestedGlobalsType, subset(globals).asInstanceOf[Row])
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
executeGeneric(ctx).toTableStage(ctx, requestedType, "PLINK file", params)

override def toJValue: JValue = {
Expand Down
8 changes: 4 additions & 4 deletions hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import is.hail.backend.{BroadcastValue, ExecuteContext, HailStateManager}
import is.hail.expr.JSONAnnotationImpex
import is.hail.expr.ir.lowering.TableStage
import is.hail.expr.ir.streams.StreamProducer
import is.hail.expr.ir.{CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, MatrixReader, PartitionReader, TableValue}
import is.hail.expr.ir.{CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, MatrixReader, TableExecuteIntermediate, PartitionReader, TableValue}
import is.hail.io.fs.{FS, FileStatus}
import is.hail.io.tabix._
import is.hail.io.vcf.LoadVCF.{getHeaderLines, parseHeader}
Expand Down Expand Up @@ -1962,11 +1962,11 @@ class MatrixVCFReader(
.apply(globals))
}

override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
executeGeneric(ctx).toTableStage(ctx, requestedType, "VCF", params)

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue =
executeGeneric(ctx, dropRows).toTableValue(ctx, requestedType)
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate =
TableExecuteIntermediate(_lower(ctx, requestedType))

override def toJValue: JValue = {
implicit val formats: Formats = DefaultFormats
Expand Down
2 changes: 1 addition & 1 deletion hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class PruneSuite extends HailSuite {

def pathsUsed: IndexedSeq[String] = FastSeq()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = ???
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = ???

def partitionCounts: Option[IndexedSeq[Long]] = ???

Expand Down
2 changes: 1 addition & 1 deletion hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ class TableIRSuite extends HailSuite {

def pathsUsed: Seq[String] = FastSeq()

override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = ???
override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = ???

override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastIndexedSeq(1, 2, 3, 4))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LowerDistributedSortSuite extends HailSuite {
val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses)

val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt)
.lower(ctx, myTable.typ.copy(key = FastIndexedSeq()))
._lower(ctx, myTable.typ.copy(key = FastIndexedSeq()))
val res = TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[IndexedSeq[Row]]].flatten

val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2
Expand Down

0 comments on commit 424b5f3

Please sign in to comment.