diff --git a/adapter/distribution_server.go b/adapter/distribution_server.go index 1d1d1bb0..d0abb3fe 100644 --- a/adapter/distribution_server.go +++ b/adapter/distribution_server.go @@ -139,7 +139,7 @@ func (s *DistributionServer) SplitRange(ctx context.Context, req *pb.SplitRangeR s.mu.Lock() defer s.mu.Unlock() - if err := s.verifyCatalogLeader(); err != nil { + if err := s.verifyCatalogLeader(ctx); err != nil { return nil, err } @@ -191,7 +191,7 @@ func (s *DistributionServer) pinReadTS(ts uint64) *kv.ActiveTimestampToken { return s.readTracker.Pin(ts) } -func (s *DistributionServer) verifyCatalogLeader() error { +func (s *DistributionServer) verifyCatalogLeader(ctx context.Context) error { if s.coordinator == nil { return grpcStatusError(codes.FailedPrecondition, errDistributionCoordinatorRequired.Error()) } @@ -199,7 +199,7 @@ func (s *DistributionServer) verifyCatalogLeader() error { if !s.coordinator.IsLeaderForKey(key) { return grpcStatusError(codes.FailedPrecondition, errDistributionNotLeader.Error()) } - if err := s.coordinator.VerifyLeaderForKey(key); err != nil { + if err := s.coordinator.VerifyLeaderForKey(ctx, key); err != nil { return grpcStatusErrorf(codes.FailedPrecondition, "verify catalog leader: %v", err) } return nil diff --git a/adapter/distribution_server_test.go b/adapter/distribution_server_test.go index 96b3cd8b..030116c3 100644 --- a/adapter/distribution_server_test.go +++ b/adapter/distribution_server_test.go @@ -713,7 +713,7 @@ func (s *distributionCoordinatorStub) IsLeader() bool { return s.leader } -func (s *distributionCoordinatorStub) VerifyLeader() error { +func (s *distributionCoordinatorStub) VerifyLeader(context.Context) error { if !s.leader { return kv.ErrLeaderNotFound } @@ -728,7 +728,7 @@ func (s *distributionCoordinatorStub) IsLeaderForKey(_ []byte) bool { return s.leader } -func (s *distributionCoordinatorStub) VerifyLeaderForKey(_ []byte) error { +func (s *distributionCoordinatorStub) VerifyLeaderForKey(_ context.Context, _ []byte) error { if !s.leader { return kv.ErrLeaderNotFound } diff --git a/adapter/dynamodb.go b/adapter/dynamodb.go index 00a18d98..a61de03f 100644 --- a/adapter/dynamodb.go +++ b/adapter/dynamodb.go @@ -377,7 +377,7 @@ func (d *DynamoDBServer) serveDynamoLeaderHealthz(w http.ResponseWriter, r *http return } - if isVerifiedDynamoLeader(d.coordinator) { + if isVerifiedDynamoLeader(r.Context(), d.coordinator) { writeDynamoHealthBody(w, r, http.StatusOK, "ok\n") return } @@ -385,11 +385,11 @@ func (d *DynamoDBServer) serveDynamoLeaderHealthz(w http.ResponseWriter, r *http writeDynamoHealthBody(w, r, http.StatusServiceUnavailable, "not leader\n") } -func isVerifiedDynamoLeader(coordinator kv.Coordinator) bool { +func isVerifiedDynamoLeader(ctx context.Context, coordinator kv.Coordinator) bool { if coordinator == nil || !coordinator.IsLeader() { return false } - return coordinator.VerifyLeader() == nil + return coordinator.VerifyLeader(ctx) == nil } func writeDynamoHealthMethod(w http.ResponseWriter, r *http.Request) bool { diff --git a/adapter/dynamodb_admin.go b/adapter/dynamodb_admin.go index 046c2040..63b05897 100644 --- a/adapter/dynamodb_admin.go +++ b/adapter/dynamodb_admin.go @@ -202,7 +202,7 @@ func (d *DynamoDBServer) AdminCreateTable(ctx context.Context, principal AdminPr if !principal.Role.canWrite() { return nil, ErrAdminForbidden } - if !isVerifiedDynamoLeader(d.coordinator) { + if !isVerifiedDynamoLeader(ctx, d.coordinator) { return nil, ErrAdminNotLeader } legacy, err := buildLegacyCreateTableInput(in) @@ -249,7 +249,7 @@ func (d *DynamoDBServer) AdminDeleteTable(ctx context.Context, principal AdminPr if !principal.Role.canWrite() { return ErrAdminForbidden } - if !isVerifiedDynamoLeader(d.coordinator) { + if !isVerifiedDynamoLeader(ctx, d.coordinator) { return ErrAdminNotLeader } if strings.TrimSpace(name) == "" { diff --git a/adapter/dynamodb_test.go b/adapter/dynamodb_test.go index 1bfaaf57..ebeaa638 100644 --- a/adapter/dynamodb_test.go +++ b/adapter/dynamodb_test.go @@ -1831,8 +1831,8 @@ func (w *testCoordinatorWrapper) IsLeader() bool { return w.inner.IsLeader() } -func (w *testCoordinatorWrapper) VerifyLeader() error { - return w.inner.VerifyLeader() +func (w *testCoordinatorWrapper) VerifyLeader(ctx context.Context) error { + return w.inner.VerifyLeader(ctx) } func (w *testCoordinatorWrapper) RaftLeader() string { @@ -1843,8 +1843,8 @@ func (w *testCoordinatorWrapper) IsLeaderForKey(key []byte) bool { return w.inner.IsLeaderForKey(key) } -func (w *testCoordinatorWrapper) VerifyLeaderForKey(key []byte) error { - return w.inner.VerifyLeaderForKey(key) +func (w *testCoordinatorWrapper) VerifyLeaderForKey(ctx context.Context, key []byte) error { + return w.inner.VerifyLeaderForKey(ctx, key) } func (w *testCoordinatorWrapper) RaftLeaderForKey(key []byte) string { diff --git a/adapter/internal.go b/adapter/internal.go index 1e36854d..e973de6b 100644 --- a/adapter/internal.go +++ b/adapter/internal.go @@ -49,7 +49,7 @@ func (i *Internal) Forward(ctx context.Context, req *pb.ForwardRequest) (*pb.For }, errors.WithStack(err) } - r, err := i.transactionManager.Commit(req.Requests) + r, err := i.transactionManager.Commit(ctx, req.Requests) if err != nil { return &pb.ForwardResponse{ Success: false, diff --git a/adapter/redis.go b/adapter/redis.go index cddb548b..fdeec2c6 100644 --- a/adapter/redis.go +++ b/adapter/redis.go @@ -1287,7 +1287,14 @@ func (r *RedisServer) keys(conn redcon.Conn, cmd redcon.Command) { pattern := cmd.Args[1] if r.coordinator.IsLeader() { - if err := r.coordinator.VerifyLeader(); err != nil { + // Per-call ctx with redisDispatchTimeout instead of the + // long-lived handlerContext: a stalled VerifyLeader on KEYS + // must not pin the command handler indefinitely. The same + // bound the rest of the dispatch path (sadd, set, …) uses; + // see Codex P1 review on PR #749. + ctx, cancel := context.WithTimeout(r.handlerContext(), redisDispatchTimeout) + defer cancel() + if err := r.coordinator.VerifyLeader(ctx); err != nil { conn.WriteError(err.Error()) return } @@ -3254,7 +3261,7 @@ func (r *RedisServer) rangeList(key []byte, startRaw, endRaw []byte) ([]string, return nil, wrongTypeError() } - if err := r.coordinator.VerifyLeaderForKey(key); err != nil { + if err := r.coordinator.VerifyLeaderForKey(r.handlerContext(), key); err != nil { return nil, errors.WithStack(err) } @@ -3510,7 +3517,7 @@ func (r *RedisServer) readValueAt(key []byte, readTS uint64) ([]byte, error) { } if r.coordinator.IsLeaderForKey(key) { - if err := r.coordinator.VerifyLeaderForKey(key); err != nil { + if err := r.coordinator.VerifyLeaderForKey(r.handlerContext(), key); err != nil { return nil, errors.WithStack(err) } v, err := r.store.GetAt(context.Background(), key, readTS) diff --git a/adapter/redis_compat_commands.go b/adapter/redis_compat_commands.go index 4746ac7c..f3e2b1f1 100644 --- a/adapter/redis_compat_commands.go +++ b/adapter/redis_compat_commands.go @@ -1039,7 +1039,7 @@ func (r *RedisServer) dbsize(conn redcon.Conn, _ redcon.Command) { conn.WriteInt(size) return } - if err := r.coordinator.VerifyLeader(); err != nil { + if err := r.coordinator.VerifyLeader(r.handlerContext()); err != nil { conn.WriteError(err.Error()) return } @@ -1144,7 +1144,11 @@ func (r *RedisServer) flushDatabase(conn redcon.Conn, all bool) { defer cancel() if err := r.retryRedisWrite(ctx, func() error { - if err := r.coordinator.VerifyLeader(); err != nil { + // Use the per-call ctx with redisDispatchTimeout, NOT + // handlerContext (the long-lived server baseCtx). FLUSHDB's + // retry budget already lives in ctx; routing it to + // VerifyLeader keeps the whole command bounded. + if err := r.coordinator.VerifyLeader(ctx); err != nil { return fmt.Errorf("verify leader: %w", err) } diff --git a/adapter/redis_compat_helpers.go b/adapter/redis_compat_helpers.go index d3ba8103..5965671a 100644 --- a/adapter/redis_compat_helpers.go +++ b/adapter/redis_compat_helpers.go @@ -860,7 +860,7 @@ func (r *RedisServer) doGetAt(key []byte, readTS uint64, verify bool) ([]byte, e } if r.coordinator.IsLeaderForKey(routingKey) { if verify { - if err := r.coordinator.VerifyLeaderForKey(routingKey); err != nil { + if err := r.coordinator.VerifyLeaderForKey(r.handlerContext(), routingKey); err != nil { return nil, errors.WithStack(err) } } diff --git a/adapter/redis_hello_test.go b/adapter/redis_hello_test.go index 5c6fd15d..b7d205bc 100644 --- a/adapter/redis_hello_test.go +++ b/adapter/redis_hello_test.go @@ -98,12 +98,12 @@ type helloTestCoordinator struct { func (c *helloTestCoordinator) Dispatch(context.Context, *kv.OperationGroup[kv.OP]) (*kv.CoordinateResponse, error) { return &kv.CoordinateResponse{}, nil } -func (c *helloTestCoordinator) IsLeader() bool { return c.isLeader } -func (c *helloTestCoordinator) VerifyLeader() error { return nil } -func (c *helloTestCoordinator) RaftLeader() string { return "" } -func (c *helloTestCoordinator) IsLeaderForKey([]byte) bool { return c.isLeader } -func (c *helloTestCoordinator) VerifyLeaderForKey([]byte) error { return nil } -func (c *helloTestCoordinator) RaftLeaderForKey([]byte) string { return "" } +func (c *helloTestCoordinator) IsLeader() bool { return c.isLeader } +func (c *helloTestCoordinator) VerifyLeader(context.Context) error { return nil } +func (c *helloTestCoordinator) RaftLeader() string { return "" } +func (c *helloTestCoordinator) IsLeaderForKey([]byte) bool { return c.isLeader } +func (c *helloTestCoordinator) VerifyLeaderForKey(context.Context, []byte) error { return nil } +func (c *helloTestCoordinator) RaftLeaderForKey([]byte) string { return "" } func (c *helloTestCoordinator) Clock() *kv.HLC { if c.clock == nil { c.clock = kv.NewHLC() diff --git a/adapter/redis_info_test.go b/adapter/redis_info_test.go index 51732aa9..5db46cf9 100644 --- a/adapter/redis_info_test.go +++ b/adapter/redis_info_test.go @@ -20,12 +20,12 @@ type infoTestCoordinator struct { func (c *infoTestCoordinator) Dispatch(context.Context, *kv.OperationGroup[kv.OP]) (*kv.CoordinateResponse, error) { return &kv.CoordinateResponse{}, nil } -func (c *infoTestCoordinator) IsLeader() bool { return c.isLeader } -func (c *infoTestCoordinator) VerifyLeader() error { return nil } -func (c *infoTestCoordinator) RaftLeader() string { return c.raftLeader } -func (c *infoTestCoordinator) IsLeaderForKey([]byte) bool { return c.isLeader } -func (c *infoTestCoordinator) VerifyLeaderForKey([]byte) error { return nil } -func (c *infoTestCoordinator) RaftLeaderForKey([]byte) string { return c.raftLeader } +func (c *infoTestCoordinator) IsLeader() bool { return c.isLeader } +func (c *infoTestCoordinator) VerifyLeader(context.Context) error { return nil } +func (c *infoTestCoordinator) RaftLeader() string { return c.raftLeader } +func (c *infoTestCoordinator) IsLeaderForKey([]byte) bool { return c.isLeader } +func (c *infoTestCoordinator) VerifyLeaderForKey(context.Context, []byte) error { return nil } +func (c *infoTestCoordinator) RaftLeaderForKey([]byte) string { return c.raftLeader } func (c *infoTestCoordinator) Clock() *kv.HLC { if c.clock == nil { c.clock = kv.NewHLC() diff --git a/adapter/redis_keys_pattern_test.go b/adapter/redis_keys_pattern_test.go index 75010edf..f893fe01 100644 --- a/adapter/redis_keys_pattern_test.go +++ b/adapter/redis_keys_pattern_test.go @@ -29,7 +29,7 @@ func (s *stubAdapterCoordinator) IsLeader() bool { return true } -func (s *stubAdapterCoordinator) VerifyLeader() error { +func (s *stubAdapterCoordinator) VerifyLeader(context.Context) error { s.verifyCalls.Add(1) return s.verifyLeaderErr } @@ -45,7 +45,7 @@ func (s *stubAdapterCoordinator) IsLeaderForKey([]byte) bool { return true } -func (s *stubAdapterCoordinator) VerifyLeaderForKey([]byte) error { +func (s *stubAdapterCoordinator) VerifyLeaderForKey(context.Context, []byte) error { return nil } diff --git a/adapter/redis_retry_test.go b/adapter/redis_retry_test.go index d17742b9..c48884e1 100644 --- a/adapter/redis_retry_test.go +++ b/adapter/redis_retry_test.go @@ -56,7 +56,7 @@ func (c *retryOnceCoordinator) IsLeader() bool { return true } -func (c *retryOnceCoordinator) VerifyLeader() error { +func (c *retryOnceCoordinator) VerifyLeader(context.Context) error { return nil } @@ -68,7 +68,7 @@ func (c *retryOnceCoordinator) IsLeaderForKey([]byte) bool { return true } -func (c *retryOnceCoordinator) VerifyLeaderForKey([]byte) error { +func (c *retryOnceCoordinator) VerifyLeaderForKey(context.Context, []byte) error { return nil } diff --git a/adapter/s3.go b/adapter/s3.go index b8c2822c..8637b438 100644 --- a/adapter/s3.go +++ b/adapter/s3.go @@ -2297,12 +2297,12 @@ func (s *S3Server) maybeProxyToLeader(w http.ResponseWriter, r *http.Request) bo } var leader string if len(key) > 0 { - if s.coordinator.IsLeaderForKey(key) && s.coordinator.VerifyLeaderForKey(key) == nil { + if s.coordinator.IsLeaderForKey(key) && s.coordinator.VerifyLeaderForKey(r.Context(), key) == nil { return false } leader = s.coordinator.RaftLeaderForKey(key) } else { - if s.coordinator.IsLeader() && s.coordinator.VerifyLeader() == nil { + if s.coordinator.IsLeader() && s.coordinator.VerifyLeader(r.Context()) == nil { return false } leader = s.coordinator.RaftLeader() @@ -2420,7 +2420,7 @@ func (s *S3Server) serveS3LeaderHealthz(w http.ResponseWriter, r *http.Request) return true } status, body := http.StatusOK, "ok" - if !s.isVerifiedS3Leader() { + if !s.isVerifiedS3Leader(r.Context()) { status, body = http.StatusServiceUnavailable, "not leader" } w.WriteHeader(status) @@ -2430,11 +2430,11 @@ func (s *S3Server) serveS3LeaderHealthz(w http.ResponseWriter, r *http.Request) return true } -func (s *S3Server) isVerifiedS3Leader() bool { +func (s *S3Server) isVerifiedS3Leader(ctx context.Context) bool { if s.coordinator == nil || !s.coordinator.IsLeader() { return false } - return s.coordinator.VerifyLeader() == nil + return s.coordinator.VerifyLeader(ctx) == nil } // prepareStreamingPutBody wraps r.Body for aws-chunked framed uploads. When diff --git a/adapter/s3_admin.go b/adapter/s3_admin.go index 722b2384..dfaaa9ad 100644 --- a/adapter/s3_admin.go +++ b/adapter/s3_admin.go @@ -226,7 +226,7 @@ func (s *S3Server) AdminCreateBucket(ctx context.Context, principal AdminPrincip if !principal.Role.canWrite() { return nil, ErrAdminForbidden } - if !s.isVerifiedS3Leader() { + if !s.isVerifiedS3Leader(ctx) { return nil, ErrAdminNotLeader } if err := validateS3BucketName(name); err != nil { @@ -314,7 +314,7 @@ func (s *S3Server) AdminPutBucketAcl(ctx context.Context, principal AdminPrincip if !principal.Role.canWrite() { return ErrAdminForbidden } - if !s.isVerifiedS3Leader() { + if !s.isVerifiedS3Leader(ctx) { return ErrAdminNotLeader } acl = adminCanonicalACL(acl) @@ -406,7 +406,7 @@ func (s *S3Server) AdminDeleteBucket(ctx context.Context, principal AdminPrincip if !principal.Role.canWrite() { return ErrAdminForbidden } - if !s.isVerifiedS3Leader() { + if !s.isVerifiedS3Leader(ctx) { return ErrAdminNotLeader } diff --git a/adapter/s3_test.go b/adapter/s3_test.go index 457a540b..cfc595a9 100644 --- a/adapter/s3_test.go +++ b/adapter/s3_test.go @@ -696,7 +696,7 @@ func (c *followerS3Coordinator) IsLeader() bool { return false } -func (c *followerS3Coordinator) VerifyLeader() error { +func (c *followerS3Coordinator) VerifyLeader(context.Context) error { return kv.ErrLeaderNotFound } @@ -729,7 +729,7 @@ func (c *routeAwareS3Coordinator) IsLeaderForKey(key []byte) bool { return c.localForKey(key) } -func (c *routeAwareS3Coordinator) VerifyLeaderForKey(key []byte) error { +func (c *routeAwareS3Coordinator) VerifyLeaderForKey(_ context.Context, key []byte) error { if c.IsLeaderForKey(key) { return nil } diff --git a/adapter/sqs.go b/adapter/sqs.go index acff4ff0..0e96179c 100644 --- a/adapter/sqs.go +++ b/adapter/sqs.go @@ -498,18 +498,18 @@ func (s *SQSServer) serveSQSLeaderHealthz(w http.ResponseWriter, r *http.Request if !writeSQSHealthMethod(w, r) { return } - if isVerifiedSQSLeader(s.coordinator) { + if isVerifiedSQSLeader(r.Context(), s.coordinator) { writeSQSHealthBody(w, r, http.StatusOK, "ok\n") return } writeSQSHealthBody(w, r, http.StatusServiceUnavailable, "not leader\n") } -func isVerifiedSQSLeader(coordinator kv.Coordinator) bool { +func isVerifiedSQSLeader(ctx context.Context, coordinator kv.Coordinator) bool { if coordinator == nil || !coordinator.IsLeader() { return false } - return coordinator.VerifyLeader() == nil + return coordinator.VerifyLeader(ctx) == nil } func writeSQSHealthMethod(w http.ResponseWriter, r *http.Request) bool { diff --git a/adapter/sqs_admin.go b/adapter/sqs_admin.go index 0788ff31..370dd615 100644 --- a/adapter/sqs_admin.go +++ b/adapter/sqs_admin.go @@ -105,7 +105,7 @@ func (s *SQSServer) AdminDeleteQueue(ctx context.Context, principal AdminPrincip if !principal.Role.canWrite() { return ErrAdminForbidden } - if !isVerifiedSQSLeader(s.coordinator) { + if !isVerifiedSQSLeader(ctx, s.coordinator) { return ErrAdminNotLeader } if strings.TrimSpace(name) == "" { diff --git a/docs/design/2026_05_10_proposed_kv_ctx_plumbing.md b/docs/design/2026_05_10_proposed_kv_ctx_plumbing.md new file mode 100644 index 00000000..af94d936 --- /dev/null +++ b/docs/design/2026_05_10_proposed_kv_ctx_plumbing.md @@ -0,0 +1,164 @@ +# 2026-05-10 — Plumb caller context through kv write & VerifyLeader paths + +Status: proposed + +## Problem + +PR #745 capped `verifyLeaderEngine` (`kv/raft_engine.go`) at 5 s as an +incident hotfix: every caller without an upstream context — `LeaderProxy.Commit/Abort`, +`Coordinate.VerifyLeader`, `ShardedCoordinator.VerifyLeader[ForKey]`, the S3 / SQS / admin +`/healthz/leader` handlers — used `context.Background()` and so blocked +indefinitely whenever a ReadIndex round-trip stalled. Goroutine pile-up +collapsed the leader (the 2026-05-08 incident: 20 K goroutines, 1870 % CPU, OOM). + +The 5 s deadline is a defense-in-depth bound. It is not the right answer +for callers that already hold a request context with its own deadline: + +- The Redis / DynamoDB / S3 / SQS dispatch path enters via + `ShardedCoordinator.Dispatch(ctx, …)` and threads `ctx` through + `dispatchTxn`, but the call lands in `g.Txn.Commit(reqs)` — a + `Transactional` method whose interface drops `ctx` on the floor. +- `LeaderProxy.Commit` then calls `verifyLeaderEngine(p.engine)` (no + ctx). The 5 s safety bound applies, but a client whose own deadline + expired 2 s in still pays the full 5 s. +- The healthz handlers have `r.Context()` but the leader-probe interface + (`LeaderProbe.IsVerifiedLeader() bool`) drops it. Caddy's per-probe + budget cannot reach the verify call. + +A second smaller hazard lives at `kv/transaction.go:152`: +`proposer.Propose(context.Background(), b)`. Same shape as the original +verifyLeaderEngine bug, just on the propose path instead of the verify +path. + +## Goals + +1. Pass the caller's `context.Context` end-to-end through the kv write + path: dispatch → `Transactional.Commit/Abort` → `TransactionManager` / + `LeaderProxy` → `verifyLeaderEngine` and `proposer.Propose`. +2. Pass the request context through the leader-probe path: HTTP handler → + `LeaderProbe.IsVerifiedLeader(ctx)` → `Coordinate.VerifyLeader(ctx)` / + `ShardedCoordinator.VerifyLeader[ForKey](ctx)` → engine. +3. Keep PR #745's 5 s bound on the **no-ctx** call site (`verifyLeaderEngine()` + with no argument) as defense-in-depth. The bound is invoked when a future + internal caller is added that genuinely cannot inherit a deadline (lock + resolver, HLC lease) so the regression cannot recur. + +## Non-goals + +- Changing the wire-level deadline of any RPC. Existing client deadlines + are preserved unchanged; this PR only stops dropping them. +- Eliminating `verifyLeaderTimeout`. It stays as the no-ctx fallback's + bound. + +## Surface change + +**Interface signatures (kv-internal, no external API):** + +```go +// kv/transaction.go +type Transactional interface { + Commit(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) + Abort(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) +} +``` + +Implementations updated to take `ctx`: + +- `*TransactionManager` — passes ctx into `applyRequests`, which passes + to `proposer.Propose(ctx, …)`. Replaces the + `proposer.Propose(context.Background(), …)` at the existing + `transaction.go:152`. +- `*LeaderProxy` — passes ctx into `verifyLeaderEngineCtx(ctx, …)` and + into `forwardWithRetry(ctx, …)`. The deadline-budget arithmetic in + `forwardWithRetry` already respects the parent ctx, so no logic + change — only the seed parent shifts from `context.Background()` to + the caller's ctx. +- `*leaseRefreshingTxn` — pure pass-through wrapper. +- `*ShardRouter` — pass-through. + +**Caller plumbing:** + +- `ShardedCoordinator.dispatchSingleShardTxn` gains a `ctx` parameter; + the 6 internal callsites of `g.Txn.Commit(...)` plumb ctx in. +- `applyTxnResolution` (`kv/shard_store.go`) gains `ctx`; called from + `LockResolver.resolveExpiredLock` which already holds a per-cycle + ctx. + +**Verify-leader surface:** + +```go +func (c *Coordinate) VerifyLeader(ctx context.Context) error +func (c *ShardedCoordinator) VerifyLeader(ctx context.Context) error +func (c *ShardedCoordinator) VerifyLeaderForKey(ctx context.Context, key []byte) error +``` + +The `Coordinator` interface methods themselves take ctx — no +parallel `Ctx`-suffixed variants. The 5 s safety bound is now +internal to `verifyLeaderEngineCtx`: when the caller's ctx has no +deadline (Redis server's long-lived `handlerContext`, background +loops, …), the helper applies `verifyLeaderTimeout` as a wrapper. +Callers with a tighter deadline keep theirs because +`context.WithTimeout` picks the earlier of the two expirations. + +**LeaderProbe (`internal/admin/router.go`):** + +```go +type LeaderProbe interface { + IsVerifiedLeader(ctx context.Context) bool +} +``` + +`main_admin.go` implementation calls `coordinate.VerifyLeader(ctx)`. +Admin `/admin/healthz/leader` handler passes `r.Context()`. + +**Adapter healthz (`adapter/s3.go`, `adapter/sqs.go`, `adapter/dynamodb.go`):** + +`isVerifiedS3Leader(ctx)` / `isVerifiedSQSLeader(ctx, coordinator)` / +`isVerifiedDynamoLeader(ctx, coordinator)` take ctx, pass it to +`VerifyLeader(ctx)`. HTTP handlers feed `r.Context()`. + +## Behaviour + +For callers that already had a deadline upstream: + +- A Redis client `BLPOP timeout=2s` whose dispatch lands on a slow + ReadIndex now fails after **2 s** (its own deadline), not 5 s. +- A Caddy active health probe with a 1 s budget likewise fails after + 1 s, not 5 s. + +For internal background callers without an upstream deadline: + +- LockResolver, HLC lease, etc. continue to hit + `verifyLeaderEngine()` (the no-arg variant) which still wraps with + `context.WithTimeout(context.Background(), verifyLeaderTimeout)`. + PR #745's 5 s bound stays as their safety net. + +For misuse cases: + +- A future code path that adds a caller without inheriting ctx and + uses `context.Background()` directly bypasses both the wrapper and + the 5 s bound; this is the same exposure the ecosystem accepts in + general (passing Background is a code smell, and the linter flags + it). The 5 s bound only protects the official no-ctx wrapper. + +## Self-review checklist (kept brief; expanded in the PR body) + +1. Data loss — no proposal-path change beyond ctx; `Propose(ctx, …)` + semantics on cancellation match upstream raftengine, which already + handles `ctx.Err()` as a transient `errProposalCancelled`. +2. Concurrency — ctx is value-passed, not shared mutable state. +3. Performance — no extra round-trip; same number of calls. `WithTimeout` + in the no-ctx wrapper is unchanged. +4. Data consistency — verify is a freshness check, not a write path; + shorter deadlines just surface ErrLeaderNotFound earlier. +5. Test coverage — interface change ripples through 3 test stubs + (`stubTransactional`, `scriptedTransactional`, `fakeTM`); each + gains a `ctx context.Context` parameter that is currently + unused but available for future tests asserting cancel propagation. + +## Rollout + +Single PR, follow-up to merged #745 / #746 / #747. No design-deferred +milestones; all four layers (`Transactional`, `Coordinate.VerifyLeader`, +`LeaderProbe`, healthz handlers) ship together because the value +chains end-to-end. diff --git a/internal/admin/router.go b/internal/admin/router.go index 2c86ab93..a7803d86 100644 --- a/internal/admin/router.go +++ b/internal/admin/router.go @@ -1,6 +1,7 @@ package admin import ( + "context" "errors" "io" "io/fs" @@ -59,15 +60,15 @@ const ( // the operational "503 not leader" state. Mirrors the S3/DynamoDB // /healthz/leader contract. type LeaderProbe interface { - IsVerifiedLeader() bool + IsVerifiedLeader(ctx context.Context) bool } // LeaderProbeFunc is a convenience adapter for wiring a plain function // without defining an interface implementation. Mirrors ClusterInfoFunc. -type LeaderProbeFunc func() bool +type LeaderProbeFunc func(ctx context.Context) bool // IsVerifiedLeader implements LeaderProbe. -func (f LeaderProbeFunc) IsVerifiedLeader() bool { return f() } +func (f LeaderProbeFunc) IsVerifiedLeader(ctx context.Context) bool { return f(ctx) } // APIHandler is the bridge between the router and all JSON API endpoints. // Everything under /admin/api/v1/ resolves through it; individual endpoint @@ -263,7 +264,7 @@ func (rt *Router) serveLeaderHealth(w http.ResponseWriter, r *http.Request) { return } status, body := http.StatusOK, "ok\n" - if !rt.leader.IsVerifiedLeader() { + if !rt.leader.IsVerifiedLeader(r.Context()) { status, body = http.StatusServiceUnavailable, "not leader\n" } w.Header().Set("Content-Type", "text/plain; charset=utf-8") diff --git a/internal/admin/router_test.go b/internal/admin/router_test.go index 75518226..d7c2c00d 100644 --- a/internal/admin/router_test.go +++ b/internal/admin/router_test.go @@ -1,6 +1,7 @@ package admin import ( + "context" "net/http" "net/http/httptest" "strings" @@ -71,7 +72,7 @@ func TestRouter_HealthzRejectsPost(t *testing.T) { // /healthz/leader contract so a multi-protocol load balancer // sees identical semantics. func TestRouter_HealthzLeader_ReturnsOKWhenLeader(t *testing.T) { - r := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func() bool { return true })) + r := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func(context.Context) bool { return true })) req := httptest.NewRequest(http.MethodGet, "/admin/healthz/leader", nil) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) @@ -88,7 +89,7 @@ func TestRouter_HealthzLeader_ReturnsOKWhenLeader(t *testing.T) { // rotation when it loses leadership; the body string is informative // for operators reading curl output. func TestRouter_HealthzLeader_Returns503WhenNotLeader(t *testing.T) { - r := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func() bool { return false })) + r := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func(context.Context) bool { return false })) req := httptest.NewRequest(http.MethodGet, "/admin/healthz/leader", nil) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) @@ -102,14 +103,14 @@ func TestRouter_HealthzLeader_Returns503WhenNotLeader(t *testing.T) { // healthz HEAD test. The status code must still indicate the // leader state; only the body is suppressed. func TestRouter_HealthzLeader_HeadOmitsBody(t *testing.T) { - rLeader := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func() bool { return true })) + rLeader := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func(context.Context) bool { return true })) req := httptest.NewRequest(http.MethodHead, "/admin/healthz/leader", nil) rec := httptest.NewRecorder() rLeader.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "", rec.Body.String()) - rFollower := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func() bool { return false })) + rFollower := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func(context.Context) bool { return false })) req = httptest.NewRequest(http.MethodHead, "/admin/healthz/leader", nil) rec = httptest.NewRecorder() rFollower.ServeHTTP(rec, req) @@ -123,7 +124,7 @@ func TestRouter_HealthzLeader_HeadOmitsBody(t *testing.T) { // §6.5.5 — load balancers and synthetic-monitor tools key off this // header to discover supported verbs. func TestRouter_HealthzLeader_RejectsPost(t *testing.T) { - r := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func() bool { return true })) + r := NewRouterWithLeaderProbe(nil, nil, LeaderProbeFunc(func(context.Context) bool { return true })) req := httptest.NewRequest(http.MethodPost, "/admin/healthz/leader", strings.NewReader("")) rec := httptest.NewRecorder() r.ServeHTTP(rec, req) @@ -148,7 +149,7 @@ func TestRouter_405_AllowHeader(t *testing.T) { {"asset", "/admin/assets/app.js"}, {"spa", "/admin/somewhere"}, } - r := NewRouterWithLeaderProbe(nil, newTestStatic(), LeaderProbeFunc(func() bool { return true })) + r := NewRouterWithLeaderProbe(nil, newTestStatic(), LeaderProbeFunc(func(context.Context) bool { return true })) for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { t.Parallel() @@ -185,7 +186,7 @@ func TestRouter_HealthzLeader_NilProbeReturns404(t *testing.T) { // balancer probing the path to see HTML 200 forever and never // detect a leadership change. func TestRouter_HealthzLeader_NotSwallowedBySPA(t *testing.T) { - probe := LeaderProbeFunc(func() bool { return false }) + probe := LeaderProbeFunc(func(context.Context) bool { return false }) r := NewRouterWithLeaderProbe(nil, newTestStatic(), probe) req := httptest.NewRequest(http.MethodGet, "/admin/healthz/leader", nil) rec := httptest.NewRecorder() diff --git a/kv/coordinator.go b/kv/coordinator.go index 6a18899b..752c1c5f 100644 --- a/kv/coordinator.go +++ b/kv/coordinator.go @@ -191,11 +191,11 @@ var _ Coordinator = (*Coordinate)(nil) type Coordinator interface { Dispatch(ctx context.Context, reqs *OperationGroup[OP]) (*CoordinateResponse, error) IsLeader() bool - VerifyLeader() error + VerifyLeader(ctx context.Context) error LinearizableRead(ctx context.Context) (uint64, error) RaftLeader() string IsLeaderForKey(key []byte) bool - VerifyLeaderForKey(key []byte) error + VerifyLeaderForKey(ctx context.Context, key []byte) error RaftLeaderForKey(key []byte) string Clock() *HLC } @@ -462,9 +462,9 @@ func (c *Coordinate) dispatchOnce(ctx context.Context, reqs *OperationGroup[OP]) var resp *CoordinateResponse var err error if reqs.IsTxn { - resp, err = c.dispatchTxn(reqs.Elems, reqs.ReadKeys, reqs.StartTS, reqs.CommitTS) + resp, err = c.dispatchTxn(ctx, reqs.Elems, reqs.ReadKeys, reqs.StartTS, reqs.CommitTS) } else { - resp, err = c.dispatchRaw(reqs.Elems) + resp, err = c.dispatchRaw(ctx, reqs.Elems) } c.refreshLeaseAfterDispatch(resp, err, dispatchStart, expectedGen) return resp, err @@ -613,8 +613,8 @@ func (c *Coordinate) IsLeaderAcceptingWrites() bool { return isLeaderAcceptingWrites(c.engine) } -func (c *Coordinate) VerifyLeader() error { - return verifyLeaderEngine(c.engine) +func (c *Coordinate) VerifyLeader(ctx context.Context) error { + return verifyLeaderEngineCtx(ctx, c.engine) } // RaftLeader returns the current leader's address as known by this node. @@ -673,8 +673,8 @@ func (c *Coordinate) IsLeaderForKey(_ []byte) bool { return c.IsLeader() } -func (c *Coordinate) VerifyLeaderForKey(_ []byte) error { - return c.VerifyLeader() +func (c *Coordinate) VerifyLeaderForKey(ctx context.Context, _ []byte) error { + return c.VerifyLeader(ctx) } func (c *Coordinate) RaftLeaderForKey(_ []byte) string { @@ -798,7 +798,7 @@ func (c *Coordinate) nextStartTS() uint64 { return c.clock.Next() } -func (c *Coordinate) dispatchTxn(reqs []*Elem[OP], readKeys [][]byte, startTS uint64, commitTS uint64) (*CoordinateResponse, error) { +func (c *Coordinate) dispatchTxn(ctx context.Context, reqs []*Elem[OP], readKeys [][]byte, startTS uint64, commitTS uint64) (*CoordinateResponse, error) { if len(readKeys) > maxReadKeys { return nil, errors.WithStack(ErrInvalidRequest) } @@ -828,7 +828,7 @@ func (c *Coordinate) dispatchTxn(reqs []*Elem[OP], readKeys [][]byte, startTS ui // and FSM application. The adapter's validateReadSet is kept as a fast // path to fail early without a Raft round-trip, but the FSM check is // the authoritative, serializable validation. - r, err := c.transactionManager.Commit([]*pb.Request{ + r, err := c.transactionManager.Commit(ctx, []*pb.Request{ onePhaseTxnRequest(startTS, commitTS, primary, reqs, readKeys), }) if err != nil { @@ -840,7 +840,7 @@ func (c *Coordinate) dispatchTxn(reqs []*Elem[OP], readKeys [][]byte, startTS ui }, nil } -func (c *Coordinate) dispatchRaw(req []*Elem[OP]) (*CoordinateResponse, error) { +func (c *Coordinate) dispatchRaw(ctx context.Context, req []*Elem[OP]) (*CoordinateResponse, error) { muts := make([]*pb.Mutation, 0, len(req)) for _, elem := range req { muts = append(muts, elemToMutation(elem)) @@ -853,7 +853,7 @@ func (c *Coordinate) dispatchRaw(req []*Elem[OP]) (*CoordinateResponse, error) { Mutations: muts, }} - r, err := c.transactionManager.Commit(logs) + r, err := c.transactionManager.Commit(ctx, logs) if err != nil { return nil, errors.WithStack(err) } diff --git a/kv/coordinator_leader_test.go b/kv/coordinator_leader_test.go index 1df197a1..05e0cce9 100644 --- a/kv/coordinator_leader_test.go +++ b/kv/coordinator_leader_test.go @@ -1,6 +1,7 @@ package kv import ( + "context" "testing" "github.com/bootjp/elastickv/store" @@ -15,12 +16,12 @@ func TestCoordinateVerifyLeader_LeaderReturnsNil(t *testing.T) { t.Cleanup(stop) c := NewCoordinatorWithEngine(&stubTransactional{}, r) - require.NoError(t, c.VerifyLeader()) + require.NoError(t, c.VerifyLeader(context.Background())) } func TestCoordinateVerifyLeader_NilRaftReturnsLeaderNotFound(t *testing.T) { t.Parallel() c := &Coordinate{} - require.ErrorIs(t, c.VerifyLeader(), ErrLeaderNotFound) + require.ErrorIs(t, c.VerifyLeader(context.Background()), ErrLeaderNotFound) } diff --git a/kv/coordinator_retry_test.go b/kv/coordinator_retry_test.go index 3378fa5e..3d963708 100644 --- a/kv/coordinator_retry_test.go +++ b/kv/coordinator_retry_test.go @@ -54,7 +54,7 @@ type scriptedTransactional struct { onCommit func(call uint64) // optional hook invoked inside Commit } -func (s *scriptedTransactional) Commit(reqs []*pb.Request) (*TransactionResponse, error) { +func (s *scriptedTransactional) Commit(_ context.Context, reqs []*pb.Request) (*TransactionResponse, error) { idx := s.commits.Add(1) - 1 s.reqs = append(s.reqs, reqs) if s.onCommit != nil { @@ -66,7 +66,7 @@ func (s *scriptedTransactional) Commit(reqs []*pb.Request) (*TransactionResponse return &TransactionResponse{CommitIndex: idx + 1}, nil } -func (s *scriptedTransactional) Abort([]*pb.Request) (*TransactionResponse, error) { +func (s *scriptedTransactional) Abort(context.Context, []*pb.Request) (*TransactionResponse, error) { return &TransactionResponse{}, nil } diff --git a/kv/coordinator_txn_test.go b/kv/coordinator_txn_test.go index 4385c878..552fb495 100644 --- a/kv/coordinator_txn_test.go +++ b/kv/coordinator_txn_test.go @@ -1,6 +1,7 @@ package kv import ( + "context" "testing" pb "github.com/bootjp/elastickv/proto" @@ -12,13 +13,13 @@ type stubTransactional struct { reqs [][]*pb.Request } -func (s *stubTransactional) Commit(reqs []*pb.Request) (*TransactionResponse, error) { +func (s *stubTransactional) Commit(_ context.Context, reqs []*pb.Request) (*TransactionResponse, error) { s.commits++ s.reqs = append(s.reqs, reqs) return &TransactionResponse{}, nil } -func (s *stubTransactional) Abort(_ []*pb.Request) (*TransactionResponse, error) { +func (s *stubTransactional) Abort(_ context.Context, _ []*pb.Request) (*TransactionResponse, error) { return &TransactionResponse{}, nil } @@ -34,7 +35,7 @@ func TestCoordinateDispatchTxn_RejectsNonMonotonicCommitTS(t *testing.T) { startTS := ^uint64(0) c.clock.Observe(startTS) - _, err := c.dispatchTxn([]*Elem[OP]{ + _, err := c.dispatchTxn(context.Background(), []*Elem[OP]{ {Op: Put, Key: []byte("k"), Value: []byte("v")}, }, nil, startTS, 0) require.ErrorIs(t, err, ErrTxnCommitTSRequired) @@ -50,7 +51,7 @@ func TestCoordinateDispatchTxn_RejectsMissingPrimaryKey(t *testing.T) { clock: NewHLC(), } - _, err := c.dispatchTxn([]*Elem[OP]{ + _, err := c.dispatchTxn(context.Background(), []*Elem[OP]{ {Op: Put, Key: nil, Value: []byte("v")}, }, nil, 1, 0) require.ErrorIs(t, err, ErrTxnPrimaryKeyRequired) @@ -67,7 +68,7 @@ func TestCoordinateDispatchTxn_UsesOnePhaseRequest(t *testing.T) { } startTS := uint64(10) - _, err := c.dispatchTxn([]*Elem[OP]{ + _, err := c.dispatchTxn(context.Background(), []*Elem[OP]{ {Op: Put, Key: []byte("b"), Value: []byte("v1")}, {Op: Del, Key: []byte("x")}, }, nil, startTS, 0) @@ -105,7 +106,7 @@ func TestCoordinateDispatchTxn_UsesProvidedCommitTS(t *testing.T) { startTS := uint64(10) commitTS := uint64(25) - _, err := c.dispatchTxn([]*Elem[OP]{ + _, err := c.dispatchTxn(context.Background(), []*Elem[OP]{ {Op: Put, Key: []byte("k"), Value: []byte("v")}, }, nil, startTS, commitTS) require.NoError(t, err) @@ -127,7 +128,7 @@ func TestCoordinateDispatchTxn_PassesReadKeysToRaftEntry(t *testing.T) { } readKeys := [][]byte{[]byte("rk1"), []byte("rk2")} - _, err := c.dispatchTxn([]*Elem[OP]{ + _, err := c.dispatchTxn(context.Background(), []*Elem[OP]{ {Op: Put, Key: []byte("k"), Value: []byte("v")}, }, readKeys, 10, 0) require.NoError(t, err) diff --git a/kv/leader_proxy.go b/kv/leader_proxy.go index 53b81210..2428c1bf 100644 --- a/kv/leader_proxy.go +++ b/kv/leader_proxy.go @@ -39,26 +39,30 @@ func NewLeaderProxyWithEngine(engine raftengine.Engine, opts ...TransactionOptio } } -func (p *LeaderProxy) Commit(reqs []*pb.Request) (*TransactionResponse, error) { +func (p *LeaderProxy) Commit(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { if !isLeaderEngine(p.engine) { - return p.forwardWithRetry(reqs) + return p.forwardWithRetry(ctx, reqs) } // Verify leadership with a quorum to avoid accepting writes on a stale leader. - if err := verifyLeaderEngine(p.engine); err != nil { - return p.forwardWithRetry(reqs) - } - return p.tm.Commit(reqs) + // The caller's ctx (via verifyLeaderEngineCtx) bounds the ReadIndex + // round-trip; verifyLeaderEngine's no-arg variant remains as the + // background-caller fallback (#745) but is no longer hit on the + // dispatch hot path. + if err := verifyLeaderEngineCtx(ctx, p.engine); err != nil { + return p.forwardWithRetry(ctx, reqs) + } + return p.tm.Commit(ctx, reqs) } -func (p *LeaderProxy) Abort(reqs []*pb.Request) (*TransactionResponse, error) { +func (p *LeaderProxy) Abort(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { if !isLeaderEngine(p.engine) { - return p.forwardWithRetry(reqs) + return p.forwardWithRetry(ctx, reqs) } // Verify leadership with a quorum to avoid accepting aborts on a stale leader. - if err := verifyLeaderEngine(p.engine); err != nil { - return p.forwardWithRetry(reqs) + if err := verifyLeaderEngineCtx(ctx, p.engine); err != nil { + return p.forwardWithRetry(ctx, reqs) } - return p.tm.Abort(reqs) + return p.tm.Abort(ctx, reqs) } // forwardWithRetry attempts to forward to the leader, re-fetching the @@ -86,7 +90,7 @@ func (p *LeaderProxy) Abort(reqs []*pb.Request) (*TransactionResponse, error) { // that second bound, a single forward() could run for the full 5s RPC // timeout AFTER the budget expired, pushing total latency well past // leaderProxyRetryBudget. -func (p *LeaderProxy) forwardWithRetry(reqs []*pb.Request) (*TransactionResponse, error) { +func (p *LeaderProxy) forwardWithRetry(callerCtx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { if len(reqs) == 0 { return &TransactionResponse{}, nil } @@ -95,8 +99,10 @@ func (p *LeaderProxy) forwardWithRetry(reqs []*pb.Request) (*TransactionResponse // Parent context carries the retry deadline so forward()'s per-call // timeout (derived via context.WithTimeout(parentCtx, ...)) can // never extend past it — context.WithTimeout picks the earlier of - // the two expirations. - parentCtx, cancelParent := context.WithDeadline(context.Background(), deadline) + // the two expirations. callerCtx is the dispatch handler's own + // context, so a Redis client whose deadline expires before the + // retry budget exits early without waiting out the full 5 s. + parentCtx, cancelParent := context.WithDeadline(callerCtx, deadline) defer cancelParent() var lastErr error diff --git a/kv/leader_proxy_test.go b/kv/leader_proxy_test.go index 92a26360..84112747 100644 --- a/kv/leader_proxy_test.go +++ b/kv/leader_proxy_test.go @@ -81,7 +81,7 @@ func TestLeaderProxy_CommitLocalWhenLeader(t *testing.T) { }, }, } - resp, err := p.Commit(reqs) + resp, err := p.Commit(context.Background(), reqs) require.NoError(t, err) require.NotNil(t, resp) require.Greater(t, resp.CommitIndex, uint64(0)) @@ -135,7 +135,7 @@ func TestLeaderProxy_ForwardsWhenFollower(t *testing.T) { }, } - resp, err := p.Commit(reqs) + resp, err := p.Commit(context.Background(), reqs) require.NoError(t, err) require.Equal(t, uint64(123), resp.CommitIndex) @@ -245,7 +245,7 @@ func TestLeaderProxy_ForwardsAfterLeaderPublishes(t *testing.T) { eng.setLeader(lis.Addr().String()) }() - resp, err := p.Commit(reqs) + resp, err := p.Commit(context.Background(), reqs) elapsed := time.Since(start) require.NoError(t, err) require.Equal(t, uint64(42), resp.CommitIndex) @@ -289,7 +289,7 @@ func TestLeaderProxy_FailsAfterLeaderBudgetElapses(t *testing.T) { } start := time.Now() - _, err := p.Commit(reqs) + _, err := p.Commit(context.Background(), reqs) elapsed := time.Since(start) require.Error(t, err) require.ErrorIs(t, err, ErrLeaderNotFound) diff --git a/kv/leader_routed_store.go b/kv/leader_routed_store.go index 9ecc8bfa..a97595d9 100644 --- a/kv/leader_routed_store.go +++ b/kv/leader_routed_store.go @@ -53,7 +53,7 @@ func (s *LeaderRoutedStore) leaderFenceTS(ctx context.Context, key []byte) (bool if !s.coordinator.IsLeaderForKey(key) { return false, 0 } - return s.coordinator.VerifyLeaderForKey(key) == nil, 0 + return s.coordinator.VerifyLeaderForKey(ctx, key) == nil, 0 } // leaderOKForKey returns whether the local store is authoritative for key. diff --git a/kv/leader_routed_store_test.go b/kv/leader_routed_store_test.go index b30a9a44..1d83ead8 100644 --- a/kv/leader_routed_store_test.go +++ b/kv/leader_routed_store_test.go @@ -29,7 +29,7 @@ func (s *stubLeaderCoordinator) IsLeader() bool { return s.isLeader } -func (s *stubLeaderCoordinator) VerifyLeader() error { +func (s *stubLeaderCoordinator) VerifyLeader(context.Context) error { return s.verify } @@ -41,7 +41,7 @@ func (s *stubLeaderCoordinator) IsLeaderForKey([]byte) bool { return s.isLeader } -func (s *stubLeaderCoordinator) VerifyLeaderForKey([]byte) error { +func (s *stubLeaderCoordinator) VerifyLeaderForKey(context.Context, []byte) error { return s.verify } diff --git a/kv/lock_resolver.go b/kv/lock_resolver.go index 015d343a..d5a02599 100644 --- a/kv/lock_resolver.go +++ b/kv/lock_resolver.go @@ -155,13 +155,13 @@ func (lr *LockResolver) resolveExpiredLock(ctx context.Context, g *ShardGroup, u switch status { case txnStatusCommitted: - return applyTxnResolution(g, pb.Phase_COMMIT, lock.StartTS, commitTS, lock.PrimaryKey, [][]byte{userKey}) + return applyTxnResolution(ctx, g, pb.Phase_COMMIT, lock.StartTS, commitTS, lock.PrimaryKey, [][]byte{userKey}) case txnStatusRolledBack: abortTS := abortTSFrom(lock.StartTS, commitTS) if abortTS <= lock.StartTS { return nil // cannot represent abort timestamp, skip } - return applyTxnResolution(g, pb.Phase_ABORT, lock.StartTS, abortTS, lock.PrimaryKey, [][]byte{userKey}) + return applyTxnResolution(ctx, g, pb.Phase_ABORT, lock.StartTS, abortTS, lock.PrimaryKey, [][]byte{userKey}) case txnStatusPending: // Lock is expired but primary is still pending — the primary's // tryAbortExpiredPrimary inside primaryTxnStatus should have diff --git a/kv/lock_resolver_test.go b/kv/lock_resolver_test.go index 9d37771c..33e26606 100644 --- a/kv/lock_resolver_test.go +++ b/kv/lock_resolver_test.go @@ -44,7 +44,7 @@ func setupLockResolverEnv(t *testing.T) (*LockResolver, *ShardStore, map[uint64] // prepareLock writes a PREPARE request (which creates a lock) for a key. func prepareLock(t *testing.T, g *ShardGroup, startTS uint64, key, primaryKey, value []byte, lockTTLms uint64) { t.Helper() - _, err := g.Txn.Commit([]*pb.Request{{ + _, err := g.Txn.Commit(context.Background(), []*pb.Request{{ IsTxn: true, Phase: pb.Phase_PREPARE, Ts: startTS, @@ -63,7 +63,7 @@ func prepareLock(t *testing.T, g *ShardGroup, startTS uint64, key, primaryKey, v // commitPrimary writes a COMMIT record for a transaction's primary key. func commitPrimary(t *testing.T, g *ShardGroup, startTS, commitTS uint64, primaryKey []byte) { t.Helper() - _, err := g.Txn.Commit([]*pb.Request{{ + _, err := g.Txn.Commit(context.Background(), []*pb.Request{{ IsTxn: true, Phase: pb.Phase_COMMIT, Ts: startTS, @@ -156,7 +156,7 @@ func TestLockResolver_ResolvesExpiredRolledBackLock(t *testing.T) { prepareLock(t, groups[2], startTS, secondaryKey, primaryKey, []byte("v2"), 0) // Abort the primary. - _, err := groups[1].Txn.Commit([]*pb.Request{{ + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{{ IsTxn: true, Phase: pb.Phase_ABORT, Ts: startTS, diff --git a/kv/raft_engine.go b/kv/raft_engine.go index f896dc40..4eedebe8 100644 --- a/kv/raft_engine.go +++ b/kv/raft_engine.go @@ -55,6 +55,20 @@ func verifyLeaderEngineCtx(ctx context.Context, engine raftengine.LeaderView) er if engine == nil { return errors.WithStack(ErrLeaderNotFound) } + // Defense-in-depth: if the caller's context carries no deadline (the + // Redis server's long-lived handlerContext, gemini-flagged background + // loops, or any future caller that passes context.Background), wrap + // with verifyLeaderTimeout so a stalled ReadIndex still surfaces + // fail-fast — same bound the no-arg verifyLeaderEngine wrapper has + // provided since #745. Callers that already set a tighter deadline + // (Redis dispatch ctx with redisDispatchTimeout, healthz + // r.Context()) keep theirs: context.WithTimeout picks the earlier of + // the two expirations. + if _, ok := ctx.Deadline(); !ok { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, verifyLeaderTimeout) + defer cancel() + } return errors.WithStack(engine.VerifyLeader(ctx)) } diff --git a/kv/shard_router.go b/kv/shard_router.go index deaba662..4f6cd094 100644 --- a/kv/shard_router.go +++ b/kv/shard_router.go @@ -152,16 +152,16 @@ func (s *ShardRouter) Register(group uint64, tm Transactional, st store.MVCCStor s.groups[group] = &routerGroup{tm: tm, store: st} } -func (s *ShardRouter) Commit(reqs []*pb.Request) (*TransactionResponse, error) { +func (s *ShardRouter) Commit(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { return s.process(reqs, func(g *routerGroup, rs []*pb.Request) (*TransactionResponse, error) { - return g.tm.Commit(rs) + return g.tm.Commit(ctx, rs) }) } // Abort dispatches aborts to the correct raft group. -func (s *ShardRouter) Abort(reqs []*pb.Request) (*TransactionResponse, error) { +func (s *ShardRouter) Abort(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { return s.process(reqs, func(g *routerGroup, rs []*pb.Request) (*TransactionResponse, error) { - return g.tm.Abort(rs) + return g.tm.Abort(ctx, rs) }) } diff --git a/kv/shard_router_partition_test.go b/kv/shard_router_partition_test.go index c5fd86fa..da24ef38 100644 --- a/kv/shard_router_partition_test.go +++ b/kv/shard_router_partition_test.go @@ -69,7 +69,7 @@ func TestShardRouter_PartitionResolverWins(t *testing.T) { reqs := []*pb.Request{ {IsTxn: false, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("resolver-key"), Value: []byte("v")}}}, } - resp, err := router.Commit(reqs) + resp, err := router.Commit(context.Background(), reqs) require.NoError(t, err) require.NotNil(t, resp) // Verify: the request landed on group 42's fake txn, not 1's. @@ -102,7 +102,7 @@ func TestShardRouter_PartitionResolverFallsThrough(t *testing.T) { router.Register(2, &fakeTxn{id: 2, sink: &sink}, s2) // "b" is in the engine's [a, m) range → group 1. - resp1, err1 := router.Commit([]*pb.Request{ + resp1, err1 := router.Commit(context.Background(), []*pb.Request{ {IsTxn: false, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v")}}}, }) require.NoError(t, err1) @@ -111,7 +111,7 @@ func TestShardRouter_PartitionResolverFallsThrough(t *testing.T) { "engine [a,m) range must route to group 1") // "x" is in the engine's [m, ∞) range → group 2. - resp2, err2 := router.Commit([]*pb.Request{ + resp2, err2 := router.Commit(context.Background(), []*pb.Request{ {IsTxn: false, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v")}}}, }) require.NoError(t, err2) @@ -138,7 +138,7 @@ func TestShardRouter_NilPartitionResolverIsNoOp(t *testing.T) { // With no resolver installed, the engine's default route owns // the request — group 7 dispatches. - resp, err := router.Commit([]*pb.Request{ + resp, err := router.Commit(context.Background(), []*pb.Request{ {IsTxn: false, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("any"), Value: []byte("v")}}}, }) require.NoError(t, err) @@ -173,7 +173,7 @@ func TestShardRouter_ResolverSeesRawKeyNotNormalized(t *testing.T) { router.Register(1, &fakeTxn{id: 1, sink: &sink}, store.NewMVCCStore()) router.Register(42, &fakeTxn{id: 42, sink: &sink}, store.NewMVCCStore()) - resp, err := router.Commit([]*pb.Request{ + resp, err := router.Commit(context.Background(), []*pb.Request{ {IsTxn: false, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: rawKey, Value: []byte("v")}}}, }) require.NoError(t, err) @@ -307,13 +307,15 @@ type fakeTxn struct { sink *atomic.Uint64 } -func (f *fakeTxn) Commit(reqs []*pb.Request) (*TransactionResponse, error) { +func (f *fakeTxn) Commit(_ context.Context, reqs []*pb.Request) (*TransactionResponse, error) { + _ = reqs if f.sink != nil { f.sink.Store(f.id) } return &TransactionResponse{CommitIndex: 1}, nil } -func (f *fakeTxn) Abort(reqs []*pb.Request) (*TransactionResponse, error) { +func (f *fakeTxn) Abort(_ context.Context, reqs []*pb.Request) (*TransactionResponse, error) { + _ = reqs return &TransactionResponse{}, nil } diff --git a/kv/shard_router_test.go b/kv/shard_router_test.go index 0de6c0ce..b288364c 100644 --- a/kv/shard_router_test.go +++ b/kv/shard_router_test.go @@ -45,7 +45,7 @@ func TestShardRouterCommit(t *testing.T) { {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}}, } - if _, err := router.Commit(reqs); err != nil { + if _, err := router.Commit(ctx, reqs); err != nil { t.Fatalf("commit: %v", err) } @@ -84,7 +84,7 @@ func TestShardRouterSplitAndMerge(t *testing.T) { req := []*pb.Request{ {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("b"), Value: []byte("v1")}}}, } - if _, err := router.Commit(req); err != nil { + if _, err := router.Commit(ctx, req); err != nil { t.Fatalf("commit group1: %v", err) } v, err := router.Get(ctx, []byte("b")) @@ -102,7 +102,7 @@ func TestShardRouterSplitAndMerge(t *testing.T) { req = []*pb.Request{ {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}}, } - if _, err := router.Commit(req); err != nil { + if _, err := router.Commit(ctx, req); err != nil { t.Fatalf("commit group2: %v", err) } v, err = router.Get(ctx, []byte("x")) @@ -119,7 +119,7 @@ func TestShardRouterSplitAndMerge(t *testing.T) { req = []*pb.Request{ {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("z"), Value: []byte("v3")}}}, } - if _, err := router.Commit(req); err != nil { + if _, err := router.Commit(ctx, req); err != nil { t.Fatalf("commit after merge: %v", err) } v, err = router.Get(ctx, []byte("z")) @@ -134,7 +134,8 @@ type fakeTM struct { abortCalls int } -func (f *fakeTM) Commit(reqs []*pb.Request) (*TransactionResponse, error) { +func (f *fakeTM) Commit(_ context.Context, reqs []*pb.Request) (*TransactionResponse, error) { + _ = reqs f.commitCalls++ if f.commitErr { return nil, fmt.Errorf("commit fail") @@ -142,12 +143,14 @@ func (f *fakeTM) Commit(reqs []*pb.Request) (*TransactionResponse, error) { return &TransactionResponse{}, nil } -func (f *fakeTM) Abort(reqs []*pb.Request) (*TransactionResponse, error) { +func (f *fakeTM) Abort(_ context.Context, reqs []*pb.Request) (*TransactionResponse, error) { + _ = reqs f.abortCalls++ return &TransactionResponse{}, nil } func TestShardRouterCommitFailure(t *testing.T) { + ctx := context.Background() e := distribution.NewEngine() e.UpdateRoute([]byte("a"), []byte("m"), 1) e.UpdateRoute([]byte("m"), nil, 2) @@ -164,7 +167,7 @@ func TestShardRouterCommitFailure(t *testing.T) { {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("x"), Value: []byte("v2")}}}, } - if _, err := router.Commit(reqs); err == nil { + if _, err := router.Commit(ctx, reqs); err == nil { t.Fatalf("expected error") } @@ -178,6 +181,7 @@ func TestShardRouterCommitFailure(t *testing.T) { } func TestShardRouterRoutesListKeys(t *testing.T) { + ctx := context.Background() e := distribution.NewEngine() e.UpdateRoute([]byte("a"), []byte("m"), 1) e.UpdateRoute([]byte("m"), nil, 2) @@ -194,7 +198,7 @@ func TestShardRouterRoutesListKeys(t *testing.T) { {IsTxn: false, Phase: pb.Phase_NONE, Mutations: []*pb.Mutation{{Op: pb.Op_PUT, Key: listMetaKey, Value: []byte("v")}}}, } - if _, err := router.Commit(reqs); err != nil { + if _, err := router.Commit(ctx, reqs); err != nil { t.Fatalf("commit: %v", err) } if ok.commitCalls != 1 { diff --git a/kv/shard_store.go b/kv/shard_store.go index 419c028e..34ceb33e 100644 --- a/kv/shard_store.go +++ b/kv/shard_store.go @@ -578,7 +578,7 @@ func (s *ShardStore) resolveTxnLockForKey(ctx context.Context, g *ShardGroup, ke } switch status { case txnStatusCommitted: - return applyTxnResolution(g, pb.Phase_COMMIT, lock.StartTS, commitTS, lock.PrimaryKey, [][]byte{key}) + return applyTxnResolution(ctx, g, pb.Phase_COMMIT, lock.StartTS, commitTS, lock.PrimaryKey, [][]byte{key}) case txnStatusRolledBack: abortTS := abortTSFrom(lock.StartTS, commitTS) if abortTS <= lock.StartTS { @@ -587,7 +587,7 @@ func (s *ShardStore) resolveTxnLockForKey(ctx context.Context, g *ShardGroup, ke // Prevents violating the FSM invariant resolveTS > startTS (fsm.go:258). return NewTxnLockedErrorWithDetail(key, "timestamp overflow") } - return applyTxnResolution(g, pb.Phase_ABORT, lock.StartTS, abortTS, lock.PrimaryKey, [][]byte{key}) + return applyTxnResolution(ctx, g, pb.Phase_ABORT, lock.StartTS, abortTS, lock.PrimaryKey, [][]byte{key}) case txnStatusPending: return NewTxnLockedError(key) default: @@ -652,7 +652,7 @@ func (s *ShardStore) resolveScanLocks(ctx context.Context, g *ShardGroup, kvs [] if err != nil { return nil, err } - if err := applyScanLockResolutions(g, plan); err != nil { + if err := applyScanLockResolutions(ctx, g, plan); err != nil { return nil, err } return s.materializeScanLockResults(ctx, g, ts, plan.items) @@ -902,10 +902,10 @@ func prefixScanEnd(prefix []byte) []byte { return nil } -func applyScanLockResolutions(g *ShardGroup, plan *scanLockPlan) error { +func applyScanLockResolutions(ctx context.Context, g *ShardGroup, plan *scanLockPlan) error { for _, txnKey := range plan.batchOrder { batch := plan.resolutionBatches[txnKey] - if err := applyTxnResolution(g, batch.phase, batch.startTS, batch.resolveTS, batch.primaryKey, batch.keys); err != nil { + if err := applyTxnResolution(ctx, g, batch.phase, batch.startTS, batch.resolveTS, batch.primaryKey, batch.keys); err != nil { return err } } @@ -1013,7 +1013,7 @@ func txnLockExpired(lock txnLock) bool { } func (s *ShardStore) expiredPrimaryTxnStatus(ctx context.Context, primaryKey []byte, startTS uint64) (txnStatus, uint64, error) { - aborted, err := s.tryAbortExpiredPrimary(primaryKey, startTS) + aborted, err := s.tryAbortExpiredPrimary(ctx, primaryKey, startTS) if err != nil { return s.statusAfterAbortFailure(ctx, primaryKey, startTS) } @@ -1075,7 +1075,7 @@ func (s *ShardStore) loadTxnLock(ctx context.Context, primaryKey []byte) (txnLoc return lock, true, nil } -func (s *ShardStore) tryAbortExpiredPrimary(primaryKey []byte, startTS uint64) (bool, error) { +func (s *ShardStore) tryAbortExpiredPrimary(ctx context.Context, primaryKey []byte, startTS uint64) (bool, error) { pg, ok := s.groupForKey(primaryKey) if !ok || pg == nil || pg.Txn == nil { return false, nil @@ -1090,13 +1090,13 @@ func (s *ShardStore) tryAbortExpiredPrimary(primaryKey []byte, startTS uint64) ( // Prevents violating the FSM invariant resolveTS > startTS (fsm.go:258). return false, nil } - if err := applyTxnResolution(pg, pb.Phase_ABORT, startTS, abortTS, primaryKey, [][]byte{primaryKey}); err != nil { + if err := applyTxnResolution(ctx, pg, pb.Phase_ABORT, startTS, abortTS, primaryKey, [][]byte{primaryKey}); err != nil { return false, err } return true, nil } -func applyTxnResolution(g *ShardGroup, phase pb.Phase, startTS, commitTS uint64, primaryKey []byte, keys [][]byte) error { +func applyTxnResolution(ctx context.Context, g *ShardGroup, phase pb.Phase, startTS, commitTS uint64, primaryKey []byte, keys [][]byte) error { if g == nil || g.Txn == nil { return errors.WithStack(store.ErrNotSupported) } @@ -1110,7 +1110,7 @@ func applyTxnResolution(g *ShardGroup, phase pb.Phase, startTS, commitTS uint64, for _, k := range keys { muts = append(muts, &pb.Mutation{Op: pb.Op_PUT, Key: k}) } - _, err := g.Txn.Commit([]*pb.Request{{IsTxn: true, Phase: phase, Ts: startTS, Mutations: muts}}) + _, err := g.Txn.Commit(ctx, []*pb.Request{{IsTxn: true, Phase: phase, Ts: startTS, Mutations: muts}}) return errors.WithStack(err) } diff --git a/kv/shard_store_txn_lock_test.go b/kv/shard_store_txn_lock_test.go index e5654645..8c50ff3a 100644 --- a/kv/shard_store_txn_lock_test.go +++ b/kv/shard_store_txn_lock_test.go @@ -105,7 +105,7 @@ func TestShardStoreGetAt_ReturnsTxnLockedForPendingLock(t *testing.T) { startTS := uint64(1) key := []byte("k") - _, err := groups[1].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, key, []byte("v"), key)}) + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, key, []byte("v"), key)}) require.NoError(t, err) _, err = shardStore.GetAt(ctx, key, ^uint64(0)) @@ -125,9 +125,9 @@ func TestShardStoreGetAt_ReturnsTxnLockedForPendingCrossShardTxn(t *testing.T) { primaryKey := []byte("b") secondaryKey := []byte("x") - _, err := groups[1].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) require.NoError(t, err) - _, err = groups[2].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, secondaryKey, []byte("v2"), primaryKey)}) + _, err = groups[2].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, secondaryKey, []byte("v2"), primaryKey)}) require.NoError(t, err) _, err = shardStore.GetAt(ctx, primaryKey, ^uint64(0)) @@ -153,9 +153,9 @@ func TestShardStoreGetAt_ResolvesCommittedSecondaryLock(t *testing.T) { primaryKey := []byte("b") // group 1 secondaryKey := []byte("x") - _, err := groups[1].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) require.NoError(t, err) - _, err = groups[2].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, secondaryKey, []byte("v2"), primaryKey)}) + _, err = groups[2].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, secondaryKey, []byte("v2"), primaryKey)}) require.NoError(t, err) commitPrimary := &pb.Request{ @@ -167,7 +167,7 @@ func TestShardStoreGetAt_ResolvesCommittedSecondaryLock(t *testing.T) { {Op: pb.Op_PUT, Key: primaryKey}, }, } - _, err = groups[1].Txn.Commit([]*pb.Request{commitPrimary}) + _, err = groups[1].Txn.Commit(context.Background(), []*pb.Request{commitPrimary}) require.NoError(t, err) // Reading the secondary key should resolve it based on the primary commit record. @@ -189,9 +189,9 @@ func TestShardStoreScanAt_ResolvesCommittedCrossShardTxn(t *testing.T) { primaryKey := []byte("b") secondaryKey := []byte("x") - _, err := groups[1].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) require.NoError(t, err) - _, err = groups[2].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, secondaryKey, []byte("v2"), primaryKey)}) + _, err = groups[2].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, secondaryKey, []byte("v2"), primaryKey)}) require.NoError(t, err) commitPrimary := &pb.Request{ @@ -203,7 +203,7 @@ func TestShardStoreScanAt_ResolvesCommittedCrossShardTxn(t *testing.T) { {Op: pb.Op_PUT, Key: primaryKey}, }, } - _, err = groups[1].Txn.Commit([]*pb.Request{commitPrimary}) + _, err = groups[1].Txn.Commit(context.Background(), []*pb.Request{commitPrimary}) require.NoError(t, err) kvs, err := shardStore.ScanAt(ctx, []byte("a"), []byte("z"), 100, commitTS) @@ -243,7 +243,7 @@ func TestShardStoreScanAt_ReturnsTxnLockedForPendingLock(t *testing.T) { require.NoError(t, st1.PutAt(ctx, key, []byte("old"), 1, 0)) startTS := uint64(2) - _, err := groups[1].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, key, []byte("v"), key)}) + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, key, []byte("v"), key)}) require.NoError(t, err) _, err = shardStore.ScanAt(ctx, []byte(""), nil, 100, ^uint64(0)) @@ -271,7 +271,7 @@ func TestShardStoreScanAt_ReturnsTxnLockedForPendingLockWithoutCommittedValue(t key := []byte("k") startTS := uint64(1) - _, err := groups[1].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, key, []byte("v"), key)}) + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, key, []byte("v"), key)}) require.NoError(t, err) // User-key range does not include raw !txn|lock|... keys, so lock-only @@ -316,7 +316,7 @@ func TestShardStoreScanAt_ReturnsTxnLockedWhenPendingLockExceedsUserLimit(t *tes {Op: pb.Op_PUT, Key: committedSecondary, Value: []byte("va")}, }, } - _, err := groups[1].Txn.Commit([]*pb.Request{prepareCommitted}) + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{prepareCommitted}) require.NoError(t, err) commitCommittedPrimary := &pb.Request{ IsTxn: true, @@ -327,13 +327,13 @@ func TestShardStoreScanAt_ReturnsTxnLockedWhenPendingLockExceedsUserLimit(t *tes {Op: pb.Op_PUT, Key: committedPrimary}, }, } - _, err = groups[1].Txn.Commit([]*pb.Request{commitCommittedPrimary}) + _, err = groups[1].Txn.Commit(context.Background(), []*pb.Request{commitCommittedPrimary}) require.NoError(t, err) // Create a later pending lock-only write that must block the scan. pendingPrimary := []byte("b") pendingStartTS := uint64(4) - _, err = groups[1].Txn.Commit([]*pb.Request{makePrepareRequest(pendingStartTS, pendingPrimary, []byte("vb"), pendingPrimary)}) + _, err = groups[1].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(pendingStartTS, pendingPrimary, []byte("vb"), pendingPrimary)}) require.NoError(t, err) // limit=1 should not hide pending locks after one resolved lock. @@ -359,7 +359,7 @@ func TestShardStoreScanAt_ResolvesCommittedSecondaryLocks(t *testing.T) { require.NoError(t, groups[2].Store.PutAt(ctx, secondaryKey1, []byte("old2"), 1, 0)) require.NoError(t, groups[2].Store.PutAt(ctx, secondaryKey2, []byte("old3"), 1, 0)) - _, err := groups[1].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) require.NoError(t, err) prepareMeta := &pb.Mutation{ @@ -377,7 +377,7 @@ func TestShardStoreScanAt_ResolvesCommittedSecondaryLocks(t *testing.T) { {Op: pb.Op_PUT, Key: secondaryKey2, Value: []byte("v3")}, }, } - _, err = groups[2].Txn.Commit([]*pb.Request{prepareSecondary}) + _, err = groups[2].Txn.Commit(context.Background(), []*pb.Request{prepareSecondary}) require.NoError(t, err) commitPrimary := &pb.Request{ @@ -389,7 +389,7 @@ func TestShardStoreScanAt_ResolvesCommittedSecondaryLocks(t *testing.T) { {Op: pb.Op_PUT, Key: primaryKey}, }, } - _, err = groups[1].Txn.Commit([]*pb.Request{commitPrimary}) + _, err = groups[1].Txn.Commit(context.Background(), []*pb.Request{commitPrimary}) require.NoError(t, err) kvs, err := shardStore.ScanAt(ctx, []byte("w"), nil, 100, commitTS) @@ -417,9 +417,9 @@ func TestShardStoreScanAt_ResolvesCommittedSecondaryLockWithoutCommittedValue(t primaryKey := []byte("b") secondaryKey := []byte("x") - _, err := groups[1].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) + _, err := groups[1].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, primaryKey, []byte("v1"), primaryKey)}) require.NoError(t, err) - _, err = groups[2].Txn.Commit([]*pb.Request{makePrepareRequest(startTS, secondaryKey, []byte("v2"), primaryKey)}) + _, err = groups[2].Txn.Commit(context.Background(), []*pb.Request{makePrepareRequest(startTS, secondaryKey, []byte("v2"), primaryKey)}) require.NoError(t, err) commitPrimary := &pb.Request{ @@ -431,7 +431,7 @@ func TestShardStoreScanAt_ResolvesCommittedSecondaryLockWithoutCommittedValue(t {Op: pb.Op_PUT, Key: primaryKey}, }, } - _, err = groups[1].Txn.Commit([]*pb.Request{commitPrimary}) + _, err = groups[1].Txn.Commit(context.Background(), []*pb.Request{commitPrimary}) require.NoError(t, err) kvs, err := shardStore.ScanAt(ctx, []byte("x"), []byte("z"), 100, commitTS) diff --git a/kv/sharded_coordinator.go b/kv/sharded_coordinator.go index 54c24115..a88b173d 100644 --- a/kv/sharded_coordinator.go +++ b/kv/sharded_coordinator.go @@ -42,10 +42,10 @@ type leaseRefreshingTxn struct { g *ShardGroup } -func (t *leaseRefreshingTxn) Commit(reqs []*pb.Request) (*TransactionResponse, error) { +func (t *leaseRefreshingTxn) Commit(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { start := monoclock.Now() expectedGen := t.g.lease.generation() - resp, err := t.inner.Commit(reqs) + resp, err := t.inner.Commit(ctx, reqs) if err != nil { // Only invalidate on errors that actually signal a leadership // change. Write-conflicts, validation errors, and deadline @@ -65,10 +65,10 @@ func (t *leaseRefreshingTxn) Commit(reqs []*pb.Request) (*TransactionResponse, e return resp, nil } -func (t *leaseRefreshingTxn) Abort(reqs []*pb.Request) (*TransactionResponse, error) { +func (t *leaseRefreshingTxn) Abort(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { start := monoclock.Now() expectedGen := t.g.lease.generation() - resp, err := t.inner.Abort(reqs) + resp, err := t.inner.Abort(ctx, reqs) if err != nil { if isLeadershipLossError(err) { t.g.lease.invalidate() @@ -246,7 +246,7 @@ func (c *ShardedCoordinator) Dispatch(ctx context.Context, reqs *OperationGroup[ // span multiple shards (or be nil, meaning "all keys"). Broadcast the // operation to every shard group so each FSM scans locally. if hasDelPrefixElem(reqs.Elems) { - return c.dispatchDelPrefixBroadcast(reqs.IsTxn, reqs.Elems) + return c.dispatchDelPrefixBroadcast(ctx, reqs.IsTxn, reqs.Elems) } if reqs.IsTxn && reqs.StartTS == 0 { @@ -271,7 +271,7 @@ func (c *ShardedCoordinator) Dispatch(ctx context.Context, reqs *OperationGroup[ return nil, err } - r, err := c.router.Commit(logs) + r, err := c.router.Commit(ctx, logs) if err != nil { return nil, errors.WithStack(err) } @@ -304,7 +304,7 @@ func validateDelPrefixOnly(elems []*Elem[OP]) error { // to every shard group. Each element becomes a separate pb.Request (the FSM's // extractDelPrefix processes only the first DEL_PREFIX mutation per request). // All requests are batched into a single Commit call per shard group. -func (c *ShardedCoordinator) dispatchDelPrefixBroadcast(isTxn bool, elems []*Elem[OP]) (*CoordinateResponse, error) { +func (c *ShardedCoordinator) dispatchDelPrefixBroadcast(ctx context.Context, isTxn bool, elems []*Elem[OP]) (*CoordinateResponse, error) { if isTxn { return nil, errors.Wrap(ErrInvalidRequest, "DEL_PREFIX not supported in transactions") } @@ -323,12 +323,12 @@ func (c *ShardedCoordinator) dispatchDelPrefixBroadcast(isTxn bool, elems []*Ele }) } - return c.broadcastToAllGroups(requests) + return c.broadcastToAllGroups(ctx, requests) } // broadcastToAllGroups sends the same set of requests to every shard group in // parallel and returns the maximum commit index. -func (c *ShardedCoordinator) broadcastToAllGroups(requests []*pb.Request) (*CoordinateResponse, error) { +func (c *ShardedCoordinator) broadcastToAllGroups(ctx context.Context, requests []*pb.Request) (*CoordinateResponse, error) { var ( maxIndex atomic.Uint64 firstErr error @@ -339,7 +339,7 @@ func (c *ShardedCoordinator) broadcastToAllGroups(requests []*pb.Request) (*Coor wg.Add(1) go func(g *ShardGroup) { defer wg.Done() - r, err := g.Txn.Commit(requests) + r, err := g.Txn.Commit(ctx, requests) if err != nil { errMu.Lock() if firstErr == nil { @@ -390,7 +390,7 @@ func (c *ShardedCoordinator) dispatchTxn(ctx context.Context, startTS uint64, co // If any read key belongs to a different shard the 2PC path is required // so that validateReadOnlyShards can issue a linearizable read barrier, // preserving SSI. - return c.dispatchSingleShardTxn(startTS, commitTS, primaryKey, gids[0], elems, readKeys) + return c.dispatchSingleShardTxn(ctx, startTS, commitTS, primaryKey, gids[0], elems, readKeys) } // Multi-shard path: group read keys by shard now. The result is passed @@ -407,13 +407,21 @@ func (c *ShardedCoordinator) dispatchTxn(ctx context.Context, startTS uint64, co return nil, err } - primaryGid, maxIndex, err := c.commitPrimaryTxn(startTS, primaryKey, grouped, commitTS) + primaryGid, maxIndex, err := c.commitPrimaryTxn(ctx, startTS, primaryKey, grouped, commitTS) if err != nil { - c.abortPreparedTxn(startTS, primaryKey, prepared, abortTSFrom(startTS, commitTS)) + // abortPreparedTxn must run even when ctx was the reason + // commitPrimaryTxn failed — otherwise prewrite intents on + // every prepared shard linger until LockResolver picks them + // up at a future tick (lease window of expensive + // keyspace-scan work). Detach cancellation but cap with + // verifyLeaderTimeout so a hung Abort cannot leak. + cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), verifyLeaderTimeout) + c.abortPreparedTxn(cleanupCtx, startTS, primaryKey, prepared, abortTSFrom(startTS, commitTS)) + cancel() return nil, errors.WithStack(err) } - maxIndex = c.commitSecondaryTxns(startTS, primaryGid, primaryKey, grouped, gids, commitTS, maxIndex) + maxIndex = c.commitSecondaryTxns(ctx, startTS, primaryGid, primaryKey, grouped, gids, commitTS, maxIndex) return &CoordinateResponse{CommitIndex: maxIndex}, nil } @@ -443,14 +451,14 @@ func (c *ShardedCoordinator) allReadKeysInShard(readKeys [][]byte, gid uint64) b return true } -func (c *ShardedCoordinator) dispatchSingleShardTxn(startTS, commitTS uint64, primaryKey []byte, gid uint64, elems []*Elem[OP], readKeys [][]byte) (*CoordinateResponse, error) { +func (c *ShardedCoordinator) dispatchSingleShardTxn(ctx context.Context, startTS, commitTS uint64, primaryKey []byte, gid uint64, elems []*Elem[OP], readKeys [][]byte) (*CoordinateResponse, error) { g, err := c.txnGroupForID(gid) if err != nil { return nil, err } // ReadKeys are included in the Raft log entry so the FSM validates // read-write conflicts atomically under applyMu. - resp, err := g.Txn.Commit([]*pb.Request{ + resp, err := g.Txn.Commit(ctx, []*pb.Request{ onePhaseTxnRequest(startTS, commitTS, primaryKey, elems, readKeys), }) if err != nil { @@ -483,8 +491,16 @@ func (c *ShardedCoordinator) prewriteTxn(ctx context.Context, startTS, commitTS Mutations: append([]*pb.Mutation{prepareMeta}, grouped[gid]...), ReadKeys: groupedReadKeys[gid], } - if _, err := g.Txn.Commit([]*pb.Request{req}); err != nil { - c.abortPreparedTxn(startTS, primaryKey, prepared, abortTSFrom(startTS, commitTS)) + if _, err := g.Txn.Commit(ctx, []*pb.Request{req}); err != nil { + // Same WithoutCancel pattern as dispatchTxn's + // commitPrimaryTxn-failure cleanup: a cancelled ctx is + // the most likely cause of Commit failing here, and the + // abort MUST still go through to release the intents we + // already wrote on prior shards. Otherwise LockResolver + // holds the bag. + cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), verifyLeaderTimeout) + c.abortPreparedTxn(cleanupCtx, startTS, primaryKey, prepared, abortTSFrom(startTS, commitTS)) + cancel() return nil, errors.WithStack(err) } prepared = append(prepared, preparedGroup{gid: gid, keys: keyMutations(grouped[gid])}) @@ -494,14 +510,20 @@ func (c *ShardedCoordinator) prewriteTxn(ctx context.Context, startTS, commitTS // but no mutations in this transaction). Without this, a concurrent // write to a read-only shard would go undetected. if err := c.validateReadOnlyShards(ctx, groupedReadKeys, gids, startTS); err != nil { - c.abortPreparedTxn(startTS, primaryKey, prepared, abortTSFrom(startTS, commitTS)) + // Same reasoning as the prepare-loop cleanup above: the + // validate read fence may have failed because ctx + // expired, so the abort needs detached cancellation to + // avoid stranding intents. + cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), verifyLeaderTimeout) + c.abortPreparedTxn(cleanupCtx, startTS, primaryKey, prepared, abortTSFrom(startTS, commitTS)) + cancel() return nil, err } return prepared, nil } -func (c *ShardedCoordinator) commitPrimaryTxn(startTS uint64, primaryKey []byte, grouped map[uint64][]*pb.Mutation, commitTS uint64) (uint64, uint64, error) { +func (c *ShardedCoordinator) commitPrimaryTxn(ctx context.Context, startTS uint64, primaryKey []byte, grouped map[uint64][]*pb.Mutation, commitTS uint64) (uint64, uint64, error) { primaryGid := c.engineGroupIDForKey(primaryKey) if primaryGid == 0 { return 0, 0, errors.WithStack(ErrInvalidRequest) @@ -521,7 +543,7 @@ func (c *ShardedCoordinator) commitPrimaryTxn(startTS uint64, primaryKey []byte, Mutations: append([]*pb.Mutation{meta}, keys...), } - r, err := g.Txn.Commit([]*pb.Request{req}) + r, err := g.Txn.Commit(ctx, []*pb.Request{req}) if err != nil { return primaryGid, 0, errors.WithStack(err) } @@ -531,7 +553,7 @@ func (c *ShardedCoordinator) commitPrimaryTxn(startTS uint64, primaryKey []byte, return primaryGid, r.CommitIndex, nil } -func (c *ShardedCoordinator) commitSecondaryTxns(startTS uint64, primaryGid uint64, primaryKey []byte, grouped map[uint64][]*pb.Mutation, gids []uint64, commitTS uint64, maxIndex uint64) uint64 { +func (c *ShardedCoordinator) commitSecondaryTxns(ctx context.Context, startTS uint64, primaryGid uint64, primaryKey []byte, grouped map[uint64][]*pb.Mutation, gids []uint64, commitTS uint64, maxIndex uint64) uint64 { // Secondary commits are best-effort. If a shard is unavailable after the // primary commits, read-time lock resolution will commit the remaining // secondaries based on the primary commit record. Retry a few times to @@ -551,7 +573,7 @@ func (c *ShardedCoordinator) commitSecondaryTxns(startTS uint64, primaryGid uint Ts: startTS, Mutations: append([]*pb.Mutation{meta}, keyMutations(grouped[gid])...), } - r, err := commitSecondaryWithRetry(g, req) + r, err := commitSecondaryWithRetry(ctx, g, req) if err != nil { c.logger().Warn("txn secondary commit failed", slog.Uint64("gid", gid), @@ -569,13 +591,13 @@ func (c *ShardedCoordinator) commitSecondaryTxns(startTS uint64, primaryGid uint return maxIndex } -func commitSecondaryWithRetry(g *ShardGroup, req *pb.Request) (*TransactionResponse, error) { +func commitSecondaryWithRetry(ctx context.Context, g *ShardGroup, req *pb.Request) (*TransactionResponse, error) { if g == nil || g.Txn == nil || req == nil { return nil, errors.WithStack(ErrInvalidRequest) } var lastErr error for attempt := range txnSecondaryCommitRetryAttempts { - resp, err := g.Txn.Commit([]*pb.Request{req}) + resp, err := g.Txn.Commit(ctx, []*pb.Request{req}) if err == nil { return resp, nil } @@ -594,7 +616,7 @@ func (c *ShardedCoordinator) logger() *slog.Logger { return slog.Default() } -func (c *ShardedCoordinator) abortPreparedTxn(startTS uint64, primaryKey []byte, prepared []preparedGroup, abortTS uint64) { +func (c *ShardedCoordinator) abortPreparedTxn(ctx context.Context, startTS uint64, primaryKey []byte, prepared []preparedGroup, abortTS uint64) { if len(prepared) == 0 { return } @@ -614,7 +636,7 @@ func (c *ShardedCoordinator) abortPreparedTxn(startTS uint64, primaryKey []byte, Ts: startTS, Mutations: append([]*pb.Mutation{meta}, pg.keys...), } - if _, err := g.Txn.Commit([]*pb.Request{req}); err != nil { + if _, err := g.Txn.Commit(ctx, []*pb.Request{req}); err != nil { if errors.Is(err, ErrTxnAlreadyCommitted) { continue } @@ -722,12 +744,12 @@ func (c *ShardedCoordinator) IsLeader() bool { return isLeaderEngine(engineForGroup(g)) } -func (c *ShardedCoordinator) VerifyLeader() error { +func (c *ShardedCoordinator) VerifyLeader(ctx context.Context) error { g, ok := c.groups[c.defaultGroup] if !ok { return errors.WithStack(ErrLeaderNotFound) } - return verifyLeaderEngine(engineForGroup(g)) + return verifyLeaderEngineCtx(ctx, engineForGroup(g)) } func (c *ShardedCoordinator) RaftLeader() string { @@ -754,12 +776,12 @@ func (c *ShardedCoordinator) IsLeaderForKey(key []byte) bool { return isLeaderEngine(engineForGroup(g)) } -func (c *ShardedCoordinator) VerifyLeaderForKey(key []byte) error { +func (c *ShardedCoordinator) VerifyLeaderForKey(ctx context.Context, key []byte) error { g, ok := c.groupForKey(key) if !ok { return errors.WithStack(ErrLeaderNotFound) } - return verifyLeaderEngine(engineForGroup(g)) + return verifyLeaderEngineCtx(ctx, engineForGroup(g)) } func (c *ShardedCoordinator) RaftLeaderForKey(key []byte) string { diff --git a/kv/sharded_coordinator_abort_test.go b/kv/sharded_coordinator_abort_test.go index c6d553a8..faea19fd 100644 --- a/kv/sharded_coordinator_abort_test.go +++ b/kv/sharded_coordinator_abort_test.go @@ -18,11 +18,11 @@ type failingTransactional struct { err error } -func (f *failingTransactional) Commit([]*pb.Request) (*TransactionResponse, error) { +func (f *failingTransactional) Commit(context.Context, []*pb.Request) (*TransactionResponse, error) { return nil, f.err } -func (f *failingTransactional) Abort([]*pb.Request) (*TransactionResponse, error) { +func (f *failingTransactional) Abort(context.Context, []*pb.Request) (*TransactionResponse, error) { return nil, f.err } @@ -97,7 +97,7 @@ func TestAbortPreparedTxn_DoesNotWarnWhenTxnAlreadyCommitted(t *testing.T) { }, } - coord.abortPreparedTxn(10, []byte("pk"), []preparedGroup{{ + coord.abortPreparedTxn(context.Background(), 10, []byte("pk"), []preparedGroup{{ gid: 1, keys: []*pb.Mutation{{Op: pb.Op_PUT, Key: []byte("pk")}}, }}, 20) diff --git a/kv/sharded_coordinator_leader_test.go b/kv/sharded_coordinator_leader_test.go index 201a899b..dd0521da 100644 --- a/kv/sharded_coordinator_leader_test.go +++ b/kv/sharded_coordinator_leader_test.go @@ -1,6 +1,7 @@ package kv import ( + "context" "testing" "github.com/bootjp/elastickv/distribution" @@ -24,8 +25,8 @@ func TestShardedCoordinatorVerifyLeader_LeaderReturnsNil(t *testing.T) { } coord := NewShardedCoordinator(engine, groups, 1, NewHLC(), NewShardStore(engine, groups)) - require.NoError(t, coord.VerifyLeader()) - require.NoError(t, coord.VerifyLeaderForKey([]byte("b"))) + require.NoError(t, coord.VerifyLeader(context.Background())) + require.NoError(t, coord.VerifyLeaderForKey(context.Background(), []byte("b"))) } func TestShardedCoordinatorVerifyLeader_MissingGroup(t *testing.T) { @@ -34,6 +35,6 @@ func TestShardedCoordinatorVerifyLeader_MissingGroup(t *testing.T) { engine := distribution.NewEngine() coord := NewShardedCoordinator(engine, map[uint64]*ShardGroup{}, 1, NewHLC(), nil) - require.ErrorIs(t, coord.VerifyLeader(), ErrLeaderNotFound) - require.ErrorIs(t, coord.VerifyLeaderForKey([]byte("k")), ErrLeaderNotFound) + require.ErrorIs(t, coord.VerifyLeader(context.Background()), ErrLeaderNotFound) + require.ErrorIs(t, coord.VerifyLeaderForKey(context.Background(), []byte("k")), ErrLeaderNotFound) } diff --git a/kv/sharded_coordinator_sampler_test.go b/kv/sharded_coordinator_sampler_test.go index ea4d85a7..2fc9576a 100644 --- a/kv/sharded_coordinator_sampler_test.go +++ b/kv/sharded_coordinator_sampler_test.go @@ -216,7 +216,7 @@ func TestShardedCoordinatorSkipsObserveForLeadershipChecks(t *testing.T) { key := []byte("k") _ = coord.IsLeaderForKey(key) - _ = coord.VerifyLeaderForKey(key) + _ = coord.VerifyLeaderForKey(context.Background(), key) _ = coord.RaftLeaderForKey(key) require.Empty(t, rec.snapshot(), "leadership checks must not produce read samples") diff --git a/kv/sharded_coordinator_txn_test.go b/kv/sharded_coordinator_txn_test.go index 6eaf567c..fbe359c9 100644 --- a/kv/sharded_coordinator_txn_test.go +++ b/kv/sharded_coordinator_txn_test.go @@ -21,7 +21,7 @@ type recordingTransactional struct { errs []error } -func (s *recordingTransactional) Commit(reqs []*pb.Request) (*TransactionResponse, error) { +func (s *recordingTransactional) Commit(_ context.Context, reqs []*pb.Request) (*TransactionResponse, error) { s.mu.Lock() defer s.mu.Unlock() @@ -39,7 +39,7 @@ func (s *recordingTransactional) Commit(reqs []*pb.Request) (*TransactionRespons return &TransactionResponse{}, nil } -func (s *recordingTransactional) Abort(_ []*pb.Request) (*TransactionResponse, error) { +func (s *recordingTransactional) Abort(_ context.Context, _ []*pb.Request) (*TransactionResponse, error) { return &TransactionResponse{}, nil } @@ -262,7 +262,7 @@ func TestCommitSecondaryWithRetry_RetriesAndSucceeds(t *testing.T) { }, } - resp, err := commitSecondaryWithRetry(&ShardGroup{Txn: txn}, &pb.Request{ + resp, err := commitSecondaryWithRetry(context.Background(), &ShardGroup{Txn: txn}, &pb.Request{ IsTxn: true, Phase: pb.Phase_COMMIT, Ts: 7, @@ -288,7 +288,7 @@ func TestCommitSecondaryWithRetry_ExhaustsRetries(t *testing.T) { }, } - _, err := commitSecondaryWithRetry(&ShardGroup{Txn: txn}, &pb.Request{ + _, err := commitSecondaryWithRetry(context.Background(), &ShardGroup{Txn: txn}, &pb.Request{ IsTxn: true, Phase: pb.Phase_COMMIT, Ts: 9, diff --git a/kv/sharded_lease_test.go b/kv/sharded_lease_test.go index b06feb3b..8ab447aa 100644 --- a/kv/sharded_lease_test.go +++ b/kv/sharded_lease_test.go @@ -119,20 +119,20 @@ func TestShardedCoordinator_LeaseRefreshingTxn_SkipsWhenCommitIndexZero(t *testi require.False(t, g1.lease.valid(monoclock.Now())) // Commit with empty input returns success with CommitIndex=0. - _, err := g1.Txn.Commit(nil) + _, err := g1.Txn.Commit(context.Background(), nil) require.NoError(t, err) require.False(t, g1.lease.valid(monoclock.Now()), "lease must NOT be refreshed when no Raft commit happened") // Same for Abort. - _, err = g1.Txn.Abort(nil) + _, err = g1.Txn.Abort(context.Background(), nil) require.NoError(t, err) require.False(t, g1.lease.valid(monoclock.Now())) // A response with CommitIndex > 0 refreshes the lease. realResp := &TransactionResponse{CommitIndex: 42} txn.inner = &fixedTransactional{response: realResp} - _, err = g1.Txn.Commit(nil) + _, err = g1.Txn.Commit(context.Background(), nil) require.NoError(t, err) require.True(t, g1.lease.valid(monoclock.Now()), "lease must be refreshed after a real Raft commit") @@ -145,11 +145,11 @@ type fixedTransactional struct { response *TransactionResponse } -func (f *fixedTransactional) Commit(_ []*pb.Request) (*TransactionResponse, error) { +func (f *fixedTransactional) Commit(_ context.Context, _ []*pb.Request) (*TransactionResponse, error) { return f.response, nil } -func (f *fixedTransactional) Abort(_ []*pb.Request) (*TransactionResponse, error) { +func (f *fixedTransactional) Abort(_ context.Context, _ []*pb.Request) (*TransactionResponse, error) { return f.response, nil } diff --git a/kv/transaction.go b/kv/transaction.go index 531eea9e..91769f1e 100644 --- a/kv/transaction.go +++ b/kv/transaction.go @@ -93,9 +93,17 @@ func (t *TransactionManager) Close() { }) } +// Transactional is the kv-internal interface that fronts the raft propose +// path. Implementations (TransactionManager, LeaderProxy, ShardRouter, +// leaseRefreshingTxn) thread the caller's context end-to-end so a Redis / +// gRPC / S3 / SQS handler's deadline reaches Propose / VerifyLeader without +// being silently dropped to context.Background. See PR #748 / design doc +// 2026_05_10_proposed_kv_ctx_plumbing.md for the rationale; the prior +// signatures lived behind `verifyLeaderEngine`'s 5 s safety bound (#745), +// which is preserved as the no-ctx defense-in-depth fallback. type Transactional interface { - Commit(reqs []*pb.Request) (*TransactionResponse, error) - Abort(reqs []*pb.Request) (*TransactionResponse, error) + Commit(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) + Abort(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) } type TransactionResponse struct { @@ -138,7 +146,7 @@ func prependByte(prefix byte, data []byte) []byte { // HashiCorp Raft delivers FSM responses via ApplyFuture.Response(), not Error(), // so we must inspect the response to avoid silently treating failed writes as // successes. -func applyRequests(proposer raftengine.Proposer, reqs []*pb.Request, proposalObserver ProposalObserver) (uint64, []error, error) { +func applyRequests(ctx context.Context, proposer raftengine.Proposer, reqs []*pb.Request, proposalObserver ProposalObserver) (uint64, []error, error) { b, err := marshalRaftCommand(reqs) if err != nil { return 0, nil, errors.WithStack(err) @@ -149,7 +157,7 @@ func applyRequests(proposer raftengine.Proposer, reqs []*pb.Request, proposalObs return 0, nil, errors.WithStack(ErrLeaderNotFound) } - result, err := proposer.Propose(context.Background(), b) + result, err := proposer.Propose(ctx, b) if err != nil { recordProposalFailure(proposalObserver) return 0, nil, errors.WithStack(err) @@ -190,14 +198,14 @@ func recordProposalFailure(observer ProposalObserver) { } } -func (t *TransactionManager) Commit(reqs []*pb.Request) (*TransactionResponse, error) { +func (t *TransactionManager) Commit(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { if len(reqs) == 0 { return &TransactionResponse{}, nil } if hasTxnRequests(reqs) { - return t.commitSequential(reqs) + return t.commitSequential(ctx, reqs) } - return t.commitRaw(reqs) + return t.commitRaw(ctx, reqs) } func hasTxnRequests(reqs []*pb.Request) bool { @@ -209,11 +217,11 @@ func hasTxnRequests(reqs []*pb.Request) bool { return false } -func (t *TransactionManager) commitSequential(reqs []*pb.Request) (*TransactionResponse, error) { +func (t *TransactionManager) commitSequential(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { commitIndex, err := func() (uint64, error) { commitIndex := uint64(0) for _, req := range reqs { - idx, results, err := applyRequests(t.proposer, []*pb.Request{req}, t.proposalObserver) + idx, results, err := applyRequests(ctx, t.proposer, []*pb.Request{req}, t.proposalObserver) if err != nil { return 0, err } @@ -232,7 +240,17 @@ func (t *TransactionManager) commitSequential(reqs []*pb.Request) (*TransactionR // transactional requests do not leave intents behind, so they do not need // abort cleanup on failure. if needsTxnCleanup(reqs) { - _, _err := t.Abort(reqs) + // Use a cleanup ctx that survives the original ctx's + // cancellation: the upstream commit very likely failed + // because ctx was cancelled / hit its deadline, and Abort + // MUST still go through to release intents — otherwise + // locks linger until LockResolver picks them up at a + // future tick. context.WithoutCancel detaches deadline + // and cancellation; we re-bound with verifyLeaderTimeout + // so a hung Abort cannot leak the cleanup goroutine. + cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), verifyLeaderTimeout) + _, _err := t.Abort(cleanupCtx, reqs) + cancel() if _err != nil { return nil, errors.WithStack(errors.CombineErrors(err, _err)) } @@ -265,7 +283,7 @@ func needsTxnCleanup(reqs []*pb.Request) bool { return false } -func (t *TransactionManager) commitRaw(reqs []*pb.Request) (*TransactionResponse, error) { +func (t *TransactionManager) commitRaw(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { item := &rawCommitItem{ reqs: reqs, done: make(chan rawCommitResult, 1), @@ -294,11 +312,23 @@ func (t *TransactionManager) commitRaw(reqs []*pb.Request) (*TransactionResponse go t.flushRawPending() } - res := <-item.done - if res.err != nil { - return nil, res.err + // Wait under the caller's ctx so a deadline expiring while batched + // commits are pending lets the caller exit without blocking on a + // busy queue. The proposal itself is driven by flushRawPending in + // a separate goroutine using context.Background — that is the + // intentional batched-commit model: many callers' ctxs map to one + // batch propose, so no single ctx can bound it. Caller cancellation + // here only abandons the wait; the propose still completes and + // other waiters in the same batch get their results normally. + select { + case res := <-item.done: + if res.err != nil { + return nil, res.err + } + return &TransactionResponse{CommitIndex: res.commitIndex}, nil + case <-ctx.Done(): + return nil, errors.WithStack(ctx.Err()) } - return &TransactionResponse{CommitIndex: res.commitIndex}, nil } func (t *TransactionManager) flushRawPending() { @@ -374,7 +404,11 @@ func (t *TransactionManager) applyRawBatch(batch []*rawCommitItem) { } offsets = append(offsets, len(reqs)) - idx, results, err := applyRequests(t.proposer, reqs, t.proposalObserver) + // Batched-commit goroutine cannot inherit any single caller's ctx — + // see the commitRaw comment. Use Background here; per-caller + // cancellation is honoured at the wait site in commitRaw via select + // on item.done vs ctx.Done. + idx, results, err := applyRequests(context.Background(), t.proposer, reqs, t.proposalObserver) if err != nil { for _, item := range batch { item.done <- rawCommitResult{err: err} @@ -403,7 +437,7 @@ func combineApplyErrors(errs []error) error { return errors.WithStack(combined) } -func (t *TransactionManager) Abort(reqs []*pb.Request) (*TransactionResponse, error) { +func (t *TransactionManager) Abort(ctx context.Context, reqs []*pb.Request) (*TransactionResponse, error) { var abortReqs []*pb.Request for _, req := range reqs { if abortReq := abortRequestFor(req); abortReq != nil { @@ -413,7 +447,7 @@ func (t *TransactionManager) Abort(reqs []*pb.Request) (*TransactionResponse, er var commitIndex uint64 for _, req := range abortReqs { - idx, results, err := applyRequests(t.proposer, []*pb.Request{req}, t.proposalObserver) + idx, results, err := applyRequests(ctx, t.proposer, []*pb.Request{req}, t.proposalObserver) if err != nil { return nil, err } diff --git a/kv/transaction_batch_test.go b/kv/transaction_batch_test.go index 31f37e05..0c49ee32 100644 --- a/kv/transaction_batch_test.go +++ b/kv/transaction_batch_test.go @@ -121,13 +121,13 @@ func TestTransactionManagerBatchesConcurrentRawCommits(t *testing.T) { go func() { defer wg.Done() <-start - resp, err := tm.Commit(req1) + resp, err := tm.Commit(context.Background(), req1) results <- result{resp: resp, err: err} }() go func() { defer wg.Done() <-start - resp, err := tm.Commit(req2) + resp, err := tm.Commit(context.Background(), req2) results <- result{resp: resp, err: err} }() close(start) @@ -167,7 +167,7 @@ func TestApplyRequestsCountsProposalFailureOnRaftApplyError(t *testing.T) { }, }} - _, _, err := applyRequests(r, reqs, observer) + _, _, err := applyRequests(context.Background(), r, reqs, observer) require.Error(t, err) require.Equal(t, 1, observer.FailureCount()) } @@ -186,7 +186,7 @@ func TestApplyRequestsDoesNotCountBusinessErrorAsProposalFailure(t *testing.T) { }, }} - _, results, err := applyRequests(r, reqs, observer) + _, results, err := applyRequests(context.Background(), r, reqs, observer) require.NoError(t, err) require.Len(t, results, 1) require.ErrorIs(t, results[0], ErrInvalidRequest) @@ -217,7 +217,7 @@ func TestApplyRequestsWithEtcdEngineKeepsKVCommandSemantics(t *testing.T) { }, }} - commitIndex, results, err := applyRequests(engine, goodReqs, observer) + commitIndex, results, err := applyRequests(context.Background(), engine, goodReqs, observer) require.NoError(t, err) require.NotZero(t, commitIndex) require.Len(t, results, 1) @@ -235,7 +235,7 @@ func TestApplyRequestsWithEtcdEngineKeepsKVCommandSemantics(t *testing.T) { }, }} - _, results, err = applyRequests(engine, badReqs, observer) + _, results, err = applyRequests(context.Background(), engine, badReqs, observer) require.NoError(t, err) require.Len(t, results, 1) require.ErrorIs(t, results[0], ErrInvalidRequest) diff --git a/main_admin.go b/main_admin.go index 1c7ba313..08ab47e3 100644 --- a/main_admin.go +++ b/main_admin.go @@ -166,15 +166,17 @@ func newAdminLeaderProbe(coordinate kv.Coordinator) admin.LeaderProbe { if coordinate == nil { return nil } - return admin.LeaderProbeFunc(func() bool { + return admin.LeaderProbeFunc(func(ctx context.Context) bool { if !coordinate.IsLeader() { return false } - // VerifyLeader is the same ReadIndex round-trip lease reads - // use; under the hood it carries an engine-bounded deadline, - // so a stalled cluster surfaces 503 here on its own without - // the probe needing an outer timeout. - return coordinate.VerifyLeader() == nil + // VerifyLeader receives the request ctx (PR #748): a Caddy probe + // or browser preflight that sets its own deadline now bounds the + // ReadIndex round-trip, instead of falling back to + // verifyLeaderEngine's no-arg 5s safety net (#745). The 5s bound + // remains as defense-in-depth for callers without an upstream + // ctx (lock resolver, HLC lease tick). + return coordinate.VerifyLeader(ctx) == nil }) }