/
prepare_response.go
121 lines (111 loc) · 2.68 KB
/
prepare_response.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package prepared
import (
"errors"
"github.com/didi/sharingan/replayer-agent/utils/protocol/pmysql/command"
"github.com/didi/sharingan/replayer-agent/utils/protocol/pmysql/common"
"github.com/modern-go/parse"
"github.com/modern-go/parse/model"
)
type PrepareResponse struct {
StatementID int // 4 bytes
ColumnNumber int // 2 bytes
ParamNumber int // 2 bytes
// filter 0x00
Warnings int // 2 bytes
ColumnDef []command.Columndef
}
func (s *PrepareResponse) String() string {
data, err := json.Marshal(s)
if nil != err {
return err.Error()
}
return string(data)
}
// Map ...
func (s *PrepareResponse) Map() model.Map {
r := make(model.Map)
r["statement_id"] = s.StatementID
r["column_num"] = s.ColumnNumber
r["param_num"] = s.ParamNumber
r["warnings"] = s.Warnings
colDef := make(model.List, 0, s.ColumnNumber)
for _, def := range s.ColumnDef {
colDef = append(colDef, def.Map())
}
r["column_def"] = colDef
return r
}
var (
errNoPrepareResponsePacket = errors.New("not prepare response")
)
// DecodePrepareResponse 解码prepare语句的response
func DecodePrepareResponse(src *parse.Source) (*PrepareResponse, error) {
pkLen, _ := common.GetPacketHeader(src)
if src.Error() != nil {
return nil, src.Error()
}
if pkLen == 0 {
return nil, errNoPrepareResponsePacket
}
if !src.Expect1(0) {
return nil, errNoPrepareResponsePacket
}
var err error
resp := new(PrepareResponse)
resp.StatementID, err = common.GetIntN(src.ReadN(4), 4)
if nil != err {
return nil, err
}
resp.ColumnNumber, err = common.GetIntN(src.ReadN(2), 2)
if nil != err {
return nil, err
}
resp.ParamNumber, err = common.GetIntN(src.ReadN(2), 2)
if nil != err {
return nil, err
}
// flter 0x00
if !src.Expect1(0) {
return nil, errNoPrepareResponsePacket
}
resp.Warnings, err = common.GetIntN(src.ReadN(2), 2)
if resp.ParamNumber > 0 {
for i := 0; i < resp.ParamNumber; i++ {
err = readParamDef(src)
if nil != err {
return nil, err
}
}
common.ReadEOFPacket(src)
if src.Error() != nil {
return nil, src.Error()
}
}
if resp.ColumnNumber > 0 {
for i := 0; i < resp.ColumnNumber; i++ {
colDef, err := command.DecodeColumnDef(src)
if nil != err {
return nil, err
}
resp.ColumnDef = append(resp.ColumnDef, colDef)
}
common.ReadEOFPacket(src)
if src.Error() != nil {
return nil, src.Error()
}
}
return resp, src.Error()
}
func readParamDef(src *parse.Source) error {
pkLen, _ := common.GetPacketHeader(src)
def := src.ReadN(4)
if src.Error() != nil {
return src.Error()
}
// 0x03 def
if def[0] != 0x03 || def[1] != 0x64 || def[2] != 0x65 || def[3] != 0x66 {
return errNoPrepareResponsePacket
}
src.ReadN(pkLen - 4)
return src.Error()
}