Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ func NewMetrics(reg prometheus.Registerer) *metrics.Metrics {
}

func NewRecorder(logger slog.Logger, tracer trace.Tracer, clientFn func() (Recorder, error)) Recorder {
return recorder.NewRecorder(logger, tracer, clientFn)
return recorder.NewWrappedRecorder(logger, tracer, clientFn)
}
3 changes: 2 additions & 1 deletion circuitbreaker/circuitbreaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,11 @@ func TestExecute_OnStateChange(t *testing.T) {

// Trip circuit
w := httptest.NewRecorder()
cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error {
err := cbs.Execute(endpoint, model, w, func(rw http.ResponseWriter) error {
rw.WriteHeader(http.StatusTooManyRequests)
return nil
})
assert.NoError(t, err)

// Verify state change callback was called with correct parameters
assert.Len(t, stateChanges, 1)
Expand Down
5 changes: 3 additions & 2 deletions intercept/apidump/apidump.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@ func (d *dumper) dumpRequest(req *http.Request) error {
if err != nil {
return xerrors.Errorf("write request header terminator: %w", err)
}
buf.Write(prettyBody)
buf.WriteByte('\n')
// bytes.Buffer writes to in-memory storage and never return errors.
_, _ = buf.Write(prettyBody)
_ = buf.WriteByte('\n')

return os.WriteFile(dumpPath, buf.Bytes(), 0o600)
}
Expand Down
2 changes: 1 addition & 1 deletion intercept/apidump/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (s *streamingBodyDumper) init() {
// Write headers first.
if _, err := s.file.Write(s.headerData); err != nil {
s.initErr = xerrors.Errorf("write headers: %w", err)
s.file.Close()
_ = s.file.Close() // best-effort cleanup on header write failure
s.file = nil
}
})
Expand Down
8 changes: 5 additions & 3 deletions intercept/apidump/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ func TestMiddleware_StreamingResponse(t *testing.T) {
// Create a pipe to simulate streaming
pr, pw := io.Pipe()
go func() {
defer pw.Close() //nolint:revive // error handled via pipe read side
for _, chunk := range chunks {
pw.Write([]byte(chunk))
if _, err := pw.Write([]byte(chunk)); err != nil {
return
}
}
pw.Close()
}()

resp, err := middleware(req, func(r *http.Request) (*http.Response, error) {
Expand All @@ -65,7 +67,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) {
for {
n, err := resp.Body.Read(buf)
if n > 0 {
receivedData.Write(buf[:n])
_, _ = receivedData.Write(buf[:n]) // bytes.Buffer.Write never fails
}
if err == io.EOF {
break
Expand Down
7 changes: 4 additions & 3 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,11 @@ func (i *StreamingInterception) marshalErr(err error) ([]byte, error) {
}

func (*StreamingInterception) encodeForStream(payload []byte) []byte {
// bytes.Buffer writes to in-memory storage and never return errors.
var buf bytes.Buffer
buf.WriteString("data: ")
buf.Write(payload)
buf.WriteString("\n\n")
_, _ = buf.WriteString("data: ")
_, _ = buf.Write(payload)
_, _ = buf.WriteString("\n\n")
return buf.Bytes()
}

Expand Down
2 changes: 1 addition & 1 deletion intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ var bedrockSupportedBetaFlags = map[string]bool{
type interceptionBase struct {
id uuid.UUID
providerName string
reqPayload MessagesRequestPayload
reqPayload RequestPayload

cfg aibconfig.Anthropic
bedrockCfg *aibconfig.AWSBedrock
Expand Down
4 changes: 2 additions & 2 deletions intercept/messages/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -763,10 +763,10 @@ func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) {
}
}

func mustMessagesPayload(t *testing.T, requestBody string) MessagesRequestPayload {
func mustMessagesPayload(t *testing.T, requestBody string) RequestPayload {
t.Helper()

payload, err := NewMessagesRequestPayload([]byte(requestBody))
payload, err := NewRequestPayload([]byte(requestBody))
require.NoError(t, err)

return payload
Expand Down
2 changes: 1 addition & 1 deletion intercept/messages/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type BlockingInterception struct {

func NewBlockingInterceptor(
id uuid.UUID,
reqPayload MessagesRequestPayload,
reqPayload RequestPayload,
providerName string,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
Expand Down
40 changes: 20 additions & 20 deletions intercept/messages/reqpayload.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,35 +82,35 @@ var (
}
)

// MessagesRequestPayload is raw JSON bytes of an Anthropic Messages API request.
// RequestPayload is raw JSON bytes of an Anthropic Messages API request.
// Methods provide package-specific reads and rewrites while preserving the
// original body for upstream pass-through.
type MessagesRequestPayload []byte
type RequestPayload []byte

func NewMessagesRequestPayload(raw []byte) (MessagesRequestPayload, error) {
func NewRequestPayload(raw []byte) (RequestPayload, error) {
if len(bytes.TrimSpace(raw)) == 0 {
return nil, xerrors.New("messages empty request body")
}
if !json.Valid(raw) {
return nil, xerrors.New("messages invalid JSON request body")
}

return MessagesRequestPayload(raw), nil
return RequestPayload(raw), nil
}

func (p MessagesRequestPayload) Stream() bool {
func (p RequestPayload) Stream() bool {
v := gjson.GetBytes(p, messagesReqPathStream)
if !v.IsBool() {
return false
}
return v.Bool()
}

func (p MessagesRequestPayload) model() string {
func (p RequestPayload) model() string {
return gjson.GetBytes(p, messagesReqPathModel).Str
}

func (p MessagesRequestPayload) correlatingToolCallID() *string {
func (p RequestPayload) correlatingToolCallID() *string {
messages := gjson.GetBytes(p, messagesReqPathMessages)
if !messages.IsArray() {
return nil
Expand Down Expand Up @@ -147,7 +147,7 @@ func (p MessagesRequestPayload) correlatingToolCallID() *string {
// lastUserPrompt returns the prompt text from the last user message. If no prompt
// is found, it returns empty string, false, nil. Unexpected shapes are treated as
// unsupported and do not fail the request path.
func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) {
func (p RequestPayload) lastUserPrompt() (string, bool, error) {
messages := gjson.GetBytes(p, messagesReqPathMessages)
if !messages.Exists() || messages.Type == gjson.Null {
return "", false, nil
Expand Down Expand Up @@ -195,7 +195,7 @@ func (p MessagesRequestPayload) lastUserPrompt() (string, bool, error) {
return "", false, nil
}

func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam) (MessagesRequestPayload, error) {
func (p RequestPayload) injectTools(injected []anthropic.ToolUnionParam) (RequestPayload, error) {
if len(injected) == 0 {
return p, nil
}
Expand All @@ -221,7 +221,7 @@ func (p MessagesRequestPayload) injectTools(injected []anthropic.ToolUnionParam)
return p.set(messagesReqPathTools, allTools)
}

func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPayload, error) {
func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) {
toolChoice := gjson.GetBytes(p, messagesReqPathToolChoice)

// If no tool_choice was defined, assume auto.
Expand Down Expand Up @@ -258,7 +258,7 @@ func (p MessagesRequestPayload) disableParallelToolCalls() (MessagesRequestPaylo
}
}

func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (MessagesRequestPayload, error) {
func (p RequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (RequestPayload, error) {
if len(newMessages) == 0 {
return p, nil
}
Expand All @@ -285,11 +285,11 @@ func (p MessagesRequestPayload) appendedMessages(newMessages []anthropic.Message
return p.set(messagesReqPathMessages, allMessages)
}

func (p MessagesRequestPayload) withModel(model string) (MessagesRequestPayload, error) {
func (p RequestPayload) withModel(model string) (RequestPayload, error) {
return p.set(messagesReqPathModel, model)
}

func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) {
func (p RequestPayload) messages() ([]json.RawMessage, error) {
messages := gjson.GetBytes(p, messagesReqPathMessages)
if !messages.Exists() || messages.Type == gjson.Null {
return nil, nil
Expand All @@ -301,7 +301,7 @@ func (p MessagesRequestPayload) messages() ([]json.RawMessage, error) {
return p.resultToRawMessage(messages.Array()), nil
}

func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) {
func (p RequestPayload) tools() ([]json.RawMessage, error) {
tools := gjson.GetBytes(p, messagesReqPathTools)
if !tools.Exists() || tools.Type == gjson.Null {
return nil, nil
Expand All @@ -313,7 +313,7 @@ func (p MessagesRequestPayload) tools() ([]json.RawMessage, error) {
return p.resultToRawMessage(tools.Array()), nil
}

func (MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage {
func (RequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage {
// gjson.Result conversion to json.RawMessage is needed because
// gjson.Result does not implement json.Marshaler — would
// serialize its struct fields instead of the raw JSON it represents.
Expand All @@ -326,7 +326,7 @@ func (MessagesRequestPayload) resultToRawMessage(items []gjson.Result) []json.Ra

// convertAdaptiveThinkingForBedrock converts thinking.type "adaptive" to "enabled" with a calculated budget_tokens
// conversion is needed for Bedrock models that does not support the "adaptive" thinking.type
func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesRequestPayload, error) {
func (p RequestPayload) convertAdaptiveThinkingForBedrock() (RequestPayload, error) {
thinkingType := gjson.GetBytes(p, messagesReqPathThinkingType)
if thinkingType.String() != constAdaptive {
return p, nil
Expand Down Expand Up @@ -377,7 +377,7 @@ func (p MessagesRequestPayload) convertAdaptiveThinkingForBedrock() (MessagesReq
// removed when the corresponding flag is absent from the Anthropic-Beta header.
// Model-specific beta flags must already be filtered from the header before
// calling this method (see filterBedrockBetaFlags).
func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Header) (MessagesRequestPayload, error) {
func (p RequestPayload) removeUnsupportedBedrockFields(headers http.Header) (RequestPayload, error) {
var payloadMap map[string]any
if err := json.Unmarshal(p, &payloadMap); err != nil {
return p, xerrors.Errorf("failed to unmarshal request payload when removing unsupported Bedrock fields: %w", err)
Expand All @@ -400,13 +400,13 @@ func (p MessagesRequestPayload) removeUnsupportedBedrockFields(headers http.Head
if err != nil {
return p, xerrors.Errorf("failed to marshal request payload when removing unsupported Bedrock fields: %w", err)
}
return MessagesRequestPayload(result), nil
return RequestPayload(result), nil
}

func (p MessagesRequestPayload) set(path string, value any) (MessagesRequestPayload, error) {
func (p RequestPayload) set(path string, value any) (RequestPayload, error) {
out, err := sjson.SetBytes(p, path, value)
if err != nil {
return p, xerrors.Errorf("set %s: %w", path, err)
}
return MessagesRequestPayload(out), nil
return RequestPayload(out), nil
}
22 changes: 11 additions & 11 deletions intercept/messages/reqpayload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/coder/aibridge/utils"
)

func TestNewMessagesRequestPayload(t *testing.T) {
func TestNewRequestPayload(t *testing.T) {
t.Parallel()

testCases := []struct {
Expand Down Expand Up @@ -42,20 +42,20 @@ func TestNewMessagesRequestPayload(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()

payload, err := NewMessagesRequestPayload(testCase.requestBody)
payload, err := NewRequestPayload(testCase.requestBody)
if testCase.expectError {
require.Error(t, err)
require.Nil(t, payload)
return
}

require.NoError(t, err)
require.Equal(t, MessagesRequestPayload(testCase.requestBody), payload)
require.Equal(t, RequestPayload(testCase.requestBody), payload)
})
}
}

func TestMessagesRequestPayloadStream(t *testing.T) {
func TestRequestPayloadStream(t *testing.T) {
t.Parallel()

testCases := []struct {
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestMessagesRequestPayloadStream(t *testing.T) {
}
}

func TestMessagesRequestPayloadModel(t *testing.T) {
func TestRequestPayloadModel(t *testing.T) {
t.Parallel()

testCases := []struct {
Expand Down Expand Up @@ -132,7 +132,7 @@ func TestMessagesRequestPayloadModel(t *testing.T) {
}
}

func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) {
func TestRequestPayloadLastUserPrompt(t *testing.T) {
t.Parallel()

testCases := []struct {
Expand Down Expand Up @@ -229,7 +229,7 @@ func TestMessagesRequestPayloadLastUserPrompt(t *testing.T) {
}
}

func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) {
func TestRequestPayloadCorrelatingToolCallID(t *testing.T) {
t.Parallel()

testCases := []struct {
Expand Down Expand Up @@ -266,7 +266,7 @@ func TestMessagesRequestPayloadCorrelatingToolCallID(t *testing.T) {
}
}

func TestMessagesRequestPayloadInjectTools(t *testing.T) {
func TestRequestPayloadInjectTools(t *testing.T) {
t.Parallel()

payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}],"tools":[{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`)
Expand All @@ -291,7 +291,7 @@ func TestMessagesRequestPayloadInjectTools(t *testing.T) {
require.Equal(t, "ephemeral", toolItems[1].Get("cache_control.type").String())
}

func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
func TestRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
t.Parallel()

testCases := []struct {
Expand Down Expand Up @@ -361,7 +361,7 @@ func TestMessagesRequestPayloadConvertAdaptiveThinkingForBedrock(t *testing.T) {
}
}

func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) {
func TestRequestPayloadDisableParallelToolCalls(t *testing.T) {
t.Parallel()

testCases := []struct {
Expand Down Expand Up @@ -451,7 +451,7 @@ func TestMessagesRequestPayloadDisableParallelToolCalls(t *testing.T) {
}
}

func TestMessagesRequestPayloadAppendedMessages(t *testing.T) {
func TestRequestPayloadAppendedMessages(t *testing.T) {
t.Parallel()

payload := mustMessagesPayload(t, `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hello"}]}`)
Expand Down
15 changes: 8 additions & 7 deletions intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type StreamingInterception struct {

func NewStreamingInterceptor(
id uuid.UUID,
reqPayload MessagesRequestPayload,
reqPayload RequestPayload,
providerName string,
cfg config.Anthropic,
bedrockCfg *config.AWSBedrock,
Expand Down Expand Up @@ -573,13 +573,14 @@ func (i *StreamingInterception) pingPayload() []byte {
}

func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte {
// bytes.Buffer writes to in-memory storage and never return errors.
var buf bytes.Buffer
buf.WriteString("event: ")
buf.WriteString(typ)
buf.WriteString("\n")
buf.WriteString("data: ")
buf.Write(payload)
buf.WriteString("\n\n")
_, _ = buf.WriteString("event: ")
_, _ = buf.WriteString(typ)
_, _ = buf.WriteString("\n")
_, _ = buf.WriteString("data: ")
_, _ = buf.Write(payload)
_, _ = buf.WriteString("\n\n")
return buf.Bytes()
}

Expand Down
2 changes: 1 addition & 1 deletion intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type responsesInterceptionBase struct {
// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
authHeaderName string
reqPayload ResponsesRequestPayload
reqPayload RequestPayload

cfg config.OpenAI
recorder recorder.Recorder
Expand Down
Loading
Loading