From 7b213b786a325d2a00d62b1b211a2aed4813d420 Mon Sep 17 00:00:00 2001 From: Josh Rickmar Date: Fri, 12 Jan 2024 21:41:14 +0000 Subject: [PATCH] wire-mix: review items --- wire/common.go | 8 + wire/common_test.go | 24 +++ wire/message_test.go | 6 +- wire/msgmixciphertexts.go | 47 +++--- wire/msgmixciphertexts_test.go | 189 +++++++++++++++++++--- wire/msgmixconfirm.go | 31 ++-- wire/msgmixconfirm_test.go | 167 ++++++++++++++++--- wire/msgmixdcnet.go | 123 +++++++------- wire/msgmixdcnet_test.go | 189 ++++++++++++++++++---- wire/msgmixkeyexchange.go | 41 +++-- wire/msgmixkeyexchange_test.go | 172 ++++++++++++++++---- wire/msgmixpairreq.go | 65 +++++--- wire/msgmixpairreq_test.go | 283 +++++++++++++++++++++++++++++---- wire/msgmixsecrets.go | 114 +++++++++---- wire/msgmixsecrets_test.go | 180 +++++++++++++++++---- wire/msgmixslotreserve.go | 47 +++--- wire/msgmixslotreserve_test.go | 198 +++++++++++++++++++---- 17 files changed, 1517 insertions(+), 367 deletions(-) diff --git a/wire/common.go b/wire/common.go index d9485f239f..59f254c5c3 100644 --- a/wire/common.go +++ b/wire/common.go @@ -517,6 +517,14 @@ func writeElement(w io.Writer, element interface{}) error { } return nil + // Mix identity + case *[33]byte: + _, err := w.Write(e[:]) + if err != nil { + return err + } + return nil + // Mix signature case *[64]byte: _, err := w.Write(e[:]) diff --git a/wire/common_test.go b/wire/common_test.go index 4901c67492..e4fdd8abfc 100644 --- a/wire/common_test.go +++ b/wire/common_test.go @@ -907,3 +907,27 @@ func TestRandomUint64Errors(t *testing.T) { t.Errorf("Nonce is not 0 [%v]", nonce) } } + +// repeat returns the byte slice containing count elements of the byte b. +func repeat(b byte, count int) []byte { + s := make([]byte, count) + for i := range s { + s[i] = b + } + return s +} + +// rhash returns a chainhash.Hash with all bytes set to b. +func rhash(b byte) chainhash.Hash { + var h chainhash.Hash + for i := range h { + h[i] = b + } + return h +} + +// varBytesLen returns the size required to encode l bytes as a varint +// followed by the bytes themselves. +func varBytesLen(l uint32) uint32 { + return uint32(VarIntSerializeSize(uint64(l))) + l +} diff --git a/wire/message_test.go b/wire/message_test.go index 9772c2558e..c14fd7b4e9 100644 --- a/wire/message_test.go +++ b/wire/message_test.go @@ -1,5 +1,5 @@ // Copyright (c) 2013-2016 The btcsuite developers -// Copyright (c) 2015-2023 The Decred developers +// Copyright (c) 2015-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -89,7 +89,7 @@ func TestMessage(t *testing.T) { msgMixSR := NewMsgMixSlotReserve([33]byte{}, [32]byte{}, 1, 1, [][][]byte{{{}}}, []chainhash.Hash{}) msgMixDC := NewMsgMixDCNet([33]byte{}, [32]byte{}, 1, 1, []MixVect{make(MixVect, 1)}, []chainhash.Hash{}) msgMixCM := NewMsgMixConfirm([33]byte{}, [32]byte{}, 1, 1, NewMsgTx(), []chainhash.Hash{}) - msgMixRS := NewMsgMixSecrets([33]byte{}, [32]byte{}, 1, 1, [32]byte{}, [][]byte{}, [][]byte{}) + msgMixRS := NewMsgMixSecrets([33]byte{}, [32]byte{}, 1, 1, [32]byte{}, [][]byte{}, MixVect{}) tests := []struct { in Message // Value to encode @@ -128,7 +128,7 @@ func TestMessage(t *testing.T) { {msgMixSR, msgMixSR, pver, MainNet, 165}, {msgMixDC, msgMixDC, pver, MainNet, 185}, {msgMixCM, msgMixCM, pver, MainNet, 177}, - {msgMixRS, msgMixRS, pver, MainNet, 209}, + {msgMixRS, msgMixRS, pver, MainNet, 196}, } t.Logf("Running %d tests", len(tests)) diff --git a/wire/msgmixciphertexts.go b/wire/msgmixciphertexts.go index e18a208bed..97695d71e8 100644 --- a/wire/msgmixciphertexts.go +++ b/wire/msgmixciphertexts.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -46,14 +46,14 @@ func (msg *MsgMixCiphertexts) BtcDecode(r io.Reader, pver uint32) error { return err } - // Count is of both Ciphertexts and seen KeyExchanges. + // Count is of both Ciphertexts and seen SeenKeyExchanges. count, err := ReadVarInt(r, pver) if err != nil { return err } - if count > MaxPrevMixMsgs { + if count > MaxMixPeers { msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) + count, MaxMixPeers) return messageError(op, ErrTooManyPrevMixMsgs, msg) } @@ -117,14 +117,14 @@ func (msg *MsgMixCiphertexts) Hash() chainhash.Hash { // panic. This method is designed to work only with hashers returned by // blake256.New. func (msg *MsgMixCiphertexts) WriteHash(h hash.Hash) { - h.Reset() - writeElement(h, &msg.Signature) - msg.writeMessageNoSignature("", h, MixVersion) - sum := h.Sum(msg.hash[:0]) - if len(sum) != len(msg.hash) { + if h.Size() != chainhash.HashSize { s := fmt.Sprintf("hasher type %T has invalid Size() for chainhash.Hash", h) panic(s) } + h.Reset() + writeElement(h, &msg.Signature) + msg.writeMessageNoSignature("", h, MixVersion) + h.Sum(msg.hash[:0]) } // writeMessageNoSignature serializes all elements of the message except for @@ -136,23 +136,25 @@ func (msg *MsgMixCiphertexts) WriteHash(h hash.Hash) { func (msg *MsgMixCiphertexts) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { _, hashing := w.(hash.Hash) - err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Expiry, - msg.Run) - if err != nil { - return err - } - count := len(msg.Ciphertexts) if !hashing && count != len(msg.SeenKeyExchanges) { - msg := "differing counts of ciphertexts and seen key exchange messages" + msg := fmt.Sprintf("differing counts of ciphertexts (%d) "+ + "and seen key exchange messages (%d)", count, + len(msg.SeenKeyExchanges)) return messageError(op, ErrInvalidMsg, msg) } - if !hashing && count > MaxPrevMixMsgs { + if !hashing && count > MaxMixPeers { msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) + count, MaxMixPeers) return messageError(op, ErrTooManyPrevMixMsgs, msg) } + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Expiry, + msg.Run) + if err != nil { + return err + } + err = WriteVarInt(w, pver, uint64(count)) if err != nil { return err @@ -173,10 +175,10 @@ func (msg *MsgMixCiphertexts) writeMessageNoSignature(op string, w io.Writer, pv return nil } -// WriteSigned writes a tag identifying the message data, followed by all +// WriteSignedData writes a tag identifying the message data, followed by all // message fields excluding the signature. This is the data committed to when // the message is signed. -func (msg *MsgMixCiphertexts) WriteSigned(h hash.Hash) { +func (msg *MsgMixCiphertexts) WriteSignedData(h hash.Hash) { WriteVarString(h, MixVersion, CmdMixCiphertexts+"-sig") msg.writeMessageNoSignature("", h, MixVersion) } @@ -190,6 +192,11 @@ func (msg *MsgMixCiphertexts) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgMixCiphertexts) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. return 552588 } diff --git a/wire/msgmixciphertexts_test.go b/wire/msgmixciphertexts_test.go index 6082e88781..1b18478036 100644 --- a/wire/msgmixciphertexts_test.go +++ b/wire/msgmixciphertexts_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -6,6 +6,8 @@ package wire import ( "bytes" + "errors" + "fmt" "reflect" "testing" @@ -13,27 +15,43 @@ import ( "github.com/decred/dcrd/chaincfg/chainhash" ) -func TestMixCTWire(t *testing.T) { - pver := MixVersion +// expectedSerializationCompare compares serialized bytes to the expected +// sequence of bytes. When got and expected are not equal, the test t will be +// errored with descriptive messages of how the two encodings are different. +// Returns true if the the serialization are equal, and false if the test +// errors. +func expectedSerializationEqual(t *testing.T, got, expected []byte) bool { + if bytes.Equal(got, expected) { + return true + } - repeat := func(b byte, count int) []byte { - s := make([]byte, count) - for i := range s { - s[i] = b + t.Errorf("encoded message differs from expected serialization") + minLen := len(expected) + if len(got) < minLen { + minLen = len(got) + } + for i := 0; i < minLen; i++ { + if b := got[i]; b != expected[i] { + t.Errorf("message differs at index %d (got 0x%x, expected 0x%x)", + i, b, expected[i]) } - return s } + if len(got) > len(expected) { + t.Errorf("serialized message contains extra bytes [%x]", + got[len(expected):]) + } + if len(expected) > len(got) { + t.Errorf("serialization prematurely ends at index %d, missing bytes [%x]", + len(got), expected[len(got):]) + } + return false +} - // Create a fictitious message with easily-distinguishable fields. - - var sig [64]byte - copy(sig[:], repeat(0x80, 64)) - - var id [33]byte - copy(id[:], repeat(0x81, 33)) - - var sid [32]byte - copy(sid[:], repeat(0x82, 32)) +func newTestMixCiphertexts() *MsgMixCiphertexts { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) const expiry = uint32(0x83838383) const run = uint32(0x84848484) @@ -51,12 +69,41 @@ func TestMixCTWire(t *testing.T) { ct := NewMsgMixCiphertexts(id, sid, expiry, run, cts, seenKEs) ct.Signature = sig + return ct +} + +func TestMsgMixCiphertextsWire(t *testing.T) { + pver := MixVersion + + ct := newTestMixCiphertexts() + buf := new(bytes.Buffer) err := ct.BtcEncode(buf, pver) if err != nil { t.Fatal(err) } + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Expiry + expected = append(expected, repeat(0x84, 4)...) // Run + // Varint count of ciphertexts and seen KEs + expected = append(expected, 0x04) + // Four ciphextexts (repeating 1047 bytes of 0x85, 0x86, 0x87, 0x88) + expected = append(expected, repeat(0x85, 1047)...) + expected = append(expected, repeat(0x86, 1047)...) + expected = append(expected, repeat(0x87, 1047)...) + expected = append(expected, repeat(0x88, 1047)...) + // Four seen KEs (repeating 32 bytes of 0x89, 0x8a, 0x8b, 0x8c) + expected = append(expected, repeat(0x89, 32)...) + expected = append(expected, repeat(0x8a, 32)...) + expected = append(expected, repeat(0x8b, 32)...) + expected = append(expected, repeat(0x8c, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + decodedCT := new(MsgMixCiphertexts) err = decodedCT.BtcDecode(bytes.NewReader(buf.Bytes()), pver) if err != nil { @@ -66,8 +113,108 @@ func TestMixCTWire(t *testing.T) { if !reflect.DeepEqual(ct, decodedCT) { t.Errorf("BtcDecode got: %s want: %s", spew.Sdump(decodedCT), spew.Sdump(ct)) - } else { - t.Logf("bytes: %x", buf.Bytes()) - t.Logf("spew: %s", spew.Sdump(decodedCT)) + } +} + +func TestMsgMixCiphertextsCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixCiphertexts() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixCiphertexts) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixCiphertextsMaxPayloadLength tests the results returned by +// [MsgMixCiphertexts.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixCiphertextsMaxPayloadLength(t *testing.T) { + var ct *MsgMixCiphertexts + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := ct.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Expiry + 4 + // Run + uint32(VarIntSerializeSize(MaxMixPeers)) + // Ciphextext and KE hash count + MaxMixPeers*1047 + // Ciphextexts + 32*MaxMixPeers // Key exchange hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := ct.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) } } diff --git a/wire/msgmixconfirm.go b/wire/msgmixconfirm.go index 03de5d32d4..9b11a13eba 100644 --- a/wire/msgmixconfirm.go +++ b/wire/msgmixconfirm.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -58,9 +58,9 @@ func (msg *MsgMixConfirm) BtcDecode(r io.Reader, pver uint32) error { if err != nil { return err } - if count > MaxPrevMixMsgs { + if count > MaxMixPeers { msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) + count, MaxMixPeers) return messageError(op, ErrTooManyPrevMixMsgs, msg) } @@ -134,6 +134,13 @@ func (msg *MsgMixConfirm) WriteHash(h hash.Hash) { func (msg *MsgMixConfirm) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { _, hashing := w.(hash.Hash) + count := len(msg.SeenDCNets) + if !hashing && count > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + count, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) + } + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Expiry, msg.Run) if err != nil { @@ -145,13 +152,6 @@ func (msg *MsgMixConfirm) writeMessageNoSignature(op string, w io.Writer, pver u return err } - count := len(msg.SeenDCNets) - if !hashing && count > MaxPrevMixMsgs { - msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) - return messageError(op, ErrTooManyPrevMixMsgs, msg) - } - err = WriteVarInt(w, pver, uint64(count)) if err != nil { return err @@ -166,10 +166,10 @@ func (msg *MsgMixConfirm) writeMessageNoSignature(op string, w io.Writer, pver u return nil } -// WriteSigned writes a tag identifying the message data, followed by all +// WriteSignedData writes a tag identifying the message data, followed by all // message fields excluding the signature. This is the data committed to when // the message is signed. -func (msg *MsgMixConfirm) WriteSigned(h hash.Hash) { +func (msg *MsgMixConfirm) WriteSignedData(h hash.Hash) { WriteVarString(h, MixVersion, CmdMixConfirm+"-sig") msg.writeMessageNoSignature("", h, MixVersion) } @@ -183,7 +183,12 @@ func (msg *MsgMixConfirm) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgMixConfirm) MaxPayloadLength(pver uint32) uint32 { - return 16539 + MaxBlockPayloadV3 + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. + return 1016524 } // Pub returns the message sender's public key identity. diff --git a/wire/msgmixconfirm_test.go b/wire/msgmixconfirm_test.go index 9f95754268..74d38714db 100644 --- a/wire/msgmixconfirm_test.go +++ b/wire/msgmixconfirm_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -6,6 +6,8 @@ package wire import ( "bytes" + "errors" + "fmt" "reflect" "testing" @@ -13,27 +15,11 @@ import ( "github.com/decred/dcrd/chaincfg/chainhash" ) -func TestMixCMWire(t *testing.T) { - pver := MixVersion - - repeat := func(b byte, count int) []byte { - s := make([]byte, count) - for i := range s { - s[i] = b - } - return s - } - - // Create a fictitious message with easily-distinguishable fields. - - var sig [64]byte - copy(sig[:], repeat(0x80, 64)) - - var id [33]byte - copy(id[:], repeat(0x81, 33)) - - var sid [32]byte - copy(sid[:], repeat(0x82, 32)) +func newTestMixConfirm() *MsgMixConfirm { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) const expiry = uint32(0x83838383) const run = uint32(0x84848484) @@ -48,12 +34,43 @@ func TestMixCMWire(t *testing.T) { cm := NewMsgMixConfirm(id, sid, expiry, run, mix, seenDCs) cm.Signature = sig + return cm +} + +func TestMsgMixConfirmWire(t *testing.T) { + pver := MixVersion + + cm := newTestMixConfirm() + buf := new(bytes.Buffer) err := cm.BtcEncode(buf, pver) if err != nil { t.Fatal(err) } + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Expiry + expected = append(expected, repeat(0x84, 4)...) // Run + expected = append(expected, []byte{ // Mix transaction + 0x01, 0x00, 0x00, 0x00, // Version + 0x00, // Varint for number of input transactions + 0x00, // Varint for number of output transactions + 0x00, 0x00, 0x00, 0x00, // Lock time + 0x00, 0x00, 0x00, 0x00, // Expiry + 0x00, // Varint for number of input signatures + }...) + // Four seen DCs (repeating 32 bytes of 0x85, 0x86, 0x87, 0x88) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x85, 32)...) + expected = append(expected, repeat(0x86, 32)...) + expected = append(expected, repeat(0x87, 32)...) + expected = append(expected, repeat(0x88, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + decodedCM := new(MsgMixConfirm) err = decodedCM.BtcDecode(bytes.NewReader(buf.Bytes()), pver) if err != nil { @@ -63,8 +80,108 @@ func TestMixCMWire(t *testing.T) { if !reflect.DeepEqual(cm, decodedCM) { t.Errorf("BtcDecode got: %s want: %s", spew.Sdump(decodedCM), spew.Sdump(cm)) - } else { - t.Logf("bytes: %x", buf.Bytes()) - t.Logf("spew: %s", spew.Sdump(decodedCM)) + } +} + +func TestMsgMixConfirmCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixConfirm() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixConfirm) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixConfirmMaxPayloadLength tests the results returned by +// [MsgMixConfirm.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixConfirmMaxPayloadLength(t *testing.T) { + var cm *MsgMixConfirm + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := cm.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Expiry + 4 + // Run + MaxBlockPayloadV3 + // Maximum transaction size + uint32(VarIntSerializeSize(MaxMixPeers)) + // DC-net count + 32*MaxMixPeers // DC-net hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := cm.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) } } diff --git a/wire/msgmixdcnet.go b/wire/msgmixdcnet.go index 317a51c855..7b91c7feed 100644 --- a/wire/msgmixdcnet.go +++ b/wire/msgmixdcnet.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -44,32 +44,20 @@ func (msg *MsgMixDCNet) BtcDecode(r io.Reader, pver uint32) error { return err } - mcount, err := ReadVarInt(r, pver) + var dcnet []MixVect + err = readMixVects(op, r, pver, &dcnet) if err != nil { return err } - if mcount > MaxMixMcount { - msg := fmt.Sprintf("too many total mixed messages [count %v, max %v]", - mcount, MaxMixMcount) - return messageError(op, ErrInvalidMsg, msg) - } - - dcnet := make([]MixVect, mcount) - for i := range dcnet { - err := readMixVect(op, r, pver, &dcnet[i]) - if err != nil { - return err - } - } msg.DCNet = dcnet count, err := ReadVarInt(r, pver) if err != nil { return err } - if count > MaxPrevMixMsgs { + if count > MaxMixPeers { msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) + count, MaxMixPeers) return messageError(op, ErrTooManyPrevMixMsgs, msg) } @@ -143,12 +131,6 @@ func (msg *MsgMixDCNet) WriteHash(h hash.Hash) { func (msg *MsgMixDCNet) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { _, hashing := w.(hash.Hash) - err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Expiry, - msg.Run) - if err != nil { - return err - } - mcount := len(msg.DCNet) if !hashing && mcount == 0 { msg := fmt.Sprintf("too few mixed messages [%v]", mcount) @@ -158,26 +140,25 @@ func (msg *MsgMixDCNet) writeMessageNoSignature(op string, w io.Writer, pver uin msg := fmt.Sprintf("too many total mixed messages [%v]", mcount) return messageError(op, ErrInvalidMsg, msg) } - err = WriteVarInt(w, pver, uint64(mcount)) - if err != nil { - return err + srcount := len(msg.SeenSlotReserves) + if !hashing && srcount > MaxMixPeers { + msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", + srcount, MaxMixPeers) + return messageError(op, ErrTooManyPrevMixMsgs, msg) } - for i := range msg.DCNet { - err := writeMixVect(w, pver, msg.DCNet[i]) - if err != nil { - return err - } + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Expiry, + msg.Run) + if err != nil { + return err } - count := len(msg.SeenSlotReserves) - if !hashing && count > MaxPrevMixMsgs { - msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) - return messageError(op, ErrTooManyPrevMixMsgs, msg) + err = writeMixVects(w, pver, msg.DCNet) + if err != nil { + return err } - err = WriteVarInt(w, pver, uint64(count)) + err = WriteVarInt(w, pver, uint64(srcount)) if err != nil { return err } @@ -191,8 +172,16 @@ func (msg *MsgMixDCNet) writeMessageNoSignature(op string, w io.Writer, pver uin return nil } -func writeMixVect(w io.Writer, pver uint32, vec MixVect) error { - err := WriteVarInt(w, pver, uint64(len(vec))) +func writeMixVects(w io.Writer, pver uint32, vecs []MixVect) error { + // Write dimensions + err := WriteVarInt(w, pver, uint64(len(vecs))) + if err != nil { + return err + } + if len(vecs) == 0 { + return nil + } + err = WriteVarInt(w, pver, uint64(len(vecs[0]))) if err != nil { return err } @@ -201,49 +190,66 @@ func writeMixVect(w io.Writer, pver uint32, vec MixVect) error { return err } - for i := range vec { - err = writeElement(w, &vec[i]) - if err != nil { - return err + // Write messages + for i := range vecs { + for j := range vecs[i] { + err = writeElement(w, &vecs[i][j]) + if err != nil { + return err + } } } return nil } -func readMixVect(op string, r io.Reader, pver uint32, vec *MixVect) error { - n, err := ReadVarInt(r, pver) +func readMixVects(op string, r io.Reader, pver uint32, vecs *[]MixVect) error { + // Read dimensions + x, err := ReadVarInt(r, pver) if err != nil { return err } - if n > MaxMixKPCount { - msg := "too many mixing peers" - return messageError(op, ErrInvalidMsg, msg) + if x == 0 { + return nil + } + y, err := ReadVarInt(r, pver) + if err != nil { + return err } msize, err := ReadVarInt(r, pver) if err != nil { return err } + + if x > MaxMixMcount || y > MaxMixMcount { + msg := "DC-net mix vector dimensions are too large for maximum message count" + return messageError(op, ErrInvalidMsg, msg) + } if msize != MixMsgSize { msg := fmt.Sprintf("mixed message length must be %d [got: %d]", MixMsgSize, msize) return messageError(op, ErrInvalidMsg, msg) } - *vec = make(MixVect, n) - for i := uint64(0); i < n; i++ { - err = readElement(r, &(*vec)[i]) - if err != nil { - return err + + // Read messages + *vecs = make([]MixVect, x) + for i := uint64(0); i < x; i++ { + (*vecs)[i] = make(MixVect, y) + for j := uint64(0); j < y; j++ { + err = readElement(r, &(*vecs)[i][j]) + if err != nil { + return err + } } } return nil } -// WriteSigned writes a tag identifying the message data, followed by all +// WriteSignedData writes a tag identifying the message data, followed by all // message fields excluding the signature. This is the data committed to when // the message is signed. -func (msg *MsgMixDCNet) WriteSigned(h hash.Hash) { +func (msg *MsgMixDCNet) WriteSignedData(h hash.Hash) { WriteVarString(h, MixVersion, CmdMixDCNet+"-sig") msg.writeMessageNoSignature("", h, MixVersion) } @@ -257,7 +263,12 @@ func (msg *MsgMixDCNet) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgMixDCNet) MaxPayloadLength(pver uint32) uint32 { - return 16800911 + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. + return 20988051 } // Pub returns the message sender's public key identity. diff --git a/wire/msgmixdcnet_test.go b/wire/msgmixdcnet_test.go index 974d498534..435acbcce0 100644 --- a/wire/msgmixdcnet_test.go +++ b/wire/msgmixdcnet_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -6,6 +6,8 @@ package wire import ( "bytes" + "errors" + "fmt" "reflect" "testing" @@ -13,40 +15,24 @@ import ( "github.com/decred/dcrd/chaincfg/chainhash" ) -func TestMixDCWire(t *testing.T) { - pver := MixVersion - - repeat := func(b byte, count int) []byte { - s := make([]byte, count) - for i := range s { - s[i] = b - } - return s - } - - // Create a fictitious message with easily-distinguishable fields. - - var sig [64]byte - copy(sig[:], repeat(0x80, 64)) - - var id [33]byte - copy(id[:], repeat(0x81, 33)) - - var sid [32]byte - copy(sid[:], repeat(0x82, 32)) +func newTestMixDCNet() *MsgMixDCNet { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) const expiry = uint32(0x83838383) const run = uint32(0x84848484) - mcount := 4 - var kpcount uint32 = 4 + const mcount = 4 + const kpcount = 4 dcnet := make([]MixVect, mcount) // will add 4x4 field numbers of incrementing repeating byte values to // dcnet, ranging from 0x85 through 0x94 b := byte(0x85) for i := 0; i < mcount; i++ { dcnet[i] = make(MixVect, kpcount) - for j := 0; j < int(kpcount); j++ { + for j := 0; j < kpcount; j++ { copy(dcnet[i][j][:], repeat(b, 32)) b++ } @@ -60,12 +46,56 @@ func TestMixDCWire(t *testing.T) { dc := NewMsgMixDCNet(id, sid, expiry, run, dcnet, seenSRs) dc.Signature = sig + return dc +} + +func TestMsgMixDCNetWire(t *testing.T) { + pver := MixVersion + + dc := newTestMixDCNet() + buf := new(bytes.Buffer) err := dc.BtcEncode(buf, pver) if err != nil { t.Fatal(err) } + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Expiry + expected = append(expected, repeat(0x84, 4)...) // Run + // DC-net dimensions (4x4, message size of 20) + expected = append(expected, 0x04) + expected = append(expected, 0x04) + expected = append(expected, 0x14) // msize + // 16 padded messages + expected = append(expected, repeat(0x85, 20)...) + expected = append(expected, repeat(0x86, 20)...) + expected = append(expected, repeat(0x87, 20)...) + expected = append(expected, repeat(0x88, 20)...) + expected = append(expected, repeat(0x89, 20)...) + expected = append(expected, repeat(0x8a, 20)...) + expected = append(expected, repeat(0x8b, 20)...) + expected = append(expected, repeat(0x8c, 20)...) + expected = append(expected, repeat(0x8d, 20)...) + expected = append(expected, repeat(0x8e, 20)...) + expected = append(expected, repeat(0x8f, 20)...) + expected = append(expected, repeat(0x90, 20)...) + expected = append(expected, repeat(0x91, 20)...) + expected = append(expected, repeat(0x92, 20)...) + expected = append(expected, repeat(0x93, 20)...) + expected = append(expected, repeat(0x94, 20)...) + // Four seen DCs (repeating 32 bytes of 0x95, 0x96, 0x97, 0x98) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x95, 32)...) + expected = append(expected, repeat(0x96, 32)...) + expected = append(expected, repeat(0x97, 32)...) + expected = append(expected, repeat(0x98, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + decodedDC := new(MsgMixDCNet) err = decodedDC.BtcDecode(bytes.NewReader(buf.Bytes()), pver) if err != nil { @@ -75,8 +105,111 @@ func TestMixDCWire(t *testing.T) { if !reflect.DeepEqual(dc, decodedDC) { t.Errorf("BtcDecode got: %s want: %s", spew.Sdump(decodedDC), spew.Sdump(dc)) - } else { - t.Logf("bytes: %x", buf.Bytes()) - t.Logf("spew: %s", spew.Sdump(decodedDC)) + } +} + +func TestMsgMixDCNetCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixDCNet() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixDCNet) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixDCNetMaxPayloadLength tests the results returned by +// [MsgMixDCNet.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixDCNetMaxPayloadLength(t *testing.T) { + var dc *MsgMixDCNet + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := dc.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Expiry + 4 + // Run + uint32(VarIntSerializeSize(MaxMixMcount)) + // Message count (our mcount) + uint32(VarIntSerializeSize(MaxMixMcount)) + // Message count (total) + uint32(VarIntSerializeSize(MixMsgSize)) + // Message size + MaxMixMcount*MaxMixMcount*MixMsgSize + // Padded DC-net values + uint32(VarIntSerializeSize(MaxMixPeers)) + // Slot reserve count + 32*MaxMixPeers // Slot reserve hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := dc.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) } } diff --git a/wire/msgmixkeyexchange.go b/wire/msgmixkeyexchange.go index 204cb2ca7e..88e573a583 100644 --- a/wire/msgmixkeyexchange.go +++ b/wire/msgmixkeyexchange.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -13,9 +13,13 @@ import ( ) const ( - // MaxPrevMixMsgs is the maximum number previous messages of a mix run - // that may be referenced by a message. - MaxPrevMixMsgs = 512 // XXX: PNOOMA + // MaxMixKPCount is the maximum number of peers allowed together in a + // single mix. This restricts the maximum dimensions of the slot + // reservation and XOR DC-net matrices and the maximum number of + // previous messages that may be referenced by mix messages. + // This value is an high estimate of what a large mix may look like, + // based on statistics from the centralized mixing server. + MaxMixPeers = 512 ) // MsgMixKeyExchange implements the Message interface and represents a mixing key @@ -59,9 +63,9 @@ func (msg *MsgMixKeyExchange) BtcDecode(r io.Reader, pver uint32) error { if err != nil { return err } - if count > MaxPrevMixMsgs { + if count > MaxMixPeers { msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) + count, MaxMixPeers) return messageError(op, ErrTooManyPrevMixMsgs, msg) } @@ -135,20 +139,20 @@ func (msg *MsgMixKeyExchange) WriteHash(h hash.Hash) { func (msg *MsgMixKeyExchange) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { _, hashing := w.(hash.Hash) - err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Expiry, - msg.Run, &msg.ECDH, &msg.PQPK, &msg.Commitment) - if err != nil { - return err - } - // Limit to max previous messages hashes. count := len(msg.SeenPRs) - if !hashing && count > MaxPrevMixMsgs { + if !hashing && count > MaxMixPeers { msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) + count, MaxMixPeers) return messageError(op, ErrTooManyPrevMixMsgs, msg) } + err := writeElements(w, &msg.Identity, &msg.SessionID, msg.Expiry, + msg.Run, &msg.ECDH, &msg.PQPK, &msg.Commitment) + if err != nil { + return err + } + err = WriteVarInt(w, pver, uint64(count)) if err != nil { return err @@ -163,10 +167,10 @@ func (msg *MsgMixKeyExchange) writeMessageNoSignature(op string, w io.Writer, pv return nil } -// WriteSigned writes a tag identifying the message data, followed by all +// WriteSignedData writes a tag identifying the message data, followed by all // message fields excluding the signature. This is the data committed to when // the message is signed. -func (msg *MsgMixKeyExchange) WriteSigned(h hash.Hash) { +func (msg *MsgMixKeyExchange) WriteSignedData(h hash.Hash) { WriteVarString(h, MixVersion, CmdMixKeyExchange+"-sig") msg.writeMessageNoSignature("", h, MixVersion) } @@ -180,6 +184,11 @@ func (msg *MsgMixKeyExchange) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgMixKeyExchange) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. return 17807 } diff --git a/wire/msgmixkeyexchange_test.go b/wire/msgmixkeyexchange_test.go index 783b05b93c..53f5044f6b 100644 --- a/wire/msgmixkeyexchange_test.go +++ b/wire/msgmixkeyexchange_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -6,6 +6,8 @@ package wire import ( "bytes" + "errors" + "fmt" "reflect" "testing" @@ -13,39 +15,18 @@ import ( "github.com/decred/dcrd/chaincfg/chainhash" ) -func TestMixKEWire(t *testing.T) { - pver := MixVersion - - repeat := func(b byte, count int) []byte { - s := make([]byte, count) - for i := range s { - s[i] = b - } - return s - } - - // Create a fictitious message with easily-distinguishable fields. - - var sig [64]byte - copy(sig[:], repeat(0x80, 64)) - - var id [33]byte - copy(id[:], repeat(0x81, 33)) - - var sid [32]byte - copy(sid[:], repeat(0x82, 32)) +func newTestMixKeyExchange() *MsgMixKeyExchange { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) const expiry = uint32(0x83838383) const run = uint32(0x84848484) - var ecdh [33]byte - copy(ecdh[:], repeat(0x85, 33)) - - var pqpk [1218]byte - copy(pqpk[:], repeat(0x86, 1218)) - - var commitment [32]byte - copy(commitment[:], repeat(0x87, 32)) + ecdh := *(*[33]byte)(repeat(0x85, 33)) + pqpk := *(*[1218]byte)(repeat(0x86, 1218)) + commitment := *(*[32]byte)(repeat(0x87, 32)) seenPRs := make([]chainhash.Hash, 4) for b := byte(0x88); b < 0x8C; b++ { @@ -55,12 +36,38 @@ func TestMixKEWire(t *testing.T) { ke := NewMsgMixKeyExchange(id, sid, expiry, run, ecdh, pqpk, commitment, seenPRs) ke.Signature = sig + return ke +} + +func TestMsgMixKeyExchangeWire(t *testing.T) { + pver := MixVersion + + ke := newTestMixKeyExchange() + buf := new(bytes.Buffer) err := ke.BtcEncode(buf, pver) if err != nil { t.Fatal(err) } + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Expiry + expected = append(expected, repeat(0x84, 4)...) // Run + expected = append(expected, repeat(0x85, 33)...) // ECDH public key + expected = append(expected, repeat(0x86, 1218)...) // PQ public key + expected = append(expected, repeat(0x87, 32)...) // Secrets commitment + // Four seen PRs (repeating 32 bytes of 0x88, 0x89, 0x8a, 0x8b) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x88, 32)...) + expected = append(expected, repeat(0x89, 32)...) + expected = append(expected, repeat(0x8a, 32)...) + expected = append(expected, repeat(0x8b, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + decodedKE := new(MsgMixKeyExchange) err = decodedKE.BtcDecode(bytes.NewReader(buf.Bytes()), pver) if err != nil { @@ -75,3 +82,108 @@ func TestMixKEWire(t *testing.T) { t.Logf("spew: %s", spew.Sdump(decodedKE)) } } + +func TestMsgMixKeyExchangeCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixKeyExchange() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixKeyExchange) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixKeyExchangeMaxPayloadLength tests the results returned by +// [MsgMixKeyExchange.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixKeyExchangeMaxPayloadLength(t *testing.T) { + var ke *MsgMixKeyExchange + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := ke.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Expiry + 4 + // Run + 33 + // ECDH public key + 1218 + // sntrup4591761 public key + 32 + // Secrets commitment + uint32(VarIntSerializeSize(MaxMixPeers)) + // Pair request count + 32*MaxMixPeers // Pair request hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := ke.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) + } +} diff --git a/wire/msgmixpairreq.go b/wire/msgmixpairreq.go index 7e8b8300ed..4b6c74e98f 100644 --- a/wire/msgmixpairreq.go +++ b/wire/msgmixpairreq.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -20,11 +20,12 @@ const ( // MaxMixPairReqUTXOs is the maximum number of unspent transaction // outputs that may be contributed in a single mixpairreq message. - MaxMixPairReqUTXOs = 512 // XXX: PNOOMA + // This value is an high estimate of what a large mix may look like, + // based on statistics from the centralized mixing server. + MaxMixPairReqUTXOs = 512 // MaxMixPairReqUTXOScriptLen is the maximum length allowed for the // unhashed P2SH script of a UTXO ownership proof. - // XXX: might want to limit this to standard script sizes MaxMixPairReqUTXOScriptLen = 16384 // txscript.MaxScriptSize // MaxMixPairReqUTXOPubKeyLen is the maximum length allowed for the @@ -71,10 +72,15 @@ type MsgMixPairReq struct { } // Pairing returns a description of the coinjoin transaction being created. -// Different mixpairreq messages area compatible to perform a mix together if +// Different mixpairreq messages are compatible to perform a mix together if // their pairing descriptions are identical. func (msg *MsgMixPairReq) Pairing() ([]byte, error) { - w := bytes.NewBuffer(make([]byte, 0, 8+32+2+4)) + bufLen := 8 + // Mix amount + VarIntSerializeSize(uint64(len(msg.ScriptClass))) + // Script class + len(msg.ScriptClass) + + 2 + // Tx version + 4 // Locktime + w := bytes.NewBuffer(make([]byte, 0, bufLen)) err := writeElement(w, msg.MixAmount) if err != nil { @@ -110,6 +116,11 @@ func (msg *MsgMixPairReq) BtcDecode(r io.Reader, pver uint32) error { return err } + if msg.MixAmount < 0 { + msg := "mixing pair request contains negative mixed amount" + return messageError(op, ErrInvalidMsg, msg) + } + sc, err := ReadAsciiVarString(r, pver, MaxMixPairReqScriptClassLen) if err != nil { return err @@ -122,6 +133,11 @@ func (msg *MsgMixPairReq) BtcDecode(r io.Reader, pver uint32) error { return err } + if msg.InputValue < 0 { + msg := "mixing pair request contains negative input value" + return messageError(op, ErrInvalidMsg, msg) + } + count, err := ReadVarInt(r, pver) if err != nil { return err @@ -243,6 +259,27 @@ func (msg *MsgMixPairReq) WriteHash(h hash.Hash) { func (msg *MsgMixPairReq) writeMessageNoSignature(op string, w io.Writer, pver uint32) error { _, hashing := w.(hash.Hash) + // Require script class to be strict ASCII and not exceed the maximum + // length. + lenScriptClass := len(msg.ScriptClass) + if lenScriptClass > MaxMixPairReqScriptClassLen { + msg := fmt.Sprintf("script class length is too long "+ + "[len %d, max %d]", lenScriptClass, + MaxMixPairReqScriptClassLen) + return messageError(op, ErrMixPairReqScriptClassTooLong, msg) + } + if !isStrictAscii(msg.ScriptClass) { + msg := "script class string is not strict ASCII" + return messageError(op, ErrMalformedStrictString, msg) + } + + // Limit to max UTXOs per message. + count := len(msg.UTXOs) + if !hashing && count > MaxMixPairReqUTXOs { + msg := fmt.Sprintf("too many UTXOs in message [%v]", count) + return messageError(op, ErrTooManyMixPairReqUTXOs, msg) + } + err := writeElements(w, &msg.Identity, msg.Expiry, msg.MixAmount) if err != nil { return err @@ -259,13 +296,6 @@ func (msg *MsgMixPairReq) writeMessageNoSignature(op string, w io.Writer, pver u return err } - // Limit to max UTXOs per message. - count := len(msg.UTXOs) - if !hashing && count > MaxMixPairReqUTXOs { - msg := fmt.Sprintf("too many UTXOs in message [%v]", count) - return messageError(op, ErrTooManyMixPairReqUTXOs, msg) - } - err = WriteVarInt(w, pver, uint64(count)) if err != nil { return err @@ -327,10 +357,10 @@ func (msg *MsgMixPairReq) writeMessageNoSignature(op string, w io.Writer, pver u return nil } -// WriteSigned writes a tag identifying the message data, followed by all +// WriteSignedData writes a tag identifying the message data, followed by all // message fields excluding the signature. This is the data committed to when // the message is signed. -func (msg *MsgMixPairReq) WriteSigned(h hash.Hash) { +func (msg *MsgMixPairReq) WriteSignedData(h hash.Hash) { WriteVarString(h, MixVersion, CmdMixPairReq+"-sig") msg.writeMessageNoSignature("", h, MixVersion) } @@ -348,9 +378,8 @@ func (msg *MsgMixPairReq) MaxPayloadLength(pver uint32) uint32 { return 0 } - // Pair requests contain a transaction, and the maximum transaction - // serialization size is limited to the max block payload. - return MaxBlockPayload + // See tests for this calculation. + return 8476336 } // Pub returns the message sender's public key identity. @@ -400,7 +429,7 @@ func NewMsgMixPairReq(identity [33]byte, expiry uint32, mixAmount int64, } if !isStrictAscii(scriptClass) { - msg := "individual initial state type is not strict ASCII" + msg := "script class string is not strict ASCII" return nil, messageError(op, ErrMalformedStrictString, msg) } diff --git a/wire/msgmixpairreq_test.go b/wire/msgmixpairreq_test.go index f6db8f0acd..36ac28f7cf 100644 --- a/wire/msgmixpairreq_test.go +++ b/wire/msgmixpairreq_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -6,38 +6,37 @@ package wire import ( "bytes" + "errors" + "fmt" "reflect" "testing" "github.com/davecgh/go-spew/spew" - "github.com/decred/dcrd/chaincfg/chainhash" ) -func TestMixPairReqWire(t *testing.T) { - pver := MixVersion - - repeat := func(b byte, count int) []byte { - s := make([]byte, count) - for i := range s { - s[i] = b - } - return s - } - rhash := func(b byte) chainhash.Hash { - var h chainhash.Hash - for i := range h { - h[i] = b - } - return h - } +type mixPairReqArgs struct { + identity [33]byte + signature [64]byte + expiry uint32 + mixAmount int64 + scriptClass string + txVersion uint16 + lockTime, messageCount uint32 + inputValue int64 + utxos []MixPairReqUTXO + change *TxOut +} - // Create a fictitious message with easily-distinguishable fields. +func (a *mixPairReqArgs) msg() (*MsgMixPairReq, error) { + return NewMsgMixPairReq(a.identity, a.expiry, a.mixAmount, a.scriptClass, + a.txVersion, a.lockTime, a.messageCount, a.inputValue, a.utxos, a.change) +} - var sig [64]byte - copy(sig[:], repeat(0x80, 64)) +func newMixPairReqArgs() *mixPairReqArgs { + // Use easily-distinguishable fields. - var id [33]byte - copy(id[:], repeat(0x81, 33)) + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) const expiry = uint32(0x82828282) const mixAmount = int64(0x0383838383838383) @@ -74,12 +73,32 @@ func TestMixPairReqWire(t *testing.T) { pkScript := repeat(0x94, 25) change := NewTxOut(changeValue, pkScript) - pr, err := NewMsgMixPairReq(id, expiry, mixAmount, sc, txVersion, lockTime, - messageCount, inputValue, utxos, change) + return &mixPairReqArgs{ + identity: id, + signature: sig, + expiry: expiry, + mixAmount: mixAmount, + scriptClass: sc, + txVersion: txVersion, + lockTime: lockTime, + messageCount: messageCount, + inputValue: inputValue, + utxos: utxos, + change: change, + } +} + +func TestMsgMixPairReqWire(t *testing.T) { + t.Parallel() + + pver := MixVersion + + a := newMixPairReqArgs() + pr, err := a.msg() if err != nil { t.Fatal(err) } - pr.Signature = sig + pr.Signature = a.signature buf := new(bytes.Buffer) err = pr.BtcEncode(buf, pver) @@ -87,6 +106,53 @@ func TestMixPairReqWire(t *testing.T) { t.Fatal(err) } + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 4)...) // Expiry + expected = append(expected, 0x83, 0x83, 0x83, 0x83, // Amount + 0x83, 0x83, 0x83, 0x03) + expected = append(expected, byte(len("P2PKH-secp256k1-v0"))) // Script class + expected = append(expected, []byte("P2PKH-secp256k1-v0")...) + expected = append(expected, 0x84, 0x84) // Tx version + expected = append(expected, repeat(0x85, 4)...) // Locktime + expected = append(expected, repeat(0x86, 4)...) // Message count + expected = append(expected, 0x87, 0x87, 0x87, 0x87, // Input value + 0x87, 0x87, 0x87, 0x07) + expected = append(expected, 0x02) // UTXO count + // First UTXO 8888888888888888888888888888888888888888888888888888888888888888:0x89898989 + expected = append(expected, repeat(0x88, 32)...) // Hash + expected = append(expected, repeat(0x89, 4)...) // Index + expected = append(expected, 0x0a) // Tree + expected = append(expected, 0x00) // Zero-length P2SH redeem script + expected = append(expected, 0x21) // 33-byte pubkey + expected = append(expected, repeat(0x8b, 33)...) + expected = append(expected, 0x40) // 64-byte signature + expected = append(expected, repeat(0x8c, 64)...) + // Second UTXO 8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d:0x8e8e8e8e + expected = append(expected, repeat(0x8d, 32)...) // Hash + expected = append(expected, repeat(0x8e, 4)...) // Index + expected = append(expected, 0x0f) // Tree + expected = append(expected, 0x19) // 25-byte P2SH redeem script + expected = append(expected, repeat(0x90, 25)...) + expected = append(expected, 0x21) // 33-byte pubkey + expected = append(expected, repeat(0x91, 33)...) + expected = append(expected, 0x40) // 64-byte signature + expected = append(expected, repeat(0x92, 64)...) + // Change output + expected = append(expected, 0x01) // Has change = true + expected = append(expected, []byte{ + 0x93, 0x93, 0x93, 0x93, 0x93, 0x93, 0x93, 0x13, // Amount + 0x00, 0x00, // Version + 0x19, // 25-byte Pkscript + 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, + 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, + 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, 0x94, + 0x94, + }...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + decodedPR := new(MsgMixPairReq) err = decodedPR.BtcDecode(bytes.NewReader(buf.Bytes()), pver) if err != nil { @@ -96,8 +162,165 @@ func TestMixPairReqWire(t *testing.T) { if !reflect.DeepEqual(pr, decodedPR) { t.Errorf("BtcDecode got: %s want: %s", spew.Sdump(decodedPR), spew.Sdump(pr)) - } else { - t.Logf("bytes: %x", buf.Bytes()) - t.Logf("spew: %s", spew.Sdump(decodedPR)) + } +} + +func TestNewMixPairReqErrs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + modArgs func(*mixPairReqArgs) + err error + }{{ + name: "LongScriptClass", + modArgs: func(a *mixPairReqArgs) { + a.scriptClass = "scriptclassthatexceedsmaximumlength" + }, + err: ErrMixPairReqScriptClassTooLong, + }, { + name: "NonAsciiScriptClass", + modArgs: func(a *mixPairReqArgs) { + a.scriptClass = string([]byte{128}) + }, + err: ErrMalformedStrictString, + }, { + name: "TooManyUTXOs", + modArgs: func(a *mixPairReqArgs) { + a.utxos = make([]MixPairReqUTXO, MaxMixPairReqUTXOs+1) + }, + err: ErrTooManyMixPairReqUTXOs, + }} + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + a := newMixPairReqArgs() + tc.modArgs(a) + _, err := a.msg() + if !errors.Is(err, tc.err) { + t.Errorf("expected error %v; got %v", tc.err, err) + } + }) + } +} + +func TestMsgMixPairReqCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + a := newMixPairReqArgs() + msg, err := a.msg() + if err != nil { + t.Fatalf("%v", err) + } + + buf := new(bytes.Buffer) + err = msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixPairReq) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixPairReqMaxPayloadLength tests the results returned by +// [MsgMixPairReq.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixPairReqMaxPayloadLength(t *testing.T) { + var pr *MsgMixPairReq + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := pr.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var maxUTXOLen uint32 = 32 + // Hash + 4 + // Index + 1 + // Tree + varBytesLen(MaxMixPairReqUTXOScriptLen) + // P2SH redeem script + varBytesLen(33) + // Pubkey + varBytesLen(64) // Signature + var maxTxOutLen uint32 = 8 + // Value + 2 + // Version + varBytesLen(16384) // PkScript (txscript.MaxScriptLen) + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 4 + // Expiry + 8 + // Amount + varBytesLen(MaxMixPairReqScriptClassLen) + // Script class + 2 + // Tx version + 4 + // Locktime + 4 + // Message count + 8 + // Input value + uint32(VarIntSerializeSize(MaxMixPairReqUTXOs)) + // UTXO count + MaxMixPairReqUTXOs*maxUTXOLen + // UTXOs + maxTxOutLen // Change output + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := pr.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) } } diff --git a/wire/msgmixsecrets.go b/wire/msgmixsecrets.go index 5e79dc872a..f962e05848 100644 --- a/wire/msgmixsecrets.go +++ b/wire/msgmixsecrets.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -25,7 +25,7 @@ type MsgMixSecrets struct { Run uint32 Seed [32]byte SlotReserveMsgs [][]byte - DCNetMsgs [][]byte + DCNetMsgs MixVect // hash records the hash of the message. It is a member of the // message for convenience and performance, but is never automatically @@ -33,6 +33,70 @@ type MsgMixSecrets struct { hash chainhash.Hash } +func writeMixVect(op string, w io.Writer, pver uint32, vec MixVect) error { + if len(vec) > MaxMixMcount { + msg := "DC-net mix vector dimensions are too large for maximum message count" + return messageError(op, ErrInvalidMsg, msg) + } + + // Write dimensions + err := WriteVarInt(w, pver, uint64(len(vec))) + if err != nil { + return err + } + err = WriteVarInt(w, pver, MixMsgSize) + if err != nil { + return err + } + + // Write messages + for i := range vec { + err = writeElement(w, &vec[i]) + if err != nil { + return err + } + } + + return nil +} + +func readMixVect(op string, r io.Reader, pver uint32, vec *MixVect) error { + // Read dimensions + n, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if n == 0 { + *vec = MixVect{} + return nil + } + msize, err := ReadVarInt(r, pver) + if err != nil { + return err + } + + if n > MaxMixMcount { + msg := "DC-net mix vector dimensions are too large for maximum message count" + return messageError(op, ErrInvalidMsg, msg) + } + if msize != MixMsgSize { + msg := fmt.Sprintf("mixed message length must be %d [got: %d]", + MixMsgSize, msize) + return messageError(op, ErrInvalidMsg, msg) + } + + // Read messages + *vec = make(MixVect, n) + for i := uint64(0); i < n; i++ { + err = readElement(r, &(*vec)[i]) + if err != nil { + return err + } + } + + return nil +} + // BtcDecode decodes r using the Decred protocol encoding into the receiver. // This is part of the Message interface implementation. func (msg *MsgMixSecrets) BtcDecode(r io.Reader, pver uint32) error { @@ -49,8 +113,7 @@ func (msg *MsgMixSecrets) BtcDecode(r io.Reader, pver uint32) error { return err } - var numSRs uint64 - err = readElement(r, &numSRs) + numSRs, err := ReadVarInt(r, pver) if err != nil { return err } @@ -61,31 +124,25 @@ func (msg *MsgMixSecrets) BtcDecode(r io.Reader, pver uint32) error { } msg.SlotReserveMsgs = make([][]byte, numSRs) for i := uint64(0); i < numSRs; i++ { - sr, err := ReadVarBytes(r, pver, MaxMixFieldValLen, "SR") + sr, err := ReadVarBytes(r, pver, MaxMixFieldValLen, + "slot reservation mixed message") if err != nil { return err } msg.SlotReserveMsgs[i] = sr } - var numMs uint64 - err = readElement(r, &numMs) + var dcnetMsgs MixVect + err = readMixVect(op, r, pver, &dcnetMsgs) if err != nil { return err } - if numMs > MaxMixMcount { + if len(dcnetMsgs) > MaxMixMcount { msg := fmt.Sprintf("too many total mixed messages [count %v, max %v]", - numMs, MaxMixMcount) + len(dcnetMsgs), MaxMixMcount) return messageError(op, ErrInvalidMsg, msg) } - msg.DCNetMsgs = make([][]byte, numMs) - for i := uint64(0); i < numMs; i++ { - m, err := ReadVarBytes(r, pver, MaxMixFieldValLen, "M") - if err != nil { - return err - } - msg.DCNetMsgs[i] = m - } + msg.DCNetMsgs = dcnetMsgs return nil } @@ -151,7 +208,7 @@ func (msg *MsgMixSecrets) writeMessageNoSignature(op string, w io.Writer, pver u return err } - err = writeElement(w, uint64(len(msg.SlotReserveMsgs))) + err = WriteVarInt(w, pver, uint64(len(msg.SlotReserveMsgs))) if err != nil { return err } @@ -162,24 +219,18 @@ func (msg *MsgMixSecrets) writeMessageNoSignature(op string, w io.Writer, pver u } } - err = writeElement(w, uint64(len(msg.DCNetMsgs))) + err = writeMixVect(op, w, pver, msg.DCNetMsgs) if err != nil { return err } - for _, m := range msg.DCNetMsgs { - err := WriteVarBytes(w, pver, m) - if err != nil { - return err - } - } return nil } -// WriteSigned writes a tag identifying the message data, followed by all +// WriteSignedData writes a tag identifying the message data, followed by all // message fields excluding the signature. This is the data committed to when // the message is signed. -func (msg *MsgMixSecrets) WriteSigned(h hash.Hash) { +func (msg *MsgMixSecrets) WriteSignedData(h hash.Hash) { WriteVarString(h, MixVersion, CmdMixSecrets+"-sig") msg.writeMessageNoSignature("", h, MixVersion) } @@ -193,7 +244,12 @@ func (msg *MsgMixSecrets) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgMixSecrets) MaxPayloadLength(pver uint32) uint32 { - return 67769 + if pver < MixVersion { + return 0 + } + + // See tests for this calculation + return 54448 } // Pub returns the message sender's public key identity. @@ -234,7 +290,7 @@ func (msg *MsgMixSecrets) GetRun() uint32 { // Message interface using the passed parameters and defaults for the // remaining fields. func NewMsgMixSecrets(identity [33]byte, sid [32]byte, expiry uint32, run uint32, - seed [32]byte, slotReserveMsgs [][]byte, dcNetMsgs [][]byte) *MsgMixSecrets { + seed [32]byte, slotReserveMsgs [][]byte, dcNetMsgs MixVect) *MsgMixSecrets { return &MsgMixSecrets{ Identity: identity, diff --git a/wire/msgmixsecrets_test.go b/wire/msgmixsecrets_test.go index c2a0fe8cfe..21ad6fd531 100644 --- a/wire/msgmixsecrets_test.go +++ b/wire/msgmixsecrets_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -6,59 +6,78 @@ package wire import ( "bytes" + "errors" + "fmt" "reflect" "testing" "github.com/davecgh/go-spew/spew" ) -func TestMixRSWire(t *testing.T) { - pver := MixVersion - - repeat := func(b byte, count int) []byte { - s := make([]byte, count) - for i := range s { - s[i] = b - } - return s - } - - // Create a fictitious message with easily-distinguishable fields. - - var sig [64]byte - copy(sig[:], repeat(0x80, 64)) - - var id [33]byte - copy(id[:], repeat(0x81, 33)) - - var sid [32]byte - copy(sid[:], repeat(0x82, 32)) +func newTestMixSecrets() *MsgMixSecrets { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) const expiry = uint32(0x83838383) const run = uint32(0x84848484) - var seed [32]byte - copy(seed[:], repeat(0x85, 32)) + seed := *(*[32]byte)(repeat(0x85, 32)) sr := make([][]byte, 4) for b := byte(0x86); b < 0x8A; b++ { sr[b-0x86] = repeat(b, 32) } - m := make([][]byte, 4) + m := make(MixVect, 4) for b := byte(0x8A); b < 0x8E; b++ { - m[b-0x8A] = repeat(b, 32) + copy(m[b-0x8A][:], repeat(b, 20)) } rs := NewMsgMixSecrets(id, sid, expiry, run, seed, sr, m) rs.Signature = sig + return rs +} + +func TestMsgMixSecretsWire(t *testing.T) { + pver := MixVersion + + rs := newTestMixSecrets() + buf := new(bytes.Buffer) err := rs.BtcEncode(buf, pver) if err != nil { t.Fatal(err) } + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Expiry + expected = append(expected, repeat(0x84, 4)...) // Run + expected = append(expected, repeat(0x85, 32)...) // Seed + // Four slot reservation mixed messages (repeating 32 bytes of 0x86, 0x87, 0x88, 0x89) + expected = append(expected, 0x04) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x86, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x87, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x88, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x89, 32)...) + // Four slot reservation mixed messages (repeating 20 bytes of 0x8a, 0x8b, 0x8c, 0x8d) + expected = append(expected, 0x04, 0x14) + expected = append(expected, repeat(0x8a, 20)...) + expected = append(expected, repeat(0x8b, 20)...) + expected = append(expected, repeat(0x8c, 20)...) + expected = append(expected, repeat(0x8d, 20)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + decodedRS := new(MsgMixSecrets) err = decodedRS.BtcDecode(bytes.NewReader(buf.Bytes()), pver) if err != nil { @@ -68,8 +87,111 @@ func TestMixRSWire(t *testing.T) { if !reflect.DeepEqual(rs, decodedRS) { t.Errorf("BtcDecode got: %s want: %s", spew.Sdump(decodedRS), spew.Sdump(rs)) - } else { - t.Logf("bytes: %x", buf.Bytes()) - t.Logf("spew: %s", spew.Sdump(decodedRS)) + } +} + +func TestMsgMixSecretsCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixSecrets() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixSecrets) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixSecretsMaxPayloadLength tests the results returned by +// [MsgMixSecrets.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixSecretsMaxPayloadLength(t *testing.T) { + var rs *MsgMixSecrets + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := rs.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Expiry + 4 + // Run + 32 + // Seed + uint32(VarIntSerializeSize(MaxMixMcount)) + // SR message count + MaxMixMcount*varBytesLen(MaxMixFieldValLen) + // Unpadded SR values + uint32(VarIntSerializeSize(MaxMixMcount)) + // DC-net message count + uint32(VarIntSerializeSize(MixMsgSize)) + // DC-net message size + MaxMixMcount*MixMsgSize // DC-net messages + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := rs.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) } } diff --git a/wire/msgmixslotreserve.go b/wire/msgmixslotreserve.go index 5ed889df0e..e440865c6f 100644 --- a/wire/msgmixslotreserve.go +++ b/wire/msgmixslotreserve.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -16,12 +16,9 @@ const ( // MaxMixMcount is the maximum number of mixed messages that are allowed // in a single mix. This restricts the total allowed size of the slot // reservation and XOR DC-net matrices. - MaxMixMcount = 1024 // XXX: PNOOMA - - // MaxMixKPCount is the maximum number of peers allowed together in a - // single mix. This restricts the total size of the slot reservation - // and XOR DC-net matrices. - MaxMixKPCount = 512 // XXX: PNOOMA + // This value is an high estimate of what a large mix may look like, + // based on statistics from the centralized mixing server. + MaxMixMcount = 1024 // MaxMixFieldValLen is the maximum number of bytes allowed to represent // a value in the slot reservation mix bounded by the field prime. @@ -66,10 +63,6 @@ func (msg *MsgMixSlotReserve) BtcDecode(r io.Reader, pver uint32) error { if err != nil { return err } - kpcount, err := ReadVarInt(r, pver) - if err != nil { - return err - } if mcount == 0 { msg := fmt.Sprintf("too few mixed messages [%v]", mcount) return messageError(op, ErrInvalidMsg, msg) @@ -78,20 +71,25 @@ func (msg *MsgMixSlotReserve) BtcDecode(r io.Reader, pver uint32) error { msg := fmt.Sprintf("too many total mixed messages [%v]", mcount) return messageError(op, ErrInvalidMsg, msg) } - if mcount == 0 { + kpcount, err := ReadVarInt(r, pver) + if err != nil { + return err + } + if kpcount == 0 { msg := fmt.Sprintf("too few mixing peers [%v]", kpcount) return messageError(op, ErrInvalidMsg, msg) } - if kpcount > MaxMixKPCount { + if kpcount > MaxMixPeers { msg := fmt.Sprintf("too many mixing peers [count %v, max %v]", - kpcount, MaxMixKPCount) + kpcount, MaxMixPeers) return messageError(op, ErrInvalidMsg, msg) } dcmix := make([][][]byte, mcount) for i := range dcmix { dcmix[i] = make([][]byte, kpcount) for j := range dcmix[i] { - v, err := ReadVarBytes(r, pver, MaxMixFieldValLen, "fieldval") + v, err := ReadVarBytes(r, pver, MaxMixFieldValLen, + "slot reservation field value") if err != nil { return err } @@ -104,9 +102,9 @@ func (msg *MsgMixSlotReserve) BtcDecode(r io.Reader, pver uint32) error { if err != nil { return err } - if count > MaxPrevMixMsgs { + if count > MaxMixPeers { msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) + count, MaxMixPeers) return messageError(op, ErrTooManyPrevMixMsgs, msg) } @@ -201,7 +199,7 @@ func (msg *MsgMixSlotReserve) writeMessageNoSignature(op string, w io.Writer, pv msg := fmt.Sprintf("too few mixing peers [%v]", kpcount) return messageError(op, ErrInvalidMsg, msg) } - if !hashing && kpcount > MaxMixKPCount { + if !hashing && kpcount > MaxMixPeers { msg := fmt.Sprintf("too many mixing peers [%v]", kpcount) return messageError(op, ErrInvalidMsg, msg) } @@ -232,9 +230,9 @@ func (msg *MsgMixSlotReserve) writeMessageNoSignature(op string, w io.Writer, pv } count := len(msg.SeenCiphertexts) - if !hashing && count > MaxPrevMixMsgs { + if !hashing && count > MaxMixPeers { msg := fmt.Sprintf("too many previous referenced messages [count %v, max %v]", - count, MaxPrevMixMsgs) + count, MaxMixPeers) return messageError(op, ErrTooManyPrevMixMsgs, msg) } @@ -252,10 +250,10 @@ func (msg *MsgMixSlotReserve) writeMessageNoSignature(op string, w io.Writer, pv return nil } -// WriteSigned writes a tag identifying the message data, followed by all +// WriteSignedData writes a tag identifying the message data, followed by all // message fields excluding the signature. This is the data committed to when // the message is signed. -func (msg *MsgMixSlotReserve) WriteSigned(h hash.Hash) { +func (msg *MsgMixSlotReserve) WriteSignedData(h hash.Hash) { WriteVarString(h, MixVersion, CmdMixSlotReserve+"-sig") msg.writeMessageNoSignature("", h, MixVersion) } @@ -269,6 +267,11 @@ func (msg *MsgMixSlotReserve) Command() string { // MaxPayloadLength returns the maximum length the payload can be for the // receiver. This is part of the Message interface implementation. func (msg *MsgMixSlotReserve) MaxPayloadLength(pver uint32) uint32 { + if pver < MixVersion { + return 0 + } + + // See tests for this calculation. return 17318034 } diff --git a/wire/msgmixslotreserve_test.go b/wire/msgmixslotreserve_test.go index f458b1e9ac..31134278e2 100644 --- a/wire/msgmixslotreserve_test.go +++ b/wire/msgmixslotreserve_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 The Decred developers +// Copyright (c) 2023-2024 The Decred developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. @@ -6,6 +6,8 @@ package wire import ( "bytes" + "errors" + "fmt" "reflect" "testing" @@ -13,33 +15,17 @@ import ( "github.com/decred/dcrd/chaincfg/chainhash" ) -func TestMixSRWire(t *testing.T) { - pver := MixVersion - - repeat := func(b byte, count int) []byte { - s := make([]byte, count) - for i := range s { - s[i] = b - } - return s - } - - // Create a fictitious message with easily-distinguishable fields. - - var sig [64]byte - copy(sig[:], repeat(0x80, 64)) - - var id [33]byte - copy(id[:], repeat(0x81, 33)) - - var sid [32]byte - copy(sid[:], repeat(0x82, 32)) +func newTestMixSlotReserve() *MsgMixSlotReserve { + // Use easily-distinguishable fields. + sig := *(*[64]byte)(repeat(0x80, 64)) + id := *(*[33]byte)(repeat(0x81, 33)) + sid := *(*[32]byte)(repeat(0x82, 32)) const expiry = uint32(0x83838383) const run = uint32(0x84848484) - mcount := 4 - kpcount := 4 + const mcount = 4 + const kpcount = 4 dcmix := make([][][]byte, mcount) // will add 4x4 field numbers of incrementing repeating byte values to // dcmix, ranging from 0x85 through 0x94 @@ -60,12 +46,68 @@ func TestMixSRWire(t *testing.T) { sr := NewMsgMixSlotReserve(id, sid, expiry, run, dcmix, seenCTs) sr.Signature = sig + return sr +} +func TestMsgMixSlotReserveWire(t *testing.T) { + pver := MixVersion + + sr := newTestMixSlotReserve() + buf := new(bytes.Buffer) err := sr.BtcEncode(buf, pver) if err != nil { t.Fatal(err) } + expected := make([]byte, 0, buf.Len()) + expected = append(expected, repeat(0x80, 64)...) // Signature + expected = append(expected, repeat(0x81, 33)...) // Identity + expected = append(expected, repeat(0x82, 32)...) // Session ID + expected = append(expected, repeat(0x83, 4)...) // Expiry + expected = append(expected, repeat(0x84, 4)...) // Run + // 4x4 slot reservation mixed messages (repeating 32 bytes from 0x85 through 0x94) + expected = append(expected, 0x04, 0x04) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x85, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x86, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x87, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x88, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x89, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8a, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8b, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8c, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8d, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8e, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x8f, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x90, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x91, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x92, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x93, 32)...) + expected = append(expected, 0x20) + expected = append(expected, repeat(0x94, 32)...) + // Four seen CTs (repeating 32 bytes of 0x95, 0x96, 0x97, 0x98) + expected = append(expected, 0x04) + expected = append(expected, repeat(0x95, 32)...) + expected = append(expected, repeat(0x96, 32)...) + expected = append(expected, repeat(0x97, 32)...) + expected = append(expected, repeat(0x98, 32)...) + + expectedSerializationEqual(t, buf.Bytes(), expected) + decodedSR := new(MsgMixSlotReserve) err = decodedSR.BtcDecode(bytes.NewReader(buf.Bytes()), pver) if err != nil { @@ -75,8 +117,110 @@ func TestMixSRWire(t *testing.T) { if !reflect.DeepEqual(sr, decodedSR) { t.Errorf("BtcDecode got: %s want: %s", spew.Sdump(decodedSR), spew.Sdump(sr)) - } else { - t.Logf("bytes: %x", buf.Bytes()) - t.Logf("spew: %s", spew.Sdump(decodedSR)) + } +} + +func TestMsgMixSlotReserveCrossProtocol(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + encodeVersion uint32 + decodeVersion uint32 + err error + remainingBytes int + }{{ + name: "Latest->MixVersion", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion, + }, { + name: "Latest->MixVersion-1", + encodeVersion: ProtocolVersion, + decodeVersion: MixVersion - 1, + err: ErrMsgInvalidForPVer, + }, { + name: "MixVersion->Latest", + encodeVersion: MixVersion, + decodeVersion: ProtocolVersion, + }} + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if tc.err != nil && tc.remainingBytes != 0 { + t.Errorf("invalid testcase: non-zero remaining bytes " + + "expects no decoding error") + } + + msg := newTestMixSlotReserve() + + buf := new(bytes.Buffer) + err := msg.BtcEncode(buf, tc.encodeVersion) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + + msg = new(MsgMixSlotReserve) + err = msg.BtcDecode(buf, tc.decodeVersion) + if !errors.Is(err, tc.err) { + t.Errorf("decode failed; want %v, got %v", tc.err, err) + } + if err == nil && buf.Len() != tc.remainingBytes { + t.Errorf("buffer contains unexpected remaining bytes "+ + "from encoded message: want %v bytes, got %v (hex: %[2]x)", + buf.Len(), buf.Bytes()) + } + }) + } +} + +// TestMsgMixSlotReserveMaxPayloadLength tests the results returned by +// [MsgMixSlotReserve.MaxPayloadLength] by calculating the maximum payload length. +func TestMsgMixSlotReserveMaxPayloadLength(t *testing.T) { + var sr *MsgMixSlotReserve + + // Test all protocol versions before MixVersion + for pver := uint32(0); pver < MixVersion; pver++ { + t.Run(fmt.Sprintf("pver=%d", pver), func(t *testing.T) { + got := sr.MaxPayloadLength(pver) + if got != 0 { + t.Errorf("got %d, expected %d", got, 0) + } + }) + } + + var expectedLen uint32 = 64 + // Signature + 33 + // Identity + 32 + // Session ID + 4 + // Expiry + 4 + // Run + uint32(VarIntSerializeSize(MaxMixMcount)) + // Message count + uint32(VarIntSerializeSize(MaxMixPeers)) + // Peer count + MaxMixMcount*MaxMixPeers*varBytesLen(MaxMixFieldValLen) + // Padded SR values + uint32(VarIntSerializeSize(MaxMixPeers)) + // Ciphertext count + 32*MaxMixPeers // Ciphertext hashes + + tests := []struct { + name string + pver uint32 + len uint32 + }{{ + name: "MixVersion", + pver: MixVersion, + len: expectedLen, + }, { + name: "ProtocolVersion", + pver: ProtocolVersion, + len: expectedLen, + }} + for _, tc := range tests { + t.Run(fmt.Sprintf("pver=%s", tc.name), func(t *testing.T) { + got := sr.MaxPayloadLength(tc.pver) + if got != tc.len { + t.Errorf("got %d, expected %d", got, tc.len) + } + }) } }