diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c56069f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.test \ No newline at end of file diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..ca1f5e2 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,7 @@ +language: go +go: 1.1 +install: + - export PATH=$PATH:$HOME/gopath/bin + - go get -v github.com/mailgun/go-redis-server +script: + - go test github.com/mailgun/go-redis-server diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..7776e70 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +coverage: + gocov test | gocov report +annotate: + FILENAME=$(shell uuidgen) + gocov test > /tmp/--go-test-server-coverage.json + gocov annotate /tmp/--go-test-server-coverage.json $(fn) diff --git a/README.md b/README.md index 4b775c1..b4325ac 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ -# go-redis-server: an implementation of the redis server protocol +[![Build Status](https://travis-ci.org/mailgun/go-redis-server.png)](https://travis-ci.org/mailgun/go-redis-server) + +Redis server protocol library +============================= There are plenty of good client implementations of the redis protocol, but not many *server* implementations. @@ -6,4 +9,31 @@ go-redis-server is a helper library for building server software capable of spea an alternate implementation of redis, a custom proxy to redis, or even a completely different backend capable of "masquerading" its API as a redis database. + +Sample code +------------ + +```go +type MyHandler struct { + values map[string][]byte +} + +func (h *MyHandler) GET(key string) ([]byte, error) { + v, _ := h.values[key] + return v, nil +} + +func (h *MyHandler) SET(key string, value []byte) error { + h.values[key] = value + return nil +} + +func main() { + handler, _ := redis.NewAutoHandler(&MyHandler{}) + server := &redis.Server{Handler: handler, Addr: ":6389"} + server.ListenAndServe() +} + +``` + Copyright (c) dotCloud 2013 diff --git a/auto.go b/auto.go new file mode 100644 index 0000000..6dc6b92 --- /dev/null +++ b/auto.go @@ -0,0 +1,179 @@ +package redis + +import ( + "errors" + "fmt" + "reflect" +) + +type AutoHandler interface { + GET(key string) ([]byte, error) + SET(key string, value []byte) error + HMSET(key string, values *map[string][]byte) error + HGETALL(key string) (*map[string][]byte, error) + HGET(hash string, key string) ([]byte, error) + BRPOP(key string, params ...[]byte) ([][]byte, error) + SUBSCRIBE(channel string, channels ...[]byte) (*ChannelWriter, error) +} + +func NewAutoHandler(autoHandler AutoHandler) (*Handler, error) { + handler := &Handler{} + + rh := reflect.TypeOf(autoHandler) + for i := 0; i < rh.NumMethod(); i += 1 { + method := rh.Method(i) + handlerFn, err := createHandlerFn(autoHandler, &method) + if err != nil { + return nil, err + } + handler.Register(method.Name, handlerFn) + } + return handler, nil +} + +func createHandlerFn(autoHandler AutoHandler, method *reflect.Method) (HandlerFn, error) { + errorType := reflect.TypeOf(createHandlerFn).Out(1) + mtype := method.Func.Type() + checkers, err := createCheckers(method) + if err != nil { + return nil, err + } + + // Check output + if mtype.NumOut() == 0 { + return nil, errors.New("Not enough return values") + } + if mtype.NumOut() > 2 { + return nil, errors.New("Too many return values") + } + if t := mtype.Out(mtype.NumOut() - 1); t != errorType { + return nil, errors.New( + fmt.Sprintf("Last return value must be an error (not %s)", t)) + } + + return handlerFn(autoHandler, method, checkers) +} + +func handlerFn(autoHandler AutoHandler, method *reflect.Method, checkers []CheckerFn) (HandlerFn, error) { + return func(request *Request) (ReplyWriter, error) { + input := []reflect.Value{reflect.ValueOf(autoHandler)} + for _, checker := range checkers { + value, reply := checker(request) + if reply != nil { + return reply, nil + } + input = append(input, value) + } + + var result []reflect.Value + if method.Func.Type().IsVariadic() { + result = method.Func.CallSlice(input) + } else { + result = method.Func.Call(input) + } + + var ret interface{} + if ierr := result[len(result)-1].Interface(); ierr != nil { + // Last return value is an error, wrap it to redis error + err := ierr.(error) + // convert to redis error reply + return NewError(err.Error()), nil + } + if len(result) > 1 { + ret = result[0].Interface() + return createReply(ret) + } + return &StatusReply{code: "OK"}, nil + }, nil +} + +func createReply(val interface{}) (ReplyWriter, error) { + switch val := val.(type) { + case [][]byte: + return &MultiBulkReply{values: val}, nil + case string: + return &BulkReply{value: []byte(val)}, nil + case []byte: + return &BulkReply{value: val}, nil + case *map[string][]byte: + return MultiBulkFromMap(val), nil + case int: + return &IntegerReply{number: val}, nil + case *ChannelWriter: + return val, nil + default: + return nil, errors.New(fmt.Sprintf("Unsupported type: %s", val)) + } +} + +func createCheckers(method *reflect.Method) ([]CheckerFn, error) { + checkers := []CheckerFn{} + mtype := method.Func.Type() + for i := 1; i < mtype.NumIn(); i += 1 { + switch mtype.In(i) { + case reflect.TypeOf(""): + checkers = append(checkers, stringChecker(i-1)) + case reflect.TypeOf([]byte{}): + checkers = append(checkers, byteChecker(i-1)) + case reflect.TypeOf([][]byte{}): + checkers = append(checkers, byteSliceChecker(i-1)) + case reflect.TypeOf(&map[string][]byte{}): + if i != mtype.NumIn()-1 { + return nil, errors.New("Map should be the last argument") + } + checkers = append(checkers, mapChecker(i-1)) + case reflect.TypeOf(1): + checkers = append(checkers, intChecker(i-1)) + default: + return nil, errors.New( + fmt.Sprintf("Argument %d: wrong type %s", i, mtype.In(i))) + } + } + return checkers, nil +} + +type CheckerFn func(request *Request) (reflect.Value, ReplyWriter) + +func stringChecker(index int) CheckerFn { + return func(request *Request) (reflect.Value, ReplyWriter) { + v, err := request.GetString(index) + if err != nil { + return reflect.ValueOf(""), err + } + return reflect.ValueOf(v), nil + } +} + +func byteChecker(index int) CheckerFn { + return func(request *Request) (reflect.Value, ReplyWriter) { + err := request.ExpectArgument(index) + if err != nil { + return reflect.ValueOf([]byte{}), err + } + return reflect.ValueOf(request.args[index]), nil + } +} + +func byteSliceChecker(index int) CheckerFn { + return func(request *Request) (reflect.Value, ReplyWriter) { + if !request.HasArgument(index) { + return reflect.ValueOf([][]byte{}), nil + } else { + return reflect.ValueOf(request.args[index:]), nil + } + } +} + +func mapChecker(index int) CheckerFn { + return func(request *Request) (reflect.Value, ReplyWriter) { + m, err := request.GetMap(index) + return reflect.ValueOf(m), err + } +} + +func intChecker(index int) CheckerFn { + return func(request *Request) (reflect.Value, ReplyWriter) { + m, err := request.GetInteger(index) + return reflect.ValueOf(m), err + } +} diff --git a/auto_test.go b/auto_test.go new file mode 100644 index 0000000..5314efc --- /dev/null +++ b/auto_test.go @@ -0,0 +1,227 @@ +package redis + +import ( + "testing" +) + +type Hash struct { + values map[string][]byte +} + +type TestHandler struct { + values map[string][]byte + hashValues map[string]Hash +} + +func NewHandler() *TestHandler { + return &TestHandler{ + values: make(map[string][]byte), + hashValues: make(map[string]Hash), + } +} + +func (h *TestHandler) GET(key string) ([]byte, error) { + v, _ := h.values[key] + return v, nil +} + +func (h *TestHandler) SET(key string, value []byte) error { + h.values[key] = value + return nil +} + +func (h *TestHandler) HMSET(key string, values *map[string][]byte) error { + _, exists := h.hashValues[key] + if !exists { + h.hashValues[key] = Hash{values: make(map[string][]byte)} + } + hash := h.hashValues[key] + for name, val := range *values { + hash.values[name] = val + } + return nil +} + +func (h *TestHandler) HGET(hash string, key string) ([]byte, error) { + hs, exists := h.hashValues[hash] + if !exists { + return nil, nil + } + val, _ := hs.values[key] + return val, nil +} + +func (h *TestHandler) HGETALL(hash string) (*map[string][]byte, error) { + hs, exists := h.hashValues[hash] + if !exists { + return nil, nil + } + return &hs.values, nil +} + +func (h *TestHandler) BRPOP(key string, params ...[]byte) ([][]byte, error) { + params = append(params, []byte(key)) + return params, nil +} + +func (h *TestHandler) SUBSCRIBE(channel string, channels ...[]byte) (*ChannelWriter, error) { + output := make(chan [][]byte) + writer := &ChannelWriter{ + FirstReply: [][]byte{ + []byte("subscribe"), + []byte(channel), + []byte("1"), + }, + Channel: output, + } + go func() { + output <- [][]byte{ + []byte("message"), + []byte(channel), + []byte("yo"), + } + close(output) + }() + return writer, nil +} + +func TestAutoHandler(t *testing.T) { + h, err := NewAutoHandler(NewHandler()) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + expected := []struct { + request *Request + expected []string + }{ + { + request: &Request{ + name: "GET", + args: [][]byte{[]byte("key")}, + }, + expected: []string{"$-1\r\n"}, + }, + { + request: &Request{ + name: "SET", + args: [][]byte{ + []byte("key"), + []byte("value"), + }, + }, + expected: []string{"+OK\r\n"}, + }, + { + request: &Request{ + name: "GET", + args: [][]byte{[]byte("key")}, + }, + expected: []string{"$5\r\nvalue\r\n"}, + }, + { + request: &Request{ + name: "HGET", + args: [][]byte{ + []byte("key"), + []byte("prop1"), + }, + }, + expected: []string{"$-1\r\n"}, + }, + { + request: &Request{ + name: "HMSET", + args: [][]byte{ + []byte("key"), + []byte("prop1"), + []byte("value1"), + []byte("prop2"), + []byte("value2"), + }, + }, + expected: []string{"+OK\r\n"}, + }, + { + request: &Request{ + name: "HGET", + args: [][]byte{ + []byte("key"), + []byte("prop1"), + }, + }, + expected: []string{"$6\r\nvalue1\r\n"}, + }, + { + request: &Request{ + name: "HGET", + args: [][]byte{ + []byte("key"), + []byte("prop2"), + }, + }, + expected: []string{"$6\r\nvalue2\r\n"}, + }, + { + request: &Request{ + name: "HGETALL", + args: [][]byte{ + []byte("key"), + }, + }, + expected: []string{ + "*4\r\n$5\r\nprop1\r\n$6\r\nvalue1\r\n$5\r\nprop2\r\n$6\r\nvalue2\r\n", + "*4\r\n$5\r\nprop2\r\n$6\r\nvalue2\r\n$5\r\nprop1\r\n$6\r\nvalue1\r\n", + }, + }, + { + request: &Request{ + name: "BRPOP", + args: [][]byte{ + []byte("key"), + }, + }, + expected: []string{ + "*1\r\n$3\r\nkey\r\n", + }, + }, + { + request: &Request{ + name: "BRPOP", + args: [][]byte{ + []byte("key1"), + []byte("key2"), + }, + }, + expected: []string{ + "*2\r\n$4\r\nkey2\r\n$4\r\nkey1\r\n", + }, + }, + { + request: &Request{ + name: "SUBSCRIBE", + args: [][]byte{ + []byte("foo"), + }, + }, + expected: []string{ + "*3\r\n$9\r\nsubscribe\r\n$3\r\nfoo\r\n$1\r\n1\r\n*3\r\n$7\r\nmessage\r\n$3\r\nfoo\r\n$2\r\nyo\r\n", + }, + }, + } + for _, v := range expected { + reply, err := ApplyString(h, v.request) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + match := false + for _, expected := range v.expected { + if reply == expected { + match = true + break + } + } + if match == false { + t.Fatalf("Eexpected one of %q, got: %q", v.expected, reply) + } + } +} diff --git a/debug.go b/debug.go new file mode 100644 index 0000000..7fd252f --- /dev/null +++ b/debug.go @@ -0,0 +1,27 @@ +package redis + +import ( + "fmt" + "os" + "runtime" + "strings" +) + +// Debug function, if the debug flag is set, then display. Do nothing otherwise +// If Docker is in damon mode, also send the debug info on the socket +// Convenience debug function, courtesy of http://github.com/dotcloud/docker +func Debugf(format string, a ...interface{}) { + if os.Getenv("DEBUG") != "" { + + // Retrieve the stack infos + _, file, line, ok := runtime.Caller(1) + if !ok { + file = "" + line = -1 + } else { + file = file[strings.LastIndex(file, "/")+1:] + } + + fmt.Fprintf(os.Stderr, fmt.Sprintf("[%d] [debug] %s:%d %s\n", os.Getpid(), file, line, format), a...) + } +} diff --git a/handler.go b/handler.go new file mode 100644 index 0000000..a705031 --- /dev/null +++ b/handler.go @@ -0,0 +1,34 @@ +package redis + +import ( + "strings" +) + +type HandlerFn func(r *Request) (ReplyWriter, error) + +type Handler struct { + methods map[string]HandlerFn +} + +func Apply(h *Handler, r *Request) (ReplyWriter, error) { + fn, exists := h.methods[strings.ToLower(r.name)] + if !exists { + return methodNotSupported(), nil + } + return fn(r) +} + +func ApplyString(h *Handler, r *Request) (string, error) { + reply, err := Apply(h, r) + if err != nil { + return "", err + } + return ReplyToString(reply) +} + +func (h *Handler) Register(name string, fn HandlerFn) { + if h.methods == nil { + h.methods = make(map[string]HandlerFn) + } + h.methods[strings.ToLower(name)] = fn +} diff --git a/handler_test.go b/handler_test.go new file mode 100644 index 0000000..36d9284 --- /dev/null +++ b/handler_test.go @@ -0,0 +1,31 @@ +package redis + +import ( + "strings" + "testing" +) + +func TestEmptyHandler(t *testing.T) { + reply, err := ApplyString(&Handler{}, &Request{}) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + if !strings.Contains(reply, "-ERROR") { + t.Fatalf("Eexpected error reply, got: %s", err) + } +} + +func TestCustomHandler(t *testing.T) { + h := &Handler{} + h.Register("GET", func(r *Request) (ReplyWriter, error) { + return &BulkReply{value: []byte("42")}, nil + }) + reply, err := ApplyString(h, &Request{name: "gEt"}) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + expected := "$2\r\n42\r\n" + if reply != expected { + t.Fatalf("Eexpected reply %q, got: %q", expected, reply) + } +} diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..88d1d14 --- /dev/null +++ b/parser.go @@ -0,0 +1,92 @@ +package redis + +import ( + "bufio" + "fmt" + "io" + "io/ioutil" + "strings" +) + +func parseRequest(r *bufio.Reader) (*Request, error) { + // first line of redis request should be: + // *CRLF + line, err := r.ReadString('\n') + if err != nil { + return nil, err + } + + // note that this line also protects us from negative integers + var argsCount int + if _, err := fmt.Sscanf(line, "*%d\r", &argsCount); err != nil { + return nil, malformed("*", line) + } + + // All next lines are pairs of: + //$ CR LF + // CR LF + // first argument is a command name, so just convert + firstArg, err := readArgument(r) + if err != nil { + return nil, err + } + + args := make([][]byte, argsCount-1) + for i := 0; i < argsCount-1; i += 1 { + if args[i], err = readArgument(r); err != nil { + return nil, err + } + } + + return &Request{name: strings.ToLower(string(firstArg)), args: args}, nil +} + +func readArgument(r *bufio.Reader) ([]byte, error) { + + line, err := r.ReadString('\n') + if err != nil { + return nil, malformed("$", line) + } + var argSize int + if _, err := fmt.Sscanf(line, "$%d\r", &argSize); err != nil { + return nil, malformed("$", line) + } + + // I think int is safe here as the max length of request + // should be less then max int value? + data, err := ioutil.ReadAll(io.LimitReader(r, int64(argSize))) + if err != nil { + return nil, err + } + + if len(data) != argSize { + return nil, malformedLength(argSize, len(data)) + } + + // Now check for trailing CR + if b, err := r.ReadByte(); err != nil || b != '\r' { + return nil, malformedMissingCRLF() + } + + // And LF + if b, err := r.ReadByte(); err != nil || b != '\n' { + return nil, malformedMissingCRLF() + } + + return data, nil +} + +func malformed(expected string, got string) error { + return fmt.Errorf( + "Mailformed request:'%s does not match %s\\r\\n'", got, expected) +} + +func malformedLength(expected int, got int) error { + return fmt.Errorf( + "Mailformed request: argument length '%d does not match %d\\r\\n'", + got, expected) +} + +func malformedMissingCRLF() error { + return fmt.Errorf("Mailformed request: line should end with \\r\\n") +} diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..170126b --- /dev/null +++ b/parser_test.go @@ -0,0 +1,87 @@ +package redis + +import ( + "bufio" + "bytes" + "io/ioutil" + "testing" +) + +func TestReader(t *testing.T) { + v, err := ioutil.ReadAll(r("Hello!")) + if err != nil { + t.Fatalf("Should have read it actually") + } + if string(v) != "Hello!" { + t.Fatalf("Expected %s here, got %s", "Hello!", v) + } +} + +func TestParseBadRequests(t *testing.T) { + requests := []string{ + //empty lines, + "", "\r\n", "\r\r", "\n\n", + // malformed start line, + "hello\r\n", " \r\n", "*hello\r\n", "*-100\r\n", + //malformed arguments count + "*3\r\nhi", "*3\r\nhi\r\n", "*4\r\n$1", "*4\r\n$1\r", "*4\r\n$1\n", + //Corrupted argument size + "*2\r\n$3\r\ngEt\r\n$what?\r\nx\r\n", + //mismatched arguments count + "*4\r\n$3\r\ngEt\r\n$1\r\nx\r\n", + //missing trailing \r\n + "*2\r\n$3\r\ngEt\r\n$1\r\nx", + //missing trailing \r\n + "*2\r\n$3\r\ngEt\r\n$1\r\nx\r", + //lied about argument length \r\n + "*2\r\n$3\r\ngEt\r\n$100\r\nx\r\n", + } + for _, v := range requests { + _, err := parseRequest(r(v)) + if err == nil { + t.Fatalf("Expected eror %s", v) + } + } +} + +func TestSucess(t *testing.T) { + expected := []struct { + r Request + s string + }{ + {Request{name: "a"}, "*1\r\n$1\r\na\r\n"}, + {Request{name: "get"}, "*1\r\n$3\r\ngEt\r\n"}, + {Request{name: "get", args: b("x")}, "*2\r\n$3\r\ngEt\r\n$1\r\nx\r\n"}, + {Request{name: "set", args: b("mykey", "myvalue")}, "*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n$7\r\nmyvalue\r\n"}, + } + + for _, p := range expected { + request, err := parseRequest(r(p.s)) + if err != nil { + t.Fatalf("Un xxpected eror %s when parsting", err, p.s) + } + if request.name != p.r.name { + t.Fatalf("Expected command %s, got %s", err, p.r.name, request.name) + } + if len(request.args) != len(p.r.args) { + t.Fatalf("Args length mismatch %s, got %s", err, p.r.args, request.args) + } + for i := 0; i < len(request.args); i += 1 { + if !bytes.Equal(request.args[i], p.r.args[i]) { + t.Fatalf("Expected args %s, got %s", err, p.r.args, request.args) + } + } + } +} + +func b(args ...string) [][]byte { + arr := make([][]byte, len(args)) + for i := 0; i < len(args); i += 1 { + arr[i] = []byte(args[i]) + } + return arr +} + +func r(request string) *bufio.Reader { + return bufio.NewReader(bytes.NewReader([]byte(request))) +} diff --git a/reply.go b/reply.go new file mode 100644 index 0000000..1adcb2b --- /dev/null +++ b/reply.go @@ -0,0 +1,149 @@ +package redis + +import ( + "bytes" + "errors" + "io" + "strconv" +) + +type StatusReply struct { + code string +} + +type ReplyWriter interface { + WriteTo(w io.Writer) (int, error) +} + +func (r *StatusReply) WriteTo(w io.Writer) (int, error) { + Debugf("Status") + return w.Write([]byte("+" + r.code + "\r\n")) +} + +type ErrorReply struct { + code string + message string +} + +func (r *ErrorReply) WriteTo(w io.Writer) (int, error) { + return w.Write([]byte("-" + r.code + " " + r.message + "\r\n")) +} + +type IntegerReply struct { + number int +} + +func (r *IntegerReply) WriteTo(w io.Writer) (int, error) { + return w.Write([]byte(":" + strconv.Itoa(r.number) + "\r\n")) +} + +type BulkReply struct { + value []byte +} + +func writeBytes(value []byte, w io.Writer) (int, error) { + //it's a NullBulkReply + if value == nil { + return w.Write([]byte("$-1\r\n")) + } + + wrote, err := w.Write([]byte("$" + strconv.Itoa(len(value)) + "\r\n")) + if err != nil { + return wrote, err + } + wroteBytes, err := w.Write(value) + if err != nil { + return wrote + wroteBytes, err + } + wroteCrLf, err := w.Write([]byte("\r\n")) + return wrote + wroteBytes + wroteCrLf, err +} + +func (r *BulkReply) WriteTo(w io.Writer) (int, error) { + return writeBytes(r.value, w) +} + +//for nil reply in multi bulk just set []byte as nil +type MultiBulkReply struct { + values [][]byte +} + +func MultiBulkFromMap(m *map[string][]byte) *MultiBulkReply { + values := make([][]byte, len(*m)*2) + i := 0 + for key, val := range *m { + values[i] = []byte(key) + values[i+1] = val + i += 2 + } + return &MultiBulkReply{values: values} +} + +func writeMultiBytes(values [][]byte, w io.Writer) (int, error) { + if values == nil { + return 0, errors.New("Nil in multi bulk replies are not ok") + } + wrote, err := w.Write([]byte("*" + strconv.Itoa(len(values)) + "\r\n")) + if err != nil { + return wrote, err + } + for _, v := range values { + wroteBytes, err := writeBytes(v, w) + if err != nil { + return wrote + wroteBytes, err + } + wrote += wroteBytes + } + return wrote, err +} + +func (r *MultiBulkReply) WriteTo(w io.Writer) (int, error) { + return writeMultiBytes(r.values, w) +} + +func NewError(message string) *ErrorReply { + return &ErrorReply{code: "ERROR", message: message} +} + +func methodNotSupported() ReplyWriter { + return NewError("Method is not supported") +} + +func ReplyToString(r ReplyWriter) (string, error) { + var b bytes.Buffer + _, err := r.WriteTo(&b) + if err != nil { + return "ERROR!", err + } + return b.String(), nil +} + +type ChannelWriter struct { + FirstReply [][]byte + Channel chan [][]byte +} + +func (c *ChannelWriter) WriteTo(w io.Writer) (int, error) { + totalBytes, err := writeMultiBytes(c.FirstReply, w) + if err != nil { + return totalBytes, err + } + + for { + select { + case reply := <-c.Channel: + if reply == nil { + return totalBytes, nil + } else { + wroteBytes, err := writeMultiBytes(reply, w) + // FIXME: obvious overflow here, + // Just ignore? Who cares? + totalBytes += wroteBytes + if err != nil { + return totalBytes, err + } + } + } + } + return totalBytes, nil +} diff --git a/reply_test.go b/reply_test.go new file mode 100644 index 0000000..25b60ac --- /dev/null +++ b/reply_test.go @@ -0,0 +1,37 @@ +package redis + +import ( + "bytes" + "testing" +) + +func TestWriteStatus(t *testing.T) { + replies := []struct { + reply ReplyWriter + expected string + }{ + {&StatusReply{code: "OK"}, "+OK\r\n"}, + {&IntegerReply{number: 42}, ":42\r\n"}, + {&ErrorReply{code: "ERROR", message: "Something went wrong"}, "-ERROR Something went wrong\r\n"}, + {&BulkReply{}, "$-1\r\n"}, + {&BulkReply{[]byte{'h', 'e', 'l', 'l', 'o'}}, "$5\r\nhello\r\n"}, + {&MultiBulkReply{[][]byte{[]byte{'h', 'e', 'l', 'l', 'o'}}}, "*1\r\n$5\r\nhello\r\n"}, + {&MultiBulkReply{[][]byte{[]byte{'h', 'e', 'l', 'l', 'o'}, []byte{'h', 'i'}}}, "*2\r\n$5\r\nhello\r\n$2\r\nhi\r\n"}, + {&MultiBulkReply{[][]byte{nil, []byte{'h', 'e', 'l', 'l', 'o'}, nil, []byte{'h', 'i'}}}, "*4\r\n$-1\r\n$5\r\nhello\r\n$-1\r\n$2\r\nhi\r\n"}, + {MultiBulkFromMap(&map[string][]byte{"hello": []byte("there"), "how": []byte("are you")}), "*4\r\n$5\r\nhello\r\n$5\r\nthere\r\n$3\r\nhow\r\n$7\r\nare you\r\n"}, + } + for _, p := range replies { + var b bytes.Buffer + n, err := p.reply.WriteTo(&b) + if err != nil { + t.Fatalf("Oops, unexpected %s", err) + } + val := b.String() + if val != p.expected { + t.Fatalf("Oops, expected %q, got %q instead", p.expected, val) + } + if n != len(p.expected) { + t.Fatalf("Expected to write %d bytes, wrote %d instead", len(p.expected), n) + } + } +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..ccc484d --- /dev/null +++ b/request.go @@ -0,0 +1,69 @@ +package redis + +import ( + "strconv" +) + +type Request struct { + name string + args [][]byte +} + +func (request *Request) HasArgument(index int) bool { + return len(request.args) >= index+1 +} + +func (request *Request) ExpectArgument(index int) ReplyWriter { + if !request.HasArgument(index) { + return NewError("Not enough arguments") + } + return nil +} + +func (request *Request) GetString(index int) (string, ReplyWriter) { + if reply := request.ExpectArgument(index); reply != nil { + return "", reply + } + return string(request.args[index]), nil +} + +func (request *Request) GetInteger(index int) (int, ReplyWriter) { + if reply := request.ExpectArgument(index); reply != nil { + return -1, reply + } + i, err := strconv.Atoi(string(request.args[index])) + if err != nil { + return -1, NewError("Expected integer") + } + return i, nil +} + +func (request *Request) GetPositiveInteger(index int) (int, ReplyWriter) { + i, reply := request.GetInteger(index) + if reply != nil { + return -1, reply + } + if i < 0 { + return -1, NewError("Expected positive integer") + } + return i, nil +} + +func (request *Request) GetMap(index int) (*map[string][]byte, ReplyWriter) { + count := len(request.args) - index + if count <= 0 { + return nil, NewError("Expected at least one key val pair") + } + if count%2 != 0 { + return nil, NewError("Got uneven number of key val pairs") + } + values := make(map[string][]byte) + for i := index; i < len(request.args); i += 2 { + key, reply := request.GetString(i) + if reply != nil { + return nil, reply + } + values[key] = request.args[i+1] + } + return &values, nil +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..bff4011 --- /dev/null +++ b/request_test.go @@ -0,0 +1,154 @@ +package redis + +import ( + "bytes" + "testing" +) + +func TestRequestExpectArgument(t *testing.T) { + r := &Request{name: "Hi", args: [][]byte{}} + for i := 0; i < 10; i += 1 { + reply := r.ExpectArgument(i) + if reply == nil { + t.Fatalf("Expected error reply, got nil") + } + } + r = &Request{name: "Hi", args: [][]byte{[]byte{'h', 'i'}}} + reply := r.ExpectArgument(0) + if reply != nil { + t.Fatalf("Expected nil reply, got %s", reply) + } + + reply = r.ExpectArgument(1) + if reply == nil { + t.Fatalf("Expected error reply, got nil") + } +} + +func TestRequestGetString(t *testing.T) { + s := "Hello, World!" + r := &Request{name: "Hi", args: [][]byte{[]byte(s)}} + val, reply := r.GetString(0) + if reply != nil { + t.Fatalf("Expected nil reply, got %s", reply) + } + if val != s { + t.Fatalf("Expected %s, got %s", s, val) + } + val, reply = r.GetString(5) + if reply == nil { + t.Fatalf("Expected reply, got nil") + } +} + +func TestRequestGetInteger(t *testing.T) { + invalid := []*Request{ + &Request{name: "Hi", args: [][]byte{}}, + &Request{name: "Hi", args: [][]byte{[]byte{'h', 'i'}}}, + } + for _, request := range invalid { + _, reply := request.GetInteger(0) + if reply == nil { + t.Fatalf("Expected error reply, got nil") + } + } + + valid := []struct { + request *Request + index int + number int + }{ + {&Request{name: "Hi", args: [][]byte{[]byte{'1'}}}, 0, 1}, + {&Request{name: "Hi", args: [][]byte{[]byte{'1'}, []byte("42")}}, 1, 42}, + {&Request{name: "Hi", args: [][]byte{[]byte{'1'}, []byte("-1043")}}, 1, -1043}, + } + for _, v := range valid { + number, reply := v.request.GetInteger(v.index) + if reply != nil { + t.Fatalf("Expected nil reply, got %s", reply) + } + if v.number != number { + t.Fatalf("Expected %d reply, got %d", number, v.number) + } + number, reply = v.request.GetPositiveInteger(v.index) + if v.number > 0 { + if reply != nil { + t.Fatalf("Expected nil reply, got %s", reply) + } + } else { + if reply == nil { + t.Fatalf("Expected error reply, got %s", reply) + } + } + } +} + +func TestRequestGetMap(t *testing.T) { + invalid := []struct { + request *Request + index int + }{ + {&Request{name: "Hi", args: [][]byte{}}, 0}, + {&Request{name: "Hi", args: [][]byte{}}, 100}, + {&Request{name: "Hi", args: [][]byte{[]byte{'h', 'i'}}}, 0}, + {&Request{name: "Hi", args: [][]byte{[]byte{'h', 'i'}, []byte{'h', 'i'}, []byte{'h', 'i'}}}, 0}, + } + for _, v := range invalid { + _, reply := v.request.GetMap(v.index) + if reply == nil { + t.Fatalf("Expected error reply, got nil for %s %d", v.request, v.index) + } + } + + valid := []struct { + request *Request + index int + expected map[string][]byte + }{ + { + request: &Request{ + name: "Hi", + args: [][]byte{ + []byte{'h', 'i'}, + []byte{'y', 'o'}, + }, + }, + index: 0, + expected: map[string][]byte{"hi": []byte("yo")}, + }, + { + request: &Request{ + name: "Hi", + args: [][]byte{ + []byte("hi"), + []byte("yo"), + []byte("key"), + []byte("value"), + }, + }, + index: 0, + expected: map[string][]byte{"hi": []byte("yo"), "key": []byte("value")}, + }, + } + for _, v := range valid { + m, reply := v.request.GetMap(v.index) + if reply != nil { + t.Fatalf("Expected nil reply, got %s for %s %d", reply, v.request, v.index) + } + if !mapsEqual(*m, v.expected) { + t.Fatalf("Expected %s got %s for %s", v.expected, v.request, m) + } + } +} + +func mapsEqual(a map[string][]byte, b map[string][]byte) bool { + if len(a) != len(b) { + return false + } + for key, val := range a { + if !bytes.Equal(b[key], val) { + return false + } + } + return true +} diff --git a/server.go b/server.go index b908c43..717e757 100644 --- a/server.go +++ b/server.go @@ -1,38 +1,47 @@ package redis import ( - "os" - "errors" - "fmt" - "reflect" "bufio" + "fmt" "io" - "io/ioutil" - "runtime" - "strings" + "net" ) -type Handler interface { - GET(key string) (*string, error) - SET(key, value string) error -} - -type DummyHandler struct { +type Server struct { + Addr string // TCP address to listen on, ":6389" if empty + Handler *Handler // handler to invoke } -func (h *DummyHandler) GET(key string) (*string, error) { - result := "42" - return &result, nil +func (srv *Server) ListenAndServe() error { + addr := srv.Addr + if addr == "" { + addr = ":6389" + } + l, e := net.Listen("tcp", addr) + if e != nil { + return e + } + return srv.Serve(l) } -func (h *DummyHandler) SET(key, value string) error { - return nil +// Serve accepts incoming connections on the Listener l, creating a +// new service goroutine for each. The service goroutines read requests and +// then call srv.Handler to reply to them. +func (srv *Server) Serve(l net.Listener) error { + defer l.Close() + for { + rw, err := l.Accept() + if err != nil { + return err + } + go Serve(rw, srv.Handler) + } } // Serve starts a new redis session, using `conn` as a transport. // It reads commands using the redis protocol, passes them to `handler`, // and returns the result. -func Serve(conn io.ReadWriteCloser, handler Handler) (err error) { +func Serve(conn io.ReadWriteCloser, handler *Handler) (err error) { defer func() { if err != nil { fmt.Fprintf(conn, "-%s\n", err) @@ -41,184 +50,18 @@ func Serve(conn io.ReadWriteCloser, handler Handler) (err error) { }() reader := bufio.NewReader(conn) for { - // FIXME: commit the current container before each command - Debugf("Reading command...") - var nArg int - line, err := reader.ReadString('\r') + request, err := parseRequest(reader) if err != nil { return err } - Debugf("line == '%s'", line) - if len(line) < 1 || line[len(line) - 1] != '\r' { - return fmt.Errorf("Malformed request: doesn't start with '*\\r\\n'. %s", err) - } - line = line[:len(line) - 1] - if _, err := fmt.Sscanf(line, "*%d", &nArg); err != nil { - return fmt.Errorf("Malformed request: '%s' doesn't start with '*'. %s", line, err) - } - Debugf("nArg = %d", nArg) - nl := make([]byte, 1) - if _, err := reader.Read(nl); err != nil { + reply, err := Apply(handler, request) + if err != nil { return err - } else if nl[0] != '\n' { - return fmt.Errorf("Malformed request: expected '%x', got '%x'", '\n', nl[0]) } - var ( - opName string - opArgs []string - ) - for i:=0; i\\r\\n'. %s", err) - } - line = line[:len(line) - 1] - if _, err := fmt.Sscanf(line, "$%d", &argSize); err != nil { - return fmt.Errorf("Malformed request: '%s' doesn't start with '$'. %s", line, err) - } - Debugf("argSize= %d", argSize) - nl := make([]byte, 1) - if _, err := reader.Read(nl); err != nil { - return err - } else if nl[0] != '\n' { - return fmt.Errorf("Malformed request: expected '%x', got '%x'", '\n', nl[0]) - } - - - // Read arg data - argData, err := ioutil.ReadAll(io.LimitReader(reader, argSize + 2)) - if err != nil { - return err - } else if n := int64(len(argData)); n < argSize + 2 { - return fmt.Errorf("Malformed request: argument data #%d doesn't match declared size (expected %d bytes (%d + \r\n), read %d)", i, argSize + 2, argSize, n) - } else if string(argData[len(argData) - 2:]) != "\r\n" { - return fmt.Errorf("Malformed request: argument #%d doesn't end with \\r\\n", i) - } - arg := string(argData[:len(argData) - 2]) - Debugf("arg = %s", arg) - if i == 0 { - opName = strings.ToLower(arg) - } else { - opArgs = append(opArgs, arg) - } - } - result, err := Apply(handler, opName, opArgs...) + _, err = reply.WriteTo(conn) if err != nil { return err } - fmt.Fprintf(conn, "+%s\n", result) } return nil } - -// ApplyString calls `Apply` and returns the result as a string pointer. -func ApplyString(handler Handler, cmd string, args ... string) (*string, error) { - IResult, err := Apply(handler, cmd, args...) - if err != nil { - return nil, err - } - if IResult == nil { - return nil, err - } - result, isString := IResult.(*string) - if !isString { - return nil, fmt.Errorf("Result is not a string") - } - return result, nil -} - -// Apply parses and executes a redis command -func Apply(handler Handler, cmd string, args ... string) (interface{}, error) { - method, exists := reflect.TypeOf(handler).MethodByName(strings.ToUpper(cmd)) - if !exists { - return nil, errors.New(fmt.Sprintf("%s: no such command", cmd)) - } - Debugf("Method = %v", method) - if err := checkMethodSignature(&method, len(args)); err != nil { - return nil, err - } - input := []reflect.Value{reflect.ValueOf(handler)} - var result []reflect.Value - mType := method.Func.Type() - if mType.IsVariadic() { - for i:=0; i 1 { - ret = result[0].Interface() - } - return ret, err -} - -func checkMethodSignature(method *reflect.Method, nArgs int) error { - errorType := reflect.TypeOf(checkMethodSignature).Out(0) - mtype := method.Func.Type() - // Check input - if mtype.IsVariadic() && mtype.In(mtype.NumIn() - 1) != reflect.TypeOf([]string{}) { - return errors.New("Variadic argument is not []string") - } - if nArgs < mtype.NumIn() - 1 { - return errors.New("Not enough arguments") - } - if nArgs > mtype.NumIn() - 1 { - return errors.New("Too many arguments") - } - for i:=1; i 2 { - return errors.New("Too many return values") - } - if t := mtype.Out(mtype.NumOut() - 1); t != errorType { - return errors.New(fmt.Sprintf("Last return value must be an error (not %s)", t)) - } - return nil -} - -// Debug function, if the debug flag is set, then display. Do nothing otherwise -// If Docker is in damon mode, also send the debug info on the socket -// Convenience debug function, courtesy of http://github.com/dotcloud/docker -func Debugf(format string, a ...interface{}) { - if os.Getenv("DEBUG") != "" { - - // Retrieve the stack infos - _, file, line, ok := runtime.Caller(1) - if !ok { - file = "" - line = -1 - } else { - file = file[strings.LastIndex(file, "/")+1:] - } - - fmt.Fprintf(os.Stderr, fmt.Sprintf("[%d] [debug] %s:%d %s\n", os.Getpid(), file, line, format), a...) - } -} - diff --git a/server_test.go b/server_test.go index cda57d7..de4d524 100644 --- a/server_test.go +++ b/server_test.go @@ -4,26 +4,5 @@ import ( "testing" ) -type TestHandler struct { - DummyHandler -} - -func (h *TestHandler) GET(key string) (*string, error) { - res := "you asked for the key " + key - return &res, nil -} - -func TestApplyGET(t *testing.T) { - t_input := "foo" - t_output := "you asked for the key " + t_input - result, err := ApplyString(new(TestHandler), "GET", t_input) - if err != nil { - t.Fatal(err) - } - if result == nil { - t.Fatalf("Expected '%s', got nil string", t_output) - } - if *result != t_output { - t.Fatal("Expected '%s', got '%s'", t_output, *result) - } +func TestServer(t *testing.T) { }