From 729c8a182b570122e04cadacc104dd39e74dac69 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 30 Sep 2025 08:46:41 +0200 Subject: [PATCH] chore: correct last user prompt detection Signed-off-by: Danny Kopping --- anthropic.go | 30 +++---- anthropic_test.go | 172 +++++++++++++++++++++++++++++++++++++ bridge_integration_test.go | 2 +- openai.go | 30 ++++--- openai_test.go | 133 ++++++++++++++++++++++++++++ 5 files changed, 338 insertions(+), 29 deletions(-) create mode 100644 anthropic_test.go create mode 100644 openai_test.go diff --git a/anthropic.go b/anthropic.go index c120ac9..66cf94e 100644 --- a/anthropic.go +++ b/anthropic.go @@ -3,7 +3,6 @@ package aibridge import ( "encoding/json" "errors" - "strings" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/param" @@ -51,22 +50,23 @@ func (b *MessageNewParamsWrapper) LastUserPrompt() (*string, error) { return nil, errors.New("no messages") } - var userMessage string - for i := len(b.Messages) - 1; i >= 0; i-- { - m := b.Messages[i] - if m.Role != anthropic.MessageParamRoleUser { - continue - } - if len(m.Content) == 0 { - continue - } + // We only care if the last message was issued by a user. + msg := b.Messages[len(b.Messages)-1] + if msg.Role != anthropic.MessageParamRoleUser { + return nil, nil + } - for j := len(m.Content) - 1; j >= 0; j-- { - if textContent := m.Content[j].GetText(); textContent != nil { - userMessage = *textContent - } + if len(msg.Content) == 0 { + return nil, nil + } - return utils.PtrTo(strings.TrimSpace(userMessage)), nil + // Walk backwards on "user"-initiated message content. Clients often inject + // content ahead of the actual prompt to provide context to the model, + // so the last item in the slice is most likely the user's prompt. + for i := len(msg.Content) - 1; i >= 0; i-- { + // Only text content is supported currently. + if textContent := msg.Content[i].GetText(); textContent != nil { + return textContent, nil } } diff --git a/anthropic_test.go b/anthropic_test.go new file mode 100644 index 0000000..93e063c --- /dev/null +++ b/anthropic_test.go @@ -0,0 +1,172 @@ +package aibridge_test + +import ( + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/coder/aibridge" + "github.com/stretchr/testify/require" +) + +func TestAnthropicLastUserPrompt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + wrapper *aibridge.MessageNewParamsWrapper + expected string + expectError bool + errorMsg string + }{ + { + name: "nil struct", + expectError: true, + errorMsg: "nil struct", + }, + { + name: "no messages", + wrapper: &aibridge.MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Messages: []anthropic.MessageParam{}, + }, + }, + expectError: true, + errorMsg: "no messages", + }, + { + name: "last message not from user", + wrapper: &aibridge.MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Messages: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleUser, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("user message"), + }, + }, + { + Role: anthropic.MessageParamRoleAssistant, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("assistant message"), + }, + }, + }, + }, + }, + }, + { + name: "last user message with empty content", + wrapper: &aibridge.MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Messages: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleUser, + Content: []anthropic.ContentBlockParamUnion{}, + }, + }, + }, + }, + }, + { + name: "last user message with single text content", + wrapper: &aibridge.MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Messages: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleUser, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("Hello, world!"), + }, + }, + }, + }, + }, + expected: "Hello, world!", + }, + { + name: "last user message with multiple content blocks - text at end", + wrapper: &aibridge.MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Messages: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleUser, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewImageBlockBase64("image/png", "base64data"), + anthropic.NewTextBlock("First text"), + anthropic.NewImageBlockBase64("image/jpeg", "moredata"), + anthropic.NewTextBlock("Last text"), + }, + }, + }, + }, + }, + expected: "Last text", + }, + { + name: "last user message with only non-text content", + wrapper: &aibridge.MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Messages: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleUser, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewImageBlockBase64("image/png", "base64data"), + anthropic.NewImageBlockBase64("image/jpeg", "moredata"), + }, + }, + }, + }, + }, + }, + { + name: "multiple messages with last being user", + wrapper: &aibridge.MessageNewParamsWrapper{ + MessageNewParams: anthropic.MessageNewParams{ + Messages: []anthropic.MessageParam{ + { + Role: anthropic.MessageParamRoleUser, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("First user message"), + }, + }, + { + Role: anthropic.MessageParamRoleAssistant, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("Assistant response"), + }, + }, + { + Role: anthropic.MessageParamRoleUser, + Content: []anthropic.ContentBlockParamUnion{ + anthropic.NewTextBlock("Second user message"), + }, + }, + }, + }, + }, + expected: "Second user message", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.wrapper.LastUserPrompt() + + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMsg) + require.Nil(t, result) + } else { + require.NoError(t, err) + // Check pointer equality - both nil or both non-nil + if tt.expected == "" { + require.Nil(t, result) + } else { + require.NotNil(t, result) + // The result should point to the same string from the content block + require.Equal(t, tt.expected, *result) + } + } + }) + } +} diff --git a/bridge_integration_test.go b/bridge_integration_test.go index de64532..66fb840 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -426,7 +426,7 @@ func TestSimple(t *testing.T) { // Then: I expect the prompt to have been tracked. require.NotEmpty(t, recorderClient.userPrompts, "no prompts tracked") - assert.Equal(t, "how many angels can dance on the head of a pin", recorderClient.userPrompts[0].Prompt) + assert.Contains(t, recorderClient.userPrompts[0].Prompt, "how many angels can dance on the head of a pin") // Validate that responses have their IDs overridden with a interception ID rather than the original ID from the upstream provider. // The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting diff --git a/openai.go b/openai.go index 5434a80..b296123 100644 --- a/openai.go +++ b/openai.go @@ -3,7 +3,6 @@ package aibridge import ( "encoding/json" "errors" - "strings" "github.com/anthropics/anthropic-sdk-go/shared" "github.com/anthropics/anthropic-sdk-go/shared/constant" @@ -56,22 +55,27 @@ func (c *ChatCompletionNewParamsWrapper) LastUserPrompt() (*string, error) { return nil, errors.New("no messages") } - var msg *openai.ChatCompletionUserMessageParam - for i := len(c.Messages) - 1; i >= 0; i-- { - m := c.Messages[i] - if m.OfUser != nil { - msg = m.OfUser - break - } + // We only care if the last message was issued by a user. + msg := c.Messages[len(c.Messages)-1] + if msg.OfUser == nil { + return nil, nil } - if msg == nil { - return nil, nil + if msg.OfUser.Content.OfString.String() != "" { + return utils.PtrTo(msg.OfUser.Content.OfString.String()), nil + } + + // Walk backwards on "user"-initiated message content. Clients often inject + // content ahead of the actual prompt to provide context to the model, + // so the last item in the slice is most likely the user's prompt. + for i := len(msg.OfUser.Content.OfArrayOfContentParts) - 1; i >= 0; i-- { + // Only text content is supported currently. + if textContent := msg.OfUser.Content.OfArrayOfContentParts[i].OfText; textContent != nil { + return &textContent.Text, nil + } } - return utils.PtrTo(strings.TrimSpace( - msg.Content.OfString.String(), - )), nil + return nil, nil } func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage { diff --git a/openai_test.go b/openai_test.go new file mode 100644 index 0000000..35e52f2 --- /dev/null +++ b/openai_test.go @@ -0,0 +1,133 @@ +package aibridge_test + +import ( + "testing" + + "github.com/coder/aibridge" + "github.com/openai/openai-go/v2" + "github.com/stretchr/testify/require" +) + +func TestOpenAILastUserPrompt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + wrapper *aibridge.ChatCompletionNewParamsWrapper + expected string + expectError bool + errorMsg string + }{ + { + name: "nil struct", + expectError: true, + errorMsg: "nil struct", + }, + { + name: "no messages", + wrapper: &aibridge.ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{}, + }, + }, + expectError: true, + errorMsg: "no messages", + }, + { + name: "last message not from user", + wrapper: &aibridge.ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("user message"), + openai.AssistantMessage("assistant message"), + }, + }, + }, + }, + { + name: "user message with string content", + wrapper: &aibridge.ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Hello, world!"), + }, + }, + }, + expected: "Hello, world!", + }, + { + name: "user message with empty string", + wrapper: &aibridge.ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(""), + }, + }, + }, + }, + { + name: "user message with array content - text at end", + wrapper: &aibridge.ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage([]openai.ChatCompletionContentPartUnionParam{ + openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{ + URL: "https://example.com/image.png", + }), + openai.TextContentPart("First text"), + openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{ + URL: "https://example.com/image2.png", + }), + openai.TextContentPart("Last text"), + }), + }, + }, + }, + expected: "Last text", + }, + { + name: "user message with array content - no text", + wrapper: &aibridge.ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage([]openai.ChatCompletionContentPartUnionParam{ + openai.ImageContentPart(openai.ChatCompletionContentPartImageImageURLParam{ + URL: "https://example.com/image.png", + }), + }), + }, + }, + }, + }, + { + name: "user message with empty array", + wrapper: &aibridge.ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage([]openai.ChatCompletionContentPartUnionParam{}), + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := tt.wrapper.LastUserPrompt() + + if tt.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMsg) + require.Nil(t, result) + } else { + require.NoError(t, err) + if tt.expected == "" { + require.Nil(t, result) + } else { + require.NotNil(t, result) + require.Equal(t, tt.expected, *result) + } + } + }) + } +}