From b77012ef2ecd83d674377dd3faba1119cb3c3b8e Mon Sep 17 00:00:00 2001 From: Aurora Gaffney Date: Fri, 21 Nov 2025 17:58:56 -0500 Subject: [PATCH] refactor: async Handshake network support Signed-off-by: Aurora Gaffney --- internal/handshake/protocol/peer.go | 251 +++++++++++++++++++--------- 1 file changed, 176 insertions(+), 75 deletions(-) diff --git a/internal/handshake/protocol/peer.go b/internal/handshake/protocol/peer.go index 7229534..a809892 100644 --- a/internal/handshake/protocol/peer.go +++ b/internal/handshake/protocol/peer.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "net" + "sync" "time" "github.com/blinklabs-io/cdnsd/internal/handshake" @@ -31,6 +32,16 @@ type Peer struct { address string conn net.Conn networkMagic uint32 + mu sync.Mutex + sendMu sync.Mutex + hasConnected bool + doneCh chan struct{} + errorCh chan error + handshakeCh chan Message + headersCh chan Message + blockCh chan Message + addrCh chan Message + proofCh chan Message } // NewPeer returns a new Peer using an existing connection (if provided) and the specified network magic. If a connection is provided, @@ -39,10 +50,15 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) { p := &Peer{ conn: conn, networkMagic: networkMagic, + doneCh: make(chan struct{}), + errorCh: make(chan error, 5), } if conn != nil { + p.conn = conn p.address = conn.RemoteAddr().String() - if err := p.handshake(); err != nil { + p.hasConnected = true + if err := p.setupConnection(); err != nil { + _ = p.conn.Close() return nil, err } } @@ -51,16 +67,23 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) { // Connect establishes a connection with a peer and performs the handshake process func (p *Peer) Connect(address string) error { + p.mu.Lock() + defer p.mu.Unlock() if p.conn != nil { return errors.New("connection already established") } + if p.hasConnected { + return errors.New("peer cannot be reused after disconnect") + } var err error p.conn, err = net.DialTimeout("tcp", address, dialTimeout) if err != nil { return err } p.address = address - if err := p.handshake(); err != nil { + p.hasConnected = true + if err := p.setupConnection(); err != nil { + _ = p.conn.Close() return err } return nil @@ -68,17 +91,45 @@ func (p *Peer) Connect(address string) error { // Close closes an active connection with a network peer func (p *Peer) Close() error { + p.mu.Lock() + defer p.mu.Unlock() if p.conn == nil { return errors.New("connection is not established") } if err := p.conn.Close(); err != nil { return err } + p.conn = nil + close(p.doneCh) + return nil +} + +// ErrorChan returns the async error channel +func (p *Peer) ErrorChan() <-chan error { + return p.errorCh +} + +// setupConnection runs the initial handshake and starts the receive loop +func (p *Peer) setupConnection() error { + // Init channels for async messages + p.handshakeCh = make(chan Message, 10) + p.headersCh = make(chan Message, 10) + p.blockCh = make(chan Message, 10) + p.addrCh = make(chan Message, 10) + p.proofCh = make(chan Message, 10) + // Start receive loop + go p.recvLoop() + // Start handshake + if err := p.handshake(); err != nil { + return err + } return nil } // sendMessage encodes and sends a message with the given type and payload func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error { + p.sendMu.Lock() + defer p.sendMu.Unlock() if p.conn == nil { return errors.New("connection is not established") } @@ -97,37 +148,67 @@ func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error { return nil } -// receiveMessage receives and decodes messages from the active connection -func (p *Peer) receiveMessage() (Message, error) { - headerBuf := make([]byte, messageHeaderLength) - if _, err := io.ReadFull(p.conn, headerBuf); err != nil { - return nil, fmt.Errorf("read header: %w", err) - } - header := new(msgHeader) - if err := header.Decode(headerBuf); err != nil { - return nil, fmt.Errorf("header decode: %w", err) - } - if header.NetworkMagic != p.networkMagic { - return nil, fmt.Errorf("invalid network magic: %d", header.NetworkMagic) - } - if header.PayloadLength > messageMaxPayloadLength { - return nil, errors.New("payload is too large") - } - payload := make([]byte, header.PayloadLength) - if _, err := io.ReadFull(p.conn, payload); err != nil { - return nil, fmt.Errorf("read payload: %w", err) - } - msg, err := decodeMessage(header, payload) - if err != nil { - // Discard unsupported messages and try to get another message - // This is a bit of a hack - var unsupportedErr UnsupportedMessageTypeError - if errors.As(err, &unsupportedErr) { - return p.receiveMessage() +// recvLoop receives and decodes messages from the active connection +func (p *Peer) recvLoop() { + err := func() error { + // Assign to local var to avoid nil deref panic on shutdown + conn := p.conn + for { + headerBuf := make([]byte, messageHeaderLength) + if _, err := io.ReadFull(conn, headerBuf); err != nil { + return fmt.Errorf("read header: %w", err) + } + header := new(msgHeader) + if err := header.Decode(headerBuf); err != nil { + return fmt.Errorf("header decode: %w", err) + } + if header.NetworkMagic != p.networkMagic { + return fmt.Errorf("invalid network magic: %d", header.NetworkMagic) + } + if header.PayloadLength > messageMaxPayloadLength { + return errors.New("payload is too large") + } + payload := make([]byte, header.PayloadLength) + if _, err := io.ReadFull(conn, payload); err != nil { + return fmt.Errorf("read payload: %w", err) + } + msg, err := decodeMessage(header, payload) + if err != nil { + // Discard unsupported messages and try to get another message + // This is a bit of a hack + var unsupportedErr UnsupportedMessageTypeError + if errors.As(err, &unsupportedErr) { + continue + } + return fmt.Errorf("decode message: %w", err) + } + if err := p.handleMessage(msg); err != nil { + return fmt.Errorf("handle message: %w", err) + } } - return nil, err + }() + if err != nil { + p.errorCh <- err + _ = p.Close() } - return msg, nil +} + +func (p *Peer) handleMessage(msg Message) error { + switch msg.(type) { + case *MsgVersion, *MsgVerack: + p.handshakeCh <- msg + case *MsgAddr: + p.addrCh <- msg + case *MsgHeaders: + p.headersCh <- msg + case *MsgBlock: + p.blockCh <- msg + case *MsgProof: + p.proofCh <- msg + default: + return fmt.Errorf("unknown message type: %T", msg) + } + return nil } // handshake performs the handshake process, which involves exchanging Version messages with the network peer @@ -157,20 +238,28 @@ func (p *Peer) handshake() error { return err } // Wait for Verack response - msg, err := p.receiveMessage() - if err != nil { - return err - } - if _, ok := msg.(*MsgVerack); !ok { - return fmt.Errorf("unexpected message: %T", msg) + select { + case msg := <-p.handshakeCh: + if _, ok := msg.(*MsgVerack); !ok { + return fmt.Errorf("unexpected message: %T", msg) + } + case err := <-p.errorCh: + return fmt.Errorf("handshake failed: %w", err) + case <-time.After(1 * time.Second): + return errors.New("handshake timed out") } // Wait for Version from peer - msg, err = p.receiveMessage() - if err != nil { - return err - } - if _, ok := msg.(*MsgVersion); !ok { - return fmt.Errorf("unexpected message: %T", msg) + select { + case msg := <-p.handshakeCh: + if _, ok := msg.(*MsgVersion); !ok { + return fmt.Errorf("unexpected message: %T", msg) + } + case err := <-p.errorCh: + return fmt.Errorf("handshake failed: %w", err) + case <-p.doneCh: + return errors.New("connection has shut down") + case <-time.After(1 * time.Second): + return errors.New("handshake timed out") } // Send Verack if err := p.sendMessage(MessageVerack, nil); err != nil { @@ -185,15 +274,18 @@ func (p *Peer) GetPeers() ([]NetAddress, error) { return nil, err } // Wait for Addr response - msg, err := p.receiveMessage() - if err != nil { - return nil, err - } - msgAddr, ok := msg.(*MsgAddr) - if !ok { - return nil, fmt.Errorf("unexpected message: %T", msg) + select { + case msg := <-p.addrCh: + msgAddr, ok := msg.(*MsgAddr) + if !ok { + return nil, fmt.Errorf("unexpected message: %T", msg) + } + return msgAddr.Peers, nil + case <-p.doneCh: + return nil, errors.New("connection has shut down") + case <-time.After(5 * time.Second): + return nil, errors.New("timed out") } - return msgAddr.Peers, nil } // GetHeaders requests a list of headers from the network peer @@ -206,15 +298,18 @@ func (p *Peer) GetHeaders(locator [][32]byte, stopHash [32]byte) ([]*handshake.B return nil, err } // Wait for Headers response - msg, err := p.receiveMessage() - if err != nil { - return nil, err - } - msgHeaders, ok := msg.(*MsgHeaders) - if !ok { - return nil, fmt.Errorf("unexpected message: %T", msg) + select { + case msg := <-p.headersCh: + msgHeaders, ok := msg.(*MsgHeaders) + if !ok { + return nil, fmt.Errorf("unexpected message: %T", msg) + } + return msgHeaders.Headers, nil + case <-p.doneCh: + return nil, errors.New("connection has shut down") + case <-time.After(5 * time.Second): + return nil, errors.New("timed out") } - return msgHeaders.Headers, nil } // GetProof requests a proof for a domain name from the network peer @@ -228,15 +323,18 @@ func (p *Peer) GetProof(name string, rootHash [32]byte) (*handshake.Proof, error return nil, err } // Wait for Proof response - msg, err := p.receiveMessage() - if err != nil { - return nil, err - } - msgProof, ok := msg.(*MsgProof) - if !ok { - return nil, fmt.Errorf("unexpected message: %T", msg) + select { + case msg := <-p.proofCh: + msgProof, ok := msg.(*MsgProof) + if !ok { + return nil, fmt.Errorf("unexpected message: %T", msg) + } + return msgProof.Proof, nil + case <-p.doneCh: + return nil, errors.New("connection has shut down") + case <-time.After(5 * time.Second): + return nil, errors.New("timed out") } - return msgProof.Proof, nil } // GetBlock requests the specified block from the network peer @@ -253,13 +351,16 @@ func (p *Peer) GetBlock(hash [32]byte) (*handshake.Block, error) { return nil, err } // Wait for Block response - msg, err := p.receiveMessage() - if err != nil { - return nil, err - } - msgBlock, ok := msg.(*MsgBlock) - if !ok { - return nil, fmt.Errorf("unexpected message: %T", msg) + select { + case msg := <-p.blockCh: + msgBlock, ok := msg.(*MsgBlock) + if !ok { + return nil, fmt.Errorf("unexpected message: %T", msg) + } + return msgBlock.Block, nil + case <-p.doneCh: + return nil, errors.New("connection has shut down") + case <-time.After(5 * time.Second): + return nil, errors.New("timed out") } - return msgBlock.Block, nil }