Skip to content

Commit

Permalink
Allow to re-obtain atomically with custom tokens (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
dim committed Jul 27, 2023
1 parent 06ee54a commit 788a79b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## v0.9.4

- Fix: allow to re-obtain locks with tokens [#71](https://github.com/bsm/redislock/pull/71)

## v0.9.3

- Feature: allow custom lock tokens [#66](https://github.com/bsm/redislock/pull/66)
Expand Down
20 changes: 16 additions & 4 deletions redislock.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ var (
luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
luaRelease = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end`)
luaPTTL = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pttl", KEYS[1]) else return -3 end`)
luaObtain = redis.NewScript(`
if redis.call("set", KEYS[1], ARGV[1], "NX", "PX", ARGV[3]) then return redis.status_reply("OK") end
local offset = tonumber(ARGV[2])
if redis.call("getrange", KEYS[1], 0, offset-1) == string.sub(ARGV[1], 1, offset) then return redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[3]) end
`)
)

var (
Expand All @@ -31,7 +37,6 @@ var (
// RedisClient is a minimal client interface.
type RedisClient interface {
redis.Scripter
SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.BoolCmd
}

// Client wraps a redis client.
Expand Down Expand Up @@ -60,6 +65,7 @@ func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt
}

value := token + opt.getMetadata()
ttlVal := strconv.FormatInt(int64(ttl/time.Millisecond), 10)
retry := opt.getRetryStrategy()

// make sure we don't retry forever
Expand All @@ -71,7 +77,7 @@ func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt

var ticker *time.Ticker
for {
ok, err := c.obtain(ctx, key, value, ttl)
ok, err := c.obtain(ctx, key, value, len(token), ttlVal)
if err != nil {
return nil, err
} else if ok {
Expand All @@ -98,8 +104,14 @@ func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt
}
}

func (c *Client) obtain(ctx context.Context, key, value string, ttl time.Duration) (bool, error) {
return c.client.SetNX(ctx, key, value, ttl).Result()
func (c *Client) obtain(ctx context.Context, key, value string, tokenLen int, ttlVal string) (bool, error) {
_, err := luaObtain.Run(ctx, c.client, []string{key}, value, tokenLen, ttlVal).Result()
if err == redis.Nil {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}

func (c *Client) randomToken() (string, error) {
Expand Down
31 changes: 26 additions & 5 deletions redislock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,39 @@ func TestObtain_custom_token(t *testing.T) {
rc := redis.NewClient(redisOpts)
defer teardown(t, rc)

lock, err := Obtain(ctx, rc, lockKey, time.Hour, &Options{Token: "foo", Metadata: "bar"})
// obtain lock
lock1, err := Obtain(ctx, rc, lockKey, time.Hour, &Options{Token: "foo", Metadata: "bar"})
if err != nil {
t.Fatal(err)
}
defer lock.Release(ctx)
defer lock1.Release(ctx)

if exp, got := "foo", lock.Token(); exp != got {
t.Fatalf("expected %v, got %v", exp, got)
if exp, got := "foo", lock1.Token(); exp != got {
t.Errorf("expected %v, got %v", exp, got)
}
if exp, got := "bar", lock1.Metadata(); exp != got {
t.Errorf("expected %v, got %v", exp, got)
}
if exp, got := "bar", lock.Metadata(); exp != got {

// try to obtain again
_, err = Obtain(ctx, rc, lockKey, time.Hour, nil)
if exp, got := ErrNotObtained, err; !errors.Is(got, exp) {
t.Fatalf("expected %v, got %v", exp, got)
}

// allow to re-obtain lock if token is known
lock2, err := Obtain(ctx, rc, lockKey, time.Hour, &Options{Token: "foo", Metadata: "baz"})
if err != nil {
t.Fatal(err)
}
defer lock2.Release(ctx)

if exp, got := "foo", lock2.Token(); exp != got {
t.Errorf("expected %v, got %v", exp, got)
}
if exp, got := "baz", lock2.Metadata(); exp != got {
t.Errorf("expected %v, got %v", exp, got)
}
}

func TestObtain_retry_success(t *testing.T) {
Expand Down

0 comments on commit 788a79b

Please sign in to comment.