diff --git a/internal/memwatch/memwatch.go b/internal/memwatch/memwatch.go new file mode 100644 index 00000000..781beb09 --- /dev/null +++ b/internal/memwatch/memwatch.go @@ -0,0 +1,262 @@ +// Package memwatch provides a user-space memory watchdog that triggers +// an orderly shutdown before the kernel OOM-killer sends SIGKILL. +// +// # Motivation +// +// elastickv runs in memory-constrained containers (e.g. 3GB RAM VMs). Go's +// runtime is unaware of the container/host memory limit, and even with +// GOMEMLIMIT set the process can still lose the race against the kernel +// OOM-killer under sustained memtable/goroutine growth. A SIGKILL leaves +// the Raft WAL potentially truncated mid-operation; a cooperative SIGTERM +// path lets the node sync the WAL and stop raft cleanly, avoiding the +// election storms and lease loss that follow crash-restarts. +// +// The watcher samples runtime/metrics at a fixed cadence. When the live +// heap-in-use byte count crosses the configured threshold it invokes +// OnExceed once and exits. The watcher never calls os.Exit or sends +// signals itself; callers wire OnExceed to the existing shutdown path +// (typically a root context.CancelFunc). +// +// Wiring in elastickv (see main.go): +// +// ctx, cancel := context.WithCancel(context.Background()) +// // ... build runtimes, servers, errgroup ... +// w := memwatch.New(memwatch.Config{ +// ThresholdBytes: threshold, +// PollInterval: pollInterval, +// OnExceed: func() { +// memoryPressureExit.Store(true) // flips exit code to 2 +// cancel() // fires the same shutdown path +// }, // SIGTERM would use. +// }) +// eg.Go(func() error { w.Start(runCtx); return nil }) +// +// # Metric choice +// +// We sample `runtime/metrics` (Go 1.16+) rather than `runtime.ReadMemStats`. +// ReadMemStats triggers a stop-the-world pause proportional to the number of +// goroutines and heap size; at 1 s cadence that's typically negligible, but +// at a tighter MinPollInterval (10 ms) it begins to register. runtime/metrics +// readers are lock-free for the counters we need and do not stop the world. +// +// The threshold is compared against +// +// /memory/classes/heap/objects:bytes + /memory/classes/heap/unused:bytes +// +// which is the runtime/metrics equivalent of MemStats.HeapInuse: bytes held +// in heap spans that are currently allocated from the OS, including span +// overhead, but excluding pages the runtime has released back. RSS from +// /proc/self/status is more accurate but requires a read syscall on every +// poll and is not what the Go allocator itself tracks. We deliberately do +// NOT compare against "total heap classes" (which includes released memory +// already returned to the OS) or "heap/objects" alone (which misses span +// fragmentation that the OOM-killer sees). +package memwatch + +import ( + "context" + "log/slog" + "runtime/metrics" + "sync" + "sync/atomic" + "time" +) + +// DefaultPollInterval is the polling cadence used when Config.PollInterval +// is zero. One second is frequent enough to catch fast-growing memtables +// before the kernel kills the process, and infrequent enough that even +// aggressive log rollups don't observe the watcher as a hot sampler. +const DefaultPollInterval = time.Second + +// MinPollInterval is the floor enforced by New. runtime/metrics reads are +// cheap but a sub-10ms cadence produces no detection benefit over 10ms +// (memory pressure does not move that fast on these VMs) and would churn +// the ticker for no gain. +const MinPollInterval = 10 * time.Millisecond + +// Config configures a Watcher. +type Config struct { + // ThresholdBytes is the heap-in-use threshold in bytes. When the + // sampled heap-in-use crosses this value the watcher invokes OnExceed + // exactly once and returns. A zero value disables the watcher entirely + // (Start returns immediately). + ThresholdBytes uint64 + + // PollInterval is how often the metrics are sampled. Defaults to + // DefaultPollInterval when zero; values below MinPollInterval are + // clamped up to MinPollInterval. + PollInterval time.Duration + + // OnExceed is called at most once, from the watcher's own goroutine, + // when the threshold is crossed. It must be non-blocking or at least + // must not block the caller indefinitely (the watcher returns + // immediately after invocation regardless). Typical implementations + // cancel a root context and flag a process-wide exit-code sentinel. + OnExceed func() + + // Logger, if non-nil, receives a single structured log line when the + // threshold is crossed. When nil, slog.Default() is used. + Logger *slog.Logger +} + +// Metric sample indices — kept stable so samples[] can be reused across +// polls without reallocating or re-resolving names. +const ( + sampleHeapObjects = iota + sampleHeapUnused + sampleHeapReleased + sampleGCGoal + sampleCount +) + +var metricNames = [sampleCount]string{ + sampleHeapObjects: "/memory/classes/heap/objects:bytes", + sampleHeapUnused: "/memory/classes/heap/unused:bytes", + sampleHeapReleased: "/memory/classes/heap/released:bytes", + sampleGCGoal: "/gc/heap/goal:bytes", +} + +// Watcher polls process memory and fires OnExceed once, when heap-in-use +// crosses the configured threshold. Callers get a single-shot notification +// and are expected to initiate graceful shutdown; Watcher does not call +// os.Exit or send signals itself. +type Watcher struct { + cfg Config + fired atomic.Bool + started atomic.Bool + doneCh chan struct{} + closeOnce sync.Once + // samples is reused across polls; metric-name resolution happens once + // in New so the hot path only walks a fixed []Sample. + samples []metrics.Sample +} + +// New constructs a Watcher from the given Config. The Watcher does not +// start polling until Start is called. +func New(cfg Config) *Watcher { + switch { + case cfg.PollInterval <= 0: + cfg.PollInterval = DefaultPollInterval + case cfg.PollInterval < MinPollInterval: + cfg.PollInterval = MinPollInterval + } + if cfg.Logger == nil { + cfg.Logger = slog.Default() + } + samples := make([]metrics.Sample, sampleCount) + for i, name := range metricNames { + samples[i].Name = name + } + return &Watcher{ + cfg: cfg, + doneCh: make(chan struct{}), + samples: samples, + } +} + +// Start runs the watchdog loop. It returns when ctx is cancelled, when +// OnExceed has fired, or immediately when ThresholdBytes is zero (the +// watcher is disabled). It is safe to call Start at most once per Watcher; +// subsequent calls return immediately because the done channel has already +// been closed. +func (w *Watcher) Start(ctx context.Context) { + if !w.started.CompareAndSwap(false, true) { + return + } + defer w.closeDoneOnce() + + if w.cfg.ThresholdBytes == 0 { + // Disabled: do not even start a ticker, so an OFF-by-default + // deployment pays zero cost. + return + } + + // Sample once before the first tick: if the process is already above + // the threshold at Start (crashloop-restart after OOM, large startup + // allocations, etc.), waiting for the first ticker cycle can let the + // kernel OOM-kill the process we were supposed to protect. + if w.checkAndMaybeFire() { + return + } + + ticker := time.NewTicker(w.cfg.PollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if w.checkAndMaybeFire() { + return + } + } + } +} + +// Done returns a channel that is closed when Start returns. Tests can use +// it to assert the watcher goroutine actually exits (no leak) after +// ctx cancel or OnExceed. +func (w *Watcher) Done() <-chan struct{} { + return w.doneCh +} + +// closeDoneOnce closes doneCh at most once across the Watcher's lifetime. +// Per-Watcher sync.Once avoids the contention a shared package-level mutex +// would introduce if multiple watchers coexisted. +func (w *Watcher) closeDoneOnce() { + w.closeOnce.Do(func() { close(w.doneCh) }) +} + +// sampleUint64 reads a named Uint64 sample from w.samples after metrics.Read +// has populated them. Returns 0 if the metric is not supported by the +// current Go runtime (runtime/metrics guarantees no panic, just +// KindBad). The watcher treats missing metrics as "no pressure detected"; +// the primary metrics used by the threshold check have been present since +// Go 1.16, so this only matters for defensive correctness. +func (w *Watcher) sampleUint64(idx int) uint64 { + if w.samples[idx].Value.Kind() != metrics.KindUint64 { + return 0 + } + return w.samples[idx].Value.Uint64() +} + +// checkAndMaybeFire samples runtime/metrics once, computes heap-in-use, and +// if it is at or above the threshold and OnExceed has not already fired, +// invokes OnExceed and returns true to signal the loop to exit. +func (w *Watcher) checkAndMaybeFire() bool { + metrics.Read(w.samples) + + objects := w.sampleUint64(sampleHeapObjects) + unused := w.sampleUint64(sampleHeapUnused) + // heap-in-use = allocated heap spans (live objects plus reusable free + // slots the runtime still owns), matching MemStats.HeapInuse. + heapInuse := objects + unused + + if heapInuse < w.cfg.ThresholdBytes { + return false + } + + // CompareAndSwap so a (hypothetical) concurrent caller cannot cause + // OnExceed to run twice. The watcher currently runs from one goroutine + // but keeping the guard explicit documents the "single-shot" contract. + if !w.fired.CompareAndSwap(false, true) { + return true + } + + released := w.sampleUint64(sampleHeapReleased) + gcGoal := w.sampleUint64(sampleGCGoal) + + w.cfg.Logger.Warn("memory pressure shutdown", + "heap_inuse_bytes", heapInuse, + "heap_objects_bytes", objects, + "heap_released_bytes", released, + "threshold_bytes", w.cfg.ThresholdBytes, + "next_gc_bytes", gcGoal, + ) + + if w.cfg.OnExceed != nil { + w.cfg.OnExceed() + } + return true +} diff --git a/internal/memwatch/memwatch_test.go b/internal/memwatch/memwatch_test.go new file mode 100644 index 00000000..cb266e13 --- /dev/null +++ b/internal/memwatch/memwatch_test.go @@ -0,0 +1,158 @@ +package memwatch + +import ( + "context" + "runtime" + "sync/atomic" + "testing" + "time" +) + +// waitForDone asserts the watcher's goroutine exits within d, returning +// immediately on success and failing the test on timeout. Using the +// Watcher.Done() channel avoids the goleak dep and keeps the check local. +func waitForDone(t *testing.T, w *Watcher, d time.Duration) { + t.Helper() + select { + case <-w.Done(): + case <-time.After(d): + t.Fatalf("watcher goroutine did not exit within %v", d) + } +} + +// TestWatcher_FiresOnceAboveThreshold creates a watcher with a threshold +// so low the current heap is guaranteed to exceed it, verifies OnExceed +// fires, and verifies it fires only once even though the polling loop +// would otherwise observe "over threshold" on every subsequent tick. +func TestWatcher_FiresOnceAboveThreshold(t *testing.T) { + t.Parallel() + + fired := make(chan struct{}, 8) + var count atomic.Int32 + w := New(Config{ + // 1 byte threshold: HeapInuse is always > 1B in a live program. + ThresholdBytes: 1, + PollInterval: 5 * time.Millisecond, + OnExceed: func() { + count.Add(1) + // non-blocking send: buffered channel so a pathological + // double-fire is observable via the count, not a deadlock. + select { + case fired <- struct{}{}: + default: + } + }, + }) + + // Hold a live allocation so HeapInuse cannot collapse mid-test. We + // touch it after the wait below so the compiler/escape analysis keeps + // it on the heap for the duration of the test. + ballast := make([]byte, 1<<20) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Start(ctx) + + select { + case <-fired: + case <-time.After(2 * time.Second): + t.Fatalf("OnExceed was not invoked") + } + + // Give the loop several more poll intervals to prove it does not fire + // again. The watcher should have already returned on first fire, so + // this also indirectly tests the loop-exit path. + time.Sleep(50 * time.Millisecond) + if got := count.Load(); got != 1 { + t.Fatalf("OnExceed fired %d times, want exactly 1", got) + } + + // Ensure Start actually returned. + waitForDone(t, w, time.Second) + + // Keep ballast live past the assertions. + runtime.KeepAlive(ballast) +} + +// TestWatcher_DoesNotFireBelowThreshold runs the watcher for multiple poll +// intervals with a threshold far above any reasonable process HeapInuse +// and verifies OnExceed is never called. +func TestWatcher_DoesNotFireBelowThreshold(t *testing.T) { + t.Parallel() + + var count atomic.Int32 + w := New(Config{ + // 1 TiB: a Go test binary will not reach this. + ThresholdBytes: 1 << 40, + PollInterval: 5 * time.Millisecond, + OnExceed: func() { + count.Add(1) + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go w.Start(ctx) + + // Three polls' worth plus a safety margin. + time.Sleep(30 * time.Millisecond) + cancel() + waitForDone(t, w, time.Second) + + if got := count.Load(); got != 0 { + t.Fatalf("OnExceed fired %d times below threshold, want 0", got) + } +} + +// TestWatcher_StopsOnContextCancel verifies the watcher goroutine exits +// promptly when the supplied context is cancelled, with no leak. +func TestWatcher_StopsOnContextCancel(t *testing.T) { + t.Parallel() + + w := New(Config{ + ThresholdBytes: 1 << 40, // never fires + PollInterval: 10 * time.Millisecond, + OnExceed: func() { t.Fatalf("OnExceed unexpectedly fired") }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + go w.Start(ctx) + + // Let the goroutine park in its select. + time.Sleep(20 * time.Millisecond) + cancel() + waitForDone(t, w, time.Second) +} + +// TestWatcher_DisabledWhenThresholdZero verifies a threshold of zero +// disables the watcher: Start returns immediately and OnExceed is never +// invoked, even after generous wait time. +func TestWatcher_DisabledWhenThresholdZero(t *testing.T) { + t.Parallel() + + var count atomic.Int32 + w := New(Config{ + ThresholdBytes: 0, + PollInterval: 1 * time.Millisecond, + OnExceed: func() { + count.Add(1) + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go w.Start(ctx) + // When disabled, Start must return essentially immediately. + waitForDone(t, w, 200*time.Millisecond) + + // Allocate something to grow the heap and confirm the (stopped) + // watcher does not observe it. + ballast := make([]byte, 4<<20) + time.Sleep(20 * time.Millisecond) + runtime.KeepAlive(ballast) + + if got := count.Load(); got != 0 { + t.Fatalf("OnExceed fired %d times while disabled, want 0", got) + } +} diff --git a/main.go b/main.go index 4306388a..5c587d77 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,8 @@ import ( "context" "flag" "log" + "log/slog" + "math" "net" "net/http" "os" @@ -11,11 +13,13 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/bootjp/elastickv/adapter" "github.com/bootjp/elastickv/distribution" internalutil "github.com/bootjp/elastickv/internal" + "github.com/bootjp/elastickv/internal/memwatch" internalraftadmin "github.com/bootjp/elastickv/internal/raftadmin" "github.com/bootjp/elastickv/internal/raftengine" etcdraftengine "github.com/bootjp/elastickv/internal/raftengine/etcd" @@ -95,14 +99,96 @@ var ( raftDynamoMap = flag.String("raftDynamoMap", "", "Map of Raft address to DynamoDB address (raftAddr=dynamoAddr,...)") ) +// memoryPressureExit is set to true by the memwatch OnExceed callback to +// signal that the subsequent graceful shutdown was triggered by user-space +// OOM avoidance rather than an ordinary SIGTERM. The process exits with a +// distinct non-zero code (exitCodeMemoryPressure) so operators reading +// logs can distinguish this case from a crash or an ordinary stop. +var memoryPressureExit atomic.Bool + +// exitCodeMemoryPressure is reported by main when memwatch triggered the +// shutdown. It is non-zero so supervisors see a non-success exit, but +// distinct from log.Fatalf's 1 and from os.Exit(1) in the other binaries +// so log scraping can tell them apart. +const exitCodeMemoryPressure = 2 + +// memoryShutdownThresholdEnvVar configures the heap-inuse ceiling at +// which memwatch triggers a graceful shutdown. Empty or "0" disables the +// watchdog (the default; existing operators see no behaviour change). +const memoryShutdownThresholdEnvVar = "ELASTICKV_MEMORY_SHUTDOWN_THRESHOLD_MB" + +// memoryShutdownPollIntervalEnvVar overrides memwatch's default poll +// cadence. Accepts any time.ParseDuration string. Invalid values log a +// warning and fall through to the default. +const memoryShutdownPollIntervalEnvVar = "ELASTICKV_MEMORY_SHUTDOWN_POLL_INTERVAL" + +const bytesPerMiB = 1024 * 1024 + func main() { flag.Parse() - if err := run(); err != nil { + err := run() + if memoryPressureExit.Load() { + // memwatch fired: surface exit code 2 regardless of whether run() + // returned a nil or an error (cancel() can cause in-flight + // listeners to return spurious errors during shutdown). Still + // log any residual error so a secondary failure during the + // graceful shutdown is visible in logs rather than swallowed. + if err != nil && !errors.Is(err, context.Canceled) { + slog.Warn("shutdown error after memory pressure", "error", err) + } + os.Exit(exitCodeMemoryPressure) + } + if err != nil { log.Fatalf("%v", err) } } +// memwatchConfigFromEnv resolves the memwatch Config from environment +// variables. It returns (cfg, true) when the watcher should run, or +// (_, false) when the operator has not opted in (the default). Errors in +// the optional poll-interval override are logged and ignored so a typo +// cannot take the process down. +func memwatchConfigFromEnv() (memwatch.Config, bool) { + raw := strings.TrimSpace(os.Getenv(memoryShutdownThresholdEnvVar)) + if raw == "" { + return memwatch.Config{}, false + } + mb, err := strconv.ParseUint(raw, 10, 64) + if err != nil { + slog.Warn("invalid "+memoryShutdownThresholdEnvVar+"; watcher disabled", + "value", raw, "error", err) + return memwatch.Config{}, false + } + if mb == 0 { + return memwatch.Config{}, false + } + // Guard against mb * bytesPerMiB wrapping past math.MaxUint64. The + // value has no real use above this ceiling (the host does not have + // exabytes of RAM), and a wrapped value would set an absurdly low + // threshold that fires immediately. + if mb > math.MaxUint64/bytesPerMiB { + slog.Warn("value for "+memoryShutdownThresholdEnvVar+" would overflow uint64; watcher disabled", + "value_mb", mb) + return memwatch.Config{}, false + } + + cfg := memwatch.Config{ + ThresholdBytes: mb * bytesPerMiB, + } + cfg.PollInterval = memwatch.DefaultPollInterval + if rawInterval := strings.TrimSpace(os.Getenv(memoryShutdownPollIntervalEnvVar)); rawInterval != "" { + d, err := time.ParseDuration(rawInterval) + if err != nil || d <= 0 { + slog.Warn("invalid "+memoryShutdownPollIntervalEnvVar+"; using default", + "value", rawInterval, "error", err) + } else { + cfg.PollInterval = d + } + } + return cfg, true +} + func run() error { cfg, engineType, bootstrapServers, bootstrap, err := resolveRuntimeInputs() if err != nil { @@ -173,6 +259,7 @@ func run() error { eg.Go(func() error { return runDistributionCatalogWatcher(runCtx, distCatalog, cfg.engine) }) + startMemoryWatchdog(runCtx, eg, cancel) distServer := adapter.NewDistributionServer( cfg.engine, distCatalog, @@ -516,6 +603,33 @@ func dispatchMonitorSources(runtimes []*raftGroupRuntime) []monitoring.DispatchS return out } +// startMemoryWatchdog optionally starts the memwatch goroutine. The +// watcher is off by default; it is enabled only when the operator sets +// ELASTICKV_MEMORY_SHUTDOWN_THRESHOLD_MB. On threshold crossing the +// callback flips the memoryPressureExit sentinel and cancels the root +// context, routing through the exact same shutdown path SIGTERM would +// use (errgroup unwinds, CleanupStack runs, WAL is synced). We do NOT +// send a signal, call os.Exit, or touch the raft engine directly here. +func startMemoryWatchdog(ctx context.Context, eg *errgroup.Group, cancel context.CancelFunc) { + cfg, enabled := memwatchConfigFromEnv() + if !enabled { + return + } + cfg.OnExceed = func() { + memoryPressureExit.Store(true) + cancel() + } + w := memwatch.New(cfg) + slog.Info("memory watchdog enabled", + "threshold_bytes", cfg.ThresholdBytes, + "poll_interval", cfg.PollInterval, + ) + eg.Go(func() error { + w.Start(ctx) + return nil + }) +} + // startMonitoringCollectors wires up the per-tick Prometheus // collectors (raft dispatch, Pebble LSM, store-layer OCC conflicts) // on top of the running raft runtimes. Kept separate from run() so