diff --git a/bitset.go b/bitset.go index 642cba6..8fb9e9f 100644 --- a/bitset.go +++ b/bitset.go @@ -927,12 +927,16 @@ func (b *BitSet) Any() bool { // IsSuperSet returns true if this is a superset of the other set func (b *BitSet) IsSuperSet(other *BitSet) bool { - for i, e := other.NextSet(0); e; i, e = other.NextSet(i + 1) { - if !b.Test(i) { + l := other.wordCount() + if b.wordCount() < l { + l = b.wordCount() + } + for i, word := range other.set[:l] { + if b.set[i]&word != word { return false } } - return true + return popcntSlice(other.set[l:]) == 0 } // IsStrictSuperSet returns true if this is a strict superset of the other set diff --git a/bitset_benchmark_test.go b/bitset_benchmark_test.go index 93a6b33..98a1b52 100644 --- a/bitset_benchmark_test.go +++ b/bitset_benchmark_test.go @@ -8,6 +8,7 @@ package bitset import ( "bytes" + "fmt" "math/rand" "testing" ) @@ -542,3 +543,77 @@ func BenchmarkBitsetReadWrite(b *testing.B) { buffer.Reset() } } + +func BenchmarkIsSuperSet(b *testing.B) { + new := func(len int, density float64) *BitSet { + r := rand.New(rand.NewSource(42)) + bs := New(uint(len)) + for i := 0; i < len; i++ { + bs.SetTo(uint(i), r.Float64() < density) + } + return bs + } + + bench := func(name string, lenS, lenSS int, density float64, overrideS, overrideSS map[int]bool, f func(*BitSet, *BitSet) bool) { + s := new(lenS, density) + ss := new(lenSS, density) + + for i, v := range overrideS { + s.SetTo(uint(i), v) + } + for i, v := range overrideSS { + ss.SetTo(uint(i), v) + } + + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = f(ss, s) + } + }) + } + + f := func(ss, s *BitSet) bool { + return ss.IsSuperSet(s) + } + fStrict := func(ss, s *BitSet) bool { + return ss.IsStrictSuperSet(s) + } + + for _, len := range []int{1, 10, 100, 1000, 10000, 100000} { + density := 0.5 + bench(fmt.Sprintf("equal, len=%d", len), + len, len, density, nil, nil, f) + bench(fmt.Sprintf("equal, len=%d, strict", len), + len, len, density, nil, nil, fStrict) + } + + for _, density := range []float64{0, 0.05, 0.2, 0.8, 0.95, 1} { + len := 10000 + bench(fmt.Sprintf("equal, density=%.2f", density), + len, len, density, nil, nil, f) + bench(fmt.Sprintf("equal, density=%.2f, strict", density), + len, len, density, nil, nil, fStrict) + } + + for _, diff := range []int{0, 100, 1000, 9999} { + len := 10000 + density := 0.5 + overrideS := map[int]bool{diff: true} + overrideSS := map[int]bool{diff: false} + bench(fmt.Sprintf("subset, len=%d, diff=%d", len, diff), + len, len, density, overrideS, overrideSS, f) + bench(fmt.Sprintf("subset, len=%d, diff=%d, strict", len, diff), + len, len, density, overrideS, overrideSS, fStrict) + } + + for _, diff := range []int{0, 100, 1000, 9999} { + len := 10000 + density := 0.5 + overrideS := map[int]bool{diff: false} + overrideSS := map[int]bool{diff: true} + bench(fmt.Sprintf("superset, len=%d, diff=%d", len, diff), + len, len, density, overrideS, overrideSS, f) + bench(fmt.Sprintf("superset, len=%d, diff=%d, strict", len, diff), + len, len, density, overrideS, overrideSS, fStrict) + } +} diff --git a/bitset_test.go b/bitset_test.go index f9534e2..3dae32e 100644 --- a/bitset_test.go +++ b/bitset_test.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "math" + "math/rand" "strconv" "testing" ) @@ -1153,60 +1154,53 @@ func TestComplement(t *testing.T) { } func TestIsSuperSet(t *testing.T) { - a := New(500) - b := New(300) - c := New(200) - - // Setup bitsets - // a and b overlap - // only c is (strict) super set - for i := uint(0); i < 100; i++ { - a.Set(i) - } - for i := uint(50); i < 150; i++ { - b.Set(i) - } - for i := uint(0); i < 200; i++ { - c.Set(i) - } + test := func(name string, lenS, lenSS int, overrideS, overrideSS map[int]bool, want, wantStrict bool) { + t.Run(name, func(t *testing.T) { + s := New(uint(lenS)) + ss := New(uint(lenSS)) + + l := lenS + if lenSS < lenS { + l = lenSS + } - if a.IsSuperSet(b) { - t.Errorf("IsSuperSet fails") - } - if a.IsSuperSet(c) { - t.Errorf("IsSuperSet fails") - } - if b.IsSuperSet(a) { - t.Errorf("IsSuperSet fails") - } - if b.IsSuperSet(c) { - t.Errorf("IsSuperSet fails") - } - if !c.IsSuperSet(a) { - t.Errorf("IsSuperSet fails") - } - if !c.IsSuperSet(b) { - t.Errorf("IsSuperSet fails") - } + r := rand.New(rand.NewSource(42)) + for i := 0; i < l; i++ { + bit := r.Intn(2) == 1 + s.SetTo(uint(i), bit) + ss.SetTo(uint(i), bit) + } - if a.IsStrictSuperSet(b) { - t.Errorf("IsStrictSuperSet fails") - } - if a.IsStrictSuperSet(c) { - t.Errorf("IsStrictSuperSet fails") - } - if b.IsStrictSuperSet(a) { - t.Errorf("IsStrictSuperSet fails") - } - if b.IsStrictSuperSet(c) { - t.Errorf("IsStrictSuperSet fails") - } - if !c.IsStrictSuperSet(a) { - t.Errorf("IsStrictSuperSet fails") - } - if !c.IsStrictSuperSet(b) { - t.Errorf("IsStrictSuperSet fails") + for i, v := range overrideS { + s.SetTo(uint(i), v) + } + for i, v := range overrideSS { + ss.SetTo(uint(i), v) + } + + if got := ss.IsSuperSet(s); got != want { + t.Errorf("IsSuperSet() = %v, want %v", got, want) + } + if got := ss.IsStrictSuperSet(s); got != wantStrict { + t.Errorf("IsStrictSuperSet() = %v, want %v", got, wantStrict) + } + }) } + + test("empty", 0, 0, nil, nil, true, false) + test("empty vs non-empty", 0, 100, nil, nil, true, false) + test("non-empty vs empty", 100, 0, nil, nil, true, false) + test("equal", 100, 100, nil, nil, true, false) + + test("set is shorter, subset", 100, 200, map[int]bool{50: true}, map[int]bool{50: false}, false, false) + test("set is shorter, equal", 100, 200, nil, nil, true, false) + test("set is shorter, superset", 100, 200, map[int]bool{50: false}, map[int]bool{50: true}, true, true) + test("set is shorter, neither", 100, 200, map[int]bool{50: true}, map[int]bool{50: false, 150: true}, false, false) + + test("set is longer, subset", 200, 100, map[int]bool{50: true}, map[int]bool{50: false}, false, false) + test("set is longer, equal", 200, 100, nil, nil, true, false) + test("set is longer, superset", 200, 100, nil, map[int]bool{150: true}, true, true) + test("set is longer, neither", 200, 100, map[int]bool{50: false, 150: true}, map[int]bool{50: true}, false, false) } func TestDumpAsBits(t *testing.T) {