Skip to content

Commit

Permalink
[R4R] - {develop}: fix consensys audit issue: cs-6.34 (#1194)
Browse files Browse the repository at this point in the history
* use pointer instead of struct

* mutex struct

* fix abnormal node memory ddos issue

* change mutex from pointer to struct

---------

Co-authored-by: Raymond <6427270+wukongcheng@users.noreply.github.com>
  • Loading branch information
HaoyangLiu and wukongcheng committed Jul 2, 2023
1 parent cf0f805 commit 750da92
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 36 deletions.
47 changes: 34 additions & 13 deletions tss/node/tsslib/abnormal/abnormal.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,63 @@ import (
"sync"
)

func NewNode(pk string, data, sig []byte) Node {
return Node{
const (
maxAbnormalNodeLength = 100
)

func NewNode(pk string, data, sig []byte) *Node {
return &Node{
Pubkey: pk,
Data: data,
Signature: sig,
}
}

func (n *Node) Equal(node Node) bool {
func (n *Node) Equal(node *Node) bool {
if node == nil {
return false
}
if n.Pubkey == node.Pubkey && bytes.Equal(n.Signature, node.Signature) {
return true
}
return false
}

func NewAbnormal(reason string, nodes []Node) Abnormal {
return Abnormal{
func NewAbnormal(reason string, nodes []*Node) *Abnormal {
abnormal := &Abnormal{
FailReason: reason,
Nodes: nodes,
AbnormalLock: &sync.RWMutex{},
Nodes: make([]*Node, 0, maxAbnormalNodeLength),
AbnormalLock: sync.RWMutex{},
}
abnormal.appendNewNodes(nodes)
return abnormal
}

func (a Abnormal) String() string {
func (a *Abnormal) String() string {
sb := strings.Builder{}
sb.WriteString("reason:" + a.FailReason + " is_unicast:" + strconv.FormatBool(a.IsUnicast) + "\n")
sb.WriteString(fmt.Sprintf("nodes:%+v\n", a.Nodes))
return sb.String()
}

func (a *Abnormal) SetAbnormal(reason string, nodes []Node, isUnicast bool) {
func (a *Abnormal) appendNewNodes(newNodes []*Node) {
if len(newNodes) > maxAbnormalNodeLength {
a.Nodes = newNodes[len(newNodes)-maxAbnormalNodeLength:]
} else if len(newNodes)+len(a.Nodes) > maxAbnormalNodeLength {
exceedAmount := len(newNodes) + len(a.Nodes) - maxAbnormalNodeLength
a.Nodes = a.Nodes[exceedAmount:]
a.Nodes = append(a.Nodes, newNodes...)
} else {
a.Nodes = append(a.Nodes, newNodes...)
}
}

func (a *Abnormal) SetAbnormal(reason string, nodes []*Node, isUnicast bool) {
a.AbnormalLock.Lock()
defer a.AbnormalLock.Unlock()
a.FailReason = reason
a.IsUnicast = isUnicast
a.Nodes = append(a.Nodes, nodes...)
a.appendNewNodes(nodes)
}

func (a *Abnormal) AlreadyAbnormal() bool {
Expand All @@ -52,8 +73,8 @@ func (a *Abnormal) AlreadyAbnormal() bool {
return len(a.Nodes) > 0
}

// AddBlameNodes add nodes to the blame list
func (a *Abnormal) AddAbnormalNodes(newNodes ...Node) {
// AddAbnormalNodes add nodes to the abnormal node list
func (a *Abnormal) AddAbnormalNodes(newNodes ...*Node) {
a.AbnormalLock.Lock()
defer a.AbnormalLock.Unlock()
for _, node := range newNodes {
Expand All @@ -65,7 +86,7 @@ func (a *Abnormal) AddAbnormalNodes(newNodes ...Node) {
}
}
if !found {
a.Nodes = append(a.Nodes, node)
a.appendNewNodes([]*Node{node})
}
}
}
15 changes: 8 additions & 7 deletions tss/node/tsslib/abnormal/manager.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package abnormal

import (
"sync"

"github.com/binance-chain/tss-lib/tss"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"sync"
)

type Manager struct {
Expand All @@ -16,26 +17,26 @@ type Manager struct {
roundMgr *RoundMgr
partyInfo *PartyInfo
PartyIDtoP2PID map[string]peer.ID
lastMsgLocker *sync.RWMutex
lastMsgLocker sync.RWMutex
lastMsg tss.Message
acceptedShares map[RoundInfo][]string
acceptShareLocker *sync.Mutex
acceptShareLocker sync.Mutex
localPartyID string
}

func NewAbnormalManager() *Manager {
Abnormal := NewAbnormal("", nil)
abnormal := NewAbnormal("", nil)
return &Manager{
logger: log.With().Str("module", "Abnormal_manager").Logger(),
partyInfo: nil,
PartyIDtoP2PID: make(map[string]peer.ID),
lastUnicastPeer: make(map[string][]peer.ID),
shareMgr: NewTssShareMgr(),
roundMgr: NewTssRoundMgr(),
Abnormal: &Abnormal,
lastMsgLocker: &sync.RWMutex{},
Abnormal: abnormal,
lastMsgLocker: sync.RWMutex{},
acceptedShares: make(map[RoundInfo][]string),
acceptShareLocker: &sync.Mutex{},
acceptShareLocker: sync.Mutex{},
}
}

Expand Down
21 changes: 11 additions & 10 deletions tss/node/tsslib/abnormal/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package abnormal
import (
"errors"
"fmt"

"github.com/binance-chain/tss-lib/tss"
mapset "github.com/deckarep/golang-set"
"github.com/libp2p/go-libp2p/core/peer"
Expand Down Expand Up @@ -42,14 +43,14 @@ func (m *Manager) tssTimeoutAbnormal(lastMessageType string, partyIDMap map[stri
return AbnormalPubKeys, nil
}

// this Abnormal Abnormals the node who cause the timeout in node sync
func (m *Manager) NodeSyncAbnormal(keys []string, onlinePeers []peer.ID) (Abnormal, error) {
Abnormal := NewAbnormal(TssSyncFail, nil)
// NodeSyncAbnormal create a SyncAbnormal object
func (m *Manager) NodeSyncAbnormal(keys []string, onlinePeers []peer.ID) (*Abnormal, error) {
abnormal := NewAbnormal(TssSyncFail, nil)
for _, item := range keys {
found := false
peerID, err := conversion2.GetPeerIDFromPubKey(item)
if err != nil {
return Abnormal, fmt.Errorf("fail to get peer id from pub key")
return nil, fmt.Errorf("fail to get peer id from pub key")
}
for _, p := range onlinePeers {
if p == peerID {
Expand All @@ -58,14 +59,14 @@ func (m *Manager) NodeSyncAbnormal(keys []string, onlinePeers []peer.ID) (Abnorm
}
}
if !found {
Abnormal.Nodes = append(Abnormal.Nodes, NewNode(item, nil, nil))
abnormal.Nodes = append(abnormal.Nodes, NewNode(item, nil, nil))
}
}
return Abnormal, nil
return abnormal, nil
}

// this Abnormal Abnormals the node who cause the timeout in unicast message
func (m *Manager) GetUnicastAbnormal(lastMsgType string) ([]Node, error) {
func (m *Manager) GetUnicastAbnormal(lastMsgType string) ([]*Node, error) {
m.lastMsgLocker.RLock()
if len(m.lastUnicastPeer) == 0 {
m.lastMsgLocker.RUnlock()
Expand All @@ -91,21 +92,21 @@ func (m *Manager) GetUnicastAbnormal(lastMsgType string) ([]Node, error) {
m.logger.Error().Err(err).Msg("fail to get the Abnormald peers")
return nil, fmt.Errorf("fail to get the Abnormald peers %w", ErrTssTimeOut)
}
var AbnormalNodes []Node
var AbnormalNodes []*Node
for _, el := range AbnormalPeers {
AbnormalNodes = append(AbnormalNodes, NewNode(el, nil, nil))
}
return AbnormalNodes, nil
}

// this Abnormal Abnormals the node who cause the timeout in broadcast message
func (m *Manager) GetBroadcastAbnormal(lastMessageType string) ([]Node, error) {
func (m *Manager) GetBroadcastAbnormal(lastMessageType string) ([]*Node, error) {
AbnormalPeers, err := m.tssTimeoutAbnormal(lastMessageType, m.partyInfo.PartyIDMap)
if err != nil {
m.logger.Error().Err(err).Msg("fail to get the Abnormald peers")
return nil, fmt.Errorf("fail to get the Abnormald peers %w", ErrTssTimeOut)
}
var AbnormalNodes []Node
var AbnormalNodes []*Node
for _, el := range AbnormalPeers {
AbnormalNodes = append(AbnormalNodes, NewNode(el, nil, nil))
}
Expand Down
11 changes: 6 additions & 5 deletions tss/node/tsslib/abnormal/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package abnormal

import (
"errors"
"github.com/binance-chain/tss-lib/tss"
"sync"

"github.com/binance-chain/tss-lib/tss"
)

const (
Expand Down Expand Up @@ -37,8 +38,8 @@ type Node struct {
}

type Abnormal struct {
FailReason string `json:"fail_reason"`
IsUnicast bool `json:"is_broadcast"`
Nodes []Node `json:"abnormal_peers,omitempty"`
AbnormalLock *sync.RWMutex
FailReason string `json:"fail_reason"`
IsUnicast bool `json:"is_broadcast"`
Nodes []*Node `json:"abnormal_peers,omitempty"`
AbnormalLock sync.RWMutex
}
2 changes: 1 addition & 1 deletion tss/node/tsslib/common/tss.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (t *TssCommon) processInvalidMsg(roundInfo string, round abnormal2.RoundInf
}
// This error indicates the share is wrong, we include this signature to prove that
// this incorrect share is from the share owner.
var blameNodes []abnormal2.Node
var blameNodes []*abnormal2.Node
var msgBody, sig []byte
for i, pk := range pubkeys {
invalidMsg := invalidMsgs[i]
Expand Down

0 comments on commit 750da92

Please sign in to comment.