Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 176 additions & 75 deletions internal/handshake/protocol/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"fmt"
"io"
"net"
"sync"
"time"

"github.com/blinklabs-io/cdnsd/internal/handshake"
Expand All @@ -31,6 +32,16 @@ type Peer struct {
address string
conn net.Conn
networkMagic uint32
mu sync.Mutex
sendMu sync.Mutex
hasConnected bool
doneCh chan struct{}
errorCh chan error
handshakeCh chan Message
headersCh chan Message
blockCh chan Message
addrCh chan Message
proofCh chan Message
}

// NewPeer returns a new Peer using an existing connection (if provided) and the specified network magic. If a connection is provided,
Expand All @@ -39,10 +50,15 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) {
p := &Peer{
conn: conn,
networkMagic: networkMagic,
doneCh: make(chan struct{}),
errorCh: make(chan error, 5),
}
if conn != nil {
p.conn = conn
p.address = conn.RemoteAddr().String()
if err := p.handshake(); err != nil {
p.hasConnected = true
if err := p.setupConnection(); err != nil {
_ = p.conn.Close()
return nil, err
}
}
Expand All @@ -51,34 +67,69 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) {

// Connect establishes a connection with a peer and performs the handshake process
func (p *Peer) Connect(address string) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.conn != nil {
return errors.New("connection already established")
}
if p.hasConnected {
return errors.New("peer cannot be reused after disconnect")
}
var err error
p.conn, err = net.DialTimeout("tcp", address, dialTimeout)
if err != nil {
return err
}
p.address = address
if err := p.handshake(); err != nil {
p.hasConnected = true
if err := p.setupConnection(); err != nil {
_ = p.conn.Close()
return err
}
return nil
}

// Close closes an active connection with a network peer
func (p *Peer) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
if p.conn == nil {
return errors.New("connection is not established")
}
if err := p.conn.Close(); err != nil {
return err
}
p.conn = nil
Copy link

@cubic-dev-ai cubic-dev-ai bot Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Close now nils p.conn while the recvLoop goroutine still dereferences it, so closing a peer mid-stream can panic when the loop tries to read from a nil connection.

Prompt for AI agents
Address the following comment on internal/handshake/protocol/peer.go at line 102:

<comment>Close now nils `p.conn` while the recvLoop goroutine still dereferences it, so closing a peer mid-stream can panic when the loop tries to read from a nil connection.</comment>

<file context>
@@ -51,34 +67,69 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) {
 	if err := p.conn.Close(); err != nil {
 		return err
 	}
+	p.conn = nil
+	close(p.doneCh)
+	return nil
</file context>
Fix with Cubic

close(p.doneCh)
return nil
}

// ErrorChan returns the async error channel
func (p *Peer) ErrorChan() <-chan error {
return p.errorCh
}

// setupConnection runs the initial handshake and starts the receive loop
func (p *Peer) setupConnection() error {
// Init channels for async messages
p.handshakeCh = make(chan Message, 10)
p.headersCh = make(chan Message, 10)
p.blockCh = make(chan Message, 10)
p.addrCh = make(chan Message, 10)
p.proofCh = make(chan Message, 10)
// Start receive loop
go p.recvLoop()
// Start handshake
if err := p.handshake(); err != nil {
return err
}
return nil
}

// sendMessage encodes and sends a message with the given type and payload
func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
if p.conn == nil {
return errors.New("connection is not established")
}
Expand All @@ -97,37 +148,67 @@ func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error {
return nil
}

// receiveMessage receives and decodes messages from the active connection
func (p *Peer) receiveMessage() (Message, error) {
headerBuf := make([]byte, messageHeaderLength)
if _, err := io.ReadFull(p.conn, headerBuf); err != nil {
return nil, fmt.Errorf("read header: %w", err)
}
header := new(msgHeader)
if err := header.Decode(headerBuf); err != nil {
return nil, fmt.Errorf("header decode: %w", err)
}
if header.NetworkMagic != p.networkMagic {
return nil, fmt.Errorf("invalid network magic: %d", header.NetworkMagic)
}
if header.PayloadLength > messageMaxPayloadLength {
return nil, errors.New("payload is too large")
}
payload := make([]byte, header.PayloadLength)
if _, err := io.ReadFull(p.conn, payload); err != nil {
return nil, fmt.Errorf("read payload: %w", err)
}
msg, err := decodeMessage(header, payload)
if err != nil {
// Discard unsupported messages and try to get another message
// This is a bit of a hack
var unsupportedErr UnsupportedMessageTypeError
if errors.As(err, &unsupportedErr) {
return p.receiveMessage()
// recvLoop receives and decodes messages from the active connection
func (p *Peer) recvLoop() {
err := func() error {
// Assign to local var to avoid nil deref panic on shutdown
conn := p.conn
for {
headerBuf := make([]byte, messageHeaderLength)
if _, err := io.ReadFull(conn, headerBuf); err != nil {
return fmt.Errorf("read header: %w", err)
}
header := new(msgHeader)
if err := header.Decode(headerBuf); err != nil {
return fmt.Errorf("header decode: %w", err)
}
if header.NetworkMagic != p.networkMagic {
return fmt.Errorf("invalid network magic: %d", header.NetworkMagic)
}
if header.PayloadLength > messageMaxPayloadLength {
return errors.New("payload is too large")
}
payload := make([]byte, header.PayloadLength)
if _, err := io.ReadFull(conn, payload); err != nil {
return fmt.Errorf("read payload: %w", err)
}
msg, err := decodeMessage(header, payload)
if err != nil {
// Discard unsupported messages and try to get another message
// This is a bit of a hack
var unsupportedErr UnsupportedMessageTypeError
if errors.As(err, &unsupportedErr) {
continue
}
return fmt.Errorf("decode message: %w", err)
}
if err := p.handleMessage(msg); err != nil {
return fmt.Errorf("handle message: %w", err)
}
}
return nil, err
}()
if err != nil {
p.errorCh <- err
_ = p.Close()
}
return msg, nil
}

func (p *Peer) handleMessage(msg Message) error {
switch msg.(type) {
case *MsgVersion, *MsgVerack:
p.handshakeCh <- msg
case *MsgAddr:
p.addrCh <- msg
case *MsgHeaders:
p.headersCh <- msg
case *MsgBlock:
p.blockCh <- msg
case *MsgProof:
p.proofCh <- msg
default:
return fmt.Errorf("unknown message type: %T", msg)
Copy link

@cubic-dev-ai cubic-dev-ai bot Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Receiving valid request messages (e.g., getaddr/getheaders/getdata/getproof) now triggers handleMessage’s default error, and recvLoop closes the connection as soon as it sees those messages. This causes unnecessary disconnects from compliant peers instead of ignoring or responding to the requests.

Prompt for AI agents
Address the following comment on internal/handshake/protocol/peer.go at line 209:

<comment>Receiving valid request messages (e.g., getaddr/getheaders/getdata/getproof) now triggers handleMessage’s default error, and recvLoop closes the connection as soon as it sees those messages. This causes unnecessary disconnects from compliant peers instead of ignoring or responding to the requests.</comment>

<file context>
@@ -97,37 +148,67 @@ func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error {
+	case *MsgProof:
+		p.proofCh &lt;- msg
+	default:
+		return fmt.Errorf(&quot;unknown message type: %T&quot;, msg)
+	}
+	return nil
</file context>
Fix with Cubic

}
return nil
}

// handshake performs the handshake process, which involves exchanging Version messages with the network peer
Expand Down Expand Up @@ -157,20 +238,28 @@ func (p *Peer) handshake() error {
return err
}
// Wait for Verack response
msg, err := p.receiveMessage()
if err != nil {
return err
}
if _, ok := msg.(*MsgVerack); !ok {
return fmt.Errorf("unexpected message: %T", msg)
select {
case msg := <-p.handshakeCh:
if _, ok := msg.(*MsgVerack); !ok {
return fmt.Errorf("unexpected message: %T", msg)
}
case err := <-p.errorCh:
return fmt.Errorf("handshake failed: %w", err)
case <-time.After(1 * time.Second):
return errors.New("handshake timed out")
}
// Wait for Version from peer
msg, err = p.receiveMessage()
if err != nil {
return err
}
if _, ok := msg.(*MsgVersion); !ok {
return fmt.Errorf("unexpected message: %T", msg)
select {
case msg := <-p.handshakeCh:
if _, ok := msg.(*MsgVersion); !ok {
return fmt.Errorf("unexpected message: %T", msg)
}
case err := <-p.errorCh:
return fmt.Errorf("handshake failed: %w", err)
case <-p.doneCh:
return errors.New("connection has shut down")
case <-time.After(1 * time.Second):
return errors.New("handshake timed out")
}
// Send Verack
if err := p.sendMessage(MessageVerack, nil); err != nil {
Expand All @@ -185,15 +274,18 @@ func (p *Peer) GetPeers() ([]NetAddress, error) {
return nil, err
}
// Wait for Addr response
msg, err := p.receiveMessage()
if err != nil {
return nil, err
}
msgAddr, ok := msg.(*MsgAddr)
if !ok {
return nil, fmt.Errorf("unexpected message: %T", msg)
select {
case msg := <-p.addrCh:
msgAddr, ok := msg.(*MsgAddr)
if !ok {
return nil, fmt.Errorf("unexpected message: %T", msg)
}
return msgAddr.Peers, nil
case <-p.doneCh:
return nil, errors.New("connection has shut down")
case <-time.After(5 * time.Second):
return nil, errors.New("timed out")
}
return msgAddr.Peers, nil
}

// GetHeaders requests a list of headers from the network peer
Expand All @@ -206,15 +298,18 @@ func (p *Peer) GetHeaders(locator [][32]byte, stopHash [32]byte) ([]*handshake.B
return nil, err
}
// Wait for Headers response
msg, err := p.receiveMessage()
if err != nil {
return nil, err
}
msgHeaders, ok := msg.(*MsgHeaders)
if !ok {
return nil, fmt.Errorf("unexpected message: %T", msg)
select {
case msg := <-p.headersCh:
msgHeaders, ok := msg.(*MsgHeaders)
if !ok {
return nil, fmt.Errorf("unexpected message: %T", msg)
}
return msgHeaders.Headers, nil
case <-p.doneCh:
return nil, errors.New("connection has shut down")
case <-time.After(5 * time.Second):
return nil, errors.New("timed out")
}
return msgHeaders.Headers, nil
}

// GetProof requests a proof for a domain name from the network peer
Expand All @@ -228,15 +323,18 @@ func (p *Peer) GetProof(name string, rootHash [32]byte) (*handshake.Proof, error
return nil, err
}
// Wait for Proof response
msg, err := p.receiveMessage()
if err != nil {
return nil, err
}
msgProof, ok := msg.(*MsgProof)
if !ok {
return nil, fmt.Errorf("unexpected message: %T", msg)
select {
case msg := <-p.proofCh:
msgProof, ok := msg.(*MsgProof)
if !ok {
return nil, fmt.Errorf("unexpected message: %T", msg)
}
return msgProof.Proof, nil
case <-p.doneCh:
return nil, errors.New("connection has shut down")
case <-time.After(5 * time.Second):
return nil, errors.New("timed out")
}
return msgProof.Proof, nil
}

// GetBlock requests the specified block from the network peer
Expand All @@ -253,13 +351,16 @@ func (p *Peer) GetBlock(hash [32]byte) (*handshake.Block, error) {
return nil, err
}
// Wait for Block response
msg, err := p.receiveMessage()
if err != nil {
return nil, err
}
msgBlock, ok := msg.(*MsgBlock)
if !ok {
return nil, fmt.Errorf("unexpected message: %T", msg)
select {
case msg := <-p.blockCh:
msgBlock, ok := msg.(*MsgBlock)
if !ok {
return nil, fmt.Errorf("unexpected message: %T", msg)
}
return msgBlock.Block, nil
case <-p.doneCh:
return nil, errors.New("connection has shut down")
case <-time.After(5 * time.Second):
return nil, errors.New("timed out")
}
return msgBlock.Block, nil
}
Loading