/
wtwire_test.go
188 lines (165 loc) · 5.2 KB
/
wtwire_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package wtwire_test
import (
"bytes"
"math/rand"
"reflect"
"testing"
"testing/quick"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/watchtower/wtwire"
)
// TestWatchtowerWireProtocol uses the testing/quick package to create a series
// of fuzz tests to attempt to break a primary scenario which is implemented as
// property based testing scenario.
func TestWatchtowerWireProtocol(t *testing.T) {
t.Parallel()
// mainScenario is the primary test that will programmatically be
// executed for all registered wire messages. The quick-checker within
// testing/quick will attempt to find an input to this function, s.t
// the function returns false, if so then we've found an input that
// violates our model of the system.
mainScenario := func(msg wtwire.Message) bool {
// Give a new message, we'll serialize the message into a new
// bytes buffer.
var b bytes.Buffer
if _, err := wtwire.WriteMessage(&b, msg, 0); err != nil {
t.Fatalf("unable to write msg: %v", err)
return false
}
// Next, we'll ensure that the serialized payload (subtracting
// the 2 bytes for the message type) is _below_ the specified
// max payload size for this message.
payloadLen := uint32(b.Len()) - 2
if payloadLen > msg.MaxPayloadLength(0) {
t.Fatalf("msg payload constraint violated: %v > %v",
payloadLen, msg.MaxPayloadLength(0))
return false
}
// Finally, we'll deserialize the message from the written
// buffer, and finally assert that the messages are equal.
newMsg, err := wtwire.ReadMessage(&b, 0)
if err != nil {
t.Fatalf("unable to read msg: %v", err)
return false
}
if !reflect.DeepEqual(msg, newMsg) {
t.Fatalf("messages don't match after re-encoding: %v "+
"vs %v", spew.Sdump(msg), spew.Sdump(newMsg))
return false
}
return true
}
type typeGenFunc func([]reflect.Value, *rand.Rand)
// customTypeGen is a map of functions that are able to randomly
// generate a given type. These functions are needed for types which
// are too complex for the testing/quick package to automatically
// generate.
customTypeGen := map[wtwire.MessageType]typeGenFunc{
wtwire.MsgSessionInit: func(v []reflect.Value, r *rand.Rand) {
req := &wtwire.SessionInit{
Version: uint16(r.Int31()),
MaxUpdates: uint16(r.Int31()),
RewardRate: uint32(r.Int63()),
SweepFeeRate: lnwallet.SatPerVByte(r.Int63()),
}
v[0] = reflect.ValueOf(*req)
},
wtwire.MsgSessionInitReply: func(v []reflect.Value, r *rand.Rand) {
req := &wtwire.SessionInitReply{
Code: wtwire.SessionInitCode(r.Int31()),
}
dataLen := rand.Int31n(
wtwire.MaxSessionInitReplyDataLength,
)
req.Data = make([]byte, dataLen)
if _, err := r.Read(req.Data); err != nil {
t.Fatalf("unable to generate data: %v",
err)
return
}
v[0] = reflect.ValueOf(*req)
},
wtwire.MsgStateUpdate: func(v []reflect.Value, r *rand.Rand) {
req := &wtwire.StateUpdate{
SeqNum: uint16(r.Int31()),
IsComplete: uint8(r.Int31()),
LastApplied: uint16(r.Int31()),
}
if _, err := r.Read(req.Hint[:]); err != nil {
t.Fatalf("unable to generate breach hint: %v",
err)
return
}
blobLen := rand.Int31n(wtwire.MaxMessagePayload)
req.EncryptedBlob = make([]byte, blobLen)
if _, err := r.Read(req.EncryptedBlob); err != nil {
t.Fatalf("unable to generate encrypted blob: %v",
err)
return
}
v[0] = reflect.ValueOf(*req)
},
wtwire.MsgStateUpdateReply: func(v []reflect.Value, r *rand.Rand) {
req := &wtwire.StateUpdateReply{
Code: wtwire.StateUpdateCode(r.Int31()),
LastApplied: uint16(r.Int31()),
}
v[0] = reflect.ValueOf(*req)
},
}
// With the above types defined, we'll now generate a slice of
// scenarios to feed into quick.Check. The function scans in input
// space of the target function under test, so we'll need to create a
// series of wrapper functions to force it to iterate over the target
// types, but re-use the mainScenario defined above.
tests := []struct {
msgType wtwire.MessageType
scenario interface{}
}{
{
msgType: wtwire.MsgSessionInit,
scenario: func(m wtwire.SessionInit) bool {
return mainScenario(&m)
},
},
{
msgType: wtwire.MsgSessionInitReply,
scenario: func(m wtwire.SessionInitReply) bool {
return mainScenario(&m)
},
},
{
msgType: wtwire.MsgStateUpdate,
scenario: func(m wtwire.StateUpdate) bool {
return mainScenario(&m)
},
},
{
msgType: wtwire.MsgStateUpdateReply,
scenario: func(m wtwire.StateUpdateReply) bool {
return mainScenario(&m)
},
},
}
for _, test := range tests {
var config *quick.Config
// If the type defined is within the custom type gen map above,
// then we'll modify the default config to use this Value
// function that knows how to generate the proper types.
if valueGen, ok := customTypeGen[test.msgType]; ok {
config = &quick.Config{
Values: valueGen,
}
}
t.Logf("Running fuzz tests for msgType=%v", test.msgType)
if err := quick.Check(test.scenario, config); err != nil {
t.Fatalf("fuzz checks for msg=%v failed: %v",
test.msgType, err)
}
}
}
func init() {
rand.Seed(time.Now().Unix())
}