Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hail] new aggregator path for TableAggregateByKey #7195

Merged
merged 2 commits into from Oct 7, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
100 changes: 100 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Expand Up @@ -1476,6 +1476,106 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR {
val prev = child.execute(ctx)
val prevRVD = prev.rvd

if (HailContext.getFlag("newaggs") != null) {
try {
val res = genUID()
val extracted = agg.Extract(expr, res)

val (_, makeInit) = ir.CompileWithAggregators2[Long, Unit](
extracted.aggs,
"global", prev.globals.t,
extracted.init)

val (_, makeSeq) = ir.CompileWithAggregators2[Long, Long, Unit](
extracted.aggs,
"global", prev.globals.t,
"row", prev.rvd.rowPType,
extracted.seqPerElt)

val valueIR = Let(res, extracted.results, extracted.postAggIR)
val keyType = PType.canonical(prev.typ.keyType).asInstanceOf[PStruct]
val key = Ref(genUID(), keyType.virtualType)
val value = Ref(genUID(), valueIR.typ)
val (rowType: PStruct, makeRow) = ir.CompileWithAggregators2[Long, Long, Long](
extracted.aggs,
"global", prev.globals.t,
key.name, keyType,
Let(value.name, valueIR,
InsertFields(key, typ.valueType.fieldNames.map(n => n -> GetField(value, n)))))
assert(rowType.virtualType == typ.rowType, s"$rowType, ${ typ.rowType }")

val localChildRowType = prevRVD.rowPType
val keyIndices = prev.typ.keyFieldIdx
val keyOrd = prevRVD.typ.kRowOrd
val globalsBc = prev.globals.broadcast

val newRVDType = prevRVD.typ.copy(rowType = rowType)

val newRVD = prevRVD
.repartition(prevRVD.partitioner.strictify)
.boundary
.mapPartitionsWithIndex(newRVDType, { (i, ctx, it) =>
val partRegion = ctx.freshRegion

val globalsOff = globalsBc.value.readRegionValue(partRegion)

val initialize = makeInit(i, partRegion)
val sequence = makeSeq(i, partRegion)
val newRowF = makeRow(i, partRegion)

val aggRegion = ctx.freshRegion

new Iterator[RegionValue] {
var isEnd = false
var current: RegionValue = _
val rowKey: WritableRegionValue = WritableRegionValue(keyType, ctx.freshRegion)
val consumerRegion: Region = ctx.region
val newRV = RegionValue(consumerRegion)

def hasNext: Boolean = {
if (isEnd || (current == null && !it.hasNext)) {
isEnd = true
return false
}
if (current == null)
current = it.next()
true
}

def next(): RegionValue = {
if (!hasNext)
throw new java.util.NoSuchElementException()

rowKey.setSelect(localChildRowType, keyIndices, current)
val region = current.region

aggRegion.clear()
initialize.newAggState(aggRegion)
initialize(region, globalsOff, false)
sequence.setAggState(aggRegion, initialize.getAggOffset())

do {
val region = current.region
sequence(region,
globalsOff, false,
current.offset, false)
current = null
} while (hasNext && keyOrd.equiv(rowKey.value, current))
newRowF.setAggState(aggRegion, sequence.getAggOffset())
newRV.setOffset(newRowF(consumerRegion, globalsOff, false, rowKey.offset, false))
newRV
}
}
})

return prev.copy(rvd = newRVD, typ = typ)

} catch {
case e: agg.UnsupportedExtraction =>
log.info(s"couldn't lower TableAggregate: $e")
}
}

val (rvAggs, makeInit, makeSeq, aggResultType, postAggIR) = ir.CompileWithAggregators[Long, Long, Long](
"global", prev.globals.t,
"global", prev.globals.t,
Expand Down