diff --git a/pkg/runtime/loop.go b/pkg/runtime/loop.go index bebcc0d7f..03edea4e5 100644 --- a/pkg/runtime/loop.go +++ b/pkg/runtime/loop.go @@ -41,6 +41,12 @@ func (r *LocalRuntime) registerDefaultTools() { }) } +// appendSteerAndEmit adds a steer message to the session and emits the corresponding event. +func (r *LocalRuntime) appendSteerAndEmit(sess *session.Session, sm QueuedMessage, events chan<- Event) { + sess.AddMessage(session.UserMessage(sm.Content, sm.MultiContent...)) + events <- UserMessage(sm.Content, sess.ID, sm.MultiContent, len(sess.Messages)-1) +} + // finalizeEventChannel performs cleanup at the end of a RunStream goroutine: // restores the previous elicitation channel, emits the StreamStopped event, // fires hooks, and closes the events channel. @@ -294,6 +300,16 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c } } + // Drain steer messages queued while idle or before the first model call + // (covers idle-window and first-turn-miss races). + if steered := r.steerQueue.Drain(ctx); len(steered) > 0 { + messageCountBeforeSteer := len(sess.GetAllMessages()) + for _, sm := range steered { + r.appendSteerAndEmit(sess, sm, events) + } + r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeSteer, events) + } + messages := sess.GetMessages(a) slog.Debug("Retrieved messages for processing", "agent", a.Name(), "message_count", len(messages)) @@ -418,19 +434,10 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c // Record per-toolset model override for the next LLM turn. toolModelOverride = resolveToolCallModelOverride(res.Calls, agentTools) - // --- STEERING: mid-turn injection --- - // Drain ALL pending steer messages. These are urgent course- - // corrections that the model should see on the very next - // iteration, wrapped in tags. + // Drain steer messages that arrived during tool calls. if steered := r.steerQueue.Drain(ctx); len(steered) > 0 { for _, sm := range steered { - wrapped := fmt.Sprintf( - "\nThe user sent the following message while you were working:\n%s\n\nPlease address this in your next response while continuing with your current tasks.\n", - sm.Content, - ) - userMsg := session.UserMessage(wrapped, sm.MultiContent...) - sess.AddMessage(userMsg) - events <- UserMessage(sm.Content, sess.ID, sm.MultiContent, len(sess.Messages)-1) + r.appendSteerAndEmit(sess, sm, events) } r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) @@ -441,6 +448,15 @@ func (r *LocalRuntime) RunStream(ctx context.Context, sess *session.Session) <-c slog.Debug("Conversation stopped", "agent", a.Name()) r.executeStopHooks(ctx, sess, a, res.Content, events) + // Re-check steer queue: closes the race between the mid-loop drain and this stop. + if steered := r.steerQueue.Drain(ctx); len(steered) > 0 { + for _, sm := range steered { + r.appendSteerAndEmit(sess, sm, events) + } + r.compactIfNeeded(ctx, sess, a, m, contextLimit, messageCountBeforeTools, events) + continue + } + // --- FOLLOW-UP: end-of-turn injection --- // Pop exactly one follow-up message. Unlike steered // messages, follow-ups are plain user messages that start diff --git a/pkg/runtime/runtime_test.go b/pkg/runtime/runtime_test.go index c1adbc84a..a72c4a14c 100644 --- a/pkg/runtime/runtime_test.go +++ b/pkg/runtime/runtime_test.go @@ -5,6 +5,7 @@ import ( "errors" "io" "reflect" + "strings" "sync" "testing" @@ -2198,3 +2199,392 @@ func toolNames(ts []tools.Tool) []string { } return names } + +// messageRecordingProvider records the chat.Message slices passed to each +// CreateChatCompletionStream call so tests can inspect what the model saw. +type messageRecordingProvider struct { + id string + mu sync.Mutex + streams []*mockStream + callIdx int + + recordedMessages [][]chat.Message // messages passed on each call +} + +func (p *messageRecordingProvider) ID() string { return p.id } + +func (p *messageRecordingProvider) CreateChatCompletionStream(_ context.Context, msgs []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { + p.mu.Lock() + defer p.mu.Unlock() + + snapshot := make([]chat.Message, len(msgs)) + copy(snapshot, msgs) + p.recordedMessages = append(p.recordedMessages, snapshot) + + if p.callIdx >= len(p.streams) { + // No stream configured for this call index. Return a plain stop so + // the caller surfaces this as a test failure via assertion rather + // than hanging, but also record the unexpected call so the test can + // detect it with require.Len / require.Equal. + return newStreamBuilder().AddStopWithUsage(1, 1).Build(), nil + } + s := p.streams[p.callIdx] + p.callIdx++ + return s, nil +} + +func (p *messageRecordingProvider) BaseConfig() base.Config { return base.Config{} } +func (p *messageRecordingProvider) MaxTokens() int { return 0 } + +// TestSteer_IdleWindowIsConsumedOnNextTurn verifies that a Steer call made +// while no RunStream is active (i.e. in the idle window between turns) is +// picked up by the very next RunStream iteration. Before the fix the steer +// queue was only drained mid-loop (after tool calls), so a message enqueued +// while idle was stranded and never seen by the model. +func TestSteer_IdleWindowIsConsumedOnNextTurn(t *testing.T) { + t.Parallel() + + // The model returns a plain-text stop (no tool calls) so we stay in the + // single-iteration path — this is the exact scenario where the old code + // would miss the steer message. + stream := newStreamBuilder(). + AddContent("Got it"). + AddStopWithUsage(5, 3). + Build() + + prov := &messageRecordingProvider{ + id: "test/mock-model", + streams: []*mockStream{stream}, + } + + root := agent.New("root", "You are a test agent", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + // Enqueue a steer message BEFORE calling RunStream — simulating the + // idle-window race where a Steer call lands between two RunStream + // invocations. + err = rt.Steer(QueuedMessage{Content: "urgent: change direction"}) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("Do the task")) + sess.Title = "steer idle-window test" + + evCh := rt.RunStream(t.Context(), sess) + var events []Event + for ev := range evCh { + events = append(events, ev) + } + + // The run must complete normally (StreamStopped as the last event). + require.NotEmpty(t, events) + assert.IsType(t, &StreamStoppedEvent{}, events[len(events)-1], + "expected StreamStopped as the final event") + + // A UserMessageEvent must have been emitted for the steer message. + var steerEventFound bool + for _, ev := range events { + if ue, ok := ev.(*UserMessageEvent); ok && strings.Contains(ue.Message, "urgent: change direction") { + steerEventFound = true + break + } + } + assert.True(t, steerEventFound, "expected a UserMessageEvent for the steer message") + + // --- Session-message assertions --- + // Find the stored message for the steer injection and verify it was + // stored as a plain user message with NO system-reminder envelope. + var steerSessionMsg *session.Message + for _, item := range sess.Messages { + if item.IsMessage() && + item.Message.Message.Role == chat.MessageRoleUser && + strings.Contains(item.Message.Message.Content, "urgent: change direction") { + steerSessionMsg = item.Message + break + } + } + require.NotNil(t, steerSessionMsg, "expected a user-role session message containing the steer content") + assert.Equal(t, "urgent: change direction", steerSessionMsg.Message.Content, + "top-of-turn steer must be stored as plain content, not wrapped in system-reminder") + assert.NotContains(t, steerSessionMsg.Message.Content, "", + "top-of-turn steer must NOT use the system-reminder envelope") + + // --- Model-call assertions --- + // Verify the model received a message containing the raw steer content. + prov.mu.Lock() + defer prov.mu.Unlock() + + require.NotEmpty(t, prov.recordedMessages, "expected at least one model call") + firstCallMsgs := prov.recordedMessages[0] + + var foundSteer bool + for _, m := range firstCallMsgs { + if strings.Contains(m.Content, "urgent: change direction") { + // Also assert the model did NOT receive the system-reminder wrapper. + assert.NotContains(t, m.Content, "", + "model must receive raw content, not system-reminder envelope, for top-of-turn steer") + foundSteer = true + break + } + } + assert.True(t, foundSteer, + "model should have received the steer message in its first turn; messages seen: %v", + firstCallMsgs) +} + +// TestSteer_EmptySessionBootstrap verifies that when RunStream is started +// with zero messages in the session but one or more messages already queued +// via Steer, the model receives those messages as its initial context — i.e. +// the run completes normally rather than erroring or producing a vacuous +// response. The behaviour must be identical to a session where those messages +// were added directly via session.WithUserMessage before the call. +func TestSteer_EmptySessionBootstrap(t *testing.T) { + t.Parallel() + + stream := newStreamBuilder(). + AddContent("Hello from the model"). + AddStopWithUsage(5, 3). + Build() + + prov := &messageRecordingProvider{ + id: "test/mock-model", + streams: []*mockStream{stream}, + } + + root := agent.New("root", "You are a test agent", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + + rt, err := NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + // Enqueue before RunStream — zero messages in the session. + err = rt.Steer(QueuedMessage{Content: "bootstrap message"}) + require.NoError(t, err) + + // Fresh session with NO messages (SendUserMessage defaults to true but + // there is nothing to send yet). + sess := session.New() + sess.Title = "steer bootstrap test" + + evCh := rt.RunStream(t.Context(), sess) + var events []Event + for ev := range evCh { + events = append(events, ev) + } + + // The run must complete normally. + require.NotEmpty(t, events) + assert.IsType(t, &StreamStoppedEvent{}, events[len(events)-1], + "expected StreamStopped as the final event; got %T", events[len(events)-1]) + + // A UserMessageEvent must have been emitted for the steer message. + var steerEventFound bool + for _, ev := range events { + if ue, ok := ev.(*UserMessageEvent); ok && strings.Contains(ue.Message, "bootstrap message") { + steerEventFound = true + break + } + } + assert.True(t, steerEventFound, + "expected a UserMessageEvent for the bootstrap steer message") + + // --- Session-message assertions --- + // The stored session message must be plain — no system-reminder envelope. + var bootstrapMsg *session.Message + for _, item := range sess.Messages { + if item.IsMessage() && + item.Message.Message.Role == chat.MessageRoleUser && + strings.Contains(item.Message.Message.Content, "bootstrap message") { + bootstrapMsg = item.Message + break + } + } + require.NotNil(t, bootstrapMsg, "expected a user-role session message for the bootstrap steer") + assert.Equal(t, "bootstrap message", bootstrapMsg.Message.Content, + "bootstrap steer must be stored as plain content, not wrapped in system-reminder") + assert.NotContains(t, bootstrapMsg.Message.Content, "", + "bootstrap steer must NOT use the system-reminder envelope") + + // --- Model-call assertions --- + // The model must have received exactly one call and that call must + // contain the raw bootstrap message (not wrapped). + prov.mu.Lock() + defer prov.mu.Unlock() + + require.Len(t, prov.recordedMessages, 1, + "expected exactly one model call for the bootstrap turn") + + firstCallMsgs := prov.recordedMessages[0] + + var foundBootstrap bool + for _, m := range firstCallMsgs { + if strings.Contains(m.Content, "bootstrap message") { + // The model must see raw content, not the system-reminder wrapper. + assert.NotContains(t, m.Content, "", + "model must receive raw content, not system-reminder envelope, for bootstrap steer") + foundBootstrap = true + break + } + } + assert.True(t, foundBootstrap, + "model must receive the bootstrap steer message as its first (and only) user turn; messages: %v", + firstCallMsgs) +} + +// hookStream wraps a mockStream and calls onStop synchronously when it +// returns a chunk with FinishReasonStop. This lets a test inject a Steer() +// call at the precise moment the stream signals completion — after the stop +// chunk is read inside tryModelWithFallback but before the mid-loop steer +// drain runs, exercising the end-of-iteration drain at res.Stopped. +type hookStream struct { + *mockStream + + onStop func() +} + +func (h *hookStream) Recv() (chat.MessageStreamResponse, error) { + resp, err := h.mockStream.Recv() + if err == nil && len(resp.Choices) > 0 && resp.Choices[0].FinishReason == chat.FinishReasonStop { + if h.onStop != nil { + h.onStop() + } + } + return resp, err +} + +// steerInjectProvider is a provider whose CreateChatCompletionStream calls a +// hook just before returning the stream. The hook is used to inject a Steer +// message synchronously while the stream response is being prepared — this +// simulates the narrow end-of-iteration race where a Steer() call lands after +// the mid-loop drain but before the res.Stopped break. +type steerInjectProvider struct { + id string + streams []chat.MessageStream + callIdx int + onCall func(callIdx int) // called with the current callIdx before returning + mu sync.Mutex +} + +func (p *steerInjectProvider) ID() string { return p.id } + +func (p *steerInjectProvider) CreateChatCompletionStream(_ context.Context, _ []chat.Message, _ []tools.Tool) (chat.MessageStream, error) { + p.mu.Lock() + idx := p.callIdx + p.callIdx++ + var s chat.MessageStream + if idx < len(p.streams) { + s = p.streams[idx] + } else { + s = newStreamBuilder().AddStopWithUsage(1, 1).Build() + } + p.mu.Unlock() + + if p.onCall != nil { + p.onCall(idx) + } + return s, nil +} + +func (p *steerInjectProvider) BaseConfig() base.Config { return base.Config{} } +func (p *steerInjectProvider) MaxTokens() int { return 0 } + +// TestSteer_EndOfIterationRaceIsConsumedInCurrentRunStream verifies that a +// Steer() call arriving in the narrow window between the mid-loop drain and +// the res.Stopped break is consumed within the same RunStream invocation +// rather than being stranded until the next call. +// +// The hookStream fires the injection synchronously inside Recv() when it +// yields the FinishReasonStop chunk. At that point tryModelWithFallback has +// not yet returned; the steer lands in the queue and is guaranteed to be +// drained by one of the three drain points (mid-loop, end-of-iteration, or +// top-of-next-turn). The test asserts the key invariant: consumed within +// this RunStream (2 model calls, UserMessageEvent present). +func TestSteer_EndOfIterationRaceIsConsumedInCurrentRunStream(t *testing.T) { + t.Parallel() + + var rt *LocalRuntime // set after NewLocalRuntime + + // Turn 1: plain-text stop. The hookStream injects a Steer() when the + // stop chunk is returned by Recv(), simulating a race in that window. + turn1Base := newStreamBuilder(). + AddContent("Here is my response"). + AddStopWithUsage(5, 3). + Build() + turn1 := &hookStream{ + mockStream: turn1Base, + onStop: func() { + _ = rt.Steer(QueuedMessage{Content: "end-of-iter steer"}) + }, + } + // Turn 2: the loop re-entered after the steer was consumed; model acks. + turn2 := newStreamBuilder(). + AddContent("Got your steer, changing direction"). + AddStopWithUsage(5, 3). + Build() + + prov := &steerInjectProvider{ + id: "test/mock-model", + streams: []chat.MessageStream{turn1, turn2}, + } + + root := agent.New("root", "You are a test agent", agent.WithModel(prov)) + tm := team.New(team.WithAgents(root)) + + var err error + rt, err = NewLocalRuntime(tm, WithSessionCompaction(false), WithModelStore(mockModelStore{})) + require.NoError(t, err) + + sess := session.New(session.WithUserMessage("Do the task")) + sess.Title = "steer end-of-iter race test" + + evCh := rt.RunStream(t.Context(), sess) + var events []Event + for ev := range evCh { + events = append(events, ev) + } + + // The run must complete normally. + require.NotEmpty(t, events) + assert.IsType(t, &StreamStoppedEvent{}, events[len(events)-1], + "expected StreamStopped as the final event") + + // The steer message must have been emitted as a UserMessageEvent + // within this RunStream (not deferred to a future one). + var steerEventFound bool + for _, ev := range events { + if ue, ok := ev.(*UserMessageEvent); ok && strings.Contains(ue.Message, "end-of-iter steer") { + steerEventFound = true + break + } + } + assert.True(t, steerEventFound, + "expected a UserMessageEvent for the end-of-iteration steer within the same RunStream") + + // The provider must have been called twice: once for the original turn + // and once for the follow-on turn triggered by the steer injection. + prov.mu.Lock() + defer prov.mu.Unlock() + assert.Equal(t, 2, prov.callIdx, + "expected exactly 2 model calls: original turn + steer follow-on turn") + + // Find the stored session message for the steer and verify it was + // consumed within this RunStream. + var steerSessionMsg *session.Message + for _, item := range sess.Messages { + if item.IsMessage() && + item.Message.Message.Role == chat.MessageRoleUser && + strings.Contains(item.Message.Message.Content, "end-of-iter steer") { + steerSessionMsg = item.Message + break + } + } + require.NotNil(t, steerSessionMsg, "expected a session message for the end-of-iteration steer") + // All steer drain sites inject plain user messages; no wrapping occurs + // regardless of which drain (mid-loop or end-of-iteration) fires first. + assert.Equal(t, "end-of-iter steer", steerSessionMsg.Message.Content, + "end-of-iteration steer must be stored as plain content") + assert.NotContains(t, steerSessionMsg.Message.Content, "", + "end-of-iteration steer must NOT use the system-reminder envelope") +}