diff --git a/internal/assistant/anthropic.go b/internal/assistant/anthropic.go
index bfc83b3..db92c50 100644
--- a/internal/assistant/anthropic.go
+++ b/internal/assistant/anthropic.go
@@ -10,54 +10,119 @@ import (
"github.com/omarluq/librecode/internal/database"
"github.com/omarluq/librecode/internal/model"
+ "github.com/omarluq/librecode/internal/tool"
)
func (client *HTTPCompletionClient) completeAnthropic(
ctx context.Context,
request *CompletionRequest,
) (*CompletionResult, error) {
- payload := anthropicPayload(request)
- endpoint := joinEndpoint(request.Model.BaseURL, "/v1/messages")
- content, err := client.postJSON(ctx, endpoint, anthropicHeaders(request), payload)
- if err != nil {
- return nil, err
+ state := anthropicLoopState{
+ messages: anthropicMessages(request.Messages),
+ endpoint: joinEndpoint(request.Model.BaseURL, "/v1/messages"),
+ result: &CompletionResult{Text: "", Thinking: nil, ToolEvents: nil, Usage: model.EmptyTokenUsage()},
}
- var response struct {
- Error providerError `json:"error"`
- Usage map[string]any `json:"usage"`
- Content []struct {
- Type string `json:"type"`
- Text string `json:"text"`
- } `json:"content"`
+ for range maxToolIterations {
+ finished, err := client.advanceAnthropicLoop(ctx, request, &state)
+ if err != nil {
+ return nil, err
+ }
+ if finished {
+ return state.result, nil
+ }
}
- if err := json.Unmarshal(content, &response); err != nil {
- return nil, oops.In("assistant").Code("anthropic_decode").Wrapf(err, "decode anthropic response")
+
+ return nil, toolIterationLimitError()
+}
+
+type anthropicLoopState struct {
+ result *CompletionResult
+ endpoint string
+ messages []map[string]any
+}
+
+func (client *HTTPCompletionClient) advanceAnthropicLoop(
+ ctx context.Context,
+ request *CompletionRequest,
+ state *anthropicLoopState,
+) (bool, error) {
+ payload := anthropicPayload(request, state.messages)
+ providerResult, err := client.requestAnthropic(ctx, state.endpoint, request, payload)
+ if err != nil {
+ return false, err
}
- if response.Error.Message != "" {
- return nil, providerErrorToOops("anthropic_error", &response.Error)
+ state.result.Usage = mergeUsage(state.result.Usage, providerResult.Usage)
+ if err := validateToolCalls(providerResult.ToolCalls); err != nil {
+ return false, err
}
- parts := make([]string, 0, len(response.Content))
- for _, block := range response.Content {
- if block.Type == jsonTextKey && block.Text != "" {
- parts = append(parts, block.Text)
+ if len(providerResult.ToolCalls) == 0 {
+ if fallback := textToolCallsFromText(providerResult.Text); len(fallback) > 0 {
+ providerResult.ToolCalls = fallback
+ } else {
+ return finishTextResult(state.result, providerResult.Text, "anthropic_empty")
}
}
- text := strings.TrimSpace(strings.Join(parts, "\n"))
- if text == "" {
- return nil, oops.In("assistant").Code("anthropic_empty").Errorf("provider returned an empty response")
+ events := executeAnthropicToolCalls(ctx, request, providerResult.ToolCalls)
+ state.result.ToolEvents = append(state.result.ToolEvents, events...)
+ if err := appendAnthropicToolConversation(state, providerResult, events); err != nil {
+ return false, err
+ }
+
+ return false, nil
+}
+
+func executeAnthropicToolCalls(
+ ctx context.Context,
+ request *CompletionRequest,
+ calls []toolCall,
+) []ToolEvent {
+ _, events := executeToolCalls(
+ ctx,
+ request.CWD,
+ calls,
+ request.OnEvent,
+ request.OnToolCall,
+ request.OnToolResult,
+ )
+
+ return events
+}
+
+func appendAnthropicToolConversation(
+ state *anthropicLoopState,
+ providerResult *providerResult,
+ events []ToolEvent,
+) error {
+ if hasTextFallbackToolCalls(providerResult.ToolCalls) {
+ state.messages = append(
+ state.messages,
+ map[string]any{jsonRoleKey: jsonAssistantRole, jsonContentKey: providerResult.Text},
+ map[string]any{jsonRoleKey: jsonUserRole, jsonContentKey: textToolResultPrompt(events)},
+ )
+ return nil
+ }
+ toolResultMessage, err := anthropicToolResultMessage(providerResult.ToolCalls, events)
+ if err != nil {
+ return err
}
+ state.messages = append(
+ state.messages,
+ anthropicAssistantToolMessage(providerResult.ToolCalls),
+ toolResultMessage,
+ )
- return textCompletionResult(text, usageFromObject(response.Usage)), nil
+ return nil
}
-func anthropicPayload(request *CompletionRequest) map[string]any {
+func anthropicPayload(request *CompletionRequest, messages []map[string]any) map[string]any {
// Anthropic's recent Claude models reject temperature when thinking/adaptive
// reasoning is available. Match production agent clients by omitting
// temperature unless/until librecode exposes an explicit user setting.
payload := map[string]any{
jsonModelKey: request.Model.ID,
"max_tokens": minPositive(request.Model.MaxTokens, 4096),
- "messages": anthropicMessages(request.Messages),
+ "messages": messages,
+ "tools": anthropicTools(),
}
if usesAnthropicOAuth(request) {
payload["system"] = anthropicOAuthSystemPrompt(request.SystemPrompt)
@@ -212,25 +277,149 @@ func appendAnthropicBeta(existing string, values ...string) string {
return strings.Join(output, ",")
}
-func anthropicMessages(messages []database.MessageEntity) []map[string]string {
- output := []map[string]string{}
+func (client *HTTPCompletionClient) requestAnthropic(
+ ctx context.Context,
+ endpoint string,
+ request *CompletionRequest,
+ payload map[string]any,
+) (*providerResult, error) {
+ content, err := client.postJSON(ctx, endpoint, anthropicHeaders(request), payload)
+ if err != nil {
+ return nil, err
+ }
+
+ return parseAnthropicResult(content)
+}
+
+func parseAnthropicResult(content []byte) (*providerResult, error) {
+ var response struct {
+ Error providerError `json:"error"`
+ Usage map[string]any `json:"usage"`
+ Content []struct {
+ Type string `json:"type"`
+ Text string `json:"text"`
+ Input any `json:"input"`
+ ID string `json:"id"`
+ Name string `json:"name"`
+ } `json:"content"`
+ }
+ if err := json.Unmarshal(content, &response); err != nil {
+ return nil, oops.In("assistant").Code("anthropic_decode").Wrapf(err, "decode anthropic response")
+ }
+ if response.Error.Message != "" {
+ return nil, providerErrorToOops("anthropic_error", &response.Error)
+ }
+ parts := make([]string, 0, len(response.Content))
+ calls := make([]toolCall, 0, len(response.Content))
+ for _, block := range response.Content {
+ switch block.Type {
+ case jsonTextKey:
+ if block.Text != "" {
+ parts = append(parts, block.Text)
+ }
+ case anthropicToolUseType:
+ calls = append(calls, anthropicToolCall(block.ID, block.Name, block.Input))
+ }
+ }
+
+ return &providerResult{
+ Text: strings.TrimSpace(strings.Join(parts, "\n")),
+ OutputItems: nil,
+ Thinking: nil,
+ ToolCalls: calls,
+ Usage: usageFromObject(response.Usage),
+ }, nil
+}
+
+func anthropicToolCall(id, name string, input any) toolCall {
+ arguments, argumentsJSON := anthropicToolArguments(input)
+
+ return toolCall{Arguments: arguments, ID: id, Name: name, ArgumentsJSON: argumentsJSON, TextFallback: false}
+}
+
+func anthropicToolArguments(input any) (arguments map[string]any, argumentsJSON string) {
+ arguments = map[string]any{}
+ payload, err := json.Marshal(input)
+ if err != nil {
+ return arguments, "{}"
+ }
+ if len(payload) == 0 || string(payload) == "null" {
+ return arguments, "{}"
+ }
+ if err := json.Unmarshal(payload, &arguments); err != nil {
+ return map[string]any{}, string(payload)
+ }
+
+ return arguments, string(payload)
+}
+
+func anthropicAssistantToolMessage(calls []toolCall) map[string]any {
+ blocks := make([]map[string]any, 0, len(calls))
+ for _, call := range calls {
+ blocks = append(blocks, map[string]any{
+ jsonTypeKey: anthropicToolUseType,
+ "id": call.ID,
+ jsonToolNameKey: call.Name,
+ "input": call.Arguments,
+ })
+ }
+
+ return map[string]any{jsonRoleKey: jsonAssistantRole, jsonContentKey: blocks}
+}
+
+func anthropicToolResultMessage(calls []toolCall, events []ToolEvent) (map[string]any, error) {
+ if len(events) != len(calls) {
+ return nil, oops.In("assistant").
+ Code("anthropic_tool_message_mismatch").
+ With("calls", len(calls)).
+ With("events", len(events)).
+ Errorf("build Anthropic tool result message: mismatched tool calls and results")
+ }
+ blocks := make([]map[string]any, 0, len(events))
+ for index, event := range events {
+ blocks = append(blocks, map[string]any{
+ jsonTypeKey: anthropicToolResultType,
+ "tool_use_id": calls[index].ID,
+ jsonContentKey: toolOutputText(event.Result, event.DetailsJSON),
+ })
+ }
+
+ return map[string]any{jsonRoleKey: jsonUserRole, jsonContentKey: blocks}, nil
+}
+
+func anthropicMessages(messages []database.MessageEntity) []map[string]any {
+ output := []map[string]any{}
for _, message := range messages {
role, ok := anthropicRole(message.Role)
if !ok || message.Content == "" {
continue
}
- output = append(output, map[string]string{jsonRoleKey: role, jsonContentKey: message.Content})
+ output = append(output, map[string]any{jsonRoleKey: role, jsonContentKey: message.Content})
}
return output
}
+func anthropicTools() []map[string]any {
+ definitions := tool.AllDefinitions()
+ tools := make([]map[string]any, 0, len(definitions))
+ for _, definition := range definitions {
+ tools = append(tools, map[string]any{
+ jsonToolNameKey: string(definition.Name),
+ jsonDescriptionKey: definition.Description,
+ jsonInputSchemaKey: toolParameterSchema(definition.Name),
+ })
+ }
+
+ return tools
+}
+
func anthropicRole(role database.Role) (string, bool) {
switch role {
case database.RoleUser:
return jsonUserRole, true
case database.RoleAssistant:
- return "assistant", true
+ return jsonAssistantRole, true
case database.RoleToolResult,
database.RoleThinking,
database.RoleCustom,
diff --git a/internal/assistant/anthropic_internal_test.go b/internal/assistant/anthropic_internal_test.go
index 03ad0ee..05d5248 100644
--- a/internal/assistant/anthropic_internal_test.go
+++ b/internal/assistant/anthropic_internal_test.go
@@ -11,7 +11,7 @@ import (
func TestAnthropicPayloadOmitsTemperature(t *testing.T) {
t.Parallel()
- payload := anthropicPayload(testCompletionRequestAuth("anthropic-claude", "subscription-access-token"))
+ payload := anthropicPayload(testCompletionRequestAuth("anthropic-claude", "subscription-access-token"), nil)
assert.NotContains(t, payload, "temperature")
assert.Equal(t, "", payload[jsonModelKey])
@@ -25,7 +25,7 @@ func TestAnthropicPayloadUsesStructuredSystemPrompt(t *testing.T) {
request := testCompletionRequestAuth("sk-ant-api03-secret")
request.SystemPrompt = anthropicTestSystemPrompt
- payload := anthropicPayload(request)
+ payload := anthropicPayload(request, nil)
assert.Equal(t, []map[string]any{anthropicSystemText(anthropicTestSystemPrompt)}, payload["system"])
}
@@ -35,7 +35,7 @@ func TestAnthropicOAuthPayloadAddsClaudeCodeIdentity(t *testing.T) {
request := testCompletionRequestAuth("anthropic-claude", "sk-ant-oat-secret")
request.SystemPrompt = anthropicTestSystemPrompt
- payload := anthropicPayload(request)
+ payload := anthropicPayload(request, nil)
systemBlocks, ok := payload["system"].([]map[string]any)
assert.True(t, ok)
@@ -51,7 +51,7 @@ func TestAnthropicPayloadAddsBudgetThinking(t *testing.T) {
request.Model.ID = "claude-sonnet-4-5"
request.Model.Reasoning = true
request.ThinkingLevel = thinkingLow
- payload := anthropicPayload(request)
+ payload := anthropicPayload(request, nil)
assert.Equal(t, map[string]any{
jsonTypeKey: "enabled",
@@ -67,7 +67,7 @@ func TestAnthropicPayloadDisablesThinkingWhenOff(t *testing.T) {
request.Model.ID = "claude-opus-4-7"
request.Model.Reasoning = true
request.ThinkingLevel = thinkingOff
- payload := anthropicPayload(request)
+ payload := anthropicPayload(request, nil)
assert.Equal(t, map[string]any{jsonTypeKey: "disabled"}, payload[jsonThinkingKey])
assert.NotContains(t, payload, "output_config")
@@ -80,7 +80,7 @@ func TestAnthropicPayloadAddsAdaptiveThinking(t *testing.T) {
request.Model.ID = "claude-opus-4-7"
request.Model.Reasoning = true
request.ThinkingLevel = thinkingXHigh
- payload := anthropicPayload(request)
+ payload := anthropicPayload(request, nil)
assert.Equal(t, map[string]any{
jsonTypeKey: "adaptive",
diff --git a/internal/assistant/anthropic_tools_test.go b/internal/assistant/anthropic_tools_test.go
new file mode 100644
index 0000000..680b2fa
--- /dev/null
+++ b/internal/assistant/anthropic_tools_test.go
@@ -0,0 +1,114 @@
+//nolint:testpackage // Tests exercise unexported Anthropic tool helpers.
+package assistant
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/omarluq/librecode/internal/model"
+)
+
+const testAnthropicToolUseID = "toolu_1"
+
+func TestParseAnthropicResultExtractsNativeToolUse(t *testing.T) {
+ t.Parallel()
+
+ content := []byte(`{
+ "content": [
+ {"type":"tool_use","id":"toolu_1","name":"read","input":{"path":"README.md"}}
+ ],
+ "usage": {"input_tokens": 12, "output_tokens": 3}
+ }`)
+
+ result, err := parseAnthropicResult(content)
+ require.NoError(t, err)
+ require.Len(t, result.ToolCalls, 1)
+ assert.Equal(t, testAnthropicToolUseID, result.ToolCalls[0].ID)
+ assert.Equal(t, "read", result.ToolCalls[0].Name)
+ assert.Equal(t, "README.md", result.ToolCalls[0].Arguments[jsonPathKey])
+ assert.Equal(t, 12, result.Usage.InputTokens)
+ assert.Equal(t, 3, result.Usage.OutputTokens)
+}
+
+func TestAnthropicPayloadIncludesTools(t *testing.T) {
+ t.Parallel()
+
+ request := testCompletionRequestAuth("sk-ant-api03-secret")
+ payload := anthropicPayload(request, nil)
+
+ tools, ok := payload["tools"].([]map[string]any)
+ require.True(t, ok)
+ require.NotEmpty(t, tools)
+ encoded, err := json.Marshal(tools)
+ require.NoError(t, err)
+ assert.Contains(t, string(encoded), `"input_schema"`)
+ assert.Contains(t, string(encoded), `"read"`)
+}
+
+func TestAnthropicToolResultMessageUsesToolUseID(t *testing.T) {
+ t.Parallel()
+
+ message, err := anthropicToolResultMessage(
+ []toolCall{{
+ Arguments: nil,
+ ID: testAnthropicToolUseID,
+ Name: jsonReadToolName,
+ ArgumentsJSON: `{}`,
+ TextFallback: false,
+ }},
+ []ToolEvent{{Name: jsonReadToolName, ArgumentsJSON: `{}`, DetailsJSON: "", Result: "ok", Error: ""}},
+ )
+
+ require.NoError(t, err)
+ blocks, ok := message[jsonContentKey].([]map[string]any)
+ require.True(t, ok)
+ require.Len(t, blocks, 1)
+ assert.Equal(t, testAnthropicToolUseID, blocks[0]["tool_use_id"])
+ assert.Equal(t, "ok", blocks[0][jsonContentKey])
+}
+
+func TestAnthropicToolResultMessageRejectsMismatchedCallsAndEvents(t *testing.T) {
+ t.Parallel()
+
+ message, err := anthropicToolResultMessage(
+ []toolCall{{
+ Arguments: nil,
+ ID: testAnthropicToolUseID,
+ Name: jsonReadToolName,
+ ArgumentsJSON: `{}`,
+ TextFallback: false,
+ }},
+ nil,
+ )
+
+ require.Error(t, err)
+ assert.Nil(t, message)
+ assert.Contains(t, err.Error(), "mismatched tool calls and results")
+}
+
+func TestAppendAnthropicToolConversationRejectsMismatchedNativeResults(t *testing.T) {
+ t.Parallel()
+
+ state := &anthropicLoopState{result: nil, endpoint: "", messages: nil}
+ result := &providerResult{
+ Text: "",
+ OutputItems: nil,
+ Thinking: nil,
+ ToolCalls: []toolCall{{
+ Arguments: nil,
+ ID: testAnthropicToolUseID,
+ Name: jsonReadToolName,
+ ArgumentsJSON: `{}`,
+ TextFallback: false,
+ }},
+ Usage: model.EmptyTokenUsage(),
+ }
+
+ err := appendAnthropicToolConversation(state, result, nil)
+
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "mismatched tool calls and results")
+}
diff --git a/internal/assistant/client.go b/internal/assistant/client.go
index 0c40a00..98ed278 100644
--- a/internal/assistant/client.go
+++ b/internal/assistant/client.go
@@ -3,7 +3,6 @@ package assistant
import (
"context"
"net/http"
- "strings"
"time"
"github.com/samber/oops"
@@ -31,6 +30,8 @@ const (
jsonObjectType = "object"
jsonToolNameKey = "name"
jsonToolParamsKey = "parameters"
+ jsonInputSchemaKey = "input_schema"
+ jsonArgumentsKey = "arguments"
jsonCallIDKey = "call_id"
jsonOutputKey = "output"
jsonOutputTokensKey = "output_tokens"
@@ -44,9 +45,21 @@ const (
jsonDisplayKey = "display"
jsonUsageKey = "usage"
jsonUserRole = "user"
+ jsonAssistantRole = "assistant"
+ jsonToolRole = "tool"
+ jsonCommandKey = "command"
+ jsonReadToolName = "read"
+ jsonBashToolName = "bash"
+ jsonEditToolName = "edit"
+ jsonWriteToolName = "write"
+ jsonGrepToolName = "grep"
+ jsonOldTextKey = "oldText"
+ jsonNewTextKey = "newText"
functionToolType = "function"
functionCallType = "function_call"
functionCallOutputType = "function_call_output"
+ anthropicToolUseType = "tool_use"
+ anthropicToolResultType = "tool_result"
reasoningEffortKey = "effort"
thinkingOff = "off"
thinkingLow = "low"
@@ -105,6 +118,7 @@ type toolCall struct {
ID string
Name string
ArgumentsJSON string
+ TextFallback bool
}
type providerResult struct {
@@ -150,7 +164,3 @@ func (client *HTTPCompletionClient) Complete(
Errorf("provider api is not implemented")
}
}
-
-func textCompletionResult(text string, usage model.TokenUsage) *CompletionResult {
- return &CompletionResult{Text: strings.TrimSpace(text), Thinking: nil, ToolEvents: nil, Usage: usage}
-}
diff --git a/internal/assistant/errors_extra.go b/internal/assistant/errors_extra.go
new file mode 100644
index 0000000..e849f46
--- /dev/null
+++ b/internal/assistant/errors_extra.go
@@ -0,0 +1,7 @@
+package assistant
+
+import "github.com/samber/oops"
+
+func emptyProviderResponseError(code string) error {
+ return oops.In("assistant").Code(code).Errorf("provider returned an empty response")
+}
diff --git a/internal/assistant/openai_chat.go b/internal/assistant/openai_chat.go
index 8dafea0..b8333ae 100644
--- a/internal/assistant/openai_chat.go
+++ b/internal/assistant/openai_chat.go
@@ -8,32 +8,141 @@ import (
"github.com/samber/oops"
"github.com/omarluq/librecode/internal/database"
+ "github.com/omarluq/librecode/internal/model"
)
func (client *HTTPCompletionClient) completeOpenAIChat(
ctx context.Context,
request *CompletionRequest,
) (*CompletionResult, error) {
+ state := openAIChatLoopState{
+ messages: openAIChatMessages(request),
+ endpoint: joinEndpoint(request.Model.BaseURL, "/chat/completions"),
+ result: &CompletionResult{Text: "", Thinking: nil, ToolEvents: nil, Usage: model.EmptyTokenUsage()},
+ }
+ for range maxToolIterations {
+ finished, err := client.advanceOpenAIChatLoop(ctx, request, &state)
+ if err != nil {
+ return nil, err
+ }
+ if finished {
+ return state.result, nil
+ }
+ }
+
+ return nil, toolIterationLimitError()
+}
+
+type openAIChatLoopState struct {
+ result *CompletionResult
+ endpoint string
+ messages []map[string]any
+}
+
+func (client *HTTPCompletionClient) advanceOpenAIChatLoop(
+ ctx context.Context,
+ request *CompletionRequest,
+ state *openAIChatLoopState,
+) (bool, error) {
+ payload := openAIChatPayload(request, state.messages)
+ content, err := client.postJSON(ctx, state.endpoint, openAIHeaders(request), payload)
+ if err != nil {
+ return false, err
+ }
+ providerResult, err := parseOpenAIChatResult(content)
+ if err != nil {
+ return false, err
+ }
+ state.result.Usage = mergeUsage(state.result.Usage, providerResult.Usage)
+ if err := validateToolCalls(providerResult.ToolCalls); err != nil {
+ return false, err
+ }
+ if len(providerResult.ToolCalls) == 0 {
+ if fallback := textToolCallsFromText(providerResult.Text); len(fallback) > 0 {
+ providerResult.ToolCalls = fallback
+ } else {
+ return finishTextResult(state.result, providerResult.Text, "openai_chat_empty")
+ }
+ }
+ events := executeOpenAIChatToolCalls(ctx, request, providerResult.ToolCalls)
+ state.result.ToolEvents = append(state.result.ToolEvents, events...)
+ if err := appendOpenAIChatToolConversation(state, providerResult, events); err != nil {
+ return false, err
+ }
+
+ return false, nil
+}
+
+func executeOpenAIChatToolCalls(
+ ctx context.Context,
+ request *CompletionRequest,
+ calls []toolCall,
+) []ToolEvent {
+ _, events := executeToolCalls(
+ ctx,
+ request.CWD,
+ calls,
+ request.OnEvent,
+ request.OnToolCall,
+ request.OnToolResult,
+ )
+
+ return events
+}
+
+func appendOpenAIChatToolConversation(state *openAIChatLoopState, result *providerResult, events []ToolEvent) error {
+ if hasTextFallbackToolCalls(result.ToolCalls) {
+ state.messages = append(
+ state.messages,
+ map[string]any{jsonRoleKey: jsonAssistantRole, jsonContentKey: result.Text},
+ map[string]any{jsonRoleKey: jsonUserRole, jsonContentKey: textToolResultPrompt(events)},
+ )
+ return nil
+ }
+ toolMessages, err := openAIChatToolMessages(result.ToolCalls, events)
+ if err != nil {
+ return err
+ }
+ state.messages = append(
+ state.messages,
+ openAIChatAssistantToolMessage(result),
+ )
+ state.messages = append(state.messages, toolMessages...)
+
+ return nil
+}
+
+func openAIChatPayload(request *CompletionRequest, messages []map[string]any) map[string]any {
payload := map[string]any{
- jsonModelKey: request.Model.ID,
- "messages": openAIChatMessages(request),
- "stream": false,
- "temperature": 0.2,
+ jsonModelKey: request.Model.ID,
+ "messages": messages,
+ "stream": false,
+ "temperature": 0.2,
+ "tools": openAIChatTools(),
+ jsonToolChoiceKey: "auto",
}
if request.Model.Reasoning && request.ThinkingLevel != "" && request.ThinkingLevel != thinkingOff {
payload["reasoning_effort"] = request.ThinkingLevel
}
- endpoint := joinEndpoint(request.Model.BaseURL, "/chat/completions")
- content, err := client.postJSON(ctx, endpoint, openAIHeaders(request), payload)
- if err != nil {
- return nil, err
- }
+
+ return payload
+}
+
+func parseOpenAIChatResult(content []byte) (*providerResult, error) {
var response struct {
Error providerError `json:"error"`
Usage map[string]any `json:"usage"`
Choices []struct {
Message struct {
- Content string `json:"content"`
+ Content string `json:"content"`
+ ToolCalls []struct {
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Function struct {
+ Name string `json:"name"`
+ Arguments string `json:"arguments"`
+ } `json:"function"`
+ } `json:"tool_calls"`
} `json:"message"`
} `json:"choices"`
}
@@ -43,35 +152,95 @@ func (client *HTTPCompletionClient) completeOpenAIChat(
if response.Error.Message != "" {
return nil, providerErrorToOops("openai_chat_error", &response.Error)
}
- if len(response.Choices) == 0 || strings.TrimSpace(response.Choices[0].Message.Content) == "" {
- return nil, oops.In("assistant").Code("openai_chat_empty").Errorf("provider returned an empty response")
+ if len(response.Choices) == 0 {
+ return nil, emptyProviderResponseError("openai_chat_empty")
+ }
+ message := response.Choices[0].Message
+ calls := make([]toolCall, 0, len(message.ToolCalls))
+ for _, call := range message.ToolCalls {
+ if call.Type != "" && call.Type != functionToolType {
+ continue
+ }
+ calls = append(calls, toolCall{
+ Arguments: toolArgumentsFromJSON(call.Function.Arguments),
+ ID: call.ID,
+ Name: call.Function.Name,
+ ArgumentsJSON: call.Function.Arguments,
+ TextFallback: false,
+ })
}
- return textCompletionResult(response.Choices[0].Message.Content, usageFromObject(response.Usage)), nil
+ return &providerResult{
+ Text: strings.TrimSpace(message.Content),
+ OutputItems: nil,
+ Thinking: nil,
+ ToolCalls: calls,
+ Usage: usageFromObject(response.Usage),
+ }, nil
}
-func openAIChatMessages(request *CompletionRequest) []map[string]string {
- messages := []map[string]string{}
+func openAIChatMessages(request *CompletionRequest) []map[string]any {
+ messages := []map[string]any{}
if request.SystemPrompt != "" {
- messages = append(messages, map[string]string{jsonRoleKey: "system", jsonContentKey: request.SystemPrompt})
+ messages = append(messages, map[string]any{jsonRoleKey: "system", jsonContentKey: request.SystemPrompt})
}
for _, message := range request.Messages {
role, ok := openAIRole(message.Role)
if !ok || message.Content == "" {
continue
}
- messages = append(messages, map[string]string{jsonRoleKey: role, jsonContentKey: message.Content})
+ messages = append(messages, map[string]any{jsonRoleKey: role, jsonContentKey: message.Content})
}
return messages
}
+func openAIChatAssistantToolMessage(result *providerResult) map[string]any {
+ toolCalls := make([]map[string]any, 0, len(result.ToolCalls))
+ for _, call := range result.ToolCalls {
+ toolCalls = append(toolCalls, map[string]any{
+ "id": call.ID,
+ jsonTypeKey: functionToolType,
+ "function": map[string]any{
+ jsonToolNameKey: call.Name,
+ jsonArgumentsKey: call.ArgumentsJSON,
+ },
+ })
+ }
+
+ return map[string]any{
+ jsonRoleKey: jsonAssistantRole,
+ jsonContentKey: result.Text,
+ "tool_calls": toolCalls,
+ }
+}
+
+func openAIChatToolMessages(calls []toolCall, events []ToolEvent) ([]map[string]any, error) {
+ if len(events) != len(calls) {
+ return nil, oops.In("assistant").
+ Code("openai_chat_tool_message_mismatch").
+ With("calls", len(calls)).
+ With("events", len(events)).
+ Errorf("build OpenAI chat tool messages: mismatched tool calls and results")
+ }
+ messages := make([]map[string]any, 0, len(events))
+ for index, event := range events {
+ messages = append(messages, map[string]any{
+ jsonRoleKey: jsonToolRole,
+ "tool_call_id": calls[index].ID,
+ jsonContentKey: toolOutputText(event.Result, event.DetailsJSON),
+ })
+ }
+
+ return messages, nil
+}
+
func openAIRole(role database.Role) (string, bool) {
switch role {
case database.RoleUser:
return jsonUserRole, true
case database.RoleAssistant:
- return "assistant", true
+ return jsonAssistantRole, true
case database.RoleToolResult,
database.RoleThinking,
database.RoleCustom,
diff --git a/internal/assistant/openai_responses.go b/internal/assistant/openai_responses.go
index 3a9eb0f..2d10737 100644
--- a/internal/assistant/openai_responses.go
+++ b/internal/assistant/openai_responses.go
@@ -151,11 +151,11 @@ func statelessResponseOutputItems(items []any) []any {
continue
}
stateless = append(stateless, map[string]any{
- jsonTypeKey: functionCallType,
- jsonCallIDKey: stringValue(object[jsonCallIDKey]),
- jsonToolNameKey: stringValue(object[jsonToolNameKey]),
- "arguments": stringValue(object["arguments"]),
- "status": "completed",
+ jsonTypeKey: functionCallType,
+ jsonCallIDKey: stringValue(object[jsonCallIDKey]),
+ jsonToolNameKey: stringValue(object[jsonToolNameKey]),
+ jsonArgumentsKey: stringValue(object[jsonArgumentsKey]),
+ "status": "completed",
})
}
@@ -228,7 +228,7 @@ func toolCallsFromOutput(output []any) []toolCall {
if !ok || stringValue(object[jsonTypeKey]) != functionCallType {
continue
}
- argumentsJSON := stringValue(object["arguments"])
+ argumentsJSON := stringValue(object[jsonArgumentsKey])
arguments := map[string]any{}
if strings.TrimSpace(argumentsJSON) != "" {
if err := json.Unmarshal([]byte(argumentsJSON), &arguments); err != nil {
@@ -240,6 +240,7 @@ func toolCallsFromOutput(output []any) []toolCall {
ID: firstNonEmptyString(object[jsonCallIDKey], object["id"]),
Name: firstNonEmptyString(object[jsonToolNameKey], object["function"]),
ArgumentsJSON: argumentsJSON,
+ TextFallback: false,
})
}
diff --git a/internal/assistant/provider_tool_calls_test.go b/internal/assistant/provider_tool_calls_test.go
new file mode 100644
index 0000000..0b59363
--- /dev/null
+++ b/internal/assistant/provider_tool_calls_test.go
@@ -0,0 +1,193 @@
+//nolint:testpackage // Tests exercise provider-specific unexported tool-loop helpers.
+package assistant
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestCompleteOpenAIChatExecutesNativeToolCalls(t *testing.T) {
+ t.Parallel()
+
+ var requests []map[string]any
+ server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
+ var payload map[string]any
+ require.NoError(t, json.NewDecoder(request.Body).Decode(&payload))
+ requests = append(requests, payload)
+ writer.Header().Set("Content-Type", "application/json")
+ if len(requests) == 1 {
+ writeTestProviderResponse(t, writer, openAIChatReadToolResponse())
+ return
+ }
+ writeTestProviderResponse(t, writer, `{"choices":[{"message":{"content":"done"}}]}`)
+ }))
+ defer server.Close()
+
+ request := testCompletionRequestAuth("sk-test")
+ request.CWD = testRepoRoot(t)
+ request.Model.BaseURL = server.URL
+
+ result, err := NewHTTPCompletionClient().completeOpenAIChat(context.Background(), request)
+ require.NoError(t, err)
+ require.Equal(t, "done", result.Text)
+ require.Len(t, result.ToolEvents, 1)
+ assert.Equal(t, jsonReadToolName, result.ToolEvents[0].Name)
+ assert.Contains(t, result.ToolEvents[0].Result, "librecode")
+ require.Len(t, requests, 2)
+ tools, ok := requests[0]["tools"].([]any)
+ require.True(t, ok)
+ assert.NotEmpty(t, tools)
+ messages, ok := requests[1]["messages"].([]any)
+ require.True(t, ok)
+ assert.True(t, containsRoleMessage(messages, jsonToolRole))
+}
+
+func TestCompleteAnthropicExecutesTextToolUseFallback(t *testing.T) {
+ t.Parallel()
+
+ var requests []map[string]any
+ server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
+ var payload map[string]any
+ require.NoError(t, json.NewDecoder(request.Body).Decode(&payload))
+ requests = append(requests, payload)
+ writer.Header().Set("Content-Type", "application/json")
+ if len(requests) == 1 {
+ writeTestProviderResponse(t, writer, anthropicTextReadToolResponse())
+ return
+ }
+ writeTestProviderResponse(t, writer, `{"content":[{"type":"text","text":"done"}]}`)
+ }))
+ defer server.Close()
+
+ request := testCompletionRequestAuth("sk-ant-api03-secret")
+ request.CWD = testRepoRoot(t)
+ request.Model.BaseURL = server.URL
+
+ result, err := NewHTTPCompletionClient().completeAnthropic(context.Background(), request)
+ require.NoError(t, err)
+ require.Equal(t, "done", result.Text)
+ require.Len(t, result.ToolEvents, 1)
+ assert.Equal(t, jsonReadToolName, result.ToolEvents[0].Name)
+ assert.Contains(t, result.ToolEvents[0].Result, "librecode")
+ require.Len(t, requests, 2)
+ messages, ok := requests[1]["messages"].([]any)
+ require.True(t, ok)
+ assert.True(t, containsAssistantTextToolUsePrompt(messages))
+ assert.True(t, containsUserToolResultPrompt(messages))
+}
+
+func openAIChatReadToolResponse() string {
+ return `{
+ "choices":[{"message":{"tool_calls":[{"id":"call_1","type":"function","function":{
+ "name":"read",
+ "arguments":"{\"path\":\"README.md\"}"
+ }}]}}]
+ }`
+}
+
+func anthropicTextReadToolResponse() string {
+ return `{ "content":[{"type":"text","text":"` + anthropicTextReadToolMarkup() + `"}] }`
+}
+
+func anthropicTextReadToolMarkup() string {
+ return "ReadREADME.md"
+}
+
+func containsRoleMessage(messages []any, role string) bool {
+ for _, message := range messages {
+ object, ok := message.(map[string]any)
+ if ok && object[jsonRoleKey] == role {
+ return true
+ }
+ }
+
+ return false
+}
+
+func containsAssistantTextToolUsePrompt(messages []any) bool {
+ for _, message := range messages {
+ object, ok := message.(map[string]any)
+ if !ok || object[jsonRoleKey] != jsonAssistantRole {
+ continue
+ }
+ content, ok := object[jsonContentKey].(string)
+ if ok && strings.Contains(content, "") {
+ return true
+ }
+ }
+
+ return false
+}
+
+func containsUserToolResultPrompt(messages []any) bool {
+ for _, message := range messages {
+ object, ok := message.(map[string]any)
+ if !ok || object[jsonRoleKey] != jsonUserRole {
+ continue
+ }
+ content, ok := object[jsonContentKey].(string)
+ if ok && strings.HasPrefix(content, "Tool result for read") {
+ return true
+ }
+ }
+
+ return false
+}
+
+func writeTestProviderResponse(t *testing.T, writer http.ResponseWriter, response string) {
+ t.Helper()
+ _, err := writer.Write([]byte(response))
+ require.NoError(t, err)
+}
+
+func testRepoRoot(t *testing.T) string {
+ t.Helper()
+ cwd, err := os.Getwd()
+ require.NoError(t, err)
+
+ return filepath.Clean(filepath.Join(cwd, "..", ".."))
+}
+
+func TestCompleteOpenAIChatExecutesTextToolUseFallback(t *testing.T) {
+ t.Parallel()
+
+ var requests []map[string]any
+ server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
+ var payload map[string]any
+ require.NoError(t, json.NewDecoder(request.Body).Decode(&payload))
+ requests = append(requests, payload)
+ writer.Header().Set("Content-Type", "application/json")
+ if len(requests) == 1 {
+ content, err := json.Marshal(anthropicTextReadToolMarkup())
+ require.NoError(t, err)
+ writeTestProviderResponse(t, writer, `{"choices":[{"message":{"content":`+string(content)+`}}]}`)
+ return
+ }
+ writeTestProviderResponse(t, writer, `{"choices":[{"message":{"content":"done"}}]}`)
+ }))
+ defer server.Close()
+
+ request := testCompletionRequestAuth("sk-test")
+ request.CWD = testRepoRoot(t)
+ request.Model.BaseURL = server.URL
+
+ result, err := NewHTTPCompletionClient().completeOpenAIChat(context.Background(), request)
+ require.NoError(t, err)
+ require.Equal(t, "done", result.Text)
+ require.Len(t, result.ToolEvents, 1)
+ assert.Equal(t, jsonReadToolName, result.ToolEvents[0].Name)
+ require.Len(t, requests, 2)
+ messages, ok := requests[1]["messages"].([]any)
+ require.True(t, ok)
+ assert.True(t, containsAssistantTextToolUsePrompt(messages))
+ assert.True(t, containsUserToolResultPrompt(messages))
+}
diff --git a/internal/assistant/text_tool_calls.go b/internal/assistant/text_tool_calls.go
new file mode 100644
index 0000000..6a79593
--- /dev/null
+++ b/internal/assistant/text_tool_calls.go
@@ -0,0 +1,176 @@
+package assistant
+
+import (
+ "encoding/json"
+ "html"
+ "regexp"
+ "strconv"
+ "strings"
+)
+
+const (
+ textToolNameField = "tool_name"
+ textToolOldTextKey = "old_text"
+ textToolNewTextKey = "new_text"
+ textToolFilePathField = "file_path"
+)
+
+var (
+ textToolUsePattern = regexp.MustCompile(`(?is)]*>(.*?)`)
+ textToolTagPattern = regexp.MustCompile(`(?is)<([a-zA-Z][a-zA-Z0-9_-]*)\b[^>]*>(.*?)[a-zA-Z][a-zA-Z0-9_-]*>`)
+)
+
+func textToolCallsFromText(text string) []toolCall {
+ matches := textToolUsePattern.FindAllStringSubmatch(text, -1)
+ if len(matches) == 0 {
+ return nil
+ }
+
+ calls := make([]toolCall, 0, len(matches))
+ for index, match := range matches {
+ fields := textToolFields(match[1])
+ name := normalizeTextToolName(firstTextToolField(fields, textToolNameField, jsonToolNameKey, jsonToolRole))
+ if name == "" {
+ continue
+ }
+ arguments := textToolArguments(name, fields)
+ argumentsJSON := encodeToolArguments(arguments)
+ calls = append(calls, toolCall{
+ Arguments: arguments,
+ ID: textToolCallID(index),
+ Name: name,
+ ArgumentsJSON: argumentsJSON,
+ TextFallback: true,
+ })
+ }
+
+ return calls
+}
+
+func textToolFields(content string) map[string]string {
+ fields := map[string]string{}
+ for _, match := range textToolTagPattern.FindAllStringSubmatch(content, -1) {
+ key := normalizeTextToolKey(match[1])
+ if key == "" || key == anthropicToolUseType {
+ continue
+ }
+ fields[key] = strings.TrimSpace(html.UnescapeString(match[2]))
+ }
+
+ return fields
+}
+
+func normalizeTextToolName(name string) string {
+ normalized := normalizeTextToolKey(name)
+ switch normalized {
+ case jsonReadToolName:
+ return jsonReadToolName
+ case jsonBashToolName, "shell", "sh", jsonCommandKey:
+ return jsonBashToolName
+ case jsonEditToolName, "replace":
+ return jsonEditToolName
+ case jsonWriteToolName, "create":
+ return jsonWriteToolName
+ case jsonGrepToolName, "search":
+ return jsonGrepToolName
+ case "find":
+ return "find"
+ case "ls", "list", "list_dir", "list_directory":
+ return "ls"
+ default:
+ return ""
+ }
+}
+
+func normalizeTextToolKey(value string) string {
+ value = strings.TrimSpace(strings.ToLower(value))
+ value = strings.ReplaceAll(value, "-", "_")
+ value = strings.ReplaceAll(value, " ", "_")
+
+ return value
+}
+
+func textToolArguments(name string, fields map[string]string) map[string]any {
+ arguments := map[string]any{}
+ for key, value := range fields {
+ if key == textToolNameField || key == jsonToolNameKey || key == jsonToolRole {
+ continue
+ }
+ arguments[textToolArgumentName(name, key)] = value
+ }
+
+ return arguments
+}
+
+func textToolArgumentName(toolName, fieldName string) string {
+ switch fieldName {
+ case textToolFilePathField, "filepath", "file", "filename":
+ return "path"
+ case textToolOldTextKey:
+ return jsonOldTextKey
+ case textToolNewTextKey:
+ return jsonNewTextKey
+ case "allow_ignored":
+ return "allowIgnored"
+ case "ignore_case":
+ return "ignoreCase"
+ default:
+ if toolName == jsonBashToolName && fieldName == "cmd" {
+ return jsonCommandKey
+ }
+ return fieldName
+ }
+}
+
+func firstTextToolField(fields map[string]string, names ...string) string {
+ for _, name := range names {
+ if value := strings.TrimSpace(fields[normalizeTextToolKey(name)]); value != "" {
+ return value
+ }
+ }
+
+ return ""
+}
+
+func encodeToolArguments(arguments map[string]any) string {
+ if len(arguments) == 0 {
+ return "{}"
+ }
+ encoded, err := json.Marshal(arguments)
+ if err != nil {
+ return "{}"
+ }
+
+ return string(encoded)
+}
+
+func textToolCallID(index int) string {
+ return "text_tool_call_" + strconv.Itoa(index+1)
+}
+
+func hasTextFallbackToolCalls(calls []toolCall) bool {
+ for _, call := range calls {
+ if call.TextFallback {
+ return true
+ }
+ }
+
+ return false
+}
+
+func textToolResultPrompt(events []ToolEvent) string {
+ parts := make([]string, 0, len(events))
+ for _, event := range events {
+ label := "Tool result for " + event.Name
+ body := strings.TrimSpace(event.Result)
+ if event.Error != "" {
+ body = strings.TrimSpace(event.Error)
+ }
+ if body == "" {
+ body = "(tool returned no text output)"
+ }
+ parts = append(parts, label+":\n"+body)
+ }
+
+ return strings.Join(parts, "\n\n")
+}
diff --git a/internal/assistant/text_tool_calls_test.go b/internal/assistant/text_tool_calls_test.go
new file mode 100644
index 0000000..4bc22ff
--- /dev/null
+++ b/internal/assistant/text_tool_calls_test.go
@@ -0,0 +1,145 @@
+//nolint:testpackage // Tests exercise unexported text fallback tool-call helpers.
+package assistant
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestTextToolCallsFromTextParsesXMLStyleToolUse(t *testing.T) {
+ t.Parallel()
+
+ text := `
+Read
+/tmp/README.md
+`
+
+ calls := textToolCallsFromText(text)
+ require.Len(t, calls, 1)
+ assert.Equal(t, "read", calls[0].Name)
+ assert.Equal(t, "/tmp/README.md", calls[0].Arguments[jsonPathKey])
+ assert.Equal(t, `{"path":"/tmp/README.md"}`, calls[0].ArgumentsJSON)
+ assert.True(t, calls[0].TextFallback)
+}
+
+func TestTextToolCallsFromTextMapsCommonFields(t *testing.T) {
+ t.Parallel()
+
+ text := `shellpwd`
+
+ calls := textToolCallsFromText(text)
+ require.Len(t, calls, 1)
+ assert.Equal(t, "bash", calls[0].Name)
+ assert.Equal(t, "pwd", calls[0].Arguments["command"])
+}
+
+func TestTextToolCallsFromTextIgnoresUnknownTools(t *testing.T) {
+ t.Parallel()
+
+ calls := textToolCallsFromText(`unknown`)
+
+ assert.Empty(t, calls)
+}
+
+func TestTextToolCallsFromTextParsesMultipleAndEscapedValues(t *testing.T) {
+ t.Parallel()
+
+ text := `
+Read
+README.md
+
+
+bash
+printf "hello"
+`
+
+ calls := textToolCallsFromText(text)
+ require.Len(t, calls, 2)
+ assert.Equal(t, "text_tool_call_1", calls[0].ID)
+ assert.Equal(t, "read", calls[0].Name)
+ assert.Equal(t, "README.md", calls[0].Arguments[jsonPathKey])
+ assert.Equal(t, "text_tool_call_2", calls[1].ID)
+ assert.Equal(t, "bash", calls[1].Name)
+ assert.Equal(t, `printf "hello"`, calls[1].Arguments[jsonCommandKey])
+}
+
+func TestTextToolCallsFromTextMapsToolNamesAndArguments(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ markup string
+ expectedTool string
+ expectedKey string
+ expectedValue string
+ }{
+ {
+ name: "write content",
+ markup: writeTextToolMarkup(),
+ expectedTool: jsonWriteToolName,
+ expectedKey: jsonPathKey,
+ expectedValue: "out.txt",
+ },
+ {
+ name: "edit old text",
+ markup: editTextToolMarkup(),
+ expectedTool: jsonEditToolName,
+ expectedKey: jsonOldTextKey,
+ expectedValue: "old",
+ },
+ {
+ name: "grep pattern",
+ markup: grepTextToolMarkup(),
+ expectedTool: jsonGrepToolName,
+ expectedKey: jsonPatternKey,
+ expectedValue: "TODO",
+ },
+ {
+ name: "ls path",
+ markup: `list_directory.`,
+ expectedTool: "ls",
+ expectedKey: jsonPathKey,
+ expectedValue: ".",
+ },
+ }
+ for _, testCase := range tests {
+ t.Run(testCase.name, func(t *testing.T) {
+ t.Parallel()
+
+ calls := textToolCallsFromText(testCase.markup)
+
+ require.Len(t, calls, 1)
+ assert.Equal(t, testCase.expectedTool, calls[0].Name)
+ assert.Equal(t, testCase.expectedValue, calls[0].Arguments[testCase.expectedKey])
+ })
+ }
+}
+
+func writeTextToolMarkup() string {
+ return `create` +
+ `out.txthello`
+}
+
+func editTextToolMarkup() string {
+ return `replace` +
+ `oldnew`
+}
+
+func grepTextToolMarkup() string {
+ return `search` +
+ `TODOtrue`
+}
+
+func TestTextToolResultPromptUsesErrorsAndEmptyFallback(t *testing.T) {
+ t.Parallel()
+
+ prompt := textToolResultPrompt([]ToolEvent{
+ {Name: "read", ArgumentsJSON: `{}`, DetailsJSON: "", Result: "", Error: "missing file"},
+ {Name: "bash", ArgumentsJSON: `{}`, DetailsJSON: "", Result: " ", Error: ""},
+ })
+
+ assert.Contains(t, prompt, "Tool result for read:\nmissing file")
+ assert.Contains(t, prompt, "Tool result for bash:\n(tool returned no text output)")
+}
diff --git a/internal/assistant/tool_loop.go b/internal/assistant/tool_loop.go
index c5efe72..f0b842b 100644
--- a/internal/assistant/tool_loop.go
+++ b/internal/assistant/tool_loop.go
@@ -48,7 +48,12 @@ func executeToolCalls(
Text: call.Name,
})
if onToolCall != nil {
- onToolCall(ctx, ToolCallEvent(call))
+ onToolCall(ctx, ToolCallEvent{
+ Arguments: call.Arguments,
+ ID: call.ID,
+ Name: call.Name,
+ ArgumentsJSON: call.ArgumentsJSON,
+ })
}
result, err := registry.Execute(ctx, call.Name, call.Arguments)
resultText := result.Text()
@@ -88,6 +93,16 @@ func executeToolCalls(
return outputs, events
}
+func finishTextResult(result *CompletionResult, text, emptyCode string) (bool, error) {
+ trimmed := strings.TrimSpace(text)
+ if trimmed == "" {
+ return false, emptyProviderResponseError(emptyCode)
+ }
+ result.Text = trimmed
+
+ return true, nil
+}
+
func emitStreamEvent(onEvent func(StreamEvent), event StreamEvent) {
if onEvent != nil {
onEvent(event)
diff --git a/internal/assistant/tool_loop_test.go b/internal/assistant/tool_loop_test.go
index e7f81c1..e0f19b7 100644
--- a/internal/assistant/tool_loop_test.go
+++ b/internal/assistant/tool_loop_test.go
@@ -18,11 +18,11 @@ func TestValidateToolCallsRejectsMissingFields(t *testing.T) {
}{
{
name: "missing id",
- call: toolCall{Arguments: nil, ID: "", Name: "read", ArgumentsJSON: ""},
+ call: toolCall{Arguments: nil, ID: "", Name: jsonReadToolName, ArgumentsJSON: "", TextFallback: false},
},
{
name: "missing name",
- call: toolCall{Arguments: nil, ID: "call-1", Name: "", ArgumentsJSON: ""},
+ call: toolCall{Arguments: nil, ID: "call-1", Name: "", ArgumentsJSON: "", TextFallback: false},
},
}
for _, tt := range tests {
@@ -49,6 +49,7 @@ func TestExecuteToolCallsInvokesCallbacksAndStreamsEvents(t *testing.T) {
ID: "call-1",
Name: "read",
ArgumentsJSON: `{"path":"missing.txt"}`,
+ TextFallback: false,
}},
func(event StreamEvent) {
streamEvents = append(streamEvents, event)
@@ -88,3 +89,31 @@ func TestEncodeToolDetailsReturnsEmptyForInvalidDetails(t *testing.T) {
encoded := encodeToolDetails(map[string]any{"bad": func() {}})
assert.Empty(t, encoded)
}
+
+func TestOpenAIChatToolMessagesRejectsMismatchedCallsAndEvents(t *testing.T) {
+ t.Parallel()
+
+ messages, err := openAIChatToolMessages(
+ []toolCall{{Arguments: nil, ID: "call_1", Name: jsonReadToolName, ArgumentsJSON: `{}`, TextFallback: false}},
+ nil,
+ )
+
+ require.Error(t, err)
+ assert.Nil(t, messages)
+ assert.Contains(t, err.Error(), "mismatched tool calls and results")
+}
+
+func TestOpenAIChatToolMessagesUsesCallIDs(t *testing.T) {
+ t.Parallel()
+
+ messages, err := openAIChatToolMessages(
+ []toolCall{{Arguments: nil, ID: "call_1", Name: jsonReadToolName, ArgumentsJSON: `{}`, TextFallback: false}},
+ []ToolEvent{{Name: jsonReadToolName, ArgumentsJSON: `{}`, DetailsJSON: "", Result: "ok", Error: ""}},
+ )
+
+ require.NoError(t, err)
+ require.Len(t, messages, 1)
+ assert.Equal(t, jsonToolRole, messages[0][jsonRoleKey])
+ assert.Equal(t, "call_1", messages[0]["tool_call_id"])
+ assert.Equal(t, "ok", messages[0][jsonContentKey])
+}
diff --git a/internal/assistant/tool_schema.go b/internal/assistant/tool_schema.go
index 59c6357..bdcc216 100644
--- a/internal/assistant/tool_schema.go
+++ b/internal/assistant/tool_schema.go
@@ -1,9 +1,15 @@
package assistant
import (
+ "encoding/json"
+
+ "github.com/samber/oops"
+
"github.com/omarluq/librecode/internal/tool"
)
+const maxToolIterations = 8
+
func responseTools() []map[string]any {
definitions := tool.AllDefinitions()
tools := make([]map[string]any, 0, len(definitions))
@@ -20,6 +26,39 @@ func responseTools() []map[string]any {
return tools
}
+func openAIChatTools() []map[string]any {
+ definitions := tool.AllDefinitions()
+ tools := make([]map[string]any, 0, len(definitions))
+ for _, definition := range definitions {
+ tools = append(tools, map[string]any{
+ jsonTypeKey: functionToolType,
+ "function": map[string]any{
+ jsonToolNameKey: string(definition.Name),
+ jsonDescriptionKey: definition.Description,
+ jsonToolParamsKey: toolParameterSchema(definition.Name),
+ },
+ })
+ }
+
+ return tools
+}
+
+func toolArgumentsFromJSON(argumentsJSON string) map[string]any {
+ arguments := map[string]any{}
+ if argumentsJSON == "" {
+ return arguments
+ }
+ if err := json.Unmarshal([]byte(argumentsJSON), &arguments); err != nil {
+ return map[string]any{}
+ }
+
+ return arguments
+}
+
+func toolIterationLimitError() error {
+ return oops.In("assistant").Code("tool_iteration_limit").Errorf("tool iteration limit reached")
+}
+
func toolParameterSchema(name tool.Name) map[string]any {
var schema map[string]any
switch name {
@@ -65,10 +104,10 @@ func bashToolSchema() map[string]any {
return map[string]any{
jsonTypeKey: jsonObjectType,
jsonPropertiesKey: map[string]any{
- "command": stringSchema("Bash command to execute in the current workspace."),
- "timeout": numberSchema("Optional timeout in seconds."),
+ jsonCommandKey: stringSchema("Bash command to execute in the current workspace."),
+ "timeout": numberSchema("Optional timeout in seconds."),
},
- jsonRequiredKey: []string{"command"},
+ jsonRequiredKey: []string{jsonCommandKey},
}
}
@@ -89,12 +128,12 @@ func editItemsSchema() map[string]any {
"items": map[string]any{
jsonTypeKey: jsonObjectType,
jsonPropertiesKey: map[string]any{
- "oldText": stringSchema(
+ jsonOldTextKey: stringSchema(
"Exact text to replace. Must match a unique, non-overlapping region.",
),
- "newText": stringSchema("Replacement text."),
+ jsonNewTextKey: stringSchema("Replacement text."),
},
- jsonRequiredKey: []string{"oldText", "newText"},
+ jsonRequiredKey: []string{jsonOldTextKey, jsonNewTextKey},
},
}
}