diff --git a/docs/plans/2026-03-11-v02-impl-progress.md b/docs/plans/2026-03-11-v02-impl-progress.md index 4369810..190c28a 100644 --- a/docs/plans/2026-03-11-v02-impl-progress.md +++ b/docs/plans/2026-03-11-v02-impl-progress.md @@ -13,7 +13,7 @@ | 1 | `feat/v02-tool-registry` | `ToolRegistry` + `ToolDefinition` + `ToolResult` interfaces | ✅ Done | | 2 | `feat/v02-tool-definitions` | Migrate existing tools to `definitions/` format | ✅ Done | | 3 | `feat/v02-safety-gate` | `SafetyGate` + config schema additions | ✅ Done | -| 4 | — | `ToolExecutor` (ties registry + safety) | ⬜ Not started | +| 4 | `feat/v02-tool-executor` | `ToolExecutor` (ties registry + safety) | ✅ Done | | 5 | — | Wire into `LocalAgent` + `Orchestrator` | ⬜ Not started | | 6 | — | New tools (`search_code`, `list_files`) | ⬜ Not started | @@ -70,3 +70,15 @@ - `always_confirm` takes precedence over `auto_approve` - `"."` in `allowed_write_paths` means project root (cwd) - Config defaults to sensible values — read tools auto-approved, write paths restricted to project + +## PR 4: ToolExecutor + +**Files:** +- `src/tools/executor.ts` — `ToolExecutor` class, `ToolCall` interface +- `src/tools/executor.test.ts` — 7 unit tests + +**Behavior:** +- `execute(call)` — validate args → check safety → run handler → return result +- `executeParallel(calls)` — runs multiple tool calls concurrently via `Promise.all` +- Write-category tools have their path checked against `SafetyGate.checkWritePath()` +- Handler errors are caught and returned as `ToolResult` failures (never throws) diff --git a/src/tools/executor.test.ts b/src/tools/executor.test.ts new file mode 100644 index 0000000..db637ba --- /dev/null +++ b/src/tools/executor.test.ts @@ -0,0 +1,120 @@ +import { describe, it, expect, vi } from 'vitest' +import { ToolExecutor } from './executor' +import { ToolRegistry } from './registry' +import { SafetyGate } from './safety-gate' +import type { ToolDefinition } from './registry' + +function makeTool(overrides: Partial = {}): ToolDefinition { + return { + name: 'test_tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: { path: { type: 'string', description: 'file path' } }, + required: ['path'], + }, + handler: vi.fn(async () => ({ success: true, output: 'ok' })), + category: 'read', + ...overrides, + } +} + +function makeExecutor(tools: ToolDefinition[] = []) { + const registry = new ToolRegistry() + for (const tool of tools) registry.register(tool) + const gate = new SafetyGate({ + always_confirm: [], + auto_approve: ['test_tool', 'tool_a', 'tool_b'], + allowed_write_paths: ['.'], + }) + return new ToolExecutor(registry, gate) +} + +describe('ToolExecutor', () => { + describe('execute', () => { + it('dispatches to the correct handler and returns result', async () => { + const tool = makeTool() + const executor = makeExecutor([tool]) + const result = await executor.execute({ tool: 'test_tool', args: { path: '/foo' } }) + expect(result.success).toBe(true) + expect(result.output).toBe('ok') + expect(tool.handler).toHaveBeenCalledWith({ path: '/foo' }) + }) + + it('returns failure for unknown tool', async () => { + const executor = makeExecutor() + const result = await executor.execute({ tool: 'nonexistent', args: {} }) + expect(result.success).toBe(false) + expect(result.error).toContain('unknown tool') + }) + + it('returns failure when required args are missing', async () => { + const executor = makeExecutor([makeTool()]) + const result = await executor.execute({ tool: 'test_tool', args: {} }) + expect(result.success).toBe(false) + expect(result.error).toContain('path') + }) + + it('returns failure when safety gate blocks write path', async () => { + const writeTool = makeTool({ + name: 'write_file', + category: 'write', + }) + const registry = new ToolRegistry() + registry.register(writeTool) + const gate = new SafetyGate({ + always_confirm: [], + auto_approve: [], + allowed_write_paths: ['src'], + }) + const executor = new ToolExecutor(registry, gate) + const result = await executor.execute({ + tool: 'write_file', + args: { path: '/etc/passwd' }, + }) + expect(result.success).toBe(false) + expect(result.error).toContain('outside allowed') + }) + + it('catches handler errors and returns failure', async () => { + const tool = makeTool({ + handler: async () => { throw new Error('boom') }, + }) + const executor = makeExecutor([tool]) + const result = await executor.execute({ tool: 'test_tool', args: { path: '/foo' } }) + expect(result.success).toBe(false) + expect(result.error).toContain('boom') + }) + }) + + describe('executeParallel', () => { + it('runs multiple tool calls concurrently', async () => { + const toolA = makeTool({ name: 'tool_a', handler: async () => ({ success: true, output: 'a' }) }) + const toolB = makeTool({ name: 'tool_b', handler: async () => ({ success: true, output: 'b' }) }) + const executor = makeExecutor([toolA, toolB]) + + const results = await executor.executeParallel([ + { tool: 'tool_a', args: { path: '1' } }, + { tool: 'tool_b', args: { path: '2' } }, + ]) + + expect(results).toHaveLength(2) + expect(results[0].output).toBe('a') + expect(results[1].output).toBe('b') + }) + + it('returns individual failures without blocking others', async () => { + const toolA = makeTool({ name: 'tool_a', handler: async () => ({ success: true, output: 'a' }) }) + const toolB = makeTool({ name: 'tool_b', handler: async () => { throw new Error('fail') } }) + const executor = makeExecutor([toolA, toolB]) + + const results = await executor.executeParallel([ + { tool: 'tool_a', args: { path: '1' } }, + { tool: 'tool_b', args: { path: '2' } }, + ]) + + expect(results[0].success).toBe(true) + expect(results[1].success).toBe(false) + }) + }) +}) diff --git a/src/tools/executor.ts b/src/tools/executor.ts new file mode 100644 index 0000000..00fe5c2 --- /dev/null +++ b/src/tools/executor.ts @@ -0,0 +1,49 @@ +import { ToolRegistry } from './registry' +import { SafetyGate } from './safety-gate' +import type { ToolResult } from './registry' + +export interface ToolCall { + tool: string + args: Record + reason?: string +} + +export class ToolExecutor { + constructor( + private registry: ToolRegistry, + private safetyGate: SafetyGate, + ) {} + + async execute(call: ToolCall): Promise { + // 1. Look up tool + const tool = this.registry.get(call.tool) + if (!tool) { + return { success: false, output: '', error: `unknown tool: '${call.tool}'` } + } + + // 2. Validate args + const validation = this.registry.validate(call.tool, call.args) + if (!validation.valid) { + return { success: false, output: '', error: validation.errors.join('; ') } + } + + // 3. Safety check — write path restriction for write-category tools + if (tool.category === 'write' && call.args.path) { + const pathCheck = this.safetyGate.checkWritePath(call.args.path as string) + if (!pathCheck.allowed) { + return { success: false, output: '', error: pathCheck.reason } + } + } + + // 4. Execute handler + try { + return await tool.handler(call.args) + } catch (err) { + return { success: false, output: '', error: (err as Error).message } + } + } + + async executeParallel(calls: ToolCall[]): Promise { + return Promise.all(calls.map(call => this.execute(call))) + } +}