-
Notifications
You must be signed in to change notification settings - Fork 10
/
codec.go
214 lines (188 loc) · 6.07 KB
/
codec.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
package codec
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"net"
"net/rpc"
"code.google.com/p/goprotobuf/proto"
"github.com/kylelemons/go-rpcgen/plugin/wire"
)
// ServerCodec implements the rpc.ServerCodec interface for generic protobufs.
// The same implementation works for all protobufs because it defers the
// decoding of a protocol buffer to the proto package and it uses a set header
// that is the same regardless of the protobuf being used for the RPC.
type ServerCodec struct {
r *bufio.Reader
w io.WriteCloser
}
type ProtoReader interface {
io.Reader
io.ByteReader
}
// ReadProto reads a uvarint size and then a protobuf from r.
// If the size read is zero, nothing more is read.
func ReadProto(r ProtoReader, pb proto.Message) error {
size, err := binary.ReadUvarint(r)
if err != nil {
return err
}
// TODO max size?
buf := make([]byte, size)
if _, err := io.ReadFull(r, buf); err != nil {
return err
}
return proto.Unmarshal(buf, pb)
}
// WriteProto writes a uvarint size and then a protobuf to w.
// If the data takes no space (like rpc.InvalidRequest),
// only a zero size is written.
func WriteProto(w io.Writer, pb proto.Message) error {
// Allocate enough space for the biggest uvarint
var size [binary.MaxVarintLen64]byte
// Marshal the protobuf
data, err := proto.Marshal(pb)
if err != nil {
return err
}
// Write the size and data
n := binary.PutUvarint(size[:], uint64(len(data)))
if _, err = w.Write(size[:n]); err != nil {
return err
}
if _, err = w.Write(data); err != nil {
return err
}
return nil
}
// NewServerCodec returns a ServerCodec that communicates with the ClientCodec
// on the other end of the given conn.
func NewServerCodec(conn net.Conn) *ServerCodec {
return &ServerCodec{bufio.NewReader(conn), conn}
}
// ReadRequestHeader reads the header protobuf (which is prefixed by a uvarint
// indicating its size) from the connection, decodes it, and stores the fields
// in the given request.
func (s *ServerCodec) ReadRequestHeader(req *rpc.Request) error {
var header wire.Header
if err := ReadProto(s.r, &header); err != nil {
return err
}
if header.Method == nil {
return fmt.Errorf("header missing method: %s", header)
}
if header.Seq == nil {
return fmt.Errorf("header missing seq: %s", header)
}
req.ServiceMethod = *header.Method
req.Seq = *header.Seq
return nil
}
// ReadRequestBody reads a uvarint from the connection and decodes that many
// subsequent bytes into the given protobuf (which should be a pointer to a
// struct that is generated by the proto package).
func (s *ServerCodec) ReadRequestBody(obj interface{}) error {
pb, ok := obj.(proto.Message)
if !ok {
return fmt.Errorf("%T does not implement proto.Message", obj)
}
return ReadProto(s.r, pb)
}
// WriteResponse writes the appropriate header protobuf and the given protobuf
// to the connection (each prefixed with a uvarint indicating its size). If
// the response was invalid, the size of the body of the resp is reported as
// having size zero and is not sent.
func (s *ServerCodec) WriteResponse(resp *rpc.Response, obj interface{}) error {
pb, ok := obj.(proto.Message)
if !ok {
return fmt.Errorf("%T does not implement proto.Message", obj)
}
// Write the header
header := wire.Header{
Method: &resp.ServiceMethod,
Seq: &resp.Seq,
}
if resp.Error != "" {
header.Error = &resp.Error
}
if err := WriteProto(s.w, &header); err != nil {
return nil
}
// Write the proto
return WriteProto(s.w, pb)
}
// Close closes the underlying conneciton.
func (s *ServerCodec) Close() error {
return s.w.Close()
}
// ClientCodec implements the rpc.ClientCodec interface for generic protobufs.
// The same implementation works for all protobufs because it defers the
// encoding of a protocol buffer to the proto package and it uses a set header
// that is the same regardless of the protobuf being used for the RPC.
type ClientCodec struct {
r *bufio.Reader
w io.WriteCloser
}
// NewClientCodec returns a ClientCodec for communicating with the ServerCodec
// on the other end of the conn.
func NewClientCodec(conn net.Conn) *ClientCodec {
return &ClientCodec{bufio.NewReader(conn), conn}
}
// WriteRequest writes the appropriate header protobuf and the given protobuf
// to the connection (each prefixed with a uvarint indicating its size).
func (c *ClientCodec) WriteRequest(req *rpc.Request, obj interface{}) error {
pb, ok := obj.(proto.Message)
if !ok {
return fmt.Errorf("%T does not implement proto.Message", obj)
}
// Write the header
header := wire.Header{
Method: &req.ServiceMethod,
Seq: &req.Seq,
}
if err := WriteProto(c.w, &header); err != nil {
return err
}
return WriteProto(c.w, pb)
}
// ReadResponseHeader reads the header protobuf (which is prefixed by a uvarint
// indicating its size) from the connection, decodes it, and stores the fields
// in the given request.
func (c *ClientCodec) ReadResponseHeader(resp *rpc.Response) error {
var header wire.Header
if err := ReadProto(c.r, &header); err != nil {
return err
}
if header.Method == nil {
return fmt.Errorf("header missing method: %s", header)
}
if header.Seq == nil {
return fmt.Errorf("header missing seq: %s", header)
}
resp.ServiceMethod = *header.Method
resp.Seq = *header.Seq
if header.Error != nil {
resp.Error = *header.Error
}
return nil
}
// ReadResponseBody reads a uvarint from the connection and decodes that many
// subsequent bytes into the given protobuf (which should be a pointer to a
// struct that is generated by the proto package). If the uvarint size read
// is zero, nothing is done (this indicates an error condition, which was
// encapsulated in the header)
func (c *ClientCodec) ReadResponseBody(obj interface{}) error {
pb, ok := obj.(proto.Message)
if !ok {
return fmt.Errorf("%T does not implement proto.Message", obj)
}
return ReadProto(c.r, pb)
}
// Close closes the underlying connection.
func (c *ClientCodec) Close() error {
return c.w.Close()
}
// BUG: The server/client don't do a sanity check on the size of the proto
// before reading it, so it's possible to maliciously instruct the
// client/server to allocate too much memory.