Skip to content

Commit

Permalink
revert _lower refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Jul 10, 2023
1 parent 9f2051c commit 3b817ba
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 36 deletions.
7 changes: 5 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import is.hail.annotations.Region
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.functions.StringFunctions
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency, TableStageToRVD}
import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency, TableStageToRVD}
import is.hail.expr.ir.streams.StreamProducer
import is.hail.io.fs.{FS, FileStatus}
import is.hail.rvd.RVDPartitioner
Expand Down Expand Up @@ -155,7 +155,7 @@ class StringTableReader(

override def pathsUsed: Seq[String] = params.files

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val fs = ctx.fs
val lines = GenericLines.read(fs, fileStatuses, None, None, params.minPartitions, params.forceBGZ, params.forceGZ,
params.filePerPartition)
Expand All @@ -168,6 +168,9 @@ class StringTableReader(
)
}

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")

override def partitionCounts: Option[IndexedSeq[Long]] = None

override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq =
Expand Down
31 changes: 10 additions & 21 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ object LoweredTableReader {
tableStage,
keyType.fieldNames.map(f => SortField(f, Ascending)),
RTable(rowRType, globRType, FastSeq())
).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]), dropRows = false)
).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]))
}
}
}
Expand Down Expand Up @@ -483,7 +483,7 @@ abstract class TableReader {

def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = {
assert(!dropRows)
TableExecuteIntermediate(_lower(ctx, requestedType))
TableExecuteIntermediate(lower(ctx, requestedType))
}

def partitionCounts: Option[IndexedSeq[Long]]
Expand All @@ -506,23 +506,9 @@ abstract class TableReader {
StringEscapeUtils.escapeString(JsonMethods.compact(toJValue))
}

def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")
def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR

final def lower(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableStage =
if (dropRows) {
val globals = lowerGlobals(ctx, requestedType.globalType)

TableStage(
globals,
RVDPartitioner.empty(ctx, requestedType.keyType),
TableStageDependency.none,
MakeStream(FastIndexedSeq(), TStream(TStruct.empty)),
(_: Ref) => MakeStream(FastIndexedSeq(), TStream(requestedType.rowType)))
} else
_lower(ctx, requestedType)

protected def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage
def lower(ctx: ExecuteContext, requestedType: TableType): TableStage
}

object TableNativeReader {
Expand Down Expand Up @@ -1452,7 +1438,7 @@ class TableNativeReader(
0)
}

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = lowerGlobals(ctx, requestedType.globalType)
val rowsSpec = spec.rowsSpec
val specPart = rowsSpec.partitioner(ctx.stateManager)
Expand Down Expand Up @@ -1548,7 +1534,7 @@ case class TableNativeZippedReader(
0)
}

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = lowerGlobals(ctx, requestedType.globalType)
val rowsSpec = specLeft.rowsSpec
val specPart = rowsSpec.partitioner(ctx.stateManager)
Expand Down Expand Up @@ -1634,9 +1620,12 @@ case class TableFromBlockMatrixNativeReader(
TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd))
}

def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented")

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")

override def toJValue: JValue = {
decomposeWithName(params, "TableFromBlockMatrixNativeReader")(TableReader.formats)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ object LowerDistributedSort {
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
PruneDeadFields.upcast(ctx, globals, requestedGlobalsType)

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
TableStage(
globals = globals,
partitioner = partitioner.coarsen(requestedType.key.length),
Expand Down Expand Up @@ -626,7 +626,7 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
PruneDeadFields.upcast(ctx, globals, requestedGlobalsType)

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {

val contextData = {
var filesCount: Long = 0
Expand Down
14 changes: 12 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ class TableStage(
rightWithPartNums,
SortField("__partNum", Ascending) +: right.key.map(k => SortField(k, Ascending)),
rightTableRType)
val sorted = sortedReader.lower(ctx, sortedReader.fullType, false)
val sorted = sortedReader.lower(ctx, sortedReader.fullType)
assert(sorted.kType.fieldNames.sameElements("__partNum" +: right.key))
val newRightPartitioner = new RVDPartitioner(
ctx.stateManager,
Expand Down Expand Up @@ -785,7 +785,17 @@ object LowerTableIR {

val lowered: TableStage = tir match {
case TableRead(typ, dropRows, reader) =>
reader.lower(ctx, typ, dropRows)
if (dropRows) {
val globals = reader.lowerGlobals(ctx, typ.globalType)

TableStage(
globals,
RVDPartitioner.empty(ctx, typ.keyType),
TableStageDependency.none,
MakeStream(FastIndexedSeq(), TStream(TStruct.empty)),
(_: Ref) => MakeStream(FastIndexedSeq(), TStream(typ.rowType)))
} else
reader.lower(ctx, typ)

case TableParallelize(rowsAndGlobal, nPartitions) =>
val nPartitionsAdj = nPartitions.getOrElse(16)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader
override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
PruneDeadFields.upcast(ctx, globals, requestedGlobalsType)

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
RVDToTableStage(rvd, globals)
.upcast(ctx, requestedType)
}
Expand Down
7 changes: 5 additions & 2 deletions hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package is.hail.io.avro

import is.hail.backend.ExecuteContext
import is.hail.expr.ir._
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency}
import is.hail.rvd.RVDPartitioner
import is.hail.types.physical.{PCanonicalStruct, PCanonicalTuple, PInt64Required}
import is.hail.types.virtual._
Expand Down Expand Up @@ -44,7 +44,7 @@ class AvroTableReader(

def renderShort(): String = defaultRender()

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
val globals = MakeStruct(FastIndexedSeq())
val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) =>
MakeStruct(Array("partitionPath" -> path, "partitionIndex" -> Cast(idx, TInt64)))
Expand All @@ -59,6 +59,9 @@ class AvroTableReader(
}
)
}

override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR =
throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented")
}

object AvroTableReader {
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ class MatrixBGENReader(
case _ => false
}

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = {

val globals = lowerGlobals(ctx, requestedType.globalType)
variants match {
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/io/plink/LoadPlink.scala
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ class MatrixPLINKReader(
Literal(requestedGlobalsType, subset(globals).asInstanceOf[Row])
}

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
executeGeneric(ctx).toTableStage(ctx, requestedType, "PLINK file", params)

override def toJValue: JValue = {
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1962,7 +1962,7 @@ class MatrixVCFReader(
.apply(globals))
}

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage =
executeGeneric(ctx).toTableStage(ctx, requestedType, "VCF", params)

override def toJValue: JValue = {
Expand Down
4 changes: 3 additions & 1 deletion hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ class PruneSuite extends HailSuite {

def pathsUsed: IndexedSeq[String] = FastSeq()

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ???
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ???

override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ???

def partitionCounts: Option[IndexedSeq[Long]] = ???

Expand Down
4 changes: 3 additions & 1 deletion hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,9 @@ class TableIRSuite extends HailSuite {

def pathsUsed: Seq[String] = FastSeq()

override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ???
override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ???

override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ???

override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastIndexedSeq(1, 2, 3, 4))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LowerDistributedSortSuite extends HailSuite {
val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses)

val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt)
.lower(ctx, myTable.typ.copy(key = FastIndexedSeq()), dropRows = false)
.lower(ctx, myTable.typ.copy(key = FastIndexedSeq()))
val res = TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[IndexedSeq[Row]]].flatten

val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2
Expand Down

0 comments on commit 3b817ba

Please sign in to comment.