Skip to content

Commit

Permalink
Add multi-lock support for atomic multi-lock. (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhallionOhbibi committed Nov 24, 2023
1 parent 788a79b commit cd93b2b
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 63 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## v0.10.0

- Feature: New ObtainMulti method to acquire multiple locks atomically [#70](https://github.com/bsm/redislock/pull/70)

## v0.9.4

- Fix: allow to re-obtain locks with tokens [#71](https://github.com/bsm/redislock/pull/71)
Expand Down
2 changes: 2 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ func Example() {
lock, err := locker.Obtain(ctx, "my-key", 100*time.Millisecond, nil)
if err == redislock.ErrNotObtained {
fmt.Println("Could not obtain lock!")
return
} else if err != nil {
log.Fatalln(err)
return
}

// Don't forget to defer Release.
Expand Down
40 changes: 40 additions & 0 deletions obtain.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
-- obtain.lua: arguments => [value, tokenLen, ttl]
-- Obtain.lua try to set provided keys's with value and ttl if they do not exists.
-- Keys can be overriden if they already exists and the correct value+tokenLen is provided.

local function pexpire(ttl)
-- Update keys ttls.
for _, key in ipairs(KEYS) do
redis.call("pexpire", key, ttl)
end
end

-- canOverrideLock check either or not the provided token match
-- previously set lock's tokens.
local function canOverrideKeys()
local offset = tonumber(ARGV[2])

for _, key in ipairs(KEYS) do
if redis.call("getrange", key, 0, offset-1) ~= string.sub(ARGV[1], 1, offset) then
return false
end
end
return true
end

-- Prepare mset arguments.
local setArgs = {}
for _, key in ipairs(KEYS) do
table.insert(setArgs, key)
table.insert(setArgs, ARGV[1])
end

if redis.call("msetnx", unpack(setArgs)) ~= 1 then
if canOverrideKeys() == false then
return false
end
redis.call("mset", unpack(setArgs))
end

pexpire(ARGV[3])
return redis.status_reply("OK")
21 changes: 21 additions & 0 deletions pttl.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
-- pttl.lua: => Arguments: [value]
-- pttl.lua returns provided keys's ttls if all their values match the input.

-- Check all keys values matches provided input.
local values = redis.call("mget", unpack(KEYS))
for i, _ in ipairs(KEYS) do
if values[i] ~= ARGV[1] then
return false
end
end

-- Find and return shortest TTL among keys.
local minTTL = 0
for _, key in ipairs(KEYS) do
local ttl = redis.call("pttl", key)
-- Note: ttl < 0 probably means the key no longer exists.
if ttl > 0 and (minTTL == 0 or ttl < minTTL) then
minTTL = ttl
end
end
return minTTL
105 changes: 68 additions & 37 deletions redislock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package redislock
import (
"context"
"crypto/rand"
_ "embed"
"encoding/base64"
"errors"
"io"
Expand All @@ -14,16 +15,23 @@ import (
"github.com/redis/go-redis/v9"
)

//go:embed release.lua
var luaReleaseScript string

//go:embed refresh.lua
var luaRefeshScript string

//go:embed pttl.lua
var luaPTTLScript string

//go:embed obtain.lua
var luaObtainScript string

var (
luaRefresh = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end`)
luaRelease = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end`)
luaPTTL = redis.NewScript(`if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pttl", KEYS[1]) else return -3 end`)
luaObtain = redis.NewScript(`
if redis.call("set", KEYS[1], ARGV[1], "NX", "PX", ARGV[3]) then return redis.status_reply("OK") end
local offset = tonumber(ARGV[2])
if redis.call("getrange", KEYS[1], 0, offset-1) == string.sub(ARGV[1], 1, offset) then return redis.call("set", KEYS[1], ARGV[1], "PX", ARGV[3]) end
`)
luaRefresh = redis.NewScript(luaRefeshScript)
luaRelease = redis.NewScript(luaReleaseScript)
luaPTTL = redis.NewScript(luaPTTLScript)
luaObtain = redis.NewScript(luaObtainScript)
)

var (
Expand Down Expand Up @@ -54,8 +62,15 @@ func New(client RedisClient) *Client {
// Obtain tries to obtain a new lock using a key with the given TTL.
// May return ErrNotObtained if not successful.
func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt *Options) (*Lock, error) {
token := opt.getToken()
return c.ObtainMulti(ctx, []string{key}, ttl, opt)
}

// ObtainMulti tries to obtain new locks using keys with the given TTL.
// If any of requested key are already locked, no additional keys are
// locked and ErrNotObtained is returned.
// May return ErrNotObtained if not successful.
func (c *Client) ObtainMulti(ctx context.Context, keys []string, ttl time.Duration, opt *Options) (*Lock, error) {
token := opt.getToken()
// Create a random token
if token == "" {
var err error
Expand All @@ -77,11 +92,11 @@ func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt

var ticker *time.Ticker
for {
ok, err := c.obtain(ctx, key, value, len(token), ttlVal)
ok, err := c.obtain(ctx, keys, value, len(token), ttlVal)
if err != nil {
return nil, err
} else if ok {
return &Lock{Client: c, key: key, value: value, tokenLen: len(token)}, nil
return &Lock{Client: c, keys: keys, value: value, tokenLen: len(token)}, nil
}

backoff := retry.NextBackoff()
Expand All @@ -104,11 +119,12 @@ func (c *Client) Obtain(ctx context.Context, key string, ttl time.Duration, opt
}
}

func (c *Client) obtain(ctx context.Context, key, value string, tokenLen int, ttlVal string) (bool, error) {
_, err := luaObtain.Run(ctx, c.client, []string{key}, value, tokenLen, ttlVal).Result()
if err == redis.Nil {
return false, nil
} else if err != nil {
func (c *Client) obtain(ctx context.Context, keys []string, value string, tokenLen int, ttlVal string) (bool, error) {
_, err := luaObtain.Run(ctx, c.client, keys, value, tokenLen, ttlVal).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return false, nil
}
return false, err
}
return true, nil
Expand All @@ -133,7 +149,7 @@ func (c *Client) randomToken() (string, error) {
// Lock represents an obtained, distributed lock.
type Lock struct {
*Client
key string
keys []string
value string
tokenLen int
}
Expand All @@ -143,9 +159,20 @@ func Obtain(ctx context.Context, client RedisClient, key string, ttl time.Durati
return New(client).Obtain(ctx, key, ttl, opt)
}

// ObtainMulti is a short-cut for New(...).ObtainMulti(...).
func ObtainMulti(ctx context.Context, client RedisClient, keys []string, ttl time.Duration, opt *Options) (*Lock, error) {
return New(client).ObtainMulti(ctx, keys, ttl, opt)
}

// Key returns the redis key used by the lock.
// If the lock hold multiple key, only the first is returned.
func (l *Lock) Key() string {
return l.key
return l.keys[0]
}

// Keys returns the redis keys used by the lock.
func (l *Lock) Keys() []string {
return l.keys
}

// Token returns the token value set by the lock.
Expand All @@ -159,14 +186,18 @@ func (l *Lock) Metadata() string {
}

// TTL returns the remaining time-to-live. Returns 0 if the lock has expired.
// In case lock is holding multiple keys, TTL returns the min ttl among thoses keys.
func (l *Lock) TTL(ctx context.Context) (time.Duration, error) {
res, err := luaPTTL.Run(ctx, l.client, []string{l.key}, l.value).Result()
if err == redis.Nil {
return 0, nil
} else if err != nil {
if l == nil {
return 0, ErrLockNotHeld
}
res, err := luaPTTL.Run(ctx, l.client, l.keys, l.value).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return 0, nil
}
return 0, err
}

if num := res.(int64); num > 0 {
return time.Duration(num) * time.Millisecond, nil
}
Expand All @@ -176,14 +207,18 @@ func (l *Lock) TTL(ctx context.Context) (time.Duration, error) {
// Refresh extends the lock with a new TTL.
// May return ErrNotObtained if refresh is unsuccessful.
func (l *Lock) Refresh(ctx context.Context, ttl time.Duration, opt *Options) error {
if l == nil {
return ErrNotObtained
}
ttlVal := strconv.FormatInt(int64(ttl/time.Millisecond), 10)
status, err := luaRefresh.Run(ctx, l.client, []string{l.key}, l.value, ttlVal).Result()
_, err := luaRefresh.Run(ctx, l.client, l.keys, l.value, ttlVal).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return ErrNotObtained
}
return err
} else if status == int64(1) {
return nil
}
return ErrNotObtained
return nil
}

// Release manually releases the lock.
Expand All @@ -192,17 +227,13 @@ func (l *Lock) Release(ctx context.Context) error {
if l == nil {
return ErrLockNotHeld
}

res, err := luaRelease.Run(ctx, l.client, []string{l.key}, l.value).Result()
if err == redis.Nil {
return ErrLockNotHeld
} else if err != nil {
_, err := luaRelease.Run(ctx, l.client, l.keys, l.value).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return ErrLockNotHeld
}
return err
}

if i, ok := res.(int64); !ok || i != 1 {
return ErrLockNotHeld
}
return nil
}

Expand Down
Loading

0 comments on commit cd93b2b

Please sign in to comment.