Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions wire/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ const (
// is received.
ErrPayloadChecksum

// ErrTrailingBytes is returned when a message is received that is valid
// enough to fully decode, but also contains additional trailing bytes.
ErrTrailingBytes

// ErrTooManyAddrs is returned when an address list exceeds the maximum
// allowed.
ErrTooManyAddrs
Expand Down Expand Up @@ -182,6 +186,7 @@ var errorCodeStrings = map[ErrorCode]string{
ErrMalformedCmd: "ErrMalformedCmd",
ErrUnknownCmd: "ErrUnknownCmd",
ErrPayloadChecksum: "ErrPayloadChecksum",
ErrTrailingBytes: "ErrTrailingBytes",
ErrTooManyAddrs: "ErrTooManyAddrs",
ErrTooManyTxs: "ErrTooManyTxs",
ErrMsgInvalidForPVer: "ErrMsgInvalidForPVer",
Expand Down
1 change: 1 addition & 0 deletions wire/error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func TestMessageErrorCodeStringer(t *testing.T) {
{ErrMalformedCmd, "ErrMalformedCmd"},
{ErrUnknownCmd, "ErrUnknownCmd"},
{ErrPayloadChecksum, "ErrPayloadChecksum"},
{ErrTrailingBytes, "ErrTrailingBytes"},
{ErrTooManyAddrs, "ErrTooManyAddrs"},
{ErrTooManyTxs, "ErrTooManyTxs"},
{ErrMsgInvalidForPVer, "ErrMsgInvalidForPVer"},
Expand Down
7 changes: 7 additions & 0 deletions wire/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,13 @@ func ReadMessageN(r io.Reader, pver uint32, dcrnet CurrencyNet) (int, Message, [
return totalBytes, nil, nil, err
}

// Reject messages that did not consume the full payload.
if buf.Len() > 0 {
msg := fmt.Sprintf("message payload has %d unconsumed trailing "+
"bytes", buf.Len())
return totalBytes, nil, nil, messageError(op, ErrTrailingBytes, msg)
}

return totalBytes, msg, payload, nil
}

Expand Down
239 changes: 105 additions & 134 deletions wire/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,10 @@ func TestReadMessageWireErrors(t *testing.T) {
pver := ProtocolVersion
dcrnet := MainNet

// Wire encoded bytes for main and testnet networks magic identifiers.
// Wire encoded bytes for testnet magic identifier.
testNetBytes := makeHeader(TestNet3, "", 0, 0)

// Wire encoded bytes for a message that exceeds max overall message
// length.
// Wire encoded bytes for a message that exceeds max overall message length.
mpl := uint32(MaxMessagePayload)
exceedMaxPayloadBytes := makeHeader(dcrnet, "getaddr", mpl+1, 0)

Expand Down Expand Up @@ -238,151 +237,123 @@ func TestReadMessageWireErrors(t *testing.T) {
// contained in the message. Claim there is two, but don't provide
// them. At the same time, forge the header fields so the message is
// otherwise accurate.
badMessageBytes := makeHeader(dcrnet, "addr", 1, 0xeaadc31c)
badMessageBytes := makeHeader(dcrnet, "addr", 1, 0xab37af49)
badMessageBytes = append(badMessageBytes, 0x2)

// Wire encoded bytes for a message which the header claims has 15k
// bytes of data to discard.
discardBytes := makeHeader(dcrnet, "bogus", 15*1024, 0)
// Wire encoded bytes for a message that is valid, but contains additional
// trailing bytes and header fields that are forged so the message is
// otherwise accurate.
payloadSize := uint32(len(testBlockBytes))
trailingBytes := makeHeader(dcrnet, "block", payloadSize+1, 0xb5ec24b8)
trailingBytes = append(trailingBytes, testBlockBytes...)
trailingBytes = append(trailingBytes, 0x01)

tests := []struct {
buf []byte // Wire encoding
pver uint32 // Protocol version for wire encoding
dcrnet CurrencyNet // Decred network for wire encoding
max int // Max size of fixed buffer to induce errors
readErr error // Expected read error
bytes int // Expected num bytes read
}{
// Latest protocol version with intentional read errors.

// Short header. [0]
{
[]byte{},
pver,
dcrnet,
0,
io.EOF,
0,
},

// Wrong network. Want MainNet, but giving TestNet. [1]
{
testNetBytes,
pver,
dcrnet,
len(testNetBytes),
&MessageError{},
24,
},

// Exceed max overall message payload length. [2]
{
exceedMaxPayloadBytes,
pver,
dcrnet,
len(exceedMaxPayloadBytes),
&MessageError{},
24,
},

// Invalid UTF-8 command. [3]
{
badCommandBytes,
pver,
dcrnet,
len(badCommandBytes),
&MessageError{},
24,
},

// Valid, but unsupported command. [4]
{
unsupportedCommandBytes,
pver,
dcrnet,
len(unsupportedCommandBytes),
&MessageError{},
24,
},

// Exceed max allowed payload for a message of a specific type. [5]
{
exceedTypePayloadBytes,
pver,
dcrnet,
len(exceedTypePayloadBytes),
&MessageError{},
24,
},

// Message with a payload shorter than the header indicates. [6]
{
shortPayloadBytes,
pver,
dcrnet,
len(shortPayloadBytes),
io.EOF,
24,
},

name string // Test description
buf []byte // Wire encoding
pver uint32 // Protocol version for wire encoding
dcrnet CurrencyNet // Decred network for wire encoding
max int // Max size of fixed buffer to induce errors
err error // Expected read error
bytes int // Expected num bytes read
}{{
name: "short header",
buf: nil,
pver: pver,
dcrnet: dcrnet,
max: 0,
err: io.EOF,
bytes: 0,
}, {
name: "wrong network, want mainnet, giving testnet",
buf: testNetBytes,
pver: pver,
dcrnet: dcrnet,
max: len(testNetBytes),
err: ErrWrongNetwork,
bytes: len(testNetBytes),
}, {
name: "exceed max overall message payload length",
buf: exceedMaxPayloadBytes,
pver: pver,
dcrnet: dcrnet,
max: len(exceedMaxPayloadBytes),
err: ErrPayloadTooLarge,
bytes: len(exceedMaxPayloadBytes),
}, {
name: "invalid utf-8 command",
buf: badCommandBytes,
pver: pver,
dcrnet: dcrnet,
max: len(badCommandBytes),
err: ErrMalformedCmd,
bytes: len(badCommandBytes),
}, {
name: "valid, but unsupported command",
buf: unsupportedCommandBytes,
pver: pver,
dcrnet: dcrnet,
max: len(unsupportedCommandBytes),
err: ErrUnknownCmd,
bytes: len(unsupportedCommandBytes),
}, {
name: "exceed max allowed payload for a message of a specific type",
buf: exceedTypePayloadBytes,
pver: pver,
dcrnet: dcrnet,
max: len(exceedTypePayloadBytes),
err: ErrPayloadTooLarge,
bytes: len(exceedTypePayloadBytes),
}, {
name: "payload shorter than the header indicates",
buf: shortPayloadBytes,
pver: pver,
dcrnet: dcrnet,
max: len(shortPayloadBytes),
err: io.EOF,
bytes: len(shortPayloadBytes),
}, {
// Message with a bad checksum. [7]
{
badChecksumBytes,
pver,
dcrnet,
len(badChecksumBytes),
&MessageError{},
26,
},

// Message with a valid header, but wrong format. [8]
{
badMessageBytes,
AddrV2Version - 1,
dcrnet,
len(badMessageBytes),
&MessageError{},
25,
},

// 15k bytes of data to discard. [9]
{
discardBytes,
pver,
dcrnet,
len(discardBytes),
&MessageError{},
24,
},
}
name: "bad checksum",
buf: badChecksumBytes,
pver: pver,
dcrnet: dcrnet,
max: len(badChecksumBytes),
err: ErrPayloadChecksum,
bytes: len(badChecksumBytes),
}, {
name: "valid header, but wrong message body format",
buf: badMessageBytes,
pver: AddrV2Version - 1,
dcrnet: dcrnet,
max: len(badMessageBytes),
err: io.EOF,
bytes: len(badMessageBytes),
}, {
name: "valid header and message with extra trailing bytes",
buf: trailingBytes,
pver: pver,
dcrnet: dcrnet,
max: len(trailingBytes),
err: ErrTrailingBytes,
bytes: len(trailingBytes),
}}

t.Logf("Running %d tests", len(tests))
for i, test := range tests {
for _, test := range tests {
// Decode from wire format.
r := newFixedReader(test.max, test.buf)
nr, _, _, err := ReadMessageN(r, test.pver, test.dcrnet)
if reflect.TypeOf(err) != reflect.TypeOf(test.readErr) {
t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+
"want: %T", i, err, err, test.readErr)
if !errors.Is(err, test.err) {
t.Errorf("%q: wrong error: got %v <%[2]T>, want: %v <%[3]T>",
test.name, err, test.err)
continue
}

// Ensure the number of bytes written match the expected value.
// Ensure the number of bytes read matches the expected value.
if nr != test.bytes {
t.Errorf("ReadMessage #%d unexpected num bytes read - "+
"got %d, want %d", i, nr, test.bytes)
}

// For errors which are not of type MessageError, check them for
// equality.
var merr *MessageError
if !errors.As(err, &merr) {
if !errors.Is(err, test.readErr) {
t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+
"want: %v <%T>", i, err, err,
test.readErr, test.readErr)
continue
}
t.Errorf("%q: unexpected num bytes read - got %d, want %d",
test.name, nr, test.bytes)
}
}
}
Expand Down