-
Notifications
You must be signed in to change notification settings - Fork 12
/
stream.go
170 lines (148 loc) · 4.67 KB
/
stream.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package httpapi
import (
"context"
"net/http"
"reflect"
"github.com/golang/protobuf/jsonpb"
"github.com/golang/protobuf/proto"
"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata"
"github.com/kelda/blimp/pkg/errors"
)
// StreamHandler defines how gRPC streams should be invoked.
type StreamHandler struct {
RequestType proto.Message
RPC StreamRPC
}
type StreamRPC func(proto.Message, WebSocketStream) error
// WebSocketStream satisfies the methods of grpc.ServerStream so that shims
// can avoid redefining the methods. Shims should forward calls from their typed
// Send methods to the SendProtoMessage method.
type WebSocketStream struct {
messages chan<- proto.Message
ctx context.Context
}
// SendProtoMessage should be called by the typed Send message of shims to
// forward stream messages to the websocket.
func (s WebSocketStream) SendProtoMessage(msg proto.Message) error {
select {
case <-s.ctx.Done():
return errors.New("connection closed")
case s.messages <- msg:
return nil
}
}
func (s WebSocketStream) SetHeader(_ metadata.MD) error {
return errors.New("unimplemented")
}
func (s WebSocketStream) SendHeader(_ metadata.MD) error {
return errors.New("unimplemented")
}
func (s WebSocketStream) SetTrailer(_ metadata.MD) {}
func (s WebSocketStream) Context() context.Context {
return s.ctx
}
func (s WebSocketStream) SendMsg(_ interface{}) error {
return errors.New("unimplemented")
}
func (s WebSocketStream) RecvMsg(_ interface{}) error {
return errors.New("unimplemented")
}
func (handler StreamHandler) Handler() (http.HandlerFunc, error) {
// Validate that the RequestType field is as expected.
if handler.RequestType == nil {
return nil, errors.New("RequestType must be set")
}
if reflect.TypeOf(handler.RequestType).Kind() != reflect.Ptr {
return nil, errors.New("RequestType's concrete type must be a pointer")
}
return func(w http.ResponseWriter, r *http.Request) {
upgrader := &websocket.Upgrader{
// By default, only same-origin requests are allowed. Allow all
// requests, regardless of origin.
CheckOrigin: func(r *http.Request) bool {
return true
},
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.WithError(err).Warn("Failed to upgrade connection")
return
}
defer conn.Close()
// The client's first message should always be the protobuf request.
_, reqJSON, err := conn.ReadMessage()
if err != nil {
log.WithError(err).Warn("Failed to read client stream request")
return
}
// Shut down the forwarder if the connection closes.
forwardCtx, cancelForward := context.WithCancel(context.Background())
defer cancelForward()
go func() {
for {
_, _, err := conn.ReadMessage()
if err != nil {
// err could be non-nil if the client gracefully closes the
// connection with a Close message, or if the connection
// breaks. Either way, we should stop sending messages
// back.
cancelForward()
return
}
}
}()
err = handler.forward(forwardCtx, conn, string(reqJSON))
if err == nil {
return
}
select {
// If the forwarder exited because the context was cancelled, don't
// bother sending any more messages.
case <-forwardCtx.Done():
return
default:
}
if err := conn.WriteJSON(unaryHTTPResponse{Error: err}); err != nil {
log.WithError(err).Warn("Failed to send final stream update")
}
if err := conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil {
log.WithError(err).Warn("Failed to send websocket close")
}
}, nil
}
func (handler StreamHandler) forward(ctx context.Context, conn *websocket.Conn, reqJSON string) error {
// The Handler method guarantees that RequestType's concrete type is a
// pointer.
reqType := reflect.TypeOf(handler.RequestType).Elem()
protoReq := reflect.New(reqType).Interface().(proto.Message)
if err := jsonpb.UnmarshalString(string(reqJSON), protoReq); err != nil {
return errors.WithContext("unmarshal request", err)
}
// Forward messages from the gRPC stream to the websockets stream.
messages := make(chan proto.Message)
wsStream := WebSocketStream{messages: messages, ctx: ctx}
ctx, cancel := context.WithCancel(ctx)
doneForwarding := make(chan struct{})
go func() {
defer close(doneForwarding)
for {
select {
case msg := <-messages:
if err := conn.WriteJSON(unaryHTTPResponse{Result: msg}); err != nil {
log.WithError(err).Warn("Failed to send stream update")
return
}
case <-ctx.Done():
return
}
}
}()
err := handler.RPC(protoReq, wsStream)
cancel()
// Block until the forwarder goroutine has returned to avoid concurrent
// writes to the websocket connection.
<-doneForwarding
return err
}