diff --git a/memcache/memcache.go b/memcache/memcache.go index b2ebacd2..9bf694ec 100644 --- a/memcache/memcache.go +++ b/memcache/memcache.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "io" + "math" "net" "strconv" "strings" @@ -541,17 +542,52 @@ func parseGetResponse(r *bufio.Reader, cb func(*Item)) error { // scanGetResponseLine populates it and returns the declared size of the item. // It does not read the bytes of the item. func scanGetResponseLine(line []byte, it *Item) (size int, err error) { - pattern := "VALUE %s %d %d %d\r\n" - dest := []interface{}{&it.Key, &it.Flags, &size, &it.CasID} - if bytes.Count(line, space) == 3 { - pattern = "VALUE %s %d %d\r\n" - dest = dest[:3] - } - n, err := fmt.Sscanf(string(line), pattern, dest...) - if err != nil || n != len(dest) { + errf := func(line []byte) (int, error) { return -1, fmt.Errorf("memcache: unexpected line in get response: %q", line) } - return size, nil + if !bytes.HasPrefix(line, []byte("VALUE ")) || !bytes.HasSuffix(line, []byte("\r\n")) { + return errf(line) + } + s := string(line[6 : len(line)-2]) + var rest string + var found bool + it.Key, rest, found = cut(s, ' ') + if !found { + return errf(line) + } + val, rest, found := cut(rest, ' ') + if !found { + return errf(line) + } + flags64, err := strconv.ParseUint(val, 10, 32) + if err != nil { + return errf(line) + } + it.Flags = uint32(flags64) + val, rest, found = cut(rest, ' ') + size64, err := strconv.ParseUint(val, 10, 32) + if err != nil { + return errf(line) + } + if size64 > math.MaxInt { // Can happen if int is 32-bit + return errf(line) + } + if !found { // final CAS ID is optional. + return int(size64), nil + } + it.CasID, err = strconv.ParseUint(rest, 10, 64) + if err != nil { + return errf(line) + } + return int(size64), nil +} + +// Similar to strings.Cut in Go 1.18, but sep can only be 1 byte. +func cut(s string, sep byte) (before, after string, found bool) { + if i := strings.IndexByte(s, sep); i >= 0 { + return s[:i], s[i+1:], true + } + return s, "", false } // Set writes the given item, unconditionally. diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index 4b60a15d..a0fa746d 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -424,3 +424,81 @@ func BenchmarkOnItem(b *testing.B) { c.onItem(&item, dummyFn) } } + +func BenchmarkScanGetResponseLine(b *testing.B) { + line := []byte("VALUE foobar1234 0 4096 1234\r\n") + var it Item + for i := 0; i < b.N; i++ { + _, err := scanGetResponseLine(line, &it) + if err != nil { + b.Fatal(err) + } + } +} + +func TestScanGetResponseLine(t *testing.T) { + tests := []struct { + name string + line string + wantKey string + wantFlags uint32 + wantCasid uint64 + wantSize int + wantErr bool + }{ + {name: "blank", line: "", + wantErr: true}, + {name: "malformed1", line: "VALU foobar1234 1 4096\r\n", + wantErr: true}, + {name: "malformed2", line: "VALUEfoobar1234 1 4096\r\n", + wantErr: true}, + {name: "malformed3", line: "VALUE foobar1234 14096\r\n", + wantErr: true}, + {name: "malformed4", line: "VALUE foobar123414096\r\n", + wantErr: true}, + {name: "no-eol", line: "VALUE foobar1234 1 4096", + wantErr: true}, + {name: "basic", line: "VALUE foobar1234 1 4096\r\n", + wantKey: "foobar1234", wantFlags: 1, wantSize: 4096}, + {name: "casid", line: "VALUE foobar1234 1 4096 1234\r\n", + wantKey: "foobar1234", wantFlags: 1, wantSize: 4096, wantCasid: 1234}, + {name: "flags-max-uint32", line: "VALUE key 4294967295 1\r\n", + wantKey: "key", wantFlags: 4294967295, wantSize: 1}, + {name: "flags-overflow", line: "VALUE key 4294967296 1\r\n", + wantErr: true}, + {name: "size-max-uint32", line: "VALUE key 1 2147483647\r\n", + wantKey: "key", wantFlags: 1, wantSize: 2147483647}, + {name: "size-overflow", line: "VALUE key 1 4294967296\r\n", + wantErr: true}, + {name: "casid-overflow", line: "VALUE key 1 4096 18446744073709551616\r\n", + wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Item + gotSize, err := scanGetResponseLine([]byte(tt.line), &got) + if tt.wantErr { + if err == nil { + t.Errorf("scanGetResponseLine() should have returned error") + } + return + } + if err != nil { + t.Errorf("scanGetResponseLine() returned error %s", err) + return + } + if got.Key != tt.wantKey { + t.Errorf("key = %v, want %v", got.Key, tt.wantKey) + } + if got.Flags != tt.wantFlags { + t.Errorf("flags = %v, want %v", got.Flags, tt.wantFlags) + } + if got.CasID != tt.wantCasid { + t.Errorf("flags = %v, want %v", got.CasID, tt.wantCasid) + } + if gotSize != tt.wantSize { + t.Errorf("size = %v, want %v", gotSize, tt.wantSize) + } + }) + } +}