Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,3 +507,65 @@ func TestE2E_StatsItems(t *testing.T) {
must.Positive(t, data[0].Number)
must.Positive(t, data[0].MemRequested)
}

func TestE2E_CAS(t *testing.T) {
t.Parallel()

address, done := memctest.LaunchTCP(t, nil)
t.Cleanup(done)

c := New([]string{address})
defer ignore.Close(c)

t.Run("success", func(t *testing.T) {
err := Set(c, "key1", "value1")
must.NoError(t, err)

v, cas, verr := Gets[string](c, "key1")
must.NoError(t, verr)
must.Eq(t, "value1", v)
must.Positive(t, uint64(cas))

err = CompareAndSwap(c, "key1", cas, "value1.updated")
must.NoError(t, err)

v, err = Get[string](c, "key1")
must.NoError(t, err)
must.Eq(t, "value1.updated", v)
})

t.Run("conflict", func(t *testing.T) {
err := Set(c, "key2", "original")
must.NoError(t, err)

_, cas1, verr := Gets[string](c, "key2")
must.NoError(t, verr)

_, _, verr = Gets[string](c, "key2")
must.NoError(t, verr)

err = CompareAndSwap(c, "key2", cas1, "first-update")
must.NoError(t, err)

err = CompareAndSwap(c, "key2", cas1, "stale-update")
must.ErrorIs(t, err, ErrConflict)

v, err := Get[string](c, "key2")
must.NoError(t, err)
must.Eq(t, "first-update", v)
})

t.Run("not found", func(t *testing.T) {
err := Set(c, "key3", "value3")
must.NoError(t, err)

_, cas, verr := Gets[string](c, "key3")
must.NoError(t, verr)

err = Delete(c, "key3")
must.NoError(t, err)

err = CompareAndSwap(c, "key3", cas, "newvalue")
must.ErrorIs(t, err, ErrNotFound)
})
}
8 changes: 4 additions & 4 deletions iopool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ type Buffer struct {

func newBuffer(conn Connection) *Buffer {
return &Buffer{
bufio.NewReader(conn),
bufio.NewWriter(conn),
conn,
new(atomic.Bool),
Reader: bufio.NewReader(conn),
Writer: bufio.NewWriter(conn),
Closer: conn,
failure: new(atomic.Bool),
}
}

Expand Down
176 changes: 175 additions & 1 deletion verbs.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ var (
ErrCommandIssue = errors.New("memc: got command error response")
)

// CAS represents a Compare-And-Swap token used for optimistic locking.
// It is returned by Gets and must be provided to CompareAndSwap to atomically update a value.
type CAS uint64

// Options contains configuration parameters that may be applied when executing
// a verb like Get, Set, etc.
type Options struct {
Expand Down Expand Up @@ -437,6 +441,88 @@ func Add[T any](c *Client, key string, item T, opts ...Option) error {
})
}

// CompareAndSwap will store the item using the given key, but only if the CAS
// token matches the current value's CAS token. This provides atomic
// compare-and-swap functionality for optimistic locking.
//
// If the key does not exist, ErrNotFound is returned.
//
// If the CAS token does not match (meaning the value was modified since it was
// retrieved with Gets), ErrConflict is returned.
//
// Uses Client c to connect to a memcached instance, and automatically handles
// connection pooling and reuse.
//
// One or more Option(s) may be applied to configure things such as the value
// expiration TTL or its associated flags.
func CompareAndSwap[T any](c *Client, key string, cas CAS, item T, opts ...Option) error {
if err := check(key); err != nil {
return err
}

options := &Options{
expiration: c.expiration,
flags: 0,
}

for _, opt := range opts {
opt(options)
}

return c.do(key, func(conn *iopool.Buffer) error {
encoding, encerr := encode(item)
if encerr != nil {
return encerr
}

expiration, experr := c.seconds(options.expiration)
if experr != nil {
return experr
}

// write the header components with CAS token
if _, err := fmt.Fprintf(
conn,
"cas %s %d %d %d %d\r\n",
key, options.flags, expiration, len(encoding), cas,
); err != nil {
return err
}

// write the payload
if _, err := conn.Write(encoding); err != nil {
return err
}

// write clrf
if _, err := io.WriteString(conn, "\r\n"); err != nil {
return err
}

// flush the buffer
if err := conn.Flush(); err != nil {
return err
}

// read response
line, lerr := conn.ReadSlice('\n')
if lerr != nil {
return lerr
}

switch string(line) {
case "STORED\r\n":
return nil
case "NOT_FOUND\r\n":
return ErrNotFound
case "EXISTS\r\n":
return ErrConflict
default:
return fmt.Errorf("memc: unexpected response to cas: %q", string(line))
}
})
}

// Get the value associated with the given key.
//
// Uses Client c to connect to a memcached instance, and automatically handles
Expand Down Expand Up @@ -472,6 +558,51 @@ func Get[T any](c *Client, key string) (T, error) {
return result, err
}

// Gets the value associated with the given key, along with its CAS token.
//
// The CAS token can be used with CompareAndSwap to atomically update the value,
// providing optimistic locking. If the value has been modified since it was
// retrieved, CompareAndSwap will return an ErrConflict error.
//
// Uses Client c to connect to a memcached instance, and automatically handles
// connection pooling and reuse.
func Gets[T any](c *Client, key string) (T, CAS, error) {
var result T
var casToken CAS

if err := check(key); err != nil {
return result, 0, err
}

err := c.do(key, func(conn *iopool.Buffer) error {
// write the header components
if _, err := fmt.Fprintf(conn, "gets %s\r\n", key); err != nil {
return err
}

// flush the connection, forcing bytes over the wire
if err := conn.Flush(); err != nil {
return err
}

// read the response payload with CAS token
payload, cas, err := getPayloadWithCAS(conn.Reader)
if err != nil {
return err
}

result, err = decode[T](payload)
if err != nil {
return err
}

casToken = CAS(cas)
return nil
})

return result, casToken, err
}

func getPayload(r *bufio.Reader) ([]byte, error) {
b, err := r.ReadSlice('\n')
if err != nil {
Expand All @@ -483,7 +614,6 @@ func getPayload(r *bufio.Reader) ([]byte, error) {
return nil, ErrCacheMiss
}

// TODO: does not handle CAS value for now
expect := "VALUE %s %d %d\r\n"
var (
key string
Expand Down Expand Up @@ -515,6 +645,50 @@ func getPayload(r *bufio.Reader) ([]byte, error) {
return payload, err
}

func getPayloadWithCAS(r *bufio.Reader) ([]byte, uint64, error) {
b, err := r.ReadSlice('\n')
if err != nil {
return nil, 0, err
}

// key was not found, is a cache miss
if string(b) == "END\r\n" {
return nil, 0, ErrCacheMiss
}

// handle CAS value - format is "VALUE key flags bytes cas\r\n"
expect := "VALUE %s %d %d %d\r\n"
var (
key string
flags int
size int
cas uint64
)

// scan the header line, giving us a payload size and CAS token
if _, err = fmt.Sscanf(string(b), expect, &key, &flags, &size, &cas); err != nil {
return nil, 0, err
}

// read the data into our payload
payload := make([]byte, size+2) // including \r\n
if _, err = io.ReadFull(r, payload); err != nil {
return nil, 0, err
}
payload = payload[0:size] // chop \r\n

// read the trailing line ("END\r\n")
b, err = r.ReadSlice('\n')
if err != nil {
return nil, 0, err
}
if string(b) != "END\r\n" {
return nil, 0, unexpected(b)
}

return payload, cas, nil
}

// Delete will remove the value associated with key from memcached.
//
// Uses Client c to connect to a memcached instance, and automatically handles
Expand Down
Loading