forked from stackrox/go-grpc-http1
-
Notifications
You must be signed in to change notification settings - Fork 0
/
response_reader.go
252 lines (213 loc) · 7.51 KB
/
response_reader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
// Copyright (c) 2020 StackRox Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License
package grpcweb
import (
"bufio"
"bytes"
"encoding/binary"
"fmt"
"io"
"net/http"
"net/textproto"
"strings"
"github.com/pkg/errors"
"github.com/mbanikazemi/go-grpc-http1/internal/ioutils"
)
type errExtraData int64
func (e errExtraData) Error() string {
return fmt.Sprintf("at least %d extra bytes after trailers frame", int64(e))
}
var (
// ErrNoDecompressor means that we don't know how to decompress a compressed trailer message.
ErrNoDecompressor = errors.New("compressed message encountered, but no decompressor specified")
)
// Decompressor returns a decompressed ReadCloser for a given compressed ReadCloser.
type Decompressor func(io.ReadCloser) io.ReadCloser
type responseReader struct {
io.ReadCloser
decompressor Decompressor
trailers *http.Header
// err is the error condition encountered, if any (sticky!)
err error
// Indicates how many bytes of the current gRPC web message remain to be read. If 0, we expect the start of the next
// message header.
currMessageRemaining int64
// A partially read message header
currPartialMsgHeader []byte
// partialTrailerData stores data read from a trailer in a previous read call.
partialTrailerData []byte
// Keeps track of whether we have read any data at all, and whether we have read trailers. Relevant for determining
// whether we can accept an EOF.
hasReadData, hasReadTrailers bool
}
// NewResponseReader returns a response reader that on-the-fly transcodes a gRPC web response into normal gRPC framing.
// Once the reader has reached EOF, the given trailers (which must be non-nil) are populated.
func NewResponseReader(origResp io.ReadCloser, trailers *http.Header, decompressor Decompressor) io.ReadCloser {
return &responseReader{
ReadCloser: origResp,
trailers: trailers,
decompressor: decompressor,
}
}
func (r *responseReader) adjustResult(n int, err error) (int, error) {
if r.hasReadTrailers {
if n > 0 && (err == nil || err == io.EOF) {
err = errExtraData(n)
}
n = 0
} else if r.hasReadData && err == io.EOF /* && !r.hasReadTrailers */ {
if len(r.partialTrailerData) > 0 {
// If there are pending trailers, these will be handled in the next call to Read, hence do not propagate
// EOF at this point. This is relevant if the reader returns EOF *with* the last bytes read, as opposed to
// return `0, EOF` in a subsequent call.
err = nil
} else {
err = io.ErrUnexpectedEOF
}
}
return n, err
}
func (r *responseReader) Read(buf []byte) (int, error) {
if r.err != nil {
return 0, r.err
}
n, err := r.adjustResult(r.doRead(buf))
if err != nil {
r.err = err
}
return n, err
}
func (r *responseReader) doRead(buf []byte) (int, error) {
if len(r.partialTrailerData) > 0 {
if err := r.readFullTrailers(); err != nil {
return 0, err
}
r.hasReadTrailers = true
r.partialTrailerData = nil
}
n, err := r.ReadCloser.Read(buf)
if n > 0 {
r.hasReadData = true
}
if r.hasReadTrailers {
// If we have already read trailers, directly pass through the result. adjustResult will take care of
// translating any extra data into a "real" error condition.
return n, err
}
buf = buf[:n]
nPayload := r.consume(buf)
extraDataBytes := n - nPayload
if extraDataBytes > 0 {
r.partialTrailerData = append(r.partialTrailerData, buf[nPayload:n]...)
}
// Special case: read buffer only contains trailers. In this case, simply repeat the read.
if nPayload == 0 && len(r.partialTrailerData) > 0 {
return r.doRead(buf)
}
return nPayload, err
}
// readFullTrailers reads the trailers, taking the stored partial trailer data into account.
func (r *responseReader) readFullTrailers() error {
reader := io.MultiReader(bytes.NewReader(r.partialTrailerData), r.ReadCloser)
var frameHeader [5]byte
_, err := io.ReadFull(reader, frameHeader[:])
if err != nil {
return err
}
frameLen := binary.BigEndian.Uint32(frameHeader[1:])
var numBytesRead int64
trailersDataReader := ioutils.NewCountingReader(io.LimitReader(reader, int64(frameLen)), &numBytesRead)
if frameHeader[0]&compressedFlag != 0 {
if r.decompressor == nil {
return ErrNoDecompressor
}
trailersDataReader = r.decompressor(trailersDataReader)
}
// textproto Reader requires a terminating newline (\r\n) after the last header line, which is not contained in the
// gRPC web trailer frame.
trailersReader := textproto.NewReader(bufio.NewReader(
io.MultiReader(trailersDataReader, strings.NewReader("\r\n"))))
trailers, err := trailersReader.ReadMIMEHeader()
if err != nil {
return err
}
if _, err := trailersReader.R.Peek(1); err != io.EOF {
if err == nil {
err = errors.New("incomplete read of trailers")
}
return err
}
// Note that if we don't use a decompressor, this is guaranteed to not close the underlying reader, as `LimitReader`
// will make the Close method inaccessible, and hence the reader returned by NewCountingReader doubles as a
// NopCloser.
if err := trailersDataReader.Close(); err != nil {
return err
}
if numBytesRead != int64(frameLen) {
return errors.Errorf("only read %d out of %d bytes from trailers frame", numBytesRead, frameLen)
}
r.populateTrailers(trailers)
// Special case: if `r.partialTrailerData` contains data past the trailers frame, make sure we don't silently
// discard it (we still discard it, but with an error).
if extraBytes := int64(len(r.partialTrailerData)) - int64(len(frameHeader)) - int64(frameLen); extraBytes > 0 {
return errExtraData(extraBytes)
}
return nil
}
func (r *responseReader) populateTrailers(trailers textproto.MIMEHeader) {
if *r.trailers == nil {
*r.trailers = make(http.Header)
}
for k, vs := range trailers {
canonicalK := http.CanonicalHeaderKey(k)
(*r.trailers)[canonicalK] = append((*r.trailers)[canonicalK], vs...)
}
}
// consume reads regular frame data from buf, stopping as soon as the first byte of a trailer frame is encountered.
// The return value is the number of bytes consumed without any trailer frame data.
func (r *responseReader) consume(buf []byte) int {
n := int64(0)
for len(buf) > 0 {
lastMsgBytes := r.currMessageRemaining
if lastMsgBytes > int64(len(buf)) {
lastMsgBytes = int64(len(buf))
}
buf = buf[lastMsgBytes:]
r.currMessageRemaining -= lastMsgBytes
n += lastMsgBytes
if len(buf) == 0 {
break
}
// At beginning of header - check if the next message is a trailer message
if len(r.currPartialMsgHeader) == 0 {
if buf[0]&trailerMessageFlag != 0 {
break
}
}
// Read header data
remainingHeaderBytes := completeHeaderLen - len(r.currPartialMsgHeader)
if remainingHeaderBytes > len(buf) {
remainingHeaderBytes = len(buf)
}
r.currPartialMsgHeader = append(r.currPartialMsgHeader, buf[:remainingHeaderBytes]...)
n += int64(remainingHeaderBytes)
buf = buf[remainingHeaderBytes:]
// Check for complete header
if len(r.currPartialMsgHeader) == completeHeaderLen {
r.currMessageRemaining = int64(binary.BigEndian.Uint32(r.currPartialMsgHeader[1:]))
r.currPartialMsgHeader = r.currPartialMsgHeader[:0]
}
}
return int(n)
}