Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Add reservoir sample aggregator #12812

Merged
merged 3 commits into from Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions hail/python/hail/expr/aggregators/__init__.py
@@ -1,7 +1,7 @@
from .aggregators import approx_cdf, approx_quantiles, approx_median, collect, collect_as_set, count, count_where, \
counter, any, all, take, _densify, min, max, sum, array_sum, ndarray_sum, mean, stats, product, fraction, \
hardy_weinberg_test, explode, filter, inbreeding, call_stats, info_score, \
hist, linreg, corr, group_by, downsample, array_agg, _prev_nonnull, _impute_type, fold
hist, linreg, corr, group_by, downsample, array_agg, _prev_nonnull, _impute_type, fold, _reservoir_sample

__all__ = [
'approx_cdf',
Expand Down Expand Up @@ -39,5 +39,6 @@
'array_agg',
'_prev_nonnull',
'_impute_type',
'fold'
'fold',
'_reservoir_sample'
]
5 changes: 5 additions & 0 deletions hail/python/hail/expr/aggregators/aggregators.py
Expand Up @@ -1406,6 +1406,11 @@ def downsample(x, y, label=None, n_divisions=500) -> ArrayExpression:
init_op_args=[n_divisions])


@typecheck(expr=expr_any, n=expr_int32)
def _reservoir_sample(expr, n):
return _agg_func('ReservoirSample', [expr], tarray(expr.dtype), [n])


@typecheck(gp=expr_array(expr_float64))
def info_score(gp) -> StructExpression:
r"""Compute the IMPUTE information score.
Expand Down
2 changes: 2 additions & 0 deletions hail/python/hail/ir/register_aggregators.py
Expand Up @@ -51,6 +51,8 @@ def register_aggregators():

register_aggregator('Take', (dtype('int32'),), (dtype('?in'),), dtype('array<?in>'))

register_aggregator('ReservoirSample', (dtype('int32'),), (dtype('?in'),), dtype('array<?in>'))

register_aggregator('TakeBy', (dtype('int32'),), (dtype('?in'), dtype('?key'),), dtype('array<?in>'))

downsample_aggregator_type = dtype('array<tuple(float64, float64, array<str>)>')
Expand Down
25 changes: 24 additions & 1 deletion hail/python/test/hail/expr/test_expr.py
Expand Up @@ -4070,4 +4070,27 @@ def test_locus_addition():

assert hl.eval((loc + 10) == hl.locus('1', 15, reference_genome='GRCh37'))
assert hl.eval((loc - 10) == hl.locus('1', 1, reference_genome='GRCh37'))
assert hl.eval((loc + 2_000_000_000) == hl.locus('1', len_1, reference_genome='GRCh37'))
assert hl.eval((loc + 2_000_000_000) == hl.locus('1', len_1, reference_genome='GRCh37'))


def test_reservoir_sampling_pointer_type():
ht = hl.utils.range_table(100000, 1)
assert ht.aggregate(hl.agg._reservoir_sample(hl.str(ht.idx), 1000)).all(lambda x: hl.str(hl.int(x)) == x)


def test_reservoir_sampling():
ht = hl.Table._generate(hl.literal([(1, 10), (10, 100), (100, 1000), (1000, 10000), (10000, 100000)]),
hl.struct(),
lambda ctx, _: hl.range(ctx[0], ctx[1]).map(lambda i: hl.struct(idx=i)),
5)

sample_sizes = [99, 811, 900, 1000, 3333]
(stats, samples) = ht.aggregate((hl.agg.stats(ht.idx), tuple([hl.sorted(hl.agg._reservoir_sample(ht.idx, size)) for size in sample_sizes])))

sample_variance = stats['stdev'] ** 2
sample_mean = stats['mean']

for sample, sample_size in zip(samples, sample_sizes):
mean = np.mean(sample)
expected_stdev = math.sqrt(sample_variance / sample_size)
assert abs(mean - sample_mean) / expected_stdev < 4 , (iteration, sample_size, abs(mean - sample_mean) / expected_stdev)
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/AggOp.scala
Expand Up @@ -12,6 +12,8 @@ object AggSignature {
AggSignature(Collect(), FastSeq(), FastSeq(requestedType.asInstanceOf[TArray].elementType))
case AggSignature(Take(), Seq(n), Seq(_)) =>
AggSignature(Take(), FastSeq(n), FastSeq(requestedType.asInstanceOf[TArray].elementType))
case AggSignature(ReservoirSample(), Seq(n), Seq(_)) =>
AggSignature(ReservoirSample(), FastSeq(n), FastSeq(requestedType.asInstanceOf[TArray].elementType))
case AggSignature(TakeBy(reverse), Seq(n), Seq(_, k)) =>
AggSignature(TakeBy(reverse), FastSeq(n), FastSeq(requestedType.asInstanceOf[TArray].elementType, k))
case AggSignature(PrevNonnull(), Seq(), Seq(_)) =>
Expand Down Expand Up @@ -44,6 +46,7 @@ final case class Min() extends AggOp
final case class Product() extends AggOp
final case class Sum() extends AggOp
final case class Take() extends AggOp
final case class ReservoirSample() extends AggOp
final case class Densify() extends AggOp
final case class TakeBy(so: SortOrder = Ascending) extends AggOp
final case class Group() extends AggOp
Expand All @@ -69,6 +72,7 @@ object AggOp {
case "min" | "Min" => Min()
case "count" | "Count" => Count()
case "take" | "Take" => Take()
case "ReservoirSample" | "Take" => ReservoirSample()
case "densify" | "Densify" => Densify()
case "takeBy" | "TakeBy" => TakeBy()
case "callStats" | "CallStats" => CallStats()
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Exists.scala
Expand Up @@ -93,7 +93,7 @@ object ContainsAggIntermediate {

object AggIsCommutative {
def apply(op: AggOp): Boolean = op match {
case Take() | Collect() | PrevNonnull() | TakeBy(_) => false
case Take() | Collect() | PrevNonnull() | TakeBy(_) | ReservoirSample() => false
case _ => true
}
}
Expand Down
3 changes: 3 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Expand Up @@ -712,6 +712,9 @@ object IRParser {
case "TakeStateSig" =>
val pt = vtwr_expr(it)
TakeStateSig(pt)
case "ReservoirSampleStateSig" =>
val pt = vtwr_expr(it)
ReservoirSampleStateSig(pt)
case "DensifyStateSig" =>
val pt = vtwr_expr(it)
DensifyStateSig(pt)
Expand Down
5 changes: 5 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala
Expand Up @@ -32,6 +32,7 @@ object AggStateSig {
case Min() | Max() => TypedStateSig(seqVTypes.head.setRequired(false))
case Count() => TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true))
case Take() => TakeStateSig(seqVTypes.head)
case ReservoirSample() => ReservoirSampleStateSig(seqVTypes.head)
case Densify() => DensifyStateSig(seqVTypes.head)
case TakeBy(reverse) =>
val Seq(vt, kt) = seqVTypes
Expand Down Expand Up @@ -61,6 +62,7 @@ object AggStateSig {
case TypedStateSig(vt) => new TypedRegionBackedAggState(vt, cb)
case DownsampleStateSig(labelType) => new DownsampleState(cb, labelType)
case TakeStateSig(vt) => new TakeRVAS(vt, cb)
case ReservoirSampleStateSig(vt) => new ReservoirSampleRVAS(vt, cb)
case DensifyStateSig(vt) => new DensifyState(vt, cb)
case TakeByStateSig(vt, kt, so) => new TakeByRVAS(vt, kt, cb, so)
case CollectStateSig(pt) => new CollectAggState(pt, cb)
Expand All @@ -86,6 +88,7 @@ case class TypedStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt),
case class DownsampleStateSig(labelType: VirtualTypeWithReq) extends AggStateSig(Array(labelType), None)
case class TakeStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt), None)
case class TakeByStateSig(vt: VirtualTypeWithReq, kt: VirtualTypeWithReq, so: SortOrder) extends AggStateSig(Array(vt, kt), None)
case class ReservoirSampleStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt), None)
case class DensifyStateSig(vt: VirtualTypeWithReq) extends AggStateSig(Array(vt), None)
case class CollectStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt), None)
case class CollectAsSetStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt), None)
Expand Down Expand Up @@ -345,6 +348,7 @@ object Extract {
case AggSignature(Max(), _, Seq(t)) => t
case AggSignature(Count(), _, _) => TInt64
case AggSignature(Take(), _, Seq(t)) => TArray(t)
case AggSignature(ReservoirSample(), _, Seq(t)) => TArray(t)
case AggSignature(CallStats(), _, _) => CallStatsState.resultPType.virtualType
case AggSignature(TakeBy(_), _, Seq(value, key)) => TArray(value)
case AggSignature(PrevNonnull(), _, Seq(t)) => t
Expand All @@ -370,6 +374,7 @@ object Extract {
case PhysicalAggSig(Count(), TypedStateSig(_)) => CountAggregator
case PhysicalAggSig(Take(), TakeStateSig(t)) => new TakeAggregator(t)
case PhysicalAggSig(TakeBy(_), TakeByStateSig(v, k, _)) => new TakeByAggregator(v, k)
case PhysicalAggSig(ReservoirSample(), ReservoirSampleStateSig(t)) => new ReservoirSampleAggregator(t)
case PhysicalAggSig(Densify(), DensifyStateSig(v)) => new DensifyAggregator(v)
case PhysicalAggSig(CallStats(), CallStatsStateSig()) => new CallStatsAggregator()
case PhysicalAggSig(Collect(), CollectStateSig(t)) => new CollectAggregator(t)
Expand Down
@@ -0,0 +1,247 @@
package is.hail.expr.ir.agg

import is.hail.annotations.Region
import is.hail.asm4s.{Code, _}
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, IEmitCode}
import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer}
import is.hail.types.VirtualTypeWithReq
import is.hail.types.physical._
import is.hail.types.physical.stypes.EmitType
import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SIndexablePointerValue}
import is.hail.types.virtual.{TInt32, Type}
import is.hail.utils._

class ReservoirSampleRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) extends AggregatorState {
val eltPType = eltType.canonicalPType

private val r: ThisFieldRef[Region] = kb.genFieldThisRef[Region]()
val region: Value[Region] = r
private val rand = kb.genFieldThisRef[java.util.Random]()

val builder = new StagedArrayBuilder(eltPType, kb, region)
val storageType: PCanonicalTuple = PCanonicalTuple(true, PInt32Required, PInt64Required, PInt64Required, builder.stateType)
val maxSize = kb.genFieldThisRef[Int]()
val seenSoFar = kb.genFieldThisRef[Long]()
private val garbage = kb.genFieldThisRef[Long]()
private val maxSizeOffset: Code[Long] => Code[Long] = storageType.loadField(_, 0)
private val elementsSeenOffset: Code[Long] => Code[Long] = storageType.loadField(_, 1)
private val garbageOffset: Code[Long] => Code[Long] = storageType.loadField(_, 2)
private val builderStateOffset: Code[Long] => Code[Long] = storageType.loadField(_, 3)

def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = {
cb += region.getNewRegion(regionSize)
}

def createState(cb: EmitCodeBuilder): Unit = {
cb.assign(rand, Code.newInstance[java.util.Random])
cb.ifx(region.isNull, {
cb.assign(r, Region.stagedCreate(regionSize, kb.pool()))
})
}

override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = {
regionLoader(cb, r)
cb.assign(maxSize, Region.loadInt(maxSizeOffset(src)))
cb.assign(seenSoFar, Region.loadLong(elementsSeenOffset(src)))
cb.assign(garbage, Region.loadLong(garbageOffset(src)))
builder.loadFrom(cb, builderStateOffset(src))
}

override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = {
cb.ifx(region.isValid,
{
regionStorer(cb, region)
cb += region.invalidate()
cb += Region.storeInt(maxSizeOffset(dest), maxSize)
cb += Region.storeLong(elementsSeenOffset(dest), seenSoFar)
cb += Region.storeLong(garbageOffset(dest), garbage)
builder.storeTo(cb, builderStateOffset(dest))
})
}

def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = {
{ (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) =>
cb += ob.writeInt(maxSize)
cb += ob.writeLong(seenSoFar)
builder.serialize(codec)(cb, ob)
}
}

def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = {
{ (cb: EmitCodeBuilder, ib: Value[InputBuffer]) =>
cb.assign(maxSize, ib.readInt())
cb.assign(seenSoFar, ib.readLong())
cb.assign(garbage, 0L)
builder.deserialize(codec)(cb, ib)
}
}

def init(cb: EmitCodeBuilder, _maxSize: Code[Int]): Unit = {
cb.assign(maxSize, _maxSize)
cb.assign(seenSoFar, 0L)
cb.assign(garbage, 0L)
builder.initialize(cb)
}

def gc(cb: EmitCodeBuilder): Unit = {
cb.invokeVoid(cb.emb.ecb.getOrGenEmitMethod("reservoir_sample_gc",
(this, "gc"), FastIndexedSeq(), UnitInfo) { mb =>
mb.voidWithBuilder { cb =>
cb.ifx(garbage > (maxSize.toL * 2L + 1024L), {
val oldRegion = mb.newLocal[Region]("old_region")
cb.assign(oldRegion, region)
cb.assign(r, Region.stagedCreate(regionSize, kb.pool()))
builder.reallocateData(cb)
cb.assign(garbage, 0L)
cb += oldRegion.invoke[Unit]("invalidate")
})
}
})
}

def seqOp(cb: EmitCodeBuilder, elt: EmitCode): Unit = {
val eltVal = cb.memoize(elt)
cb.assign(seenSoFar, seenSoFar + 1)
cb.ifx(builder.size < maxSize,
eltVal.toI(cb)
.consume(cb,
builder.setMissing(cb),
sc => builder.append(cb, sc)),
{
// swaps the next element into the reservoir with probability (k / n), where
// k is the reservoir size and n is the number of elements seen so far (including current)
cb.ifx(rand.invoke[Double]("nextDouble") * seenSoFar.toD <= maxSize.toD, {
val idxToSwap = cb.memoize(rand.invoke[Int, Int]("nextInt", maxSize))
builder.overwrite(cb, eltVal, idxToSwap)
cb.assign(garbage, garbage + 1L)
gc(cb)
})
})
}

def dump(cb: EmitCodeBuilder, prefix: String): Unit = {
cb.println(s"> dumping reservoir: $prefix with size=", maxSize.toS,", seen=", seenSoFar.toS)
val j = cb.newLocal[Int]("j", 0)
cb.whileLoop(j < builder.size, {
cb.println(" j=", j.toS, ", elt=", cb.strValue(builder.loadElement(cb, j)))
cb.assign(j, j + 1)
})

}

def combine(cb: EmitCodeBuilder, other: ReservoirSampleRVAS): Unit = {
val j = cb.newLocal[Int]("j")
cb.ifx(other.builder.size < maxSize, {

cb.assign(j, 0)
cb.whileLoop(j < other.builder.size, {
seqOp(cb, cb.memoize(other.builder.loadElement(cb, j)))
cb.assign(j, j + 1)
})
}, {
cb.ifx(builder.size < maxSize, {
cb.assign(j, 0)
cb.whileLoop(j < builder.size, {
other.seqOp(cb, cb.memoize(builder.loadElement(cb, j)))
cb.assign(j, j + 1)
})

cb.assign(seenSoFar, other.seenSoFar)
cb.assign(garbage, other.garbage)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assigned garbage twice

val tmpRegion = cb.newLocal[Region]("tmpRegion", region)
cb.assign(r, other.region)
cb.assign(other.r, tmpRegion)
cb += tmpRegion.invoke[Unit]("invalidate")
builder.cloneFrom(cb, other.builder)

}, {
val newBuilder = new StagedArrayBuilder(eltPType, kb, region)
newBuilder.initializeWithCapacity(cb, maxSize)

val totalWeightLeft = cb.newLocal("totalWeightLeft", seenSoFar.toD)
val totalWeightRight = cb.newLocal("totalWeightRight", other.seenSoFar.toD)

val leftSize = cb.newLocal[Int]("leftSize", builder.size)
val rightSize = cb.newLocal[Int]("rightSize", other.builder.size)

cb.assign(j, 0)
cb.whileLoop(j < maxSize, {
val x = cb.memoize(rand.invoke[Double]("nextDouble"))
cb.ifx(x * (totalWeightLeft + totalWeightRight) <= totalWeightLeft, {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the probabilities need to change as you start pulling items out of the two sides. I think it should be

if (x * (leftSize * totalWeightLeft + rightSize * totalWeightRight) <= leftSize * totalWeightLeft)

Another possibility is to modify the left builder in place, using a weighted generalization of the seqOp:

weightSoFar = totalWeightLeft
rightWeight = totalWeightRight / rightSize
for (j in 0..right.size)
  weightSoFar += rightWeight
  if (left.size < maxSize)
    left.append(right[j])
  else
    if (randDouble() * weightSoFar < rightWeight * maxSize)
      swap right[j] into random position in left

The unweighted sampler maintains the invariant that at any time, the probability any item seen so far is in the sample (P(x in S)) is maxSize / seenSoFar. The weighted generalization makes that maxSize * weight(x) / weightSoFar, where weightSoFar is the sum of the weights of all items seen so far.

For the combOp, if we just union the two samples together, but give each item from the left the weight totalWeightLeft / leftSize, and similarly for the right, then after the weighted sampler runs, the probability any item from the left is in the result is

(leftSize / totalWeightLeft) * (maxSize * (totalWeightLeft / leftSize) / totalWeight)
=
maxSize / totalWeight

I'm pretty sure this handles all cases where one or both sides aren't full as well.


val idxToSample = cb.memoize(rand.invoke[Int, Int]("nextInt", leftSize))
builder.loadElement(cb, idxToSample).toI(cb).consume(cb,
newBuilder.setMissing(cb),
newBuilder.append(cb, _, false))
cb.assign(leftSize, leftSize - 1)
cb.assign(totalWeightLeft, totalWeightLeft - 1)
cb.ifx(idxToSample < leftSize, {
builder.overwrite(cb, cb.memoize(builder.loadElement(cb, leftSize)), idxToSample, false)
})
}, {
val idxToSample = cb.memoize(rand.invoke[Int, Int]("nextInt", rightSize))
other.builder.loadElement(cb, idxToSample).toI(cb).consume(cb,
newBuilder.setMissing(cb),
newBuilder.append(cb, _, true))
cb.assign(rightSize, rightSize - 1)
cb.assign(totalWeightRight, totalWeightRight - 1)
cb.ifx(idxToSample < rightSize, {
other.builder.overwrite(cb, cb.memoize(other.builder.loadElement(cb, rightSize)), idxToSample, false)
})
})
cb.assign(j, j + 1)
})
builder.cloneFrom(cb, newBuilder)
cb.assign(seenSoFar, seenSoFar + other.seenSoFar)
cb.assign(garbage, garbage + leftSize.toL)
gc(cb)
})
})
}

def resultArray(cb: EmitCodeBuilder, region: Value[Region], resType: PCanonicalArray): SIndexablePointerValue = {
resType.constructFromElements(cb, region, builder.size, deepCopy = true) { (cb, idx) =>
builder.loadElement(cb, idx).toI(cb)
}
}

def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = {
cb.assign(maxSize, Region.loadInt(maxSizeOffset(src)))
cb.assign(seenSoFar, Region.loadLong(elementsSeenOffset(src)))
cb.assign(garbage, Region.loadLong(garbageOffset(src)))
builder.copyFrom(cb, builderStateOffset(src))
}
}

class ReservoirSampleAggregator(typ: VirtualTypeWithReq) extends StagedAggregator {
type State = ReservoirSampleRVAS

private val pt = typ.canonicalPType
val resultPType: PCanonicalArray = PCanonicalArray(pt)
val resultEmitType: EmitType = EmitType(SIndexablePointer(resultPType), true)
val initOpTypes: Seq[Type] = Array(TInt32)
val seqOpTypes: Seq[Type] = Array(typ.t)

protected def _initOp(cb: EmitCodeBuilder, state: ReservoirSampleRVAS, init: Array[EmitCode]): Unit = {
assert(init.length == 1)
val Array(sizeTriplet) = init
sizeTriplet.toI(cb)
.consume(cb,
cb += Code._fatal[Unit](s"argument 'n' for 'hl.agg.reservoir_sample' may not be missing"),
sc => state.init(cb, sc.asInt.value)
)
}

protected def _seqOp(cb: EmitCodeBuilder, state: ReservoirSampleRVAS, seq: Array[EmitCode]): Unit = {
val Array(elt: EmitCode) = seq
state.seqOp(cb, elt)
}

protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: ReservoirSampleRVAS, other: ReservoirSampleRVAS): Unit = state.combine(cb, other)

protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = {
// deepCopy is handled by state.resultArray
IEmitCode.present(cb, state.resultArray(cb, region, resultPType))
}
}