Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ linters:
- mirror
- modernize
- nakedret
- noctx
- nolintlint
- nosprintfhostport
- nilnesserr
Expand Down
4 changes: 3 additions & 1 deletion e2e/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ func TestCagentAPI_ListSessions(t *testing.T) {
},
}

resp, err := client.Get("http://localhost/api/sessions")
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, "http://localhost/api/sessions", http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

Expand Down
12 changes: 6 additions & 6 deletions pkg/fake/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func TestAPIKeyHeaderUpdater(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Setenv(tt.envKey, tt.envValue)

req, err := http.NewRequest(http.MethodPost, "https://example.com", http.NoBody)
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://example.com", http.NoBody)
require.NoError(t, err)

APIKeyHeaderUpdater(tt.host, req)
Expand All @@ -72,7 +72,7 @@ func TestAPIKeyHeaderUpdater(t *testing.T) {
}

func TestAPIKeyHeaderUpdater_UnknownHost(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com", http.NoBody)
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://example.com", http.NoBody)
require.NoError(t, err)

APIKeyHeaderUpdater("https://unknown.host.com", req)
Expand Down Expand Up @@ -222,7 +222,7 @@ func TestStreamCopy_ContextCancellation(t *testing.T) {

// Create an echo context with a request that has a cancelable context
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody)
rec := &readerFromRecorder{httptest.NewRecorder()}
ctx, cancel := context.WithCancel(t.Context())
req = req.WithContext(ctx)
Expand Down Expand Up @@ -258,7 +258,7 @@ func TestStreamCopy_NormalCompletion(t *testing.T) {

// Create an echo context with a wrapped recorder
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody)
rec := &readerFromRecorder{httptest.NewRecorder()}
c := e.NewContext(req, rec)

Expand All @@ -279,7 +279,7 @@ func TestSimulatedStreamCopy_SSEEvents(t *testing.T) {

// Create an echo context
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

Expand Down Expand Up @@ -340,7 +340,7 @@ func TestSimulatedStreamCopy_ContextCancellation(t *testing.T) {
}

e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody)
rec := httptest.NewRecorder()
ctx, cancel := context.WithCancel(t.Context())
req = req.WithContext(ctx)
Expand Down
2 changes: 1 addition & 1 deletion pkg/httpclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func doRequest(t *testing.T, opts ...Opt) http.Header {
defer srv.Close()

client := NewHTTPClient(t.Context(), opts...)
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)

resp, err := client.Do(req)
Expand Down
15 changes: 8 additions & 7 deletions pkg/model/provider/anthropic/wrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ import (

// makeTestAnthropicError creates an *anthropic.Error with the given status code and
// optional Retry-After header value for testing.
func makeTestAnthropicError(statusCode int, retryAfterValue string) *anthropic.Error {
func makeTestAnthropicError(t *testing.T, statusCode int, retryAfterValue string) *anthropic.Error {
t.Helper()
header := http.Header{}
if retryAfterValue != "" {
header.Set("Retry-After", retryAfterValue)
Expand All @@ -26,7 +27,7 @@ func makeTestAnthropicError(statusCode int, retryAfterValue string) *anthropic.E
resp.StatusCode = statusCode
resp.Header = header
// anthropic.Error.Error() dereferences Request, so we must provide a non-nil one.
req, _ := http.NewRequest(http.MethodPost, "https://api.anthropic.com/v1/messages", http.NoBody)
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.anthropic.com/v1/messages", http.NoBody)
return &anthropic.Error{
StatusCode: statusCode,
Response: resp,
Expand All @@ -53,7 +54,7 @@ func TestWrapAnthropicError(t *testing.T) {

t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(429, "")
apiErr := makeTestAnthropicError(t, 429, "")
result := wrapAnthropicError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
Expand All @@ -65,7 +66,7 @@ func TestWrapAnthropicError(t *testing.T) {

t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(429, "20")
apiErr := makeTestAnthropicError(t, 429, "20")
result := wrapAnthropicError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
Expand All @@ -75,7 +76,7 @@ func TestWrapAnthropicError(t *testing.T) {

t.Run("500 wraps with correct status code", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(500, "")
apiErr := makeTestAnthropicError(t, 500, "")
result := wrapAnthropicError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
Expand All @@ -85,7 +86,7 @@ func TestWrapAnthropicError(t *testing.T) {

t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(429, "15")
apiErr := makeTestAnthropicError(t, 429, "15")
result := wrapAnthropicError(apiErr)
retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result)
assert.False(t, retryable)
Expand All @@ -95,7 +96,7 @@ func TestWrapAnthropicError(t *testing.T) {

t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) {
t.Parallel()
apiErr := makeTestAnthropicError(429, "5")
apiErr := makeTestAnthropicError(t, 429, "5")
wrapped := fmt.Errorf("stream error: %w", wrapAnthropicError(apiErr))
retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(wrapped)
assert.False(t, retryable)
Expand Down
4 changes: 3 additions & 1 deletion pkg/model/provider/bedrock/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ func TestBearerTokenTransport(t *testing.T) {

// Make a request through the transport
client := &http.Client{Transport: transport}
resp, err := client.Get(server.URL)
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, server.URL, http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

Expand Down
4 changes: 3 additions & 1 deletion pkg/model/provider/oaistream/adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ func newTestStream(t *testing.T, sseData string) *ssestream.Stream[openai.ChatCo
}))
t.Cleanup(srv.Close)

resp, err := http.Get(srv.URL) //nolint:gosec,bodyclose // body is closed by the stream
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req) //nolint:bodyclose // body is closed by the stream
require.NoError(t, err)
return ssestream.NewStream[openai.ChatCompletionChunk](ssestream.NewDecoder(resp), nil)
}
Expand Down
15 changes: 8 additions & 7 deletions pkg/model/provider/oaistream/wrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ import (

// makeTestOpenAIError creates an *openai.Error with the given status code and
// optional Retry-After header value for testing.
func makeTestOpenAIError(statusCode int, retryAfterValue string) *openaisdk.Error {
func makeTestOpenAIError(t *testing.T, statusCode int, retryAfterValue string) *openaisdk.Error {
t.Helper()
header := http.Header{}
if retryAfterValue != "" {
header.Set("Retry-After", retryAfterValue)
Expand All @@ -26,7 +27,7 @@ func makeTestOpenAIError(statusCode int, retryAfterValue string) *openaisdk.Erro
resp.StatusCode = statusCode
resp.Header = header
// openai.Error.Error() dereferences Request, so we must provide a non-nil one.
req, _ := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", http.NoBody)
req, _ := http.NewRequestWithContext(t.Context(), http.MethodPost, "https://api.openai.com/v1/chat/completions", http.NoBody)
return &openaisdk.Error{
StatusCode: statusCode,
Response: resp,
Expand All @@ -53,7 +54,7 @@ func TestWrapOpenAIError(t *testing.T) {

t.Run("429 without Retry-After wraps with zero RetryAfter", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(429, "")
apiErr := makeTestOpenAIError(t, 429, "")
result := WrapOpenAIError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
Expand All @@ -65,7 +66,7 @@ func TestWrapOpenAIError(t *testing.T) {

t.Run("429 with Retry-After header sets RetryAfter", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(429, "30")
apiErr := makeTestOpenAIError(t, 429, "30")
result := WrapOpenAIError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
Expand All @@ -75,7 +76,7 @@ func TestWrapOpenAIError(t *testing.T) {

t.Run("500 wraps with correct status code", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(500, "")
apiErr := makeTestOpenAIError(t, 500, "")
result := WrapOpenAIError(apiErr)
var se *modelerrors.StatusError
require.ErrorAs(t, result, &se)
Expand All @@ -84,7 +85,7 @@ func TestWrapOpenAIError(t *testing.T) {

t.Run("wrapped error is classified correctly by ClassifyModelError", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(429, "10")
apiErr := makeTestOpenAIError(t, 429, "10")
result := WrapOpenAIError(apiErr)
retryable, rateLimited, retryAfter := modelerrors.ClassifyModelError(result)
assert.False(t, retryable)
Expand All @@ -94,7 +95,7 @@ func TestWrapOpenAIError(t *testing.T) {

t.Run("wrapped in fmt.Errorf still classified correctly", func(t *testing.T) {
t.Parallel()
apiErr := makeTestOpenAIError(500, "")
apiErr := makeTestOpenAIError(t, 500, "")
wrapped := fmt.Errorf("stream error: %w", WrapOpenAIError(apiErr))
retryable, rateLimited, _ := modelerrors.ClassifyModelError(wrapped)
assert.True(t, retryable)
Expand Down
7 changes: 2 additions & 5 deletions pkg/rag/strategy/bm25_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func (d *bm25DB) createSchema() error {
d.metadataTable,
d.tablePrefix, d.metadataTable)

_, err := d.db.Exec(schema)
_, err := d.db.ExecContext(context.Background(), schema)
return err
}

Expand Down Expand Up @@ -208,10 +208,7 @@ func (d *bm25DB) DeleteFileMetadata(ctx context.Context, sourcePath string) erro
}

func (d *bm25DB) Close() error {
if _, err := d.db.Exec("PRAGMA wal_checkpoint(TRUNCATE)"); err != nil {
slog.Warn("Failed to checkpoint WAL before close", "error", err)
}
return d.db.Close()
return sqliteutil.CheckpointAndClose(d.db)
}

// ensureDir creates the parent directory for a file path if it doesn't exist
Expand Down
7 changes: 2 additions & 5 deletions pkg/rag/strategy/chunked_embeddings_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (d *chunkedVectorDB) createSchema() error {
);
`, d.filesTable, d.tablePrefix, d.filesTable, d.chunksTable, d.filesTable)

_, err := d.db.Exec(schema)
_, err := d.db.ExecContext(context.Background(), schema)
return err
}

Expand Down Expand Up @@ -244,8 +244,5 @@ func (d *chunkedVectorDB) DeleteFileMetadata(ctx context.Context, sourcePath str
}

func (d *chunkedVectorDB) Close() error {
if _, err := d.db.Exec("PRAGMA wal_checkpoint(TRUNCATE)"); err != nil {
slog.Warn("Failed to checkpoint WAL before close", "error", err)
}
return d.db.Close()
return sqliteutil.CheckpointAndClose(d.db)
}
9 changes: 3 additions & 6 deletions pkg/rag/strategy/semantic_embeddings_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ func (d *semanticVectorDB) createSchema() error {
);
`, d.filesTable, d.tablePrefix, d.filesTable, d.chunksTable, d.filesTable)

if _, err := d.db.Exec(schema); err != nil {
if _, err := d.db.ExecContext(context.Background(), schema); err != nil {
return err
}

// Migration for existing databases that don't have embedding_input column
_, _ = d.db.Exec(fmt.Sprintf(`ALTER TABLE %s ADD COLUMN embedding_input TEXT`, d.chunksTable))
_, _ = d.db.ExecContext(context.Background(), fmt.Sprintf(`ALTER TABLE %s ADD COLUMN embedding_input TEXT`, d.chunksTable))

return nil
}
Expand Down Expand Up @@ -254,8 +254,5 @@ func (d *semanticVectorDB) DeleteFileMetadata(ctx context.Context, sourcePath st
}

func (d *semanticVectorDB) Close() error {
if _, err := d.db.Exec("PRAGMA wal_checkpoint(TRUNCATE)"); err != nil {
slog.Warn("Failed to checkpoint WAL before close", "error", err)
}
return d.db.Close()
return sqliteutil.CheckpointAndClose(d.db)
}
4 changes: 3 additions & 1 deletion pkg/remote/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ func TestNewTransport_WorksWithoutDesktopProxy(t *testing.T) {

// Make a simple HTTP request to verify the transport works
client := &http.Client{Transport: transport}
resp, err := client.Get(server.URL)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, http.NoBody)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()

Expand Down
2 changes: 1 addition & 1 deletion pkg/session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ func TestNewSQLiteSessionStore_RejectsNewerDatabase(t *testing.T) {
// Inject a future migration into the database to simulate a newer version
db, err := sql.Open("sqlite", dbPath)
require.NoError(t, err)
_, err = db.Exec(
_, err = db.ExecContext(t.Context(),
"INSERT INTO migrations (id, name, description, applied_at) VALUES (?, ?, ?, ?)",
9999, "9999_future_migration", "Added by a newer version", "2099-01-01T00:00:00Z")
require.NoError(t, err)
Expand Down
24 changes: 19 additions & 5 deletions pkg/skills/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ import (
"github.com/docker/docker-agent/pkg/remote"
)

// remoteHTTPTimeout caps each HTTP request made to a remote skills source.
const remoteHTTPTimeout = 30 * time.Second

// httpGet performs a GET request using the standard remote transport so that
// Docker Desktop proxy/SSL settings are honoured. The returned response body
// must be closed by the caller.
func httpGet(ctx context.Context, url string) (*http.Response, error) {
client := &http.Client{
Timeout: remoteHTTPTimeout,
Transport: remote.NewTransport(ctx),
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return nil, fmt.Errorf("creating request for %s: %w", url, err)
}
return client.Do(req)
}

type diskCache struct {
baseDir string
}
Expand Down Expand Up @@ -70,11 +88,7 @@ func (c *diskCache) Get(baseURL, skillName, filePath string) (string, bool) {
func (c *diskCache) FetchAndStore(ctx context.Context, baseURL, skillName, filePath, fileURL string) (string, error) {
slog.Debug("Fetching remote skill file", "url", fileURL)

httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: remote.NewTransport(ctx),
}
resp, err := httpClient.Get(fileURL)
resp, err := httpGet(ctx, fileURL)
if err != nil {
return "", fmt.Errorf("fetching %s: %w", fileURL, err)
}
Expand Down
8 changes: 1 addition & 7 deletions pkg/skills/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import (
"path/filepath"
"regexp"
"strings"
"time"

"github.com/docker/docker-agent/pkg/paths"
"github.com/docker/docker-agent/pkg/remote"
)

// remoteIndex represents the index.json served at /.well-known/skills/index.json
Expand Down Expand Up @@ -45,11 +43,7 @@ func loadRemoteSkillsWithCache(ctx context.Context, baseURL string, cache *diskC

slog.Debug("Fetching remote skills index", "url", indexURL)

httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: remote.NewTransport(ctx),
}
resp, err := httpClient.Get(indexURL)
resp, err := httpGet(ctx, indexURL)
if err != nil {
slog.Warn("Failed to fetch remote skills index", "url", indexURL, "error", err)
return nil
Expand Down
Loading
Loading