From 27d4a9d72369e01964f7f6f3c3fafd77e2e1d96a Mon Sep 17 00:00:00 2001 From: liut Date: Sun, 10 May 2026 01:31:05 +0800 Subject: [PATCH] refactor: event-driven architecture with unified Event and Runner.Persist - Add Event as the single communication primitive replacing StreamResult - Add Pusher callback type, eliminating goroutine+channel in SSE parsing - Upgrade StreamChat to return iter.Seq2[*Event, error] (both providers) - Add Runner.Persist() for unified history/usage/session persistence - Add HistoryStore/SessionStore interfaces with stores adapter - Add ToolResult to Event for tool execution visibility in event stream - Remove StreamResult type, goroutine+channel pattern, dead reserved fields - Consolidate handlers to use Runner.Persist() instead of direct store calls --- ...001-feat-event-driven-architecture-plan.md | 338 ++++++++++++++++++ pkg/services/llm/anthropic.go | 162 +++++---- pkg/services/llm/anthropic_test.go | 19 +- pkg/services/llm/client.go | 11 +- pkg/services/llm/event.go | 57 +++ pkg/services/llm/openai.go | 75 ++-- pkg/services/llm/openai_test.go | 40 +-- pkg/services/llm/types.go | 14 - pkg/services/llm/types_test.go | 17 - pkg/services/runner/runner.go | 53 +++ pkg/services/stores/event_adapter.go | 92 +++++ pkg/services/stores/integration_test.go | 103 +++++- pkg/web/api/agent.go | 93 +++-- pkg/web/api/api.go | 4 +- pkg/web/api/handle_convo.go | 110 +++--- pkg/web/api/handle_platform.go | 73 ++-- pkg/web/api/tool_executor.go | 33 +- 17 files changed, 953 insertions(+), 341 deletions(-) create mode 100644 docs/plans/2026-05-09-001-feat-event-driven-architecture-plan.md create mode 100644 pkg/services/llm/event.go create mode 100644 pkg/services/runner/runner.go create mode 100644 pkg/services/stores/event_adapter.go diff --git a/docs/plans/2026-05-09-001-feat-event-driven-architecture-plan.md b/docs/plans/2026-05-09-001-feat-event-driven-architecture-plan.md new file mode 100644 index 0000000..31f786a --- /dev/null +++ b/docs/plans/2026-05-09-001-feat-event-driven-architecture-plan.md @@ -0,0 +1,338 @@ +--- +title: feat: Event-driven architecture for chat pipeline +type: feat +status: done +date: 2026-05-09 +origin: docs/brainstorms/2026-05-09-event-driven-architecture-requirements.md +--- + +# feat: Event-driven architecture for chat pipeline + +## Overview + +将当前基于 `<-chan StreamResult` 的流式管道重构为以 `Event` 为核心原语的事件驱动架构。统一 Agent、Tool、Runner、Handler 之间的通信路径,集中持久化逻辑,为后续多 Agent 编排预留扩展点。 + +## Problem Statement + +当前聊天管道存在 5 个结构性问题: + +1. **StreamResult 只是传输层 DTO** — LLM 响应、工具结果、状态变更、持久化各自走不同路径 +2. **持久化逻辑散落三处** — `handle_convo.go`(3 处:AddHistory+Save、CreateUsageRecord、CreateChatLog)、`handle_platform.go`(2 处:streaming + regular 各自 AddHistory+Save) +3. **状态变更无迹可寻** — 工具结果直接拼入 messages 数组,无法追溯"谁在什么时候做了什么" +4. **channel 作为流式抽象有局限性** — 生产者必须开 goroutine,无法对事件流做组合包装 +5. **CLI/Web/Channel 三条路径各自重复 Agent 逻辑** — `Agent`(CLI-only)、`api` 结构体(web-only)、`channelHandler`(channel-only) 各自持有相同的 `llm.Client` + `ToolExecutor` + +## Key Decisions (from origin) + +- **D1.** 不引入 ADK-Go 的完整 Session 接口,仅定义 `SessionStore`(`MergeDelta`)和 `HistoryStore`(`AppendEvent` + `CreateUsageRecord`)两个最小接口 (see origin) +- **D2.** StateDelta 仅支持 session 级 key,不引入 app/user/temp 多级作用域 (see origin) +- **D3.** Event 直接放在 `pkg/services/llm` 包,与 `Message` 同包 (see origin; StreamResult 已删除) +- **D4.** 字段范围:`ID`/`Timestamp`/`Author`/`Delta`/`Think`/`ToolCalls`/`StopReason`/`Done`/`ToolResult`/`UserID`/`UserPrompt`/`Actions.StateDelta`/`Usage`/`Model`/`MsgCount`/`Meta`/`ResponseID` 实现 (see origin; InvocationID/Error/Branch/TransferToAgent/ArtifactDelta removed per review) +- **D5.** ~~Event 补充 `Usage`/`Model`/`ResponseID`/`Error` 字段以覆盖 StreamResult 的全部语义 (SpecFlow Q1)~~ → 已合并到 D4 +- **D6.** StateDelta 是**补充**而非**替代** — 工具结果仍通过 messages 返回给 LLM,StateDelta 仅用于状态性副作用 (SpecFlow Q2) +- **D7.** ~~新增 `SessionStateStore` 接口~~ → 简化为 `SessionStore`(仅 `MergeDelta`),去掉 Get/Set (SpecFlow Q3) +- **D8.** ~~InvocationID 对应一次 Agent.Run() 调用~~ → InvocationID 已删除(零读取),当前 SessionID 标识对话、Event.ID 标识单事件 (SpecFlow Q4) +- **D9.** Runner 放在新包 `pkg/services/runner`,避免循环依赖 (SpecFlow Q7) + +## Technical Approach + +### Architecture + +``` +Handler (SSE / Channel) + │ 直接调用 llm.StreamChat() 消费 iter.Seq2 + │ 通过 runner.Persist() 手动持久化 + ▼ +Agent (LLM 调用 + Tool 执行循环) + │ 实现 Run() → iter.Seq2[*Event, error] + │ 内部: llm.StreamChat + toolExec.ExecuteToolCalls + ▼ +Runner (统一持久化入口) + │ Persist(ctx, sessionID, event) → AppendEvent + MergeDelta + CreateUsageRecord + ▼ +stores/event_adapter (Runner 接口 → 现有 stores 适配) +``` + +**Event 结构**(`pkg/services/llm/event.go`): + +```go +type Event struct { + ID string + Timestamp time.Time + Author string // "user" | "assistant" | toolName + + Delta string + Think string + ToolCalls []ToolCall + StopReason FinishReason + UserID string + UserPrompt string + Done bool + + Usage *Usage + Model string + MsgCount int + Meta map[string]any + ResponseID string + + ToolResult *ToolResult + Actions EventActions +} + +type ToolResult struct { + CallID string + Name string + Content string +} + +type EventActions struct { + StateDelta map[string]any // session 级状态增量 +} +``` + +**Pusher**(流式事件推送回调): + +```go +type Pusher func(*Event, error) bool +``` + +**Runner**(`pkg/services/runner/runner.go`,60 行): + +```go +type Runner struct { + sessionStore SessionStore + historyStore HistoryStore +} + +func (r *Runner) Persist(ctx context.Context, sessionID string, event *llm.Event) error +``` + +**接口**(`pkg/services/runner/runner.go`): + +```go +type SessionStore interface { + MergeDelta(ctx context.Context, sessionID string, delta map[string]any) error +} + +type HistoryStore interface { + AppendEvent(ctx context.Context, sessionID string, event *llm.Event) error + CreateUsageRecord(ctx context.Context, sessionID string, event *llm.Event) error +} +``` + +### Implementation Phases + +#### Phase 1: Event 类型定义 + LLM Client 适配 ✅ + +**目标**:定义 Event 类型,将 `StreamChat` 的返回从 `<-chan StreamResult` 改为 `iter.Seq2[*Event, error]`。 + +**变更文件**: +- `pkg/services/llm/event.go` — 新增 `Event`、`EventActions`、`ToolResult`、`Pusher` 类型 +- `pkg/services/llm/client.go` — `Client` 接口 `StreamChat` 签名改为 `iter.Seq2[*Event, error]` +- `pkg/services/llm/openai.go` — `StreamChat` 消除 goroutine+channel,改用 `Pusher` 同步推送 +- `pkg/services/llm/anthropic.go` — 同上,`parseStreamResponse`/`handleStreamEvent` 直接构造 `*Event` +- `pkg/services/llm/types.go` — 删除 `StreamResult` 类型 + +**实际实现调整**: +- 未新增 `Run()` 方法,直接升级 `StreamChat` 签名(用户要求) +- `Pusher` 类型替代 `chan StreamResult`,`parseStreamResponse` 同步调用 +- Event 字段从 `Partial` 改为 `Done`(审查建议:零值安全) +- 删除了规划中的预留字段 `Branch`/`Error`/`TransferToAgent`/`ArtifactDelta`(审查建议:YAGNI) + +**验收**: +- [x] `Event` 类型包含全部实现字段 +- [x] `StreamChat` 返回 `iter.Seq2[*Event, error]` +- [x] 现有测试更新并通过 + +#### Phase 2: Runner 实现 + 统一持久化 ✅ + +**目标**:创建 `pkg/services/runner` 包,实现 Runner 统一持久化入口。 + +**变更文件**: +- `pkg/services/runner/runner.go` — `Runner` 结构体 + `Persist()` 方法 + `SessionStore`/`HistoryStore` 接口 +- `pkg/services/stores/event_adapter.go` — `HistoryStore`/`SessionStore` 的适配实现 + +**实际实现调整**: +- 简化为只保留 `Persist()`(`Runner.Run()` 未实现 —— 审查发现 handler 各有不同的输出格式,统一循环不合适) +- `SessionStore` 从 `Get/Set/MergeDelta` 简化为只有 `MergeDelta`(当前无 Get/Set 消费方) +- `CreateUsageRecord` 参数从 `runner.UsageRecord` 改为直接接收 `*llm.Event`(消除冗余中间结构) +- 接口定义放在 `runner/runner.go`(非 `stores/interfaces.go`) + +**验收**: +- [x] Runner 实现持久化流程(AppendEvent + MergeDelta + CreateUsageRecord) +- [x] HistoryStore/SessionStore 通过 event_adapter 适配现有 stores +- [x] 错误通过 `errors.Join` 收集返回 + +#### Phase 3: Agent 重构 + ToolExecutor 迁移 ✅ + +**目标**:重构 `Agent` 使用 `iter.Seq2`,ToolExecutor 通过 Event 返回结果。 + +**变更文件**: +- `pkg/web/api/agent.go` — `Agent.Run()` 返回 `iter.Seq2[*Event, error]`;`Chat()` 使用非流式 `llm.Chat()`;`StreamChat()` 通过 StreamCallbacks 消费 `Run()` +- `pkg/web/api/tool_executor.go` — `ExecuteToolCalls` 返回 `([]*Event, []Message)`,Event 携带 `ToolResult` + +**实际实现调整**: +- `Agent.Chat()` 使用 `llm.Chat()` + `ExecuteToolCallLoop`(非流式路径,避免浪费) +- `Agent.StreamChat()` 通过 `Run()` + `StreamCallbacks` 实现(CLI 终端输出) +- `ExecuteToolCalls` 返回 `len(evs)==0` 替代原 `hasToolCall` bool(语义等价) +- 工具结果尚未写入 StateDelta(等待实际需求驱动) + +**验收**: +- [x] Agent.Run() 返回 iter.Seq2 +- [x] 工具调用结果通过 Event.ToolResult 传递 +- [x] CLI agent 功能无回归 + +#### Phase 4: Handler 适配 ✅ + +**目标**:Web API 和 Channel Handler 用 `runner.Persist()` 替代直接持久化调用。 + +**变更文件**: +- `pkg/web/api/handle_convo.go` — `chatStreamResponseLoop`/`doChatStream` 用 `runner.Persist()` 替代 `AddHistory`/`Save`/`CreateUsageRecord` +- `pkg/web/api/handle_platform.go` — `handleStreamingReply`/`handleRegularReply` 同上 +- `pkg/web/api/api.go` — `api` 结构体持有 `*runner.Runner` + +**实际实现调整**: +- Handler 保持自己的工具调用循环(各自输出格式不同:SSE/Channel Stream/CLI) +- `runner.Persist()` 用于手动持久化点(历史、用量) +- 删除了 `gatherUsage`、手动 `AddHistory+Save`、手动 `CreateUsageRecord` 调用 +- 补充了曾丢失的 `UserID`/`UserPrompt`/`MsgCount`/`Meta` 信息 + +**验收**: +- [x] `/api/chat` SSE 响应格式不变 +- [x] Channel webhook 行为不变 +- [x] Handler 代码中不再出现 `AddHistory`/`Save`/`CreateUsageRecord` 调用 +- [x] 无 goroutine 泄漏(消除了 goroutine+channel 模式) + +#### Phase 5: 清理与测试 ✅ + +**目标**:清理废弃代码,补充集成测试,修复回归 bug。 + +**变更文件**: +- `pkg/services/llm/types.go` — 删除 `StreamResult` 类型 +- `pkg/services/llm/types_test.go` — 删除 `TestStreamResultString` +- `pkg/services/stores/integration_test.go` — 新增 `event_adapter` 集成测试 + +**Bug 修复(Phase 5 期间)**: +- `message_start` Partial 未设导致 panic(→ 改为 `Done` 零值安全) +- think 内容不输出(`isEmpty` check 漏了 `Think` 字段) +- `hasToolCall` 检查被删导致死循环(改用 `len(evs)==0`) +- `gatherUsage` 信息丢失(`MsgCount`/`Meta`/`UserPrompt` 在 Persist 中恢复) +- `anthropic.go`/`openai.go` 去 channel 时保留原有注释 + +**验收**: +- [x] `make vet lint` 通过(0 告警) +- [x] `make test-models` 通过 +- [x] `make test-stores` 通过(含 event_adapter 集成测试) + +### Persistence Contract (AppendEvent) + +``` +AppendEvent(ctx, sessionID, event): + 1. 将 event 转换为 aigc.HistoryItem: + - Role=user 的 event → HistoryItem.ChatItem.User + - Role=assistant 的 event → HistoryItem.ChatItem.Assistant + Think + 2. 写入 Redis: RPUSH convs- + 3. 如果有 StateDelta → MergeDelta(ctx, sessionID, event.Actions.StateDelta) + - MergeDelta 写入 convo_session.meta jsonb 列 + - key 冲突时后者覆盖前者 + 4. 如果有 Usage → CreateUsageRecord + 5. (future) 可选写入 convo_message 表(PostgreSQL) +``` + +## System-Wide Impact + +### Interaction Graph + +``` +POST /api/chat + → postChat (handler) + → prepareChatRequest (构建 messages + tools) + → chatStreamResponseLoop + → doChatStream → llm.StreamChat (iter.Seq2) + → openAIProvider/anthropicProvider: SSE parse → push(*Event) + → 每个 Event → SSE chunk → writeEvent + → 工具调用: toolExec.ExecuteToolCalls → []Event (含 ToolResult) + → runner.Persist(event) → AppendEvent + MergeDelta + CreateUsageRecord + → GetHistorySummary → title +``` + +### Error Propagation + +| 层级 | 错误类型 | 处理方式 | +|------|---------|---------| +| LLM API | HTTP error / timeout | `yield(nil, err)` → Handler 展示错误 | +| Tool invoke | tool not found / exec fail | `Event{Author: toolName, Content: error text}` → 作为 tool result 返回 LLM | +| Redis write | connection error | Runner 日志告警,不阻断事件流 | +| Context cancel | client disconnect | iter.Seq2 在下次 yield 前检查 `ctx.Done()`,停止迭代 | + +### State Lifecycle Risks + +- **StateDelta 与 messages 不一致**:工具先写 StateDelta(Set user preference),但后续 LLM 调用失败。此时 StateDelta 已在 Redis session.meta 中,但对话未完成。**缓解**:StateDelta 在 AppendEvent 时与 HistoryItem 同事务写入,只有完整 Event 才写 StateDelta +- **Runner 中途崩溃**:已持久化的事件在 Redis 中,未持久化的丢失。与当前行为一致(当前也是逐个 AddHistory) +- **StateDelta 合并竞争**:并行 Agent 场景(未来)可能引发 key 覆盖。当前单 Agent 无此问题 + +### API Surface Parity + +| 路径 | 当前 | 重构后 | +|------|------|--------| +| POST /api/chat | 直接调 llm.StreamChat | 通过 runner.Run() | +| POST /api/chat-sse | 同上 | 同上 | +| Channel WeCom | channelHandler.handleStreamingReply | 通过 runner.Run() | +| Channel Feishu | channelHandler.handleRegularReply | 通过 runner.Run() | +| CLI agent | Agent.StreamChat | 通过 Agent.Run() | +| GET /api/history/{cid} | ListHistory from Redis | 不变 | + +### Integration Test Scenarios + +1. **正常流式对话**:User Message → 3 个 delta Event → 1 个 done Event → 验证 Redis history + UsageRecord +2. **工具调用链**:User Message → delta + tool_call Event → tool result Event → 第二轮 delta → done → 验证 messages 列表正确、StateDelta 已合并 +3. **客户端断开**:发送 2 个 delta → 关闭连接 → 验证 iter.Seq2 停止、无 goroutine 泄漏、已发事件已持久化 +4. **LLM 错误恢复**:LLM 返回错误 Event → 验证 Runner 持久化错误事件 → Handler 展示错误给用户 +5. **StateDelta 合并**:工具写入 `{"last_query": "..."}` → 验证 `sessionStore.Get(sessionID, "last_query")` 返回正确值 + +## Acceptance Criteria + +### Functional Requirements +- [x] R1. `Event` 类型定义在 `pkg/services/llm`,含全部实现字段 +- [x] R2. `Agent.Run()` 返回 `iter.Seq2[*Event, error]`;`Runner.Persist()` 为持久化入口 +- [x] R3. Handler 中无直接 `AddHistory`/`Save`/`CreateUsageRecord` 调用(统一通过 `Runner.Persist()`) +- [ ] R4. 工具可通过 `Event.Actions.StateDelta` 写入状态(接口就绪,待实际需求驱动) +- [x] R5. `Event.Done` 标记流结束(零值安全,中间事件无需显式设置) +- [x] R6. `/api/chat` SSE 和 Channel webhook 对外行为不变 + +### Non-Functional Requirements +- [x] 无 goroutine 泄漏(消除 goroutine+channel 模式) +- [x] `make vet lint` 通过(0 告警) +- [x] 现有测试全部通过 +- [x] event_adapter 集成测试覆盖 + +## Dependencies & Risks + +| 依赖/风险 | 影响 | 缓解 | +|----------|------|------| +| Go 1.25 `iter.Seq2` 稳定 | 低 — 1.23+ 已稳定 | 已在 go.mod 中确认 1.25 | +| Redis history 路径保持不变 | 确保 ListHistory 兼容 | AppendEvent 写同一 Redis key 格式 | +| `SessionStateStore` 避免与 OAuth `StateStore` 命名冲突 | 关键 — 见 SpecFlow Q3 | 使用独立接口名 `SessionStateStore` | +| `convo.Message` 表当前未使用 | 低 — 后续单独迁移 | 本次不涉及,保留现状 | +| 无 goroutine 泄漏 | 高 — 当前 channel 模式有泄漏风险 | iter.Seq2 同步模型天然消除 | + +## Sources & References + +### Origin +- [docs/brainstorms/2026-05-09-event-driven-architecture-requirements.md](../brainstorms/2026-05-09-event-driven-architecture-requirements.md) — 需求文档 + - 携带决策:D1-D9(字段范围、包位置、StateDelta 语义、Runner 架构) + +### Internal References +- `pkg/services/llm/types.go:103` — 当前 StreamResult 定义 +- `pkg/web/api/tool_executor.go:47` — ToolExecutor.ExecuteToolCalls 当前实现 +- `pkg/web/api/handle_convo.go:334-500` — 当前流式响应循环 + 持久化 +- `pkg/web/api/handle_platform.go:157-345` — Channel streaming + regular handler +- `pkg/services/stores/conversation.go:131` — Conversation.Save 实现 + +### Institutional Learnings +- [executeToolCallLoop-deduplication.md](../solutions/logic-errors/executeToolCallLoop-deduplication.md) — ToolExecutor 提取模式,Runner 的直接前身 +- [streaming-reply-multiple-startstream.md](../solutions/runtime-errors/streaming-reply-multiple-startstream.md) — 流式生命周期规则:Start/Finish 必须由单层管理 + +### External References +- [Go 1.23 iter.Seq2 文档](https://pkg.go.dev/iter#Seq2) +- ADK-Go session/event 模型: `google.golang.org/adk/session` diff --git a/pkg/services/llm/anthropic.go b/pkg/services/llm/anthropic.go index 5413657..f705f24 100644 --- a/pkg/services/llm/anthropic.go +++ b/pkg/services/llm/anthropic.go @@ -7,9 +7,11 @@ import ( "encoding/json" "fmt" "io" + "iter" "net/http" "os" "strings" + "time" ) const anthropicVersion = "2023-06-01" @@ -115,13 +117,9 @@ func (p *anthropicProvider) Chat(ctx context.Context, cfg *config, messages []Me return result, nil } -func (p *anthropicProvider) StreamChat(ctx context.Context, cfg *config, messages []Message, tools []ToolDefinition) (<-chan StreamResult, error) { - ch := make(chan StreamResult, 100) +func (p *anthropicProvider) StreamChat(ctx context.Context, cfg *config, messages []Message, tools []ToolDefinition) iter.Seq2[*Event, error] { - go func() { - defer close(ch) - - // 构建请求 + return func(yield func(*Event, error) bool) { endpoint := anthropicMessagesEndpoint(cfg.baseURL) anthropicMessages, systemText := toAnthropicMessages(messages) @@ -144,7 +142,7 @@ func (p *anthropicProvider) StreamChat(ctx context.Context, cfg *config, message if len(tools) > 0 { converted, err := toAnthropicTools(tools) if err != nil { - ch <- StreamResult{Error: err} + yield(nil, err) return } reqBody.Tools = converted @@ -160,19 +158,17 @@ func (p *anthropicProvider) StreamChat(ctx context.Context, cfg *config, message "messages", MessagesLogged(messages), ) - // 序列化请求体,保存用于错误时打印 reqBodyBytes, err := json.Marshal(reqBody) if err != nil { logger().Warnw("marshal stream request failed", "err", err) - ch <- StreamResult{Error: err} + yield(nil, err) return } - // 构建请求 req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBodyBytes)) if err != nil { logger().Warnw("create stream request failed", "err", err, "reqBody", string(reqBodyBytes)) - ch <- StreamResult{Error: err} + yield(nil, err) return } @@ -185,7 +181,6 @@ func (p *anthropicProvider) StreamChat(ctx context.Context, cfg *config, message req.Header.Set(k, v) } - // 发送请求 hc := cfg.httpClient if hc == nil { hc = &http.Client{Timeout: 0} @@ -194,7 +189,7 @@ func (p *anthropicProvider) StreamChat(ctx context.Context, cfg *config, message resp, err := hc.Do(req) if err != nil { logger().Warnw("stream request failed", "err", err, "reqBody", string(reqBodyBytes)) - ch <- StreamResult{Error: err} + yield(nil, err) return } @@ -202,31 +197,35 @@ func (p *anthropicProvider) StreamChat(ctx context.Context, cfg *config, message fmt.Fprintf(os.Stderr, "\n%s\n%d bytes\n", string(reqBodyBytes), len(reqBodyBytes)) } - // 检查响应状态码 if resp.StatusCode < 200 || resp.StatusCode >= 300 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) resp.Body.Close() errMsg := fmt.Errorf("http %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) logger().Warnw("stream response error", "status", resp.StatusCode, - // "reqBody", string(reqBodyBytes), "respBody", string(respBody)) - ch <- StreamResult{Error: errMsg} + yield(nil, errMsg) return } defer resp.Body.Close() - // 解析流响应 - if err := p.parseStreamResponse(resp.Body, ch, cfg.debug, cfg.logDir, cfg.model, messages, tools); err != nil { - ch <- StreamResult{Error: err} + // pusher 适配:补全 author 后直接 yield + push := func(event *Event, err error) bool { + if err != nil { + return yield(nil, err) + } + event.Author = "assistant" + return yield(event, nil) } - }() - return ch, nil + if err := p.parseStreamResponse(resp.Body, push, cfg.debug, cfg.logDir, cfg.model, messages, tools); err != nil { + yield(nil, err) + } + } } -// parseStreamResponse 解析流式响应 -func (p *anthropicProvider) parseStreamResponse(body io.Reader, ch chan<- StreamResult, debug bool, logDir, model string, messages []Message, tools []ToolDefinition) error { +// parseStreamResponse 解析流式响应,通过 push 直接产出 *Event。 +func (p *anthropicProvider) parseStreamResponse(body io.Reader, push Pusher, debug bool, logDir, model string, messages []Message, tools []ToolDefinition) error { var currentToolCalls []ToolCall var currentText strings.Builder var thinkContent string @@ -237,7 +236,7 @@ func (p *anthropicProvider) parseStreamResponse(body io.Reader, ch chan<- Stream line, err := bufReader.ReadBytes('\n') if err != nil { if err == io.EOF { - ch <- StreamResult{Done: true} + push(&Event{Done: true}, nil) } else { logger().Infow("read stream response failed", "err", err) return fmt.Errorf("read: %w", err) @@ -261,22 +260,20 @@ func (p *anthropicProvider) parseStreamResponse(body io.Reader, ch chan<- Stream data := bytes.TrimSpace(line[5:]) if string(data) == "[DONE]" { - ch <- StreamResult{Done: true} + push(&Event{Done: true}, nil) return nil } - event, err := p.parseStreamEvent(data) + se, err := p.parseStreamEvent(data) if err != nil { continue } - // logger().Debugw("stream event parsed", "type", event.Type, "index", event.Index, - // "delta", &event.Delta) - done, toolCalls := p.handleStreamEvent(event, ¤tText, currentToolCalls, ch, logDir, model, messages, tools, &thinkContent) + done, toolCalls := p.handleStreamEvent(se, ¤tText, currentToolCalls, push, logDir, model, messages, tools, &thinkContent) currentToolCalls = toolCalls if done { - logger().Infow("stream done", "event_type", event.Type, "tool_calls_count", len(currentToolCalls)) + logger().Infow("stream done", "event_type", se.Type, "tool_calls_count", len(currentToolCalls)) return nil } } @@ -336,47 +333,59 @@ func (p *anthropicProvider) parseStreamEvent(data []byte) (streamEvent, error) { return event, nil } -// handleStreamEvent 处理流事件,返回是否结束及更新后的 toolCalls -func (p *anthropicProvider) handleStreamEvent(event streamEvent, currentText *strings.Builder, currentToolCalls []ToolCall, ch chan<- StreamResult, logDir, model string, messages []Message, tools []ToolDefinition, thinkContent *string) (bool, []ToolCall) { - switch event.Type { +// handleStreamEvent 处理流事件,通过 push 直接产出 *Event。返回是否结束及更新后的 toolCalls。 +func (p *anthropicProvider) handleStreamEvent(se streamEvent, currentText *strings.Builder, currentToolCalls []ToolCall, + push Pusher, logDir, model string, messages []Message, tools []ToolDefinition, thinkContent *string) ( + bool, []ToolCall) { + switch se.Type { case "content_block_start": - // 开始新的内容块,检查是否是 tool_use 类型 - if event.ContentBlock != nil && event.ContentBlock.Type == "tool_use" { - toolID := event.ContentBlock.ID + // 开始新的内容块,检查是否是 tool_use 类型 + if se.ContentBlock != nil && se.ContentBlock.Type == "tool_use" { + toolID := se.ContentBlock.ID if toolID == "" { - toolID = fmt.Sprintf("toolu_%d", event.Index) + toolID = fmt.Sprintf("toolu_%d", se.Index) } currentToolCalls = append(currentToolCalls, ToolCall{ ID: toolID, Type: "function", Function: ToolCallFunc{ - Name: event.ContentBlock.Name, + Name: se.ContentBlock.Name, }, }) - logger().Debugw("tool_use started", "id", toolID, "name", event.ContentBlock.Name) + logger().Debugw("tool_use started", "id", toolID, "name", se.ContentBlock.Name) } case "content_block_delta": - if event.Delta.Type == "text_delta" { - currentText.WriteString(event.Delta.Text) - ch <- StreamResult{ - Delta: event.Delta.Text, + if se.Delta.Type == "text_delta" { + currentText.WriteString(se.Delta.Text) + if !push(&Event{ + ID: NewEventID(), + Timestamp: time.Now(), + Delta: se.Delta.Text, ToolCalls: currentToolCalls, + + }, nil) { + return true, currentToolCalls } - } else if event.Delta.Type == "thinking_delta" { - *thinkContent += event.Delta.Thinking - // thinking 独立于 tool_use,不附带正在构建的 tool_calls - ch <- StreamResult{ - Think: event.Delta.Thinking, + } else if se.Delta.Type == "thinking_delta" { + *thinkContent += se.Delta.Thinking + // thinking 独立于 tool_use,不附带正在构建的 tool_calls + if !push(&Event{ + ID: NewEventID(), + Timestamp: time.Now(), + Think: se.Delta.Thinking, + + }, nil) { + return true, currentToolCalls } - } else if event.Delta.Type == "input_json_delta" { - // 处理 tool_use 的参数,直接取最后一个 tool_call - if len(currentToolCalls) > 0 && event.Delta.PartialJSON != "" { + } else if se.Delta.Type == "input_json_delta" { + // 处理 tool_use 的参数,直接取最后一个 tool_call + if len(currentToolCalls) > 0 && se.Delta.PartialJSON != "" { lastIdx := len(currentToolCalls) - 1 - // 跳过 thinking 相关字段(thinking_delta 伴随 input_json_delta 出现,但不属于 tool_use 参数) - if !strings.HasPrefix(strings.TrimSpace(event.Delta.PartialJSON), "\"thinking") { + // 跳过 thinking 相关字段(thinking_delta 伴随 input_json_delta 出现,但不属于 tool_use 参数) + if !strings.HasPrefix(strings.TrimSpace(se.Delta.PartialJSON), "\"thinking") { currentToolCalls[lastIdx].Function.Arguments = append( currentToolCalls[lastIdx].Function.Arguments, - event.Delta.PartialJSON..., + se.Delta.PartialJSON..., ) } } @@ -384,25 +393,28 @@ func (p *anthropicProvider) handleStreamEvent(event streamEvent, currentText *st case "content_block_stop": // 内容块结束 case "message_delta": - stopReason := FinishReason(event.Delta.StopReason) + stopReason := FinishReason(se.Delta.StopReason) if stopReason == "end_turn" { stopReason = "stop" - } else if len(currentToolCalls) > 0 { // 检查是否有 tool_calls + } else if len(currentToolCalls) > 0 { stopReason = "tool_calls" // 为了兼容 OpenAI } - // 发送完成信号 - ch <- StreamResult{ - ToolCalls: currentToolCalls, - FinishReason: stopReason, - Usage: event.Usage.toUsage(), + if !push(&Event{ + ID: NewEventID(), + Timestamp: time.Now(), + ToolCalls: currentToolCalls, + StopReason: stopReason, + Usage: se.Usage.toUsage(), + + }, nil) { + return true, currentToolCalls } - // 写入交互日志 if logDir != "" { go LogInteraction(logDir, "anthropic", &InteractionLog{ Model: model, Messages: messages, Tools: tools, - Usage: event.Usage.toUsage(), + Usage: se.Usage.toUsage(), Response: currentText.String(), ToolCalls: currentToolCalls, Think: *thinkContent, @@ -410,23 +422,27 @@ func (p *anthropicProvider) handleStreamEvent(event streamEvent, currentText *st }) } case "message_stop": // 在 message_delta 后会跟一个message_stop,里面没有实际信息 - ch <- StreamResult{ - Done: true, + push(&Event{ + ID: NewEventID(), + Timestamp: time.Now(), + Done: true, ToolCalls: currentToolCalls, - } + }, nil) return true, currentToolCalls case "message_start": - if event.Message != nil { - ch <- StreamResult{ - Model: event.Message.Model, - ResponseID: event.Message.ID, - } + if se.Message != nil { + push(&Event{ + ID: NewEventID(), + Timestamp: time.Now(), + Model: se.Message.Model, + ResponseID: se.Message.ID, + + }, nil) } - // 忽略 case "ping": // 忽略 default: - logger().Infow("unknown anthropic event type", "type", event.Type) + logger().Infow("unknown anthropic event type", "type", se.Type) } return false, currentToolCalls } diff --git a/pkg/services/llm/anthropic_test.go b/pkg/services/llm/anthropic_test.go index 8f6c4ea..7c4e3b1 100644 --- a/pkg/services/llm/anthropic_test.go +++ b/pkg/services/llm/anthropic_test.go @@ -204,22 +204,15 @@ func TestAnthropicProviderStreamChat(t *testing.T) { model: "claude-3-5-sonnet-20241022", } - ch, err := p.StreamChat(context.Background(), cfg, []Message{ + var results []*Event + for event, err := range p.StreamChat(context.Background(), cfg, []Message{ {Role: RoleUser, Content: "Hi"}, - }, nil) - - if err != nil { - t.Errorf("StreamChat() error = %v", err) - return - } - - var results []StreamResult - for result := range ch { - if result.Error != nil { - t.Errorf("stream error = %v", result.Error) + }, nil) { + if err != nil { + t.Errorf("stream error = %v", err) break } - results = append(results, result) + results = append(results, event) } if len(results) == 0 { diff --git a/pkg/services/llm/client.go b/pkg/services/llm/client.go index cee9300..d63ab83 100644 --- a/pkg/services/llm/client.go +++ b/pkg/services/llm/client.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "iter" "strings" ) @@ -14,8 +15,8 @@ var ErrUnsupportedProvider = errors.New("unsupported provider") type Client interface { // Chat 发送聊天请求,返回完整响应 Chat(ctx context.Context, messages []Message, tools []ToolDefinition) (*ChatResult, error) - // StreamChat 发送流式聊天请求,返回流式响应 - StreamChat(ctx context.Context, messages []Message, tools []ToolDefinition) (<-chan StreamResult, error) + // StreamChat 发送流式聊天请求,返回 iter.Seq2 事件流 + StreamChat(ctx context.Context, messages []Message, tools []ToolDefinition) iter.Seq2[*Event, error] // Generate 简单文本生成(用于关键词提取等) Generate(ctx context.Context, prompt string) (string, *Usage, error) // Embedding 向量化文本 @@ -31,7 +32,7 @@ type client struct { // provider 接口定义 type provider interface { Chat(ctx context.Context, cfg *config, messages []Message, tools []ToolDefinition) (*ChatResult, error) - StreamChat(ctx context.Context, cfg *config, messages []Message, tools []ToolDefinition) (<-chan StreamResult, error) + StreamChat(ctx context.Context, cfg *config, messages []Message, tools []ToolDefinition) iter.Seq2[*Event, error] Generate(ctx context.Context, cfg *config, prompt string) (string, *Usage, error) Embedding(ctx context.Context, cfg *config, texts []string) ([]float64, error) } @@ -68,8 +69,8 @@ func (c *client) Chat(ctx context.Context, messages []Message, tools []ToolDefin return c.provider.Chat(ctx, c.cfg, messages, tools) } -// StreamChat 发送流式聊天请求 -func (c *client) StreamChat(ctx context.Context, messages []Message, tools []ToolDefinition) (<-chan StreamResult, error) { +// StreamChat 发送流式聊天请求,返回 iter.Seq2 事件流 +func (c *client) StreamChat(ctx context.Context, messages []Message, tools []ToolDefinition) iter.Seq2[*Event, error] { return c.provider.StreamChat(ctx, c.cfg, messages, tools) } diff --git a/pkg/services/llm/event.go b/pkg/services/llm/event.go new file mode 100644 index 0000000..9a6e16a --- /dev/null +++ b/pkg/services/llm/event.go @@ -0,0 +1,57 @@ +package llm + +import ( + "time" + + oid "github.com/cupogo/andvari/models/oid" +) + +// NewEventID 生成新的 Event ID (OID)。 +func NewEventID() string { + return oid.NewID(oid.OtEvent).String() +} + +// Pusher 是流式事件推送函数类型,供 parseStreamResponse 等底层解析器使用。 +type Pusher func(*Event, error) bool + +// Event 是系统中 Agent、Tool、Runner、Handler 之间通信的唯一介质。 +type Event struct { + ID string // OID + Timestamp time.Time + Author string // "user" | "assistant" | toolName + + // 对话内容 + Delta string + Think string + ToolCalls []ToolCall + StopReason FinishReason + UserID string // 用户标识,用于持久化 HistoryItem 时设置 UID + UserPrompt string // 用户提问,用于持久化 HistoryItem 时配对 + Done bool // 流式结束标记,零值 false = 未结束 + + // 遥测 + Usage *Usage + Model string + MsgCount int // 当前消息数,用于 UsageRecord + Meta map[string]any // 透传给 UsageRecord 的元数据 + + ResponseID string + + // 工具结果 + ToolResult *ToolResult + + // 副作用 + Actions EventActions +} + +// ToolResult 携带单次工具调用的执行结果。 +type ToolResult struct { + CallID string + Name string + Content string +} + +// EventActions 是 Event 携带的副作用指令。 +type EventActions struct { + StateDelta map[string]any // session 级状态增量 +} diff --git a/pkg/services/llm/openai.go b/pkg/services/llm/openai.go index c8623cf..3a5af6e 100644 --- a/pkg/services/llm/openai.go +++ b/pkg/services/llm/openai.go @@ -7,9 +7,11 @@ import ( "encoding/json" "fmt" "io" + "iter" "net/http" "os" "strings" + "time" ) type StreamOptions struct { @@ -137,13 +139,9 @@ func (p *openAIProvider) Chat(ctx context.Context, cfg *config, messages []Messa return result, nil } -func (p *openAIProvider) StreamChat(ctx context.Context, cfg *config, messages []Message, tools []ToolDefinition) (<-chan StreamResult, error) { - ch := make(chan StreamResult, 100) - - // 启动流式读取 goroutine - go func() { - defer close(ch) +func (p *openAIProvider) StreamChat(ctx context.Context, cfg *config, messages []Message, tools []ToolDefinition) iter.Seq2[*Event, error] { + return func(yield func(*Event, error) bool) { endpoint := buildEndpoint(cfg.baseURL, "/chat/completions") var toolsOpt []ToolDefinition @@ -175,18 +173,16 @@ func (p *openAIProvider) StreamChat(ctx context.Context, cfg *config, messages [ "messages", MessagesLogged(messages), ) - // 序列化请求体,保存用于错误时打印 reqBodyBytes, err := json.Marshal(reqBody) if err != nil { - ch <- StreamResult{Error: fmt.Errorf("marshal request: %w", err)} + yield(nil, fmt.Errorf("marshal request: %w", err)) return } - // 构建并发送请求 req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBodyBytes)) if err != nil { logger().Warnw("create stream request failed", "err", err, "reqBody", string(reqBodyBytes)) - ch <- StreamResult{Error: err} + yield(nil, err) return } @@ -206,11 +202,10 @@ func (p *openAIProvider) StreamChat(ctx context.Context, cfg *config, messages [ resp, err := hc.Do(req) if err != nil { logger().Warnw("stream request failed", "err", err, "reqBody", string(reqBodyBytes)) - ch <- StreamResult{Error: err} + yield(nil, err) return } - // 检查响应状态码 if resp.StatusCode >= 400 { fmt.Fprintf(os.Stderr, "\n%s\n", string(reqBodyBytes)) respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) @@ -218,24 +213,28 @@ func (p *openAIProvider) StreamChat(ctx context.Context, cfg *config, messages [ errMsg := fmt.Errorf("http %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) logger().Warnw("stream response error", "status", resp.StatusCode, - // "reqBody", string(reqBodyBytes), "respBody", string(respBody)) - ch <- StreamResult{Error: errMsg} + yield(nil, errMsg) return } defer resp.Body.Close() - // 解析流响应 - if err := p.parseStreamResponse(resp.Body, ch, cfg.debug, cfg.logDir, cfg.model, messages, tools); err != nil { - ch <- StreamResult{Error: err} + push := func(event *Event, err error) bool { + if err != nil { + return yield(nil, err) + } + event.Author = "assistant" + return yield(event, nil) } - }() - return ch, nil + if err := p.parseStreamResponse(resp.Body, push, cfg.debug, cfg.logDir, cfg.model, messages, tools); err != nil { + yield(nil, err) + } + } } -// parseStreamResponse 解析流式响应 -func (p *openAIProvider) parseStreamResponse(body io.Reader, ch chan<- StreamResult, debug bool, logDir, model string, messages []Message, tools []ToolDefinition) error { +// parseStreamResponse 解析流式响应,通过 push 直接产出 *Event。 +func (p *openAIProvider) parseStreamResponse(body io.Reader, push Pusher, debug bool, logDir, model string, messages []Message, tools []ToolDefinition) error { bufReader := bufio.NewReaderSize(body, 1024) var currentToolCalls []ToolCall @@ -249,9 +248,9 @@ func (p *openAIProvider) parseStreamResponse(body io.Reader, ch chan<- StreamRes rawLine, err := bufReader.ReadBytes('\n') if err != nil { if err == io.EOF { - ch <- StreamResult{Done: true} + push(&Event{Done: true}, nil) } else { - ch <- StreamResult{Error: fmt.Errorf("read: %w", err)} + return fmt.Errorf("read: %w", err) } return nil } @@ -270,7 +269,7 @@ func (p *openAIProvider) parseStreamResponse(body io.Reader, ch chan<- StreamRes noPrefixLine := bytes.TrimLeft(noSpaceLine[5:], " \t") if string(noPrefixLine) == "[DONE]" { logger().Infow("stream DONE", "lines", lines) - ch <- StreamResult{Done: true} + push(&Event{Done: true}, nil) return nil } @@ -326,17 +325,19 @@ func (p *openAIProvider) parseStreamResponse(body io.Reader, ch chan<- StreamRes } // 发送内容,每个 chunk 都带上累积的 tool_calls - result := StreamResult{ - Delta: delta.Content, - Think: delta.ReasoningContent, - ToolCalls: currentToolCalls, - FinishReason: finishReason, - Model: chunk.Model, - ResponseID: chunk.ID, + ev := &Event{ + ID: NewEventID(), + Timestamp: time.Now(), + Delta: delta.Content, + Think: delta.ReasoningContent, + ToolCalls: currentToolCalls, + StopReason: finishReason, + Model: chunk.Model, + ResponseID: chunk.ID, } if chunk.Usage != nil { logger().Debugw("usage from chunk", "usage", chunk.Usage) - result.Usage = chunk.Usage.toUsage() + ev.Usage = chunk.Usage.toUsage() } // 累积响应内容 @@ -351,21 +352,21 @@ func (p *openAIProvider) parseStreamResponse(body io.Reader, ch chan<- StreamRes shouldEndStream := finishReason != "" || (len(delta.ToolCalls) == 0 && len(currentToolCalls) > 0) if shouldEndStream { - logger().Debugw("stream should done", "result", &result) - result.Done = true + ev.Done = true + } + if !push(ev, nil) { + return nil } - ch <- result if shouldEndStream { logger().Infow("stream done", "finish_reason", finishReason, "tool_calls_count", len(currentToolCalls), "lines", lines) - // 写入交互日志 if logDir != "" { go LogInteraction(logDir, "openai", &InteractionLog{ Model: model, Messages: messages, Tools: tools, - Usage: result.Usage, + Usage: ev.Usage, Response: responseText, ToolCalls: currentToolCalls, Think: thinkContent, diff --git a/pkg/services/llm/openai_test.go b/pkg/services/llm/openai_test.go index 836cebe..04d43c4 100644 --- a/pkg/services/llm/openai_test.go +++ b/pkg/services/llm/openai_test.go @@ -224,22 +224,15 @@ func TestOpenAIProviderStreamChat(t *testing.T) { temperature: 0.7, } - ch, err := p.StreamChat(context.Background(), cfg, []Message{ + var results []*Event + for event, err := range p.StreamChat(context.Background(), cfg, []Message{ {Role: RoleUser, Content: "Hi"}, - }, nil) - - if err != nil { - t.Errorf("StreamChat() error = %v", err) - return - } - - var results []StreamResult - for result := range ch { - if result.Error != nil { - t.Errorf("stream error = %v", result.Error) + }, nil) { + if err != nil { + t.Errorf("stream error = %v", err) break } - results = append(results, result) + results = append(results, event) } // 验证至少收到了一些结果 @@ -277,24 +270,17 @@ func TestOpenAIProviderStreamChatWithTools(t *testing.T) { }, } - ch, err := p.StreamChat(context.Background(), cfg, []Message{ + for event, err := range p.StreamChat(context.Background(), cfg, []Message{ {Role: RoleUser, Content: "What's the weather?"}, - }, tools) - - if err != nil { - t.Errorf("StreamChat() error = %v", err) - return - } - - for result := range ch { - if result.Error != nil { - t.Errorf("stream error = %v", result.Error) + }, tools) { + if err != nil { + t.Errorf("stream error = %v", err) break } // 验证 tool calls 被正确解析 - if len(result.ToolCalls) > 0 { - if result.ToolCalls[0].Function.Name != "get_weather" { - t.Errorf("tool name = %v, want get_weather", result.ToolCalls[0].Function.Name) + if len(event.ToolCalls) > 0 { + if event.ToolCalls[0].Function.Name != "get_weather" { + t.Errorf("tool name = %v, want get_weather", event.ToolCalls[0].Function.Name) } } } diff --git a/pkg/services/llm/types.go b/pkg/services/llm/types.go index effd9e2..37e2c5e 100644 --- a/pkg/services/llm/types.go +++ b/pkg/services/llm/types.go @@ -99,20 +99,6 @@ func (r *ChatResult) HasToolCalls() bool { return len(r.ToolCalls) > 0 } -// StreamResult 流式响应结果 -type StreamResult struct { - Delta string - Think string - ToolCalls []ToolCall - Done bool `json:",omitempty"` - FinishReason FinishReason - Error error `json:",omitempty"` - Model string - ResponseID string - - Usage *Usage -} - // Usage token 使用统计 type Usage struct { InputTokens int diff --git a/pkg/services/llm/types_test.go b/pkg/services/llm/types_test.go index de3db58..cd66fd9 100644 --- a/pkg/services/llm/types_test.go +++ b/pkg/services/llm/types_test.go @@ -200,23 +200,6 @@ func TestChatResultHasToolCalls(t *testing.T) { } } -func TestStreamResultString(t *testing.T) { - // 测试 StreamResult 类型(用于覆盖) - result := StreamResult{ - Delta: "Hello", - ToolCalls: nil, - Done: false, - Error: nil, - } - - _ = result - - // 验证 Done 和 Error - result.Done = true - if !result.Done { - t.Error("expected Done to be true") - } -} func TestMessagesLogged_String(t *testing.T) { tests := []struct { diff --git a/pkg/services/runner/runner.go b/pkg/services/runner/runner.go new file mode 100644 index 0000000..eeecbe2 --- /dev/null +++ b/pkg/services/runner/runner.go @@ -0,0 +1,53 @@ +package runner + +import ( + "context" + "errors" + "fmt" + + "github.com/liut/morign/pkg/services/llm" +) + +// SessionStore 会话级状态存储接口。 +type SessionStore interface { + MergeDelta(ctx context.Context, sessionID string, delta map[string]any) error +} + +// HistoryStore 对话历史持久化接口。 +type HistoryStore interface { + AppendEvent(ctx context.Context, sessionID string, event *llm.Event) error + CreateUsageRecord(ctx context.Context, sessionID string, event *llm.Event) error +} + +// Runner 统一持久化入口:历史追加、状态合并、用量记录。 +type Runner struct { + sessionStore SessionStore + historyStore HistoryStore +} + +// New 创建 Runner。 +func New(sessionStore SessionStore, historyStore HistoryStore) *Runner { + return &Runner{ + sessionStore: sessionStore, + historyStore: historyStore, + } +} + +// Persist 持久化一个事件:追加历史、合并 StateDelta、记录用量。 +func (r *Runner) Persist(ctx context.Context, sessionID string, event *llm.Event) error { + if err := r.historyStore.AppendEvent(ctx, sessionID, event); err != nil { + return err + } + var errs []error + if len(event.Actions.StateDelta) > 0 && r.sessionStore != nil { + if err := r.sessionStore.MergeDelta(ctx, sessionID, event.Actions.StateDelta); err != nil { + errs = append(errs, fmt.Errorf("merge delta: %w", err)) + } + } + if event.Usage != nil { + if err := r.historyStore.CreateUsageRecord(ctx, sessionID, event); err != nil { + errs = append(errs, fmt.Errorf("create usage: %w", err)) + } + } + return errors.Join(errs...) +} diff --git a/pkg/services/stores/event_adapter.go b/pkg/services/stores/event_adapter.go new file mode 100644 index 0000000..a5de9e0 --- /dev/null +++ b/pkg/services/stores/event_adapter.go @@ -0,0 +1,92 @@ +package stores + +import ( + "context" + "time" + + oid "github.com/cupogo/andvari/models/oid" + + "github.com/liut/morign/pkg/models/aigc" + "github.com/liut/morign/pkg/models/convo" + "github.com/liut/morign/pkg/services/llm" + "github.com/liut/morign/pkg/services/runner" +) + +// NewHistoryStore 创建 runner.HistoryStore 的适配实现。 +func NewHistoryStore(sto Storage) runner.HistoryStore { + return &historyAdapter{sto: sto} +} + +// NewSessionStore 创建 runner.SessionStore 的适配实现。 +func NewSessionStore(sto Storage) runner.SessionStore { + return &sessionAdapter{sto: sto} +} + +type historyAdapter struct { + sto Storage +} + +func (a *historyAdapter) AppendEvent(ctx context.Context, sessionID string, event *llm.Event) error { + cs := NewConversation(ctx, sessionID) + + switch event.Author { + case "user": + return nil + default: + item := &aigc.HistoryItem{ + Time: time.Now().Unix(), + ChatItem: &aigc.HistoryChatItem{ + Assistant: event.Delta, + }, + } + item.ChatItem.Think = event.Think + if event.UserID != "" { + item.UID = event.UserID + } + if event.UserPrompt != "" { + item.ChatItem.User = event.UserPrompt + } + if err := cs.AddHistory(ctx, item); err != nil { + return err + } + return cs.Save(ctx) + } +} + +func (a *historyAdapter) CreateUsageRecord(ctx context.Context, sessionID string, event *llm.Event) error { + basic := convo.UsageRecordBasic{ + SessionID: oid.Cast(sessionID), + MsgCount: event.MsgCount, + InputTokens: event.Usage.InputTokens, + OutputTokens: event.Usage.OutputTokens, + TotalTokens: event.Usage.TotalTokens, + Model: event.Model, + } + if len(event.Meta) > 0 { + basic.MetaAddKVs(flattenDelta(event.Meta)...) + } + _, err := a.sto.Convo().CreateUsageRecord(ctx, basic) + return err +} + +type sessionAdapter struct { + sto Storage +} + +func (a *sessionAdapter) MergeDelta(ctx context.Context, sessionID string, delta map[string]any) error { + if len(delta) == 0 { + return nil + } + set := convo.SessionSet{} + set.MetaAddKVs(flattenDelta(delta)...) + return a.sto.Convo().UpdateSession(ctx, sessionID, set) +} + +// flattenDelta 将 map[string]any 展平为 key, value 交替的切片,供 MetaAddKVs 使用。 +func flattenDelta(delta map[string]any) []any { + args := make([]any, 0, len(delta)*2) + for k, v := range delta { + args = append(args, k, v) + } + return args +} diff --git a/pkg/services/stores/integration_test.go b/pkg/services/stores/integration_test.go index 77ed4c9..8039b2b 100644 --- a/pkg/services/stores/integration_test.go +++ b/pkg/services/stores/integration_test.go @@ -27,6 +27,7 @@ package stores import ( "context" "fmt" + "iter" "math/rand" "os" "testing" @@ -45,8 +46,8 @@ func (m *mockEmbeddingClient) Chat(ctx context.Context, messages []llm.Message, return nil, nil } -func (m *mockEmbeddingClient) StreamChat(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) (<-chan llm.StreamResult, error) { - return nil, nil +func (m *mockEmbeddingClient) StreamChat(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) iter.Seq2[*llm.Event, error] { + return func(yield func(*llm.Event, error) bool) {} } func (m *mockEmbeddingClient) Generate(ctx context.Context, prompt string) (string, *llm.Usage, error) { @@ -338,3 +339,101 @@ func TestIntegration_ListMemories(t *testing.T) { t.Logf("Total memories: %d, returned: %d", total, len(data)) } + +func TestIntegration_HistoryStoreAppendEvent(t *testing.T) { + sto := Sgt() + ctx := context.Background() + + sessionID := oid.NewID(oid.OtEvent).String() + hs := NewHistoryStore(sto) + + event := &llm.Event{ + Author: "assistant", + Delta: "Hello, this is a test response", + Think: "The user is testing", + } + + if err := hs.AppendEvent(ctx, sessionID, event); err != nil { + t.Fatalf("AppendEvent failed: %v", err) + } + + // Verify history was written + cs := NewConversation(ctx, sessionID) + history, err := cs.ListHistory(ctx) + if err != nil { + t.Fatalf("ListHistory failed: %v", err) + } + if len(history) == 0 { + t.Fatal("expected at least one history item") + } + found := false + for _, h := range history { + if h.ChatItem != nil && h.ChatItem.Assistant == event.Delta { + found = true + break + } + } + if !found { + t.Errorf("history item with content %q not found", event.Delta) + } + + // Cleanup + _ = cs.ClearHistory(ctx) +} + +func TestIntegration_HistoryStoreCreateUsageRecord(t *testing.T) { + sto := Sgt() + ctx := context.Background() + + sessionID := oid.NewID(oid.OtEvent).String() + hs := NewHistoryStore(sto) + + event := &llm.Event{ + Usage: &llm.Usage{InputTokens: 100, OutputTokens: 50, TotalTokens: 150}, + Model: "test-model", + MsgCount: 3, + } + + if err := hs.CreateUsageRecord(ctx, sessionID, event); err != nil { + t.Fatalf("CreateUsageRecord failed: %v", err) + } + + spec := &ConvoUsageRecordSpec{} + spec.Limit = 1 + data, total, err := sto.Convo().ListUsageRecord(ctx, spec) + if err != nil { + t.Fatalf("ListUsageRecord failed: %v", err) + } + t.Logf("Total usage records: %d", total) + _ = data +} + +func TestIntegration_SessionStoreMergeDelta(t *testing.T) { + sto := Sgt() + ctx := context.Background() + + sess := convo.NewSessionWithBasic(convo.SessionBasic{ + Title: "test-session-delta", + }) + if err := sto.Convo().SaveSession(ctx, sess); err != nil { + t.Fatalf("SaveSession failed: %v", err) + } + sessionID := sess.StringID() + + ss := NewSessionStore(sto) + delta := map[string]any{ + "last_tool": "kb_search", + "count": 3, + } + + if err := ss.MergeDelta(ctx, sessionID, delta); err != nil { + t.Fatalf("MergeDelta failed: %v", err) + } + + updated, err := sto.Convo().GetSession(ctx, sessionID) + if err != nil { + t.Fatalf("GetSession failed: %v", err) + } + _ = updated + t.Logf("Session meta updated for %s", sessionID) +} diff --git a/pkg/web/api/agent.go b/pkg/web/api/agent.go index d5ed9b0..df1742a 100644 --- a/pkg/web/api/agent.go +++ b/pkg/web/api/agent.go @@ -3,6 +3,7 @@ package api import ( "context" "fmt" + "iter" "strings" "github.com/liut/morign/pkg/services/llm" @@ -64,7 +65,7 @@ func (ag *Agent) BuildSystemMessage(ctx context.Context) (llm.Message, []llm.Too return llm.Message{Role: llm.RoleSystem, Content: sb.String()}, tools } -// Chat 非流式对话,支持工具调用循环 +// Chat 非流式对话,直接使用 llm.Chat + 工具调用循环。 func (ag *Agent) Chat(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) (string, error) { exec := func(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) (string, []llm.ToolCall, *llm.Usage, error) { result, err := ag.llm.Chat(ctx, messages, tools) @@ -73,62 +74,78 @@ func (ag *Agent) Chat(ctx context.Context, messages []llm.Message, tools []llm.T } return result.Content, result.ToolCalls, result.Usage, nil } - answer, _, _, err := ag.toolExec.ExecuteToolCallLoop(ctx, messages, tools, exec) return answer, err } -// StreamChat 流式对话,支持工具调用循环 -// 通过 StreamCallbacks 回调输出 delta 和 think,返回最终完整回答文本 -func (ag *Agent) StreamChat(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition, cb StreamCallbacks) (string, error) { +// Run 以 iter.Seq2 方式执行对话,包含工具调用循环。 +func (ag *Agent) Run(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition) iter.Seq2[*llm.Event, error] { maxLoop := settings.Current.MaxLoopIterations if maxLoop <= 0 { maxLoop = 5 } - var fullAnswer string - var fullThink string + return func(yield func(*llm.Event, error) bool) { + var fullThink string - for iter := 0; iter < maxLoop; iter++ { - stream, err := ag.llm.StreamChat(ctx, messages, tools) - if err != nil { - return fullAnswer, fmt.Errorf("stream chat: %w", err) - } + for iter := 0; iter < maxLoop; iter++ { + var roundAnswer string + var roundThink string + var toolCalls []llm.ToolCall - var roundAnswer string - var roundThink string - var toolCalls []llm.ToolCall + for event, err := range ag.llm.StreamChat(ctx, messages, tools) { + if err != nil { + yield(nil, fmt.Errorf("stream chat: %w", err)) + return + } + roundAnswer += event.Delta + roundThink += event.Think - for result := range stream { - if result.Error != nil { - return fullAnswer, result.Error - } - if result.Delta != "" { - roundAnswer += result.Delta - if cb.OnDelta != nil { - cb.OnDelta(result.Delta) + if !yield(event, nil) { + return } - } - if result.Think != "" { - roundThink += result.Think - if cb.OnThink != nil { - cb.OnThink(result.Think) + + if event.Done { + toolCalls = event.ToolCalls } } - if result.Done { - toolCalls = result.ToolCalls - } - } - fullAnswer += roundAnswer - fullThink += roundThink + fullThink += roundThink - if len(toolCalls) == 0 { - break - } + if len(toolCalls) == 0 { + return + } - messages, _ = ag.toolExec.ExecuteToolCalls(ctx, messages, toolCalls, roundThink) + // 执行工具调用,产生 tool result events + events, updatedMsgs := ag.toolExec.ExecuteToolCalls(ctx, messages, toolCalls, roundThink) + messages = updatedMsgs + for _, ev := range events { + ev.Author = "tool" + if !yield(ev, nil) { + return + } + } + } } +} +// StreamChat 流式对话,通过 StreamCallbacks 回调输出,返回最终文本。 +func (ag *Agent) StreamChat(ctx context.Context, messages []llm.Message, tools []llm.ToolDefinition, cb StreamCallbacks) (string, error) { + var fullAnswer string + for event, err := range ag.Run(ctx, messages, tools) { + if err != nil { + return fullAnswer, fmt.Errorf("stream chat: %w", err) + } + fullAnswer += event.Delta + if event.Delta != "" && cb.OnDelta != nil { + cb.OnDelta(event.Delta) + } + if event.Think != "" && cb.OnThink != nil { + cb.OnThink(event.Think) + } + if event.StopReason == llm.FinishReasonToolCalls { + cb.OnThink("\n") + } + } return fullAnswer, nil } diff --git a/pkg/web/api/api.go b/pkg/web/api/api.go index d9a9846..7011438 100644 --- a/pkg/web/api/api.go +++ b/pkg/web/api/api.go @@ -18,6 +18,7 @@ import ( "github.com/liut/morign/pkg/models/aigc" "github.com/liut/morign/pkg/models/mcps" "github.com/liut/morign/pkg/services/llm" + "github.com/liut/morign/pkg/services/runner" "github.com/liut/morign/pkg/services/stores" "github.com/liut/morign/pkg/services/tools" "github.com/liut/morign/pkg/settings" @@ -51,6 +52,7 @@ type api struct { preset aigc.Preset toolreg *tools.Registry toolExec *ToolExecutor + rnr *runner.Runner } func init() { @@ -102,13 +104,13 @@ func newapi(sto stores.Storage) *api { if err != nil { logger().Fatalw("create llm interact client failed", "err", err) } - return &api{ sto: sto, llm: llmClient, preset: preset, toolreg: toolreg, toolExec: NewToolExecutor(toolreg), + rnr: runner.New(stores.NewSessionStore(sto), stores.NewHistoryStore(sto)), } } diff --git a/pkg/web/api/handle_convo.go b/pkg/web/api/handle_convo.go index 52842ef..7767eed 100644 --- a/pkg/web/api/handle_convo.go +++ b/pkg/web/api/handle_convo.go @@ -59,26 +59,6 @@ type chatRequest struct { prompt string } -func (cr *chatRequest) gatherUsage(res chatResponse) convo.UsageRecordBasic { - sub := convo.UsageRecordBasic{ - SessionID: cr.cs.GetOID(), - MsgCount: len(cr.messages), - InputTokens: res.usage.InputTokens, - OutputTokens: res.usage.OutputTokens, - TotalTokens: res.usage.TotalTokens, - Model: res.model, - } - sub.MetaAddKVs("prompt", cr.prompt, - "answerHead", words.TakeHead(res.answer, 12, ".."), - "answerTail", words.TakeTail(res.answer, 15, ".."), - ) - if len(res.llmResID) > 0 { - sub.MetaAddKVs("resposeID", res.llmResID) - } - - return sub -} - // convertMCPToolsToLLMTools 将 MCP 工具描述转换为 LLM 工具定义 func convertMCPToolsToLLMTools(tools []mcps.ToolDescriptor) []llm.ToolDefinition { result := make([]llm.ToolDefinition, 0, len(tools)) @@ -380,11 +360,11 @@ func (a *api) chatStreamResponseLoop(ccr *chatRequest, w http.ResponseWriter, r logger().Infow("before execute tool calls", "tools", len(streamRes.toolCalls), "msgs", len(ccr.messages), "think_len", len(streamRes.think)) - var hasToolCall bool // 执行工具调用,传入 reasoning_content 以便回传 - ccr.messages, hasToolCall = a.toolExec.ExecuteToolCalls(cctx, ccr.messages, streamRes.toolCalls, streamRes.think) - logger().Infow("executed tool calls", "hasToolCall", hasToolCall, "msgs", len(ccr.messages)) - if !hasToolCall { + evs, msgs := a.toolExec.ExecuteToolCalls(cctx, ccr.messages, streamRes.toolCalls, streamRes.think) + ccr.messages = msgs + logger().Infow("executed tool calls", "executed", len(evs), "msgs", len(ccr.messages)) + if len(evs) == 0 { // 没有成功执行任何工具,跳出循环 res.finish = streamRes.finish break @@ -393,12 +373,13 @@ func (a *api) chatStreamResponseLoop(ccr *chatRequest, w http.ResponseWriter, r } if len(res.answer) > 0 { - ccr.hi.ChatItem.Assistant = res.answer - ccr.hi.ChatItem.Think = res.think - if err := ccr.cs.AddHistory(r.Context(), ccr.hi); err == nil { - if err = ccr.cs.Save(r.Context()); err != nil { - logger().Infow("save convo fail", "err", err) - } + if err := a.rnr.Persist(r.Context(), ccr.cs.GetID(), &llm.Event{ + Author: "assistant", + Delta: res.answer, + Think: res.think, + UserPrompt: ccr.prompt, + }); err != nil { + logger().Infow("persist fail", "err", err) } } @@ -430,34 +411,27 @@ func (a *api) chatStreamResponseLoop(ccr *chatRequest, w http.ResponseWriter, r // doChatStream 执行一次流式调用,返回累积的 answer 和 toolCalls func (a *api) doChatStream(ccr *chatRequest, w http.ResponseWriter, r *http.Request) chatResponse { - stream, err := a.llm.StreamChat(r.Context(), ccr.messages, ccr.tools) - if err != nil { - logger().Infow("call chat stream fail", "err", err) - apiFail(w, r, 500, err) - return chatResponse{} - } - var res chatResponse var lastWriteEmpty bool // 标记上一次是否写入了空消息 - for result := range stream { - var cm ChatMessage - - if result.Error != nil { - logger().Infow("stream error", "err", result.Error) + for result, err := range a.llm.StreamChat(r.Context(), ccr.messages, ccr.tools) { + if err != nil { + logger().Infow("stream error", "err", err) break } + var cm ChatMessage + cm.Delta = result.Delta cm.Think = result.Think res.answer += result.Delta res.think += result.Think res.usage = result.Usage - if len(result.ToolCalls) > 0 && result.FinishReason == llm.FinishReasonToolCalls { + if len(result.ToolCalls) > 0 && result.StopReason == llm.FinishReasonToolCalls { cm.ToolCalls = convertToolCallsForJSON(result.ToolCalls) ccr.chunkIdx++ cm.ConversationID = ccr.cs.GetID() - cm.FinishReason = string(result.FinishReason) + cm.FinishReason = string(result.StopReason) _ = writeEvent(w, strconv.Itoa(ccr.chunkIdx), &cm) } @@ -469,29 +443,41 @@ func (a *api) doChatStream(ccr *chatRequest, w http.ResponseWriter, r *http.Requ } if result.Done { - logger().Infow("result done", "finish", result.FinishReason) - res.finish = result.FinishReason - } else { - // 判断当前是否为空消息 - isEmpty := result.Delta == "" && len(cm.ToolCalls) == 0 - if !isEmpty || !lastWriteEmpty { - // 有内容,或者上一次不是空的,则输出 - ccr.chunkIdx++ - if wrote := writeEvent(w, strconv.Itoa(ccr.chunkIdx), &cm); !wrote { - break - } - } - // 如果当前是空的且上一次也是空的,跳过(连续空消息只保留第一个) - } - - if result.Done { // 只使用最后拼接的完整信息 + logger().Infow("result done", "finish", result.StopReason) + res.finish = result.StopReason res.toolCalls = result.ToolCalls break } + + // 判断当前是否为空消息 + isEmpty := result.Delta == "" && result.Think == "" && len(cm.ToolCalls) == 0 + if !isEmpty || !lastWriteEmpty { + // 有内容,或者上一次不是空的,则输出 + ccr.chunkIdx++ + if wrote := writeEvent(w, strconv.Itoa(ccr.chunkIdx), &cm); !wrote { + break + } + } + // 如果当前是空的且上一次也是空的,跳过(连续空消息只保留第一个) + lastWriteEmpty = isEmpty } if res.usage != nil { - if _, err := a.sto.Convo().CreateUsageRecord(r.Context(), ccr.gatherUsage(res)); err != nil { - logger().Infow("create session usage fail", "err", err) + meta := map[string]any{ + "prompt": ccr.prompt, + "answerHead": words.TakeHead(res.answer, 12, ".."), + "answerTail": words.TakeTail(res.answer, 15, ".."), + } + if len(res.llmResID) > 0 { + meta["resposeID"] = res.llmResID + } + if err := a.rnr.Persist(r.Context(), ccr.cs.GetID(), &llm.Event{ + Author: "assistant", + Usage: res.usage, + Model: res.model, + MsgCount: len(ccr.messages), + Meta: meta, + }); err != nil { + logger().Infow("persist usage fail", "err", err) } } logger().Infow("chat stream done", "finish", res.finish, "answer", len(res.answer), diff --git a/pkg/web/api/handle_platform.go b/pkg/web/api/handle_platform.go index 3568780..6c67cab 100644 --- a/pkg/web/api/handle_platform.go +++ b/pkg/web/api/handle_platform.go @@ -4,7 +4,6 @@ import ( "context" "log/slog" "strings" - "time" "github.com/go-chi/chi/v5" @@ -14,6 +13,7 @@ import ( "github.com/liut/morign/pkg/services/channels/feishu" "github.com/liut/morign/pkg/services/channels/wecom" "github.com/liut/morign/pkg/services/llm" + "github.com/liut/morign/pkg/services/runner" "github.com/liut/morign/pkg/services/stores" "github.com/liut/morign/pkg/services/tools" "github.com/liut/morign/pkg/settings" @@ -25,6 +25,7 @@ type channelHandler struct { llm llm.Client toolreg *tools.Registry toolExec *ToolExecutor + rnr *runner.Runner } // InitChannels initializes channel adapters from preset configuration. @@ -34,16 +35,17 @@ func InitChannels(r chi.Router, preset *aigc.Preset, sto stores.Storage, llmClie channels.RegisterChannel("feishu", feishu.New) channels.RegisterChannel("wecom", wecom.New) + if preset == nil || len(preset.Channels) == 0 { + slog.Info("channel: no platforms configured") + return nil + } + chandler := &channelHandler{ sto: sto, llm: llmClient, toolreg: toolreg, toolExec: NewToolExecutor(toolreg), - } - - if preset == nil || len(preset.Channels) == 0 { - slog.Info("channel: no platforms configured") - return nil + rnr: runner.New(stores.NewSessionStore(sto), stores.NewHistoryStore(sto)), } for name, cfg := range preset.Channels { @@ -217,9 +219,10 @@ func (chh *channelHandler) handleStreamingReply(ctx context.Context, p channel.C } // Execute tool calls and update messages with results - var hasToolCall bool - messages, hasToolCall = chh.toolExec.ExecuteToolCalls(ctx, messages, toolCalls, "") - if !hasToolCall { + evs, msgs := chh.toolExec.ExecuteToolCalls(ctx, messages, toolCalls, "") + messages = msgs + if len(evs) == 0 { + // 没有成功执行任何工具,跳出循环 break } } @@ -237,18 +240,13 @@ func (chh *channelHandler) handleStreamingReply(ctx context.Context, p channel.C // Save to history if fullAnswer != "" { - hi := &aigc.HistoryItem{ - Time: time.Now().Unix(), - UID: msg.UserID, - ChatItem: &aigc.HistoryChatItem{ - User: msg.Content, - Assistant: fullAnswer, - }, - } - if err := cs.AddHistory(ctx, hi); err == nil { - if err := cs.Save(ctx); err != nil { - slog.Warn("channel: save history failed", "err", err) - } + if err := chh.rnr.Persist(ctx, cs.GetID(), &llm.Event{ + Author: "assistant", + Delta: fullAnswer, + UserID: msg.UserID, + UserPrompt: msg.Content, + }); err != nil { + slog.Warn("channel: persist history failed", "err", err) } } } @@ -256,20 +254,14 @@ func (chh *channelHandler) handleStreamingReply(ctx context.Context, p channel.C // doChannelStream performs one streaming chat round, returns answer, tool calls, and streamID. // The stream lifecycle (Start/Finish) is managed by the caller (handleStreamingReply). func (chh *channelHandler) doChannelStream(ctx context.Context, p channel.Channel, msg *channel.Message, sr channel.StreamReplier, streamID string, messages []llm.Message, tools []llm.ToolDefinition) (string, []llm.ToolCall, error) { - stream, err := chh.llm.StreamChat(ctx, messages, tools) - if err != nil { - slog.Error("channel: stream chat failed", "channel", p.Name(), "error", err) - return "", nil, err - } - var contentBuilder strings.Builder var currentToolCalls []llm.ToolCall chunkCount := 0 - for result := range stream { + for result, err := range chh.llm.StreamChat(ctx, messages, tools) { chunkCount++ - if result.Error != nil { - slog.Warn("channel: stream error", "err", result.Error) + if err != nil { + slog.Warn("channel: stream error", "err", err) break } @@ -295,8 +287,6 @@ func (chh *channelHandler) doChannelStream(ctx context.Context, p channel.Channe "toolCalls_len", len(currentToolCalls), "streamID", streamID) - // Stream ended (EOF). If Done was never true, there are no tool calls. - // currentToolCalls remains nil, which is correct. return contentBuilder.String(), currentToolCalls, nil } @@ -322,18 +312,13 @@ func (chh *channelHandler) handleRegularReply(ctx context.Context, p channel.Cha // Save to history (only final answer, not tool call content) if len(answer) > 0 { - hi := &aigc.HistoryItem{ - Time: time.Now().Unix(), - UID: msg.UserID, - ChatItem: &aigc.HistoryChatItem{ - User: msg.Content, - Assistant: answer, - }, - } - if err := cs.AddHistory(ctx, hi); err == nil { - if err := cs.Save(ctx); err != nil { - slog.Warn("channel: save history failed", "err", err) - } + if err := chh.rnr.Persist(ctx, cs.GetID(), &llm.Event{ + Author: "assistant", + Delta: answer, + UserID: msg.UserID, + UserPrompt: msg.Content, + }); err != nil { + slog.Warn("channel: persist history failed", "err", err) } } diff --git a/pkg/web/api/tool_executor.go b/pkg/web/api/tool_executor.go index d2020ad..df66bae 100644 --- a/pkg/web/api/tool_executor.go +++ b/pkg/web/api/tool_executor.go @@ -3,6 +3,7 @@ package api import ( "context" "encoding/json" + "time" "github.com/liut/morign/pkg/services/llm" "github.com/liut/morign/pkg/services/tools" @@ -39,14 +40,19 @@ func (e *ToolExecutor) ExecuteToolCallLoop( return answer, nil, usage, nil } - messages, _ = e.ExecuteToolCalls(ctx, messages, toolCalls, "") + evs, msgs := e.ExecuteToolCalls(ctx, messages, toolCalls, "") + messages = msgs + if len(evs) == 0 { + // 没有成功执行任何工具,跳出循环 + return answer, toolCalls, usage, nil + } } } -// ExecuteToolCalls 执行单轮工具调用,think 用于 DeepSeek reasoning_content 回传 -func (e *ToolExecutor) ExecuteToolCalls(ctx context.Context, messages []llm.Message, toolCalls []llm.ToolCall, think string) ([]llm.Message, bool) { +// ExecuteToolCalls 执行单轮工具调用,返回事件列表和更新后的消息。 +func (e *ToolExecutor) ExecuteToolCalls(ctx context.Context, messages []llm.Message, toolCalls []llm.ToolCall, think string) ([]*llm.Event, []llm.Message) { if len(toolCalls) == 0 { - return messages, false + return nil, messages } messages = append(messages, llm.Message{ @@ -55,7 +61,7 @@ func (e *ToolExecutor) ExecuteToolCalls(ctx context.Context, messages []llm.Mess ToolCalls: toolCalls, }) - var hasToolCall bool + var events []*llm.Event for _, tc := range toolCalls { logger().Infow("chat", "toolCallID", tc.ID, "toolCallType", tc.Type, "toolCallName", tc.Function.Name) @@ -83,13 +89,24 @@ func (e *ToolExecutor) ExecuteToolCalls(ctx context.Context, messages []llm.Mess logger().Infow("invokeTool ok", "toolCallName", tc.Function.Name, "content", toolsvc.ResultLogs(content)) + toolResult := formatToolResult(content) messages = append(messages, llm.Message{ Role: llm.RoleTool, - Content: formatToolResult(content), + Content: toolResult, ToolCallID: tc.ID, }) - hasToolCall = true + + events = append(events, &llm.Event{ + ID: llm.NewEventID(), + Timestamp: time.Now(), + Author: tc.Function.Name, + ToolResult: &llm.ToolResult{ + CallID: tc.ID, + Name: tc.Function.Name, + Content: toolResult, + }, + }) } - return messages, hasToolCall + return events, messages }