From 16ec811a2f5f2e35151a023bb2bd7d364ec21d1b Mon Sep 17 00:00:00 2001 From: "Guillaume J. Charmes" Date: Thu, 8 Aug 2013 16:34:23 -0700 Subject: [PATCH] Make the code more idiomatic --- auto.go | 12 +++++------ error.go | 23 ++++++++++++++++++++++ reply.go | 58 ++++++++++++++++++++++-------------------------------- request.go | 34 ++++++++++++++++---------------- 4 files changed, 69 insertions(+), 58 deletions(-) create mode 100644 error.go diff --git a/auto.go b/auto.go index 41fa41e..8a40375 100644 --- a/auto.go +++ b/auto.go @@ -6,6 +6,8 @@ import ( "reflect" ) +type CheckerFn func(request *Request) (reflect.Value, ReplyWriter) + type AutoHandler interface { GET(key string) ([]byte, error) SET(key string, value []byte) error @@ -49,8 +51,7 @@ func createHandlerFn(autoHandler AutoHandler, method *reflect.Method) (HandlerFn return nil, errors.New("Too many return values") } if t := mtype.Out(mtype.NumOut() - 1); t != errorType { - return nil, errors.New( - fmt.Sprintf("Last return value must be an error (not %s)", t)) + return nil, fmt.Errorf("Last return value must be an error (not %s)", t) } return handlerFn(autoHandler, method, checkers) @@ -104,7 +105,7 @@ func createReply(val interface{}) (ReplyWriter, error) { case *ChannelWriter: return val, nil default: - return nil, errors.New(fmt.Sprintf("Unsupported type: %s", val)) + return nil, fmt.Errorf("Unsupported type: %s", val) } } @@ -127,15 +128,12 @@ func createCheckers(method *reflect.Method) ([]CheckerFn, error) { case reflect.TypeOf(1): checkers = append(checkers, intChecker(i-1)) default: - return nil, errors.New( - fmt.Sprintf("Argument %d: wrong type %s", i, mtype.In(i))) + return nil, fmt.Errorf("Argument %d: wrong type %s", i, mtype.In(i)) } } return checkers, nil } -type CheckerFn func(request *Request) (reflect.Value, ReplyWriter) - func stringChecker(index int) CheckerFn { return func(request *Request) (reflect.Value, ReplyWriter) { v, err := request.GetString(index) diff --git a/error.go b/error.go new file mode 100644 index 0000000..d8d5a3d --- /dev/null +++ b/error.go @@ -0,0 +1,23 @@ +package redis + +import ( + "io" +) + +type ErrorReply struct { + code string + message string +} + +func (er *ErrorReply) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write([]byte("-" + er.code + " " + er.message + "\r\n")) + return int64(n), err +} + +func (er *ErrorReply) Error() string { + return "-" + er.code + " " + er.message + "\r\n" +} + +func NewError(message string) *ErrorReply { + return &ErrorReply{code: "ERROR", message: message} +} diff --git a/reply.go b/reply.go index 1adcb2b..1e44cbd 100644 --- a/reply.go +++ b/reply.go @@ -7,59 +7,51 @@ import ( "strconv" ) +type ReplyWriter io.WriterTo + type StatusReply struct { code string } -type ReplyWriter interface { - WriteTo(w io.Writer) (int, error) -} - -func (r *StatusReply) WriteTo(w io.Writer) (int, error) { +func (r *StatusReply) WriteTo(w io.Writer) (int64, error) { Debugf("Status") - return w.Write([]byte("+" + r.code + "\r\n")) -} - -type ErrorReply struct { - code string - message string -} - -func (r *ErrorReply) WriteTo(w io.Writer) (int, error) { - return w.Write([]byte("-" + r.code + " " + r.message + "\r\n")) + n, err := w.Write([]byte("+" + r.code + "\r\n")) + return int64(n), err } type IntegerReply struct { number int } -func (r *IntegerReply) WriteTo(w io.Writer) (int, error) { - return w.Write([]byte(":" + strconv.Itoa(r.number) + "\r\n")) +func (r *IntegerReply) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write([]byte(":" + strconv.Itoa(r.number) + "\r\n")) + return int64(n), err } type BulkReply struct { value []byte } -func writeBytes(value []byte, w io.Writer) (int, error) { +func writeBytes(value []byte, w io.Writer) (int64, error) { //it's a NullBulkReply if value == nil { - return w.Write([]byte("$-1\r\n")) + n, err := w.Write([]byte("$-1\r\n")) + return int64(n), err } wrote, err := w.Write([]byte("$" + strconv.Itoa(len(value)) + "\r\n")) if err != nil { - return wrote, err + return int64(wrote), err } wroteBytes, err := w.Write(value) if err != nil { - return wrote + wroteBytes, err + return int64(wrote + wroteBytes), err } wroteCrLf, err := w.Write([]byte("\r\n")) - return wrote + wroteBytes + wroteCrLf, err + return int64(wrote + wroteBytes + wroteCrLf), err } -func (r *BulkReply) WriteTo(w io.Writer) (int, error) { +func (r *BulkReply) WriteTo(w io.Writer) (int64, error) { return writeBytes(r.value, w) } @@ -79,38 +71,36 @@ func MultiBulkFromMap(m *map[string][]byte) *MultiBulkReply { return &MultiBulkReply{values: values} } -func writeMultiBytes(values [][]byte, w io.Writer) (int, error) { +func writeMultiBytes(values [][]byte, w io.Writer) (int64, error) { if values == nil { return 0, errors.New("Nil in multi bulk replies are not ok") } wrote, err := w.Write([]byte("*" + strconv.Itoa(len(values)) + "\r\n")) if err != nil { - return wrote, err + return int64(wrote), err } + wrote64 := int64(wrote) for _, v := range values { wroteBytes, err := writeBytes(v, w) if err != nil { - return wrote + wroteBytes, err + return wrote64 + wroteBytes, err } - wrote += wroteBytes + wrote64 += wroteBytes } - return wrote, err + return wrote64, err } -func (r *MultiBulkReply) WriteTo(w io.Writer) (int, error) { +func (r *MultiBulkReply) WriteTo(w io.Writer) (int64, error) { return writeMultiBytes(r.values, w) } -func NewError(message string) *ErrorReply { - return &ErrorReply{code: "ERROR", message: message} -} - func methodNotSupported() ReplyWriter { return NewError("Method is not supported") } func ReplyToString(r ReplyWriter) (string, error) { var b bytes.Buffer + _, err := r.WriteTo(&b) if err != nil { return "ERROR!", err @@ -123,7 +113,7 @@ type ChannelWriter struct { Channel chan [][]byte } -func (c *ChannelWriter) WriteTo(w io.Writer) (int, error) { +func (c *ChannelWriter) WriteTo(w io.Writer) (int64, error) { totalBytes, err := writeMultiBytes(c.FirstReply, w) if err != nil { return totalBytes, err diff --git a/request.go b/request.go index ccc484d..6d03e3f 100644 --- a/request.go +++ b/request.go @@ -9,37 +9,37 @@ type Request struct { args [][]byte } -func (request *Request) HasArgument(index int) bool { - return len(request.args) >= index+1 +func (r *Request) HasArgument(index int) bool { + return len(r.args) >= index+1 } -func (request *Request) ExpectArgument(index int) ReplyWriter { - if !request.HasArgument(index) { +func (r *Request) ExpectArgument(index int) ReplyWriter { + if !r.HasArgument(index) { return NewError("Not enough arguments") } return nil } -func (request *Request) GetString(index int) (string, ReplyWriter) { - if reply := request.ExpectArgument(index); reply != nil { +func (r *Request) GetString(index int) (string, ReplyWriter) { + if reply := r.ExpectArgument(index); reply != nil { return "", reply } - return string(request.args[index]), nil + return string(r.args[index]), nil } -func (request *Request) GetInteger(index int) (int, ReplyWriter) { - if reply := request.ExpectArgument(index); reply != nil { +func (r *Request) GetInteger(index int) (int, ReplyWriter) { + if reply := r.ExpectArgument(index); reply != nil { return -1, reply } - i, err := strconv.Atoi(string(request.args[index])) + i, err := strconv.Atoi(string(r.args[index])) if err != nil { return -1, NewError("Expected integer") } return i, nil } -func (request *Request) GetPositiveInteger(index int) (int, ReplyWriter) { - i, reply := request.GetInteger(index) +func (r *Request) GetPositiveInteger(index int) (int, ReplyWriter) { + i, reply := r.GetInteger(index) if reply != nil { return -1, reply } @@ -49,8 +49,8 @@ func (request *Request) GetPositiveInteger(index int) (int, ReplyWriter) { return i, nil } -func (request *Request) GetMap(index int) (*map[string][]byte, ReplyWriter) { - count := len(request.args) - index +func (r *Request) GetMap(index int) (*map[string][]byte, ReplyWriter) { + count := len(r.args) - index if count <= 0 { return nil, NewError("Expected at least one key val pair") } @@ -58,12 +58,12 @@ func (request *Request) GetMap(index int) (*map[string][]byte, ReplyWriter) { return nil, NewError("Got uneven number of key val pairs") } values := make(map[string][]byte) - for i := index; i < len(request.args); i += 2 { - key, reply := request.GetString(i) + for i := index; i < len(r.args); i += 2 { + key, reply := r.GetString(i) if reply != nil { return nil, reply } - values[key] = request.args[i+1] + values[key] = r.args[i+1] } return &values, nil }