diff --git a/CHANGELOG.md b/CHANGELOG.md index f65bc9f..08bcdf7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## v0.10.0 + + - Feature: New ObtainMulti method to acquire multiple locks atomically [#70](https://github.com/bsm/redislock/pull/70) + ## v0.9.4 - Fix: allow to re-obtain locks with tokens [#71](https://github.com/bsm/redislock/pull/71) diff --git a/example_test.go b/example_test.go index cd7bd2c..7aaac4c 100644 --- a/example_test.go +++ b/example_test.go @@ -27,8 +27,10 @@ func Example() { lock, err := locker.Obtain(ctx, "my-key", 100*time.Millisecond, nil) if err == redislock.ErrNotObtained { fmt.Println("Could not obtain lock!") + return } else if err != nil { log.Fatalln(err) + return } // Don't forget to defer Release. diff --git a/obtain.lua b/obtain.lua new file mode 100644 index 0000000..11e55d6 --- /dev/null +++ b/obtain.lua @@ -0,0 +1,40 @@ +-- obtain.lua: arguments => [value, tokenLen, ttl] +-- Obtain.lua try to set provided keys's with value and ttl if they do not exists. +-- Keys can be overriden if they already exists and the correct value+tokenLen is provided. + +local function pexpire(ttl) + -- Update keys ttls. + for _, key in ipairs(KEYS) do + redis.call("pexpire", key, ttl) + end +end + +-- canOverrideLock check either or not the provided token match +-- previously set lock's tokens. +local function canOverrideKeys() + local offset = tonumber(ARGV[2]) + + for _, key in ipairs(KEYS) do + if redis.call("getrange", key, 0, offset-1) ~= string.sub(ARGV[1], 1, offset) then + return false + end + end + return true +end + +-- Prepare mset arguments. +local setArgs = {} +for _, key in ipairs(KEYS) do + table.insert(setArgs, key) + table.insert(setArgs, ARGV[1]) +end + +if redis.call("msetnx", unpack(setArgs)) ~= 1 then + if canOverrideKeys() == false then + return false + end + redis.call("mset", unpack(setArgs)) +end + +pexpire(ARGV[3]) +return redis.status_reply("OK") \ No newline at end of file diff --git a/pttl.lua b/pttl.lua new file mode 100644 index 0000000..0756001 --- /dev/null +++ b/pttl.lua @@ -0,0 +1,21 @@ +-- pttl.lua: => Arguments: [value] +-- pttl.lua returns provided keys's ttls if all their values match the input. + +-- Check all keys values matches provided input. +local values = redis.call("mget", unpack(KEYS)) +for i, _ in ipairs(KEYS) do + if values[i] ~= ARGV[1] then + return false + end +end + +-- Find and return shortest TTL among keys. +local minTTL = 0 +for _, key in ipairs(KEYS) do + local ttl = redis.call("pttl", key) + -- Note: ttl < 0 probably means the key no longer exists. + if ttl > 0 and (minTTL == 0 or ttl < minTTL) then + minTTL = ttl + end +end +return minTTL \ No newline at end of file diff --git a/redislock.go b/redislock.go index 9f12a03..67561f6 100644 --- a/redislock.go +++ b/redislock.go @@ -3,6 +3,7 @@ package redislock import ( "context" "crypto/rand" + _ "embed" "encoding/base64" "errors" "io" @@ -14,16 +15,23 @@ import ( "github.com/redis/go-redis/v9" ) +//go:embed release.lua +var luaReleaseScript string + +//go:embed refresh.lua +var luaRefeshScript string + +//go:embed pttl.lua +var luaPTTLScript string + +//go:embed obtain.lua +var luaObtainScript string + var ( - luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`) - luaRelease = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end`) - luaPTTL = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pttl", KEYS[1]) else return -3 end`) - luaObtain = redis.NewScript(` -if redis.call("set", KEYS[1], ARGV[1], "NX", "PX", ARGV[3]) then return redis.status_reply("OK") end - -local offset = tonumber(ARGV[2]) -if redis.call("getrange", KEYS[1], 0, offset-1) == string.sub(ARGV[1], 1, offset) then return redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[3]) end -`) + luaRefresh = redis.NewScript(luaRefeshScript) + luaRelease = redis.NewScript(luaReleaseScript) + luaPTTL = redis.NewScript(luaPTTLScript) + luaObtain = redis.NewScript(luaObtainScript) ) var ( @@ -54,8 +62,15 @@ func New(client RedisClient) *Client { // Obtain tries to obtain a new lock using a key with the given TTL. // May return ErrNotObtained if not successful. func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt *Options) (*Lock, error) { - token := opt.getToken() + return c.ObtainMulti(ctx, []string{key}, ttl, opt) +} +// ObtainMulti tries to obtain new locks using keys with the given TTL. +// If any of requested key are already locked, no additional keys are +// locked and ErrNotObtained is returned. +// May return ErrNotObtained if not successful. +func (c *Client) ObtainMulti(ctx context.Context, keys []string, ttl time.Duration, opt *Options) (*Lock, error) { + token := opt.getToken() // Create a random token if token == "" { var err error @@ -77,11 +92,11 @@ func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt var ticker *time.Ticker for { - ok, err := c.obtain(ctx, key, value, len(token), ttlVal) + ok, err := c.obtain(ctx, keys, value, len(token), ttlVal) if err != nil { return nil, err } else if ok { - return &Lock{Client: c, key: key, value: value, tokenLen: len(token)}, nil + return &Lock{Client: c, keys: keys, value: value, tokenLen: len(token)}, nil } backoff := retry.NextBackoff() @@ -104,11 +119,12 @@ func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt } } -func (c *Client) obtain(ctx context.Context, key, value string, tokenLen int, ttlVal string) (bool, error) { - _, err := luaObtain.Run(ctx, c.client, []string{key}, value, tokenLen, ttlVal).Result() - if err == redis.Nil { - return false, nil - } else if err != nil { +func (c *Client) obtain(ctx context.Context, keys []string, value string, tokenLen int, ttlVal string) (bool, error) { + _, err := luaObtain.Run(ctx, c.client, keys, value, tokenLen, ttlVal).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return false, nil + } return false, err } return true, nil @@ -133,7 +149,7 @@ func (c *Client) randomToken() (string, error) { // Lock represents an obtained, distributed lock. type Lock struct { *Client - key string + keys []string value string tokenLen int } @@ -143,9 +159,20 @@ func Obtain(ctx context.Context, client RedisClient, key string, ttl time.Durati return New(client).Obtain(ctx, key, ttl, opt) } +// ObtainMulti is a short-cut for New(...).ObtainMulti(...). +func ObtainMulti(ctx context.Context, client RedisClient, keys []string, ttl time.Duration, opt *Options) (*Lock, error) { + return New(client).ObtainMulti(ctx, keys, ttl, opt) +} + // Key returns the redis key used by the lock. +// If the lock hold multiple key, only the first is returned. func (l *Lock) Key() string { - return l.key + return l.keys[0] +} + +// Keys returns the redis keys used by the lock. +func (l *Lock) Keys() []string { + return l.keys } // Token returns the token value set by the lock. @@ -159,14 +186,18 @@ func (l *Lock) Metadata() string { } // TTL returns the remaining time-to-live. Returns 0 if the lock has expired. +// In case lock is holding multiple keys, TTL returns the min ttl among thoses keys. func (l *Lock) TTL(ctx context.Context) (time.Duration, error) { - res, err := luaPTTL.Run(ctx, l.client, []string{l.key}, l.value).Result() - if err == redis.Nil { - return 0, nil - } else if err != nil { + if l == nil { + return 0, ErrLockNotHeld + } + res, err := luaPTTL.Run(ctx, l.client, l.keys, l.value).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return 0, nil + } return 0, err } - if num := res.(int64); num > 0 { return time.Duration(num) * time.Millisecond, nil } @@ -176,14 +207,18 @@ func (l *Lock) TTL(ctx context.Context) (time.Duration, error) { // Refresh extends the lock with a new TTL. // May return ErrNotObtained if refresh is unsuccessful. func (l *Lock) Refresh(ctx context.Context, ttl time.Duration, opt *Options) error { + if l == nil { + return ErrNotObtained + } ttlVal := strconv.FormatInt(int64(ttl/time.Millisecond), 10) - status, err := luaRefresh.Run(ctx, l.client, []string{l.key}, l.value, ttlVal).Result() + _, err := luaRefresh.Run(ctx, l.client, l.keys, l.value, ttlVal).Result() if err != nil { + if errors.Is(err, redis.Nil) { + return ErrNotObtained + } return err - } else if status == int64(1) { - return nil } - return ErrNotObtained + return nil } // Release manually releases the lock. @@ -192,17 +227,13 @@ func (l *Lock) Release(ctx context.Context) error { if l == nil { return ErrLockNotHeld } - - res, err := luaRelease.Run(ctx, l.client, []string{l.key}, l.value).Result() - if err == redis.Nil { - return ErrLockNotHeld - } else if err != nil { + _, err := luaRelease.Run(ctx, l.client, l.keys, l.value).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return ErrLockNotHeld + } return err } - - if i, ok := res.(int64); !ok || i != 1 { - return ErrLockNotHeld - } return nil } diff --git a/redislock_test.go b/redislock_test.go index 8fab416..51f1332 100644 --- a/redislock_test.go +++ b/redislock_test.go @@ -3,7 +3,10 @@ package redislock_test import ( "context" "errors" + "fmt" "math/rand" + "path/filepath" + "runtime" "sync" "sync/atomic" "testing" @@ -13,8 +16,6 @@ import ( "github.com/redis/go-redis/v9" ) -const lockKey = "__bsm_redislock_unit_test__" - var redisOpts = &redis.Options{ Network: "tcp", Addr: "127.0.0.1:6379", @@ -22,9 +23,10 @@ var redisOpts = &redis.Options{ } func TestClient(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) // init client client := New(rc) @@ -63,20 +65,22 @@ func TestClient(t *testing.T) { } func TestObtain(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) - lock := quickObtain(t, rc, time.Hour) + lock := quickObtain(t, rc, lockKey, time.Hour) if err := lock.Release(ctx); err != nil { t.Fatal(err) } } func TestObtain_metadata(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) meta := "my-data" lock, err := Obtain(ctx, rc, lockKey, time.Hour, &Options{Metadata: meta}) @@ -91,9 +95,10 @@ func TestObtain_metadata(t *testing.T) { } func TestObtain_custom_token(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) // obtain lock lock1, err := Obtain(ctx, rc, lockKey, time.Hour, &Options{Token: "foo", Metadata: "bar"}) @@ -131,12 +136,13 @@ func TestObtain_custom_token(t *testing.T) { } func TestObtain_retry_success(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) // obtain for 20ms - lock1 := quickObtain(t, rc, 20*time.Millisecond) + lock1 := quickObtain(t, rc, lockKey, 20*time.Millisecond) defer lock1.Release(ctx) // lock again with linar retry - 3x for 20ms @@ -150,12 +156,13 @@ func TestObtain_retry_success(t *testing.T) { } func TestObtain_retry_failure(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) // obtain for 50ms - lock1 := quickObtain(t, rc, 50*time.Millisecond) + lock1 := quickObtain(t, rc, lockKey, 50*time.Millisecond) defer lock1.Release(ctx) // lock again with linar retry - 2x for 5ms @@ -168,9 +175,10 @@ func TestObtain_retry_failure(t *testing.T) { } func TestObtain_concurrent(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) numLocks := int32(0) numThreads := 100 @@ -207,11 +215,12 @@ func TestObtain_concurrent(t *testing.T) { } func TestLock_Refresh(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) - lock := quickObtain(t, rc, time.Hour) + lock := quickObtain(t, rc, lockKey, time.Hour) defer lock.Release(ctx) // check TTL @@ -227,11 +236,12 @@ func TestLock_Refresh(t *testing.T) { } func TestLock_Refresh_expired(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) - lock := quickObtain(t, rc, 5*time.Millisecond) + lock := quickObtain(t, rc, lockKey, 5*time.Millisecond) defer lock.Release(ctx) // try releasing @@ -242,11 +252,12 @@ func TestLock_Refresh_expired(t *testing.T) { } func TestLock_Release_expired(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) - lock := quickObtain(t, rc, 5*time.Millisecond) + lock := quickObtain(t, rc, lockKey, 5*time.Millisecond) defer lock.Release(ctx) // try releasing @@ -257,11 +268,12 @@ func TestLock_Release_expired(t *testing.T) { } func TestLock_Release_not_own(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) - lock := quickObtain(t, rc, time.Hour) + lock := quickObtain(t, rc, lockKey, time.Hour) defer lock.Release(ctx) // manually transfer ownership @@ -269,6 +281,15 @@ func TestLock_Release_not_own(t *testing.T) { t.Fatal(err) } + cmd := rc.Get(ctx, lockKey) + v, err := cmd.Result() + if err != nil { + t.Fatal(err) + } + if v != "ABCD" { + t.Fatalf("expected %v, got %v", "ABCD", v) + } + // try releasing if exp, got := ErrLockNotHeld, lock.Release(ctx); !errors.Is(got, exp) { t.Fatalf("expected %v, got %v", exp, got) @@ -276,11 +297,12 @@ func TestLock_Release_not_own(t *testing.T) { } func TestLock_Release_not_held(t *testing.T) { + lockKey := getLockKey() ctx := context.Background() rc := redis.NewClient(redisOpts) - defer teardown(t, rc) + defer teardown(t, rc, lockKey) - lock1 := quickObtain(t, rc, time.Hour) + lock1 := quickObtain(t, rc, lockKey, time.Hour) defer lock1.Release(ctx) lock2, err := Obtain(context.Background(), rc, lockKey, time.Minute, nil) @@ -295,7 +317,77 @@ func TestLock_Release_not_held(t *testing.T) { } } -func quickObtain(t *testing.T, rc *redis.Client, ttl time.Duration) *Lock { +func TestLock_ObtainMulti(t *testing.T) { + lockKeys := []string{ + getLockKey() + "_MultiLock_1", + getLockKey() + "_MultiLock_2", + getLockKey() + "_MultiLock_3", + getLockKey() + "_MultiLock_4", + } + ctx := context.Background() + rc := redis.NewClient(redisOpts) + defer teardown(t, rc, lockKeys...) + + lockKey1 := lockKeys[0] + lockKey2 := lockKeys[1] + lockKey3 := lockKeys[2] + lockKey4 := lockKeys[3] + + // 1. Obtain lock 1 and 2 + lock12, err := ObtainMulti(ctx, rc, []string{lockKey1, lockKey2}, time.Hour, nil) + if err != nil { + t.Fatal(err) + } + + // 2. Obtain lock 3 and 4 + lock34, err := ObtainMulti(ctx, rc, []string{lockKey3, lockKey4}, time.Hour, nil) + if err != nil { + t.Fatal(err) + } + + // 3. Try to obtain lock 2 and 3 + _, err = ObtainMulti(ctx, rc, []string{lockKey2, lockKey3}, time.Hour, nil) + // Expect it to fail since lock 2 and 3 are already locked. + if !errors.Is(err, ErrNotObtained) { + t.Fatalf("expected ErrNotObtained, got %s.", err) + } + + // 4. Release lock 1 and 2 + lock12.Release(ctx) + + // 5. Obtain lock 1 + lock1, err := ObtainMulti(ctx, rc, []string{lockKey1}, time.Hour, nil) + // Expected to succeed since lock 1 was released (along with lock 2) + if err != nil { + t.Fatal(err) + } + defer lock1.Release(ctx) + + // 6. Try to obtain lock 2 and 3 (again) + _, err = ObtainMulti(ctx, rc, []string{lockKey2, lockKey3}, time.Hour, nil) + // Expect it to fail since lock 3 is still locked. + if !errors.Is(err, ErrNotObtained) { + t.Fatalf("expected ErrNotObtained, got %s.", err) + } + + // 7. Release lock 3 and 4 + lock34.Release(ctx) + + // 8. Try to obtain lock 2 and 3 (again) + lock23, err := ObtainMulti(ctx, rc, []string{lockKey2, lockKey3}, time.Hour, nil) + // Expect it to succeed since lock 2 and 3 are available. + if err != nil { + t.Fatal(err) + } + + defer lock23.Release(ctx) +} + +func getLockKey() string { + return fmt.Sprintf("__bsm_redislock_%s_%d__", getCallingFunctionName(1), time.Now().UnixNano()) +} + +func quickObtain(t *testing.T, rc *redis.Client, lockKey string, ttl time.Duration) *Lock { t.Helper() lock, err := Obtain(context.Background(), rc, lockKey, ttl, nil) @@ -322,13 +414,25 @@ func assertTTL(t *testing.T, lock *Lock, exp time.Duration) { } } -func teardown(t *testing.T, rc *redis.Client) { +func teardown(t *testing.T, rc *redis.Client, lockKeys ...string) { t.Helper() - if err := rc.Del(context.Background(), lockKey).Err(); err != nil { - t.Fatal(err) + for _, lockKey := range lockKeys { + if err := rc.Del(context.Background(), lockKey).Err(); err != nil { + t.Fatal(err) + } } if err := rc.Close(); err != nil { t.Fatal(err) } } + +func getCallingFunctionName(skipFrameCount int) string { + fpc, _, _, _ := runtime.Caller(1 + skipFrameCount) + funcName := "unknown" + fun := runtime.FuncForPC(fpc) + if fun != nil { + _, funcName = filepath.Split(fun.Name()) + } + return funcName +} diff --git a/refresh.lua b/refresh.lua new file mode 100644 index 0000000..a7ca443 --- /dev/null +++ b/refresh.lua @@ -0,0 +1,17 @@ +-- refresh.lua: => Arguments: [value, ttl] +-- refresh.lua refreshes provided keys's ttls if all their values match the input. + +-- Check all keys values matches provided input. +local values = redis.call("mget", unpack(KEYS)) +for i, _ in ipairs(KEYS) do + if values[i] ~= ARGV[1] then + return false + end +end + +-- Update keys ttls. +for _, key in ipairs(KEYS) do + redis.call("pexpire", key, ARGV[2]) +end + +return redis.status_reply("OK") \ No newline at end of file diff --git a/release.lua b/release.lua new file mode 100644 index 0000000..da611d9 --- /dev/null +++ b/release.lua @@ -0,0 +1,15 @@ +-- release.lua: => Arguments: [value] +-- Release.lua deletes provided keys if all their values match the input. + +-- Check all keys values matches provided input. +local values = redis.call("mget", unpack(KEYS)) +for i, _ in ipairs(KEYS) do + if values[i] ~= ARGV[1] then + return false + end +end + +-- Delete keys. +redis.call("del", unpack(KEYS)) + +return redis.status_reply("OK") \ No newline at end of file