Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions internal/raftengine/etcd/dispatch_report_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package etcd

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"
raftpb "go.etcd.io/raft/v3/raftpb"
)

const testDispatchReportTimeout = 2 * time.Second

// TestPostDispatchReport_DeliversWhenChannelHasSpace verifies that a failure
// report from the dispatch worker is delivered to the event loop through the
// dedicated channel, which is what allows the engine to call
// rawNode.ReportUnreachable / ReportSnapshot from the correct goroutine.
func TestPostDispatchReport_DeliversWhenChannelHasSpace(t *testing.T) {
t.Parallel()
e := &Engine{
dispatchReportCh: make(chan dispatchReport, 4),
closeCh: make(chan struct{}),
}
report := dispatchReport{to: 42, msgType: raftpb.MsgSnap}

e.postDispatchReport(report)

select {
case got := <-e.dispatchReportCh:
require.Equal(t, report, got)
default:
t.Fatal("expected dispatchReport to be delivered to channel")
}
}

// TestPostDispatchReport_DropsWhenChannelFull asserts the non-blocking
// contract: dispatch workers must not stall because the event loop is busy.
// The worst case is an eventually-consistent gap that raft will fix on the
// next heartbeat-driven retry, so dropping under pressure is acceptable.
func TestPostDispatchReport_DropsWhenChannelFull(t *testing.T) {
t.Parallel()
e := &Engine{
dispatchReportCh: make(chan dispatchReport, 1),
closeCh: make(chan struct{}),
}
e.dispatchReportCh <- dispatchReport{to: 1, msgType: raftpb.MsgApp}

done := make(chan struct{})
go func() {
e.postDispatchReport(dispatchReport{to: 2, msgType: raftpb.MsgApp})
close(done)
}()

ctx, cancel := context.WithTimeout(context.Background(), testDispatchReportTimeout)
defer cancel()
select {
case <-done:
case <-ctx.Done():
t.Fatal("postDispatchReport blocked while channel was full")
}
}

// TestPostDispatchReport_AbortsOnClose ensures the worker does not get stuck
// posting reports during shutdown when the channel is full.
func TestPostDispatchReport_AbortsOnClose(t *testing.T) {
t.Parallel()
e := &Engine{
dispatchReportCh: make(chan dispatchReport, 1),
closeCh: make(chan struct{}),
}
e.dispatchReportCh <- dispatchReport{to: 1, msgType: raftpb.MsgApp}
close(e.closeCh)

done := make(chan struct{})
go func() {
e.postDispatchReport(dispatchReport{to: 2, msgType: raftpb.MsgApp})
close(done)
}()

ctx, cancel := context.WithTimeout(context.Background(), testDispatchReportTimeout)
defer cancel()
select {
case <-done:
case <-ctx.Done():
t.Fatal("postDispatchReport did not abort when closeCh was signalled")
}
}
74 changes: 63 additions & 11 deletions internal/raftengine/etcd/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ type Engine struct {
readCh chan readRequest
adminCh chan adminRequest
stepCh chan raftpb.Message
dispatchReportCh chan dispatchReport
peerDispatchers map[uint64]*peerQueues
perPeerQueueSize int
dispatchStopCh chan struct{}
Expand Down Expand Up @@ -297,6 +298,7 @@ func Open(ctx context.Context, cfg OpenConfig) (*Engine, error) {
readCh: make(chan readRequest),
adminCh: make(chan adminRequest),
stepCh: make(chan raftpb.Message, defaultMaxInflightMsg),
dispatchReportCh: make(chan dispatchReport, defaultMaxInflightMsg),
closeCh: make(chan struct{}),
doneCh: make(chan struct{}),
startedCh: make(chan struct{}),
Expand Down Expand Up @@ -766,6 +768,8 @@ func (e *Engine) handleEvent(tick <-chan time.Time) (bool, error) {
e.handleAdmin(req)
case msg := <-e.stepCh:
e.handleStep(msg)
case report := <-e.dispatchReportCh:
e.handleDispatchReport(report)
case result := <-e.snapshotResCh:
if err := e.handleSnapshotResult(result); err != nil {
return false, err
Expand All @@ -774,6 +778,48 @@ func (e *Engine) handleEvent(tick <-chan time.Time) (bool, error) {
return true, nil
}

// dispatchReport is posted by the dispatch workers when a transport send
// to a peer fails; the engine goroutine drains these and informs etcd/raft
// via rawNode so follower Progress leaves StateReplicate / StateSnapshot on
// unreachable peers and does not silently stall.
type dispatchReport struct {
to uint64
msgType raftpb.MessageType
}

func (e *Engine) handleDispatchReport(report dispatchReport) {
if e.rawNode == nil {
return
}
// MsgSnap requires the distinct SnapshotFailure path: raft tracks
// PendingSnapshot in Progress, and only ReportSnapshot clears it.
// All other message types use ReportUnreachable, which transitions the
// peer from StateReplicate to StateProbe so the next heartbeat response
// drives a fresh sendAppend attempt.
if report.msgType == raftpb.MsgSnap {
e.rawNode.ReportSnapshot(report.to, etcdraft.SnapshotFailure)
return
}
e.rawNode.ReportUnreachable(report.to)
}

// postDispatchReport delivers a dispatch failure to the event loop without
// blocking the worker. If the channel is full (unlikely — the buffer is
// sized to MaxInflightMsg), the report is dropped and logged; this is
// acceptable because raft will retry on the next tick and we only need
// eventual consistency between transport state and Progress state.
func (e *Engine) postDispatchReport(report dispatchReport) {
select {
case e.dispatchReportCh <- report:
case <-e.closeCh:
default:
slog.Warn("etcd raft dispatch report dropped (channel full)",
"to", report.to,
"type", report.msgType.String(),
)
}
}

func (e *Engine) handleProposal(req proposalRequest) {
if err := contextErr(req.ctx); err != nil {
req.done <- proposalResult{err: err}
Expand Down Expand Up @@ -2132,18 +2178,24 @@ func (e *Engine) handleDispatchRequest(ctx context.Context, req dispatchRequest)
if err := req.Close(); err != nil {
slog.Error("etcd raft dispatch: failed to close request", "err", err)
}
if dispatchErr != nil && !errors.Is(dispatchErr, ctx.Err()) {
count := e.dispatchErrorCount.Add(1)
if shouldLogDispatchEvent(count) {
slog.Warn("etcd raft outbound dispatch failed",
"node_id", e.nodeID,
"to", req.msg.To,
"type", req.msg.Type.String(),
"dispatch_error_count", count,
"err", dispatchErr,
)
}
if dispatchErr == nil || errors.Is(dispatchErr, ctx.Err()) {
return
}
count := e.dispatchErrorCount.Add(1)
if shouldLogDispatchEvent(count) {
slog.Warn("etcd raft outbound dispatch failed",
"node_id", e.nodeID,
"to", req.msg.To,
"type", req.msg.Type.String(),
"dispatch_error_count", count,
"err", dispatchErr,
)
}
// Inform etcd/raft that the peer is unreachable so Progress transitions
// out of StateReplicate / StateSnapshot. Without this the leader keeps
// Progress stuck and never retries sendAppend/sendSnap for the peer,
// leaving the follower indefinitely stale even after heartbeats resume.
e.postDispatchReport(dispatchReport{to: req.msg.To, msgType: req.msg.Type})
}

func (e *Engine) stopDispatchWorkers() {
Expand Down
49 changes: 45 additions & 4 deletions internal/raftengine/etcd/grpc_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,37 @@ func (t *GRPCTransport) streamFSMSnapshot(ctx context.Context, msg raftpb.Messag
if err != nil {
return errors.WithStack(err)
}
if err := sendSnapshotReaderChunks(stream, header, rc, t.chunkSize()); err != nil {
// Count payload bytes so a sender-side total can be correlated against the
// receiver-side total when a follower fails to restore. A mismatch points
// at transport truncation; a match points at a format/parsing issue.
counter := &countingReadCloser{inner: rc}
if err := sendSnapshotReaderChunks(stream, header, counter, t.chunkSize()); err != nil {
return err
}
_, err = stream.CloseAndRecv()
return errors.WithStack(err)
if _, err := stream.CloseAndRecv(); err != nil {
return errors.WithStack(err)
}
slog.Info("etcd raft snapshot stream sent",
"index", index,
"to", msg.To,
"payload_bytes", counter.n,
)
return nil
}

type countingReadCloser struct {
inner io.ReadCloser
n int64
}

func (c *countingReadCloser) Read(p []byte) (int, error) {
n, err := c.inner.Read(p)
c.n += int64(n)
return n, err //nolint:wrapcheck // preserve io.EOF sentinel identity for callers
}

func (c *countingReadCloser) Close() error {
return c.inner.Close() //nolint:wrapcheck // caller expects the underlying close error verbatim
}

// applyBridgeMode implements the Phase 1 bridge: when MemoryStorage holds a
Expand Down Expand Up @@ -652,6 +678,7 @@ func (t *GRPCTransport) receiveSnapshotStream(stream pb.EtcdRaft_SendSnapshotSer
_ = spool.Close()
}()

var payloadBytes int64
for {
chunk, err := stream.Recv()
if err != nil {
Expand All @@ -660,13 +687,27 @@ func (t *GRPCTransport) receiveSnapshotStream(stream pb.EtcdRaft_SendSnapshotSer
}
return raftpb.Message{}, errors.WithStack(err)
}
payloadBytes += int64(len(chunk.Chunk))
seen, err := appendSnapshotChunk(&metadata, spool, chunk, seenMetadata)
if err != nil {
return raftpb.Message{}, err
}
seenMetadata = seen
if chunk.Final {
return buildSnapshotMessage(metadata, spool, seenMetadata)
msg, err := buildSnapshotMessage(metadata, spool, seenMetadata)
if err != nil {
return raftpb.Message{}, err
}
index := uint64(0)
if msg.Snapshot != nil {
index = msg.Snapshot.Metadata.Index
}
slog.Info("etcd raft snapshot stream received",
"index", index,
"from", msg.From,
"payload_bytes", payloadBytes,
)
return msg, nil
}
}
}
Expand Down
Loading