-
Notifications
You must be signed in to change notification settings - Fork 2
refactor: async Handshake network support #434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ import ( | |
| "fmt" | ||
| "io" | ||
| "net" | ||
| "sync" | ||
| "time" | ||
|
|
||
| "github.com/blinklabs-io/cdnsd/internal/handshake" | ||
|
|
@@ -31,6 +32,16 @@ type Peer struct { | |
| address string | ||
| conn net.Conn | ||
| networkMagic uint32 | ||
| mu sync.Mutex | ||
| sendMu sync.Mutex | ||
| hasConnected bool | ||
| doneCh chan struct{} | ||
| errorCh chan error | ||
| handshakeCh chan Message | ||
| headersCh chan Message | ||
| blockCh chan Message | ||
| addrCh chan Message | ||
| proofCh chan Message | ||
| } | ||
|
|
||
| // NewPeer returns a new Peer using an existing connection (if provided) and the specified network magic. If a connection is provided, | ||
|
|
@@ -39,10 +50,15 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) { | |
| p := &Peer{ | ||
| conn: conn, | ||
| networkMagic: networkMagic, | ||
| doneCh: make(chan struct{}), | ||
| errorCh: make(chan error, 5), | ||
| } | ||
| if conn != nil { | ||
| p.conn = conn | ||
| p.address = conn.RemoteAddr().String() | ||
| if err := p.handshake(); err != nil { | ||
| p.hasConnected = true | ||
| if err := p.setupConnection(); err != nil { | ||
| _ = p.conn.Close() | ||
| return nil, err | ||
| } | ||
| } | ||
|
|
@@ -51,34 +67,69 @@ func NewPeer(conn net.Conn, networkMagic uint32) (*Peer, error) { | |
|
|
||
| // Connect establishes a connection with a peer and performs the handshake process | ||
| func (p *Peer) Connect(address string) error { | ||
| p.mu.Lock() | ||
| defer p.mu.Unlock() | ||
| if p.conn != nil { | ||
| return errors.New("connection already established") | ||
| } | ||
| if p.hasConnected { | ||
| return errors.New("peer cannot be reused after disconnect") | ||
| } | ||
| var err error | ||
| p.conn, err = net.DialTimeout("tcp", address, dialTimeout) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| p.address = address | ||
| if err := p.handshake(); err != nil { | ||
| p.hasConnected = true | ||
| if err := p.setupConnection(); err != nil { | ||
| _ = p.conn.Close() | ||
| return err | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| // Close closes an active connection with a network peer | ||
| func (p *Peer) Close() error { | ||
| p.mu.Lock() | ||
| defer p.mu.Unlock() | ||
| if p.conn == nil { | ||
| return errors.New("connection is not established") | ||
| } | ||
| if err := p.conn.Close(); err != nil { | ||
| return err | ||
| } | ||
| p.conn = nil | ||
| close(p.doneCh) | ||
| return nil | ||
| } | ||
|
|
||
| // ErrorChan returns the async error channel | ||
| func (p *Peer) ErrorChan() <-chan error { | ||
| return p.errorCh | ||
| } | ||
|
|
||
| // setupConnection runs the initial handshake and starts the receive loop | ||
| func (p *Peer) setupConnection() error { | ||
| // Init channels for async messages | ||
| p.handshakeCh = make(chan Message, 10) | ||
| p.headersCh = make(chan Message, 10) | ||
| p.blockCh = make(chan Message, 10) | ||
| p.addrCh = make(chan Message, 10) | ||
| p.proofCh = make(chan Message, 10) | ||
| // Start receive loop | ||
| go p.recvLoop() | ||
| // Start handshake | ||
| if err := p.handshake(); err != nil { | ||
| return err | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| // sendMessage encodes and sends a message with the given type and payload | ||
| func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error { | ||
| p.sendMu.Lock() | ||
| defer p.sendMu.Unlock() | ||
| if p.conn == nil { | ||
| return errors.New("connection is not established") | ||
| } | ||
|
|
@@ -97,37 +148,67 @@ func (p *Peer) sendMessage(msgType uint8, msgPayload Message) error { | |
| return nil | ||
| } | ||
|
|
||
| // receiveMessage receives and decodes messages from the active connection | ||
| func (p *Peer) receiveMessage() (Message, error) { | ||
| headerBuf := make([]byte, messageHeaderLength) | ||
| if _, err := io.ReadFull(p.conn, headerBuf); err != nil { | ||
| return nil, fmt.Errorf("read header: %w", err) | ||
| } | ||
| header := new(msgHeader) | ||
| if err := header.Decode(headerBuf); err != nil { | ||
| return nil, fmt.Errorf("header decode: %w", err) | ||
| } | ||
| if header.NetworkMagic != p.networkMagic { | ||
| return nil, fmt.Errorf("invalid network magic: %d", header.NetworkMagic) | ||
| } | ||
| if header.PayloadLength > messageMaxPayloadLength { | ||
| return nil, errors.New("payload is too large") | ||
| } | ||
| payload := make([]byte, header.PayloadLength) | ||
| if _, err := io.ReadFull(p.conn, payload); err != nil { | ||
| return nil, fmt.Errorf("read payload: %w", err) | ||
| } | ||
| msg, err := decodeMessage(header, payload) | ||
| if err != nil { | ||
| // Discard unsupported messages and try to get another message | ||
| // This is a bit of a hack | ||
| var unsupportedErr UnsupportedMessageTypeError | ||
| if errors.As(err, &unsupportedErr) { | ||
| return p.receiveMessage() | ||
| // recvLoop receives and decodes messages from the active connection | ||
| func (p *Peer) recvLoop() { | ||
| err := func() error { | ||
| // Assign to local var to avoid nil deref panic on shutdown | ||
| conn := p.conn | ||
| for { | ||
| headerBuf := make([]byte, messageHeaderLength) | ||
| if _, err := io.ReadFull(conn, headerBuf); err != nil { | ||
| return fmt.Errorf("read header: %w", err) | ||
| } | ||
| header := new(msgHeader) | ||
| if err := header.Decode(headerBuf); err != nil { | ||
| return fmt.Errorf("header decode: %w", err) | ||
| } | ||
| if header.NetworkMagic != p.networkMagic { | ||
| return fmt.Errorf("invalid network magic: %d", header.NetworkMagic) | ||
| } | ||
| if header.PayloadLength > messageMaxPayloadLength { | ||
| return errors.New("payload is too large") | ||
| } | ||
| payload := make([]byte, header.PayloadLength) | ||
| if _, err := io.ReadFull(conn, payload); err != nil { | ||
| return fmt.Errorf("read payload: %w", err) | ||
| } | ||
| msg, err := decodeMessage(header, payload) | ||
| if err != nil { | ||
| // Discard unsupported messages and try to get another message | ||
| // This is a bit of a hack | ||
| var unsupportedErr UnsupportedMessageTypeError | ||
| if errors.As(err, &unsupportedErr) { | ||
| continue | ||
| } | ||
| return fmt.Errorf("decode message: %w", err) | ||
| } | ||
| if err := p.handleMessage(msg); err != nil { | ||
| return fmt.Errorf("handle message: %w", err) | ||
| } | ||
| } | ||
| return nil, err | ||
| }() | ||
| if err != nil { | ||
| p.errorCh <- err | ||
| _ = p.Close() | ||
| } | ||
| return msg, nil | ||
| } | ||
|
|
||
| func (p *Peer) handleMessage(msg Message) error { | ||
| switch msg.(type) { | ||
| case *MsgVersion, *MsgVerack: | ||
| p.handshakeCh <- msg | ||
| case *MsgAddr: | ||
| p.addrCh <- msg | ||
| case *MsgHeaders: | ||
| p.headersCh <- msg | ||
| case *MsgBlock: | ||
| p.blockCh <- msg | ||
| case *MsgProof: | ||
| p.proofCh <- msg | ||
| default: | ||
| return fmt.Errorf("unknown message type: %T", msg) | ||
agaffney marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Receiving valid request messages (e.g., getaddr/getheaders/getdata/getproof) now triggers handleMessage’s default error, and recvLoop closes the connection as soon as it sees those messages. This causes unnecessary disconnects from compliant peers instead of ignoring or responding to the requests. Prompt for AI agents |
||
| } | ||
| return nil | ||
| } | ||
|
|
||
| // handshake performs the handshake process, which involves exchanging Version messages with the network peer | ||
|
|
@@ -157,20 +238,28 @@ func (p *Peer) handshake() error { | |
| return err | ||
| } | ||
| // Wait for Verack response | ||
| msg, err := p.receiveMessage() | ||
| if err != nil { | ||
| return err | ||
| } | ||
| if _, ok := msg.(*MsgVerack); !ok { | ||
| return fmt.Errorf("unexpected message: %T", msg) | ||
| select { | ||
| case msg := <-p.handshakeCh: | ||
| if _, ok := msg.(*MsgVerack); !ok { | ||
| return fmt.Errorf("unexpected message: %T", msg) | ||
| } | ||
| case err := <-p.errorCh: | ||
| return fmt.Errorf("handshake failed: %w", err) | ||
| case <-time.After(1 * time.Second): | ||
| return errors.New("handshake timed out") | ||
| } | ||
| // Wait for Version from peer | ||
| msg, err = p.receiveMessage() | ||
| if err != nil { | ||
| return err | ||
| } | ||
| if _, ok := msg.(*MsgVersion); !ok { | ||
| return fmt.Errorf("unexpected message: %T", msg) | ||
| select { | ||
| case msg := <-p.handshakeCh: | ||
| if _, ok := msg.(*MsgVersion); !ok { | ||
| return fmt.Errorf("unexpected message: %T", msg) | ||
| } | ||
| case err := <-p.errorCh: | ||
| return fmt.Errorf("handshake failed: %w", err) | ||
| case <-p.doneCh: | ||
| return errors.New("connection has shut down") | ||
| case <-time.After(1 * time.Second): | ||
| return errors.New("handshake timed out") | ||
| } | ||
| // Send Verack | ||
| if err := p.sendMessage(MessageVerack, nil); err != nil { | ||
|
|
@@ -185,15 +274,18 @@ func (p *Peer) GetPeers() ([]NetAddress, error) { | |
| return nil, err | ||
| } | ||
| // Wait for Addr response | ||
| msg, err := p.receiveMessage() | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| msgAddr, ok := msg.(*MsgAddr) | ||
| if !ok { | ||
| return nil, fmt.Errorf("unexpected message: %T", msg) | ||
| select { | ||
| case msg := <-p.addrCh: | ||
| msgAddr, ok := msg.(*MsgAddr) | ||
| if !ok { | ||
| return nil, fmt.Errorf("unexpected message: %T", msg) | ||
| } | ||
| return msgAddr.Peers, nil | ||
| case <-p.doneCh: | ||
| return nil, errors.New("connection has shut down") | ||
| case <-time.After(5 * time.Second): | ||
| return nil, errors.New("timed out") | ||
| } | ||
| return msgAddr.Peers, nil | ||
| } | ||
|
|
||
| // GetHeaders requests a list of headers from the network peer | ||
|
|
@@ -206,15 +298,18 @@ func (p *Peer) GetHeaders(locator [][32]byte, stopHash [32]byte) ([]*handshake.B | |
| return nil, err | ||
| } | ||
| // Wait for Headers response | ||
| msg, err := p.receiveMessage() | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| msgHeaders, ok := msg.(*MsgHeaders) | ||
| if !ok { | ||
| return nil, fmt.Errorf("unexpected message: %T", msg) | ||
| select { | ||
| case msg := <-p.headersCh: | ||
| msgHeaders, ok := msg.(*MsgHeaders) | ||
| if !ok { | ||
| return nil, fmt.Errorf("unexpected message: %T", msg) | ||
| } | ||
| return msgHeaders.Headers, nil | ||
| case <-p.doneCh: | ||
| return nil, errors.New("connection has shut down") | ||
| case <-time.After(5 * time.Second): | ||
| return nil, errors.New("timed out") | ||
| } | ||
| return msgHeaders.Headers, nil | ||
| } | ||
|
|
||
| // GetProof requests a proof for a domain name from the network peer | ||
|
|
@@ -228,15 +323,18 @@ func (p *Peer) GetProof(name string, rootHash [32]byte) (*handshake.Proof, error | |
| return nil, err | ||
| } | ||
| // Wait for Proof response | ||
| msg, err := p.receiveMessage() | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| msgProof, ok := msg.(*MsgProof) | ||
| if !ok { | ||
| return nil, fmt.Errorf("unexpected message: %T", msg) | ||
| select { | ||
| case msg := <-p.proofCh: | ||
| msgProof, ok := msg.(*MsgProof) | ||
| if !ok { | ||
| return nil, fmt.Errorf("unexpected message: %T", msg) | ||
| } | ||
| return msgProof.Proof, nil | ||
| case <-p.doneCh: | ||
| return nil, errors.New("connection has shut down") | ||
| case <-time.After(5 * time.Second): | ||
| return nil, errors.New("timed out") | ||
| } | ||
| return msgProof.Proof, nil | ||
| } | ||
|
|
||
| // GetBlock requests the specified block from the network peer | ||
|
|
@@ -253,13 +351,16 @@ func (p *Peer) GetBlock(hash [32]byte) (*handshake.Block, error) { | |
| return nil, err | ||
| } | ||
| // Wait for Block response | ||
| msg, err := p.receiveMessage() | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| msgBlock, ok := msg.(*MsgBlock) | ||
| if !ok { | ||
| return nil, fmt.Errorf("unexpected message: %T", msg) | ||
| select { | ||
| case msg := <-p.blockCh: | ||
| msgBlock, ok := msg.(*MsgBlock) | ||
| if !ok { | ||
| return nil, fmt.Errorf("unexpected message: %T", msg) | ||
| } | ||
| return msgBlock.Block, nil | ||
| case <-p.doneCh: | ||
| return nil, errors.New("connection has shut down") | ||
| case <-time.After(5 * time.Second): | ||
| return nil, errors.New("timed out") | ||
| } | ||
| return msgBlock.Block, nil | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Close now nils
p.connwhile the recvLoop goroutine still dereferences it, so closing a peer mid-stream can panic when the loop tries to read from a nil connection.Prompt for AI agents