diff --git a/buildinfo/buildinfo_test.go b/buildinfo/buildinfo_test.go index 390d4e9..ffb0472 100644 --- a/buildinfo/buildinfo_test.go +++ b/buildinfo/buildinfo_test.go @@ -9,7 +9,10 @@ import ( ) func TestBuildInfo(t *testing.T) { + t.Parallel() + t.Run("Version", func(t *testing.T) { + t.Parallel() // Should return a non-empty version version := buildinfo.Version() assert.NotEmpty(t, version) diff --git a/circuitbreaker/circuitbreaker.go b/circuitbreaker/circuitbreaker.go index cdba12d..4d5b4dc 100644 --- a/circuitbreaker/circuitbreaker.go +++ b/circuitbreaker/circuitbreaker.go @@ -65,8 +65,8 @@ func (p *ProviderCircuitBreakers) isFailure(statusCode int) bool { return DefaultIsFailure(statusCode) } -// openErrorResponse returns the error response body when the circuit is open. -func (p *ProviderCircuitBreakers) openErrorResponse() []byte { +// openErrBody returns the error response body when the circuit is open. +func (p *ProviderCircuitBreakers) openErrBody() []byte { if p.config.OpenErrorResponse != nil { return p.config.OpenErrorResponse() } @@ -167,7 +167,7 @@ func (p *ProviderCircuitBreakers) Execute(endpoint, model string, w http.Respons w.Header().Set("Content-Type", "application/json") w.Header().Set("Retry-After", fmt.Sprintf("%d", int64(p.config.Timeout.Seconds()))) w.WriteHeader(http.StatusServiceUnavailable) - _, _ = w.Write(p.openErrorResponse()) + _, _ = w.Write(p.openErrBody()) return ErrCircuitOpen } @@ -187,7 +187,7 @@ func (p *ProviderCircuitBreakers) Provider() string { // OpenErrorResponse returns the error response body when the circuit is open. // This is exposed for handlers to use when responding to rejected requests. func (p *ProviderCircuitBreakers) OpenErrorResponse() []byte { - return p.openErrorResponse() + return p.openErrBody() } // StateToGaugeValue converts gobreaker.State to a gauge value. diff --git a/client_test.go b/client_test.go index 5c1d101..a33f845 100644 --- a/client_test.go +++ b/client_test.go @@ -108,7 +108,7 @@ func TestGuessClient(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - req, err := http.NewRequest(http.MethodGet, "", nil) + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "", nil) require.NoError(t, err) req.Header.Set("User-Agent", tt.userAgent) diff --git a/intercept/apidump/apidump_test.go b/intercept/apidump/apidump_test.go index f26f685..5563d41 100644 --- a/intercept/apidump/apidump_test.go +++ b/intercept/apidump/apidump_test.go @@ -39,7 +39,7 @@ func TestBridgedMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`))) require.NoError(t, err) // Add sensitive headers that should be redacted @@ -97,7 +97,7 @@ func TestBridgedMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) // Call middleware with a response containing sensitive headers @@ -167,7 +167,7 @@ func TestBridgedMiddleware_PreservesRequestBody(t *testing.T) { require.NotNil(t, middleware) originalBody := `{"messages": [{"role": "user", "content": "hello"}]}` - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody))) require.NoError(t, err) var capturedBody []byte @@ -201,7 +201,7 @@ func TestBridgedMiddleware_ModelWithSlash(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) resp3, err := middleware(req, func(r *http.Request) (*http.Response, error) { @@ -281,7 +281,7 @@ func TestBridgedMiddleware_AllSensitiveRequestHeaders(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) // Set all sensitive headers @@ -359,7 +359,7 @@ func TestPassthroughMiddleware(t *testing.T) { rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk) - req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/models", nil) + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "https://api.openai.com/v1/models", nil) require.NoError(t, err) resp, err := rt.RoundTrip(req) //nolint:bodyclose // resp is nil on error @@ -403,7 +403,7 @@ func TestPassthroughMiddleware(t *testing.T) { rt := NewPassthroughMiddleware(inner, tmpDir, "openai", logger, clk) - req, err := http.NewRequest(http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/models", bytes.NewReader([]byte(req1Body))) require.NoError(t, err) req.Header.Set("Authorization", "Bearer sk-secret-key-12345") resp, err := rt.RoundTrip(req) @@ -413,7 +413,7 @@ func TestPassthroughMiddleware(t *testing.T) { require.NoError(t, resp.Body.Close()) // Second request should create new req/resp files - req2, err := http.NewRequest(http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body))) + req2, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/conversations", bytes.NewReader([]byte(req2Body))) require.NoError(t, err) resp2, err := rt.RoundTrip(req2) require.NoError(t, err) diff --git a/intercept/apidump/streaming_test.go b/intercept/apidump/streaming_test.go index 9ab8e71..87223df 100644 --- a/intercept/apidump/streaming_test.go +++ b/intercept/apidump/streaming_test.go @@ -28,7 +28,7 @@ func TestMiddleware_StreamingResponse(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) // Simulate a streaming response with multiple chunks @@ -106,7 +106,7 @@ func TestMiddleware_PreservesResponseBody(t *testing.T) { middleware := NewBridgeMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) require.NotNil(t, middleware) - req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) require.NoError(t, err) originalRespBody := `{"choices": [{"message": {"content": "hi"}}]}` diff --git a/intercept/chatcompletions/paramswrap_test.go b/intercept/chatcompletions/paramswrap_test.go index eac38ba..7397e22 100644 --- a/intercept/chatcompletions/paramswrap_test.go +++ b/intercept/chatcompletions/paramswrap_test.go @@ -114,6 +114,8 @@ func TestOpenAILastUserPrompt(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result, err := tt.wrapper.lastUserPrompt() if tt.expectError { diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index 2e3a72e..89d4bd3 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -261,7 +261,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re if interceptionErr != nil { payload, err := i.marshalErr(interceptionErr) if err != nil { - logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr)) + logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error())) } else if err := events.Send(streamCtx, payload); err != nil { logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) } diff --git a/intercept/eventstream/eventstream.go b/intercept/eventstream/eventstream.go index 85f911d..9baeade 100644 --- a/intercept/eventstream/eventstream.go +++ b/intercept/eventstream/eventstream.go @@ -239,7 +239,7 @@ func flush(w http.ResponseWriter) (err error) { } defer func() { - if r := recover(); r != nil { //nolint:revive // Intentionally swallowed; likely a broken connection. + if r := recover(); r != nil { //nolint:revive,staticcheck // Intentionally swallowed; likely a broken connection. } }() diff --git a/intercept/messages/base_test.go b/intercept/messages/base_test.go index ec82912..ae1ee5b 100644 --- a/intercept/messages/base_test.go +++ b/intercept/messages/base_test.go @@ -197,6 +197,8 @@ func TestAWSBedrockValidation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() + base := &interceptionBase{} opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg) @@ -212,7 +214,10 @@ func TestAWSBedrockValidation(t *testing.T) { } func TestAccumulateUsage(t *testing.T) { + t.Parallel() + t.Run("Usage to Usage", func(t *testing.T) { + t.Parallel() dest := &anthropic.Usage{ InputTokens: 10, OutputTokens: 20, @@ -253,6 +258,8 @@ func TestAccumulateUsage(t *testing.T) { }) t.Run("MessageDeltaUsage to MessageDeltaUsage", func(t *testing.T) { + t.Parallel() + dest := &anthropic.MessageDeltaUsage{ InputTokens: 10, OutputTokens: 20, @@ -283,6 +290,8 @@ func TestAccumulateUsage(t *testing.T) { }) t.Run("Usage to MessageDeltaUsage", func(t *testing.T) { + t.Parallel() + dest := &anthropic.MessageDeltaUsage{ InputTokens: 10, OutputTokens: 20, @@ -317,6 +326,8 @@ func TestAccumulateUsage(t *testing.T) { }) t.Run("MessageDeltaUsage to Usage", func(t *testing.T) { + t.Parallel() + dest := &anthropic.Usage{ InputTokens: 10, OutputTokens: 20, @@ -354,6 +365,8 @@ func TestAccumulateUsage(t *testing.T) { }) t.Run("Nil or unsupported types", func(t *testing.T) { + t.Parallel() + // Test with nil dest var nilUsage *anthropic.Usage source := anthropic.Usage{InputTokens: 10} diff --git a/intercept/messages/streaming.go b/intercept/messages/streaming.go index dfe6acc..5ee7b96 100644 --- a/intercept/messages/streaming.go +++ b/intercept/messages/streaming.go @@ -504,7 +504,7 @@ newStream: if interceptionErr != nil { payload, err := i.marshal(interceptionErr) if err != nil { - logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr)) + logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", interceptionErr.Error())) } else if err := events.Send(streamCtx, payload); err != nil { logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) } diff --git a/intercept/responses/base_test.go b/intercept/responses/base_test.go index adf2322..cf02738 100644 --- a/intercept/responses/base_test.go +++ b/intercept/responses/base_test.go @@ -359,6 +359,8 @@ func (mrw *mockResponseWriter) WriteHeader(statusCode int) { } func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) { + t.Parallel() + mrw := mockResponseWriter{} respCopy := responseCopier{} diff --git a/internal/integrationtest/apidump_test.go b/internal/integrationtest/apidump_test.go index 9fbbaf1..e05f9d0 100644 --- a/internal/integrationtest/apidump_test.go +++ b/internal/integrationtest/apidump_test.go @@ -11,7 +11,6 @@ import ( "path/filepath" "strings" "testing" - "time" "github.com/stretchr/testify/require" @@ -19,6 +18,7 @@ import ( "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept/apidump" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/provider" ) @@ -114,7 +114,7 @@ func TestAPIDump(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock upstream server. @@ -244,7 +244,7 @@ func TestAPIDumpPassthrough(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/integrationtest/bridge_test.go b/internal/integrationtest/bridge_test.go index c305c9f..15623b8 100644 --- a/internal/integrationtest/bridge_test.go +++ b/internal/integrationtest/bridge_test.go @@ -10,7 +10,6 @@ import ( "slices" "strings" "testing" - "time" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/ssestream" @@ -29,6 +28,7 @@ import ( "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" "github.com/coder/aibridge/intercept" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" @@ -78,7 +78,7 @@ func TestAnthropicMessages(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) @@ -212,7 +212,7 @@ func TestAnthropicMessagesModelThoughts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) @@ -246,7 +246,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run("invalid config", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Invalid bedrock config - missing region & base url @@ -278,7 +278,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) @@ -404,7 +404,7 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.AntSimpleBedrock) @@ -501,7 +501,7 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) @@ -577,7 +577,7 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock server for multi-turn interaction. @@ -770,7 +770,7 @@ func TestSimple(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) @@ -877,7 +877,7 @@ func TestSessionIDTracking(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) @@ -1256,6 +1256,8 @@ func TestErrorHandling(t *testing.T) { // Tests that errors which occur *before* a streaming response begins, or in non-streaming requests, are handled as expected. t.Run("non-stream error", func(t *testing.T) { + t.Parallel() + cases := []struct { name string fixture []byte @@ -1298,7 +1300,7 @@ func TestErrorHandling(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock server. Error fixtures contain raw HTTP @@ -1371,7 +1373,7 @@ func TestErrorHandling(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock server. @@ -1420,7 +1422,7 @@ func TestStableRequestEncoding(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup MCP tools. @@ -1685,7 +1687,7 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup MCP tools conditionally. @@ -1849,7 +1851,7 @@ func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) @@ -1904,7 +1906,7 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Create a mock server that captures the request body sent upstream. @@ -1969,7 +1971,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution. - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) @@ -2081,7 +2083,7 @@ func TestActorHeaders(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v/send-headers=%v", tc.name, tc.streaming, send), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) diff --git a/internal/integrationtest/metrics_test.go b/internal/integrationtest/metrics_test.go index bb777ea..f774b04 100644 --- a/internal/integrationtest/metrics_test.go +++ b/internal/integrationtest/metrics_test.go @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "testing" - "time" "github.com/prometheus/client_golang/prometheus" promtest "github.com/prometheus/client_golang/prometheus/testutil" @@ -17,6 +16,7 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/metrics" ) @@ -143,7 +143,7 @@ func TestMetrics_Interception(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) @@ -175,7 +175,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { fix := fixtures.Parse(t, fixtures.AntSimple) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) blockCh := make(chan struct{}) @@ -210,7 +210,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { return promtest.ToFloat64( m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 1 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) // Unblock request, await completion. close(blockCh) @@ -225,7 +225,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { return promtest.ToFloat64( m.InterceptionsInflight.WithLabelValues(config.ProviderAnthropic, "claude-sonnet-4-0", "/v1/messages"), ) == 0 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) } func TestMetrics_PassthroughCount(t *testing.T) { @@ -252,7 +252,7 @@ func TestMetrics_PassthroughCount(t *testing.T) { func TestMetrics_PromptCount(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSimple) @@ -342,7 +342,7 @@ func TestMetrics_TokenUseCount(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) @@ -369,7 +369,7 @@ func TestMetrics_TokenUseCount(t *testing.T) { require.Eventually(t, func() bool { return promtest.ToFloat64(m.TokenUseCount.WithLabelValues( tc.expectProvider, tc.expectModel, "input", defaultActorID, string(aibridge.ClientUnknown))) > 0 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) for label, expected := range tc.expectedLabels { require.Equal(t, expected, promtest.ToFloat64(m.TokenUseCount.WithLabelValues( @@ -383,7 +383,7 @@ func TestMetrics_TokenUseCount(t *testing.T) { func TestMetrics_NonInjectedToolUseCount(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) @@ -409,7 +409,7 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { func TestMetrics_InjectedToolUseCount(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // First request returns the tool invocation, the second returns the mocked response to the tool result. @@ -436,7 +436,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { // Wait until full roundtrip has completed. require.Eventually(t, func() bool { return upstream.Calls.Load() == 2 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) recorder := bridgeServer.Recorder require.Len(t, recorder.ToolUsages(), 1) diff --git a/internal/integrationtest/mockmcp.go b/internal/integrationtest/mockmcp.go index df81d31..812ad2e 100644 --- a/internal/integrationtest/mockmcp.go +++ b/internal/integrationtest/mockmcp.go @@ -7,7 +7,6 @@ import ( "net/http/httptest" "sync" "testing" - "time" "github.com/mark3labs/mcp-go/client/transport" mcplib "github.com/mark3labs/mcp-go/mcp" @@ -19,6 +18,7 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" ) @@ -68,12 +68,12 @@ func setupMCPForTestWithName(t *testing.T, name string, tracer trace.Tracer) *mo mgr := mcp.NewServerProxyManager(map[string]mcp.ServerProxier{proxy.Name(): proxy}, tracer) t.Cleanup(func() { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() require.NoError(t, mgr.Shutdown(ctx)) }) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) require.NoError(t, mgr.Init(ctx)) require.NotEmpty(t, mgr.ListTools(), "mock MCP server should expose tools after init") diff --git a/internal/integrationtest/responses_test.go b/internal/integrationtest/responses_test.go index 65c307e..61c885d 100644 --- a/internal/integrationtest/responses_test.go +++ b/internal/integrationtest/responses_test.go @@ -22,6 +22,7 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/utils" @@ -335,7 +336,7 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) @@ -418,7 +419,7 @@ func TestResponsesBackgroundModeForbidden(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // request with Background mode should be rejected before it reaches upstream @@ -551,7 +552,7 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", tc.name, streaming), func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture[i]) @@ -637,7 +638,7 @@ func TestClientAndConnectionError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // tc.addr may be an intentionally invalid URL; use withCustomProvider. @@ -709,7 +710,7 @@ func TestUpstreamError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -890,7 +891,7 @@ func TestResponsesInjectedTool(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Setup mock server for multi-turn interaction. @@ -917,7 +918,7 @@ func TestResponsesInjectedTool(t *testing.T) { // Wait for both requests to be made (inner agentic loop). require.Eventually(t, func() bool { return upstream.Calls.Load() == 2 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) // Verify the injected tool was invoked via MCP. invocations := mockMCP.getCallsByTool(tc.mcpToolName) @@ -1037,7 +1038,7 @@ func TestResponsesModelThoughts(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, tc.fixture) diff --git a/internal/integrationtest/setupbridge.go b/internal/integrationtest/setupbridge.go index afd403a..5bfa01b 100644 --- a/internal/integrationtest/setupbridge.go +++ b/internal/integrationtest/setupbridge.go @@ -7,7 +7,6 @@ import ( "net/http" "net/http/httptest" "testing" - "time" "github.com/stretchr/testify/require" "github.com/tidwall/sjson" @@ -208,7 +207,7 @@ func setupInjectedToolTest( ) (*bridgeTestServer, *mockMCP, *http.Response) { t.Helper() - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) fix := fixtures.Parse(t, fixture) @@ -242,7 +241,7 @@ func setupInjectedToolTest( // Wait both requests (initial + tool call result) require.Eventually(t, func() bool { return upstream.Calls.Load() == 2 - }, time.Second*10, time.Millisecond*50) + }, testutil.WaitMedium, testutil.IntervalFast) return bridgeServer, mockMCP, resp } diff --git a/internal/integrationtest/trace_test.go b/internal/integrationtest/trace_test.go index 878118a..164c880 100644 --- a/internal/integrationtest/trace_test.go +++ b/internal/integrationtest/trace_test.go @@ -6,7 +6,6 @@ import ( "slices" "strings" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -20,6 +19,7 @@ import ( "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/tracing" ) @@ -43,6 +43,8 @@ func setupTracer(t *testing.T) (*tracetest.SpanRecorder, oteltrace.Tracer) { } func TestTraceAnthropic(t *testing.T) { + t.Parallel() + expectNonStreaming := []expectTrace{ {"Intercept", 1, codes.Unset}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -137,7 +139,9 @@ func TestTraceAnthropic(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) sr, tracer := setupTracer(t) @@ -191,6 +195,8 @@ func TestTraceAnthropic(t *testing.T) { } func TestTraceAnthropicErr(t *testing.T) { + t.Parallel() + expectNonStream := []expectTrace{ {"Intercept", 1, codes.Error}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -249,7 +255,9 @@ func TestTraceAnthropicErr(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) sr, tracer := setupTracer(t) @@ -422,6 +430,8 @@ func TestInjectedToolsTrace(t *testing.T) { } func TestTraceOpenAI(t *testing.T) { + t.Parallel() + cases := []struct { name string fixture []byte @@ -534,7 +544,9 @@ func TestTraceOpenAI(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) sr, tracer := setupTracer(t) @@ -576,6 +588,8 @@ func TestTraceOpenAI(t *testing.T) { } func TestTraceOpenAIErr(t *testing.T) { + t.Parallel() + cases := []struct { name string fixture []byte @@ -689,7 +703,9 @@ func TestTraceOpenAIErr(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) sr, tracer := setupTracer(t) @@ -766,6 +782,8 @@ func TestTracePassthrough(t *testing.T) { } func TestNewServerProxyManagerTraces(t *testing.T) { + t.Parallel() + sr, tracer := setupTracer(t) serverName := "serverName" diff --git a/internal/testutil/timeout.go b/internal/testutil/timeout.go new file mode 100644 index 0000000..ef8b2b5 --- /dev/null +++ b/internal/testutil/timeout.go @@ -0,0 +1,21 @@ +package testutil + +import "time" + +// Shared test timeout and interval constants. +// Using named constants avoids magic numbers and makes timeout policy +// easy to adjust across the entire test suite. +const ( + // WaitLong is the default timeout for test operations that may take a while + // (e.g. integration tests with HTTP round-trips). + WaitLong = 30 * time.Second + + // WaitMedium is a timeout for moderately slow operations. + WaitMedium = 10 * time.Second + + // WaitShort is a timeout for operations expected to complete quickly. + WaitShort = 5 * time.Second + + // IntervalFast is a short polling interval for require.Eventually and similar. + IntervalFast = 50 * time.Millisecond +) diff --git a/mcp/client_info_test.go b/mcp/client_info_test.go index a48487b..4dfabc5 100644 --- a/mcp/client_info_test.go +++ b/mcp/client_info_test.go @@ -9,6 +9,8 @@ import ( ) func TestGetClientInfo(t *testing.T) { + t.Parallel() + info := mcp.GetClientInfo() assert.Equal(t, "coder/aibridge", info.Name) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 5769440..b9a1430 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -9,7 +9,6 @@ import ( "slices" "strings" "testing" - "time" "go.opentelemetry.io/otel" "go.uber.org/goleak" @@ -20,6 +19,7 @@ import ( "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/require" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" mcplib "github.com/mark3labs/mcp-go/mcp" @@ -302,7 +302,7 @@ func TestToolInjectionOrder(t *testing.T) { // Setup. logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitLong) t.Cleanup(cancel) // Given: a MCP mock server offering a set of tools. diff --git a/session_test.go b/session_test.go index ec1be50..7cce9e4 100644 --- a/session_test.go +++ b/session_test.go @@ -205,7 +205,7 @@ func TestGuessSessionID(t *testing.T) { t.Parallel() body := tc.body - req, err := http.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(body)) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://localhost", strings.NewReader(body)) require.NoError(t, err) for key, value := range tc.headers { @@ -226,7 +226,7 @@ func TestGuessSessionID(t *testing.T) { func TestUnreadableBody(t *testing.T) { t.Parallel() - req, err := http.NewRequest(http.MethodPost, "http://localhost", &errReader{}) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "http://localhost", &errReader{}) require.NoError(t, err) got := guessSessionID(ClientClaudeCode, req) diff --git a/utils/concurrent_group_test.go b/utils/concurrent_group_test.go index 36ca448..6b5b788 100644 --- a/utils/concurrent_group_test.go +++ b/utils/concurrent_group_test.go @@ -18,11 +18,15 @@ func TestConcurrentGroup(t *testing.T) { t.Parallel() t.Run("no goroutines", func(t *testing.T) { + t.Parallel() + cg := utils.NewConcurrentGroup() require.NoError(t, cg.Wait()) }) t.Run("multiple goroutines, all ok", func(t *testing.T) { + t.Parallel() + cg := utils.NewConcurrentGroup() cg.Go(func() error { return nil @@ -34,6 +38,8 @@ func TestConcurrentGroup(t *testing.T) { }) t.Run("multiple goroutines, one err", func(t *testing.T) { + t.Parallel() + cg := utils.NewConcurrentGroup() oops := xerrors.New("oops") cg.Go(func() error { @@ -46,6 +52,8 @@ func TestConcurrentGroup(t *testing.T) { }) t.Run("multiple goroutines, multiple errs", func(t *testing.T) { + t.Parallel() + cg := utils.NewConcurrentGroup() oops := xerrors.New("oops") eek := xerrors.New("eek")