diff --git a/mixing/mixpool/errors.go b/mixing/mixpool/errors.go index b6249d2f9..642605aed 100644 --- a/mixing/mixpool/errors.go +++ b/mixing/mixpool/errors.go @@ -8,6 +8,7 @@ import ( "errors" "github.com/decred/dcrd/chaincfg/chainhash" + "github.com/decred/dcrd/wire" ) // RuleError represents a mixpool rule violation. @@ -31,16 +32,18 @@ func (e *RuleError) Unwrap() error { } type bannableError struct { - s string + s string + services wire.ServiceFlag } func (e *bannableError) Error() string { return e.s } -func newBannableError(s string) error { +func newBannableError(s string, services wire.ServiceFlag) error { return &bannableError{ - s: s, + s: s, + services: services, } } @@ -48,54 +51,54 @@ func newBannableError(s string) error { var ( // ErrChangeDust is returned by AcceptMessage if a pair request's // change amount is dust. - ErrChangeDust = newBannableError("change output is dust") + ErrChangeDust = newBannableError("change output is dust", 0) // ErrLowInput is returned by AcceptMessage when not enough input value // is provided by a pair request to cover the mixed output, any change // output, and the minimum required fee. - ErrLowInput = newBannableError("not enough input value, or too low fee") + ErrLowInput = newBannableError("not enough input value, or too low fee", 0) // ErrInvalidMessageCount is returned by AcceptMessage if a // pair request contains an invalid message count. - ErrInvalidMessageCount = newBannableError("message count must be positive") + ErrInvalidMessageCount = newBannableError("message count must be positive", 0) // ErrInvalidScript is returned by AcceptMessage if a pair request // contains an invalid script. - ErrInvalidScript = newBannableError("invalid script") + ErrInvalidScript = newBannableError("invalid script", 0) // ErrInvalidSessionID is returned by AcceptMessage if the message // contains an invalid session id. - ErrInvalidSessionID = newBannableError("invalid session ID") + ErrInvalidSessionID = newBannableError("invalid session ID", 0) // ErrInvalidSignature is returned by AcceptMessage if the message is // not properly signed for the claimed identity. - ErrInvalidSignature = newBannableError("invalid message signature") + ErrInvalidSignature = newBannableError("invalid message signature", 0) // ErrInvalidTotalMixAmount is returned by AcceptMessage if a pair // request contains the product of the message count and mix amount // that exceeds the total input value. - ErrInvalidTotalMixAmount = newBannableError("invalid total mix amount") + ErrInvalidTotalMixAmount = newBannableError("invalid total mix amount", 0) // ErrInvalidUTXOProof is returned by AcceptMessage if a pair request // fails to prove ownership of each utxo. - ErrInvalidUTXOProof = newBannableError("invalid UTXO ownership proof") + ErrInvalidUTXOProof = newBannableError("invalid UTXO ownership proof", wire.SFNodeNetwork) // ErrMissingUTXOs is returned by AcceptMessage if a pair request // message does not reference any previous UTXOs. - ErrMissingUTXOs = newBannableError("pair request contains no UTXOs") + ErrMissingUTXOs = newBannableError("pair request contains no UTXOs", 0) // ErrPeerPositionOutOfBounds is returned by AcceptMessage if the // position of a peer's own PR is outside of the possible bounds of // the previously seen messages. - ErrPeerPositionOutOfBounds = newBannableError("peer position cannot be a valid seen PRs index") + ErrPeerPositionOutOfBounds = newBannableError("peer position cannot be a valid seen PRs index", 0) ) -// IsBannable returns whether the error condition is such that the peer who -// sent the message should be immediately banned for malicious or buggy -// behavior. -func IsBannable(err error) bool { +// IsBannable returns whether the error condition is such that the peer with +// capabilities defined by services who sent the message should be immediately +// banned for malicious or buggy behavior. +func IsBannable(err error, services wire.ServiceFlag) bool { var be *bannableError - return errors.As(err, &be) + return errors.As(err, &be) && be.services&services == be.services } var ( diff --git a/server.go b/server.go index 982df113f..673b37efa 100644 --- a/server.go +++ b/server.go @@ -1734,7 +1734,7 @@ func (sp *serverPeer) onMixMessage(msg mixing.Message) { nil, nil, nil, mixHashes) return } - if mixpool.IsBannable(err) { + if mixpool.IsBannable(err, sp.Services()) { reason := fmt.Sprintf("sent malformed mix message: %s", err) sp.server.BanPeer(sp, reason) }