-
Notifications
You must be signed in to change notification settings - Fork 204
/
grpcproxy.go
87 lines (75 loc) · 2.44 KB
/
grpcproxy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package grpcproxy
import (
"context"
"net/http"
"github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/lavanet/lava/utils"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
type ProxyCallBack = func(ctx context.Context, method string, reqBody []byte) ([]byte, metadata.MD, error)
func NewGRPCProxy(cb ProxyCallBack, healthCheckPath string) (*grpc.Server, *http.Server, error) {
s := grpc.NewServer(grpc.UnknownServiceHandler(makeProxyFunc(cb)), grpc.ForceServerCodec(RawBytesCodec{}))
wrappedServer := grpcweb.WrapServer(s)
handler := func(resp http.ResponseWriter, req *http.Request) {
// Set CORS headers
resp.Header().Set("Access-Control-Allow-Origin", "*")
resp.Header().Set("Access-Control-Allow-Headers", "Content-Type,x-grpc-web")
if req.URL.Path == healthCheckPath && req.Method == http.MethodGet {
resp.WriteHeader(200)
_, _ = resp.Write(make([]byte, 0))
return
}
wrappedServer.ServeHTTP(resp, req)
}
httpServer := &http.Server{
Handler: h2c.NewHandler(http.HandlerFunc(handler), &http2.Server{}),
}
return s, httpServer, nil
}
func makeProxyFunc(callBack ProxyCallBack) grpc.StreamHandler {
return func(srv interface{}, stream grpc.ServerStream) error {
// currently the callback function does not account for headers.
methodName, ok := grpc.MethodFromServerStream(stream)
if !ok {
return status.Error(codes.Unavailable, "unable to get method name")
}
var reqBytes []byte
err := stream.RecvMsg(&reqBytes)
if err != nil {
return err
}
respBytes, md, err := callBack(stream.Context(), methodName[1:], reqBytes) // strip first '/' of the method name
if err != nil {
return err
}
stream.SetHeader(md)
return stream.SendMsg(respBytes)
}
}
type RawBytesCodec struct{}
func (RawBytesCodec) Marshal(v interface{}) ([]byte, error) {
bytes, ok := v.([]byte)
if !ok {
return nil, utils.LavaFormatError("cannot encode type", nil, utils.Attribute{Key: "v", Value: v})
}
return bytes, nil
}
func (RawBytesCodec) Unmarshal(data []byte, v interface{}) error {
bufferPtr, ok := v.(*[]byte)
if !ok {
return utils.LavaFormatError("cannot decode into type", nil, utils.Attribute{Key: "v", Value: v})
}
*bufferPtr = data
return nil
}
func (RawBytesCodec) Name() string {
return "lava/grpc-proxy-codec"
}
func (RawBytesCodec) String() string {
return RawBytesCodec{}.Name()
}