From ce9ec06dc32cba54fc6956e280c66d66627a5602 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 7 Jul 2023 16:41:59 -0400 Subject: [PATCH 1/5] =?UTF-8?q?don=E2=80=99t=20force=20TableRead=20to=20us?= =?UTF-8?q?e=20TableValue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../is/hail/expr/ir/GenericTableValue.scala | 4 +- .../is/hail/expr/ir/StringTableReader.scala | 8 ++-- .../main/scala/is/hail/expr/ir/TableIR.scala | 37 +++++++++++++------ .../ir/lowering/LowerDistributedSort.scala | 12 +++--- .../hail/expr/ir/lowering/LowerTableIR.scala | 14 +------ .../expr/ir/lowering/RVDToTableStage.scala | 8 ++-- .../is/hail/io/avro/AvroTableReader.scala | 7 ++-- .../main/scala/is/hail/io/bgen/LoadBgen.scala | 8 ++-- .../scala/is/hail/io/plink/LoadPlink.scala | 6 +-- .../main/scala/is/hail/io/vcf/LoadVCF.scala | 8 ++-- .../scala/is/hail/expr/ir/PruneSuite.scala | 2 +- .../scala/is/hail/expr/ir/TableIRSuite.scala | 2 +- .../lowering/LowerDistributedSortSuite.scala | 2 +- 13 files changed, 59 insertions(+), 59 deletions(-) diff --git a/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala b/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala index b009fe92ab7..a8694210e2d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala +++ b/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala @@ -164,8 +164,8 @@ class GenericTableValue( assert(contextType.hasField("partitionIndex")) assert(contextType.fieldType("partitionIndex") == TInt32) - var ltrCoercer: LoweredTableReaderCoercer = _ - def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = { + private var ltrCoercer: LoweredTableReaderCoercer = _ + private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = { if (ltrCoercer == null) { ltrCoercer = LoweredTableReader.makeCoercer( ctx, diff --git a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala index 6ad351100d8..e18e538acd5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala +++ b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala @@ -155,7 +155,7 @@ class StringTableReader( override def pathsUsed: Seq[String] = params.files - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val fs = ctx.fs val lines = GenericLines.read(fs, fileStatuses, None, None, params.minPartitions, params.forceBGZ, params.forceGZ, params.filePerPartition) @@ -168,10 +168,8 @@ class StringTableReader( ) } - override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = { - val ts = lower(ctx, requestedType) - val (broadCastRow, rvd) = TableStageToRVD.apply(ctx, ts) - TableValue(ctx, requestedType, broadCastRow, rvd) + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { + TableExecuteIntermediate(_lower(ctx, requestedType)) } override def partitionCounts: Option[IndexedSeq[Long]] = None diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index 1d6325b619d..0a662ab48dc 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -429,7 +429,7 @@ object LoweredTableReader { tableStage, keyType.fieldNames.map(f => SortField(f, Ascending)), RTable(rowRType, globRType, FastSeq()) - ).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct])) + )._lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct])) } } } @@ -481,7 +481,7 @@ trait TableReaderWithExtraUID extends TableReader { abstract class TableReader { def pathsUsed: Seq[String] - def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue + def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate def partitionCounts: Option[IndexedSeq[Long]] @@ -506,7 +506,20 @@ abstract class TableReader { def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") - def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + final def lower(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableStage = + if (dropRows) { + val globals = lowerGlobals(ctx, requestedType.globalType) + + TableStage( + globals, + RVDPartitioner.empty(ctx, requestedType.keyType), + TableStageDependency.none, + MakeStream(FastIndexedSeq(), TStream(TStruct.empty)), + (_: Ref) => MakeStream(FastIndexedSeq(), TStream(requestedType.rowType))) + } else + _lower(ctx, requestedType) + + def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented") } @@ -1409,8 +1422,8 @@ class TableNativeReader( VirtualTypeWithReq(tcoerce[PStruct](spec.globalsComponent.rvdSpec(ctx.fs, params.path) .typedCodecSpec.encodedType.decodedPType(requestedType.globalType))) - def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = - TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx) + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = + TableExecuteIntermediate(_lower(ctx, requestedType)) override def toJValue: JValue = { implicit val formats: Formats = DefaultFormats @@ -1440,7 +1453,7 @@ class TableNativeReader( 0) } - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = lowerGlobals(ctx, requestedType.globalType) val rowsSpec = spec.rowsSpec val specPart = rowsSpec.partitioner(ctx.stateManager) @@ -1525,8 +1538,8 @@ case class TableNativeZippedReader( (t, mk) } - override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = - TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx) + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = + TableExecuteIntermediate(_lower(ctx, requestedType)) override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = { val globalsSpec = specLeft.globalsSpec @@ -1539,7 +1552,7 @@ case class TableNativeZippedReader( 0) } - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = lowerGlobals(ctx, requestedType.globalType) val rowsSpec = specLeft.rowsSpec val specPart = rowsSpec.partitioner(ctx.stateManager) @@ -1612,7 +1625,7 @@ case class TableFromBlockMatrixNativeReader( override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalStruct.empty(required = true)) - override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = { + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { val rowsRDD = new BlockMatrixReadRowBlockedRDD( ctx.fsBc, params.path, partitionRanges, requestedType.rowType, metadata, maybeMaximumCacheMemoryInBytes = params.maximumCacheMemoryInBytes) @@ -1622,7 +1635,7 @@ case class TableFromBlockMatrixNativeReader( val rowTyp = PType.canonical(requestedType.rowType, required = true).asInstanceOf[PStruct] val rvd = RVD(RVDType(rowTyp, fullType.key.filter(rowTyp.hasField)), partitioner, ContextRDD(rowsRDD)) - TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd) + TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd)) } override def toJValue: JValue = { @@ -1666,7 +1679,7 @@ case class TableRead(typ: TableType, dropRows: Boolean, tr: TableReader) extends } protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = - new TableValueIntermediate(tr.apply(ctx, typ, dropRows)) + tr.toExecuteIntermediate(ctx, typ, dropRows) } case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) extends TableIR { diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala index 81c35b43e62..15dd8b02814 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala @@ -70,9 +70,9 @@ object LowerDistributedSort { override def partitionCounts: Option[IndexedSeq[Long]] = None - def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = { + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { assert(!dropRows) - TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx) + TableExecuteIntermediate(_lower(ctx, requestedType)) } override def isDistinctlyKeyed: Boolean = false // FIXME: No default value @@ -94,7 +94,7 @@ object LowerDistributedSort { override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { TableStage( globals = globals, partitioner = partitioner.coarsen(requestedType.key.length), @@ -612,9 +612,9 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec override def partitionCounts: Option[IndexedSeq[Long]] = None - def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = { + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { assert(!dropRows) - TableExecuteIntermediate(lower(ctx, requestedType)).asTableValue(ctx) + TableExecuteIntermediate(_lower(ctx, requestedType)) } override def isDistinctlyKeyed: Boolean = false // FIXME: No default value @@ -636,7 +636,7 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val contextData = { var filesCount: Long = 0 diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala index f4a84aade17..072958d3c7f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -510,7 +510,7 @@ class TableStage( rightWithPartNums, SortField("__partNum", Ascending) +: right.key.map(k => SortField(k, Ascending)), rightTableRType) - val sorted = sortedReader.lower(ctx, sortedReader.fullType) + val sorted = sortedReader.lower(ctx, sortedReader.fullType, false) assert(sorted.kType.fieldNames.sameElements("__partNum" +: right.key)) val newRightPartitioner = new RVDPartitioner( ctx.stateManager, @@ -785,17 +785,7 @@ object LowerTableIR { val lowered: TableStage = tir match { case TableRead(typ, dropRows, reader) => - if (dropRows) { - val globals = reader.lowerGlobals(ctx, typ.globalType) - - TableStage( - globals, - RVDPartitioner.empty(ctx, typ.keyType), - TableStageDependency.none, - MakeStream(FastIndexedSeq(), TStream(TStruct.empty)), - (_: Ref) => MakeStream(FastIndexedSeq(), TStream(typ.rowType))) - } else - reader.lower(ctx, typ) + reader.lower(ctx, typ, dropRows) case TableParallelize(rowsAndGlobal, nPartitions) => val nPartitionsAdj = nPartitions.getOrElse(16) diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala index d4adf6b0d1d..5457184cd98 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -28,7 +28,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader override def partitionCounts: Option[IndexedSeq[Long]] = None - def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = { + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { assert(!dropRows) val (Some(PTypeReferenceSingleCodeType(globType: PStruct)), f) = Compile[AsmFunction1RegionLong]( ctx, FastIndexedSeq(), FastIndexedSeq(classInfo[Region]), LongInfo, PruneDeadFields.upcast(ctx, globals, requestedType.globalType)) @@ -43,10 +43,10 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader requestedType.rowType)) val fsBc = ctx.fsBc - TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) => + TableExecuteIntermediate(TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) => val partF = rowF(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), ctx.partitionRegion) it.map { elt => partF(ctx.r, elt) } - }) + })) } override def isDistinctlyKeyed: Boolean = false @@ -68,7 +68,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { RVDToTableStage(rvd, globals) .upcast(ctx, requestedType) } diff --git a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala index 877976a383f..9f1d2624f2d 100644 --- a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala +++ b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala @@ -44,12 +44,11 @@ class AvroTableReader( def renderShort(): String = defaultRender() - override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = { - val ts = lower(ctx, requestedType) - new TableStageIntermediate(ts).asTableValue(ctx) + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { + TableExecuteIntermediate(_lower(ctx, requestedType)) } - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = MakeStruct(FastIndexedSeq()) val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) => MakeStruct(Array("partitionPath" -> path, "partitionIndex" -> Cast(idx, TInt64))) diff --git a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala index dc7e4dba05c..dfa55e805d0 100644 --- a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala +++ b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala @@ -479,15 +479,15 @@ class MatrixBGENReader( override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = VirtualTypeWithReq(PType.canonical(requestedType.globalType, required = true)) - def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = { + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { - val _lc = lower(ctx, requestedType) + val _lc = _lower(ctx, requestedType) val lc = if (dropRows) _lc.copy(partitioner = _lc.partitioner.copy(rangeBounds = Array[Interval]()), contexts = StreamTake(_lc.contexts, 0)) else _lc - TableExecuteIntermediate(lc).asTableValue(ctx) + TableExecuteIntermediate(lc) } override def lowerGlobals(ctx: ExecuteContext, requestedGlobalType: TStruct): IR = { @@ -520,7 +520,7 @@ class MatrixBGENReader( case _ => false } - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = lowerGlobals(ctx, requestedType.globalType) variants match { diff --git a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala index eb4fc633815..8451bc531b7 100644 --- a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -485,8 +485,8 @@ class MatrixPLINKReader( body) } - override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = - executeGeneric(ctx).toTableValue(ctx, requestedType) + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = + TableExecuteIntermediate(_lower(ctx, requestedType)) override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = { val tt = fullMatrixType.toTableType(LowerMatrixIR.entriesFieldName, LowerMatrixIR.colsFieldName) @@ -494,7 +494,7 @@ class MatrixPLINKReader( Literal(requestedGlobalsType, subset(globals).asInstanceOf[Row]) } - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = executeGeneric(ctx).toTableStage(ctx, requestedType, "PLINK file", params) override def toJValue: JValue = { diff --git a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala index e2cc73bf1c0..d99226e61d7 100644 --- a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -8,7 +8,7 @@ import is.hail.backend.{BroadcastValue, ExecuteContext, HailStateManager} import is.hail.expr.JSONAnnotationImpex import is.hail.expr.ir.lowering.TableStage import is.hail.expr.ir.streams.StreamProducer -import is.hail.expr.ir.{CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, MatrixReader, PartitionReader, TableValue} +import is.hail.expr.ir.{CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, MatrixReader, TableExecuteIntermediate, PartitionReader, TableValue} import is.hail.io.fs.{FS, FileStatus} import is.hail.io.tabix._ import is.hail.io.vcf.LoadVCF.{getHeaderLines, parseHeader} @@ -1962,11 +1962,11 @@ class MatrixVCFReader( .apply(globals)) } - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = executeGeneric(ctx).toTableStage(ctx, requestedType, "VCF", params) - override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = - executeGeneric(ctx, dropRows).toTableValue(ctx, requestedType) + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = + TableExecuteIntermediate(_lower(ctx, requestedType)) override def toJValue: JValue = { implicit val formats: Formats = DefaultFormats diff --git a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala index 334d4262139..84e954db474 100644 --- a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala @@ -114,7 +114,7 @@ class PruneSuite extends HailSuite { def pathsUsed: IndexedSeq[String] = FastSeq() - override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = ??? + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = ??? def partitionCounts: Option[IndexedSeq[Long]] = ??? diff --git a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala index 7dad5ce6e70..cf18bd9ce1c 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala @@ -826,7 +826,7 @@ class TableIRSuite extends HailSuite { def pathsUsed: Seq[String] = FastSeq() - override def apply(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableValue = ??? + override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = ??? override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastIndexedSeq(1, 2, 3, 4)) diff --git a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index a7af487cbd9..cad8bbb4c63 100644 --- a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -42,7 +42,7 @@ class LowerDistributedSortSuite extends HailSuite { val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses) val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) - .lower(ctx, myTable.typ.copy(key = FastIndexedSeq())) + ._lower(ctx, myTable.typ.copy(key = FastIndexedSeq())) val res = TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[IndexedSeq[Row]]].flatten val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2 From 5e8ecddaec0f43c1d975d915ef10ca105186bd20 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 10 Jul 2023 08:46:08 -0400 Subject: [PATCH 2/5] cleanup --- .../is/hail/expr/ir/StringTableReader.scala | 4 ---- .../main/scala/is/hail/expr/ir/TableIR.scala | 17 ++++++++--------- .../expr/ir/lowering/LowerDistributedSort.scala | 10 ---------- .../scala/is/hail/io/avro/AvroTableReader.scala | 4 ---- .../main/scala/is/hail/io/bgen/LoadBgen.scala | 11 ----------- .../main/scala/is/hail/io/plink/LoadPlink.scala | 3 --- .../src/main/scala/is/hail/io/vcf/LoadVCF.scala | 3 --- .../test/scala/is/hail/expr/ir/PruneSuite.scala | 3 ++- .../scala/is/hail/expr/ir/TableIRSuite.scala | 4 ++-- 9 files changed, 12 insertions(+), 47 deletions(-) diff --git a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala index e18e538acd5..9477c8dbd88 100644 --- a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala +++ b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala @@ -168,10 +168,6 @@ class StringTableReader( ) } - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { - TableExecuteIntermediate(_lower(ctx, requestedType)) - } - override def partitionCounts: Option[IndexedSeq[Long]] = None override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index 0a662ab48dc..0d1f7a2b43e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -481,7 +481,10 @@ trait TableReaderWithExtraUID extends TableReader { abstract class TableReader { def pathsUsed: Seq[String] - def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate + def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { + assert(!dropRows) + TableExecuteIntermediate(_lower(ctx, requestedType)) + } def partitionCounts: Option[IndexedSeq[Long]] @@ -519,8 +522,7 @@ abstract class TableReader { } else _lower(ctx, requestedType) - def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = - throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented") + def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage } object TableNativeReader { @@ -1422,9 +1424,6 @@ class TableNativeReader( VirtualTypeWithReq(tcoerce[PStruct](spec.globalsComponent.rvdSpec(ctx.fs, params.path) .typedCodecSpec.encodedType.decodedPType(requestedType.globalType))) - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = - TableExecuteIntermediate(_lower(ctx, requestedType)) - override def toJValue: JValue = { implicit val formats: Formats = DefaultFormats decomposeWithName(params, "TableNativeReader") @@ -1538,9 +1537,6 @@ case class TableNativeZippedReader( (t, mk) } - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = - TableExecuteIntermediate(_lower(ctx, requestedType)) - override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = { val globalsSpec = specLeft.globalsSpec val globalsPath = specLeft.globalsComponent.absolutePath(pathLeft) @@ -1638,6 +1634,9 @@ case class TableFromBlockMatrixNativeReader( TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd)) } + def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented") + override def toJValue: JValue = { decomposeWithName(params, "TableFromBlockMatrixNativeReader")(TableReader.formats) } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala index 15dd8b02814..d5e86f2c12c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala @@ -70,11 +70,6 @@ object LowerDistributedSort { override def partitionCounts: Option[IndexedSeq[Long]] = None - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { - assert(!dropRows) - TableExecuteIntermediate(_lower(ctx, requestedType)) - } - override def isDistinctlyKeyed: Boolean = false // FIXME: No default value def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { @@ -612,11 +607,6 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec override def partitionCounts: Option[IndexedSeq[Long]] = None - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { - assert(!dropRows) - TableExecuteIntermediate(_lower(ctx, requestedType)) - } - override def isDistinctlyKeyed: Boolean = false // FIXME: No default value def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { diff --git a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala index 9f1d2624f2d..959d35b6148 100644 --- a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala +++ b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala @@ -44,10 +44,6 @@ class AvroTableReader( def renderShort(): String = defaultRender() - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { - TableExecuteIntermediate(_lower(ctx, requestedType)) - } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = MakeStruct(FastIndexedSeq()) val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) => diff --git a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala index dfa55e805d0..88409b981b5 100644 --- a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala +++ b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala @@ -479,17 +479,6 @@ class MatrixBGENReader( override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = VirtualTypeWithReq(PType.canonical(requestedType.globalType, required = true)) - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { - - val _lc = _lower(ctx, requestedType) - val lc = if (dropRows) - _lc.copy(partitioner = _lc.partitioner.copy(rangeBounds = Array[Interval]()), - contexts = StreamTake(_lc.contexts, 0)) - else _lc - - TableExecuteIntermediate(lc) - } - override def lowerGlobals(ctx: ExecuteContext, requestedGlobalType: TStruct): IR = { requestedGlobalType.fieldOption(LowerMatrixIR.colsFieldName) match { case Some(f) => diff --git a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala index 8451bc531b7..e58a592aa9d 100644 --- a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -485,9 +485,6 @@ class MatrixPLINKReader( body) } - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = - TableExecuteIntermediate(_lower(ctx, requestedType)) - override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = { val tt = fullMatrixType.toTableType(LowerMatrixIR.entriesFieldName, LowerMatrixIR.colsFieldName) val subset = tt.globalType.valueSubsetter(requestedGlobalsType) diff --git a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala index d99226e61d7..e854d41aec3 100644 --- a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -1965,9 +1965,6 @@ class MatrixVCFReader( override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = executeGeneric(ctx).toTableStage(ctx, requestedType, "VCF", params) - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = - TableExecuteIntermediate(_lower(ctx, requestedType)) - override def toJValue: JValue = { implicit val formats: Formats = DefaultFormats decomposeWithName(params, "MatrixVCFReader") diff --git a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala index 84e954db474..b13ef7c9a1e 100644 --- a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala @@ -3,6 +3,7 @@ package is.hail.expr.ir import is.hail.HailSuite import is.hail.backend.ExecuteContext import is.hail.expr.Nat +import is.hail.expr.ir.lowering.TableStage import is.hail.methods.{ForceCountMatrixTable, ForceCountTable} import is.hail.rvd.RVD import is.hail.types._ @@ -114,7 +115,7 @@ class PruneSuite extends HailSuite { def pathsUsed: IndexedSeq[String] = FastSeq() - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = ??? + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? def partitionCounts: Option[IndexedSeq[Long]] = ??? diff --git a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala index cf18bd9ce1c..df955fc0555 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala @@ -6,7 +6,7 @@ import is.hail.annotations.SafeNDArray import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.expr.ir.TestUtils._ -import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} +import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR, TableStage} import is.hail.methods.ForceCountTable import is.hail.rvd.RVDPartitioner import is.hail.types._ @@ -826,7 +826,7 @@ class TableIRSuite extends HailSuite { def pathsUsed: Seq[String] = FastSeq() - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = ??? + override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastIndexedSeq(1, 2, 3, 4)) From 56cb6e2d1c9a3f1993e8f1741777d3972942878a Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 10 Jul 2023 09:17:54 -0400 Subject: [PATCH 3/5] delete GenericTableValue toTableValue code path --- .../is/hail/expr/ir/GenericTableValue.scala | 64 ------------------- 1 file changed, 64 deletions(-) diff --git a/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala b/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala index a8694210e2d..8ea524a2fd1 100644 --- a/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala +++ b/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala @@ -122,26 +122,6 @@ class PartitionIteratorLongReader( } } -class GenericTableValueRDDPartition( - val index: Int, - val context: Any -) extends Partition - -class GenericTableValueRDD( - @transient val contexts: IndexedSeq[Any], - body: (Region, HailClassLoader, Any) => Iterator[Long] -) extends RDD[RVDContext => Iterator[Long]](SparkBackend.sparkContext("GenericTableValueRDD"), Nil) { - def getPartitions: Array[Partition] = contexts.zipWithIndex.map { case (c, i) => - new GenericTableValueRDDPartition(i, c) - }.toArray - - def compute(split: Partition, context: TaskContext): Iterator[RVDContext => Iterator[Long]] = { - Iterator.single { (rvdCtx: RVDContext) => - body(rvdCtx.region, theHailClassLoaderForSparkWorkers, split.asInstanceOf[GenericTableValueRDDPartition].context) - } - } -} - abstract class LoweredTableReaderCoercer { def coerce(ctx: ExecuteContext, globals: IR, @@ -210,48 +190,4 @@ class GenericTableValue( requestedBody) } } - - def toContextRDD(fs: FS, requestedRowType: TStruct): ContextRDD[Long] = { - val localBody = body(requestedRowType) - ContextRDD(new GenericTableValueRDD(contexts, localBody(_, _, fs, _))) - } - - private[this] var rvdCoercer: RVDCoercer = _ - - def getRVDCoercer(ctx: ExecuteContext): RVDCoercer = { - if (rvdCoercer == null) { - rvdCoercer = RVD.makeCoercer( - ctx, - RVDType(bodyPType(fullTableType.rowType), fullTableType.key), - 1, - toContextRDD(ctx.fs, fullTableType.keyType)) - } - rvdCoercer - } - - def toTableValue(ctx: ExecuteContext, requestedType: TableType): TableValue = { - val requestedRowType = requestedType.rowType - val requestedRowPType = bodyPType(requestedType.rowType) - val crdd = toContextRDD(ctx.fs, requestedRowType) - - val rvd = partitioner match { - case Some(partitioner) => - RVD( - RVDType(requestedRowPType, fullTableType.key), - partitioner, - crdd) - case None if requestedType.key.isEmpty => - RVD( - RVDType(requestedRowPType, fullTableType.key), - RVDPartitioner.unkeyed(ctx.stateManager, contexts.length), - crdd) - case None => - getRVDCoercer(ctx).coerce(RVDType(requestedRowPType, fullTableType.key), crdd) - } - - TableValue(ctx, - requestedType, - BroadcastRow(ctx, globals(requestedType.globalType), requestedType.globalType), - rvd) - } } From 9f2051c82c3617125e18087e7561fe647b5b4e5b Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 10 Jul 2023 14:49:32 -0400 Subject: [PATCH 4/5] make _lower protected --- hail/src/main/scala/is/hail/expr/ir/TableIR.scala | 4 ++-- .../is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index 0d1f7a2b43e..c93979bca60 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -429,7 +429,7 @@ object LoweredTableReader { tableStage, keyType.fieldNames.map(f => SortField(f, Ascending)), RTable(rowRType, globRType, FastSeq()) - )._lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct])) + ).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]), dropRows = false) } } } @@ -522,7 +522,7 @@ abstract class TableReader { } else _lower(ctx, requestedType) - def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage + protected def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage } object TableNativeReader { diff --git a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index cad8bbb4c63..7c956e88441 100644 --- a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -42,7 +42,7 @@ class LowerDistributedSortSuite extends HailSuite { val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses) val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) - ._lower(ctx, myTable.typ.copy(key = FastIndexedSeq())) + .lower(ctx, myTable.typ.copy(key = FastIndexedSeq()), dropRows = false) val res = TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[IndexedSeq[Row]]].flatten val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2 From 3b817ba6d798cf566ee6a41474a479952fa78fd0 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 10 Jul 2023 15:27:32 -0400 Subject: [PATCH 5/5] revert _lower refactoring --- .../is/hail/expr/ir/StringTableReader.scala | 7 +++-- .../main/scala/is/hail/expr/ir/TableIR.scala | 31 ++++++------------- .../ir/lowering/LowerDistributedSort.scala | 4 +-- .../hail/expr/ir/lowering/LowerTableIR.scala | 14 +++++++-- .../expr/ir/lowering/RVDToTableStage.scala | 2 +- .../is/hail/io/avro/AvroTableReader.scala | 7 +++-- .../main/scala/is/hail/io/bgen/LoadBgen.scala | 2 +- .../scala/is/hail/io/plink/LoadPlink.scala | 2 +- .../main/scala/is/hail/io/vcf/LoadVCF.scala | 2 +- .../scala/is/hail/expr/ir/PruneSuite.scala | 4 ++- .../scala/is/hail/expr/ir/TableIRSuite.scala | 4 ++- .../lowering/LowerDistributedSortSuite.scala | 2 +- 12 files changed, 45 insertions(+), 36 deletions(-) diff --git a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala index 9477c8dbd88..61865ed5449 100644 --- a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala +++ b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala @@ -3,7 +3,7 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir.functions.StringFunctions -import is.hail.expr.ir.lowering.{TableStage, TableStageDependency, TableStageToRVD} +import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency, TableStageToRVD} import is.hail.expr.ir.streams.StreamProducer import is.hail.io.fs.{FS, FileStatus} import is.hail.rvd.RVDPartitioner @@ -155,7 +155,7 @@ class StringTableReader( override def pathsUsed: Seq[String] = params.files - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val fs = ctx.fs val lines = GenericLines.read(fs, fileStatuses, None, None, params.minPartitions, params.forceBGZ, params.forceGZ, params.filePerPartition) @@ -168,6 +168,9 @@ class StringTableReader( ) } + override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = + throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") + override def partitionCounts: Option[IndexedSeq[Long]] = None override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index c93979bca60..ffff0b15c4a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -429,7 +429,7 @@ object LoweredTableReader { tableStage, keyType.fieldNames.map(f => SortField(f, Ascending)), RTable(rowRType, globRType, FastSeq()) - ).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]), dropRows = false) + ).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct])) } } } @@ -483,7 +483,7 @@ abstract class TableReader { def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { assert(!dropRows) - TableExecuteIntermediate(_lower(ctx, requestedType)) + TableExecuteIntermediate(lower(ctx, requestedType)) } def partitionCounts: Option[IndexedSeq[Long]] @@ -506,23 +506,9 @@ abstract class TableReader { StringEscapeUtils.escapeString(JsonMethods.compact(toJValue)) } - def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = - throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") + def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR - final def lower(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableStage = - if (dropRows) { - val globals = lowerGlobals(ctx, requestedType.globalType) - - TableStage( - globals, - RVDPartitioner.empty(ctx, requestedType.keyType), - TableStageDependency.none, - MakeStream(FastIndexedSeq(), TStream(TStruct.empty)), - (_: Ref) => MakeStream(FastIndexedSeq(), TStream(requestedType.rowType))) - } else - _lower(ctx, requestedType) - - protected def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage + def lower(ctx: ExecuteContext, requestedType: TableType): TableStage } object TableNativeReader { @@ -1452,7 +1438,7 @@ class TableNativeReader( 0) } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = lowerGlobals(ctx, requestedType.globalType) val rowsSpec = spec.rowsSpec val specPart = rowsSpec.partitioner(ctx.stateManager) @@ -1548,7 +1534,7 @@ case class TableNativeZippedReader( 0) } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = lowerGlobals(ctx, requestedType.globalType) val rowsSpec = specLeft.rowsSpec val specPart = rowsSpec.partitioner(ctx.stateManager) @@ -1634,9 +1620,12 @@ case class TableFromBlockMatrixNativeReader( TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd)) } - def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented") + override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = + throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") + override def toJValue: JValue = { decomposeWithName(params, "TableFromBlockMatrixNativeReader")(TableReader.formats) } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala index d5e86f2c12c..950faa9084d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala @@ -89,7 +89,7 @@ object LowerDistributedSort { override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { TableStage( globals = globals, partitioner = partitioner.coarsen(requestedType.key.length), @@ -626,7 +626,7 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val contextData = { var filesCount: Long = 0 diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala index 072958d3c7f..f4a84aade17 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -510,7 +510,7 @@ class TableStage( rightWithPartNums, SortField("__partNum", Ascending) +: right.key.map(k => SortField(k, Ascending)), rightTableRType) - val sorted = sortedReader.lower(ctx, sortedReader.fullType, false) + val sorted = sortedReader.lower(ctx, sortedReader.fullType) assert(sorted.kType.fieldNames.sameElements("__partNum" +: right.key)) val newRightPartitioner = new RVDPartitioner( ctx.stateManager, @@ -785,7 +785,17 @@ object LowerTableIR { val lowered: TableStage = tir match { case TableRead(typ, dropRows, reader) => - reader.lower(ctx, typ, dropRows) + if (dropRows) { + val globals = reader.lowerGlobals(ctx, typ.globalType) + + TableStage( + globals, + RVDPartitioner.empty(ctx, typ.keyType), + TableStageDependency.none, + MakeStream(FastIndexedSeq(), TStream(TStruct.empty)), + (_: Ref) => MakeStream(FastIndexedSeq(), TStream(typ.rowType))) + } else + reader.lower(ctx, typ) case TableParallelize(rowsAndGlobal, nPartitions) => val nPartitionsAdj = nPartitions.getOrElse(16) diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala index 5457184cd98..017f6488b11 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -68,7 +68,7 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { RVDToTableStage(rvd, globals) .upcast(ctx, requestedType) } diff --git a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala index 959d35b6148..53a4a773581 100644 --- a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala +++ b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala @@ -2,7 +2,7 @@ package is.hail.io.avro import is.hail.backend.ExecuteContext import is.hail.expr.ir._ -import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} +import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency} import is.hail.rvd.RVDPartitioner import is.hail.types.physical.{PCanonicalStruct, PCanonicalTuple, PInt64Required} import is.hail.types.virtual._ @@ -44,7 +44,7 @@ class AvroTableReader( def renderShort(): String = defaultRender() - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = MakeStruct(FastIndexedSeq()) val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) => MakeStruct(Array("partitionPath" -> path, "partitionIndex" -> Cast(idx, TInt64))) @@ -59,6 +59,9 @@ class AvroTableReader( } ) } + + override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = + throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") } object AvroTableReader { diff --git a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala index 88409b981b5..5d76f90f8e9 100644 --- a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala +++ b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala @@ -509,7 +509,7 @@ class MatrixBGENReader( case _ => false } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = lowerGlobals(ctx, requestedType.globalType) variants match { diff --git a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala index e58a592aa9d..5ffbc1f43be 100644 --- a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -491,7 +491,7 @@ class MatrixPLINKReader( Literal(requestedGlobalsType, subset(globals).asInstanceOf[Row]) } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = executeGeneric(ctx).toTableStage(ctx, requestedType, "PLINK file", params) override def toJValue: JValue = { diff --git a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala index e854d41aec3..cf5801a3bfb 100644 --- a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -1962,7 +1962,7 @@ class MatrixVCFReader( .apply(globals)) } - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = executeGeneric(ctx).toTableStage(ctx, requestedType, "VCF", params) override def toJValue: JValue = { diff --git a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala index b13ef7c9a1e..8e2e460e567 100644 --- a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala @@ -115,7 +115,9 @@ class PruneSuite extends HailSuite { def pathsUsed: IndexedSeq[String] = FastSeq() - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? + + override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ??? def partitionCounts: Option[IndexedSeq[Long]] = ??? diff --git a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala index df955fc0555..9d7b0c9b73f 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala @@ -826,7 +826,9 @@ class TableIRSuite extends HailSuite { def pathsUsed: Seq[String] = FastSeq() - override def _lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? + + override def lowerGlobals(ctx: ExecuteContext, requestedType: TStruct): IR = ??? override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastIndexedSeq(1, 2, 3, 4)) diff --git a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index 7c956e88441..a7af487cbd9 100644 --- a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -42,7 +42,7 @@ class LowerDistributedSortSuite extends HailSuite { val stage = LowerTableIR.applyTable(myTable, DArrayLowering.All, ctx, analyses) val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) - .lower(ctx, myTable.typ.copy(key = FastIndexedSeq()), dropRows = false) + .lower(ctx, myTable.typ.copy(key = FastIndexedSeq())) val res = TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[IndexedSeq[Row]]].flatten val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2