From 0176f95bcc1b38a4bd691907d6b8c8a8349b2b21 Mon Sep 17 00:00:00 2001 From: Simon Ferquel's Clanker Date: Mon, 20 Apr 2026 11:36:59 +0000 Subject: [PATCH] fix(#2457): retry MCP toolsets after tool calls within the same turn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When an MCP server is configured but unavailable at session start, the agent now retries automatically after every tool-call batch — within the same user turn — without requiring a new user message. ## Core changes ### MCP double-watcher race (pkg/tools/mcp/mcp.go) Added watcherAlive bool to Toolset. Toolset.Start() only spawns go watchConnection(...) when !watcherAlive. The goroutine clears the flag on all exit paths via defer. This prevents reprobe() from spawning a second watcher while an existing one is mid-backoff (ts.started==false but goroutine alive), which would cause racing doStart() calls and unsafe close/recreate of ts.restarted. ### Failure deduplication + recovery notices (pkg/tools/startable.go) ShouldReportFailure() returns true exactly once per failure streak, suppressing repeated 'start failed' warnings on every retry. ConsumeRecovery() returns true exactly once when a previously-failed toolset successfully starts, triggering a 'now available' warning. Both surface via WarningEvent -> notification.WarningCmd() (persistent TUI notifications that stay until dismissed). ### Reprobe after each tool-call batch (pkg/runtime/loop.go) reprobe() is called after every tool-call batch. It re-runs ensureToolSetsAreStarted() without emitting MCPInitStarted/Finished events (no TUI spinner flicker), emits any pending warnings, and emits a ToolsetInfo event when new tools appear. The updated tool list is picked up by the top-of-loop getTools() on the next iteration, so the model sees new tools in its very next response within the same user turn. ### TUI (pkg/agent/agent.go, pkg/runtime/loop.go, pkg/runtime/event.go) DrainWarnings() now includes both failure and recovery messages. WarningEvent used for all toolset lifecycle notifications. ## Tests - pkg/tools/startable_test.go: ShouldReportFailure/ConsumeRecovery behaviour (one warning per streak, recovery fires once, Stop resets) - pkg/agent/agent_test.go: TestAgentReProbeEmitsWarningThenNotice, TestAgentNoDuplicateStartWarnings - pkg/runtime/runtime_test.go: TestReprobe_NewToolsAvailableAfterToolCall, TestReprobe_NoChangeMeansNoExtraEvents Fixes #2457 Assisted-By: docker-agent --- pkg/agent/agent.go | 19 +++- pkg/agent/agent_test.go | 86 ++++++++++++++ pkg/runtime/loop.go | 94 +++++++++++++--- pkg/runtime/runtime_test.go | 219 +++++++++++++++++++++++++++++++++++- pkg/tools/mcp/mcp.go | 27 +++-- pkg/tools/startable.go | 61 +++++++++- pkg/tools/startable_test.go | 121 ++++++++++++++++++++ 7 files changed, 596 insertions(+), 31 deletions(-) diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 8dfd60619..ae6d82179 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -260,11 +260,24 @@ func (a *Agent) ToolSets() []tools.ToolSet { func (a *Agent) ensureToolSetsAreStarted(ctx context.Context) { for _, toolSet := range a.toolsets { if err := toolSet.Start(ctx); err != nil { - desc := tools.DescribeToolSet(toolSet) - slog.Warn("Toolset start failed; skipping", "agent", a.Name(), "toolset", desc, "error", err) - a.addToolWarning(fmt.Sprintf("%s start failed: %v", desc, err)) + // Only warn on the first failure in a streak; suppress duplicate + // warnings for subsequent retries that also fail. + if toolSet.ShouldReportFailure() { + desc := tools.DescribeToolSet(toolSet) + slog.Warn("Toolset start failed; will retry on next turn", "agent", a.Name(), "toolset", desc, "error", err) + a.addToolWarning(fmt.Sprintf("%s start failed: %v", desc, err)) + } else { + desc := tools.DescribeToolSet(toolSet) + slog.Debug("Toolset still unavailable; retrying next turn", "agent", a.Name(), "toolset", desc, "error", err) + } continue } + // Emit a one-time notice when a previously-failed toolset recovers. + if toolSet.ConsumeRecovery() { + desc := tools.DescribeToolSet(toolSet) + slog.Info("Toolset now available", "agent", a.Name(), "toolset", desc) + a.addToolWarning(desc + " is now available") + } } } diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index ee3e221c0..bbd4d292d 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -44,6 +44,34 @@ func (s *stubToolSet) Tools(context.Context) ([]tools.Tool, error) { return s.tools, nil } +// flappyToolSet is a ToolSet+Startable that returns a scripted sequence of +// errors from Start(). nil in the sequence means success. +type flappyToolSet struct { + errs []error + callIdx int + stubs []tools.Tool +} + +var ( + _ tools.ToolSet = (*flappyToolSet)(nil) + _ tools.Startable = (*flappyToolSet)(nil) +) + +func (f *flappyToolSet) Start(_ context.Context) error { + if f.callIdx >= len(f.errs) { + return nil + } + err := f.errs[f.callIdx] + f.callIdx++ + return err +} + +func (f *flappyToolSet) Stop(_ context.Context) error { return nil } + +func (f *flappyToolSet) Tools(_ context.Context) ([]tools.Tool, error) { + return f.stubs, nil +} + func TestAgentTools(t *testing.T) { tests := []struct { name string @@ -210,3 +238,61 @@ func TestModelOverride_ConcurrentAccess(t *testing.T) { <-done // If we got here without a race condition panic, the test passes } + +// TestAgentReProbeEmitsWarningThenNotice verifies the full retry lifecycle: +// turn 1 fails → warning emitted; turn 2 succeeds → notice emitted; tools available. +func TestAgentReProbeEmitsWarningThenNotice(t *testing.T) { + t.Parallel() + + errBoom := errors.New("server unavailable") + stub := &flappyToolSet{ + errs: []error{errBoom, nil}, + stubs: []tools.Tool{{Name: "mcp_ping", Parameters: map[string]any{}}}, + } + a := New("root", "test", WithToolSets(stub)) + + // Turn 1: start fails → 1 warning, 0 tools. + got, err := a.Tools(t.Context()) + require.NoError(t, err) + assert.Empty(t, got, "turn 1: no tools while toolset is unavailable") + warnings := a.DrainWarnings() + require.Len(t, warnings, 1, "turn 1: exactly one warning expected") + assert.Contains(t, warnings[0], "start failed") + + // Turn 2: start succeeds → 1 recovery warning, tools available. + got, err = a.Tools(t.Context()) + require.NoError(t, err) + assert.Len(t, got, 1, "turn 2: tool should be available after recovery") + recovery := a.DrainWarnings() + require.Len(t, recovery, 1, "turn 2: exactly one recovery warning expected") + assert.Contains(t, recovery[0], "now available", "turn 2: recovery warning must mention availability") +} + +// TestAgentNoDuplicateStartWarnings verifies that repeated failures generate +// only one warning (on the first failure), not one per retry. +func TestAgentNoDuplicateStartWarnings(t *testing.T) { + t.Parallel() + + errBoom := errors.New("server unavailable") + stub := &flappyToolSet{ + errs: []error{errBoom, errBoom, errBoom}, + stubs: []tools.Tool{{Name: "mcp_ping", Parameters: map[string]any{}}}, + } + a := New("root", "test", WithToolSets(stub)) + + // Turn 1: first failure → warning. + _, err := a.Tools(t.Context()) + require.NoError(t, err) + warnings := a.DrainWarnings() + require.Len(t, warnings, 1, "turn 1: exactly one warning on first failure") + + // Turn 2: repeated failure → no new warning. + _, err = a.Tools(t.Context()) + require.NoError(t, err) + assert.Empty(t, a.DrainWarnings(), "turn 2: no duplicate warning on repeated failure") + + // Turn 3: still failing → still no new warning. + _, err = a.Tools(t.Context()) + require.NoError(t, err) + assert.Empty(t, a.DrainWarnings(), "turn 3: no duplicate warning on repeated failure") +} diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index f76921d26..bebcc0d7f 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -100,7 +100,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c r.emitAgentWarnings(a, chanSend(events)) r.configureToolsetHandlers(a, events) - agentTools, err := r.getTools(ctx, a, sessionSpan, events) + agentTools, err := r.getTools(ctx, a, sessionSpan, events, true) if err != nil { events <- Error(fmt.Sprintf("failed to get tools: %v", err)) return @@ -163,7 +163,7 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c r.emitAgentWarnings(a, chanSend(events)) r.configureToolsetHandlers(a, events) - agentTools, err := r.getTools(ctx, a, sessionSpan, events) + agentTools, err := r.getTools(ctx, a, sessionSpan, events, true) if err != nil { events <- Error(fmt.Sprintf("failed to get tools: %v", err)) return @@ -382,6 +382,20 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c r.processToolCalls(ctx, sess, res.Calls, agentTools, events) + // Re-probe toolsets after tool calls: an install/setup tool call may + // have made a previously-unavailable LSP or MCP connectable. reprobe() + // calls ensureToolSetsAreStarted, emits recovery notices, and updates + // the TUI tool-count immediately. + // + // The new tools are picked up by the next iteration's getTools() call + // at the top of this loop, so the model sees them on its very next + // response — within the same user turn, without requiring a new user + // message. reprobe's return value is intentionally discarded here; + // the top-of-loop getTools() is the authoritative source. + if len(res.Calls) > 0 { + r.reprobe(ctx, sess, a, agentTools, sessionSpan, events) + } + // Check for degenerate tool call loops if loopDetector.record(res.Calls) { toolName := "unknown" @@ -575,17 +589,14 @@ func (r *LocalRuntime) compactIfNeeded( r.Summarize(ctx, sess, "", events) } -// getTools executes tool retrieval with automatic OAuth handling -func (r *LocalRuntime) getTools(ctx context.Context, a *agent.Agent, sessionSpan trace.Span, events chan Event) ([]tools.Tool, error) { - shouldEmitMCPInit := len(a.ToolSets()) > 0 - if shouldEmitMCPInit { +// getTools executes tool retrieval with automatic OAuth handling. +// emitLifecycleEvents controls whether MCPInitStarted/Finished are emitted; +// pass false when calling from reprobe to avoid spurious TUI spinner flicker. +func (r *LocalRuntime) getTools(ctx context.Context, a *agent.Agent, sessionSpan trace.Span, events chan Event, emitLifecycleEvents bool) ([]tools.Tool, error) { + if emitLifecycleEvents && len(a.ToolSets()) > 0 { events <- MCPInitStarted(a.Name()) + defer func() { events <- MCPInitFinished(a.Name()) }() } - defer func() { - if shouldEmitMCPInit { - events <- MCPInitFinished(a.Name()) - } - }() agentTools, err := a.Tools(ctx) if err != nil { @@ -616,15 +627,15 @@ func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Even } } -// emitAgentWarnings drains and emits any agent initialization warnings. +// emitAgentWarnings drains and emits any pending toolset warnings as persistent +// TUI notifications. Both start failures and recovery notices are emitted as +// warnings so they remain visible until the user dismisses them. func (r *LocalRuntime) emitAgentWarnings(a *agent.Agent, send func(Event)) { warnings := a.DrainWarnings() - if len(warnings) == 0 { - return + if len(warnings) > 0 { + slog.Warn("Tool setup partially failed; continuing", "agent", a.Name(), "warnings", warnings) + send(Warning(formatToolWarning(a, warnings), a.Name())) } - - slog.Warn("Tool setup partially failed; continuing", "agent", a.Name(), "warnings", warnings) - send(Warning(formatToolWarning(a, warnings), a.Name())) } func formatToolWarning(a *agent.Agent, warnings []string) string { @@ -669,3 +680,52 @@ func chanSend(ch chan Event) func(Event) { } } } + +// reprobe re-runs ensureToolSetsAreStarted after a batch of tool calls. +// If new tools became available (by name-set diff), it emits recovery notices +// and a ToolsetInfo event to update the TUI immediately. The new tools will be +// picked up by the next iteration's getTools() call at the top of the loop. +// +// reprobe deliberately does NOT return the new tool list: the top-of-loop +// getTools() is the single authoritative source for agentTools each iteration. +func (r *LocalRuntime) reprobe( + ctx context.Context, + sess *session.Session, + a *agent.Agent, + currentTools []tools.Tool, + sessionSpan trace.Span, + events chan Event, +) { + updated, err := r.getTools(ctx, a, sessionSpan, events, false) + if err != nil { + slog.Warn("reprobe: getTools failed", "agent", a.Name(), "error", err) + return + } + updated = filterExcludedTools(updated, sess.ExcludedTools) + + // Emit any pending warnings/notices that getTools just generated. + r.emitAgentWarnings(a, chanSend(events)) + + // Compute added tools by comparing name-sets (not just counts), so we + // correctly handle a toolset that replaced one tool with another. + prev := make(map[string]struct{}, len(currentTools)) + for _, t := range currentTools { + prev[t.Name] = struct{}{} + } + var added []string + for _, t := range updated { + if _, exists := prev[t.Name]; !exists { + added = append(added, t.Name) + } + } + + if len(added) == 0 { + return + } + + slog.Info("New tools available after toolset re-probe", + "agent", a.Name(), "added", added) + + // Emit updated tool count to the TUI immediately. + chanSend(events)(ToolsetInfo(len(updated), false, a.Name())) +} diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index a4590ac8c..c1adbc84a 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -135,6 +135,17 @@ func (b *streamBuilder) AddStopWithUsage(input, output int64) *streamBuilder { return b } +func (b *streamBuilder) AddToolCallStopWithUsage(input, output int64) *streamBuilder { + b.responses = append(b.responses, chat.MessageStreamResponse{ + Choices: []chat.MessageStreamChoice{{ + Index: 0, + FinishReason: chat.FinishReasonToolCalls, + }}, + Usage: &chat.Usage{InputTokens: input, OutputTokens: output}, + }) + return b +} + func (b *streamBuilder) Build() *mockStream { return &mockStream{responses: b.responses} } type mockProvider struct { @@ -763,7 +774,7 @@ func TestGetTools_WarningHandling(t *testing.T) { sessionSpan := trace.SpanFromContext(t.Context()) // First call - tools1, err := rt.getTools(t.Context(), root, sessionSpan, events) + tools1, err := rt.getTools(t.Context(), root, sessionSpan, events, true) require.NoError(t, err) require.Len(t, tools1, tt.wantToolCount) @@ -1981,3 +1992,209 @@ func TestRunStream_EmptyMessages_SendUserMessage(t *testing.T) { } require.NotEmpty(t, events) } + +// recordingProvider wraps a sequence of mock streams and records the tools +// passed to each CreateChatCompletionStream call. +type recordingProvider struct { + id string + streams []*mockStream + callIdx int + + mu sync.Mutex + recordedCalls [][]tools.Tool // tools passed on each call +} + +func (r *recordingProvider) ID() string { return r.id } + +func (r *recordingProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, toolList []tools.Tool) (chat.MessageStream, error) { + r.mu.Lock() + defer r.mu.Unlock() + + // Record the tool names for this call. + r.recordedCalls = append(r.recordedCalls, append([]tools.Tool{}, toolList...)) + + if r.callIdx >= len(r.streams) { + return newStreamBuilder().AddStopWithUsage(1, 1).Build(), nil + } + s := r.streams[r.callIdx] + r.callIdx++ + return s, nil +} + +func (r *recordingProvider) BaseConfig() base.Config { return base.Config{} } +func (r *recordingProvider) MaxTokens() int { return 0 } + +// flappyRuntimeToolSet is a ToolSet+Startable that fails on the first N +// Start() calls and succeeds on all subsequent ones, revealing a new tool +// on success. +type flappyRuntimeToolSet struct { + mu sync.Mutex + attempts int + failUntil int // fail while attempts <= failUntil + newTool tools.Tool +} + +func (f *flappyRuntimeToolSet) Start(_ context.Context) error { + f.mu.Lock() + defer f.mu.Unlock() + f.attempts++ + if f.attempts <= f.failUntil { + return errors.New("server unavailable") + } + return nil +} + +func (f *flappyRuntimeToolSet) Stop(_ context.Context) error { return nil } + +func (f *flappyRuntimeToolSet) Tools(_ context.Context) ([]tools.Tool, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.attempts <= f.failUntil { + return nil, nil + } + return []tools.Tool{f.newTool}, nil +} + +// TestReprobe_NewToolsAvailableAfterToolCall verifies that when a toolset +// fails to start initially but succeeds after a tool call runs (simulating +// an install step), the reprobe mechanism surfaces the new tool to the model +// on its very next response — within the same user turn. +func TestReprobe_NewToolsAvailableAfterToolCall(t *testing.T) { + t.Parallel() + + mcpTool := tools.Tool{Name: "mcp_hello", Parameters: map[string]any{}} + installTool := tools.Tool{ + Name: "install_mcp", + Parameters: map[string]any{}, + Handler: func(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) { + return tools.ResultSuccess("installed"), nil + }, + } + + // Turn 1: model calls install_mcp and keeps going (FinishReasonToolCall → loop continues). + // Turn 2: model sees mcp_hello in its tool list and stops. + turn1 := newStreamBuilder(). + AddToolCallName("call_1", "install_mcp"). + AddToolCallArguments("call_1", `{}`). + AddToolCallStopWithUsage(5, 5). + Build() + turn2 := newStreamBuilder(). + AddContent("MCP is now available"). + AddStopWithUsage(3, 3). + Build() + + flappy := &flappyRuntimeToolSet{newTool: mcpTool, failUntil: 2} + installTS := newStubToolSet(nil, []tools.Tool{installTool}, nil) + + prov := &recordingProvider{ + id: "test/mock-model", + streams: []*mockStream{turn1, turn2}, + } + + root := agent.New("root", "test", + agent.WithModel(prov), + agent.WithToolSets(installTS, flappy), + ) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + rt.registerDefaultTools() + + sess := session.New(session.WithUserMessage("Install and use MCP")) + sess.Title = "reprobe test" + sess.ToolsApproved = true + + evCh := rt.RunStream(t.Context(), sess) + var events []Event + for ev := range evCh { + events = append(events, ev) + } + + prov.mu.Lock() + defer prov.mu.Unlock() + + require.GreaterOrEqual(t, len(prov.recordedCalls), 2, "expected at least 2 model calls") + + // First model call: only install_mcp available (mcp_hello not yet). + call1Names := toolNames(prov.recordedCalls[0]) + assert.Contains(t, call1Names, "install_mcp", "turn 1 must include install_mcp") + assert.NotContains(t, call1Names, "mcp_hello", "turn 1 must NOT include mcp_hello before install") + + // Second model call: mcp_hello must be visible. + call2Names := toolNames(prov.recordedCalls[1]) + assert.Contains(t, call2Names, "mcp_hello", "turn 2 must include mcp_hello after reprobe") + + // A ToolsetInfo event with the new count must have been emitted during reprobe. + var toolsetInfoCounts []int + for _, ev := range events { + if ti, ok := ev.(*ToolsetInfoEvent); ok { + toolsetInfoCounts = append(toolsetInfoCounts, ti.AvailableTools) + } + } + assert.Contains(t, toolsetInfoCounts, 2, "ToolsetInfo with count=2 expected after reprobe") +} + +// TestReprobe_NoChangeMeansNoExtraEvents verifies that reprobe is a no-op +// (no extra ToolsetInfo events, no panics) when no new tools appear after +// a tool call. +func TestReprobe_NoChangeMeansNoExtraEvents(t *testing.T) { + t.Parallel() + + staticTool := tools.Tool{ + Name: "do_thing", + Parameters: map[string]any{}, + Handler: func(_ context.Context, _ tools.ToolCall) (*tools.ToolCallResult, error) { + return tools.ResultSuccess("done"), nil + }, + } + + stream1 := newStreamBuilder(). + AddToolCallName("c1", "do_thing"). + AddToolCallArguments("c1", `{}`). + AddStopWithUsage(5, 5). + Build() + + prov := &recordingProvider{ + id: "test/mock-model", + streams: []*mockStream{stream1}, + } + + ts := newStubToolSet(nil, []tools.Tool{staticTool}, nil) + root := agent.New("root", "test", agent.WithModel(prov), agent.WithToolSets(ts)) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + rt.registerDefaultTools() + + sess := session.New(session.WithUserMessage("Do the thing")) + sess.Title = "no-change reprobe test" + sess.ToolsApproved = true + + evCh := rt.RunStream(t.Context(), sess) + var events []Event + for ev := range evCh { + events = append(events, ev) + } + + // Count ToolsetInfo events — reprobe should NOT emit an extra one. + var counts []int + for _, ev := range events { + if ti, ok := ev.(*ToolsetInfoEvent); ok { + counts = append(counts, ti.AvailableTools) + } + } + // All counts should be 1 (the static tool). + for _, c := range counts { + assert.Equal(t, 1, c, "unexpected ToolsetInfo count — reprobe emitted extra event when tools unchanged") + } +} + +func toolNames(ts []tools.Tool) []string { + names := make([]string, len(ts)) + for i, t := range ts { + names[i] = t.Name + } + return names +} diff --git a/pkg/tools/mcp/mcp.go b/pkg/tools/mcp/mcp.go index d8a27fa2b..db971988d 100644 --- a/pkg/tools/mcp/mcp.go +++ b/pkg/tools/mcp/mcp.go @@ -51,6 +51,7 @@ type Toolset struct { mu sync.Mutex started bool stopping bool // true when Stop() has been called + watcherAlive bool // true while the watchConnection goroutine is running // Cached tools and prompts, invalidated via MCP notifications. // cacheGen is bumped on each invalidation so that a concurrent @@ -178,12 +179,16 @@ func (ts *Toolset) Start(ctx context.Context) error { ts.started = true - // Spawn the connection watcher only on the initial Start. - // Restarts from within watchConnection call doStart directly - // and must NOT spawn an additional watcher goroutine. - // Use WithoutCancel so the watcher outlives the caller's context; - // the only way to stop it is via Stop() setting ts.stopping. - go ts.watchConnection(context.WithoutCancel(ctx)) + // Spawn the connection watcher only when no watcher is already running. + // A watcher goroutine survives across restarts (it loops inside + // watchConnection); reprobe may call Start() while that goroutine is + // mid-restart with started==false, and we must not spawn a second one. + if !ts.watcherAlive { + ts.watcherAlive = true + // Use WithoutCancel so the watcher outlives the caller's context; + // the only way to stop it is via Stop() setting ts.stopping. + go ts.watchConnection(context.WithoutCancel(ctx)) + } return nil } @@ -281,9 +286,15 @@ func (ts *Toolset) doStart(ctx context.Context) error { // watchConnection monitors the MCP server connection and auto-restarts it // if the server dies unexpectedly (i.e. we didn't call Stop()). -// Only one watchConnection goroutine exists per Toolset; it is spawned by -// Start() and loops across restarts without spawning additional goroutines. +// Exactly one watchConnection goroutine exists per Toolset while ts.watcherAlive +// is true; it is spawned by Start() and cleared on exit. func (ts *Toolset) watchConnection(ctx context.Context) { + defer func() { + ts.mu.Lock() + ts.watcherAlive = false + ts.mu.Unlock() + }() + for { err := ts.mcpClient.Wait() diff --git a/pkg/tools/startable.go b/pkg/tools/startable.go index 6f1609aa7..e8e4ee45b 100644 --- a/pkg/tools/startable.go +++ b/pkg/tools/startable.go @@ -30,11 +30,24 @@ func DescribeToolSet(ts ToolSet) string { // StartableToolSet wraps a ToolSet with lazy, single-flight start semantics. // This is the canonical way to manage toolset lifecycle. +// +// Failure and recovery tracking: +// - freshFailure is set to true on the first Start() failure in a streak +// (i.e. when hasEverFailed transitions false→true). It is consumed by +// ShouldReportFailure() which returns true exactly once per streak. +// - hasEverFailed stays true for the duration of the failure streak. +// - pendingRecovery is set to true on the first successful Start() after a +// failure streak. It is consumed by ConsumeRecovery(). +// - ConsumeRecovery() also resets hasEverFailed, so the next failure streak +// generates a fresh warning. type StartableToolSet struct { ToolSet - mu sync.Mutex - started bool + mu sync.Mutex + started bool + hasEverFailed bool // true for the duration of a failure streak + freshFailure bool // true only for the first failure in a streak; consumed by ShouldReportFailure + pendingRecovery bool // true when a recovery notice is pending; consumed by ConsumeRecovery } // NewStartable wraps a ToolSet for lazy initialization. @@ -64,9 +77,20 @@ func (s *StartableToolSet) Start(ctx context.Context) error { if startable, ok := As[Startable](s.ToolSet); ok { if err := startable.Start(ctx); err != nil { + // Only set freshFailure on the very first failure in a streak so + // that repeated failed retries don't each emit a new warning. + if !s.hasEverFailed { + s.hasEverFailed = true + s.freshFailure = true + } return err } } + + // Successful start: if this followed a failure streak, signal recovery. + if s.hasEverFailed { + s.pendingRecovery = true + } s.started = true return nil } @@ -78,12 +102,45 @@ func (s *StartableToolSet) Stop(ctx context.Context) error { defer s.mu.Unlock() s.started = false + s.hasEverFailed = false + s.freshFailure = false + s.pendingRecovery = false if startable, ok := As[Startable](s.ToolSet); ok { return startable.Stop(ctx) } return nil } +// ShouldReportFailure returns true the first time Start() fails in a new +// failure streak — i.e. when hasEverFailed transitions from false to true. +// It returns false for all subsequent failures in the same streak, preventing +// repeated "start failed" warnings from flooding the user. It is safe to call +// even when Start() did not return an error (it will return false). +func (s *StartableToolSet) ShouldReportFailure() bool { + s.mu.Lock() + defer s.mu.Unlock() + if !s.freshFailure { + return false + } + s.freshFailure = false + return true +} + +// ConsumeRecovery returns true exactly once after a Start() that succeeded +// following a previously-reported failure streak. Calling it also resets +// hasEverFailed and freshFailure so that a future failure generates a fresh warning. +func (s *StartableToolSet) ConsumeRecovery() bool { + s.mu.Lock() + defer s.mu.Unlock() + if !s.pendingRecovery { + return false + } + s.pendingRecovery = false + s.hasEverFailed = false + s.freshFailure = false + return true +} + // Unwrap returns the underlying ToolSet. func (s *StartableToolSet) Unwrap() ToolSet { return s.ToolSet diff --git a/pkg/tools/startable_test.go b/pkg/tools/startable_test.go index 565f44e24..b796459bf 100644 --- a/pkg/tools/startable_test.go +++ b/pkg/tools/startable_test.go @@ -2,6 +2,7 @@ package tools_test import ( "context" + "errors" "testing" "gotest.tools/v3/assert" @@ -21,6 +22,34 @@ type stubToolSet struct{} func (s *stubToolSet) Tools(context.Context) ([]tools.Tool, error) { return nil, nil } +// flappyToolSet implements ToolSet + Startable with a scripted sequence of errors. +// Each call to Start() consumes the next error from errs; nil means success. +type flappyToolSet struct { + errs []error + callIdx int + startups int // number of successful Start() calls +} + +func (f *flappyToolSet) Tools(context.Context) ([]tools.Tool, error) { + return []tools.Tool{{Name: "flappy_tool"}}, nil +} + +func (f *flappyToolSet) Start(_ context.Context) error { + if f.callIdx < len(f.errs) { + err := f.errs[f.callIdx] + f.callIdx++ + if err != nil { + return err + } + } + f.startups++ + return nil +} + +func (f *flappyToolSet) Stop(_ context.Context) error { + return nil +} + func TestDescribeToolSet_UsesDescriber(t *testing.T) { t.Parallel() @@ -57,3 +86,95 @@ func TestDescribeToolSet_UnwrapsStartableAndFallsBackToTypeName(t *testing.T) { wrapped := tools.NewStartable(inner) assert.Check(t, is.Equal(tools.DescribeToolSet(wrapped), "*tools_test.stubToolSet")) } + +// TestStartableToolSet_ShouldReportFailure_OncePerStreak verifies that +// ShouldReportFailure returns true exactly once per failure streak, +// suppressing duplicate warnings on repeated retries. +func TestStartableToolSet_ShouldReportFailure_OncePerStreak(t *testing.T) { + t.Parallel() + + errBoom := errors.New("boom") + f := &flappyToolSet{errs: []error{errBoom, errBoom, nil}} + s := tools.NewStartable(f) + + // Turn 1: first failure — should report. + err := s.Start(t.Context()) + assert.Check(t, err != nil, "expected error on turn 1") + assert.Check(t, is.Equal(s.ShouldReportFailure(), true), "turn 1: first failure should be reported") + assert.Check(t, is.Equal(s.ShouldReportFailure(), false), "turn 1: second call must return false") + + // Turn 2: second failure in same streak — must NOT report again. + err = s.Start(t.Context()) + assert.Check(t, err != nil, "expected error on turn 2") + assert.Check(t, is.Equal(s.ShouldReportFailure(), false), "turn 2: duplicate failure must not report") + + // Turn 3: success — ConsumeRecovery fires exactly once. + err = s.Start(t.Context()) + assert.Check(t, err == nil, "expected success on turn 3") + assert.Check(t, is.Equal(s.ConsumeRecovery(), true), "turn 3: recovery must be signalled") + assert.Check(t, is.Equal(s.ConsumeRecovery(), false), "turn 3: recovery must fire only once") +} + +// TestStartableToolSet_NoRecoveryWithoutPriorFailure verifies that +// ConsumeRecovery returns false when Start succeeds on the very first try. +func TestStartableToolSet_NoRecoveryWithoutPriorFailure(t *testing.T) { + t.Parallel() + + f := &flappyToolSet{errs: []error{nil}} + s := tools.NewStartable(f) + + err := s.Start(t.Context()) + assert.Check(t, err == nil) + assert.Check(t, is.Equal(s.ShouldReportFailure(), false), "no failure: ShouldReportFailure must be false") + assert.Check(t, is.Equal(s.ConsumeRecovery(), false), "no prior failure: ConsumeRecovery must be false") +} + +// TestStartableToolSet_RecoveryThenFailureWarnsAgain verifies that after a full +// fail→report→recover cycle, a subsequent new failure generates a fresh warning. +func TestStartableToolSet_RecoveryThenFailureWarnsAgain(t *testing.T) { + t.Parallel() + + errBoom := errors.New("boom") + f := &flappyToolSet{errs: []error{errBoom, nil, errBoom}} + s := tools.NewStartable(f) + + // Cycle 1: fail then recover. + err := s.Start(t.Context()) + assert.Check(t, err != nil) + assert.Check(t, is.Equal(s.ShouldReportFailure(), true)) + + err = s.Start(t.Context()) + assert.Check(t, err == nil) + assert.Check(t, is.Equal(s.ConsumeRecovery(), true)) + + // Now stop so we can start again (resets started flag). + assert.Check(t, s.Stop(t.Context()) == nil) + + // Cycle 2: new failure — must warn again. + err = s.Start(t.Context()) + assert.Check(t, err != nil) + assert.Check(t, is.Equal(s.ShouldReportFailure(), true), "fresh failure after recovery must warn") +} + +// TestStartableToolSet_StopResetsFailureState verifies that after a failure streak, +// an explicit Stop() clears all tracking so the next failure warns again. +func TestStartableToolSet_StopResetsFailureState(t *testing.T) { + t.Parallel() + + errBoom := errors.New("boom") + f := &flappyToolSet{errs: []error{errBoom, errBoom}} + s := tools.NewStartable(f) + + // First failure: consume the warning. + err := s.Start(t.Context()) + assert.Check(t, err != nil) + assert.Check(t, is.Equal(s.ShouldReportFailure(), true)) + + // Stop resets state. + assert.Check(t, s.Stop(t.Context()) == nil) + + // Second failure after Stop: must warn again. + err = s.Start(t.Context()) + assert.Check(t, err != nil) + assert.Check(t, is.Equal(s.ShouldReportFailure(), true), "failure after Stop must produce fresh warning") +}