forked from smallnest/rpcx
/
gateway.go
125 lines (105 loc) ยท 3.06 KB
/
gateway.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package server
import (
"context"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/julienschmidt/httprouter"
"github.com/smallnest/rpcx/log"
"github.com/smallnest/rpcx/protocol"
"github.com/smallnest/rpcx/share"
"github.com/soheilhy/cmux"
)
func (s *Server) startGateway(network string, ln net.Listener) net.Listener {
if network != "tcp" && network != "tcp4" && network != "tcp6" {
log.Infof("network is not tcp/tcp4/tcp6 so can not start gateway")
return ln
}
m := cmux.New(ln)
httpLn := m.Match(cmux.HTTP1Fast())
rpcxLn := m.Match(cmux.Any())
go s.startHTTP1APIGateway(httpLn)
go m.Serve()
return rpcxLn
}
func (s *Server) startHTTP1APIGateway(ln net.Listener) {
router := httprouter.New()
router.POST("/*servicePath", s.handleGatewayRequest)
router.GET("/*servicePath", s.handleGatewayRequest)
router.PUT("/*servicePath", s.handleGatewayRequest)
if err := http.Serve(ln, router); err != nil {
log.Errorf("error in gateway Serve: %s", err)
}
}
func (s *Server) handleGatewayRequest(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
if r.Header.Get(XServicePath) == "" {
servicePath := params.ByName("servicePath")
if strings.HasPrefix(servicePath, "/") {
servicePath = servicePath[1:]
}
r.Header.Set(XServicePath, servicePath)
}
servicePath := r.Header.Get(XServicePath)
wh := w.Header()
req, err := HTTPRequest2RpcxRequest(r)
defer protocol.FreeMsg(req)
//set headers
wh.Set(XVersion, r.Header.Get(XVersion))
wh.Set(XMessageID, r.Header.Get(XMessageID))
wh.Set(XServicePath, servicePath)
wh.Set(XServiceMethod, r.Header.Get(XServiceMethod))
wh.Set(XSerializeType, r.Header.Get(XSerializeType))
if err != nil {
rh := r.Header
for k, v := range rh {
if strings.HasPrefix(k, "X-RPCX-") && len(v) > 0 {
wh.Set(k, v[0])
}
}
wh.Set(XMessageStatusType, "Error")
wh.Set(XErrorMessage, err.Error())
return
}
ctx := context.WithValue(context.Background(), StartRequestContextKey, time.Now().UnixNano())
err = s.auth(ctx, req)
if err != nil {
s.Plugins.DoPreWriteResponse(ctx, req)
wh.Set(XMessageStatusType, "Error")
wh.Set(XErrorMessage, err.Error())
w.WriteHeader(401)
s.Plugins.DoPostWriteResponse(ctx, req, req.Clone(), err)
return
}
resMetadata := make(map[string]string)
newCtx := context.WithValue(context.WithValue(ctx, share.ReqMetaDataKey, req.Metadata),
share.ResMetaDataKey, resMetadata)
res, err := s.handleRequest(newCtx, req)
defer protocol.FreeMsg(res)
if err != nil {
log.Warnf("rpcx: failed to handle gateway request: %v", err)
wh.Set(XMessageStatusType, "Error")
wh.Set(XErrorMessage, err.Error())
w.WriteHeader(500)
return
}
s.Plugins.DoPreWriteResponse(newCtx, req)
if len(resMetadata) > 0 { //copy meta in context to request
meta := res.Metadata
if meta == nil {
res.Metadata = resMetadata
} else {
for k, v := range resMetadata {
meta[k] = v
}
}
}
meta := url.Values{}
for k, v := range res.Metadata {
meta.Add(k, v)
}
wh.Set(XMeta, meta.Encode())
w.Write(res.Payload)
s.Plugins.DoPostWriteResponse(newCtx, req, res, err)
}