diff --git a/ledis/t_set.go b/ledis/t_set.go index 51458bb0..ae58e3a2 100644 --- a/ledis/t_set.go +++ b/ledis/t_set.go @@ -490,7 +490,6 @@ func (db *DB) sStoreGeneric(dstKey []byte, optType byte, keys ...[]byte) (int64, var err error var ek []byte - var num int64 = 0 var v [][]byte switch optType { @@ -513,22 +512,21 @@ func (db *DB) sStoreGeneric(dstKey []byte, optType byte, keys ...[]byte) (int64, ek = db.sEncodeSetKey(dstKey, m) - if v, err := db.db.Get(ek); err != nil { + if _, err := db.db.Get(ek); err != nil { return 0, err - } else if v == nil { - num++ } t.Put(ek, nil) - } - if _, err = db.sIncrSize(dstKey, num); err != nil { + var num = int64(len(v)) + sk := db.sEncodeSizeKey(dstKey) + t.Put(sk, PutInt64(num)) + + if err = t.Commit(); err != nil { return 0, err } - - err = t.Commit() - return num, err + return num, nil } func (db *DB) SClear(key []byte) (int64, error) { diff --git a/ledis/t_set_test.go b/ledis/t_set_test.go index c5f65238..110a9829 100644 --- a/ledis/t_set_test.go +++ b/ledis/t_set_test.go @@ -147,7 +147,7 @@ func testUnion(db *DB, t *testing.T) { m2 := []byte("m2") m3 := []byte("m3") db.SAdd(key, m1, m2) - db.SAdd(key1, m1, m3) + db.SAdd(key1, m1, m2, m3) db.SAdd(key2, m2, m3) if _, err := db.sUnionGeneric(key, key2); err != nil { t.Fatal(err) @@ -158,11 +158,13 @@ func testUnion(db *DB, t *testing.T) { } dstkey := []byte("union_dsk") + db.SAdd(dstkey, []byte("x")) if num, err := db.SUnionStore(dstkey, key1, key2); err != nil { t.Fatal(err) } else if num != 3 { t.Fatal(num) } + if _, err := db.SMembers(dstkey); err != nil { t.Fatal(err) }