-
Notifications
You must be signed in to change notification settings - Fork 9
/
handler.go
119 lines (103 loc) · 2.96 KB
/
handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package httpjson
import (
"context"
"encoding/json"
"errors"
"net/http"
"reflect"
)
// ErrorWriter is responsible for writing the provided error value
// to the response.
type ErrorWriter func(context.Context, http.ResponseWriter, error)
// DefaultResponse will be sent as the response body
// when the handler function signature
// has no return value.
var DefaultResponse = json.RawMessage(`{"message":"ok"}`)
// handler is an http.Handler that calls a function for each request.
// It uses the signature of the function to decide how to interpret
type handler struct {
fv reflect.Value
inType reflect.Type
hasCtx bool
errFunc ErrorWriter
}
// Handler returns an HTTP handler for function f.
// See the package doc for details on allowed signatures for f.
// If f returns a non-nil error, the handler will call errFunc.
func Handler(f interface{}, errFunc ErrorWriter) (http.Handler, error) {
fv := reflect.ValueOf(f)
hasCtx, inType, err := funcInputType(fv)
if err != nil {
return nil, err
}
h := &handler{fv, inType, hasCtx, errFunc}
return h, nil
}
func (h *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
var a []reflect.Value
if h.hasCtx {
ctx := req.Context()
ctx = context.WithValue(ctx, reqKey, req)
ctx = context.WithValue(ctx, respKey, w)
a = append(a, reflect.ValueOf(ctx))
}
if h.inType != nil {
inPtr := reflect.New(h.inType)
err := Read(req.Body, inPtr.Interface())
if err != nil {
h.errFunc(req.Context(), w, err)
return
}
a = append(a, inPtr.Elem())
}
rv := h.fv.Call(a)
var (
res interface{}
err error
)
switch n := len(rv); {
case n == 0:
res = &DefaultResponse
case n == 1 && !h.fv.Type().Out(0).Implements(errorType):
res = rv[0].Interface()
case n == 1 && h.fv.Type().Out(0).Implements(errorType):
// out param is of type error; its value can still be nil
res = &DefaultResponse
err, _ = rv[0].Interface().(error)
case n == 2:
res = rv[0].Interface()
err, _ = rv[1].Interface().(error)
}
if err != nil {
h.errFunc(req.Context(), w, err)
return
}
Write(req.Context(), w, 200, res)
}
var (
errorType = reflect.TypeOf((*error)(nil)).Elem()
contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
)
func funcInputType(fv reflect.Value) (hasCtx bool, t reflect.Type, err error) {
ft := fv.Type()
if ft.Kind() != reflect.Func || ft.IsVariadic() {
return false, nil, errors.New("need nonvariadic func in " + ft.String())
}
off := 0 // or 1 with context
hasCtx = ft.NumIn() >= 1 && ft.In(0).Implements(contextType)
if hasCtx {
off = 1
}
if ft.NumIn() > off+1 {
return false, nil, errors.New("too many params in " + ft.String())
}
if ft.NumIn() == off+1 {
t = ft.In(ft.NumIn() - 1)
}
if n := ft.NumOut(); n == 2 && !ft.Out(1).Implements(errorType) {
return false, nil, errors.New("second return value must be error in " + ft.String())
} else if n > 2 {
return false, nil, errors.New("need at most two return values in " + ft.String())
}
return hasCtx, t, nil
}