diff --git a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/filter.kt b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/filter.kt index fe4cff629ee..1775289a807 100644 --- a/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/filter.kt +++ b/kotest-property/src/commonMain/kotlin/io/kotest/property/arbitrary/filter.kt @@ -39,4 +39,4 @@ fun Arb.filterNot(f: (A) -> Boolean): Arb = filter { !f(it) } * a particular subtype. */ @Suppress("UNCHECKED_CAST") -inline fun Arb.filterIsInstance(): Arb = filter { it is B }.map { it as B } +inline fun Arb<*>.filterIsInstance(): Arb = filter { it is B } as Arb diff --git a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FilterTest.kt b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FilterTest.kt index a98f764af92..83f680923ac 100644 --- a/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FilterTest.kt +++ b/kotest-property/src/jvmTest/kotlin/com/sksamuel/kotest/property/arbitrary/FilterTest.kt @@ -6,12 +6,15 @@ import io.kotest.core.spec.style.FunSpec import io.kotest.inspectors.forAll import io.kotest.matchers.collections.shouldContainExactly import io.kotest.matchers.collections.shouldNotBeIn +import io.kotest.matchers.should import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.beInstanceOf import io.kotest.property.Arb import io.kotest.property.EdgeConfig import io.kotest.property.RandomSource import io.kotest.property.Sample import io.kotest.property.arbitrary.filter +import io.kotest.property.arbitrary.filterIsInstance import io.kotest.property.arbitrary.int import io.kotest.property.arbitrary.map import io.kotest.property.arbitrary.of @@ -68,4 +71,11 @@ class FilterTest : FunSpec({ val result = shouldNotThrowAny { arb.single(RandomSource.seeded(1234L)) } result shouldBe 0 } + + test("Arb.filterIsInstance should only keep instances of the given type") { + val arb: Arb = Arb.of(1, "2", 3.0, "4", 5) + val filtered = arb.filterIsInstance() + val result = filtered.samples().take(100).map { it.value } + result.forAll { it should beInstanceOf(String::class) } + } })