From 3b817ba6d798cf566ee6a41474a479952fa78fd0 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 10 Jul 2023 15:27:32 -0400 Subject: [PATCH] revert _lower refactoring --- .../is/hail/expr/ir/StringTableReader.scala | 7 +++-- .../main/scala/is/hail/expr/ir/TableIR.scala | 31 ++++++------------- .../ir/lowering/LowerDistributedSort.scala | 4 +-- .../hail/expr/ir/lowering/LowerTableIR.scala | 14 +++++++-- .../expr/ir/lowering/RVDToTableStage.scala | 2 +- .../is/hail/io/avro/AvroTableReader.scala | 7 +++-- .../main/scala/is/hail/io/bgen/LoadBgen.scala | 2 +- .../scala/is/hail/io/plink/LoadPlink.scala | 2 +- .../main/scala/is/hail/io/vcf/LoadVCF.scala | 2 +- .../scala/is/hail/expr/ir/PruneSuite.scala | 4 ++- .../scala/is/hail/expr/ir/TableIRSuite.scala | 4 ++- .../lowering/LowerDistributedSortSuite.scala | 2 +- 12 files changed, 45 insertions(+), 36 deletions(-) diff --git a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala index 9477c8dbd88..61865ed5449 100644 --- a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala +++ b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala @@ -3,7 +3,7 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir.functions.StringFunctions -import is.hail.expr.ir.lowering.{TableStage, TableStageDependency, TableStageToRVD} +import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency, TableStageToRVD} import is.hail.expr.ir.streams.StreamProducer import is.hail.io.fs.{FS, FileStatus} import is.hail.rvd.RVDPartitioner @@ -155,7 +155,7 @@ class StringTableReader( override def pathsUsed: Seq[String] = params.files - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val fs = ctx.fs val lines = GenericLines.read(fs, fileStatuses, None, None, params.minPartitions, params.forceBGZ, params.forceGZ, params.filePerPartition) @@ -168,6 +168,9 @@ class StringTableReader( ) } + override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = + throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") + override def partitionCounts: Option[IndexedSeq[Long]] = None override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index c93979bca60..ffff0b15c4a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -429,7 +429,7 @@ object LoweredTableReader { tableStage, keyType.fieldNames.map(f => SortField(f, Ascending)), RTable(rowRType, globRType, FastSeq()) - ).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]), dropRows = false) + ).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct])) } } } @@ -483,7 +483,7 @@ abstract class TableReader { def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { assert(!dropRows) - TableExecuteIntermediate(_lower(ctx, requestedType)) + TableExecuteIntermediate(lower(ctx, requestedType)) } def partitionCounts: Option[IndexedSeq[Long]] @@ -506,23 +506,9 @@ abstract class TableReader { StringEscapeUtils.escapeString(JsonMethods.compact(toJValue)) } - def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = - throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") + def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR - final def lower(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableStage = - if (dropRows) { - val globals = lowerGlobals(ctx, requestedType.globalType) - - TableStage( - globals, - RVDPartitioner.empty(ctx, requestedType.keyType), - TableStageDependency.none, - MakeStream(FastIndexedSeq(), TStream(TStruct.empty)), - (_: Ref) => MakeStream(FastIndexedSeq(), TStream(requestedType.rowType))) - } else - _lower(ctx, requestedType) - - protected def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage + def lower(ctx: ExecuteContext, requestedType: TableType): TableStage } object TableNativeReader { @@ -1452,7 +1438,7 @@ class TableNativeReader( 0) } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = lowerGlobals(ctx, requestedType.globalType) val rowsSpec = spec.rowsSpec val specPart = rowsSpec.partitioner(ctx.stateManager) @@ -1548,7 +1534,7 @@ case class TableNativeZippedReader( 0) } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = lowerGlobals(ctx, requestedType.globalType) val rowsSpec = specLeft.rowsSpec val specPart = rowsSpec.partitioner(ctx.stateManager) @@ -1634,9 +1620,12 @@ case class TableFromBlockMatrixNativeReader( TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd)) } - def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented") + override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = + throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") + override def toJValue: JValue = { decomposeWithName(params, "TableFromBlockMatrixNativeReader")(TableReader.formats) } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala index d5e86f2c12c..950faa9084d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala @@ -89,7 +89,7 @@ object LowerDistributedSort { override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { TableStage( globals = globals, partitioner = partitioner.coarsen(requestedType.key.length), @@ -626,7 +626,7 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val contextData = { var filesCount: Long = 0 diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala index 072958d3c7f..f4a84aade17 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -510,7 +510,7 @@ class TableStage( rightWithPartNums, SortField("__partNum", Ascending) +: right.key.map(k => SortField(k, Ascending)), rightTableRType) - val sorted = sortedReader.lower(ctx, sortedReader.fullType, false) + val sorted = sortedReader.lower(ctx, sortedReader.fullType) assert(sorted.kType.fieldNames.sameElements("__partNum" +: right.key)) val newRightPartitioner = new RVDPartitioner( ctx.stateManager, @@ -785,7 +785,17 @@ object LowerTableIR { val lowered: TableStage = tir match { case TableRead(typ, dropRows, reader) => - reader.lower(ctx, typ, dropRows) + if (dropRows) { + val globals = reader.lowerGlobals(ctx, typ.globalType) + + TableStage( + globals, + RVDPartitioner.empty(ctx, typ.keyType), + TableStageDependency.none, + MakeStream(FastIndexedSeq(), TStream(TStruct.empty)), + (_: Ref) => MakeStream(FastIndexedSeq(), TStream(typ.rowType))) + } else + reader.lower(ctx, typ) case TableParallelize(rowsAndGlobal, nPartitions) => val nPartitionsAdj = nPartitions.getOrElse(16) diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala index 5457184cd98..017f6488b11 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -68,7 +68,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { RVDToTableStage(rvd, globals) .upcast(ctx, requestedType) } diff --git a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala index 959d35b6148..53a4a773581 100644 --- a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala +++ b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala @@ -2,7 +2,7 @@ package is.hail.io.avro import is.hail.backend.ExecuteContext import is.hail.expr.ir._ -import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} +import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency} import is.hail.rvd.RVDPartitioner import is.hail.types.physical.{PCanonicalStruct, PCanonicalTuple, PInt64Required} import is.hail.types.virtual._ @@ -44,7 +44,7 @@ class AvroTableReader( def renderShort(): String = defaultRender() - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = MakeStruct(FastIndexedSeq()) val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) => MakeStruct(Array("partitionPath" -> path, "partitionIndex" -> Cast(idx, TInt64))) @@ -59,6 +59,9 @@ class AvroTableReader( } ) } + + override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = + throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") } object AvroTableReader { diff --git a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala index 88409b981b5..5d76f90f8e9 100644 --- a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala +++ b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala @@ -509,7 +509,7 @@ class MatrixBGENReader( case _ => false } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = lowerGlobals(ctx, requestedType.globalType) variants match { diff --git a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala index e58a592aa9d..5ffbc1f43be 100644 --- a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -491,7 +491,7 @@ class MatrixPLINKReader( Literal(requestedGlobalsType, subset(globals).asInstanceOf[Row]) } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = executeGeneric(ctx).toTableStage(ctx, requestedType, "PLINK file", params) override def toJValue: JValue = { diff --git a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala index e854d41aec3..cf5801a3bfb 100644 --- a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -1962,7 +1962,7 @@ class MatrixVCFReader( .apply(globals)) } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = executeGeneric(ctx).toTableStage(ctx, requestedType, "VCF", params) override def toJValue: JValue = { diff --git a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala index b13ef7c9a1e..8e2e460e567 100644 --- a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala @@ -115,7 +115,9 @@ class PruneSuite extends HailSuite { def pathsUsed: IndexedSeq[String] = FastSeq() - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? + + override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ??? def partitionCounts: Option[IndexedSeq[Long]] = ??? diff --git a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala index df955fc0555..9d7b0c9b73f 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala @@ -826,7 +826,9 @@ class TableIRSuite extends HailSuite { def pathsUsed: Seq[String] = FastSeq() - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? + + override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ??? override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastIndexedSeq(1, 2, 3, 4)) diff --git a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index 7c956e88441..a7af487cbd9 100644 --- a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -42,7 +42,7 @@ class LowerDistributedSortSuite extends HailSuite { val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses) val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) - .lower(ctx, myTable.typ.copy(key = FastIndexedSeq()), dropRows = false) + .lower(ctx, myTable.typ.copy(key = FastIndexedSeq())) val res = TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[IndexedSeq[Row]]].flatten val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2