diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d310b33..5d3b979 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,14 +14,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: - go-version: '1.19' + go-version: '1.26' - name: Run tests run: go test -v ./... diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 379d900..b2ee6c3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,15 +11,15 @@ jobs: strategy: matrix: os: [ubuntu-latest, macos-latest] - go-version: ['1.19', '1.20', '1.21'] + go-version: ['1.26'] runs-on: ${{ matrix.os }} steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go-version }} @@ -39,22 +39,22 @@ jobs: - name: Run tests run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./... - - name: Upload coverage - if: matrix.os == 'ubuntu-latest' && matrix.go-version == '1.21' - uses: codecov/codecov-action@v3 - with: - file: ./coverage.txt + # - name: Upload coverage + # if: matrix.os == 'ubuntu-latest' && matrix.go-version == '1.26' + # uses: codecov/codecov-action@v3 + # with: + # file: ./coverage.txt # lint: # runs-on: ubuntu-latest # steps: # - name: Checkout code - # uses: actions/checkout@v4 + # uses: actions/checkout@v5 # - name: Set up Go - # uses: actions/setup-go@v5 + # uses: actions/setup-go@v6 # with: - # go-version: '1.21' + # go-version: '1.26' # - name: Run golangci-lint # uses: golangci/golangci-lint-action@v3 @@ -85,12 +85,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: - go-version: '1.21' + go-version: '1.26' - name: Build env: diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b453db..b7cc2d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Built-in tools: file operations, shell execution, code search - Permission confirmation system with risk classification - Multi-LLM provider support (Anthropic Claude, OpenAI GPT) -- Project context awareness (AICODER.md, Git status, dependencies) +- Project context awareness (.AICODER.md, Git status, dependencies) - 11 slash commands: /help, /clear, /history, /undo, /diff, /commit, /cost, /model, /config, /init, /exit - Session management with snapshots and undo - Token usage tracking and cost estimation diff --git a/README.md b/README.md index 812c807..c77bb58 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ - **内置工具集**:文件读写编辑、Shell 命令执行、全局代码搜索 - **权限确认机制**:危险操作前要求用户确认,支持黑名单和白名单 - **多 LLM 提供商**:Anthropic Claude、OpenAI GPT(兼容任何 OpenAI 格式端点) -- **项目上下文感知**:自动读取 `AICODER.md`、Git 状态、项目依赖 +- **项目上下文感知**:自动读取 `.AICODER.md`、Git 状态、项目依赖 - **斜杠命令**:`/diff`、`/undo`、`/commit`、`/cost` 等 11 个内置命令 - **纯 Go 标准库**:无外部依赖,单二进制,跨平台 @@ -113,7 +113,7 @@ $ aicoder | `/cost` | 查看 Token 用量和费用估算 | | `/model [name]` | 查看或切换 AI 模型 | | `/config` | 查看当前配置 | -| `/init` | 在当前目录生成 AICODER.md 模板 | +| `/init` | 在当前目录生成 .AICODER.md 模板 | | `/exit` | 退出程序 | --- @@ -143,9 +143,9 @@ $ aicoder --- -## AICODER.md 项目配置 +## .AICODER.md 项目配置 -在项目根目录创建 `AICODER.md`(或运行 `/init`),AI 会在每次会话开始时自动加载它作为项目级系统提示词: +在项目根目录创建 `.AICODER.md`(或运行 `/init`),AI 会在每次会话开始时自动加载它作为项目级系统提示词: ```markdown # 项目说明 diff --git a/cmd/root.go b/cmd/root.go index 19f3e74..72f7176 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -10,13 +10,14 @@ import ( "strings" "syscall" - "github.com/chzyer/readline" "github.com/iminders/aicoder/internal/agent" "github.com/iminders/aicoder/internal/config" "github.com/iminders/aicoder/internal/llm" anthropicprovider "github.com/iminders/aicoder/internal/llm/anthropic" + deepseekprovider "github.com/iminders/aicoder/internal/llm/deepseek" openaiprovider "github.com/iminders/aicoder/internal/llm/openai" "github.com/iminders/aicoder/internal/logger" + "github.com/iminders/aicoder/internal/skills" "github.com/iminders/aicoder/internal/slash" "github.com/iminders/aicoder/internal/ui" "github.com/iminders/aicoder/pkg/version" @@ -25,6 +26,7 @@ import ( _ "github.com/iminders/aicoder/internal/tools/filesystem" _ "github.com/iminders/aicoder/internal/tools/search" _ "github.com/iminders/aicoder/internal/tools/shell" + ) // flags holds CLI flag values. @@ -106,6 +108,11 @@ func Execute() { // Init logger logger.Init(flags.verbose) + // Load skills (built-ins + user custom) + if err := skills.Load(); err != nil { + logger.Warn("skill load error: %v", err) + } + // Build provider provider, err := buildProvider(cfg) if err != nil { @@ -170,88 +177,8 @@ func runOneShot(a *agent.Agent, prompt string) { } func runInteractive(a *agent.Agent, cfg *config.Config) { - slashHandler := slash.NewHandler(a.Session(), cfg) - - // Setup readline with tab completion - completer := readline.NewPrefixCompleter() - for _, cmd := range slash.AllCommands() { - completer.Children = append(completer.Children, readline.PcItem(cmd.Name)) - } - - rl, err := readline.NewEx(&readline.Config{ - Prompt: "\033[1;34m> \033[0m", - HistoryFile: getHistoryFile(), - AutoComplete: completer, - InterruptPrompt: "^C", - EOFPrompt: "exit", - }) - if err != nil { - // Fallback to basic reader if readline fails - runInteractiveBasic(a, cfg) - return - } - defer rl.Close() - - // Setup signal handling for Ctrl+C (interrupt current task, not exit) - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, os.Interrupt) - - var cancelCurrent context.CancelFunc - - go func() { - for range sigCh { - if cancelCurrent != nil { - fmt.Println("\n\033[33m[任务已中断,输入新的指令继续]\033[0m") - cancelCurrent() - cancelCurrent = nil - } - } - }() - - for { - line, err := rl.Readline() - if err != nil { - if err == readline.ErrInterrupt { - if cancelCurrent != nil { - cancelCurrent() - cancelCurrent = nil - } - continue - } else if err == io.EOF { - fmt.Println("\n再见!") - break - } - continue - } - input := strings.TrimSpace(line) - if input == "" { - continue - } - - // Handle slash commands - if strings.HasPrefix(input, "/") { - handled, shouldExit := slashHandler.Handle(input) - if shouldExit { - fmt.Println("再见!") - return - } - if handled { - continue - } - } - - // Run agent - ctx, cancel := context.WithCancel(context.Background()) - cancelCurrent = cancel - - ui.PrintDivider() - if err := a.Run(ctx, input); err != nil && ctx.Err() == nil { - ui.PrintError(err.Error()) - } - cancel() - cancelCurrent = nil - ui.PrintDivider() - } + // Use simple interactive mode (bubbletea TUI conflicts with agent's streaming output) + runInteractiveBasic(a, cfg) } func isPipeInput() bool { @@ -276,8 +203,15 @@ func buildProvider(cfg *config.Config) (llm.Provider, error) { return nil, fmt.Errorf("未找到 OpenAI API Key,请设置 OPENAI_API_KEY 环境变量") } return openaiprovider.New(apiKey, cfg.BaseURL, cfg.Model), nil + case "deepseek": + // For local deployments, API key is optional + // If baseURL is set to localhost/127.0.0.1, allow empty API key + if apiKey == "" && (cfg.BaseURL == "" || (!strings.Contains(cfg.BaseURL, "localhost") && !strings.Contains(cfg.BaseURL, "127.0.0.1"))) { + return nil, fmt.Errorf("未找到 DeepSeek API Key,请设置 DEEPSEEK_API_KEY 环境变量") + } + return deepseekprovider.New(apiKey, cfg.BaseURL, cfg.Model), nil default: - return nil, fmt.Errorf("不支持的 provider: %s (支持: anthropic, openai)", cfg.Provider) + return nil, fmt.Errorf("不支持的 provider: %s (支持: anthropic, openai, deepseek)", cfg.Provider) } } @@ -315,15 +249,38 @@ func runInteractiveBasic(a *agent.Agent, cfg *config.Config) { for { fmt.Print("\033[1;34m> \033[0m") - line, err := reader.ReadString('\n') - if err != nil { - if err == io.EOF { - fmt.Println("\n再见!") + + // Read input with multi-line support + // Use Ctrl+D (EOF) to submit, or empty line after content + var inputLines []string + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + if len(inputLines) > 0 { + // Submit accumulated input + break + } + fmt.Println("\n再见!") + return + } + continue + } + + trimmed := strings.TrimSpace(line) + + // If empty line and we have content, submit + if trimmed == "" && len(inputLines) > 0 { break } - continue + + // If not empty, accumulate + if trimmed != "" { + inputLines = append(inputLines, line) + } } - input := strings.TrimSpace(line) + + input := strings.TrimSpace(strings.Join(inputLines, "")) if input == "" { continue } @@ -336,6 +293,23 @@ func runInteractiveBasic(a *agent.Agent, cfg *config.Config) { return } if handled { + // Check if /skill queued a skill run + if a.Session().PendingSkillName != "" { + skillName := a.Session().PendingSkillName + prompt := a.Session().PendingPrompt + a.Session().PendingSkillName = "" + a.Session().PendingPrompt = "" + + ctx, cancel := context.WithCancel(context.Background()) + cancelCurrent = cancel + ui.PrintDivider() + if err := a.RunWithSkillByName(ctx, prompt, skillName); err != nil && ctx.Err() == nil { + ui.PrintError(err.Error()) + } + cancel() + cancelCurrent = nil + ui.PrintDivider() + } continue } } diff --git a/docs/PHASE6_7_COMPLETION.md b/docs/PHASE6_7_COMPLETION.md index 4c03bb7..ffab8c1 100644 --- a/docs/PHASE6_7_COMPLETION.md +++ b/docs/PHASE6_7_COMPLETION.md @@ -28,7 +28,7 @@ | `/cost` | ✅ | 展示 Token 消耗和费用估算 | | `/model [name]` | ✅ | 查看或热切换 AI 模型 | | `/config [set key value]` | ✅ | 查看或修改配置并持久化 | -| `/init` | ✅ | 生成 AICODER.md 模板 | +| `/init` | ✅ | 生成 .AICODER.md 模板 | | `/exit`, `/quit`, `/q` | ✅ | 优雅退出程序 | #### 3. Tab 补全功能 (`internal/slash/completion.go`) 🆕 diff --git a/docs/PHASE6_7_SUMMARY.md b/docs/PHASE6_7_SUMMARY.md index 3691fb7..d735e36 100644 --- a/docs/PHASE6_7_SUMMARY.md +++ b/docs/PHASE6_7_SUMMARY.md @@ -62,8 +62,8 @@ This document summarizes the implementation of Phase 6 (Slash Commands) and Phas - Validates values (e.g., provider must be anthropic/openai) - Persists changes to user config file -10. **`/init`** - Initialize AICODER.md template - - Creates AICODER.md in current directory +10. **`/init`** - Initialize .AICODER.md template + - Creates .AICODER.md in current directory - Includes sections: Project Description, Code Standards, Common Commands, Notes - Adds version and timestamp footer diff --git a/docs/arch.md b/docs/arch.md index 7a9d0c4..af486a8 100644 --- a/docs/arch.md +++ b/docs/arch.md @@ -104,7 +104,7 @@ aicoder/ │ │ ├── collector.go # 上下文收集主入口 │ │ ├── git.go # Git 信息(status/diff/log) │ │ ├── project.go # 项目语言/依赖检测 -│ │ ├── aicoder_md.go # AICODER.md 加载与解析 +│ │ ├── aicoder_md.go # .AICODER.md 加载与解析 │ │ └── summarizer.go # 目录结构摘要生成 │ │ │ ├── session/ # 会话管理 @@ -659,7 +659,7 @@ type FileSnapshot struct { **实现状态:** ✅ 已完成 **收集的信息:** -- AICODER.md 内容 +- .AICODER.md 内容 - Git 状态 (分支、修改、最近提交) - 项目类型检测 (Go/Node.js/Python/Rust/Java/Ruby) - 项目根目录 @@ -667,7 +667,7 @@ type FileSnapshot struct { **系统提示词构成:** ``` 基础角色定义 -+ AICODER.md (项目说明) ++ .AICODER.md (项目说明) + 项目环境信息 + Git 状态 ``` @@ -753,7 +753,7 @@ type FileSnapshot struct { **待实现:** - ⏳ 多 Agent 并行任务 - ⏳ Web Dashboard -- ⏳ AICODER.md 模板市场 +- ⏳ .AICODER.md 模板市场 - ⏳ VS Code 插件 ### 8.4 技术债务和改进空间 diff --git a/docs/code-guide.md b/docs/code-guide.md index 34452d6..594d729 100644 --- a/docs/code-guide.md +++ b/docs/code-guide.md @@ -488,7 +488,7 @@ func Collect() (*ProjectContext, error) 收集内容: 1. 查找项目根目录 -2. 读取 AICODER.md +2. 读取 .AICODER.md 3. 收集 Git 信息 4. 检测项目类型 @@ -500,7 +500,7 @@ func (pc *ProjectContext) BuildSystemPrompt() string 组合: - 基础角色定义 -- AICODER.md 内容 +- .AICODER.md 内容 - 项目环境信息 - Git 状态 diff --git a/docs/prd.md b/docs/prd.md index f74dc8b..74e6c25 100644 --- a/docs/prd.md +++ b/docs/prd.md @@ -130,16 +130,16 @@ AI 能够调用一组内置工具来自主完成任务,工具调用前须向 | `/cost` | 查看当前会话已消耗的 Token 及费用估算 | | `/model` | 切换使用的 AI 模型 | | `/config` | 查看或修改配置项 | -| `/init` | 在当前项目初始化 `aicoder` 配置(生成 `AICODER.md`) | +| `/init` | 在当前项目初始化 `aicoder` 配置(生成 `.AICODER.md`) | | `/exit` | 退出程序 | --- ### 2.4 项目上下文感知(Project Context) -#### 2.4.1 AICODER.md +#### 2.4.1 .AICODER.md -项目根目录下的 `AICODER.md` 文件作为**项目级系统提示词**,AI 在每次会话开始时自动加载: +项目根目录下的 `.AICODER.md` 文件作为**项目级系统提示词**,AI 在每次会话开始时自动加载: ```markdown # 项目说明 @@ -394,7 +394,7 @@ go install github.com/iminders/aicoder@latest - [ ] 多 Agent 并行任务 - [ ] 可视化 Web Dashboard(会话历史、费用分析) -- [ ] 团队共享 `AICODER.md` 模板市场 +- [ ] 团队共享 `.AICODER.md` 模板市场 - [ ] IDE 插件(VS Code)集成本 CLI --- diff --git a/docs/todo.md b/docs/todo.md index 2f8c217..942232c 100644 --- a/docs/todo.md +++ b/docs/todo.md @@ -339,7 +339,7 @@ - [ ] 多 Agent 并行任务(任务拆解 + 子 Agent 协同) - [ ] Web Dashboard(会话历史可视化、Token 用量分析) -- [ ] `AICODER.md` 模板市场(团队共享最佳实践) +- [ ] `.AICODER.md` 模板市场(团队共享最佳实践) - [ ] VS Code 插件(内嵌本 CLI,提供侧边栏 UI) - [ ] 企业版:SSO 集成、操作审计日志导出、私有部署 diff --git a/go.mod b/go.mod index 69dd7c6..4ae9283 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,12 @@ module github.com/iminders/aicoder -go 1.19 +go 1.24.0 require ( github.com/charmbracelet/bubbletea v0.25.0 github.com/charmbracelet/glamour v0.6.0 github.com/charmbracelet/lipgloss v0.9.1 - github.com/chzyer/readline v1.5.1 + github.com/mattn/go-isatty v0.0.20 ) require ( @@ -17,21 +17,20 @@ require ( github.com/dlclark/regexp2 v1.4.0 // indirect github.com/gorilla/css v1.0.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-isatty v0.0.18 // indirect github.com/mattn/go-localereader v0.0.1 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect github.com/microcosm-cc/bluemonday v1.0.21 // indirect - github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/reflow v0.3.0 // indirect - github.com/muesli/termenv v0.15.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect - github.com/rivo/uniseg v0.2.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/yuin/goldmark v1.5.2 // indirect github.com/yuin/goldmark-emoji v1.0.1 // indirect golang.org/x/net v0.0.0-20221002022538-bcab6841153b // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.12.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.36.0 // indirect golang.org/x/term v0.6.0 // indirect golang.org/x/text v0.3.8 // indirect ) diff --git a/go.sum b/go.sum index bac4741..ae54981 100644 --- a/go.sum +++ b/go.sum @@ -11,12 +11,6 @@ github.com/charmbracelet/glamour v0.6.0 h1:wi8fse3Y7nfcabbbDuwolqTqMQPMnVPeZhDM2 github.com/charmbracelet/glamour v0.6.0/go.mod h1:taqWV4swIMMbWALc0m7AfE9JkPSU8om2538k9ITBxOc= github.com/charmbracelet/lipgloss v0.9.1 h1:PNyd3jvaJbg4jRHKWXnCj1akQm4rh8dbEzN1p/u1KWg= github.com/charmbracelet/lipgloss v0.9.1/go.mod h1:1mPmG4cxScwUQALAAnacHaigiiHB9Pmr+v1VEawJl6I= -github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= -github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= -github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= -github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= -github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= -github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY= github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -29,33 +23,34 @@ github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98= -github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/microcosm-cc/bluemonday v1.0.21 h1:dNH3e4PSyE4vNX+KlRGHT5KrSvjeUkoNPwEORjffHJg= github.com/microcosm-cc/bluemonday v1.0.21/go.mod h1:ytNkv4RrDrLJ2pqlsSI46O6IVXmZOBBD4SaJyDwwTkM= -github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34= -github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.13.0/go.mod h1:sP1+uffeLaEYpyOTb8pLCUctGcGLnoFjSn4YJK5e2bc= -github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= -github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -66,16 +61,15 @@ github.com/yuin/goldmark-emoji v1.0.1 h1:ctuWEyzGBwiucEqxzwe0SOYDXPAucOrE9NQC18W github.com/yuin/goldmark-emoji v1.0.1/go.mod h1:2w1E6FEWLcDQkoTE+7HU6QF1F6SLlNGjRIBbIZQFqkQ= golang.org/x/net v0.0.0-20221002022538-bcab6841153b h1:6e93nYa3hNqAvLr0pD4PN1fFS+gKzp2zAXqrnTCstqU= golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 890bbec..e6f6790 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -3,6 +3,7 @@ package agent import ( "bufio" "context" + "encoding/json" "fmt" "os" "strings" @@ -13,8 +14,10 @@ import ( "github.com/iminders/aicoder/internal/llm" "github.com/iminders/aicoder/internal/logger" "github.com/iminders/aicoder/internal/session" + "github.com/iminders/aicoder/internal/skills" "github.com/iminders/aicoder/internal/tools" "github.com/iminders/aicoder/internal/ui" + "github.com/mattn/go-isatty" ) // Agent orchestrates the full LLM ↔ tool loop. @@ -55,10 +58,45 @@ func New(cfg *config.Config, provider llm.Provider) *Agent { func (a *Agent) Session() *session.Session { return a.sess } // Run processes one user turn and drives the Agent Loop until the model stops. +// It auto-detects applicable skills and injects their prompts. func (a *Agent) Run(ctx context.Context, userInput string) error { + // Auto-detect skill from input + if skill := skills.Global.Detect(userInput); skill != nil { + return a.RunWithSkill(ctx, userInput, skill) + } + return a.run(ctx, userInput) +} + +// RunWithSkill runs with a specific skill injected as additional context. +func (a *Agent) RunWithSkill(ctx context.Context, userInput string, skill *skills.Skill) error { + ui.PrintInfo(fmt.Sprintf("🎯 技能激活: %s — %s", skill.Name, skill.Description)) + if skill.OutputFile != "" { + ui.PrintInfo(fmt.Sprintf("📄 建议输出文件: %s", skill.OutputFile)) + } + + // Inject skill prompt as a system-level context message for this turn only. + // We add it as a user message prefix so it doesn't pollute the system prompt. + augmented := fmt.Sprintf("[技能上下文: %s]\n%s\n\n---\n用户需求:%s", + skill.Name, skill.Prompt, userInput) + return a.run(ctx, augmented) +} + +// RunWithSkillByName looks up a skill by name and delegates to RunWithSkill. +func (a *Agent) RunWithSkillByName(ctx context.Context, userInput, skillName string) error { + skill := skills.Global.Get(skillName) + if skill == nil { + // Skill not found — run without it but warn + ui.PrintWarn(fmt.Sprintf("Skill %q not found, running without skill context", skillName)) + return a.run(ctx, userInput) + } + return a.RunWithSkill(ctx, userInput, skill) +} + +// Run processes one user turn and drives the Agent Loop until the model stops. +func (a *Agent) run(ctx context.Context, userInput string) error { a.sess.AppendMessage(session.TextMessage("user", userInput)) - for iteration := 0; iteration < 20; iteration++ { + for iteration := 0; iteration < 50; iteration++ { // Build the LLM request req := &llm.Request{ Model: a.cfg.Model, @@ -76,6 +114,11 @@ func (a *Agent) Run(ctx context.Context, userInput string) error { // Consume the stream var toolCalls []llm.ToolUseBlock var textBuf strings.Builder + var thinkingActive bool + var thinkingChars int + + // Check if stdout is a terminal + isTTY := isatty.IsTerminal(os.Stdout.Fd()) for event := range eventCh { select { @@ -86,8 +129,33 @@ func (a *Agent) Run(ctx context.Context, userInput string) error { } switch event.Type { case "text_delta": + // If we were thinking, clear the thinking indicator + if thinkingActive && isTTY { + fmt.Print("\r\033[K") // Clear line + } + thinkingActive = false + thinkingChars = 0 a.renderer.Write(event.Delta) textBuf.WriteString(event.Delta) + case "thinking_delta": + // Show thinking progress without newline (only in TTY mode) + thinkingChars += len(event.Delta) + if !thinkingActive { + thinkingActive = true + } + // Only update indicator in TTY mode to avoid cluttering pipe output + if isTTY { + fmt.Printf("\r\033[90m[Thinking... %d chars]\033[0m", thinkingChars) + } + case "thinking_done": + // DeepSeek R1 thinking complete ( detected) + if thinkingActive && isTTY { + fmt.Print("\r\033[K") // Clear the thinking line + } + thinkingActive = false + thinkingChars = 0 + logger.Debug("thinking complete: %d chars", len(event.Delta)) + // Don't write thinking content to output case "tool_use_end": if event.ToolUse != nil { toolCalls = append(toolCalls, *event.ToolUse) @@ -163,6 +231,8 @@ func (a *Agent) executeToolCall(ctx context.Context, tc llm.ToolUseBlock) sessio // Execute fmt.Printf("\033[36m⚙ 执行 %s...\033[0m\n", tc.Name) + printToolCallPath(tc.Input) + start := time.Now() result, err := tool.Execute(ctx, tc.Input) elapsed := time.Since(start) @@ -186,6 +256,27 @@ func (a *Agent) executeToolCall(ctx context.Context, tc llm.ToolUseBlock) sessio } } +func printToolCallPath(raw json.RawMessage) { + var result struct { + Path *string `json:"path"` + Command *string `json:"command"` + } + + if err := json.Unmarshal(raw, &result); err != nil { + fmt.Println("解析失败") + return + } + + if result.Path != nil { + fmt.Printf("%s\n", *result.Path) + } + + if result.Command != nil { + fmt.Printf("$ %s\n", *result.Command) + } + +} + func buildAssistantMessage(text string, toolCalls []llm.ToolUseBlock) session.Message { msg := session.Message{Role: "assistant"} if text != "" { diff --git a/internal/config/config.go b/internal/config/config.go index fa943ad..769eb00 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -136,6 +136,8 @@ func APIKey(provider string) string { return os.Getenv("ANTHROPIC_API_KEY") case "openai": return os.Getenv("OPENAI_API_KEY") + case "deepseek": + return os.Getenv("DEEPSEEK_API_KEY") case "google": return os.Getenv("GOOGLE_API_KEY") default: diff --git a/internal/context/PHASE4_SUMMARY.md b/internal/context/PHASE4_SUMMARY.md index b9ad7b2..5e3627d 100644 --- a/internal/context/PHASE4_SUMMARY.md +++ b/internal/context/PHASE4_SUMMARY.md @@ -131,7 +131,7 @@ aicoder/ ### 项目上下文收集 (`internal/context/`) - [x] `git.go` - Git 信息采集 ⭐ - [x] `project.go` - 项目类型检测 ⭐ -- [x] `aicoder_md.go` - AICODER.md 加载 ⭐ +- [x] `aicoder_md.go` - .AICODER.md 加载 ⭐ - [x] `summarizer.go` - 目录结构摘要 ⭐ **新增** - [x] `collector.go` - 上下文组合和系统提示生成 ⭐ diff --git a/internal/context/collector.go b/internal/context/collector.go index 5efb945..1a3dcc6 100644 --- a/internal/context/collector.go +++ b/internal/context/collector.go @@ -11,7 +11,7 @@ import ( // ProjectContext holds auto-discovered project information. type ProjectContext struct { RootDir string - AICoderMD string // contents of AICODER.md if found + AICoderMD string // contents of .AICODER.md if found GitInfo string ProjectInfo string DirectoryTree string // directory structure summary @@ -55,7 +55,7 @@ func (c *ProjectContext) SystemPrompt() string { func findAICoderMD(startDir string) string { dir := startDir for { - path := filepath.Join(dir, "AICODER.md") + path := filepath.Join(dir, ".AICODER.md") if data, err := os.ReadFile(path); err == nil { return string(data) } diff --git a/internal/context/summarizer.go b/internal/context/summarizer.go index ac8d1af..784eb45 100644 --- a/internal/context/summarizer.go +++ b/internal/context/summarizer.go @@ -77,7 +77,8 @@ func SummarizeDirectory(rootPath string, config *SummarizerConfig) string { // 渲染为字符串 var sb strings.Builder - sb.WriteString("项目结构:\n") + // sb.WriteString("项目结构(当前根目录为" + filepath.Base(rootPath) + "):\n") + sb.WriteString("项目名称为:" + filepath.Base(rootPath) + ",项目结构:\n") totalFiles := 0 renderTree(&sb, root, "", true, &totalFiles) @@ -245,10 +246,11 @@ func renderTree(sb *strings.Builder, node *DirectoryNode, prefix string, isLast *totalFiles++ } } - } else { - // 根节点 - sb.WriteString(node.Name + "/\n") } + // else { + // 根节点 + // sb.WriteString(node.Name + "/\n") + // } // 递归渲染子节点 for i, child := range node.Children { @@ -310,7 +312,7 @@ func GetImportantFiles(rootPath string) []string { "Gemfile", ".gitignore", ".env.example", - "AICODER.md", + ".AICODER.md", } var files []string diff --git a/internal/llm/deepseek/provider.go b/internal/llm/deepseek/provider.go new file mode 100644 index 0000000..5ad9e5e --- /dev/null +++ b/internal/llm/deepseek/provider.go @@ -0,0 +1,542 @@ +package deepseek + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/iminders/aicoder/internal/llm" + "github.com/iminders/aicoder/internal/logger" + "github.com/iminders/aicoder/internal/session" +) + +// Provider implements the LLM provider interface for DeepSeek API. +// DeepSeek R1 uses OpenAI-compatible API but has special handling for: +// 1. Reasoning tokens wrapped in ... tags +// 2. Token usage may not be provided in streaming responses +type Provider struct { + apiKey string + baseURL string + model string + client *http.Client +} + +// New creates a new DeepSeek provider. +func New(apiKey, baseURL, model string) *Provider { + if baseURL == "" { + baseURL = "https://api.deepseek.com" + } + if model == "" { + model = "deepseek-reasoner" + } + return &Provider{ + apiKey: apiKey, + baseURL: baseURL, + model: model, + client: &http.Client{Timeout: 3000 * time.Second}, + } +} + +// Stream sends a streaming request to DeepSeek API. +func (p *Provider) Stream(ctx context.Context, req *llm.Request) (<-chan llm.StreamEvent, error) { + // Type assert messages from RawMsgs + messages, ok := req.RawMsgs.([]session.Message) + if !ok { + return nil, fmt.Errorf("deepseek provider expects []session.Message") + } + + apiMessages := convertMessages(messages) + + payload := map[string]interface{}{ + "model": p.model, + "messages": apiMessages, + "stream": true, + "max_tokens": req.MaxTokens, + "temperature": 1.0, + } + + // Add tools if provided + if len(req.Tools) > 0 { + payload["tools"] = convertTools(req.Tools) + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/v1/chat/completions", bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + logger.Debug("DeepSeek request to: %s", p.baseURL+"/v1/chat/completions") + logger.Debug("DeepSeek model: %s", p.model) + + httpReq.Header.Set("Content-Type", "application/json") + // Only set Authorization header if API key is provided (not needed for local deployments) + if p.apiKey != "" { + httpReq.Header.Set("Authorization", "Bearer "+p.apiKey) + logger.Debug("DeepSeek: using API key authentication") + } else { + logger.Debug("DeepSeek: no API key (local deployment mode)") + } + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("http request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + logger.Debug("DeepSeek API error response: %s", string(body)) + return nil, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body)) + } + + ch := make(chan llm.StreamEvent, 10) + + go p.readStream(resp.Body, ch, messages) + + return ch, nil +} + +// readStream reads SSE stream from DeepSeek API. +func (p *Provider) readStream(body io.ReadCloser, ch chan<- llm.StreamEvent, messages []session.Message) { + defer close(ch) + defer body.Close() + + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 0, 512*1024), 512*1024) + + var textBuffer strings.Builder + var toolCalls []llm.ToolUseBlock + var inThinking bool + var thinkingBuffer strings.Builder + var usageReceived bool + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var chunk struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + ToolCalls []struct { + Index int `json:"index"` + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage *struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + } `json:"usage"` + } + + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + logger.Debug("parse chunk error: %v", err) + continue + } + + // Handle usage if provided + if chunk.Usage != nil { + usageReceived = true + ch <- llm.StreamEvent{ + Type: "usage", + Input: chunk.Usage.PromptTokens, + Output: chunk.Usage.CompletionTokens, + } + } + + if len(chunk.Choices) == 0 { + continue + } + + choice := chunk.Choices[0] + + // Handle text content with thinking detection + if choice.Delta.Content != "" { + content := choice.Delta.Content + + // Check for thinking tags + if strings.Contains(content, "") { + inThinking = true + // Find where starts and split content + idx := strings.Index(content, "") + if idx > 0 { + // Emit content before + beforeThink := content[:idx] + textBuffer.WriteString(beforeThink) + ch <- llm.StreamEvent{ + Type: "text_delta", + Delta: beforeThink, + } + } + // Start collecting thinking content (skip the tag itself) + afterTag := content[idx+7:] // len("") = 7 + + // Check if is also in this chunk + if strings.Contains(afterTag, "") { + endIdx := strings.Index(afterTag, "") + thinkingContent := afterTag[:endIdx] + thinkingBuffer.WriteString(thinkingContent) + + // Emit thinking_delta for the content + if len(thinkingContent) > 0 { + ch <- llm.StreamEvent{ + Type: "thinking_delta", + Delta: thinkingContent, + } + } + + // Immediately end thinking + inThinking = false + ch <- llm.StreamEvent{ + Type: "thinking_done", + Delta: thinkingBuffer.String(), + } + thinkingBuffer.Reset() + + // Process content after + afterThink := afterTag[endIdx+8:] // len("") = 8 + if len(afterThink) > 0 { + textBuffer.WriteString(afterThink) + ch <- llm.StreamEvent{ + Type: "text_delta", + Delta: afterThink, + } + } + } else { + // No yet, just accumulate + thinkingBuffer.WriteString(afterTag) + if len(afterTag) > 0 { + ch <- llm.StreamEvent{ + Type: "thinking_delta", + Delta: afterTag, + } + } + } + continue + } + + if inThinking { + // Check if this chunk contains + if strings.Contains(content, "") { + idx := strings.Index(content, "") + // Add content before to thinking buffer + beforeEnd := content[:idx] + if len(beforeEnd) > 0 { + thinkingBuffer.WriteString(beforeEnd) + ch <- llm.StreamEvent{ + Type: "thinking_delta", + Delta: beforeEnd, + } + } + + // End thinking immediately + inThinking = false + ch <- llm.StreamEvent{ + Type: "thinking_done", + Delta: thinkingBuffer.String(), + } + thinkingBuffer.Reset() + + // Process content after + afterThink := content[idx+8:] // len("") = 8 + if len(afterThink) > 0 { + textBuffer.WriteString(afterThink) + ch <- llm.StreamEvent{ + Type: "text_delta", + Delta: afterThink, + } + } + } else { + // Still thinking, accumulate + thinkingBuffer.WriteString(content) + ch <- llm.StreamEvent{ + Type: "thinking_delta", + Delta: content, + } + } + } else { + // Normal content, not thinking + textBuffer.WriteString(content) + ch <- llm.StreamEvent{ + Type: "text_delta", + Delta: content, + } + } + } + + // Handle tool calls + if len(choice.Delta.ToolCalls) > 0 { + for _, tc := range choice.Delta.ToolCalls { + if tc.Index >= len(toolCalls) { + // Initialize new tool call + inputJSON, _ := json.Marshal(map[string]string{"_args": ""}) + toolCalls = append(toolCalls, llm.ToolUseBlock{ + ID: tc.ID, + Name: tc.Function.Name, + Input: inputJSON, + }) + } + if tc.Function.Arguments != "" { + // Accumulate arguments as JSON string + var currentInput map[string]string + json.Unmarshal(toolCalls[tc.Index].Input, ¤tInput) + if currentInput == nil { + currentInput = map[string]string{} + } + currentInput["_args"] += tc.Function.Arguments + toolCalls[tc.Index].Input, _ = json.Marshal(currentInput) + } + } + } + + // Handle finish + if choice.FinishReason == "tool_calls" { + // Parse accumulated tool arguments + for i := range toolCalls { + var tempInput map[string]string + json.Unmarshal(toolCalls[i].Input, &tempInput) + if argsStr, ok := tempInput["_args"]; ok && argsStr != "" { + var args map[string]interface{} + if err := json.Unmarshal([]byte(argsStr), &args); err == nil { + toolCalls[i].Input, _ = json.Marshal(args) + } + } + ch <- llm.StreamEvent{ + Type: "tool_use_end", + ToolUse: &toolCalls[i], + } + } + } else if choice.FinishReason == "stop" || choice.FinishReason == "length" { + ch <- llm.StreamEvent{Type: "done"} + } + } + + // If no usage received, estimate tokens client-side + if !usageReceived { + inputTokens := estimateInputTokens(messages) + outputTokens := estimateTokens(textBuffer.String()) + logger.Debug("DeepSeek: no usage from API, estimating: in=%d out=%d", inputTokens, outputTokens) + ch <- llm.StreamEvent{ + Type: "usage", + Input: inputTokens, + Output: outputTokens, + } + } + + if err := scanner.Err(); err != nil { + ch <- llm.StreamEvent{ + Type: "error", + Err: fmt.Errorf("stream read error: %w", err), + } + } +} + +// estimateTokens estimates token count for text. +// Uses similar logic to OpenAI provider. +func estimateTokens(text string) int { + if text == "" { + return 0 + } + + cjkCount := 0 + totalRunes := 0 + + for _, r := range text { + totalRunes++ + // Check if character is CJK + if (r >= 0x4E00 && r <= 0x9FFF) || // CJK Unified Ideographs + (r >= 0x3400 && r <= 0x4DBF) || // CJK Extension A + (r >= 0x20000 && r <= 0x2A6DF) || // CJK Extension B + (r >= 0x2A700 && r <= 0x2B73F) || // CJK Extension C + (r >= 0x2B740 && r <= 0x2B81F) || // CJK Extension D + (r >= 0x2B820 && r <= 0x2CEAF) || // CJK Extension E + (r >= 0x3000 && r <= 0x303F) || // CJK Symbols + (r >= 0xFF00 && r <= 0xFFEF) { // Fullwidth Forms + cjkCount++ + } + } + + cjkRatio := float64(cjkCount) / float64(totalRunes) + + if cjkRatio > 0.5 { + // Mostly CJK: ~1.5 chars per token + return int(float64(totalRunes) / 1.5) + } else { + // Mostly English: ~4 chars per token + return len(text) / 4 + } +} + +// estimateInputTokens estimates input token count. +func estimateInputTokens(messages []session.Message) int { + total := 0 + for _, msg := range messages { + // Add 4 tokens for message structure + total += 4 + + for _, content := range msg.Content { + if content.Type == "text" { + total += estimateTokens(content.Text) + } else if content.Type == "tool_use" { + // Estimate tool_use as JSON + data, _ := json.Marshal(content) + total += estimateTokens(string(data)) + } else if content.Type == "tool_result" { + total += estimateTokens(content.Text) + } + } + } + return total +} + +// convertMessages converts session messages to DeepSeek API format (OpenAI-compatible). +func convertMessages(messages []session.Message) []map[string]interface{} { + var result []map[string]interface{} + + for _, msg := range messages { + // Handle different message types + if msg.Role == "user" { + // Check if this is a tool result message + hasToolResult := false + for _, c := range msg.Content { + if c.Type == "tool_result" { + hasToolResult = true + break + } + } + + if hasToolResult { + // Convert tool results to OpenAI format (role: "tool") + for _, c := range msg.Content { + if c.Type == "tool_result" { + result = append(result, map[string]interface{}{ + "role": "tool", + "tool_call_id": c.ToolUseID, + "content": c.Text, + }) + } + } + } else { + // Regular user message + var textParts []string + for _, c := range msg.Content { + if c.Type == "text" { + textParts = append(textParts, c.Text) + } + } + if len(textParts) > 0 { + result = append(result, map[string]interface{}{ + "role": "user", + "content": strings.Join(textParts, "\n"), + }) + } + } + } else if msg.Role == "assistant" { + // Check if this message has tool calls + var toolCalls []map[string]interface{} + var textContent string + + for _, c := range msg.Content { + if c.Type == "text" { + textContent = c.Text + } else if c.Type == "tool_use" { + // Convert to OpenAI tool_calls format + argsJSON, _ := json.Marshal(c.Input) + toolCalls = append(toolCalls, map[string]interface{}{ + "id": c.ID, + "type": "function", + "function": map[string]interface{}{ + "name": c.Name, + "arguments": string(argsJSON), + }, + }) + } + } + + m := map[string]interface{}{ + "role": "assistant", + } + + if textContent != "" { + m["content"] = textContent + } else { + m["content"] = "" // OpenAI requires content field + } + + if len(toolCalls) > 0 { + m["tool_calls"] = toolCalls + } + + result = append(result, m) + } else { + // System or other roles - simple text + var textParts []string + for _, c := range msg.Content { + if c.Type == "text" { + textParts = append(textParts, c.Text) + } + } + if len(textParts) > 0 { + result = append(result, map[string]interface{}{ + "role": msg.Role, + "content": strings.Join(textParts, "\n"), + }) + } + } + } + + return result +} + +// convertTools converts session tools to DeepSeek API format. +func convertTools(tools []llm.ToolSchema) []map[string]interface{} { + var result []map[string]interface{} + + for _, tool := range tools { + result = append(result, map[string]interface{}{ + "type": "function", + "function": map[string]interface{}{ + "name": tool.Name, + "description": tool.Description, + "parameters": tool.InputSchema, + }, + }) + } + + return result +} + +// Name returns the provider name. +func (p *Provider) Name() string { + return "deepseek" +} + +// CurrentModel returns the current model name. +func (p *Provider) CurrentModel() string { + return p.model +} diff --git a/internal/llm/interface.go b/internal/llm/interface.go index 0261414..f135def 100644 --- a/internal/llm/interface.go +++ b/internal/llm/interface.go @@ -14,7 +14,7 @@ type ToolSchema struct { // StreamEvent is emitted by Provider.Stream for each piece of the response. type StreamEvent struct { - Type string // text_delta | tool_use_start | tool_use_delta | tool_use_end | usage | done | error + Type string // text_delta | tool_use_start | tool_use_delta | tool_use_end | thinking_delta | thinking_done | usage | done | error Delta string ToolUse *ToolUseBlock Input int // input token count (on usage event) diff --git a/internal/logger/logger.go b/internal/logger/logger.go index f174a2d..8bfa8be 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -33,31 +33,31 @@ func Init(verboseMode bool) { func Info(format string, args ...any) { msg := fmt.Sprintf(format, args...) if fileLogger != nil { - fileLogger.Printf("[INFO] " + msg) + fileLogger.Print("[INFO] " + msg) } } func Debug(format string, args ...any) { msg := fmt.Sprintf(format, args...) if fileLogger != nil { - fileLogger.Printf("[DEBUG] " + msg) + fileLogger.Print("[DEBUG] " + msg) } if verbose { - fmt.Fprintf(os.Stderr, "\033[90m[debug] "+msg+"\033[0m\n") + fmt.Fprintf(os.Stderr, "\033[90m[debug] %s\033[0m\n", msg) } } func Error(format string, args ...any) { msg := fmt.Sprintf(format, args...) if fileLogger != nil { - fileLogger.Printf("[ERROR] " + msg) + fileLogger.Print("[ERROR] " + msg) } - fmt.Fprintf(os.Stderr, "\033[31m[error] "+msg+"\033[0m\n") + fmt.Fprintf(os.Stderr, "\033[31m[error] %s\033[0m\n", msg) } func Warn(format string, args ...any) { msg := fmt.Sprintf(format, args...) if fileLogger != nil { - fileLogger.Printf("[WARN] " + msg) + fileLogger.Print("[WARN] " + msg) } } diff --git a/internal/session/session.go b/internal/session/session.go index 399ab0f..0546ec6 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -79,6 +79,11 @@ type Session struct { Snapshots []FileSnapshot Usage TokenUsage Model string + + // PendingSkill / PendingPrompt are set by /skill + // and consumed by the interactive loop to trigger RunWithSkill. + PendingSkillName string + PendingPrompt string } // New creates a fresh session. diff --git a/internal/skills/builtin/apidoc.md b/internal/skills/builtin/apidoc.md new file mode 100644 index 0000000..b0c76f6 --- /dev/null +++ b/internal/skills/builtin/apidoc.md @@ -0,0 +1,88 @@ +--- +name: apidoc +aliases: [API文档, 接口文档, REST文档, OpenAPI, Swagger, api doc] +description: 生成 REST API 接口文档(OpenAPI 3.0 格式) +triggers: + - API文档 + - 接口文档 + - openapi + - swagger + - REST.*文档 + - 写.*接口文档 +output_file: openapi.yaml +--- + +# Skill: API 接口文档 + +你现在是一名 API 设计专家,请生成规范的 API 文档。 + +## 输出格式 + +同时生成两份文档: + +### 1. OpenAPI 3.0 YAML(机器可读) + +```yaml +openapi: 3.0.3 +info: + title: API 名称 + version: 1.0.0 + description: | + API 说明 + contact: + name: 团队名 +paths: + /resource: + get: + summary: 简短描述 + operationId: uniqueId + tags: [分组] + parameters: [] + responses: + '200': + description: 成功 + content: + application/json: + schema: + $ref: '#/components/schemas/Resource' + example: {} + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' +components: + schemas: {} + responses: {} + securitySchemes: {} +``` + +### 2. Markdown 可读版 + +每个接口包含: +- **接口名称与描述** +- **HTTP 方法 + 路径** +- **认证要求** +- **请求参数表**(参数名 | 位置 | 类型 | 必须 | 说明 | 示例) +- **请求体示例**(JSON) +- **响应码说明表**(状态码 | 含义 | 示例) +- **成功响应示例** +- **错误响应示例** +- **cURL 调用示例** + +## 设计规范 + +- 路径使用小写 kebab-case,如 `/user-profiles` +- 资源使用复数名词,如 `/users` 而非 `/user` +- 分页参数统一:`page`、`page_size`、`total` +- 错误响应统一格式:`{"error": {"code": "...", "message": "...", "details": []}}` +- 时间字段使用 ISO 8601:`2024-01-01T00:00:00Z` +- ID 字段说明类型(UUID/自增整数/ULID) + +## 操作方式 + +1. 先读取源码(路由定义、handler、model 文件) +2. 提取所有端点和数据结构 +3. 生成 OpenAPI YAML 文件 +4. 生成人可读的 Markdown 文档 +5. 校验 YAML 语法正确性 +6. 可调用tool function write_file diff --git a/internal/skills/builtin/arch.md b/internal/skills/builtin/arch.md new file mode 100644 index 0000000..fb31f45 --- /dev/null +++ b/internal/skills/builtin/arch.md @@ -0,0 +1,79 @@ +--- +name: arch +aliases: [架构设计, 系统设计, 架构文档, architecture, system design] +description: 编写技术架构设计文档(Architecture Design Document) +triggers: + - 架构设计 + - 系统设计 + - architecture design + - 技术方案 + - 设计文档 + - arch doc +output_file: arch.md +--- + +# Skill: 技术架构设计文档 + +你现在是一名资深软件架构师,请按照以下结构生成一份完整的架构设计文档。 + +## 必须包含的章节 + +1. **架构概览** + - 系统目标与设计原则(如:高可用、可扩展、安全优先) + - 架构风格(微服务/单体/Serverless/事件驱动等) + - 关键技术决策及其理由(ADR 格式) + +2. **系统上下文(Context Diagram)** + - 系统边界:哪些在范围内,哪些是外部系统 + - 主要用户类型和外部依赖 + - 用 ASCII 图或 Mermaid 图表示 + +3. **核心模块划分** + - 模块列表:名称、职责、对外接口 + - 模块间依赖关系(避免循环依赖) + - 完整目录结构(精确到文件级) + +4. **数据架构** + - 数据模型(核心实体及关系) + - 存储选型(关系型/NoSQL/缓存/消息队列)及理由 + - 数据流向(CRUD 操作如何流经系统) + +5. **接口设计** + - 对外 API 规范(REST/gRPC/GraphQL) + - 内部服务间通信协议 + - 关键接口的请求/响应示例 + +6. **关键技术选型** + - 技术栈表格(组件 | 选型 | 版本 | 选型理由 | 备选方案) + +7. **部署架构** + - 部署环境(云/私有/混合) + - 容器化/编排方案 + - 网络拓扑、负载均衡、CDN + +8. **安全架构** + - 认证与授权模型 + - 数据加密(传输/存储) + - 敏感操作审计 + +9. **可观测性** + - 日志规范(结构化日志字段定义) + - 指标(核心 SLI/SLO) + - 链路追踪方案 + +10. **演进路线** + - v1 → v2 的关键架构变更点 + - 已知技术债及偿还计划 + +## 图表规范 + +- 所有架构图优先使用 Mermaid(可在 GitHub/GitLab 直接渲染) +- 图表必须有标题和简短说明 +- 每个模块用方框,外部系统用圆角方框 + +## 输出要求 + +- 可调用tool function write_file +- 完整 Markdown,含所有 Mermaid 图 +- 对每个关键决策说明"为什么选它,为什么不选备选方案" +- 末尾附:术语表、参考资料 diff --git a/internal/skills/builtin/codedoc.md b/internal/skills/builtin/codedoc.md new file mode 100644 index 0000000..5cdf47e --- /dev/null +++ b/internal/skills/builtin/codedoc.md @@ -0,0 +1,82 @@ +--- +name: codedoc +aliases: [代码文档, 代码说明, 代码注释, code documentation, code doc] +description: 为代码生成完整的技术文档和注释 +triggers: + - 代码(文档|说明|注释) + - 给.*写.*注释 + - 写.*文档.*代码 + - code documentation + - document.*code + - API.*文档 +output_file: "" +--- + +# Skill: 代码文档生成 + +你现在是一名技术文档专家,请为给定的代码生成完整、准确的文档。 + +## 文档覆盖范围 + +### 1. 文件级文档 +- 文件职责(一句话概括) +- 包/模块说明 +- 关键依赖说明 + +### 2. 函数/方法文档(每个公开函数必须有) +``` +// FunctionName 简短描述(动词开头)。 +// +// 详细描述(可选,超过一行时使用)。 +// +// Parameters: +// - param1: 含义、取值范围、是否可为空 +// - param2: 含义 +// +// Returns: +// - 返回值含义 +// - 错误情况 +// +// Example: +// result, err := FunctionName(arg1, arg2) +// +// Note: 特殊行为、副作用、并发安全性 +``` + +### 3. 类型/结构体文档 +- 类型用途 +- 每个字段的含义、单位、约束 +- 零值/默认值语义 + +### 4. 常量/枚举文档 +- 每个值的含义 +- 使用场景 + +### 5. 算法说明(复杂逻辑必须有) +- 算法名称和时间/空间复杂度 +- 核心步骤编号说明 +- 边界条件处理 + +## 注释风格规范 + +按语言自动选择: +- **Go**: GoDoc 格式 +- **Python**: Google Style Docstring +- **TypeScript/JavaScript**: JSDoc +- **Java/Kotlin**: JavaDoc +- **Rust**: `///` rustdoc +- **C/C++**: Doxygen + +## 额外输出(根据文件复杂度决定) + +- **README 段落**:该模块的使用示例 +- **流程图**:复杂业务逻辑的 Mermaid 流程图 +- **接口说明表**:公开 API 汇总表 + +## 操作方式 + +1. 先用 `read_file` 读取目标文件 +2. 分析代码结构,识别所有公开符号 +3. 为每个符号生成文档注释 +4. 用 `edit_file` 将注释写入原文件 +5. 输出文档变更摘要 diff --git a/internal/skills/builtin/debug.md b/internal/skills/builtin/debug.md new file mode 100644 index 0000000..a6a730a --- /dev/null +++ b/internal/skills/builtin/debug.md @@ -0,0 +1,93 @@ +--- +name: debug +aliases: [调试, debug, 排查问题, 找bug, troubleshoot] +description: 系统性问题排查与调试 +triggers: + - 调试 + - debug + - 排查.*问题 + - 找.*bug + - 为什么.*报错 + - troubleshoot + - 报错了 + - 有个bug + - 有.*bug +output_file: "" +--- + +# Skill: 系统性调试 + +你现在是一名调试专家,请系统性地分析和定位问题。 + +## 调试方法论(5 Whys + 二分法) + +### Step 1: 信息收集(不做假设) +- 完整的错误信息/堆栈跟踪 +- 复现步骤(能稳定复现 vs 偶发) +- 环境信息(OS、语言版本、依赖版本) +- 最近的变更(什么时候开始出现?) +- 用 `search_files` 搜索相关错误关键词 + +### Step 2: 假设生成 +列出所有可能的根因(脑暴,不过滤),按可能性排序 + +### Step 3: 逐一验证(从最可能开始) +- 用 `run_command` 执行诊断命令 +- 添加临时日志(最小改动) +- 二分法缩小问题范围 + +### Step 4: 根因确认 +- 确认能稳定复现 +- 确认修复后问题消失 +- 确认没有引入新问题 + +### Step 5: 修复与预防 +- 最小化修复(只改必要的) +- 添加回归测试(防止复现) +- 记录根因和修复方式 + +## 常用诊断命令模板 + +```bash +# 查看进程状态 +ps aux | grep + +# 查看端口占用 +lsof -i : +netstat -tlnp | grep + +# 查看系统日志 +journalctl -u -n 100 --no-pager + +# 检查磁盘/内存 +df -h && free -h + +# Go 程序调试 +go run -race . # 竞态检测 +GODEBUG=gctrace=1 ./app # GC 追踪 +dlv debug . # 交互式调试器 +``` + +## 调试输出格式 + +``` +## 问题诊断报告 + +**问题描述**: ... +**复现率**: 必现/偶发(X%) +**影响范围**: ... + +### 根因分析 +1. 现象: ... +2. 第一层原因: ...(Why?) +3. 第二层原因: ...(Why?) +4. 根因: ... + +### 修复方案 +**立即修复**: [最小改动,最快恢复] +**根本修复**: [彻底解决] +**预防措施**: [防止复发] + +### 回归测试 +[验证修复的测试步骤] +``` diff --git a/internal/skills/builtin/devplan.md b/internal/skills/builtin/devplan.md new file mode 100644 index 0000000..af77777 --- /dev/null +++ b/internal/skills/builtin/devplan.md @@ -0,0 +1,80 @@ +--- +name: devplan +aliases: [开发计划, 开发排期, 迭代计划, sprint plan, development plan, todo] +description: 生成详细的开发计划与任务分解文档 +triggers: + - 开发计划 + - 开发排期 + - 迭代计划 + - sprint + - 任务分解 + - todo.*计划 + - development plan +output_file: todo.md +--- + +# Skill: 开发计划文档 + +你现在是一名技术项目经理,请按照以下结构生成一份详细的开发计划文档。 + +## 必须包含的章节 + +1. **项目概况** + - 目标版本(如 v1.0 MVP) + - 总工期与团队规模 + - 关键里程碑(Milestone)日期表 + +2. **阶段划分(Phase)** + 每个 Phase 包含: + - 阶段名称与目标 + - 工期(周数) + - **完成标准**(DoD: Definition of Done,可验收的具体标准) + +3. **任务清单(细粒度)** + 每个任务包含: + - [ ] 任务描述(动词开头,如"实现 xxx"、"编写 xxx") + - 所属模块 + - 优先级(P0/P1/P2) + - 预估工时 + - 负责人(如知道) + - 依赖任务 + +4. **技术风险登记表** + | 风险 | 可能性 | 影响 | 缓解措施 | 触发条件 | + |---|---|---|---|---| + +5. **依赖与阻塞项** + - 外部依赖(第三方 API、硬件、其他团队) + - 已知阻塞项及解决负责人 + +6. **测试计划** + - 单元测试覆盖率目标 + - 集成测试场景列表 + - 性能测试基准 + +7. **发布检查清单(Release Checklist)** + - [ ] 所有 P0 测试通过 + - [ ] 文档更新完毕 + - [ ] 安全扫描通过 + - [ ] 性能基准达标 + - … + +## 任务状态图例 + +``` +⬜ 未开始 🔵 进行中 ✅ 已完成 🔶 阻塞中 ❌ 取消 +``` + +## 工时估算规范 + +- 最小粒度:0.5 天(4h) +- 每个任务不超过 3 天,超过则继续拆分 +- 预留 20% buffer 用于联调和意外 + +## 输出要求 + +- 完整 Markdown,任务使用 `- [ ]` checkbox 格式 +- 按 Phase 组织,每 Phase 有完成标准 +- 末尾附风险登记表和发布清单 +- 如工期信息不足,给出合理估算并标注假设条件 +- 可调用tool function write_file diff --git a/internal/skills/builtin/prd.md b/internal/skills/builtin/prd.md new file mode 100644 index 0000000..bb6a934 --- /dev/null +++ b/internal/skills/builtin/prd.md @@ -0,0 +1,47 @@ +--- +name: prd +aliases: [需求文档, 产品需求, 需求说明, product requirements, requirements doc] +description: 编写专业的产品需求说明文档(PRD) +triggers: + - 写.*需求(文档|说明) + - 产品需求 + - PRD + - product requirement + - 需求文档 +output_file: prd.md +--- + +# Skill: 产品需求文档(PRD) + +你现在是一名资深产品经理,请按照以下结构生成一份专业的 PRD 文档。 + +## 必须包含的章节 + +1. **产品概述** — 背景、目标用户、产品定位(一句话定位) +2. **问题陈述** — 用户痛点、现有方案的不足 +3. **核心功能需求** — 按优先级(P0/P1/P2)列出,每条需求含: + - 用户故事(As a... I want... So that...) + - 验收标准(Acceptance Criteria) + - 排除范围(Out of Scope) +4. **非功能性需求** — 性能、安全、可用性、兼容性 +5. **用户体验设计要点** — 关键流程、UI 规范 +6. **数据需求** — 需要收集/存储/分析的数据 +7. **技术约束** — 已知的技术限制和依赖 +8. **成功指标(KPI)** — 可量化的成功标准 +9. **里程碑计划** — 分阶段交付计划(MVP → GA) +10. **风险与假设** — 主要风险及缓解措施 + +## 写作规范 + +- 使用第三人称,避免"我觉得"等主观表述 +- 每条需求必须可测试、可验收 +- 优先级标注:P0=必须有,P1=重要,P2=锦上添花 +- 数据尽量量化(如"响应时间 < 200ms"而非"响应速度快") +- 输出为标准 Markdown,可直接保存为 `prd.md` + +## 输出要求 + +- 可调用tool function write_file +- 生成完整的 Markdown 文档 +- 文档末尾注明版本、作者(如已知)、日期 +- 如果信息不足,在对应章节标注 `[待补充]` 并列出需要确认的问题 diff --git a/internal/skills/builtin/refactor.md b/internal/skills/builtin/refactor.md new file mode 100644 index 0000000..280d46a --- /dev/null +++ b/internal/skills/builtin/refactor.md @@ -0,0 +1,63 @@ +--- +name: refactor +aliases: [重构, 代码重构, refactor, clean code, 代码优化] +description: 系统性代码重构:识别坏味道、制定方案、安全执行 +triggers: + - 重构 + - refactor + - 代码优化 + - clean.*code + - 改善代码 +output_file: "" +--- + +# Skill: 代码重构 + +你现在是一名代码质量专家,请系统性地分析并重构代码。 + +## 重构流程(必须严格遵守顺序) + +### Phase 1: 分析(不修改任何文件) +1. 用 `read_file` 和 `search_files` 全面了解代码结构 +2. 识别代码坏味道(Bad Smells): + - 过长函数(> 30 行) + - 过深嵌套(> 3 层) + - 重复代码(DRY 违反) + - 神秘数字(Magic Numbers) + - 过长参数列表(> 4 个参数) + - 数据泥团(Data Clumps) + - 过度耦合(Feature Envy) + - 注释掉的死代码 +3. 输出分析报告(不修改代码) + +### Phase 2: 制定方案 +列出具体重构操作,每条包含: +- 重构手法名称(如:Extract Method、Rename Variable) +- 影响范围(哪些文件/函数) +- 预期收益 +- 风险评估 +- **必须:先有测试再重构(如无测试,先补测试)** + +### Phase 3: 执行(最小步骤,每步可验证) +1. 每次只做一个独立的重构操作 +2. 修改后立即运行测试(`run_command "go test ./..."`) +3. 测试失败则立即回滚(`/undo`) +4. 测试通过才进行下一个重构 + +### Phase 4: 验收 +- 运行全量测试,确保无回归 +- 对比重构前后的代码复杂度指标 +- 更新受影响的文档和注释 + +## 重构原则 + +- **小步快跑**:每次改动最小化,保持可回滚 +- **测试保障**:重构不改变外部行为,测试是保障 +- **命名优先**:好名字比好注释更重要 +- **SOLID 原则**:单一职责、开闭、里氏替换、接口隔离、依赖倒置 + +## 禁止行为 + +- ❌ 重构同时修复 Bug(分开做) +- ❌ 重构同时添加新功能(分开做) +- ❌ 在没有测试的情况下做大规模重构 diff --git a/internal/skills/builtin/review.md b/internal/skills/builtin/review.md new file mode 100644 index 0000000..409fb90 --- /dev/null +++ b/internal/skills/builtin/review.md @@ -0,0 +1,81 @@ +--- +name: review +aliases: [代码审查, code review, 审查代码, PR review] +description: 对代码或 PR diff 进行专业的代码审查 +triggers: + - 代码审查 + - code review + - review.*代码 + - 审查.*代码 + - PR review +output_file: "" +--- + +# Skill: 代码审查(Code Review) + +你现在是一名资深工程师,请对代码进行全面、专业、建设性的审查。 + +## 审查维度(按优先级) + +### 🔴 P0 — 必须修复(阻塞合并) +- **正确性**:逻辑错误、边界条件遗漏、并发问题 +- **安全性**:SQL 注入、XSS、权限漏洞、敏感信息硬编码 +- **数据完整性**:缺少事务、竞态条件、数据丢失风险 + +### 🟡 P1 — 强烈建议修复 +- **性能**:N+1 查询、不必要的全表扫描、内存泄漏 +- **错误处理**:未处理的错误、panic 风险、不完整的回滚 +- **测试**:缺少关键测试、测试覆盖率不足 + +### 🔵 P2 — 建议优化 +- **可读性**:命名不清、注释缺失、函数过长 +- **可维护性**:重复代码、硬编码魔法值、过度耦合 +- **一致性**:与现有代码风格不符 + +### ⚪ 建议(可选) +- 更优雅的实现方式 +- 可以复用的已有工具/库 + +## 审查报告格式 + +``` +## 代码审查报告 + +**文件**: xxx.go +**审查人**: AI Assistant +**审查时间**: YYYY-MM-DD + +### 总结 +[一段话总体评价,正面肯定 + 关键问题] + +### 问题列表 + +#### [P0] 并发安全问题 (line 42) +**问题**: `map` 在多 goroutine 下读写未加锁 +**风险**: 运行时 panic +**建议**: +```go +// 修改前 +m[key] = value + +// 修改后 +mu.Lock() +m[key] = value +mu.Unlock() +``` + +[按 P0 → P1 → P2 顺序列出所有问题] + +### 亮点 +[值得表扬的好代码] + +### 建议 +[整体性改进建议] +``` + +## 操作方式 + +1. 读取所有相关文件(或分析 `git diff` 输出) +2. 系统性分析每个审查维度 +3. 生成结构化审查报告 +4. 对 P0 问题,可直接用 `edit_file` 提供修复建议(不自动应用,征求用户同意) diff --git a/internal/skills/builtin/testplan.md b/internal/skills/builtin/testplan.md new file mode 100644 index 0000000..c292ab4 --- /dev/null +++ b/internal/skills/builtin/testplan.md @@ -0,0 +1,76 @@ +--- +name: testplan +aliases: [测试计划, 测试方案, 测试用例, test plan, test cases] +description: 生成测试计划、测试用例和测试报告模板 +triggers: + - 测试计划 + - 测试方案 + - 测试用例 + - test plan + - 写测试 + - 生成测试 +output_file: testplan.md +--- + +# Skill: 测试计划与用例 + +你现在是一名资深测试工程师,请生成全面的测试文档。 + +## 测试策略金字塔 + +``` + [E2E 测试] ← 少量,覆盖核心用户路径 + [集成测试] ← 适量,覆盖模块间交互 + [单元测试] ← 大量,覆盖所有业务逻辑 +``` + +## 必须包含的测试类型 + +### 1. 单元测试用例 +每个测试用例包含: +- 测试目标(函数/方法名) +- 前置条件(Arrange) +- 执行步骤(Act) +- 断言(Assert) +- 边界值:空值、最大值、最小值、非法输入 +- 异常路径:错误处理、超时、网络异常 + +### 2. 集成测试场景 +- 核心业务流程的完整路径测试 +- 数据库/缓存/消息队列的读写验证 +- 第三方服务的 Mock 测试 + +### 3. 性能测试基准 +| 指标 | 目标值 | 测试工具 | +|---|---|---| +| P50 响应时间 | < Xms | wrk/k6 | +| P99 响应时间 | < Xms | | +| QPS | > X | | +| 错误率 | < 0.1% | | + +### 4. 安全测试清单 +- [ ] SQL 注入 +- [ ] XSS +- [ ] CSRF +- [ ] 权限越权 +- [ ] 敏感信息泄露 +- [ ] 速率限制 + +## 测试用例格式 + +``` +TC-001: [模块名] 正常场景描述 +前置条件: xxx +输入: xxx +预期输出: xxx +优先级: P0 +自动化: 是/否 +``` + +## 操作方式 + +1. 读取源码,分析所有公开函数和业务逻辑 +2. 识别关键路径和边界条件 +3. 生成对应语言的测试代码框架(Go/pytest/Jest 等) +4. 用 `write_file` 写入测试文件 +5. 运行已有测试套件验证(`run_command`) diff --git a/internal/skills/skill.go b/internal/skills/skill.go new file mode 100644 index 0000000..6bbe685 --- /dev/null +++ b/internal/skills/skill.go @@ -0,0 +1,276 @@ +// Package skills implements the aicoder skill system. +// Skills are Markdown files with a YAML front-matter header that define +// specialized AI personas and structured output guidance. They are embedded +// into the binary at compile time and can also be loaded from +// ~/.aicoder/skills/ for user-defined skills. +package skills + +import ( + "bufio" + "bytes" + "embed" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" +) + +//go:embed builtin/*.md +var builtinFS embed.FS + +// Skill represents one loaded skill definition. +type Skill struct { + // Metadata (from YAML front-matter) + Name string // canonical name, e.g. "prd" + Aliases []string // alternative trigger phrases + Description string // one-line description + Triggers []string // regex patterns for auto-detection + OutputFile string // suggested output filename (empty = no file) + + // Content (everything after the front-matter) + Prompt string // full skill prompt injected as additional context + + // Compiled matchers + compiled []*regexp.Regexp +} + +// Matches returns true if the input string triggers this skill. +func (s *Skill) Matches(input string) bool { + lower := strings.ToLower(input) + // Check explicit name / aliases + for _, alias := range append([]string{s.Name}, s.Aliases...) { + if strings.Contains(lower, strings.ToLower(alias)) { + return true + } + } + // Check regex triggers + for _, re := range s.compiled { + if re.MatchString(input) { + return true + } + } + return false +} + +// ─── Registry ───────────────────────────────────────────────────────────────── + +// Registry holds all loaded skills. +type Registry struct { + skills []*Skill + byName map[string]*Skill +} + +// Global is the default registry, populated at startup. +var Global = &Registry{byName: map[string]*Skill{}} + +// Load loads all built-in skills plus any user skills from ~/.aicoder/skills/. +func Load() error { + if err := loadBuiltins(); err != nil { + return fmt.Errorf("load built-in skills: %w", err) + } + loadUserSkills() // best-effort; errors are silently ignored + return nil +} + +func loadBuiltins() error { + entries, err := builtinFS.ReadDir("builtin") + if err != nil { + return err + } + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".md") { + continue + } + data, err := builtinFS.ReadFile("builtin/" + e.Name()) + if err != nil { + return err + } + skill, err := parseSkill(data) + if err != nil { + return fmt.Errorf("parse %s: %w", e.Name(), err) + } + Global.register(skill) + } + return nil +} + +func loadUserSkills() { + home, err := os.UserHomeDir() + if err != nil { + return + } + dir := filepath.Join(home, ".aicoder", "skills") + entries, err := os.ReadDir(dir) + if err != nil { + return // directory doesn't exist yet — normal + } + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(e.Name(), ".md") { + continue + } + data, err := os.ReadFile(filepath.Join(dir, e.Name())) + if err != nil { + continue + } + skill, err := parseSkill(data) + if err != nil { + continue + } + skill.Name = "user:" + skill.Name // namespace user skills + Global.register(skill) + } +} + +func (r *Registry) register(s *Skill) { + r.skills = append(r.skills, s) + r.byName[s.Name] = s + // Also index by aliases + for _, alias := range s.Aliases { + key := strings.ToLower(alias) + if _, exists := r.byName[key]; !exists { + r.byName[key] = s + } + } +} + +// All returns all registered skills. +func (r *Registry) All() []*Skill { return r.skills } + +// Get returns a skill by name (case-insensitive), or nil. +func (r *Registry) Get(name string) *Skill { + if s, ok := r.byName[strings.ToLower(name)]; ok { + return s + } + // Partial match on name + lower := strings.ToLower(name) + for _, s := range r.skills { + if strings.HasPrefix(s.Name, lower) { + return s + } + } + return nil +} + +// Detect finds the best-matching skill for a user input string. +// Returns nil if no skill matches. +func (r *Registry) Detect(input string) *Skill { + for _, s := range r.skills { + if s.Matches(input) { + return s + } + } + return nil +} + +// ─── Parser ─────────────────────────────────────────────────────────────────── + +// parseSkill parses a Markdown file with YAML-like front-matter. +// Front-matter is delimited by --- lines. +func parseSkill(data []byte) (*Skill, error) { + scanner := bufio.NewScanner(bytes.NewReader(data)) + s := &Skill{} + inFrontMatter := false + frontMatterDone := false + var bodyLines []string + lineNum := 0 + + for scanner.Scan() { + line := scanner.Text() + lineNum++ + + if lineNum == 1 && line == "---" { + inFrontMatter = true + continue + } + if inFrontMatter && line == "---" { + inFrontMatter = false + frontMatterDone = true + continue + } + if inFrontMatter { + parseFrontMatterLine(s, line) + continue + } + if frontMatterDone { + bodyLines = append(bodyLines, line) + } + } + + if s.Name == "" { + return nil, fmt.Errorf("skill is missing 'name' field in front-matter") + } + + s.Prompt = strings.Join(bodyLines, "\n") + + // Compile regex triggers + for _, pattern := range s.Triggers { + re, err := regexp.Compile("(?i)" + pattern) + if err == nil { + s.compiled = append(s.compiled, re) + } + } + + return s, nil +} + +// parseFrontMatterLine handles one key: value line from the front-matter. +func parseFrontMatterLine(s *Skill, line string) { + // Handle list items: " - value" + trimmed := strings.TrimSpace(line) + + // Key: value + if idx := strings.Index(trimmed, ":"); idx > 0 { + key := strings.TrimSpace(trimmed[:idx]) + val := strings.TrimSpace(trimmed[idx+1:]) + + switch key { + case "name": + s.Name = val + case "description": + s.Description = val + case "output_file": + s.OutputFile = strings.Trim(val, `"`) + case "aliases", "triggers": + // Inline list: [a, b, c] + if strings.HasPrefix(val, "[") && strings.HasSuffix(val, "]") { + items := splitList(val[1 : len(val)-1]) + switch key { + case "aliases": + s.Aliases = items + case "triggers": + s.Triggers = items + } + } + // Multi-line list handled by list item case below + } + return + } + + // List item: " - value" + if strings.HasPrefix(trimmed, "- ") { + // We need context from the previous key — track with a simple heuristic: + // if the skill has more triggers than aliases, we're in triggers; else aliases + val := strings.TrimSpace(strings.TrimPrefix(trimmed, "- ")) + val = strings.Trim(val, `"'`) + // Assign to whichever list was last being populated + // Heuristic: triggers tend to have regex chars + if strings.ContainsAny(val, ".*+?()[]") || len(s.Triggers) < len(s.Aliases) { + s.Triggers = append(s.Triggers, val) + } else { + s.Aliases = append(s.Aliases, val) + } + } +} + +func splitList(s string) []string { + var result []string + for _, part := range strings.Split(s, ",") { + v := strings.TrimSpace(part) + v = strings.Trim(v, `"'`) + if v != "" { + result = append(result, v) + } + } + return result +} diff --git a/internal/skills/skill_integration_test.go b/internal/skills/skill_integration_test.go new file mode 100644 index 0000000..0f71617 --- /dev/null +++ b/internal/skills/skill_integration_test.go @@ -0,0 +1,199 @@ +package skills + +import ( + "strings" + "testing" +) + +// TestAllBuiltinTriggersDetected verifies every built-in skill can be detected +// by at least one of its documented trigger phrases. +func TestAllBuiltinTriggersDetected(t *testing.T) { + if err := Load(); err != nil { + t.Fatal(err) + } + + cases := []struct { + phrase string + wantName string + }{ + // prd + {"帮我写产品需求文档", "prd"}, + {"请生成一份PRD", "prd"}, + {"product requirement document", "prd"}, + {"需求说明书怎么写", "prd"}, + // arch + {"帮我做架构设计", "arch"}, + {"请写一份系统设计文档", "arch"}, + {"architecture design for microservice", "arch"}, + {"技术方案怎么写", "arch"}, + // devplan + {"制定开发计划", "devplan"}, + {"帮我做开发排期", "devplan"}, + {"sprint planning", "devplan"}, + {"任务分解一下", "devplan"}, + // codedoc + {"给这个文件写代码注释", "codedoc"}, + {"帮我写代码文档", "codedoc"}, + {"code documentation for this module", "codedoc"}, + // apidoc + {"生成API文档", "apidoc"}, + {"写一份接口文档", "apidoc"}, + {"openapi spec", "apidoc"}, + {"swagger documentation", "apidoc"}, + // testplan + {"写测试计划", "testplan"}, + {"生成测试用例", "testplan"}, + {"test plan for this feature", "testplan"}, + // refactor + {"帮我重构这段代码", "refactor"}, + {"refactor this function", "refactor"}, + {"代码优化一下", "refactor"}, + // review + {"code review一下", "review"}, + {"帮我做代码审查", "review"}, + {"review this PR", "review"}, + // debug + {"帮我debug这个问题", "debug"}, + {"程序报错了怎么排查", "debug"}, + {"troubleshoot this error", "debug"}, + {"这里有个bug", "debug"}, + } + + for _, c := range cases { + t.Run(c.phrase, func(t *testing.T) { + s := Global.Detect(c.phrase) + if s == nil { + t.Errorf("Detect(%q) = nil, want %q", c.phrase, c.wantName) + return + } + if s.Name != c.wantName { + t.Errorf("Detect(%q) = %q, want %q", c.phrase, s.Name, c.wantName) + } + }) + } +} + +// TestSkillPromptsNotEmpty verifies every skill has a meaningful prompt. +func TestSkillPromptsNotEmpty(t *testing.T) { + if err := Load(); err != nil { + t.Fatal(err) + } + for _, s := range Global.All() { + if len(strings.TrimSpace(s.Prompt)) < 100 { + t.Errorf("skill %q prompt too short (%d chars)", s.Name, len(s.Prompt)) + } + } +} + +// TestSkillOutputFiles verifies output_file values are sane. +func TestSkillOutputFiles(t *testing.T) { + if err := Load(); err != nil { + t.Fatal(err) + } + withFile := map[string]string{ + "prd": "prd.md", + "arch": "arch.md", + "devplan": "todo.md", + "apidoc": "openapi.yaml", + "testplan": "testplan.md", + } + for name, want := range withFile { + s := Global.Get(name) + if s == nil { + t.Errorf("skill %q not found", name) + continue + } + if s.OutputFile != want { + t.Errorf("skill %q: output_file = %q, want %q", name, s.OutputFile, want) + } + } + // These skills edit files in-place, no output_file + noFile := []string{"codedoc", "refactor", "review", "debug"} + for _, name := range noFile { + s := Global.Get(name) + if s == nil { + t.Errorf("skill %q not found", name) + continue + } + if s.OutputFile != "" { + t.Errorf("skill %q should have empty output_file, got %q", name, s.OutputFile) + } + } +} + +// TestNoFalsePositives verifies common non-skill inputs don't match any skill. +func TestNoFalsePositives(t *testing.T) { + if err := Load(); err != nil { + t.Fatal(err) + } + inputs := []string{ + "hello", + "今天天气怎么样", + "what time is it", + "ls -la", + "帮我看看这段代码", + "解释一下这个函数", + "git status", + "how does TCP work", + } + for _, input := range inputs { + s := Global.Detect(input) + if s != nil { + t.Errorf("Detect(%q) = %q, want no match (false positive)", input, s.Name) + } + } +} + +// TestRegistryAll verifies All() returns all 9 built-in skills. +func TestRegistryAll(t *testing.T) { + // Reset and reload + Global = &Registry{byName: map[string]*Skill{}} + if err := Load(); err != nil { + t.Fatal(err) + } + all := Global.All() + if len(all) < 9 { + t.Errorf("expected at least 9 built-in skills, got %d", len(all)) + } +} + +// TestUserSkillCreation simulates user skill file parsing. +func TestUserSkillCreation(t *testing.T) { + raw := []byte(`--- +name: mycompany-pr +aliases: [PR模板, pull request] +description: 生成符合公司规范的 PR 描述 +triggers: + - PR模板 + - pull request.*描述 +output_file: "" +--- + +# 公司 PR 规范 + +## PR 标题格式 +[类型] 简短描述 + +## 必填章节 +- **背景**: 为什么要做这个改动 +- **方案**: 怎么做的 +- **测试**: 如何验证 +- **风险**: 可能影响什么 +`) + s, err := parseSkill(raw) + if err != nil { + t.Fatalf("parseSkill failed: %v", err) + } + if s.Name != "mycompany-pr" { + t.Errorf("unexpected name: %s", s.Name) + } + if len(s.Aliases) < 2 { + t.Errorf("expected 2 aliases, got %d", len(s.Aliases)) + } + if !strings.Contains(s.Prompt, "PR 规范") { + t.Error("prompt should contain skill content") + } + if s.OutputFile != "" { + t.Errorf("expected empty output_file, got %q", s.OutputFile) + } +} diff --git a/internal/skills/skill_test.go b/internal/skills/skill_test.go new file mode 100644 index 0000000..5ff5ac3 --- /dev/null +++ b/internal/skills/skill_test.go @@ -0,0 +1,156 @@ +package skills + +import ( + "strings" + "testing" +) + +func TestLoadBuiltins(t *testing.T) { + r := &Registry{byName: map[string]*Skill{}} + if err := loadBuiltins(); err != nil { + t.Fatalf("loadBuiltins failed: %v", err) + } + // Verify Global registry has been populated + if err := Load(); err != nil { + t.Fatal(err) + } + all := Global.All() + if len(all) == 0 { + t.Fatal("expected at least one built-in skill") + } + _ = r +} + +func TestBuiltinSkillNames(t *testing.T) { + _ = Load() + expectedNames := []string{"prd", "arch", "devplan", "codedoc", "apidoc", "testplan", "refactor", "review", "debug"} + for _, name := range expectedNames { + s := Global.Get(name) + if s == nil { + t.Errorf("expected skill %q to be registered", name) + continue + } + if s.Description == "" { + t.Errorf("skill %q has empty description", name) + } + if s.Prompt == "" { + t.Errorf("skill %q has empty prompt", name) + } + } +} + +func TestSkillDetect(t *testing.T) { + _ = Load() + cases := []struct { + input string + expected string + }{ + {"帮我写一个产品需求文档", "prd"}, + {"请生成架构设计文档", "arch"}, + {"制定开发计划", "devplan"}, + {"给这段代码写文档注释", "codedoc"}, + {"生成API文档", "apidoc"}, + {"写测试计划", "testplan"}, + {"帮我重构这个函数", "refactor"}, + {"review一下这段代码", "review"}, + {"这个程序报错了,帮我debug", "debug"}, + } + for _, c := range cases { + s := Global.Detect(c.input) + if s == nil { + t.Errorf("Detect(%q) = nil, want %q", c.input, c.expected) + continue + } + if s.Name != c.expected { + t.Errorf("Detect(%q) = %q, want %q", c.input, s.Name, c.expected) + } + } +} + +func TestSkillGet(t *testing.T) { + _ = Load() + s := Global.Get("prd") + if s == nil { + t.Fatal("expected to find prd skill") + } + if s.Name != "prd" { + t.Errorf("unexpected name: %s", s.Name) + } +} + +func TestSkillGetCaseInsensitive(t *testing.T) { + _ = Load() + if Global.Get("PRD") == nil { + t.Error("expected case-insensitive Get to work") + } +} + +func TestSkillGetPartialMatch(t *testing.T) { + _ = Load() + // "dev" should match "devplan" + s := Global.Get("dev") + if s == nil { + t.Error("expected partial match on 'dev'") + } +} + +func TestParseSkill(t *testing.T) { + raw := []byte(`--- +name: myskill +aliases: [alias1, alias2] +description: A test skill +triggers: + - test.*skill + - myskill +output_file: out.md +--- + +# My Skill Content + +This is the skill prompt. +`) + s, err := parseSkill(raw) + if err != nil { + t.Fatalf("parseSkill failed: %v", err) + } + if s.Name != "myskill" { + t.Errorf("unexpected name: %s", s.Name) + } + if s.Description != "A test skill" { + t.Errorf("unexpected description: %s", s.Description) + } + if s.OutputFile != "out.md" { + t.Errorf("unexpected output_file: %s", s.OutputFile) + } + if !strings.Contains(s.Prompt, "My Skill Content") { + t.Errorf("prompt missing content: %s", s.Prompt) + } + if len(s.Aliases) < 2 { + t.Errorf("expected 2 aliases, got %d: %v", len(s.Aliases), s.Aliases) + } +} + +func TestSkillMatches(t *testing.T) { + s := &Skill{ + Name: "test", + Aliases: []string{"测试技能"}, + } + if !s.Matches("这是一个测试技能") { + t.Error("expected alias to match") + } + if !s.Matches("test skill") { + t.Error("expected name to match") + } + if s.Matches("completely unrelated") { + t.Error("expected no match") + } +} + +func TestDetectNoMatch(t *testing.T) { + _ = Load() + // Something completely unrelated + s := Global.Detect("the weather is nice today 今天天气很好") + if s != nil { + t.Errorf("expected no match, got skill %q", s.Name) + } +} diff --git a/internal/slash/commands.go b/internal/slash/commands.go index 7bae8f7..7ee374b 100644 --- a/internal/slash/commands.go +++ b/internal/slash/commands.go @@ -10,6 +10,7 @@ import ( "github.com/iminders/aicoder/internal/config" "github.com/iminders/aicoder/internal/session" + "github.com/iminders/aicoder/internal/skills" "github.com/iminders/aicoder/internal/ui" "github.com/iminders/aicoder/pkg/diff" "github.com/iminders/aicoder/pkg/version" @@ -17,12 +18,32 @@ import ( // Handler processes a slash command string. Returns true if the program should exit. type Handler struct { - sess *session.Session - cfg *config.Config + sess *session.Session + cfg *config.Config + printer func(...interface{}) // For printing output (can be tea.Program.Println or fmt.Println) } func NewHandler(sess *session.Session, cfg *config.Config) *Handler { - return &Handler{sess: sess, cfg: cfg} + return &Handler{ + sess: sess, + cfg: cfg, + printer: func(args ...interface{}) { fmt.Println(args...) }, // Default to fmt.Println + } +} + +// SetPrinter sets the print function (use tea.Program.Println for TUI mode) +func (h *Handler) SetPrinter(printer func(...interface{})) { + h.printer = printer +} + +// println is a helper that uses the configured printer +func (h *Handler) println(args ...interface{}) { + h.printer(args...) +} + +// printf is a helper for formatted printing +func (h *Handler) printf(format string, args ...interface{}) { + h.printer(fmt.Sprintf(format, args...)) } // Handle dispatches a slash command. Returns (handled, shouldExit). @@ -63,6 +84,8 @@ func (h *Handler) Handle(input string) (handled bool, shouldExit bool) { h.cmdSave() case "/tools": h.cmdTools() + case "/skill", "/skills": + return h.cmdSkill(args) default: ui.PrintWarn(fmt.Sprintf("未知命令: %s (输入 /help 查看所有命令)", cmd)) } @@ -83,13 +106,17 @@ func (h *Handler) cmdHelp() { │ /cost │ 查看 Token 用量和费用估算 │ │ /model [m] │ 查看或切换 AI 模型 │ │ /config │ 查看当前配置 │ -│ /init │ 在当前目录初始化 AICODER.md │ +│ /init │ 在当前目录初始化 .AICODER.md │ │ /sessions │ 列出历史会话 │ │ /save │ 手动保存当前会话 │ │ /tools │ 列出所有可用工具 │ +│ /skill list │ 列出所有内置 Skill │ +│ /skill <名称> │ 显示 Skill 详情 │ +│ /skill <名称> <提示> │ 以指定 Skill 模式运行 │ +│ /skill new <名称> │ 创建自定义 Skill 模板 │ │ /exit │ 退出程序 │ └───────────────┴──────────────────────────────────────────────────┘` - fmt.Println(help) + h.printer(help) } func (h *Handler) cmdClear() { @@ -103,7 +130,7 @@ func (h *Handler) cmdHistory() { ui.PrintInfo("暂无对话历史") return } - fmt.Printf("\033[1m对话历史 (%d 条消息):\033[0m\n", len(msgs)) + h.printf("\033[1m对话历史 (%d 条消息):\033[0m\n", len(msgs)) ui.PrintDivider() for i, m := range msgs { if m.Role == "system" { @@ -130,7 +157,7 @@ func (h *Handler) cmdHistory() { break } } - fmt.Printf("%s[%d] %s %s\033[0m\n", color, i, icon, preview) + h.printf("%s[%d] %s %s\033[0m\n", color, i, icon, preview) } ui.PrintDivider() } @@ -154,13 +181,13 @@ func (h *Handler) cmdDiff() { ui.PrintInfo("本次会话暂无文件变更") return } - fmt.Printf("\033[1m本次会话文件变更 (%d 个文件):\033[0m\n", len(changes)) + h.printf("\033[1m本次会话文件变更 (%d 个文件):\033[0m\n", len(changes)) ui.PrintDivider() for path, after := range changes { before, _ := os.ReadFile(path) d := diff.ColorDiff(string(before), string(after), path) if d != "" { - fmt.Print(d) + h.printer(d) } } } @@ -186,7 +213,7 @@ func (h *Handler) cmdCommit(args []string) { return } ui.PrintSuccess("已提交: " + msg) - fmt.Println(string(out)) + h.println(string(out)) } func (h *Handler) cmdCost() { @@ -194,25 +221,25 @@ func (h *Handler) cmdCost() { model := h.sess.Model est := usage.CostEstimate(model) ui.PrintDivider() - fmt.Printf(" \033[1m模型:\033[0m %s\n", model) - fmt.Printf(" \033[1m输入 tokens:\033[0m %d\n", usage.InputTokens) - fmt.Printf(" \033[1m输出 tokens:\033[0m %d\n", usage.OutputTokens) - fmt.Printf(" \033[1m费用估算:\033[0m $%.4f USD\n", est) + h.printf(" \033[1m模型:\033[0m %s\n", model) + h.printf(" \033[1m输入 tokens:\033[0m %d\n", usage.InputTokens) + h.printf(" \033[1m输出 tokens:\033[0m %d\n", usage.OutputTokens) + h.printf(" \033[1m费用估算:\033[0m $%.4f USD\n", est) ui.PrintDivider() } func (h *Handler) cmdModel(args []string) { if len(args) == 0 { - fmt.Printf("当前模型: \033[1m%s\033[0m\n", h.sess.Model) - fmt.Println("可用模型示例:") + h.printf("当前模型: \033[1m%s\033[0m\n", h.sess.Model) + h.println("可用模型示例:") models := []string{ "claude-opus-4-5", "claude-sonnet-4-5", "claude-haiku-4-5-20251001", "gpt-4o", "gpt-4o-mini", } for _, m := range models { - fmt.Printf(" - %s\n", m) + h.printf(" - %s\n", m) } - fmt.Println("用法: /model ") + h.println("用法: /model ") return } newModel := args[0] @@ -223,26 +250,26 @@ func (h *Handler) cmdModel(args []string) { func (h *Handler) cmdConfig(args []string) { _ = args - fmt.Printf("\033[1m当前配置:\033[0m\n") + h.printf("\033[1m当前配置:\033[0m\n") ui.PrintDivider() - fmt.Printf(" provider: %s\n", h.cfg.Provider) - fmt.Printf(" model: %s\n", h.cfg.Model) - fmt.Printf(" maxTokens: %d\n", h.cfg.MaxTokens) - fmt.Printf(" autoApprove: %v\n", h.cfg.AutoApprove) - fmt.Printf(" autoApproveReads: %v\n", h.cfg.AutoApproveReads) - fmt.Printf(" backupOnWrite: %v\n", h.cfg.BackupOnWrite) - fmt.Printf(" theme: %s\n", h.cfg.Theme) - fmt.Printf(" language: %s\n", h.cfg.Language) + h.printf(" provider: %s\n", h.cfg.Provider) + h.printf(" model: %s\n", h.cfg.Model) + h.printf(" maxTokens: %d\n", h.cfg.MaxTokens) + h.printf(" autoApprove: %v\n", h.cfg.AutoApprove) + h.printf(" autoApproveReads: %v\n", h.cfg.AutoApproveReads) + h.printf(" backupOnWrite: %v\n", h.cfg.BackupOnWrite) + h.printf(" theme: %s\n", h.cfg.Theme) + h.printf(" language: %s\n", h.cfg.Language) if h.cfg.Proxy != "" { - fmt.Printf(" proxy: %s\n", h.cfg.Proxy) + h.printf(" proxy: %s\n", h.cfg.Proxy) } ui.PrintDivider() } func (h *Handler) cmdInit() { - path := filepath.Join("AICODER.md") + path := filepath.Join(".AICODER.md") if _, err := os.Stat(path); err == nil { - ui.PrintWarn("AICODER.md 已存在,跳过初始化") + ui.PrintWarn(".AICODER.md 已存在,跳过初始化") return } template := fmt.Sprintf(`# 项目说明 @@ -257,14 +284,19 @@ func (h *Handler) cmdInit() { # 注意事项 +# 工具使用规范 + +允许使用web_search工具进行联网搜索 +允许git clone 到third_party目录, 但禁止直接修改第三方代码 + _由 aicoder v%s 生成于 %s_ `, version.Version, time.Now().Format("2006-01-02")) if err := os.WriteFile(path, []byte(template), 0644); err != nil { - ui.PrintError("创建 AICODER.md 失败: " + err.Error()) + ui.PrintError("创建 .AICODER.md 失败: " + err.Error()) return } - ui.PrintSuccess("已创建 AICODER.md,请编辑它来描述您的项目") + ui.PrintSuccess("已创建 .AICODER.md,请编辑它来描述您的项目") } func (h *Handler) cmdSessions() { home, err := os.UserHomeDir() @@ -279,7 +311,7 @@ func (h *Handler) cmdSessions() { return } - fmt.Printf("\033[1m历史会话 (%d 个):\033[0m\n", len(entries)) + h.printf("\033[1m历史会话 (%d 个):\033[0m\n", len(entries)) ui.PrintDivider() // Show most recent 20, newest first @@ -300,10 +332,10 @@ func (h *Handler) cmdSessions() { if info != nil { modTime = info.ModTime().Format("2006-01-02 15:04") } - fmt.Printf(" \033[36m%s\033[0m %s %s\n", name[:min(len(name), 20)], modTime, size) + h.printf(" \033[36m%s\033[0m %s %s\n", name[:min(len(name), 20)], modTime, size) } ui.PrintDivider() - fmt.Println(" 提示:会话文件保存在", dir) + h.println(" 提示:会话文件保存在", dir) } func (h *Handler) cmdSave() { @@ -317,7 +349,7 @@ func (h *Handler) cmdSave() { func (h *Handler) cmdTools() { // Import tools package to list all registered tools // We use a type assertion via the session's known tool names - fmt.Printf("\033[1m已注册工具:\033[0m\n") + h.printf("\033[1m已注册工具:\033[0m\n") ui.PrintDivider() // Tool metadata is stored in the global registry; we query it via the session @@ -332,19 +364,174 @@ func (h *Handler) cmdTools() { {"run_command", "中", "执行 Shell 命令"}, {"run_background", "中", "后台启动长时进程"}, {"grep_search", "低", "全目录正则搜索"}, - {"web_search", "低", "联网搜索(需 TAVILY_API_KEY)"}, + {"web_search", "低", "联网搜索"}, } for _, r := range rows { riskColor := "\033[32m" if r.risk == "中" { riskColor = "\033[33m" } if r.risk == "高" { riskColor = "\033[31m" } - fmt.Printf(" %-18s %s[%s]\033[0m %s\n", r.name, riskColor, r.risk, r.desc) + h.printf(" %-18s %s[%s]\033[0m %s\n", r.name, riskColor, r.risk, r.desc) } ui.PrintDivider() - fmt.Println(" MCP 工具以 __ 格式列出(连接后可见)") + h.println(" MCP 工具以 __ 格式列出(连接后可见)") } func min(a, b int) int { if a < b { return a } return b } + + + +// cmdSkill handles: /skill list | /skill | /skill | /skill new +// It returns (handled bool, shouldExit bool) so the caller can optionally +// hand off to the agent with a skill override. +func (h *Handler) cmdSkill(args []string) (bool, bool) { + + if len(args) == 0 || args[0] == "list" { + h.printSkillList() + return true, false + } + + if args[0] == "new" { + if len(args) < 2 { + ui.PrintWarn("用法: /skill new <名称>") + return true, false + } + h.createUserSkill(args[1]) + return true, false + } + + // /skill [optional prompt...] + skillName := args[0] + sk := skills.Global.Get(skillName) + if sk == nil { + ui.PrintError(fmt.Sprintf("未找到 Skill %q,输入 /skill list 查看所有可用 Skill", skillName)) + return true, false + } + + if len(args) == 1 { + // Show skill details + h.printSkillDetail(sk) + return true, false + } + + // /skill — signal caller to run agent with this skill + // We store the pending skill+prompt in session metadata and return a special + // sentinel so the interactive loop can handle it. + prompt := strings.Join(args[1:], " ") + fmt.Printf("\033[90m[Skill %q 已激活,正在处理: %s]\033[0m\n", sk.Name, prompt) + // Inject skill directly — store on session for the loop to pick up + h.sess.PendingSkillName = sk.Name + h.sess.PendingPrompt = prompt + return true, false +} + +func (h *Handler) printSkillList() { + all := skills.Global.All() + fmt.Printf("\033[1m内置 Skill (%d 个):\033[0m\n", len(all)) + ui.PrintDivider() + for _, s := range all { + tag := "\033[34m[内置]\033[0m" + if strings.HasPrefix(s.Name, "user:") { + tag = "\033[32m[自定义]\033[0m" + } + outFile := "" + if s.OutputFile != "" { + outFile = fmt.Sprintf(" \033[90m→ %s\033[0m", s.OutputFile) + } + fmt.Printf(" %-12s %s %s%s\n", s.Name, tag, s.Description, outFile) + } + ui.PrintDivider() + fmt.Println(" 用法: /skill <名称> <你的需求描述>") + fmt.Println(" 示例: /skill prd 电商平台用户评价系统") + fmt.Println(" 自动触发: 直接描述需求,aicoder 会自动匹配合适的 Skill") +} + +func (h *Handler) printSkillDetail(sk *skills.Skill) { + fmt.Printf("\n\033[1m🎯 Skill: %s\033[0m\n", sk.Name) + ui.PrintDivider() + fmt.Printf(" \033[1m描述:\033[0m %s\n", sk.Description) + if len(sk.Aliases) > 0 { + fmt.Printf(" \033[1m别名:\033[0m %s\n", strings.Join(sk.Aliases, ", ")) + } + if len(sk.Triggers) > 0 { + fmt.Printf(" \033[1m触发词:\033[0m %s\n", strings.Join(sk.Triggers[:min(3, len(sk.Triggers))], " | ")) + } + if sk.OutputFile != "" { + fmt.Printf(" \033[1m输出文件:\033[0m %s\n", sk.OutputFile) + } + ui.PrintDivider() + // Show first 10 lines of the prompt as preview + lines := strings.Split(sk.Prompt, "\n") + preview := lines + truncated := false + if len(lines) > 12 { + preview = lines[:12] + truncated = true + } + fmt.Println("\033[90m" + strings.Join(preview, "\n") + "\033[0m") + if truncated { + fmt.Printf("\033[90m... (共 %d 行) ...\033[0m\n", len(lines)) + } + ui.PrintDivider() + fmt.Printf(" 运行: /skill %s <你的需求描述>\n\n", sk.Name) +} + +func (h *Handler) createUserSkill(name string) { + home, err := os.UserHomeDir() + if err != nil { + ui.PrintError("无法获取 home 目录: " + err.Error()) + return + } + dir := filepath.Join(home, ".aicoder", "skills") + if err := os.MkdirAll(dir, 0700); err != nil { + ui.PrintError("创建目录失败: " + err.Error()) + return + } + path := filepath.Join(dir, name+".md") + if _, err := os.Stat(path); err == nil { + ui.PrintWarn(fmt.Sprintf("Skill %q 已存在: %s", name, path)) + return + } + + template := fmt.Sprintf(`--- +name: %s +aliases: [] +description: 在这里填写 Skill 的一句话描述 +triggers: + - 触发关键词1 + - 触发关键词2 +output_file: output.md +--- + +# Skill: %s + +在这里描述这个 Skill 的职责和使用场景。 + +## 输出结构 + +1. **章节一** — 说明 +2. **章节二** — 说明 + +## 写作规范 + +- 规范1 +- 规范2 + +## 操作方式 + +1. 先用工具收集信息 +2. 按结构生成文档 +3. 保存到输出文件 + +_由 aicoder v%s 生成于 %s_ +`, name, name, version.Version, time.Now().Format("2006-01-02")) + + if err := os.WriteFile(path, []byte(template), 0644); err != nil { + ui.PrintError("创建 Skill 失败: " + err.Error()) + return + } + ui.PrintSuccess(fmt.Sprintf("已创建自定义 Skill 模板: %s", path)) + fmt.Println(" 请编辑该文件,然后重启 aicoder 或输入 /skill list 刷新") +} diff --git a/internal/slash/completion.go b/internal/slash/completion.go index fd56037..35008b7 100644 --- a/internal/slash/completion.go +++ b/internal/slash/completion.go @@ -21,7 +21,7 @@ func AllCommands() []CommandInfo { {"/cost", "查看 Token 用量和费用估算", "/cost"}, {"/model", "查看或切换 AI 模型", "/model [name]"}, {"/config", "查看或修改配置", "/config [set key value]"}, - {"/init", "在当前目录初始化 AICODER.md", "/init"}, + {"/init", "在当前目录初始化 .AICODER.md", "/init"}, {"/sessions", "列出历史会话", "/sessions"}, {"/save", "保存当前会话到文件", "/save [filename]"}, {"/tools", "列出可用工具", "/tools"}, diff --git a/internal/tools/filesystem/tools.go b/internal/tools/filesystem/tools.go index b701db4..b3c7213 100644 --- a/internal/tools/filesystem/tools.go +++ b/internal/tools/filesystem/tools.go @@ -72,19 +72,27 @@ func (t *WriteFileTool) Execute(_ context.Context, raw json.RawMessage) (*tools. if err := json.Unmarshal(raw, &in); err != nil { return &tools.Result{IsError: true, Content: err.Error()}, nil } + if in.Path == "" { + return &tools.Result{IsError: true, Content: "path cannot be empty"}, nil + } if err := checkSandbox(in.Path); err != nil { return &tools.Result{IsError: true, Content: err.Error()}, nil } - // Snapshot before state - var before []byte - before, _ = os.ReadFile(in.Path) if err := os.MkdirAll(filepath.Dir(in.Path), 0755); err != nil { return &tools.Result{IsError: true, Content: err.Error()}, nil } + + // Snapshot before state + var before []byte + if _, err := os.Stat(in.Path); err == nil { + before, _ = os.ReadFile(in.Path) + } + if err := os.WriteFile(in.Path, []byte(in.Content), 0644); err != nil { return &tools.Result{IsError: true, Content: err.Error()}, nil } + if SnapshotFunc != nil { SnapshotFunc("write_file", fmt.Sprintf("%d", time.Now().UnixNano()), in.Path, before, []byte(in.Content)) } @@ -98,7 +106,7 @@ type EditFileTool struct{} func (t *EditFileTool) Name() string { return "edit_file" } func (t *EditFileTool) Risk() tools.RiskLevel { return tools.RiskMedium } func (t *EditFileTool) Description() string { - return "Edit a file by replacing an exact old_string with new_string. old_string must match exactly once." + return "Edit a file by replacing an exact old_string with new_string if it exist. old_string must match exactly once." } func (t *EditFileTool) Schema() json.RawMessage { return json.RawMessage(`{ diff --git a/internal/tools/search/web_search.go b/internal/tools/search/web_search.go index 8be48f6..5fbd596 100644 --- a/internal/tools/search/web_search.go +++ b/internal/tools/search/web_search.go @@ -6,10 +6,8 @@ import ( "fmt" "io" "net/http" - "net/url" "os" "strings" - "time" "github.com/iminders/aicoder/internal/tools" ) @@ -19,7 +17,7 @@ type WebSearchTool struct{} func (t *WebSearchTool) Name() string { return "web_search" } func (t *WebSearchTool) Risk() tools.RiskLevel { return tools.RiskLow } func (t *WebSearchTool) Description() string { - return "Search the web using a search API. Requires SEARCH_API_KEY and SEARCH_ENGINE_ID environment variables to be set. Returns top search results with titles, snippets, and URLs." + return "Search the web using Tavily API. Returns top search results with titles, snippets, and URLs." } func (t *WebSearchTool) Schema() json.RawMessage { @@ -46,12 +44,13 @@ type searchResult struct { Snippet string `json:"snippet"` } -type googleSearchResponse struct { - Items []struct { +type tavilySearchResponse struct { + Results []struct { Title string `json:"title"` - Link string `json:"link"` - Snippet string `json:"snippet"` - } `json:"items"` + URL string `json:"url"` + Content string `json:"content"` + Score float64 `json:"score"` + } `json:"results"` } func (t *WebSearchTool) Execute(ctx context.Context, raw json.RawMessage) (*tools.Result, error) { @@ -72,24 +71,21 @@ func (t *WebSearchTool) Execute(ctx context.Context, raw json.RawMessage) (*tool } // 检查环境变量 - apiKey := os.Getenv("SEARCH_API_KEY") - engineID := os.Getenv("SEARCH_ENGINE_ID") + apiKey := os.Getenv("TAVILY_API_KEY") - if apiKey == "" || engineID == "" { + if apiKey == "" { return &tools.Result{ IsError: true, - Content: "Web search is not configured. Please set SEARCH_API_KEY and SEARCH_ENGINE_ID environment variables.\n\n" + - "To use Google Custom Search:\n" + - "1. Get API key from: https://developers.google.com/custom-search/v1/overview\n" + - "2. Create search engine at: https://programmablesearchengine.google.com/\n" + - "3. Set environment variables:\n" + - " export SEARCH_API_KEY=\"your-api-key\"\n" + - " export SEARCH_ENGINE_ID=\"your-engine-id\"", + Content: "Web search is not configured. Please set TAVILY_API_KEY environment variable.\n\n" + + "To use Tavily Search:\n" + + "1. Get API key from: https://tavily.com/\n" + + "2. Set environment variable:\n" + + " export TAVILY_API_KEY=\"tvly-your-api-key\"", }, nil } // 执行搜索 - results, err := t.googleSearch(ctx, in.Query, in.NumResults, in.Language, apiKey, engineID) + results, err := t.tavilySearch(ctx, in.Query, in.NumResults, apiKey) if err != nil { return &tools.Result{IsError: true, Content: fmt.Sprintf("Search failed: %v", err)}, nil } @@ -117,29 +113,32 @@ func (t *WebSearchTool) Execute(ctx context.Context, raw json.RawMessage) (*tool }, nil } -// googleSearch 使用 Google Custom Search API 进行搜索 -func (t *WebSearchTool) googleSearch(ctx context.Context, query string, numResults int, language, apiKey, engineID string) ([]searchResult, error) { - // 构建 API URL - baseURL := "https://www.googleapis.com/customsearch/v1" - params := url.Values{} - params.Set("key", apiKey) - params.Set("cx", engineID) - params.Set("q", query) - params.Set("num", fmt.Sprintf("%d", numResults)) - params.Set("lr", "lang_"+language) +// tavilySearch 使用 Tavily Search API 进行搜索 +func (t *WebSearchTool) tavilySearch(ctx context.Context, query string, numResults int, apiKey string) ([]searchResult, error) { + // 构建请求体 + requestBody := map[string]interface{}{ + "api_key": apiKey, + "query": query, + "max_results": numResults, + "search_depth": "basic", + "include_answer": false, + } - apiURL := baseURL + "?" + params.Encode() + bodyJSON, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } // 创建 HTTP 请求 - req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + req, err := http.NewRequestWithContext(ctx, "POST", "https://api.tavily.com/search", strings.NewReader(string(bodyJSON))) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } - // 设置超时 - client := &http.Client{ - Timeout: 30 * time.Second, - } + req.Header.Set("Content-Type", "application/json") + + // 使用不带超时的client,完全依赖context控制超时 + client := &http.Client{} // 发送请求 resp, err := client.Do(req) @@ -155,18 +154,18 @@ func (t *WebSearchTool) googleSearch(ctx context.Context, query string, numResul } // 解析响应 - var searchResp googleSearchResponse + var searchResp tavilySearchResponse if err := json.NewDecoder(resp.Body).Decode(&searchResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } // 转换结果 - results := make([]searchResult, 0, len(searchResp.Items)) - for _, item := range searchResp.Items { + results := make([]searchResult, 0, len(searchResp.Results)) + for _, item := range searchResp.Results { results = append(results, searchResult{ Title: item.Title, - Link: item.Link, - Snippet: item.Snippet, + Link: item.URL, + Snippet: item.Content, }) } diff --git a/internal/tools/search/web_search_test.go b/internal/tools/search/web_search_test.go index 27baa49..782bf3f 100644 --- a/internal/tools/search/web_search_test.go +++ b/internal/tools/search/web_search_test.go @@ -10,8 +10,7 @@ import ( func TestWebSearchTool_NoConfig(t *testing.T) { // 确保环境变量未设置 - os.Unsetenv("SEARCH_API_KEY") - os.Unsetenv("SEARCH_ENGINE_ID") + os.Unsetenv("TAVILY_API_KEY") tool := &WebSearchTool{} @@ -34,8 +33,8 @@ func TestWebSearchTool_NoConfig(t *testing.T) { t.Errorf("Expected configuration error message, got: %s", result.Content) } - if !strings.Contains(result.Content, "SEARCH_API_KEY") { - t.Errorf("Expected mention of SEARCH_API_KEY, got: %s", result.Content) + if !strings.Contains(result.Content, "TAVILY_API_KEY") { + t.Errorf("Expected mention of TAVILY_API_KEY, got: %s", result.Content) } } @@ -172,11 +171,10 @@ func TestWebSearchTool_ToolInterface(t *testing.T) { // TestWebSearchTool_Integration 是集成测试,需要真实的 API 凭证 // 默认跳过,可以通过设置环境变量来运行 func TestWebSearchTool_Integration(t *testing.T) { - apiKey := os.Getenv("SEARCH_API_KEY") - engineID := os.Getenv("SEARCH_ENGINE_ID") + apiKey := os.Getenv("TAVILY_API_KEY") - if apiKey == "" || engineID == "" { - t.Skip("Skipping integration test: SEARCH_API_KEY or SEARCH_ENGINE_ID not set") + if apiKey == "" { + t.Skip("Skipping integration test: TAVILY_API_KEY not set") } tool := &WebSearchTool{} diff --git a/internal/ui/app.go b/internal/ui/app.go index 617f213..e8b1fca 100644 --- a/internal/ui/app.go +++ b/internal/ui/app.go @@ -37,6 +37,7 @@ type Model struct { currentOutput strings.Builder inputBuffer string statusText string + completions []string // Current autocomplete suggestions // Metadata modelName string @@ -44,8 +45,10 @@ type Model struct { tokens int // Callbacks - onSubmit func(string) error - onCancel func() + onSubmit func(string) error + onCancel func() + onSlashCmd func(string) (handled bool, shouldExit bool) + slashCommands []string // For autocomplete } // Message represents a chat message. @@ -84,7 +87,8 @@ func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.markdown != nil { m.markdown.SetWidth(msg.Width - 4) } - return m, nil + // Clear screen on resize to avoid artifacts + return m, tea.ClearScreen case tea.KeyMsg: return m.handleKeyPress(msg) @@ -121,10 +125,26 @@ func (m *Model) handleKeyPress(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.quitting = true return m, tea.Quit - case tea.KeyEnter: + case tea.KeyCtrlD: + // Ctrl+D submits the input if m.state == StateIdle && m.inputBuffer != "" { - input := m.inputBuffer + input := strings.TrimSpace(m.inputBuffer) m.inputBuffer = "" + + // Check for slash commands + if strings.HasPrefix(input, "/") && m.onSlashCmd != nil { + // Execute slash command (will use tea.Println for output) + handled, shouldExit := m.onSlashCmd(input) + if shouldExit { + m.quitting = true + return m, tea.Quit + } + if handled { + // After slash command, just return to keep TUI running + return m, nil + } + } + m.messages = append(m.messages, Message{ Role: "user", Content: input, @@ -138,15 +158,41 @@ func (m *Model) handleKeyPress(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } return m, nil + case tea.KeyEnter: + // Enter adds a newline to support multi-line input + if m.state == StateIdle { + m.inputBuffer += "\n" + } + return m, nil + + case tea.KeyTab: + // Tab for slash command autocomplete + if m.state == StateIdle && strings.HasPrefix(m.inputBuffer, "/") { + m.autocompleteSlashCommand() + } + return m, nil + case tea.KeyBackspace: if len(m.inputBuffer) > 0 { m.inputBuffer = m.inputBuffer[:len(m.inputBuffer)-1] + // Update completions if editing a slash command + if strings.HasPrefix(m.inputBuffer, "/") { + m.updateCompletions() + } else { + m.completions = nil + } } return m, nil case tea.KeyRunes: if m.state == StateIdle { m.inputBuffer += string(msg.Runes) + // Update completions if typing a slash command + if strings.HasPrefix(m.inputBuffer, "/") { + m.updateCompletions() + } else { + m.completions = nil + } } return m, nil } @@ -162,38 +208,77 @@ func (m *Model) View() string { var b strings.Builder - // Status bar + // Status bar (always at top) b.WriteString(m.renderStatusBar()) b.WriteString("\n\n") + // Calculate available height for content + // Reserve space for: status bar (3 lines), input area (4 lines), margins + availableHeight := m.height - 7 + if availableHeight < 5 { + availableHeight = 5 + } + + // Collect all content lines + var contentLines []string + // Messages for _, msg := range m.messages { - b.WriteString(m.renderMessage(msg)) - b.WriteString("\n") + rendered := m.renderMessage(msg) + lines := strings.Split(rendered, "\n") + contentLines = append(contentLines, lines...) } // Current streaming output if m.currentOutput.Len() > 0 { - b.WriteString(m.theme.InfoStyle.Render("Assistant: ")) - b.WriteString(m.currentOutput.String()) + output := m.theme.InfoStyle.Render("Assistant: ") + m.currentOutput.String() + lines := strings.Split(output, "\n") + contentLines = append(contentLines, lines...) + } + + // Show only the last N lines that fit in the available height + startLine := 0 + if len(contentLines) > availableHeight { + startLine = len(contentLines) - availableHeight + } + + for i := startLine; i < len(contentLines); i++ { + b.WriteString(contentLines[i]) b.WriteString("\n") } - // Spinner for thinking/streaming state + // Spinner for thinking/streaming state (on same line) if m.state == StateThinking { - b.WriteString(m.spinner.View()) - b.WriteString(" Thinking...\n") + b.WriteString(m.spinner.View() + " Thinking...") } else if m.state == StateStreaming { - b.WriteString(m.spinner.View()) - b.WriteString(" Streaming...\n") + b.WriteString(m.spinner.View() + " Streaming...") } - // Input prompt + // Input prompt (always at bottom) if m.state == StateIdle { b.WriteString("\n") b.WriteString(m.theme.PromptStyle.Render("> ")) b.WriteString(m.inputBuffer) b.WriteString(lipgloss.NewStyle().Foreground(m.theme.Muted).Render("_")) + b.WriteString("\n") + + // Show completions if available + if len(m.completions) > 0 { + b.WriteString(m.theme.SubtitleStyle.Render("Completions: ")) + for i, comp := range m.completions { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(m.theme.InfoStyle.Render(comp)) + if i >= 4 { // Limit to 5 suggestions + b.WriteString(fmt.Sprintf(" ... (%d more)", len(m.completions)-5)) + break + } + } + b.WriteString("\n") + } + + b.WriteString(m.theme.SubtitleStyle.Render("Press Ctrl+D to submit, Enter for new line, Tab for autocomplete")) } return b.String() @@ -251,6 +336,81 @@ func (m *Model) SetOnCancel(fn func()) { m.onCancel = fn } +// SetOnSlashCommand sets the callback for slash commands. +func (m *Model) SetOnSlashCommand(fn func(string) (bool, bool)) { + m.onSlashCmd = fn +} + +// SetSlashCommands sets the list of available slash commands for autocomplete. +func (m *Model) SetSlashCommands(commands []string) { + m.slashCommands = commands +} + +// autocompleteSlashCommand attempts to autocomplete the current slash command. +func (m *Model) autocompleteSlashCommand() { + if len(m.slashCommands) == 0 { + return + } + + input := strings.TrimSpace(m.inputBuffer) + if !strings.HasPrefix(input, "/") { + return + } + + // Find matching commands + var matches []string + for _, cmd := range m.slashCommands { + if strings.HasPrefix(cmd, input) { + matches = append(matches, cmd) + } + } + + // If exactly one match, complete it + if len(matches) == 1 { + m.inputBuffer = matches[0] + " " + m.completions = nil + } else if len(matches) > 1 { + // Find common prefix + commonPrefix := matches[0] + for _, match := range matches[1:] { + for i := 0; i < len(commonPrefix) && i < len(match); i++ { + if commonPrefix[i] != match[i] { + commonPrefix = commonPrefix[:i] + break + } + } + } + if len(commonPrefix) > len(input) { + m.inputBuffer = commonPrefix + } + m.completions = matches + } +} + +// updateCompletions updates the list of completion suggestions. +func (m *Model) updateCompletions() { + if len(m.slashCommands) == 0 { + m.completions = nil + return + } + + input := strings.TrimSpace(m.inputBuffer) + if !strings.HasPrefix(input, "/") { + m.completions = nil + return + } + + // Find matching commands + var matches []string + for _, cmd := range m.slashCommands { + if strings.HasPrefix(cmd, input) { + matches = append(matches, cmd) + } + } + + m.completions = matches +} + // AddMessage adds a message to the history. func (m *Model) AddMessage(role, content string) { m.messages = append(m.messages, Message{ diff --git a/scripts/proxy_tool.py b/scripts/proxy_tool.py new file mode 100644 index 0000000..8f37289 --- /dev/null +++ b/scripts/proxy_tool.py @@ -0,0 +1,65 @@ +from flask import Flask, request, Response +import requests +import json + +app = Flask(__name__) + +# 目标:你本地运行 DeepSeek R1 的真实端口 +TARGET_SERVER = "http://127.0.0.1:10002" + +@app.route('/v1/chat/completions', methods=['POST']) +@app.route('/chat/completions', methods=['POST']) +def proxy_deepseek(): + req_data = request.get_json(silent=True) + + + print("="*120 + "\n") + print(json.dumps(req_data, indent=4, ensure_ascii=False)) + + # --- 打印逻辑保持不变 --- + print(f"\n🚀 [Intercepted] Model: {req_data.get('model')}") + + # 1. 清洗 Headers:移除会导致冲突的传输层 Header + excluded_headers = ['content-encoding', 'content-length', 'transfer-encoding', 'connection', 'host'] + headers = { + k: v for k, v in request.headers + if k.lower() not in excluded_headers + } + + # 2. 转发请求 + is_stream = req_data.get('stream', False) if req_data else False + + try: + resp = requests.post( + f"{TARGET_SERVER}{request.full_path}", + json=req_data, + headers=headers, + stream=is_stream, + timeout=300 # DeepSeek R1 推理较慢,建议增加超时 + ) + + # 3. 构造响应并清洗返回的 Headers + def generate(): + for chunk in resp.iter_content(chunk_size=None): # chunk_size=None 保持原始分块 + yield chunk + + # 同样移除返回时的冲突 Header + resp_headers = [ + (k, v) for k, v in resp.raw.headers.items() + if k.lower() not in excluded_headers + ] + + return Response( + generate() if is_stream else resp.content, + status=resp.status_code, + headers=resp_headers + ) + except Exception as e: + print(f"❌ Proxy Error: {e}") + return jsonify({"error": str(e)}), 500 + +if __name__ == '__main__': + # 启动代理,监听 10004 + print("Proxy is running on http://127.0.0.1:10004") + print(f"Forwarding to DeepSeek R1 at {TARGET_SERVER}") + app.run(port=10004, debug=False)