Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package aibridge
import (
"encoding/json"
"errors"
"strings"

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/packages/param"
Expand Down Expand Up @@ -51,22 +50,23 @@ func (b *MessageNewParamsWrapper) LastUserPrompt() (*string, error) {
return nil, errors.New("no messages")
}

var userMessage string
for i := len(b.Messages) - 1; i >= 0; i-- {
m := b.Messages[i]
if m.Role != anthropic.MessageParamRoleUser {
continue
}
if len(m.Content) == 0 {
continue
}
// We only care if the last message was issued by a user.
msg := b.Messages[len(b.Messages)-1]
if msg.Role != anthropic.MessageParamRoleUser {
return nil, nil
}

for j := len(m.Content) - 1; j >= 0; j-- {
if textContent := m.Content[j].GetText(); textContent != nil {
userMessage = *textContent
}
if len(msg.Content) == 0 {
return nil, nil
}

return utils.PtrTo(strings.TrimSpace(userMessage)), nil
// Walk backwards on "user"-initiated message content. Clients often inject
// content ahead of the actual prompt to provide context to the model,
// so the last item in the slice is most likely the user's prompt.
for i := len(msg.Content) - 1; i >= 0; i-- {
// Only text content is supported currently.
if textContent := msg.Content[i].GetText(); textContent != nil {
return textContent, nil
}
}

Expand Down
172 changes: 172 additions & 0 deletions anthropic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package aibridge_test

import (
"testing"

"github.com/anthropics/anthropic-sdk-go"
"github.com/coder/aibridge"
"github.com/stretchr/testify/require"
)

func TestAnthropicLastUserPrompt(t *testing.T) {
t.Parallel()

tests := []struct {
name string
wrapper *aibridge.MessageNewParamsWrapper
expected string
expectError bool
errorMsg string
}{
{
name: "nil struct",
expectError: true,
errorMsg: "nil struct",
},
{
name: "no messages",
wrapper: &aibridge.MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Messages: []anthropic.MessageParam{},
},
},
expectError: true,
errorMsg: "no messages",
},
{
name: "last message not from user",
wrapper: &aibridge.MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Messages: []anthropic.MessageParam{
{
Role: anthropic.MessageParamRoleUser,
Content: []anthropic.ContentBlockParamUnion{
anthropic.NewTextBlock("user message"),
},
},
{
Role: anthropic.MessageParamRoleAssistant,
Content: []anthropic.ContentBlockParamUnion{
anthropic.NewTextBlock("assistant message"),
},
},
},
},
},
},
{
name: "last user message with empty content",
wrapper: &aibridge.MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Messages: []anthropic.MessageParam{
{
Role: anthropic.MessageParamRoleUser,
Content: []anthropic.ContentBlockParamUnion{},
},
},
},
},
},
{
name: "last user message with single text content",
wrapper: &aibridge.MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Messages: []anthropic.MessageParam{
{
Role: anthropic.MessageParamRoleUser,
Content: []anthropic.ContentBlockParamUnion{
anthropic.NewTextBlock("Hello, world!"),
},
},
},
},
},
expected: "Hello, world!",
},
{
name: "last user message with multiple content blocks - text at end",
wrapper: &aibridge.MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Messages: []anthropic.MessageParam{
{
Role: anthropic.MessageParamRoleUser,
Content: []anthropic.ContentBlockParamUnion{
anthropic.NewImageBlockBase64("image/png", "base64data"),
anthropic.NewTextBlock("First text"),
anthropic.NewImageBlockBase64("image/jpeg", "moredata"),
anthropic.NewTextBlock("Last text"),
},
},
},
},
},
expected: "Last text",
},
{
name: "last user message with only non-text content",
wrapper: &aibridge.MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Messages: []anthropic.MessageParam{
{
Role: anthropic.MessageParamRoleUser,
Content: []anthropic.ContentBlockParamUnion{
anthropic.NewImageBlockBase64("image/png", "base64data"),
anthropic.NewImageBlockBase64("image/jpeg", "moredata"),
},
},
},
},
},
},
{
name: "multiple messages with last being user",
wrapper: &aibridge.MessageNewParamsWrapper{
MessageNewParams: anthropic.MessageNewParams{
Messages: []anthropic.MessageParam{
{
Role: anthropic.MessageParamRoleUser,
Content: []anthropic.ContentBlockParamUnion{
anthropic.NewTextBlock("First user message"),
},
},
{
Role: anthropic.MessageParamRoleAssistant,
Content: []anthropic.ContentBlockParamUnion{
anthropic.NewTextBlock("Assistant response"),
},
},
{
Role: anthropic.MessageParamRoleUser,
Content: []anthropic.ContentBlockParamUnion{
anthropic.NewTextBlock("Second user message"),
},
},
},
},
},
expected: "Second user message",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := tt.wrapper.LastUserPrompt()

if tt.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
require.Nil(t, result)
} else {
require.NoError(t, err)
// Check pointer equality - both nil or both non-nil
if tt.expected == "" {
require.Nil(t, result)
} else {
require.NotNil(t, result)
// The result should point to the same string from the content block
require.Equal(t, tt.expected, *result)
}
}
})
}
}
2 changes: 1 addition & 1 deletion bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ func TestSimple(t *testing.T) {

// Then: I expect the prompt to have been tracked.
require.NotEmpty(t, recorderClient.userPrompts, "no prompts tracked")
assert.Equal(t, "how many angels can dance on the head of a pin", recorderClient.userPrompts[0].Prompt)
assert.Contains(t, recorderClient.userPrompts[0].Prompt, "how many angels can dance on the head of a pin")

// Validate that responses have their IDs overridden with a interception ID rather than the original ID from the upstream provider.
// The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting
Expand Down
30 changes: 17 additions & 13 deletions openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package aibridge
import (
"encoding/json"
"errors"
"strings"

"github.com/anthropics/anthropic-sdk-go/shared"
"github.com/anthropics/anthropic-sdk-go/shared/constant"
Expand Down Expand Up @@ -56,22 +55,27 @@ func (c *ChatCompletionNewParamsWrapper) LastUserPrompt() (*string, error) {
return nil, errors.New("no messages")
}

var msg *openai.ChatCompletionUserMessageParam
for i := len(c.Messages) - 1; i >= 0; i-- {
m := c.Messages[i]
if m.OfUser != nil {
msg = m.OfUser
break
}
// We only care if the last message was issued by a user.
msg := c.Messages[len(c.Messages)-1]
if msg.OfUser == nil {
return nil, nil
}

if msg == nil {
return nil, nil
if msg.OfUser.Content.OfString.String() != "" {
return utils.PtrTo(msg.OfUser.Content.OfString.String()), nil
}

// Walk backwards on "user"-initiated message content. Clients often inject
// content ahead of the actual prompt to provide context to the model,
// so the last item in the slice is most likely the user's prompt.
for i := len(msg.OfUser.Content.OfArrayOfContentParts) - 1; i >= 0; i-- {
// Only text content is supported currently.
if textContent := msg.OfUser.Content.OfArrayOfContentParts[i].OfText; textContent != nil {
return &textContent.Text, nil
}
}

return utils.PtrTo(strings.TrimSpace(
msg.Content.OfString.String(),
)), nil
return nil, nil
}

func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {
Expand Down
Loading