/
helper.go
133 lines (122 loc) · 2.91 KB
/
helper.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package wrpc
import (
"context"
"fmt"
"time"
. "github.com/kong/go-wrpc/wrpc/internal/wrpc"
)
func deadlineFromCtx(ctx context.Context) uint32 {
if ctx == nil {
deadline := time.Now().Add(defaultTimeout)
return uint32(deadline.Unix())
}
deadline, ok := ctx.Deadline()
if !ok {
deadline = time.Now().Add(defaultTimeout)
}
return uint32(deadline.Unix())
}
type rpcMessageOpts struct {
svcID, rpcID ID
ack, seq uint32
encoding Encoding
payload []byte
}
type errorMessageOpts struct {
svcID, rpcID, ack, seq uint32
error error
}
func createErrorMessage(opts errorMessageOpts) *WebsocketPayload {
return &WebsocketPayload{
Version: 1,
Payload: &PayloadV1{
Mtype: MessageType_MESSAGE_TYPE_ERROR,
Error: &Error{
// XXX
Etype: ErrorType_ERROR_TYPE_UNSPECIFIED,
Description: opts.error.Error(),
},
SvcId: opts.svcID,
RpcId: opts.rpcID,
Seq: opts.seq,
Ack: opts.ack,
},
}
}
func createRPCMessage(opts rpcMessageOpts) *WebsocketPayload {
return &WebsocketPayload{
Version: 1,
Payload: &PayloadV1{
Mtype: MessageType_MESSAGE_TYPE_RPC,
SvcId: uint32(opts.svcID),
RpcId: uint32(opts.rpcID),
Seq: opts.seq,
Ack: opts.ack,
PayloadEncoding: opts.encoding,
Payloads: [][]byte{opts.payload},
},
}
}
func validateMessage(m *WebsocketPayload) error {
if m.Version != 1 {
return fmt.Errorf("wrpc: invalid version: %v", m.Version)
}
if m.Payload == nil {
return fmt.Errorf("no payload")
}
if m.Payload.Seq == 0 {
return fmt.Errorf("invalid seq(0)")
}
switch m.Payload.Mtype {
case MessageType_MESSAGE_TYPE_ERROR:
return validateErrorMessage(m.Payload)
case MessageType_MESSAGE_TYPE_RPC:
return validateRPCMessage(m.Payload)
case MessageType_MESSAGE_TYPE_UNSPECIFIED:
default:
return fmt.Errorf("invalid message type: %d", m.Payload.Mtype)
}
return nil
}
func validateErrorMessage(m *PayloadV1) error {
if m.Error == nil {
return fmt.Errorf("error message without any error")
}
return nil
}
func validateRPCMessage(m *PayloadV1) error {
if m.SvcId == 0 {
return fmt.Errorf("invalid svc_id(0)")
}
if m.RpcId == 0 {
return fmt.Errorf("invalid rpc_id(0)")
}
if m.Ack == 0 && m.Deadline == 0 {
return fmt.Errorf("invalid deadline(0) for request")
}
if m.Ack != 0 && m.Deadline != 0 {
return fmt.Errorf("invalid deadline(%v) for response", m.Deadline)
}
numPayloads := len(m.Payloads)
if numPayloads > 1 {
return fmt.Errorf("unexpected number of payloads(%v) in a message",
numPayloads)
}
if numPayloads > 0 {
err := validateEncoding(m.PayloadEncoding)
if err != nil {
return err
}
}
return nil
}
func validateEncoding(e Encoding) error {
switch e {
case Encoding_ENCODING_PROTO3:
return nil
case Encoding_ENCODING_UNSPECIFIED:
default:
return fmt.Errorf("invalid encoding(%v)", e)
}
return fmt.Errorf("invalid encoding(%v)", e)
}