Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions buildinfo/buildinfo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (
)

func TestBuildInfo(t *testing.T) {
t.Parallel()

t.Run("Version", func(t *testing.T) {
t.Parallel()
// Should return a non-empty version
version := buildinfo.Version()
assert.NotEmpty(t, version)
Expand Down
8 changes: 4 additions & 4 deletions circuitbreaker/circuitbreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func (p *ProviderCircuitBreakers) isFailure(statusCode int) bool {
return DefaultIsFailure(statusCode)
}

// openErrorResponse returns the error response body when the circuit is open.
func (p *ProviderCircuitBreakers) openErrorResponse() []byte {
// openErrBody returns the error response body when the circuit is open.
func (p *ProviderCircuitBreakers) openErrBody() []byte {
if p.config.OpenErrorResponse != nil {
return p.config.OpenErrorResponse()
}
Expand Down Expand Up @@ -167,7 +167,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds())))
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write(p.openErrorResponse())
_, _ = w.Write(p.openErrBody())
return ErrCircuitOpen
}

Expand All @@ -187,7 +187,7 @@ func (p *ProviderCircuitBreakers) Provider() string {
// OpenErrorResponse returns the error response body when the circuit is open.
// This is exposed for handlers to use when responding to rejected requests.
func (p *ProviderCircuitBreakers) OpenErrorResponse() []byte {
return p.openErrorResponse()
return p.openErrBody()
}

// StateToGaugeValue converts gobreaker.State to a gauge value.
Expand Down
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestGuessClient(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

req, err := http.NewRequest(http.MethodGet, "", nil)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "", nil)
require.NoError(t, err)

req.Header.Set("User-Agent", tt.userAgent)
Expand Down
16 changes: 8 additions & 8 deletions intercept/apidump/apidump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)

req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`)))
require.NoError(t, err)

// Add sensitive headers that should be redacted
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestBridgedMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)

req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)

// Call middleware with a response containing sensitive headers
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) {
require.NotNil(t, middleware)

originalBody := `{"messages": [{"role": "user", "content": "hello"}]}`
req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody)))
require.NoError(t, err)

var capturedBody []byte
Expand Down Expand Up @@ -201,7 +201,7 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk)
require.NotNil(t, middleware)

req, err := http.NewRequest(http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)

resp3, err := middleware(req, func(r *http.Request) (*http.Response, error) {
Expand Down Expand Up @@ -281,7 +281,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)

req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)

// Set all sensitive headers
Expand Down Expand Up @@ -359,7 +359,7 @@ func TestPassthroughMiddleware(t *testing.T) {

rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk)

req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/models", nil)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "https://api.openai.com/v1/models", nil)
require.NoError(t, err)

resp, err := rt.RoundTrip(req) //nolint:bodyclose // resp is nil on error
Expand Down Expand Up @@ -403,7 +403,7 @@ func TestPassthroughMiddleware(t *testing.T) {

rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk)

req, err := http.NewRequest(http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body)))
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer sk-secret-key-12345")
resp, err := rt.RoundTrip(req)
Expand All @@ -413,7 +413,7 @@ func TestPassthroughMiddleware(t *testing.T) {
require.NoError(t, resp.Body.Close())

// Second request should create new req/resp files
req2, err := http.NewRequest(http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body)))
req2, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body)))
require.NoError(t, err)
resp2, err := rt.RoundTrip(req2)
require.NoError(t, err)
Expand Down
4 changes: 2 additions & 2 deletions intercept/apidump/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)

req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)

// Simulate a streaming response with multiple chunks
Expand Down Expand Up @@ -106,7 +106,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) {
middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk)
require.NotNil(t, middleware)

req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`)))
require.NoError(t, err)

originalRespBody := `{"choices": [{"message": {"content": "hi"}}]}`
Expand Down
2 changes: 2 additions & 0 deletions intercept/chatcompletions/paramswrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ func TestOpenAILastUserPrompt(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

result, err := tt.wrapper.lastUserPrompt()

if tt.expectError {
Expand Down
2 changes: 1 addition & 1 deletion intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
if interceptionErr != nil {
payload, err := i.marshalErr(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr))
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error()))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
}
Expand Down
2 changes: 1 addition & 1 deletion intercept/eventstream/eventstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func flush(w http.ResponseWriter) (err error) {
}

defer func() {
if r := recover(); r != nil { //nolint:revive // Intentionally swallowed; likely a broken connection.
if r := recover(); r != nil { //nolint:revive,staticcheck // Intentionally swallowed; likely a broken connection.
}
}()

Expand Down
13 changes: 13 additions & 0 deletions intercept/messages/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ func TestAWSBedrockValidation(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

base := &interceptionBase{}
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)

Expand All @@ -212,7 +214,10 @@ func TestAWSBedrockValidation(t *testing.T) {
}

func TestAccumulateUsage(t *testing.T) {
t.Parallel()

t.Run("Usage to Usage", func(t *testing.T) {
t.Parallel()
dest := &anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
Expand Down Expand Up @@ -253,6 +258,8 @@ func TestAccumulateUsage(t *testing.T) {
})

t.Run("MessageDeltaUsage to MessageDeltaUsage", func(t *testing.T) {
t.Parallel()

dest := &anthropic.MessageDeltaUsage{
InputTokens: 10,
OutputTokens: 20,
Expand Down Expand Up @@ -283,6 +290,8 @@ func TestAccumulateUsage(t *testing.T) {
})

t.Run("Usage to MessageDeltaUsage", func(t *testing.T) {
t.Parallel()

dest := &anthropic.MessageDeltaUsage{
InputTokens: 10,
OutputTokens: 20,
Expand Down Expand Up @@ -317,6 +326,8 @@ func TestAccumulateUsage(t *testing.T) {
})

t.Run("MessageDeltaUsage to Usage", func(t *testing.T) {
t.Parallel()

dest := &anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
Expand Down Expand Up @@ -354,6 +365,8 @@ func TestAccumulateUsage(t *testing.T) {
})

t.Run("Nil or unsupported types", func(t *testing.T) {
t.Parallel()

// Test with nil dest
var nilUsage *anthropic.Usage
source := anthropic.Usage{InputTokens: 10}
Expand Down
2 changes: 1 addition & 1 deletion intercept/messages/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ newStream:
if interceptionErr != nil {
payload, err := i.marshal(interceptionErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr))
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error()))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
}
Expand Down
2 changes: 2 additions & 0 deletions intercept/responses/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ func (mrw *mockResponseWriter) WriteHeader(statusCode int) {
}

func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) {
t.Parallel()

mrw := mockResponseWriter{}

respCopy := responseCopier{}
Expand Down
6 changes: 3 additions & 3 deletions internal/integrationtest/apidump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import (
"path/filepath"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/coder/aibridge"
"github.com/coder/aibridge/config"
"github.com/coder/aibridge/fixtures"
"github.com/coder/aibridge/intercept/apidump"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/provider"
)

Expand Down Expand Up @@ -114,7 +114,7 @@ func TestAPIDump(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)

// Setup mock upstream server.
Expand Down Expand Up @@ -244,7 +244,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong)
t.Cleanup(cancel)

upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading
Loading