Skip to content

Commit

Permalink
Add Arb.choose that accepts weighted arbs #1499
Browse files Browse the repository at this point in the history
  • Loading branch information
sksamuel committed Jun 7, 2020
1 parent 70155cd commit 8416627
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 33 deletions.
Expand Up @@ -12,12 +12,22 @@ import kotlin.random.nextInt
*/
fun <T> Arb.Companion.element(collection: Collection<T>): Arb<T> = Arb.create { collection.random(it.random) }

/**
* Alias for [element]
*/
fun <T> Arb.Companion.of(collection: Collection<T>): Arb<T> = element(collection)

/**
* Returns an [Arb] whose values are chosen randomly from those in the supplied collection.
* May not cover all items. If you want an exhaustive selection from the list, see [Exhaustive.collection]
*/
fun <T> Arb.Companion.element(vararg collection: T): Arb<T> = Arb.create { collection.random(it.random) }

/**
* Alias for [element]
*/
fun <T> Arb.Companion.of(vararg collection: T): Arb<T> = element(*collection)

/**
* Returns an [Arb] whose of values are a set of values generated by the given element generator.
* The size of each set is determined randomly within the specified [range].
Expand Down
Expand Up @@ -3,6 +3,7 @@ package io.kotest.property.arbitrary
import io.kotest.property.Arb
import io.kotest.property.Gen
import io.kotest.property.Sample
import kotlin.jvm.JvmName

/**
* Returns a stream of values based on weights:
Expand Down Expand Up @@ -38,6 +39,63 @@ fun <A : Any> Arb.Companion.choose(a: Pair<Int, A>, b: Pair<Int, A>, vararg cs:
}
}

/**
* An alias to [choose] to aid in discoverability for those used to Haskell's QuickCheck.
*/
fun <A : Any> Arb.Companion.frequency(
a: Pair<Int, A>,
b: Pair<Int, A>,
vararg cs: Pair<Int, A>
): Arb<A> = choose(a, b, *cs)

/**
* Returns a stream of values based on weights:
*
* Arb.choose(1 to arbA, 2 to arbB) will generate a value from arbA 33% of the time
* and from arbB 66% of the time.
*
* @throws IllegalArgumentException If any negative weight is given or only
* weights of zero are given.
*/
@JvmName("chooseArbs")
fun <A : Any> Arb.Companion.choose(a: Pair<Int, Arb<A>>, b: Pair<Int, Arb<A>>, vararg cs: Pair<Int, Arb<A>>): Arb<A> {
val allPairs = listOf(a, b) + cs
val weights = allPairs.map { it.first }
require(weights.all { it >= 0 }) { "Negative weights not allowed" }
require(weights.any { it > 0 }) { "At least one weight must be greater than zero" }

// The algorithm for pick is a migration of
// the algorithm from Haskell QuickCheck
// http://hackage.haskell.org/package/QuickCheck
// See function frequency in the package Test.QuickCheck
tailrec fun pick(n: Int, l: List<Pair<Int, Iterator<A>>>): Iterator<A> {
val (w, e) = l.first()
return if (n <= w) e
else pick(n - w, l.drop(1))
}

return arb { rs ->
// we must open up an iter stream for each arb
val allIters = allPairs.map { (weight, arb) -> weight to arb.values(rs).map { it.value }.iterator() }
generateSequence {
val total = weights.sum()
val n = rs.random.nextInt(1, total + 1)
val arb = pick(n, allIters)
arb.next()
}
}
}

/**
* An alias to [choose] to aid in discoverability for those used to Haskell's QuickCheck.
*/
@JvmName("frequencyArbs")
fun <A : Any> Arb.Companion.frequency(
a: Pair<Int, Arb<A>>,
b: Pair<Int, Arb<A>>,
vararg cs: Pair<Int, Arb<A>>
): Arb<A> = choose(a, b, *cs)

/**
* Generates random permutations of a list.
*/
Expand Down
Expand Up @@ -9,46 +9,91 @@ import io.kotest.matchers.shouldBe
import io.kotest.property.arbitrary.choose
import io.kotest.data.row
import io.kotest.property.Arb
import io.kotest.property.arbitrary.constant
import io.kotest.property.arbitrary.single
import io.kotest.property.random

class ChooseTest : FunSpec({

test("weighted should honour seed") {
val seedListA = Arb.choose(1 to 'A', 3 to 'B', 4 to 'C', 5 to 'D').values(684658365846L.random()).take(500).toList().map { it.value }
val seedListB = Arb.choose(1 to 'A', 3 to 'B', 4 to 'C', 5 to 'D').values(684658365846L.random()).take(500).toList().map { it.value }
seedListA shouldBe seedListB
}

test("weighted should generate expected values in correct ratios according to weights") {
forAll(
row(listOf(1 to 'A', 1 to 'B'), mapOf('A' to 0.5, 'B' to 0.5)),
row(listOf(1 to 'A', 3 to 'B', 1 to 'C'), mapOf('A' to 0.2, 'B' to 0.6, 'C' to 0.2)),
row(listOf(1 to 'A', 3 to 'C', 1 to 'C'), mapOf('A' to 0.2, 'C' to 0.8)),
row(listOf(1 to 'A', 3 to 'B', 1 to 'C', 4 to 'D'), mapOf('A' to 0.11, 'B' to 0.33, 'C' to 0.11, 'D' to 0.44))
) { weightPairs, expectedRatiosMap ->
val genCount = 100000
val chooseGen = Arb.choose(weightPairs[0], weightPairs[1], *weightPairs.drop(2).toTypedArray())
val actualCountsMap = (1..genCount).map { chooseGen.single() }.groupBy { it }.map { (k, v) -> k to v.count() }
val actualRatiosMap = actualCountsMap.map { (k, v) -> k to (v.toDouble() / genCount) }.toMap()

actualRatiosMap.keys shouldBe expectedRatiosMap.keys

actualRatiosMap.forEach { (k, actualRatio) ->
actualRatio shouldBe (expectedRatiosMap[k] as Double plusOrMinus 0.02)
test("Arb.choose should honour seed") {
val seedListA =
Arb.choose(1 to 'A', 3 to 'B', 4 to 'C', 5 to 'D').values(684658365846L.random()).take(500).toList()
.map { it.value }
val seedListB =
Arb.choose(1 to 'A', 3 to 'B', 4 to 'C', 5 to 'D').values(684658365846L.random()).take(500).toList()
.map { it.value }
seedListA shouldBe seedListB
}

test("Arb.choose for values should generate expected values in correct ratios according to weights") {
forAll(
row(listOf(1 to 'A', 1 to 'B'), mapOf('A' to 0.5, 'B' to 0.5)),
row(listOf(1 to 'A', 3 to 'B', 1 to 'C'), mapOf('A' to 0.2, 'B' to 0.6, 'C' to 0.2)),
row(listOf(1 to 'A', 3 to 'C', 1 to 'C'), mapOf('A' to 0.2, 'C' to 0.8)),
row(listOf(1 to 'A', 3 to 'B', 1 to 'C', 4 to 'D'), mapOf('A' to 0.11, 'B' to 0.33, 'C' to 0.11, 'D' to 0.44))
) { weightPairs, expectedRatiosMap ->
val genCount = 100000
val chooseGen = Arb.choose(weightPairs[0], weightPairs[1], *weightPairs.drop(2).toTypedArray())
val actualCountsMap = (1..genCount).map { chooseGen.single() }.groupBy { it }.map { (k, v) -> k to v.count() }
val actualRatiosMap = actualCountsMap.map { (k, v) -> k to (v.toDouble() / genCount) }.toMap()

actualRatiosMap.keys shouldBe expectedRatiosMap.keys

actualRatiosMap.forEach { (k, actualRatio) ->
actualRatio shouldBe (expectedRatiosMap[k] as Double plusOrMinus 0.02)
}
}
}
}
}

test("Arb.choose should not accept negative weights") {
shouldThrow<IllegalArgumentException> { Arb.choose(-1 to 'A', 1 to 'B') }
}

test("Arb.choose should not accept all zero weights") {
shouldThrow<IllegalArgumentException> { Arb.choose(0 to 'A', 0 to 'B') }
}

test("Arb.choose should accept weights if at least one is non-zero") {
shouldNotThrow<Exception> { Arb.choose(0 to 'A', 0 to 'B', 1 to 'C') }
}

test("Arb.choose(arbs) should generate expected values in correct ratios according to weights") {
val arbA = Arb.constant('A')
val arbB = Arb.constant('B')
val arbC = Arb.constant('C')
val arbD = Arb.constant('D')
forAll(
row(listOf(1 to arbA, 1 to arbB), mapOf('A' to 0.5, 'B' to 0.5)),
row(listOf(1 to arbA, 3 to arbB, 1 to arbC), mapOf('A' to 0.2, 'B' to 0.6, 'C' to 0.2)),
row(listOf(1 to arbA, 3 to arbC, 1 to arbC), mapOf('A' to 0.2, 'C' to 0.8)),
row(
listOf(1 to arbA, 3 to arbB, 1 to arbC, 4 to arbD),
mapOf('A' to 0.11, 'B' to 0.33, 'C' to 0.11, 'D' to 0.44)
)
) { weightPairs, expectedRatiosMap ->
val genCount = 100000
val chooseGen = Arb.choose(weightPairs[0], weightPairs[1], *weightPairs.drop(2).toTypedArray())
val actualCountsMap = (1..genCount).map { chooseGen.single() }.groupBy { it }.map { (k, v) -> k to v.count() }
val actualRatiosMap = actualCountsMap.map { (k, v) -> k to (v.toDouble() / genCount) }.toMap()

actualRatiosMap.keys shouldBe expectedRatiosMap.keys

actualRatiosMap.forEach { (k, actualRatio) ->
actualRatio shouldBe (expectedRatiosMap[k] as Double plusOrMinus 0.02)
}
}
}

test("Arb.choose(arbs) should not accept all zero weights") {
shouldThrow<IllegalArgumentException> { Arb.choose(0 to Arb.constant('A'), 0 to Arb.constant('B')) }
}

test("weighted should not accept negative weights") {
shouldThrow<IllegalArgumentException> { Arb.choose(-1 to 'A', 1 to 'B') }
}
test("Arb.choose(arbs) should not accept negative weights") {
shouldThrow<IllegalArgumentException> { Arb.choose(-1 to Arb.constant('A'), 1 to Arb.constant('B')) }
}

test("weighted should not accept all zero weights") {
shouldThrow<IllegalArgumentException> { Arb.choose(0 to 'A', 0 to 'B') }
}
test("Arb.choose(arbs) should accept weights if at least one is non-zero") {
shouldNotThrow<Exception> { Arb.choose(0 to Arb.constant('A'), 0 to Arb.constant('B'), 1 to Arb.constant('C')) }
}

test("weighted should accept weights if at least one is non-zero") {
shouldNotThrow<Exception> { Arb.choose(0 to 'A', 0 to 'B', 1 to 'C') }
}
})

0 comments on commit 8416627

Please sign in to comment.