/
h2rpc.go
201 lines (173 loc) · 5.44 KB
/
h2rpc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
package handler
import (
"fmt"
"io/ioutil"
"mime"
"net/http"
"strings"
"time"
"github.com/facebookgo/stack"
h2error "github.com/HailoOSS/api-proxy/errors"
"github.com/HailoOSS/api-proxy/session"
"github.com/HailoOSS/api-proxy/stats"
"github.com/HailoOSS/api-proxy/trace"
"github.com/HailoOSS/platform/client"
"github.com/HailoOSS/platform/errors"
"github.com/HailoOSS/service/auth"
inst "github.com/HailoOSS/service/instrumentation"
)
const (
protoMime = "application/x-protobuf"
formEncodedMime = "application/x-www-form-urlencoded"
defaultResponseMime = "application/json; charset=utf-8"
)
type caller func(req *client.Request, options ...client.Options) (*client.Response, errors.Error)
var (
rpcCaller caller = client.CustomReq
authScopeConstructor func() auth.Scope = auth.New
)
// rpcHandler handles inbound HTTP requests for H2 (RPC)
func rpcHandler(rw http.ResponseWriter, r *http.Request) {
start := time.Now()
success := false
defer func() {
if !success {
inst.Timing(1.0, h2_failure, time.Since(start))
inst.Counter(1.0, h2_failure, 1)
inst.Timing(1.0, h2_azFailure, time.Since(start))
inst.Counter(1.0, h2_azFailure, 1)
} else {
inst.Timing(1.0, h2_success, time.Since(start))
inst.Counter(1.0, h2_success, 1)
inst.Timing(1.0, h2_azSuccess, time.Since(start))
inst.Counter(1.0, h2_azSuccess, 1)
}
if r.URL.Path != "/" {
stats.Record("/", success, time.Since(start))
}
stats.Record(r.URL.Path, success, time.Since(start))
}()
traceInfo := trace.Start(r)
// sanity check basics - we should have a content type and requests shoudl be POSTed
if r.Method != "POST" {
err := &h2error.ApiError{
ErrorType: errors.ErrorBadRequest,
ErrorCode: "com.HailoOSS.api.rpc.postrequired",
ErrorDescription: "Requests to the RPC endpoint must be POST-ed",
ErrorContext: []string{"15"},
ErrorHttpCode: http.StatusMethodNotAllowed,
HttpHeaders: map[string]string{
"Allow": "POST",
},
ErrorMultiStack: stack.CallersMulti(0),
}
h2error.Write(rw, err, defaultResponseMime, traceInfo)
return
}
ct, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if ct == "" || err != nil {
r.Header.Set("Content-Type", formEncodedMime)
}
// decide how to respond
responseContentType := defaultResponseMime
if r.Header.Get("Content-Type") == protoMime {
responseContentType = protoMime
}
// ---
service, request, perr := httpToH2Request(r)
if perr != nil {
h2error.Write(rw, perr, responseContentType, traceInfo)
return
}
// test auth -- blanket block on com.HailoOSS.kernel.*
if perr := authorisedFor(r, service); perr != nil {
h2error.Write(rw, perr, responseContentType, traceInfo)
return
}
// add scope of session and/or trace
if traceInfo.TraceId != "" {
request.SetTraceID(traceInfo.TraceId)
request.SetTraceShouldPersist(traceInfo.PersistentTrace)
}
request.SetSessionID(session.SessionId(r))
request.SetFrom("com.HailoOSS.hailo-2-api")
request.SetRemoteAddr(r.RemoteAddr)
rsp, perr := rpcCaller(request)
if perr != nil {
h2error.Write(rw, perr, responseContentType, traceInfo)
return
}
// add any trace details to output
trace.Write(rw, traceInfo)
rw.Header().Set("Content-Type", responseContentType)
rw.WriteHeader(200)
rw.Write(rsp.Body())
success = true
}
// httpToH2Request looks at the HTTP headers to determine what content type we are fed,
// and then constructs an appropriate H2 request
func httpToH2Request(r *http.Request) (service string, req *client.Request, perr errors.Error) {
// extract request bytes, service name, endpoint name
var (
endpoint string
reqBytes []byte
)
ct := r.Header.Get("Content-Type")
switch ct {
case protoMime: // raw bytes
reqBytes, _ = ioutil.ReadAll(r.Body)
service = r.URL.Query().Get("service")
endpoint = r.URL.Query().Get("endpoint")
default: // assume JSON is posted as a form param
if err := r.ParseForm(); err != nil {
perr = errors.BadRequest("com.HailoOSS.api.rpc.parseform", "Cannot parse form data.", "15")
return
}
reqBytes = []byte(r.PostForm.Get("request"))
if len(reqBytes) == 0 {
reqBytes = []byte(`{}`)
}
service = r.PostForm.Get("service")
endpoint = r.PostForm.Get("endpoint")
}
if service == "" {
perr = errors.BadRequest("com.HailoOSS.api.rpc.missingservice", "Missing 'service' parameter.", "15")
return
}
if endpoint == "" {
perr = errors.BadRequest("com.HailoOSS.api.rpc.missingendpoint", "Missing 'endpoint' parameter.", "15")
return
}
// mint client request now
var reqErr error
switch ct {
case protoMime: // raw bytes
req, reqErr = client.NewProtoRequest(service, endpoint, reqBytes)
default: // assume JSON is posted as a form param
req, reqErr = client.NewJsonRequest(service, endpoint, reqBytes)
}
if reqErr != nil {
perr = errors.BadRequest("com.HailoOSS.api.rpc.badrequest", fmt.Sprintf("%v", reqErr))
return
}
return
}
// authorisedFor checks if we are authorised to hit this service
func authorisedFor(r *http.Request, service string) errors.Error {
// only bother if trying to hit kernel
if !strings.HasPrefix(service, "com.HailoOSS.kernel.") {
return nil
}
sessId := r.URL.Query().Get("session_id")
if sessId == "" {
sessId = r.Form.Get("session_id")
}
scope := auth.New()
if sessId != "" {
scope.RecoverSession(sessId)
}
if scope.IsAuth() && scope.AuthUser().HasRole("ADMIN") {
return nil
}
return errors.Forbidden("com.HailoOSS.api.rpc.auth", "Permission denied.", "5")
}