From 4f7d8f8d6898fc1ea755a8a1ec66d3485f5ad606 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Sat, 25 Apr 2026 20:00:14 +0200 Subject: [PATCH 01/19] feat: add `docker agent serve chat` command (OpenAI-compatible API) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Expose any docker-agent agent through an OpenAI-compatible HTTP server, so tools that already speak the Chat Completions protocol (Open WebUI, the official `openai` SDKs, ad-hoc curl scripts, etc.) can drive an agent without any custom integration. Endpoints: GET /v1/models — lists exposed agents as OpenAI models POST /v1/chat/completions — runs the agent; supports stream: true (Server-Sent Events) and false The team is loaded once at startup and shared across requests; each chat completion gets a fresh session and runtime. Tool calls and elicitation prompts are auto-handled (this is a non-interactive endpoint). The `model` field can pin a specific agent in a multi- agent team, or is ignored and the team's default agent runs. Implementation notes: - New `cmd/root/chat.go` cobra command (default 127.0.0.1:8083, --agent / --listen flags) wired into `cmd/root/serve.go`. - New `pkg/chatserver` package, split into: - server.go — Run, router, HTTP handlers, sseStream, errors - agent.go — agentPolicy, buildSession, runAgentLoop, sessionUsage - types.go — request/response shapes - Reuses `openai.Model` from github.com/openai/openai-go/v3 for /v1/models. Other OpenAI SDK response types serialise too noisily with stdlib `encoding/json` (the SDK relies on its internal `apijson` package which we can't import), so request/response shapes are hand-rolled for clean output. - Defensive event handling in runAgentLoop: ToolsApproved=true and NonInteractive=true mean the runtime never blocks for confirmation in normal flow, but ElicitationRequestEvent must still be answered or the runtime would hang on its dedicated channel. Tests cover session-building, agent-policy, error-envelope shape, and the three early-validation paths of /v1/chat/completions via httptest. Validated with `mise lint` (0 issues), `mise test` (all packages green), and a curl smoke test against examples/42.yaml. Fixes docker/docker-agent#2502 Assisted-By: docker-agent --- cmd/root/chat.go | 62 +++++++ cmd/root/serve.go | 3 +- e2e/binary/binary_test.go | 5 +- pkg/chatserver/agent.go | 160 ++++++++++++++++++ pkg/chatserver/handlers_test.go | 124 ++++++++++++++ pkg/chatserver/server.go | 281 ++++++++++++++++++++++++++++++++ pkg/chatserver/server_test.go | 142 ++++++++++++++++ pkg/chatserver/types.go | 102 ++++++++++++ 8 files changed, 876 insertions(+), 3 deletions(-) create mode 100644 cmd/root/chat.go create mode 100644 pkg/chatserver/agent.go create mode 100644 pkg/chatserver/handlers_test.go create mode 100644 pkg/chatserver/server.go create mode 100644 pkg/chatserver/server_test.go create mode 100644 pkg/chatserver/types.go diff --git a/cmd/root/chat.go b/cmd/root/chat.go new file mode 100644 index 000000000..fc181b312 --- /dev/null +++ b/cmd/root/chat.go @@ -0,0 +1,62 @@ +package root + +import ( + "github.com/spf13/cobra" + + "github.com/docker/docker-agent/pkg/chatserver" + "github.com/docker/docker-agent/pkg/cli" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/telemetry" +) + +type chatFlags struct { + agentName string + listenAddr string + runConfig config.RuntimeConfig +} + +func newChatCmd() *cobra.Command { + var flags chatFlags + + cmd := &cobra.Command{ + Use: "chat |", + Short: "Start an agent as an OpenAI-compatible chat completions server", + Long: `Start an HTTP server that exposes the agent through an OpenAI-compatible +API at /v1/chat/completions and /v1/models. This lets tools that already +speak OpenAI's chat protocol (such as Open WebUI) drive a docker-agent +agent without any custom integration.`, + Example: ` docker-agent serve chat ./agent.yaml + docker-agent serve chat ./team.yaml --agent reviewer + docker-agent serve chat agentcatalog/pirate --listen 127.0.0.1:9090`, + Args: cobra.ExactArgs(1), + RunE: flags.runChatCommand, + } + + cmd.Flags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to expose (all agents if not specified)") + cmd.Flags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8083", "Address to listen on") + addRuntimeConfigFlags(cmd, &flags.runConfig) + + return cmd +} + +func (f *chatFlags) runChatCommand(cmd *cobra.Command, args []string) (commandErr error) { + ctx := cmd.Context() + telemetry.TrackCommand(ctx, "serve", append([]string{"chat"}, args...)) + defer func() { // do not inline this defer so that commandErr is not resolved early + telemetry.TrackCommandError(ctx, "serve", append([]string{"chat"}, args...), commandErr) + }() + + out := cli.NewPrinter(cmd.OutOrStdout()) + agentFilename := args[0] + + ln, cleanup, err := newListener(ctx, f.listenAddr) + if err != nil { + return err + } + defer cleanup() + + out.Println("Listening on", ln.Addr().String()) + out.Println("OpenAI-compatible chat completions endpoint: http://" + ln.Addr().String() + "/v1/chat/completions") + + return chatserver.Run(ctx, agentFilename, f.agentName, &f.runConfig, ln) +} diff --git a/cmd/root/serve.go b/cmd/root/serve.go index df541cecd..13a10d836 100644 --- a/cmd/root/serve.go +++ b/cmd/root/serve.go @@ -13,8 +13,9 @@ func newServeCmd() *cobra.Command { cmd.AddCommand(newA2ACmd()) cmd.AddCommand(newACPCmd()) - cmd.AddCommand(newMCPCmd()) cmd.AddCommand(newAPICmd()) + cmd.AddCommand(newChatCmd()) + cmd.AddCommand(newMCPCmd()) return cmd } diff --git a/e2e/binary/binary_test.go b/e2e/binary/binary_test.go index 1c8c823c2..4d7130404 100644 --- a/e2e/binary/binary_test.go +++ b/e2e/binary/binary_test.go @@ -54,12 +54,13 @@ func TestAutoComplete(t *testing.T) { res, err := Exec(binDir+"/docker-agent", "__complete", "serve", "") require.NoError(t, err) props := lines(res.Stdout) - require.Greater(t, len(props), 4) + require.Greater(t, len(props), 5) require.Contains(t, props[0], "a2a") require.Contains(t, props[0], "Start an agent as an A2A") require.Contains(t, props[1], "acp") require.Contains(t, props[2], "api") - require.Contains(t, props[3], "mcp") + require.Contains(t, props[3], "chat") + require.Contains(t, props[4], "mcp") }) t.Run("cli plugin auto-complete docker agent", func(t *testing.T) { diff --git a/pkg/chatserver/agent.go b/pkg/chatserver/agent.go new file mode 100644 index 000000000..4203a264f --- /dev/null +++ b/pkg/chatserver/agent.go @@ -0,0 +1,160 @@ +package chatserver + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/tools" +) + +// agentPolicy decides which agent in a team is exposed by the server and +// which one to run for a given request. It is built once at startup and is +// read-only thereafter, so it's safe to share across goroutines. +type agentPolicy struct { + // exposed is the list of agent names advertised on /v1/models. + exposed []string + // fallback is used when the request's "model" field doesn't match any + // exposed agent (so we don't fail when clients hard-code "gpt-4"). + fallback string +} + +// newAgentPolicy validates the requested agent name against the team and +// returns the selection policy. If agentName is empty, every agent in the +// team is exposed and the team's default agent is used as fallback. +// Otherwise only that one agent is exposed and used. +func newAgentPolicy(t *team.Team, agentName string) (agentPolicy, error) { + if agentName != "" { + if !slices.Contains(t.AgentNames(), agentName) { + return agentPolicy{}, fmt.Errorf("agent %q not found", agentName) + } + return agentPolicy{exposed: []string{agentName}, fallback: agentName}, nil + } + a, err := t.DefaultAgent() + if err != nil { + return agentPolicy{}, fmt.Errorf("resolving default agent: %w", err) + } + return agentPolicy{exposed: t.AgentNames(), fallback: a.Name()}, nil +} + +// pick returns the agent name to use for a request. The "model" field is +// honoured when it matches an exposed agent; otherwise we silently fall +// back, mirroring how OpenAI's API behaves with unknown model strings on +// some compatible servers. +func (p agentPolicy) pick(model string) string { + if model != "" && slices.Contains(p.exposed, model) { + return model + } + return p.fallback +} + +// buildSession converts an OpenAI-style message history into a docker-agent +// session. System messages are added as system context, prior user/ +// assistant/tool turns are replayed verbatim so the agent sees the full +// conversation, and the latest user message becomes the prompt. +// +// Tool approval and non-interactive mode are forced on: this is a headless +// HTTP endpoint, there's no human in the loop to approve anything. +// +// Returns nil when the history contains no usable user message, in which +// case the caller should reject the request. +func buildSession(messages []ChatCompletionMessage) *session.Session { + sess := session.New( + session.WithToolsApproved(true), + session.WithNonInteractive(true), + ) + + hasUser := false + for _, m := range messages { + content := m.Content + if strings.TrimSpace(content) == "" { + continue + } + switch strings.ToLower(strings.TrimSpace(m.Role)) { + case "system": + sess.AddMessage(session.SystemMessage(content)) + case "assistant": + sess.AddMessage(&session.Message{Message: chat.Message{ + Role: chat.MessageRoleAssistant, + Content: content, + }}) + case "tool": + sess.AddMessage(&session.Message{Message: chat.Message{ + Role: chat.MessageRoleTool, + Content: content, + ToolCallID: m.ToolCallID, + }}) + default: + // user, developer, or any other role: feed it to the agent + // as user input rather than rejecting the request. + sess.AddMessage(session.UserMessage(content)) + hasUser = true + } + } + + if !hasUser { + return nil + } + return sess +} + +// runAgentLoop drives the runtime to completion, forwarding assistant +// content to emit (which may be nil for non-streaming mode). +// +// The session is built with ToolsApproved=true and NonInteractive=true, +// which means the runtime auto-approves tool calls and auto-stops on +// max-iterations. The handler cases below are intentionally kept as +// defence-in-depth: if those session settings ever drift, this handler +// still won't hang the request. Elicitation is the exception — the +// runtime always blocks until we respond, so its case is required for +// correctness, not just defence. +// +// The first error reported by the runtime is surfaced; later events in +// the same run are still drained so the runtime can shut down cleanly. +func runAgentLoop(ctx context.Context, rt runtime.Runtime, sess *session.Session, emit func(string)) error { + var runErr error + for ev := range rt.RunStream(ctx, sess) { + switch e := ev.(type) { + case *runtime.AgentChoiceEvent: + if emit != nil { + emit(e.Content) + } + case *runtime.ToolCallConfirmationEvent: + // Defensive: should never fire while ToolsApproved=true. + rt.Resume(ctx, runtime.ResumeApprove()) + case *runtime.ElicitationRequestEvent: + // Required: the runtime blocks until we respond, regardless + // of NonInteractive. Decline so the tool call fails fast. + _ = rt.ResumeElicitation(ctx, tools.ElicitationActionDecline, nil) + case *runtime.MaxIterationsReachedEvent: + // Defensive: in non-interactive mode the runtime already + // stops on its own and this Resume is dropped. + rt.Resume(ctx, runtime.ResumeReject("")) + case *runtime.ErrorEvent: + if runErr == nil { + runErr = errors.New(e.Error) + } + } + } + return runErr +} + +// sessionUsage extracts approximate token usage from a completed session, +// returning nil when nothing is known so we can omit the field entirely +// rather than reporting zeroes. +func sessionUsage(sess *session.Session) *ChatCompletionUsage { + if sess.InputTokens == 0 && sess.OutputTokens == 0 { + return nil + } + return &ChatCompletionUsage{ + PromptTokens: sess.InputTokens, + CompletionTokens: sess.OutputTokens, + TotalTokens: sess.InputTokens + sess.OutputTokens, + } +} diff --git a/pkg/chatserver/handlers_test.go b/pkg/chatserver/handlers_test.go new file mode 100644 index 000000000..78f35be88 --- /dev/null +++ b/pkg/chatserver/handlers_test.go @@ -0,0 +1,124 @@ +package chatserver + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/labstack/echo/v4" + "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestServer builds a server with a fake policy for tests that don't run +// the agent loop. Handlers that touch s.team will panic — those code paths +// are exercised by integration tests, not here. +func newTestServer(exposed ...string) (*server, *echo.Echo) { + if len(exposed) == 0 { + exposed = []string{"root"} + } + srv := &server{policy: agentPolicy{exposed: exposed, fallback: exposed[0]}} + e := echo.New() + return srv, e +} + +func TestHandleModels(t *testing.T) { + srv, e := newTestServer("root", "reviewer") + + req := httptest.NewRequest(http.MethodGet, "/v1/models", http.NoBody) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, srv.handleModels(c)) + require.Equal(t, http.StatusOK, rec.Code) + + var got ModelsResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, "list", got.Object) + require.Len(t, got.Data, 2) + + ids := []string{got.Data[0].ID, got.Data[1].ID} + assert.ElementsMatch(t, []string{"root", "reviewer"}, ids) + for _, m := range got.Data { + assert.Equal(t, "docker-agent", m.OwnedBy) + // openai.Model carries a typed `Object constant.Model` field that + // always serialises to "model". Ensure the wire shape is stable. + assert.Equal(t, openai.Model{}.Object.Default(), m.Object) + } +} + +func TestHandleChatCompletions_RejectsBadJSON(t *testing.T) { + srv, e := newTestServer() + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader("not json")) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, srv.handleChatCompletions(c)) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Contains(t, rec.Body.String(), "invalid_request_error") +} + +func TestHandleChatCompletions_RejectsEmptyMessages(t *testing.T) { + srv, e := newTestServer() + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", + strings.NewReader(`{"messages":[]}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, srv.handleChatCompletions(c)) + assert.Equal(t, http.StatusBadRequest, rec.Code) + + var got ErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, "invalid_request_error", got.Error.Type) + assert.Contains(t, got.Error.Message, "at least one message") +} + +func TestHandleChatCompletions_RejectsHistoryWithoutUser(t *testing.T) { + srv, e := newTestServer() + + body := `{"messages":[{"role":"system","content":"be helpful"}]}` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, srv.handleChatCompletions(c)) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Contains(t, rec.Body.String(), "no user message") +} + +func TestWriteError_ShapeAndType(t *testing.T) { + cases := []struct { + name string + status int + message string + wantType string + }{ + {"client error", http.StatusBadRequest, "bad input", "invalid_request_error"}, + {"server error", http.StatusInternalServerError, "boom", "internal_error"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, writeError(c, tc.status, tc.message)) + assert.Equal(t, tc.status, rec.Code) + + var got ErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, tc.message, got.Error.Message) + assert.Equal(t, tc.wantType, got.Error.Type) + }) + } +} diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go new file mode 100644 index 000000000..a159180f5 --- /dev/null +++ b/pkg/chatserver/server.go @@ -0,0 +1,281 @@ +// Package chatserver implements an OpenAI-compatible HTTP server that exposes +// docker-agent agents through the /v1/chat/completions and /v1/models +// endpoints. +// +// The goal is to let any tool that already speaks OpenAI's chat protocol +// (e.g. Open WebUI, custom shell scripts using the openai SDK) drive a +// docker-agent agent without needing to know about docker-agent's own +// protocol. +// +// On types: we deliberately don't reuse the request/response structs from +// github.com/openai/openai-go/v3. The SDK is built around its internal +// `apijson` encoder; with stdlib `encoding/json` those types serialize +// every field and produce noisy responses. `apijson` lives under +// `internal/`, so we can't borrow it. `openai.Model` is the one type that +// round-trips cleanly with stdlib json, so we reuse it for /v1/models. +package chatserver + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "github.com/openai/openai-go/v3" + + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/teamloader" +) + +// Run starts an OpenAI-compatible HTTP server on the given listener and +// blocks until ctx is cancelled or the server fails. The team is loaded +// once from agentFilename and shared across requests; every chat completion +// request gets a fresh session. +// +// If agentName is empty, every agent in the team is exposed and the team's +// default agent is used when the client doesn't pin one. +func Run(ctx context.Context, agentFilename, agentName string, runConfig *config.RuntimeConfig, ln net.Listener) error { + slog.Debug("Starting chat completions server", "agent", agentFilename, "addr", ln.Addr()) + + t, err := loadTeam(ctx, agentFilename, runConfig) + if err != nil { + return err + } + defer func() { + if err := t.StopToolSets(ctx); err != nil { + slog.Error("Failed to stop tool sets", "error", err) + } + }() + + policy, err := newAgentPolicy(t, agentName) + if err != nil { + return err + } + + httpServer := &http.Server{ + Handler: newRouter(&server{team: t, policy: policy}), + ReadHeaderTimeout: 30 * time.Second, + } + return serve(ctx, httpServer, ln) +} + +// loadTeam resolves and loads the team referenced by agentFilename. +func loadTeam(ctx context.Context, agentFilename string, runConfig *config.RuntimeConfig) (*team.Team, error) { + src, err := config.Resolve(agentFilename, nil) + if err != nil { + return nil, err + } + t, err := teamloader.Load(ctx, src, runConfig) + if err != nil { + return nil, fmt.Errorf("failed to load agents: %w", err) + } + return t, nil +} + +// serve runs httpServer on ln until ctx is cancelled, then triggers a +// graceful shutdown. +func serve(ctx context.Context, httpServer *http.Server, ln net.Listener) error { + errCh := make(chan error, 1) + go func() { errCh <- httpServer.Serve(ln) }() + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return httpServer.Shutdown(shutdownCtx) + case err := <-errCh: + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err + } +} + +// server is concurrent-safe: every request creates its own session and +// runtime, so the only shared state is the team (whose toolsets are +// independently safe to call). +type server struct { + team *team.Team + policy agentPolicy +} + +func newRouter(s *server) http.Handler { + e := echo.New() + e.HideBanner = true + e.HidePort = true + + e.Use(middleware.RequestLogger()) + e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ + AllowOrigins: []string{"*"}, + AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions}, + AllowHeaders: []string{"Authorization", "Content-Type", "Accept"}, + MaxAge: 86400, + })) + + e.GET("/v1/models", s.handleModels) + e.POST("/v1/chat/completions", s.handleChatCompletions) + return e +} + +func (s *server) handleModels(c echo.Context) error { + data := make([]openai.Model, 0, len(s.policy.exposed)) + for _, name := range s.policy.exposed { + data = append(data, openai.Model{ID: name, OwnedBy: "docker-agent"}) + } + return c.JSON(http.StatusOK, ModelsResponse{Object: "list", Data: data}) +} + +func (s *server) handleChatCompletions(c echo.Context) error { + var req ChatCompletionRequest + if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil { + return writeError(c, http.StatusBadRequest, err.Error()) + } + if len(req.Messages) == 0 { + return writeError(c, http.StatusBadRequest, "at least one message is required") + } + + sess := buildSession(req.Messages) + if sess == nil { + return writeError(c, http.StatusBadRequest, "no user message provided") + } + + agentName := s.policy.pick(req.Model) + rt, err := runtime.New(s.team, runtime.WithCurrentAgent(agentName)) + if err != nil { + return writeError(c, http.StatusInternalServerError, fmt.Sprintf("failed to create runtime: %v", err)) + } + + // Echo back the requested model verbatim when set, so clients matching + // on the model field stay happy. Otherwise expose the actual agent. + model := agentName + if req.Model != "" { + model = req.Model + } + + if req.Stream { + return s.streamChatCompletion(c, rt, sess, model) + } + return s.chatCompletion(c, rt, sess, model) +} + +// chatCompletion runs the agent to completion and replies with one +// non-streaming OpenAI ChatCompletion object. +func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { + if err := runAgentLoop(c.Request().Context(), rt, sess, nil); err != nil { + return writeError(c, http.StatusInternalServerError, fmt.Sprintf("agent execution failed: %v", err)) + } + + return c.JSON(http.StatusOK, ChatCompletionResponse{ + ID: newChatID(), + Object: "chat.completion", + Created: time.Now().Unix(), + Model: model, + Choices: []ChatCompletionChoice{{ + Index: 0, + Message: ChatCompletionMessage{ + Role: "assistant", + Content: sess.GetLastAssistantMessageContent(), + }, + FinishReason: "stop", + }}, + Usage: sessionUsage(sess), + }) +} + +// streamChatCompletion runs the agent and streams its response back to the +// client as Server-Sent Events in OpenAI's chat.completion.chunk format. +func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { + stream := newSSEStream(c.Response(), newChatID(), model) + + // Initial "role: assistant" delta so clients can start rendering. + stream.send(ChatCompletionStreamDelta{Role: "assistant"}, "") + + if err := runAgentLoop(c.Request().Context(), rt, sess, func(content string) { + if content != "" { + stream.send(ChatCompletionStreamDelta{Content: content}, "") + } + }); err != nil { + // Surface the error as a final content chunk so the client sees it. + stream.send(ChatCompletionStreamDelta{Content: fmt.Sprintf("\n\n[error: %v]", err)}, "") + } + + stream.send(ChatCompletionStreamDelta{}, "stop") + stream.done() + return nil +} + +// sseStream writes OpenAI-style chat.completion.chunk events to a response. +// It centralises SSE bookkeeping (headers, JSON encoding, flushing, +// terminator) so the handler can focus on what to emit. +type sseStream struct { + w http.ResponseWriter + id string + model string + created int64 +} + +func newSSEStream(w http.ResponseWriter, id, model string) *sseStream { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + return &sseStream{w: w, id: id, model: model, created: time.Now().Unix()} +} + +func (s *sseStream) send(delta ChatCompletionStreamDelta, finishReason string) { + chunk := ChatCompletionStreamResponse{ + ID: s.id, + Object: "chat.completion.chunk", + Created: s.created, + Model: s.model, + Choices: []ChatCompletionStreamChoice{{ + Index: 0, + Delta: delta, + FinishReason: finishReason, + }}, + } + data, err := json.Marshal(chunk) + if err != nil { + return + } + _, _ = fmt.Fprintf(s.w, "data: %s\n\n", data) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } +} + +// done writes the OpenAI sentinel terminator that ends the stream. +func (s *sseStream) done() { + _, _ = fmt.Fprint(s.w, "data: [DONE]\n\n") + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } +} + +// newChatID returns a fresh OpenAI-style chat completion id. +func newChatID() string { return "chatcmpl-" + uuid.NewString() } + +// writeError writes an OpenAI-style error envelope. +func writeError(c echo.Context, status int, message string) error { + return c.JSON(status, ErrorResponse{Error: ErrorDetail{ + Message: message, + Type: errTypeFor(status), + }}) +} + +func errTypeFor(status int) string { + if status >= 500 { + return "internal_error" + } + return "invalid_request_error" +} diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go new file mode 100644 index 000000000..62ef97374 --- /dev/null +++ b/pkg/chatserver/server_test.go @@ -0,0 +1,142 @@ +package chatserver + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/session" +) + +func TestBuildSession_RequiresUserMessage(t *testing.T) { + tests := []struct { + name string + messages []ChatCompletionMessage + wantNil bool + }{ + { + name: "empty list", + wantNil: true, + }, + { + name: "only system messages", + messages: []ChatCompletionMessage{ + {Role: "system", Content: "be helpful"}, + }, + wantNil: true, + }, + { + name: "blank user message is ignored", + messages: []ChatCompletionMessage{ + {Role: "user", Content: " "}, + }, + wantNil: true, + }, + { + name: "valid user message", + messages: []ChatCompletionMessage{ + {Role: "user", Content: "hello"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sess := buildSession(tc.messages) + if tc.wantNil { + assert.Nil(t, sess) + return + } + require.NotNil(t, sess) + assert.True(t, sess.ToolsApproved) + assert.True(t, sess.NonInteractive) + }) + } +} + +func TestBuildSession_PreservesHistory(t *testing.T) { + sess := buildSession([]ChatCompletionMessage{ + {Role: "system", Content: "you are a docker agent"}, + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi there"}, + {Role: "user", Content: "how are you?"}, + }) + require.NotNil(t, sess) + + // GetAllMessages omits system messages. + all := sess.GetAllMessages() + require.Len(t, all, 3) + + roles := make([]chat.MessageRole, len(all)) + for i, m := range all { + roles[i] = m.Message.Role + } + assert.Equal(t, []chat.MessageRole{ + chat.MessageRoleUser, + chat.MessageRoleAssistant, + chat.MessageRoleUser, + }, roles) + + assert.Equal(t, "how are you?", sess.GetLastUserMessageContent()) + assert.Equal(t, "hi there", sess.GetLastAssistantMessageContent()) +} + +func TestBuildSession_PreservesToolMessage(t *testing.T) { + sess := buildSession([]ChatCompletionMessage{ + {Role: "user", Content: "compute 2+2"}, + {Role: "assistant", Content: ""}, // dropped: empty content + {Role: "tool", Content: "4", ToolCallID: "call_1"}, + }) + require.NotNil(t, sess) + + all := sess.GetAllMessages() + require.Len(t, all, 2) + + last := all[len(all)-1].Message + assert.Equal(t, chat.MessageRoleTool, last.Role) + assert.Equal(t, "4", last.Content) + assert.Equal(t, "call_1", last.ToolCallID) +} + +func TestBuildSession_UnknownRoleTreatedAsUser(t *testing.T) { + sess := buildSession([]ChatCompletionMessage{ + {Role: "developer", Content: "do this"}, + }) + require.NotNil(t, sess) + + all := sess.GetAllMessages() + require.Len(t, all, 1) + assert.Equal(t, chat.MessageRoleUser, all[0].Message.Role) + assert.Equal(t, "do this", all[0].Message.Content) +} + +func TestSessionUsage_OmitsZero(t *testing.T) { + sess := session.New() + assert.Nil(t, sessionUsage(sess)) + + sess.InputTokens = 5 + sess.OutputTokens = 7 + usage := sessionUsage(sess) + require.NotNil(t, usage) + assert.Equal(t, int64(5), usage.PromptTokens) + assert.Equal(t, int64(7), usage.CompletionTokens) + assert.Equal(t, int64(12), usage.TotalTokens) +} + +func TestAgentPolicy_Pick(t *testing.T) { + p := agentPolicy{exposed: []string{"root", "reviewer"}, fallback: "root"} + + assert.Equal(t, "reviewer", p.pick("reviewer")) + assert.Equal(t, "root", p.pick("root")) + assert.Equal(t, "root", p.pick(""), "empty model falls back") + assert.Equal(t, "root", p.pick("gpt-4"), "unknown model falls back") +} + +func TestErrTypeFor(t *testing.T) { + assert.Equal(t, "invalid_request_error", errTypeFor(400)) + assert.Equal(t, "invalid_request_error", errTypeFor(404)) + assert.Equal(t, "internal_error", errTypeFor(500)) + assert.Equal(t, "internal_error", errTypeFor(502)) +} diff --git a/pkg/chatserver/types.go b/pkg/chatserver/types.go new file mode 100644 index 000000000..6edc3c958 --- /dev/null +++ b/pkg/chatserver/types.go @@ -0,0 +1,102 @@ +package chatserver + +import "github.com/openai/openai-go/v3" + +// This file declares the OpenAI-compatible request/response types used by +// /v1/chat/completions and /v1/models. We hand-roll most of them instead of +// borrowing from github.com/openai/openai-go/v3 because the SDK's response +// structs are deserialised through its internal `apijson` package and don't +// have `omitempty` JSON tags; marshalling them with stdlib `encoding/json` +// produces noisy responses full of empty audio/tool_call/refusal +// placeholders. `openai.Model` round-trips cleanly with stdlib json, so +// /v1/models reuses it. + +// --- Request -------------------------------------------------------------- + +// ChatCompletionRequest is the body of a /v1/chat/completions call. We +// only declare the fields we act on; any extras are silently ignored. +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` +} + +// ChatCompletionMessage is a single message in the conversation. Multi-modal +// content (image parts, audio, etc.) is not supported and falls back to the +// `Content` string. +type ChatCompletionMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Name string `json:"name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// --- Non-streaming response ----------------------------------------------- + +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage *ChatCompletionUsage `json:"usage,omitempty"` +} + +type ChatCompletionChoice struct { + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatCompletionUsage reports approximate token counts. Best-effort: when +// the underlying provider doesn't report usage we omit the field entirely. +type ChatCompletionUsage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +// --- Streaming response --------------------------------------------------- + +// ChatCompletionStreamResponse is one SSE chunk emitted when the client +// requests stream: true. +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` +} + +type ChatCompletionStreamChoice struct { + Index int `json:"index"` + Delta ChatCompletionStreamDelta `json:"delta"` + FinishReason string `json:"finish_reason,omitempty"` +} + +type ChatCompletionStreamDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +// --- Models endpoint ------------------------------------------------------ + +// ModelsResponse is the body returned by /v1/models. Each agent in the team +// is exposed as one entry. +type ModelsResponse struct { + Object string `json:"object"` + Data []openai.Model `json:"data"` +} + +// --- Errors --------------------------------------------------------------- + +// ErrorResponse is the OpenAI-style error envelope returned on 4xx/5xx. +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +type ErrorDetail struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code,omitempty"` +} From 3ce409ea1f11038eebd05cc4d61d1da2fa14846d Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:16:50 +0200 Subject: [PATCH 02/19] examples: add minimal chat client for `docker agent serve chat` Demonstrates the OpenAI-compatible HTTP server introduced in PR #2510. Uses the official github.com/openai/openai-go SDK pointed at the local chat server's /v1 base URL and runs an interactive REPL with streaming, history retention, and graceful Ctrl-C shutdown. Run `docker agent serve chat ./agent.yaml` in one terminal, then `go run ./examples/chat` in another. Assisted-By: docker-agent --- examples/chat/main.go | 174 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 examples/chat/main.go diff --git a/examples/chat/main.go b/examples/chat/main.go new file mode 100644 index 000000000..5efba5bcd --- /dev/null +++ b/examples/chat/main.go @@ -0,0 +1,174 @@ +// A very, very basic chat client for `docker agent serve chat`. +// +// PR #2510 (`feat: add docker agent serve chat command`) exposes any +// docker-agent agent through an OpenAI-compatible HTTP server. The whole +// point of that feature is that any tool already speaking OpenAI's +// /v1/chat/completions protocol can drive a docker-agent agent without +// custom integration. This example demonstrates exactly that: it uses the +// official github.com/openai/openai-go SDK, only repointed at the local +// chat server, to run an interactive REPL against an agent. +// +// Prerequisites: +// +// # Start an agent in chat mode (in another terminal): +// ./bin/docker-agent serve chat ./examples/42.yaml +// # It listens on http://127.0.0.1:8083 by default. +// +// Then run this client: +// +// go run ./examples/chat +// # or, to pin a specific agent in a multi-agent team: +// go run ./examples/chat -model root +// # or, to point at a different server: +// go run ./examples/chat -base http://127.0.0.1:9090/v1 +// +// Type a message and press . Type "exit" (or send EOF with ^D) to +// quit. +package main + +import ( + "bufio" + "context" + "errors" + "flag" + "fmt" + "io" + "log" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" +) + +func main() { + baseURL := flag.String("base", "http://127.0.0.1:8083/v1", "Base URL of the docker-agent chat server") + model := flag.String("model", "", "Agent name to talk to (defaults to the team's default agent)") + stream := flag.Bool("stream", true, "Stream the agent's response token-by-token") + flag.Parse() + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + err := run(ctx, *baseURL, *model, *stream) + cancel() + if err != nil && !errors.Is(err, context.Canceled) { + log.Fatal(err) + } +} + +func run(ctx context.Context, baseURL, model string, stream bool) error { + // The chat server doesn't validate API keys, but the OpenAI SDK + // requires *some* string to be passed. + client := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("not-needed"), + ) + + // Ask the server which agents are exposed and pick a default model + // when the user didn't pin one. This also doubles as a connectivity + // check. + if model == "" { + picked, err := pickDefaultModel(ctx, &client) + if err != nil { + return fmt.Errorf("listing models: %w", err) + } + model = picked + } + fmt.Printf("Connected to %s — chatting with %q. Type \"exit\" to quit.\n", baseURL, model) + + // History keeps the conversation going across turns. The chat server + // is stateless: it builds a fresh session per request from whatever + // messages the client sends, so it's the client's job to remember. + var history []openai.ChatCompletionMessageParamUnion + + in := bufio.NewScanner(os.Stdin) + in.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for { + fmt.Print("\n> ") + if !in.Scan() { + if err := in.Err(); err != nil { + return err + } + fmt.Println() + return nil // EOF + } + userInput := strings.TrimSpace(in.Text()) + if userInput == "" { + continue + } + if userInput == "exit" || userInput == "quit" { + return nil + } + + history = append(history, openai.UserMessage(userInput)) + + reply, err := chat(ctx, &client, model, history, stream) + if err != nil { + return err + } + history = append(history, openai.AssistantMessage(reply)) + } +} + +// pickDefaultModel queries /v1/models and returns the first agent name +// the server advertises. +func pickDefaultModel(ctx context.Context, client *openai.Client) (string, error) { + page, err := client.Models.List(ctx) + if err != nil { + return "", err + } + if len(page.Data) == 0 { + return "", errors.New("server exposes no models") + } + return page.Data[0].ID, nil +} + +// chat sends the conversation to the server, prints the assistant's reply +// to stdout (streaming if requested) and returns the final assembled +// content so the caller can append it to the history. +func chat( + ctx context.Context, + client *openai.Client, + model string, + history []openai.ChatCompletionMessageParamUnion, + stream bool, +) (string, error) { + params := openai.ChatCompletionNewParams{ + Model: model, + Messages: history, + } + + if !stream { + resp, err := client.Chat.Completions.New(ctx, params) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", errors.New("server returned no choices") + } + content := resp.Choices[0].Message.Content + fmt.Println(content) + return content, nil + } + + s := client.Chat.Completions.NewStreaming(ctx, params) + var b strings.Builder + for s.Next() { + chunk := s.Current() + if len(chunk.Choices) == 0 { + continue + } + delta := chunk.Choices[0].Delta.Content + if delta == "" { + continue + } + fmt.Print(delta) + b.WriteString(delta) + } + if err := s.Err(); err != nil && !errors.Is(err, io.EOF) { + return "", err + } + fmt.Println() + return b.String(), nil +} From 2e70ef8270ba2991873e1e7fb05436d14ab1f4e4 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:44:12 +0200 Subject: [PATCH 03/19] chatserver: replace `*` CORS with `--cors-origin` flag The chat server used to set `Access-Control-Allow-Origin: *` on every response, which makes it unsafe to expose on anything other than loopback. Replace the wildcard with an explicit per-server allow-list of one origin and disable the CORS middleware entirely when the flag is empty. - Introduce `chatserver.Options` so future improvements can extend the server configuration without breaking the `Run` signature on each change. - Add `--cors-origin` flag to `docker agent serve chat`. Default empty = no CORS headers emitted. - Update tests; fix three pre-existing `noctx` lint failures in handlers_test.go that surfaced when the PR was rebased onto current main. Assisted-By: docker-agent --- cmd/root/chat.go | 8 ++++++- pkg/chatserver/handlers_test.go | 10 ++++---- pkg/chatserver/server.go | 42 ++++++++++++++++++++++----------- pkg/chatserver/server_test.go | 29 +++++++++++++++++++++++ 4 files changed, 69 insertions(+), 20 deletions(-) diff --git a/cmd/root/chat.go b/cmd/root/chat.go index fc181b312..90f5f7334 100644 --- a/cmd/root/chat.go +++ b/cmd/root/chat.go @@ -12,6 +12,7 @@ import ( type chatFlags struct { agentName string listenAddr string + corsOrigin string runConfig config.RuntimeConfig } @@ -34,6 +35,7 @@ agent without any custom integration.`, cmd.Flags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to expose (all agents if not specified)") cmd.Flags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8083", "Address to listen on") + cmd.Flags().StringVar(&flags.corsOrigin, "cors-origin", "", "Allowed CORS origin (e.g. https://example.com); empty disables CORS entirely") addRuntimeConfigFlags(cmd, &flags.runConfig) return cmd @@ -58,5 +60,9 @@ func (f *chatFlags) runChatCommand(cmd *cobra.Command, args []string) (commandEr out.Println("Listening on", ln.Addr().String()) out.Println("OpenAI-compatible chat completions endpoint: http://" + ln.Addr().String() + "/v1/chat/completions") - return chatserver.Run(ctx, agentFilename, f.agentName, &f.runConfig, ln) + return chatserver.Run(ctx, agentFilename, chatserver.Options{ + AgentName: f.agentName, + RunConfig: &f.runConfig, + CORSOrigin: f.corsOrigin, + }, ln) } diff --git a/pkg/chatserver/handlers_test.go b/pkg/chatserver/handlers_test.go index 78f35be88..9ca43aa23 100644 --- a/pkg/chatserver/handlers_test.go +++ b/pkg/chatserver/handlers_test.go @@ -28,7 +28,7 @@ func newTestServer(exposed ...string) (*server, *echo.Echo) { func TestHandleModels(t *testing.T) { srv, e := newTestServer("root", "reviewer") - req := httptest.NewRequest(http.MethodGet, "/v1/models", http.NoBody) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/models", http.NoBody) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -53,7 +53,7 @@ func TestHandleModels(t *testing.T) { func TestHandleChatCompletions_RejectsBadJSON(t *testing.T) { srv, e := newTestServer() - req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader("not json")) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", strings.NewReader("not json")) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -66,7 +66,7 @@ func TestHandleChatCompletions_RejectsBadJSON(t *testing.T) { func TestHandleChatCompletions_RejectsEmptyMessages(t *testing.T) { srv, e := newTestServer() - req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"messages":[]}`)) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() @@ -85,7 +85,7 @@ func TestHandleChatCompletions_RejectsHistoryWithoutUser(t *testing.T) { srv, e := newTestServer() body := `{"messages":[{"role":"system","content":"be helpful"}]}` - req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -108,7 +108,7 @@ func TestWriteError_ShapeAndType(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) rec := httptest.NewRecorder() c := e.NewContext(req, rec) diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index a159180f5..4b8ec7246 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -37,17 +37,29 @@ import ( "github.com/docker/docker-agent/pkg/teamloader" ) +// Options configures the chat completions server. Future improvements +// (auth, conversations, etc.) extend this struct rather than the Run +// signature so callers stay stable. +type Options struct { + // AgentName pins the single agent to expose. Empty exposes every + // agent in the team and uses the team's default as the fallback. + AgentName string + // RunConfig is the runtime configuration used to load the team. + RunConfig *config.RuntimeConfig + // CORSOrigin is the allowed value for the Access-Control-Allow-Origin + // header. When empty, the CORS middleware is not registered at all + // (the server never emits any Access-Control-* response header). + CORSOrigin string +} + // Run starts an OpenAI-compatible HTTP server on the given listener and // blocks until ctx is cancelled or the server fails. The team is loaded // once from agentFilename and shared across requests; every chat completion // request gets a fresh session. -// -// If agentName is empty, every agent in the team is exposed and the team's -// default agent is used when the client doesn't pin one. -func Run(ctx context.Context, agentFilename, agentName string, runConfig *config.RuntimeConfig, ln net.Listener) error { +func Run(ctx context.Context, agentFilename string, opts Options, ln net.Listener) error { slog.Debug("Starting chat completions server", "agent", agentFilename, "addr", ln.Addr()) - t, err := loadTeam(ctx, agentFilename, runConfig) + t, err := loadTeam(ctx, agentFilename, opts.RunConfig) if err != nil { return err } @@ -57,13 +69,13 @@ func Run(ctx context.Context, agentFilename, agentName string, runConfig *config } }() - policy, err := newAgentPolicy(t, agentName) + policy, err := newAgentPolicy(t, opts.AgentName) if err != nil { return err } httpServer := &http.Server{ - Handler: newRouter(&server{team: t, policy: policy}), + Handler: newRouter(&server{team: t, policy: policy}, opts), ReadHeaderTimeout: 30 * time.Second, } return serve(ctx, httpServer, ln) @@ -109,18 +121,20 @@ type server struct { policy agentPolicy } -func newRouter(s *server) http.Handler { +func newRouter(s *server, opts Options) http.Handler { e := echo.New() e.HideBanner = true e.HidePort = true e.Use(middleware.RequestLogger()) - e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions}, - AllowHeaders: []string{"Authorization", "Content-Type", "Accept"}, - MaxAge: 86400, - })) + if opts.CORSOrigin != "" { + e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ + AllowOrigins: []string{opts.CORSOrigin}, + AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions}, + AllowHeaders: []string{"Authorization", "Content-Type", "Accept"}, + MaxAge: 86400, + })) + } e.GET("/v1/models", s.handleModels) e.POST("/v1/chat/completions", s.handleChatCompletions) diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index 62ef97374..519569f18 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -1,6 +1,8 @@ package chatserver import ( + "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -140,3 +142,30 @@ func TestErrTypeFor(t *testing.T) { assert.Equal(t, "internal_error", errTypeFor(500)) assert.Equal(t, "internal_error", errTypeFor(502)) } + +func TestNewRouter_CORSDisabledByDefault(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"), + "no CORS header should be emitted when no origin is configured") +} + +func TestNewRouter_CORSAllowsConfiguredOrigin(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{CORSOrigin: "https://example.com"}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, "https://example.com", rec.Header().Get("Access-Control-Allow-Origin")) +} From 61fb5ab1c13382e4e77c34d53525db3f035a4d55 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:45:27 +0200 Subject: [PATCH 04/19] chatserver: enforce max body size and per-request timeout Hostile or buggy clients could previously stream gigabytes into the chat completions endpoint or hold a goroutine open indefinitely on a slow upstream model. Cap both via Echo middleware: - `BodyLimit` defaults to 1 MiB (configurable via `--max-request-size`). Oversized bodies now return 413 instead of being silently buffered. - A new `requestTimeoutMiddleware` wraps `c.Request().Context()` in `context.WithTimeout` so model + tool calls + SSE streaming all share a single deadline. Default 5 minutes, configurable via `--request-timeout`. Both limits are exposed on `chatserver.Options` (`MaxRequestBytes`, `RequestTimeout`); zero values fall back to package defaults. Tests cover oversized body rejection and deadline propagation through the middleware chain. Assisted-By: docker-agent --- cmd/root/chat.go | 22 +++++++++++++------ pkg/chatserver/server.go | 37 ++++++++++++++++++++++++++++++++ pkg/chatserver/server_test.go | 40 +++++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 7 deletions(-) diff --git a/cmd/root/chat.go b/cmd/root/chat.go index 90f5f7334..f636b3727 100644 --- a/cmd/root/chat.go +++ b/cmd/root/chat.go @@ -1,6 +1,8 @@ package root import ( + "time" + "github.com/spf13/cobra" "github.com/docker/docker-agent/pkg/chatserver" @@ -10,10 +12,12 @@ import ( ) type chatFlags struct { - agentName string - listenAddr string - corsOrigin string - runConfig config.RuntimeConfig + agentName string + listenAddr string + corsOrigin string + maxRequestSize int64 + requestTimeout time.Duration + runConfig config.RuntimeConfig } func newChatCmd() *cobra.Command { @@ -36,6 +40,8 @@ agent without any custom integration.`, cmd.Flags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to expose (all agents if not specified)") cmd.Flags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8083", "Address to listen on") cmd.Flags().StringVar(&flags.corsOrigin, "cors-origin", "", "Allowed CORS origin (e.g. https://example.com); empty disables CORS entirely") + cmd.Flags().Int64Var(&flags.maxRequestSize, "max-request-size", 1<<20, "Maximum request body size in bytes (default 1 MiB)") + cmd.Flags().DurationVar(&flags.requestTimeout, "request-timeout", 5*time.Minute, "Per-request timeout (covers model + tool calls + streaming)") addRuntimeConfigFlags(cmd, &flags.runConfig) return cmd @@ -61,8 +67,10 @@ func (f *chatFlags) runChatCommand(cmd *cobra.Command, args []string) (commandEr out.Println("OpenAI-compatible chat completions endpoint: http://" + ln.Addr().String() + "/v1/chat/completions") return chatserver.Run(ctx, agentFilename, chatserver.Options{ - AgentName: f.agentName, - RunConfig: &f.runConfig, - CORSOrigin: f.corsOrigin, + AgentName: f.agentName, + RunConfig: &f.runConfig, + CORSOrigin: f.corsOrigin, + MaxRequestBytes: f.maxRequestSize, + RequestTimeout: f.requestTimeout, }, ln) } diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index 4b8ec7246..f21c235a7 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -23,6 +23,7 @@ import ( "log/slog" "net" "net/http" + "strconv" "time" "github.com/google/uuid" @@ -50,8 +51,20 @@ type Options struct { // header. When empty, the CORS middleware is not registered at all // (the server never emits any Access-Control-* response header). CORSOrigin string + // MaxRequestBytes caps the size of an incoming request body. Zero + // means use the package default (1 MiB). + MaxRequestBytes int64 + // RequestTimeout caps how long a single chat completion is allowed to + // run. Zero means use the package default (5 minutes). The cap covers + // model calls, tool calls, and SSE streaming combined. + RequestTimeout time.Duration } +const ( + defaultMaxRequestBytes int64 = 1 << 20 // 1 MiB + defaultRequestTimeout time.Duration = 5 * time.Minute +) + // Run starts an OpenAI-compatible HTTP server on the given listener and // blocks until ctx is cancelled or the server fails. The team is loaded // once from agentFilename and shared across requests; every chat completion @@ -126,7 +139,18 @@ func newRouter(s *server, opts Options) http.Handler { e.HideBanner = true e.HidePort = true + maxBytes := opts.MaxRequestBytes + if maxBytes <= 0 { + maxBytes = defaultMaxRequestBytes + } + timeout := opts.RequestTimeout + if timeout <= 0 { + timeout = defaultRequestTimeout + } + e.Use(middleware.RequestLogger()) + e.Use(middleware.BodyLimit(strconv.FormatInt(maxBytes, 10))) + e.Use(requestTimeoutMiddleware(timeout)) if opts.CORSOrigin != "" { e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowOrigins: []string{opts.CORSOrigin}, @@ -141,6 +165,19 @@ func newRouter(s *server, opts Options) http.Handler { return e } +// requestTimeoutMiddleware caps each request's lifetime. Streaming +// handlers honour the timeout via c.Request().Context(). +func requestTimeoutMiddleware(d time.Duration) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + ctx, cancel := context.WithTimeout(c.Request().Context(), d) + defer cancel() + c.SetRequest(c.Request().WithContext(ctx)) + return next(c) + } + } +} + func (s *server) handleModels(c echo.Context) error { data := make([]openai.Model, 0, len(s.policy.exposed)) for _, name := range s.policy.exposed { diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index 519569f18..408ede2ae 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -1,10 +1,14 @@ package chatserver import ( + "context" "net/http" "net/http/httptest" + "strings" "testing" + "time" + "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -169,3 +173,39 @@ func TestNewRouter_CORSAllowsConfiguredOrigin(t *testing.T) { assert.Equal(t, "https://example.com", rec.Header().Get("Access-Control-Allow-Origin")) } + +func TestNewRouter_RejectsOversizedBody(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{MaxRequestBytes: 16}) + + body := strings.NewReader(`{"messages":[{"role":"user","content":"this body is far longer than sixteen bytes"}]}`) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", body) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) +} + +func TestRequestTimeoutMiddleware_AppliesDeadline(t *testing.T) { + e := echo.New() + e.Use(requestTimeoutMiddleware(5 * time.Millisecond)) + + var gotErr error + e.GET("/sleep", func(c echo.Context) error { + select { + case <-c.Request().Context().Done(): + gotErr = c.Request().Context().Err() + return c.String(http.StatusOK, "ok") + case <-time.After(time.Second): + return c.String(http.StatusOK, "too slow") + } + }) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/sleep", http.NoBody) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + require.Error(t, gotErr) + assert.ErrorIs(t, gotErr, context.DeadlineExceeded) +} From 0c7009718847270cf423c02d57ae6e5f984dfba1 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:46:11 +0200 Subject: [PATCH 05/19] chatserver: collect every runtime ErrorEvent (errors.Join) Previously runAgentLoop would record only the first ErrorEvent and drop every subsequent one on the floor while still draining the stream. That made debugging a multi-error run frustrating: only the earliest symptom was ever surfaced, even though later events often held the actual root cause (a model timeout followed by a tool call that couldn't connect, for instance). Switch to a slice of errors and join them with `errors.Join` at the end. The handler's behaviour for callers is unchanged when a single error occurs; multi-error runs now surface a wrapped error whose `Unwrap() []error` makes each cause inspectable. Assisted-By: docker-agent --- pkg/chatserver/agent.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pkg/chatserver/agent.go b/pkg/chatserver/agent.go index 4203a264f..272bd1434 100644 --- a/pkg/chatserver/agent.go +++ b/pkg/chatserver/agent.go @@ -115,10 +115,11 @@ func buildSession(messages []ChatCompletionMessage) *session.Session { // runtime always blocks until we respond, so its case is required for // correctness, not just defence. // -// The first error reported by the runtime is surfaced; later events in -// the same run are still drained so the runtime can shut down cleanly. +// All ErrorEvents seen in the run are joined into the returned error so +// callers can see the full picture; the loop keeps draining until the +// stream closes so the runtime can shut down cleanly. func runAgentLoop(ctx context.Context, rt runtime.Runtime, sess *session.Session, emit func(string)) error { - var runErr error + var runErrs []error for ev := range rt.RunStream(ctx, sess) { switch e := ev.(type) { case *runtime.AgentChoiceEvent: @@ -137,12 +138,10 @@ func runAgentLoop(ctx context.Context, rt runtime.Runtime, sess *session.Session // stops on its own and this Resume is dropped. rt.Resume(ctx, runtime.ResumeReject("")) case *runtime.ErrorEvent: - if runErr == nil { - runErr = errors.New(e.Error) - } + runErrs = append(runErrs, errors.New(e.Error)) } } - return runErr + return errors.Join(runErrs...) } // sessionUsage extracts approximate token usage from a completed session, From 2a3c95a9d24c34b79726d41827cf871e7165d1bc Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:47:15 +0200 Subject: [PATCH 06/19] chatserver: emit structured error events on streaming failures Until now a runtime error mid-stream was injected into the assistant content as `[error: ...]` and the stream still closed with `finish_reason: "stop"`. Clients matching on the OpenAI protocol had no programmatic way to tell a successful completion apart from a failed one. Switch to OpenAI's actual on-the-wire shape: emit a separate `data: {"error": {...}}` envelope, then terminate the stream with `finish_reason: "error"` before the `[DONE]` sentinel. Successful runs continue to terminate with `finish_reason: "stop"`. Add a unit test on the new `sseStream.sendError` covering the wire format. Assisted-By: docker-agent --- pkg/chatserver/server.go | 38 +++++++++++++++++++++++++++++------ pkg/chatserver/server_test.go | 17 ++++++++++++++++ 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index f21c235a7..00e72e57f 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -251,16 +251,22 @@ func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess * // Initial "role: assistant" delta so clients can start rendering. stream.send(ChatCompletionStreamDelta{Role: "assistant"}, "") - if err := runAgentLoop(c.Request().Context(), rt, sess, func(content string) { + runErr := runAgentLoop(c.Request().Context(), rt, sess, func(content string) { if content != "" { stream.send(ChatCompletionStreamDelta{Content: content}, "") } - }); err != nil { - // Surface the error as a final content chunk so the client sees it. - stream.send(ChatCompletionStreamDelta{Content: fmt.Sprintf("\n\n[error: %v]", err)}, "") + }) + if runErr != nil { + // Emit a structured error envelope (OpenAI streams use a regular + // `data:` line carrying an `error` object, then close the stream + // with finish_reason "error" instead of "stop"). Clients matching + // on the OpenAI protocol can therefore distinguish a model error + // from a normal completion. + stream.sendError(runErr) + stream.send(ChatCompletionStreamDelta{}, "error") + } else { + stream.send(ChatCompletionStreamDelta{}, "stop") } - - stream.send(ChatCompletionStreamDelta{}, "stop") stream.done() return nil } @@ -313,6 +319,26 @@ func (s *sseStream) done() { } } +// sendError emits an OpenAI-style error envelope as a separate SSE event +// alongside the chunked deltas. Real OpenAI streams use this shape when a +// run fails mid-flight, e.g. a content filter trips: the message arrives +// in its own `data:` line carrying an `error` object before the stream +// terminates. +func (s *sseStream) sendError(err error) { + envelope := ErrorResponse{Error: ErrorDetail{ + Message: err.Error(), + Type: "internal_error", + }} + data, marshalErr := json.Marshal(envelope) + if marshalErr != nil { + return + } + _, _ = fmt.Fprintf(s.w, "data: %s\n\n", data) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } +} + // newChatID returns a fresh OpenAI-style chat completion id. func newChatID() string { return "chatcmpl-" + uuid.NewString() } diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index 408ede2ae..5afbdecfc 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -2,6 +2,7 @@ package chatserver import ( "context" + "errors" "net/http" "net/http/httptest" "strings" @@ -187,6 +188,22 @@ func TestNewRouter_RejectsOversizedBody(t *testing.T) { assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) } +func TestSSEStream_SendError(t *testing.T) { + rec := httptest.NewRecorder() + s := newSSEStream(rec, "chatcmpl-x", "root") + s.sendError(errors.New("model exploded")) + s.send(ChatCompletionStreamDelta{}, "error") + s.done() + + body := rec.Body.String() + // One error envelope. + assert.Contains(t, body, `"error":{"message":"model exploded"`) + // One terminating chunk with finish_reason=error (instead of stop). + assert.Contains(t, body, `"finish_reason":"error"`) + // And the OpenAI sentinel. + assert.Contains(t, body, "data: [DONE]") +} + func TestRequestTimeoutMiddleware_AppliesDeadline(t *testing.T) { e := echo.New() e.Use(requestTimeoutMiddleware(5 * time.Millisecond)) From 8af4b8a079b187fe859daef755bc972c81f7fe07 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:50:03 +0200 Subject: [PATCH 07/19] chatserver: parse and validate OpenAI sampling parameters OpenAI clients regularly send `temperature`, `top_p`, `max_tokens`, and `stop` on every chat completion request. The server used to drop them silently because the request struct didn't declare them, so typos and out-of-range values went unnoticed until the upstream provider eventually returned an opaque error several seconds later. - Add `Temperature`, `TopP`, `MaxTokens`, `Stop` to `ChatCompletionRequest` so the OpenAPI schema matches what the wire protocol allows. - `Stop` is JSON-flexible: clients send either a single string or an array, and OpenAI accepts both. Custom `UnmarshalJSON` handles the union shape. - `validateSamplingParams` range-checks the new fields and rejects bad input with a 400 invalid_request_error, matching how OpenAI itself behaves. Plumbing these values through the runtime to the model layer requires per-request overrides that don't exist today; that work is tracked separately. Validating up front is the user-visible win and unblocks future plumbing. Assisted-By: docker-agent --- pkg/chatserver/server.go | 31 ++++++++++++++++++ pkg/chatserver/server_test.go | 61 +++++++++++++++++++++++++++++++++++ pkg/chatserver/types.go | 55 +++++++++++++++++++++++++++++-- 3 files changed, 145 insertions(+), 2 deletions(-) diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index 00e72e57f..5d42d7e50 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -21,8 +21,10 @@ import ( "errors" "fmt" "log/slog" + "math" "net" "net/http" + "slices" "strconv" "time" @@ -194,6 +196,9 @@ func (s *server) handleChatCompletions(c echo.Context) error { if len(req.Messages) == 0 { return writeError(c, http.StatusBadRequest, "at least one message is required") } + if err := validateSamplingParams(&req); err != nil { + return writeError(c, http.StatusBadRequest, err.Error()) + } sess := buildSession(req.Messages) if sess == nil { @@ -356,3 +361,29 @@ func errTypeFor(status int) string { } return "invalid_request_error" } + +// validateSamplingParams range-checks the OpenAI sampling fields. Even when +// we don't yet plumb them all the way through to the model, validating up +// front lets clients learn about typos / out-of-range values immediately +// instead of getting an opaque provider error several seconds later. +func validateSamplingParams(req *ChatCompletionRequest) error { + if req.Temperature != nil { + t := *req.Temperature + if math.IsNaN(t) || t < 0 || t > 2 { + return fmt.Errorf("temperature must be in [0, 2], got %g", t) + } + } + if req.TopP != nil { + p := *req.TopP + if math.IsNaN(p) || p <= 0 || p > 1 { + return fmt.Errorf("top_p must be in (0, 1], got %g", p) + } + } + if req.MaxTokens != nil && *req.MaxTokens <= 0 { + return fmt.Errorf("max_tokens must be > 0, got %d", *req.MaxTokens) + } + if slices.Contains(req.Stop, "") { + return errors.New("stop sequences must not be empty strings") + } + return nil +} diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index 5afbdecfc..5f00f4ded 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -188,6 +188,67 @@ func TestNewRouter_RejectsOversizedBody(t *testing.T) { assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) } +func TestValidateSamplingParams(t *testing.T) { + f := func(v float64) *float64 { return &v } + i := func(v int64) *int64 { return &v } + + cases := []struct { + name string + req ChatCompletionRequest + wantErr string + }{ + {name: "all empty"}, + {name: "valid", req: ChatCompletionRequest{ + Temperature: f(0.7), TopP: f(0.95), MaxTokens: i(256), + Stop: StopSequences{"\n\n", "END"}, + }}, + {name: "temp negative", req: ChatCompletionRequest{Temperature: f(-0.1)}, wantErr: "temperature"}, + {name: "temp too high", req: ChatCompletionRequest{Temperature: f(2.5)}, wantErr: "temperature"}, + {name: "topp zero", req: ChatCompletionRequest{TopP: f(0)}, wantErr: "top_p"}, + {name: "topp too high", req: ChatCompletionRequest{TopP: f(1.5)}, wantErr: "top_p"}, + {name: "max_tokens zero", req: ChatCompletionRequest{MaxTokens: i(0)}, wantErr: "max_tokens"}, + {name: "empty stop", req: ChatCompletionRequest{Stop: StopSequences{""}}, wantErr: "stop"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := validateSamplingParams(&tc.req) + if tc.wantErr == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + }) + } +} + +func TestStopSequences_UnmarshalJSON(t *testing.T) { + cases := []struct { + name string + json string + want []string + err bool + }{ + {"null", `null`, nil, false}, + {"single string", `"END"`, []string{"END"}, false}, + {"array", `["a", "b"]`, []string{"a", "b"}, false}, + {"empty array", `[]`, []string{}, false}, + {"number invalid", `42`, nil, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got StopSequences + err := got.UnmarshalJSON([]byte(tc.json)) + if tc.err { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, []string(got)) + }) + } +} + func TestSSEStream_SendError(t *testing.T) { rec := httptest.NewRecorder() s := newSSEStream(rec, "chatcmpl-x", "root") diff --git a/pkg/chatserver/types.go b/pkg/chatserver/types.go index 6edc3c958..22f7e41f2 100644 --- a/pkg/chatserver/types.go +++ b/pkg/chatserver/types.go @@ -1,6 +1,11 @@ package chatserver -import "github.com/openai/openai-go/v3" +import ( + "encoding/json" + "errors" + + "github.com/openai/openai-go/v3" +) // This file declares the OpenAI-compatible request/response types used by // /v1/chat/completions and /v1/models. We hand-roll most of them instead of @@ -14,11 +19,57 @@ import "github.com/openai/openai-go/v3" // --- Request -------------------------------------------------------------- // ChatCompletionRequest is the body of a /v1/chat/completions call. We -// only declare the fields we act on; any extras are silently ignored. +// declare every field commonly sent by OpenAI clients so they are accepted +// without surprise. Whether each field is *acted on* is documented inline. type ChatCompletionRequest struct { Model string `json:"model"` Messages []ChatCompletionMessage `json:"messages"` Stream bool `json:"stream,omitempty"` + + // Temperature is parsed and range-checked but not yet plumbed through + // to the runtime/model layer (no per-request override exists today). + // Set on the agent's YAML configuration to control sampling. + Temperature *float64 `json:"temperature,omitempty"` + // TopP is parsed and range-checked but not yet plumbed through. + TopP *float64 `json:"top_p,omitempty"` + // MaxTokens is the maximum number of tokens the model may generate in + // the response. Parsed and validated; runtime plumbing is tracked for + // a follow-up. + MaxTokens *int64 `json:"max_tokens,omitempty"` + // Stop is one or more substrings that, if produced, end generation. + // Accepted as either a single string or an array of strings, matching + // the OpenAI schema. Validated; not yet enforced. + Stop StopSequences `json:"stop,omitempty"` +} + +// StopSequences is a JSON-flexible field that accepts either a single +// string or an array of strings. OpenAI's API uses both shapes +// interchangeably; clients in the wild send both. +type StopSequences []string + +func (s *StopSequences) UnmarshalJSON(data []byte) error { + if len(data) == 0 || string(data) == "null" { + *s = nil + return nil + } + switch data[0] { + case '"': + var one string + if err := json.Unmarshal(data, &one); err != nil { + return err + } + *s = []string{one} + return nil + case '[': + var many []string + if err := json.Unmarshal(data, &many); err != nil { + return err + } + *s = many + return nil + default: + return errors.New("stop must be a string or array of strings") + } } // ChatCompletionMessage is a single message in the conversation. Multi-modal From 17a1c3b54dd180d54fc68e437d6626f21b48803a Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:51:17 +0200 Subject: [PATCH 08/19] chatserver: add Bearer-token auth (`--api-key`) The chat server is unauthenticated by default, which is fine on loopback but unsafe anywhere else. Add an opt-in static bearer-token gate so the server can be safely bound to a LAN interface. - `chatserver.Options.APIKey`: when non-empty, every request to /v1/* must carry `Authorization: Bearer ` or it is rejected with 401. Empty preserves the previous unauthenticated behaviour. - `bearerAuthMiddleware` uses `subtle.ConstantTimeCompare` to dodge timing-side-channel leaks. CORS preflight (OPTIONS) is exempted so browsers can negotiate before sending the auth header. - `--api-key` and `--api-key-env` flags expose the option from the CLI; the env-var form keeps secrets out of process listings. Tests cover missing/wrong/correct tokens and the OPTIONS exemption. Assisted-By: docker-agent --- cmd/root/chat.go | 13 +++++++++++ pkg/chatserver/server.go | 35 +++++++++++++++++++++++++++++ pkg/chatserver/server_test.go | 42 +++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+) diff --git a/cmd/root/chat.go b/cmd/root/chat.go index f636b3727..133b71d1b 100644 --- a/cmd/root/chat.go +++ b/cmd/root/chat.go @@ -1,6 +1,7 @@ package root import ( + "os" "time" "github.com/spf13/cobra" @@ -15,6 +16,8 @@ type chatFlags struct { agentName string listenAddr string corsOrigin string + apiKey string + apiKeyEnv string maxRequestSize int64 requestTimeout time.Duration runConfig config.RuntimeConfig @@ -40,6 +43,8 @@ agent without any custom integration.`, cmd.Flags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to expose (all agents if not specified)") cmd.Flags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8083", "Address to listen on") cmd.Flags().StringVar(&flags.corsOrigin, "cors-origin", "", "Allowed CORS origin (e.g. https://example.com); empty disables CORS entirely") + cmd.Flags().StringVar(&flags.apiKey, "api-key", "", "Required Bearer token clients must present (Authorization: Bearer ); empty disables auth") + cmd.Flags().StringVar(&flags.apiKeyEnv, "api-key-env", "", "Read the API key from this environment variable instead of the command line") cmd.Flags().Int64Var(&flags.maxRequestSize, "max-request-size", 1<<20, "Maximum request body size in bytes (default 1 MiB)") cmd.Flags().DurationVar(&flags.requestTimeout, "request-timeout", 5*time.Minute, "Per-request timeout (covers model + tool calls + streaming)") addRuntimeConfigFlags(cmd, &flags.runConfig) @@ -66,10 +71,18 @@ func (f *chatFlags) runChatCommand(cmd *cobra.Command, args []string) (commandEr out.Println("Listening on", ln.Addr().String()) out.Println("OpenAI-compatible chat completions endpoint: http://" + ln.Addr().String() + "/v1/chat/completions") + apiKey := f.apiKey + if f.apiKeyEnv != "" { + if v := os.Getenv(f.apiKeyEnv); v != "" { + apiKey = v + } + } + return chatserver.Run(ctx, agentFilename, chatserver.Options{ AgentName: f.agentName, RunConfig: &f.runConfig, CORSOrigin: f.corsOrigin, + APIKey: apiKey, MaxRequestBytes: f.maxRequestSize, RequestTimeout: f.requestTimeout, }, ln) diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index 5d42d7e50..a01ef5a59 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -17,6 +17,7 @@ package chatserver import ( "context" + "crypto/subtle" "encoding/json" "errors" "fmt" @@ -26,6 +27,7 @@ import ( "net/http" "slices" "strconv" + "strings" "time" "github.com/google/uuid" @@ -53,6 +55,13 @@ type Options struct { // header. When empty, the CORS middleware is not registered at all // (the server never emits any Access-Control-* response header). CORSOrigin string + // APIKey, if non-empty, is the static bearer token clients must + // present in the `Authorization` header (`Authorization: Bearer X`). + // Empty disables authentication; once set, every request to /v1/* is + // rejected with 401 unless it carries the matching token. + // /v1/models is also protected so an unauthenticated client can't + // fingerprint the server. + APIKey string // MaxRequestBytes caps the size of an incoming request body. Zero // means use the package default (1 MiB). MaxRequestBytes int64 @@ -153,6 +162,9 @@ func newRouter(s *server, opts Options) http.Handler { e.Use(middleware.RequestLogger()) e.Use(middleware.BodyLimit(strconv.FormatInt(maxBytes, 10))) e.Use(requestTimeoutMiddleware(timeout)) + if opts.APIKey != "" { + e.Use(bearerAuthMiddleware(opts.APIKey)) + } if opts.CORSOrigin != "" { e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowOrigins: []string{opts.CORSOrigin}, @@ -180,6 +192,29 @@ func requestTimeoutMiddleware(d time.Duration) echo.MiddlewareFunc { } } +// bearerAuthMiddleware enforces the static `Authorization: Bearer ` +// header. CORS preflight requests (OPTIONS) are exempted so that browsers +// can negotiate before sending the auth header. +// +// The expected token is captured by closure rather than read per-request, +// and the comparison uses subtle.ConstantTimeCompare so timing observation +// can't reveal valid prefixes. +func bearerAuthMiddleware(expected string) echo.MiddlewareFunc { + exp := []byte(expected) + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if c.Request().Method == http.MethodOptions { + return next(c) + } + got, ok := strings.CutPrefix(c.Request().Header.Get("Authorization"), "Bearer ") + if !ok || subtle.ConstantTimeCompare([]byte(got), exp) != 1 { + return writeError(c, http.StatusUnauthorized, "missing or invalid bearer token") + } + return next(c) + } + } +} + func (s *server) handleModels(c echo.Context) error { data := make([]openai.Model, 0, len(s.policy.exposed)) for _, name := range s.policy.exposed { diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index 5f00f4ded..e84cad521 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -175,6 +175,48 @@ func TestNewRouter_CORSAllowsConfiguredOrigin(t *testing.T) { assert.Equal(t, "https://example.com", rec.Header().Get("Access-Control-Allow-Origin")) } +func TestBearerAuthMiddleware(t *testing.T) { + cases := []struct { + name string + header string + wantStatus int + }{ + {"missing", "", http.StatusUnauthorized}, + {"wrong scheme", "Basic abc", http.StatusUnauthorized}, + {"wrong token", "Bearer wrong", http.StatusUnauthorized}, + {"correct token", "Bearer secret", http.StatusOK}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{APIKey: "secret"}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/models", http.NoBody) + if tc.header != "" { + req.Header.Set("Authorization", tc.header) + } + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, tc.wantStatus, rec.Code) + }) + } +} + +func TestBearerAuthMiddleware_AllowsCORSPreflight(t *testing.T) { + // CORS preflight must succeed without an Authorization header. + srv, _ := newTestServer("root") + r := newRouter(srv, Options{APIKey: "secret", CORSOrigin: "https://example.com"}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.NotEqual(t, http.StatusUnauthorized, rec.Code) +} + func TestNewRouter_RejectsOversizedBody(t *testing.T) { srv, _ := newTestServer("root") r := newRouter(srv, Options{MaxRequestBytes: 16}) From d0c9985fb2a0dce3d7abf9b1679ccb2ecd003822 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:53:50 +0200 Subject: [PATCH 09/19] chatserver: support `X-Conversation-Id` for stateful sessions Until now the server was strictly stateless: every chat completion request rebuilt a fresh session from the messages array, so clients paid the tokenization cost of replaying the full history on every turn. That works but is wasteful for long conversations. Add an opt-in conversation cache: - `chatserver.Options.ConversationsMaxSessions` enables an in-memory LRU keyed by the `X-Conversation-Id` request header. `Options.ConversationTTL` (default 30 min) bounds idle lifetime; expired entries are evicted lazily on access and on Put. - When a request carries a known id, the server reuses the existing session and only appends the latest user message from the request body. The session already has the prior turns. When the id is unknown (or the header is absent), the server falls back to the previous behaviour and builds a session from scratch. - New `--conversations-max` and `--conversation-ttl` CLI flags expose the feature. Default 0 keeps the old stateless behaviour. The cache implementation is a simple map + mutex with O(n) LRU scan; that's appropriate for the small caches typical for this feature, and avoids pulling in a new dependency. Tests cover Put/Get, TTL expiry, LRU eviction, Delete, and the new appendLatestUser helper. Assisted-By: docker-agent --- cmd/root/chat.go | 34 ++++--- pkg/chatserver/agent.go | 22 +++++ pkg/chatserver/conversations.go | 137 +++++++++++++++++++++++++++ pkg/chatserver/conversations_test.go | 90 ++++++++++++++++++ pkg/chatserver/server.go | 68 +++++++++++-- 5 files changed, 330 insertions(+), 21 deletions(-) create mode 100644 pkg/chatserver/conversations.go create mode 100644 pkg/chatserver/conversations_test.go diff --git a/cmd/root/chat.go b/cmd/root/chat.go index 133b71d1b..7cf184d28 100644 --- a/cmd/root/chat.go +++ b/cmd/root/chat.go @@ -13,14 +13,16 @@ import ( ) type chatFlags struct { - agentName string - listenAddr string - corsOrigin string - apiKey string - apiKeyEnv string - maxRequestSize int64 - requestTimeout time.Duration - runConfig config.RuntimeConfig + agentName string + listenAddr string + corsOrigin string + apiKey string + apiKeyEnv string + maxRequestSize int64 + requestTimeout time.Duration + conversationsMaxItems int + conversationTTL time.Duration + runConfig config.RuntimeConfig } func newChatCmd() *cobra.Command { @@ -47,6 +49,8 @@ agent without any custom integration.`, cmd.Flags().StringVar(&flags.apiKeyEnv, "api-key-env", "", "Read the API key from this environment variable instead of the command line") cmd.Flags().Int64Var(&flags.maxRequestSize, "max-request-size", 1<<20, "Maximum request body size in bytes (default 1 MiB)") cmd.Flags().DurationVar(&flags.requestTimeout, "request-timeout", 5*time.Minute, "Per-request timeout (covers model + tool calls + streaming)") + cmd.Flags().IntVar(&flags.conversationsMaxItems, "conversations-max", 0, "Cache up to N conversations server-side, keyed by X-Conversation-Id (0 disables; clients must resend full history)") + cmd.Flags().DurationVar(&flags.conversationTTL, "conversation-ttl", 30*time.Minute, "Idle TTL after which a cached conversation is evicted") addRuntimeConfigFlags(cmd, &flags.runConfig) return cmd @@ -79,11 +83,13 @@ func (f *chatFlags) runChatCommand(cmd *cobra.Command, args []string) (commandEr } return chatserver.Run(ctx, agentFilename, chatserver.Options{ - AgentName: f.agentName, - RunConfig: &f.runConfig, - CORSOrigin: f.corsOrigin, - APIKey: apiKey, - MaxRequestBytes: f.maxRequestSize, - RequestTimeout: f.requestTimeout, + AgentName: f.agentName, + RunConfig: &f.runConfig, + CORSOrigin: f.corsOrigin, + APIKey: apiKey, + MaxRequestBytes: f.maxRequestSize, + RequestTimeout: f.requestTimeout, + ConversationsMaxSessions: f.conversationsMaxItems, + ConversationTTL: f.conversationTTL, }, ln) } diff --git a/pkg/chatserver/agent.go b/pkg/chatserver/agent.go index 272bd1434..ce5525d1e 100644 --- a/pkg/chatserver/agent.go +++ b/pkg/chatserver/agent.go @@ -104,6 +104,28 @@ func buildSession(messages []ChatCompletionMessage) *session.Session { return sess } +// appendLatestUser walks msgs from the end and appends only the last +// user-role message into sess. Used by conversation continuation, where +// the session already contains the full prior history and we just need +// to inject what the client just said. +func appendLatestUser(sess *session.Session, msgs []ChatCompletionMessage) { + for i := len(msgs) - 1; i >= 0; i-- { + m := msgs[i] + content := strings.TrimSpace(m.Content) + if content == "" { + continue + } + role := strings.ToLower(strings.TrimSpace(m.Role)) + // Treat any non-system/assistant/tool role as user (matches + // buildSession's policy). + if role == "system" || role == "assistant" || role == "tool" { + continue + } + sess.AddMessage(session.UserMessage(m.Content)) + return + } +} + // runAgentLoop drives the runtime to completion, forwarding assistant // content to emit (which may be nil for non-streaming mode). // diff --git a/pkg/chatserver/conversations.go b/pkg/chatserver/conversations.go new file mode 100644 index 000000000..68d28c052 --- /dev/null +++ b/pkg/chatserver/conversations.go @@ -0,0 +1,137 @@ +package chatserver + +import ( + "sync" + "time" + + "github.com/docker/docker-agent/pkg/session" +) + +// conversationStore keeps long-lived sessions keyed by the +// `X-Conversation-Id` header so clients don't have to resend the full +// conversation history on every turn. +// +// It's an LRU with a TTL: entries past `ttl` since their last use are +// considered expired and lazily evicted on Get. When the store would +// grow past `maxEntries`, the least-recently-used entry is evicted on +// Put. Both eviction paths are O(n) since this cache is small (typical +// `maxEntries` ≤ a few hundred); a doubly-linked-list LRU would be +// strictly faster but the extra code is rarely worth it. +// +// All operations are safe for concurrent use. +type conversationStore struct { + mu sync.Mutex + items map[string]*conversationEntry + maxEntries int + ttl time.Duration + now func() time.Time // injectable for tests +} + +type conversationEntry struct { + sess *session.Session + lastUsed time.Time +} + +// newConversationStore returns a store that holds at most maxEntries +// sessions and forgets entries that have been idle for more than ttl. +// Either bound can be zero/negative to disable that bound. A store with +// both bounds disabled is functionally a regular map; a store with +// maxEntries == 0 is disabled and Get always misses. +func newConversationStore(maxEntries int, ttl time.Duration) *conversationStore { + return &conversationStore{ + items: make(map[string]*conversationEntry), + maxEntries: maxEntries, + ttl: ttl, + now: time.Now, + } +} + +// Get returns the stored session for id and refreshes its last-used +// timestamp. Misses return nil. The store is disabled when maxEntries +// <= 0, in which case Get always misses. +func (c *conversationStore) Get(id string) *session.Session { + if c == nil || c.maxEntries <= 0 || id == "" { + return nil + } + c.mu.Lock() + defer c.mu.Unlock() + e, ok := c.items[id] + if !ok { + return nil + } + if c.expired(e) { + delete(c.items, id) + return nil + } + e.lastUsed = c.now() + return e.sess +} + +// Put stores sess under id and evicts the least-recently-used entry if +// the store is over capacity. Has no effect when the store is disabled. +func (c *conversationStore) Put(id string, sess *session.Session) { + if c == nil || c.maxEntries <= 0 || id == "" || sess == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + now := c.now() + c.items[id] = &conversationEntry{sess: sess, lastUsed: now} + + // Drop expired neighbours in the same critical section so callers + // don't accumulate dead weight on long-running stores. + for k, v := range c.items { + if c.expired(v) { + delete(c.items, k) + } + } + for len(c.items) > c.maxEntries { + c.evictOldestLocked() + } +} + +// Delete removes id from the store, if present. Useful for clients that +// want to explicitly close out a conversation. +func (c *conversationStore) Delete(id string) { + if c == nil || id == "" { + return + } + c.mu.Lock() + defer c.mu.Unlock() + delete(c.items, id) +} + +// Len returns the current number of cached conversations. Mostly useful +// for tests and metrics. +func (c *conversationStore) Len() int { + if c == nil { + return 0 + } + c.mu.Lock() + defer c.mu.Unlock() + return len(c.items) +} + +func (c *conversationStore) expired(e *conversationEntry) bool { + if c.ttl <= 0 { + return false + } + return c.now().Sub(e.lastUsed) > c.ttl +} + +// evictOldestLocked removes the oldest entry. Caller holds c.mu. +func (c *conversationStore) evictOldestLocked() { + var oldestKey string + var oldestTime time.Time + first := true + for k, v := range c.items { + if first || v.lastUsed.Before(oldestTime) { + oldestKey = k + oldestTime = v.lastUsed + first = false + } + } + if !first { + delete(c.items, oldestKey) + } +} diff --git a/pkg/chatserver/conversations_test.go b/pkg/chatserver/conversations_test.go new file mode 100644 index 000000000..038cc0758 --- /dev/null +++ b/pkg/chatserver/conversations_test.go @@ -0,0 +1,90 @@ +package chatserver + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/session" +) + +func TestConversationStore_Disabled(t *testing.T) { + c := newConversationStore(0, time.Hour) + c.Put("a", session.New()) + assert.Nil(t, c.Get("a")) + assert.Equal(t, 0, c.Len()) +} + +func TestConversationStore_PutGet(t *testing.T) { + c := newConversationStore(8, time.Hour) + s := session.New() + c.Put("a", s) + + got := c.Get("a") + require.NotNil(t, got) + assert.Same(t, s, got) +} + +func TestConversationStore_TTL(t *testing.T) { + now := time.Unix(1_000_000, 0) + c := newConversationStore(8, time.Minute) + c.now = func() time.Time { return now } + + c.Put("a", session.New()) + assert.NotNil(t, c.Get("a")) + + now = now.Add(2 * time.Minute) + assert.Nil(t, c.Get("a"), "entry should be expired") + assert.Equal(t, 0, c.Len(), "expired entry should be evicted on Get miss") +} + +func TestConversationStore_LRUEviction(t *testing.T) { + now := time.Unix(1_000_000, 0) + c := newConversationStore(2, time.Hour) + c.now = func() time.Time { return now } + + c.Put("a", session.New()) + now = now.Add(time.Second) + c.Put("b", session.New()) + now = now.Add(time.Second) + // Touch "a" so it becomes the most-recently-used. + require.NotNil(t, c.Get("a")) + now = now.Add(time.Second) + c.Put("c", session.New()) + + // "b" was the LRU when capacity was exceeded, so it should be the + // one that got evicted. + assert.Nil(t, c.Get("b")) + assert.NotNil(t, c.Get("a")) + assert.NotNil(t, c.Get("c")) +} + +func TestConversationStore_Delete(t *testing.T) { + c := newConversationStore(8, time.Hour) + c.Put("a", session.New()) + c.Delete("a") + assert.Nil(t, c.Get("a")) +} + +func TestAppendLatestUser(t *testing.T) { + sess := session.New() + appendLatestUser(sess, []ChatCompletionMessage{ + {Role: "system", Content: "be helpful"}, + {Role: "user", Content: "first"}, + {Role: "assistant", Content: "ack"}, + {Role: "user", Content: "second"}, + {Role: "tool", Content: "tool result", ToolCallID: "x"}, + }) + assert.Equal(t, "second", sess.GetLastUserMessageContent()) +} + +func TestAppendLatestUser_NoUserMessage(t *testing.T) { + sess := session.New() + appendLatestUser(sess, []ChatCompletionMessage{ + {Role: "system", Content: "be helpful"}, + {Role: "assistant", Content: "ack"}, + }) + assert.Empty(t, sess.GetLastUserMessageContent()) +} diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index a01ef5a59..21af419bf 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -69,11 +69,21 @@ type Options struct { // run. Zero means use the package default (5 minutes). The cap covers // model calls, tool calls, and SSE streaming combined. RequestTimeout time.Duration + // ConversationsMaxSessions, when > 0, enables the X-Conversation-Id + // header: clients can pass a stable id to reuse the same session + // across requests instead of re-sending the full message history + // every turn. This is the size of the in-memory LRU cache. + ConversationsMaxSessions int + // ConversationTTL is how long a cached conversation may be idle + // before it's evicted. Zero means use the package default + // (30 minutes). + ConversationTTL time.Duration } const ( defaultMaxRequestBytes int64 = 1 << 20 // 1 MiB defaultRequestTimeout time.Duration = 5 * time.Minute + defaultConversationTTL time.Duration = 30 * time.Minute ) // Run starts an OpenAI-compatible HTTP server on the given listener and @@ -99,12 +109,19 @@ func Run(ctx context.Context, agentFilename string, opts Options, ln net.Listene } httpServer := &http.Server{ - Handler: newRouter(&server{team: t, policy: policy}, opts), + Handler: newRouter(&server{team: t, policy: policy, conversations: newConversationStore(opts.ConversationsMaxSessions, conversationTTL(opts))}, opts), ReadHeaderTimeout: 30 * time.Second, } return serve(ctx, httpServer, ln) } +func conversationTTL(opts Options) time.Duration { + if opts.ConversationTTL > 0 { + return opts.ConversationTTL + } + return defaultConversationTTL +} + // loadTeam resolves and loads the team referenced by agentFilename. func loadTeam(ctx context.Context, agentFilename string, runConfig *config.RuntimeConfig) (*team.Team, error) { src, err := config.Resolve(agentFilename, nil) @@ -139,10 +156,11 @@ func serve(ctx context.Context, httpServer *http.Server, ln net.Listener) error // server is concurrent-safe: every request creates its own session and // runtime, so the only shared state is the team (whose toolsets are -// independently safe to call). +// independently safe to call) and the optional conversation cache. type server struct { - team *team.Team - policy agentPolicy + team *team.Team + policy agentPolicy + conversations *conversationStore } func newRouter(s *server, opts Options) http.Handler { @@ -235,7 +253,8 @@ func (s *server) handleChatCompletions(c echo.Context) error { return writeError(c, http.StatusBadRequest, err.Error()) } - sess := buildSession(req.Messages) + conversationID := c.Request().Header.Get("X-Conversation-Id") + sess, isNew := s.resolveSession(conversationID, req.Messages) if sess == nil { return writeError(c, http.StatusBadRequest, "no user message provided") } @@ -254,9 +273,44 @@ func (s *server) handleChatCompletions(c echo.Context) error { } if req.Stream { - return s.streamChatCompletion(c, rt, sess, model) + err := s.streamChatCompletion(c, rt, sess, model) + s.maybeStoreConversation(conversationID, sess, isNew) + return err + } + err = s.chatCompletion(c, rt, sess, model) + s.maybeStoreConversation(conversationID, sess, isNew) + return err +} + +// resolveSession decides whether to start fresh or continue an existing +// conversation. When X-Conversation-Id is set and we have an existing +// session for it, we append only the latest user message from the +// request (the prior history is already in the session). Otherwise we +// build a brand-new session from the full request history. +// +// Returns the session to run, plus a flag indicating whether it was +// freshly created (so callers know whether they need to insert it into +// the cache). +func (s *server) resolveSession(id string, msgs []ChatCompletionMessage) (*session.Session, bool) { + if id != "" { + if existing := s.conversations.Get(id); existing != nil { + appendLatestUser(existing, msgs) + return existing, false + } + } + return buildSession(msgs), true +} + +// maybeStoreConversation inserts the session into the cache after a +// run. We only need to insert when the conversation is new — existing +// entries are mutated in place. +func (s *server) maybeStoreConversation(id string, sess *session.Session, isNew bool) { + if id == "" || s.conversations == nil { + return + } + if isNew { + s.conversations.Put(id, sess) } - return s.chatCompletion(c, rt, sess, model) } // chatCompletion runs the agent to completion and replies with one From 6e09c3a31a753d5a8092a6780733325384f7a641 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:56:35 +0200 Subject: [PATCH 10/19] chatserver: pool runtimes per agent for warm reuse Every chat completion request used to call `runtime.New` from scratch: that resolves the agent's tools, builds per-agent hook executors, and allocates per-runtime resume/elicitation channels. On a busy server those allocations show up in profiles. Add an opt-in pool so a small number of warm runtimes per agent can be reused across requests: - `chatserver.Options.MaxIdleRuntimes` (default 4 via `--max-idle- runtimes`) bounds the idle pool size per agent. 0 disables pooling entirely and restores the original "fresh runtime per request" behaviour. - `runtimePool.Get` returns a recycled runtime when one is idle, or creates a new one. `Put` returns it to the pool on completion; overflow is dropped on the floor (the team owns the toolsets, so nothing leaks). - A runtime is *not* safe for concurrent `RunStream` calls (its resume/elicitation channels are per-runtime), so the pool hands out at most one borrow per runtime at a time. Concurrency comes from holding multiple runtimes per agent. Assisted-By: docker-agent --- cmd/root/chat.go | 3 + pkg/chatserver/runtime_pool.go | 107 ++++++++++++++++++++++++++++ pkg/chatserver/runtime_pool_test.go | 26 +++++++ pkg/chatserver/server.go | 25 +++++-- 4 files changed, 157 insertions(+), 4 deletions(-) create mode 100644 pkg/chatserver/runtime_pool.go create mode 100644 pkg/chatserver/runtime_pool_test.go diff --git a/cmd/root/chat.go b/cmd/root/chat.go index 7cf184d28..ec1b0c23c 100644 --- a/cmd/root/chat.go +++ b/cmd/root/chat.go @@ -22,6 +22,7 @@ type chatFlags struct { requestTimeout time.Duration conversationsMaxItems int conversationTTL time.Duration + maxIdleRuntimes int runConfig config.RuntimeConfig } @@ -51,6 +52,7 @@ agent without any custom integration.`, cmd.Flags().DurationVar(&flags.requestTimeout, "request-timeout", 5*time.Minute, "Per-request timeout (covers model + tool calls + streaming)") cmd.Flags().IntVar(&flags.conversationsMaxItems, "conversations-max", 0, "Cache up to N conversations server-side, keyed by X-Conversation-Id (0 disables; clients must resend full history)") cmd.Flags().DurationVar(&flags.conversationTTL, "conversation-ttl", 30*time.Minute, "Idle TTL after which a cached conversation is evicted") + cmd.Flags().IntVar(&flags.maxIdleRuntimes, "max-idle-runtimes", 4, "Maximum number of idle runtimes pooled per agent (0 disables pooling)") addRuntimeConfigFlags(cmd, &flags.runConfig) return cmd @@ -91,5 +93,6 @@ func (f *chatFlags) runChatCommand(cmd *cobra.Command, args []string) (commandEr RequestTimeout: f.requestTimeout, ConversationsMaxSessions: f.conversationsMaxItems, ConversationTTL: f.conversationTTL, + MaxIdleRuntimes: f.maxIdleRuntimes, }, ln) } diff --git a/pkg/chatserver/runtime_pool.go b/pkg/chatserver/runtime_pool.go new file mode 100644 index 000000000..d79f03448 --- /dev/null +++ b/pkg/chatserver/runtime_pool.go @@ -0,0 +1,107 @@ +package chatserver + +import ( + "errors" + "sync" + + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/team" +) + +// runtimePool keeps a small set of `runtime.Runtime` instances ready for +// reuse, keyed by agent name. Building a runtime is non-trivial (it +// resolves the agent's tools, creates per-agent hook executors, sets up +// channels for resume/elicitation), so reusing the work across requests +// is a real latency win for hot paths. +// +// Concurrency model: a single runtime is *not* safe for concurrent +// RunStream calls (its resume/elicitation channels are per-runtime +// state). The pool therefore hands out a runtime to one caller at a +// time. Callers Get → use → Put back. When the pool is empty a fresh +// runtime is built. +// +// `maxIdle` bounds the number of idle runtimes per agent. Returning a +// runtime to a full pool is a no-op; it simply gets garbage collected. +type runtimePool struct { + team *team.Team + maxIdle int + + mu sync.Mutex + idle map[string]chan runtime.Runtime +} + +// errInvalidRuntime is returned when a caller asks for a runtime for an +// agent the pool can't create one for. Today this can only happen if +// runtime.New fails for a reason unrelated to the team (e.g. context +// cancellation in a future async path). +var errInvalidRuntime = errors.New("failed to acquire runtime") + +func newRuntimePool(t *team.Team, maxIdle int) *runtimePool { + if maxIdle < 0 { + maxIdle = 0 + } + return &runtimePool{ + team: t, + maxIdle: maxIdle, + idle: make(map[string]chan runtime.Runtime), + } +} + +// Get returns a ready-to-use runtime for the given agent, either +// recycled from the pool or freshly created. +func (p *runtimePool) Get(agent string) (runtime.Runtime, error) { + if p == nil { + return nil, errInvalidRuntime + } + if rt := p.takeIdle(agent); rt != nil { + return rt, nil + } + rt, err := runtime.New(p.team, runtime.WithCurrentAgent(agent)) + if err != nil { + return nil, err + } + return rt, nil +} + +// Put hands a finished runtime back to the pool. If the agent's idle +// slot is full the runtime is discarded (not closed: the team owns the +// underlying toolsets). The runtime must not be used by the caller +// after Put returns. +func (p *runtimePool) Put(agent string, rt runtime.Runtime) { + if p == nil || rt == nil || p.maxIdle == 0 { + return + } + ch := p.channelFor(agent) + select { + case ch <- rt: + default: + // pool full: drop on the floor. The team owns the toolsets, + // so nothing leaks; the runtime itself is ordinary garbage. + } +} + +func (p *runtimePool) takeIdle(agent string) runtime.Runtime { + p.mu.Lock() + ch, ok := p.idle[agent] + p.mu.Unlock() + if !ok { + return nil + } + select { + case rt := <-ch: + return rt + default: + return nil + } +} + +func (p *runtimePool) channelFor(agent string) chan runtime.Runtime { + p.mu.Lock() + defer p.mu.Unlock() + ch, ok := p.idle[agent] + if !ok { + ch = make(chan runtime.Runtime, p.maxIdle) + p.idle[agent] = ch + } + return ch +} diff --git a/pkg/chatserver/runtime_pool_test.go b/pkg/chatserver/runtime_pool_test.go new file mode 100644 index 000000000..b69774289 --- /dev/null +++ b/pkg/chatserver/runtime_pool_test.go @@ -0,0 +1,26 @@ +package chatserver + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRuntimePool_DisabledIsNotCached(t *testing.T) { + p := newRuntimePool(nil, 0) + + // Put with maxIdle=0 must be a no-op (we don't have a runtime to put, + // but the channel-for behaviour itself shouldn't allocate). + p.Put("root", nil) + assert.Empty(t, p.idle, "no per-agent channels should be allocated when pooling is disabled") +} + +func TestRuntimePool_NegativeCapTreatedAsZero(t *testing.T) { + p := newRuntimePool(nil, -1) + assert.Equal(t, 0, p.maxIdle) +} + +func TestRuntimePool_takeIdleNoChannel(t *testing.T) { + p := newRuntimePool(nil, 4) + assert.Nil(t, p.takeIdle("anything")) +} diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index 21af419bf..8976c0539 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -78,6 +78,12 @@ type Options struct { // before it's evicted. Zero means use the package default // (30 minutes). ConversationTTL time.Duration + // MaxIdleRuntimes bounds the number of idle runtimes pooled per + // agent. Building a runtime resolves tools and sets up channels; + // keeping a small pool of warm runtimes avoids paying that cost on + // every request. Zero disables pooling (a fresh runtime is built + // for every request, the original behaviour). + MaxIdleRuntimes int } const ( @@ -109,7 +115,12 @@ func Run(ctx context.Context, agentFilename string, opts Options, ln net.Listene } httpServer := &http.Server{ - Handler: newRouter(&server{team: t, policy: policy, conversations: newConversationStore(opts.ConversationsMaxSessions, conversationTTL(opts))}, opts), + Handler: newRouter(&server{ + team: t, + policy: policy, + conversations: newConversationStore(opts.ConversationsMaxSessions, conversationTTL(opts)), + runtimes: newRuntimePool(t, opts.MaxIdleRuntimes), + }, opts), ReadHeaderTimeout: 30 * time.Second, } return serve(ctx, httpServer, ln) @@ -161,6 +172,7 @@ type server struct { team *team.Team policy agentPolicy conversations *conversationStore + runtimes *runtimePool } func newRouter(s *server, opts Options) http.Handler { @@ -260,10 +272,11 @@ func (s *server) handleChatCompletions(c echo.Context) error { } agentName := s.policy.pick(req.Model) - rt, err := runtime.New(s.team, runtime.WithCurrentAgent(agentName)) + rt, err := s.runtimes.Get(agentName) if err != nil { - return writeError(c, http.StatusInternalServerError, fmt.Sprintf("failed to create runtime: %v", err)) + return writeError(c, http.StatusInternalServerError, fmt.Sprintf("failed to acquire runtime: %v", err)) } + defer s.runtimes.Put(agentName, rt) // Echo back the requested model verbatim when set, so clients matching // on the model field stay happy. Otherwise expose the actual agent. @@ -339,7 +352,11 @@ func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *sessio // streamChatCompletion runs the agent and streams its response back to the // client as Server-Sent Events in OpenAI's chat.completion.chunk format. -func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { +// +// The error return is reserved for future use (e.g. surfacing a write +// failure to the request logger). Today every error is converted into an +// in-band SSE error event, so the function always returns nil. +func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { //nolint:unparam // see comment stream := newSSEStream(c.Response(), newChatID(), model) // Initial "role: assistant" delta so clients can start rendering. From 31302b84dd968cb6d09da2e5041b44077bf4b958 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 13:58:21 +0200 Subject: [PATCH 11/19] chatserver: support comma-list and regex in `--cors-origin` The previous commit only accepted a single literal origin. Real deployments often need to allow several front-ends or all subdomains of a known SaaS. Extend the flag's grammar: - comma-separated entries form an explicit allow-list, each matched exactly; - entries prefixed with `~` are compiled as Go regex and matched against the request's `Origin` header at request time; - the literal `*` wildcard is preserved for the (rare) cases where the operator really wants it; - literal entries are validated up front: scheme must be http/https, no path/query/fragment, no missing host. Mistakes are caught at startup rather than producing silent allow-none behaviour at runtime. When the spec parses cleanly to nothing usable, the middleware is left unregistered and a slog.Error documents the misconfiguration. Tests cover the parser's accept/reject set and exercise allow-list + regex routing through the real Echo middleware. Assisted-By: docker-agent --- pkg/chatserver/server.go | 103 ++++++++++++++++++++++++++++++++-- pkg/chatserver/server_test.go | 76 +++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 6 deletions(-) diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index 8976c0539..adc5bda67 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -25,6 +25,8 @@ import ( "math" "net" "net/http" + "net/url" + "regexp" "slices" "strconv" "strings" @@ -54,6 +56,15 @@ type Options struct { // CORSOrigin is the allowed value for the Access-Control-Allow-Origin // header. When empty, the CORS middleware is not registered at all // (the server never emits any Access-Control-* response header). + // + // Multiple values can be provided separated by commas. Each entry is + // either a literal origin (matched exactly), the wildcard "*", or a + // pattern starting with "~" interpreted as a Go regular expression + // against the request's Origin header. Examples: + // + // "https://app.example.com" + // "https://app.example.com,https://staging.example.com" + // "~^https://[a-z0-9-]+\\.example\\.com$" CORSOrigin string // APIKey, if non-empty, is the static bearer token clients must // present in the `Authorization` header (`Authorization: Bearer X`). @@ -196,12 +207,14 @@ func newRouter(s *server, opts Options) http.Handler { e.Use(bearerAuthMiddleware(opts.APIKey)) } if opts.CORSOrigin != "" { - e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ - AllowOrigins: []string{opts.CORSOrigin}, - AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions}, - AllowHeaders: []string{"Authorization", "Content-Type", "Accept"}, - MaxAge: 86400, - })) + cfg, err := corsMiddlewareConfig(opts.CORSOrigin) + if err != nil { + // Bad config is reported via the request log. The middleware + // is simply not registered, which is the safest default. + slog.Error("Invalid --cors-origin, CORS disabled", "error", err) + } else { + e.Use(middleware.CORSWithConfig(cfg)) + } } e.GET("/v1/models", s.handleModels) @@ -222,6 +235,84 @@ func requestTimeoutMiddleware(d time.Duration) echo.MiddlewareFunc { } } +// corsMiddlewareConfig parses a comma-separated --cors-origin value into an +// echo middleware.CORSConfig. Each entry is one of: +// +// - the literal "*" wildcard; +// - a regex when prefixed with "~" (compiled and matched against the +// request's Origin header); +// - a literal origin matched verbatim. +// +// Returns an error when no entry parses successfully, in which case the +// caller leaves the middleware unregistered. +func corsMiddlewareConfig(spec string) (middleware.CORSConfig, error) { + var literals []string + var patterns []*regexp.Regexp + for raw := range strings.SplitSeq(spec, ",") { + entry := strings.TrimSpace(raw) + if entry == "" { + continue + } + if rest, ok := strings.CutPrefix(entry, "~"); ok { + re, err := regexp.Compile(rest) + if err != nil { + return middleware.CORSConfig{}, fmt.Errorf("invalid CORS regex %q: %w", rest, err) + } + patterns = append(patterns, re) + continue + } + if err := validateCORSOrigin(entry); err != nil { + return middleware.CORSConfig{}, err + } + literals = append(literals, entry) + } + if len(literals) == 0 && len(patterns) == 0 { + return middleware.CORSConfig{}, errors.New("no usable CORS origins") + } + + cfg := middleware.CORSConfig{ + AllowOrigins: literals, + AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions}, + AllowHeaders: []string{"Authorization", "Content-Type", "Accept"}, + MaxAge: 86400, + } + if len(patterns) > 0 { + cfg.AllowOriginFunc = func(origin string) (bool, error) { + for _, re := range patterns { + if re.MatchString(origin) { + return true, nil + } + } + return false, nil + } + } + return cfg, nil +} + +// validateCORSOrigin sanity-checks a literal origin entry. The aim is to +// reject obvious typos early ("http//foo.com", "https://foo.com/bar") +// rather than to be a full URL parser — the echo middleware will still +// do its own matching at request time. +func validateCORSOrigin(o string) error { + if o == "*" { + return nil + } + u, err := url.Parse(o) + if err != nil { + return fmt.Errorf("invalid CORS origin %q: %w", o, err) + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("invalid CORS origin %q: scheme must be http or https", o) + } + if u.Host == "" { + return fmt.Errorf("invalid CORS origin %q: missing host", o) + } + if u.Path != "" || u.RawQuery != "" || u.Fragment != "" { + return fmt.Errorf("invalid CORS origin %q: must not include path, query, or fragment", o) + } + return nil +} + // bearerAuthMiddleware enforces the static `Authorization: Bearer ` // header. CORS preflight requests (OPTIONS) are exempted so that browsers // can negotiate before sending the auth header. diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index e84cad521..acc54b069 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -175,6 +175,82 @@ func TestNewRouter_CORSAllowsConfiguredOrigin(t *testing.T) { assert.Equal(t, "https://example.com", rec.Header().Get("Access-Control-Allow-Origin")) } +func TestCorsMiddlewareConfig(t *testing.T) { + cases := []struct { + name string + spec string + wantErr bool + }{ + {name: "single literal", spec: "https://app.example.com"}, + {name: "comma list", spec: "https://a.example.com, https://b.example.com"}, + {name: "regex", spec: `~^https://[a-z]+\.example\.com$`}, + {name: "wildcard", spec: "*"}, + {name: "mixed", spec: `https://a.example.com,~^https://b\.example\.com$`}, + {name: "empty entries collapse", spec: ", , https://x.com,,"}, + + {name: "missing scheme", spec: "app.example.com", wantErr: true}, + {name: "with path", spec: "https://example.com/api", wantErr: true}, + {name: "with query", spec: "https://example.com?x=1", wantErr: true}, + {name: "bad regex", spec: "~[", wantErr: true}, + {name: "all blanks", spec: ", , ,", wantErr: true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := corsMiddlewareConfig(tc.spec) + if tc.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func TestNewRouter_CORSAllowList(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{CORSOrigin: "https://a.example.com,https://b.example.com"}) + + cases := []struct { + origin string + want string // expected Access-Control-Allow-Origin + }{ + {"https://a.example.com", "https://a.example.com"}, + {"https://b.example.com", "https://b.example.com"}, + {"https://evil.example.com", ""}, + } + for _, tc := range cases { + t.Run(tc.origin, func(t *testing.T) { + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", tc.origin) + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + assert.Equal(t, tc.want, rec.Header().Get("Access-Control-Allow-Origin")) + }) + } +} + +func TestNewRouter_CORSRegex(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{CORSOrigin: `~^https://[a-z]+\.example\.com$`}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", "https://staging.example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, "https://staging.example.com", rec.Header().Get("Access-Control-Allow-Origin")) + + // A non-matching origin must not get the header. + req2 := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req2.Header.Set("Origin", "https://evil.attacker.com") + req2.Header.Set("Access-Control-Request-Method", "GET") + rec2 := httptest.NewRecorder() + r.ServeHTTP(rec2, req2) + assert.Empty(t, rec2.Header().Get("Access-Control-Allow-Origin")) +} + func TestBearerAuthMiddleware(t *testing.T) { cases := []struct { name string From ebb34ea364b415c010c2a7641d7e9d3c4d49a341 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 14:01:10 +0200 Subject: [PATCH 12/19] chatserver: surface agent tool calls as OpenAI tool_calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the agent invokes a tool, clients had no way to see what happened: tools ran inside the runtime, the assistant's eventual text output sometimes referenced them but often didn't, and the streaming protocol carried only the model's plain content. That's fine for a black-box transcript but useless for a chat UI that wants to render "🔧 calling search(query=…)" badges. Use OpenAI's standard `tool_calls` shape on both response styles: - Add `ToolCallReference` (mirrors OpenAI's tool_call entry) with `index`, `id`, `type`, `function.{name,arguments}`. - `ChatCompletionMessage.ToolCalls` populated on the non-streaming response so the assistant message lists every tool the agent invoked. - `ChatCompletionStreamDelta.ToolCalls` carries one tool per delta in streaming mode. The runtime hands us complete arguments, so one chunk per call is sufficient (vs. OpenAI's incremental argument streaming, which clients accumulate either way). - `runAgentLoop` now takes an `agentEmit` struct with `onContent` and `onToolCall` hooks instead of a single content callback. Both handlers fill in their respective hooks; missing ones are no-ops. Tools still execute server-side; this commit is purely about client observability. Surfacing results back through the protocol (so clients could intercept / replay them) is left for a future change. Assisted-By: docker-agent --- pkg/chatserver/agent.go | 32 +++++++++++++++++++++++----- pkg/chatserver/server.go | 34 +++++++++++++++++++++++------- pkg/chatserver/server_test.go | 20 ++++++++++++++++++ pkg/chatserver/types.go | 39 +++++++++++++++++++++++++++++------ 4 files changed, 106 insertions(+), 19 deletions(-) diff --git a/pkg/chatserver/agent.go b/pkg/chatserver/agent.go index ce5525d1e..a372f0526 100644 --- a/pkg/chatserver/agent.go +++ b/pkg/chatserver/agent.go @@ -126,8 +126,19 @@ func appendLatestUser(sess *session.Session, msgs []ChatCompletionMessage) { } } -// runAgentLoop drives the runtime to completion, forwarding assistant -// content to emit (which may be nil for non-streaming mode). +// agentEmit collects the side-effect callbacks invoked by runAgentLoop as +// it drives the runtime. All callbacks are optional; nil means "ignore +// this kind of event". +type agentEmit struct { + // onContent fires for every assistant text delta from the model. + onContent func(string) + // onToolCall fires when the agent dispatches a tool. Called once per + // tool, with the tool already populated with its arguments. + onToolCall func(ToolCallReference) +} + +// runAgentLoop drives the runtime to completion, forwarding events to +// the supplied callbacks. // // The session is built with ToolsApproved=true and NonInteractive=true, // which means the runtime auto-approves tool calls and auto-stops on @@ -140,13 +151,24 @@ func appendLatestUser(sess *session.Session, msgs []ChatCompletionMessage) { // All ErrorEvents seen in the run are joined into the returned error so // callers can see the full picture; the loop keeps draining until the // stream closes so the runtime can shut down cleanly. -func runAgentLoop(ctx context.Context, rt runtime.Runtime, sess *session.Session, emit func(string)) error { +func runAgentLoop(ctx context.Context, rt runtime.Runtime, sess *session.Session, emit agentEmit) error { var runErrs []error + toolIndex := 0 for ev := range rt.RunStream(ctx, sess) { switch e := ev.(type) { case *runtime.AgentChoiceEvent: - if emit != nil { - emit(e.Content) + if emit.onContent != nil { + emit.onContent(e.Content) + } + case *runtime.ToolCallEvent: + if emit.onToolCall != nil { + emit.onToolCall(ToolCallReference{ + Index: toolIndex, + ID: e.ToolCall.ID, + Type: string(e.ToolCall.Type), + Function: ToolCallFunction{Name: e.ToolCall.Function.Name, Arguments: e.ToolCall.Function.Arguments}, + }) + toolIndex++ } case *runtime.ToolCallConfirmationEvent: // Defensive: should never fire while ToolsApproved=true. diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index adc5bda67..bc120a67b 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -420,7 +420,13 @@ func (s *server) maybeStoreConversation(id string, sess *session.Session, isNew // chatCompletion runs the agent to completion and replies with one // non-streaming OpenAI ChatCompletion object. func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { - if err := runAgentLoop(c.Request().Context(), rt, sess, nil); err != nil { + var toolCalls []ToolCallReference + emit := agentEmit{ + onToolCall: func(tc ToolCallReference) { + toolCalls = append(toolCalls, tc) + }, + } + if err := runAgentLoop(c.Request().Context(), rt, sess, emit); err != nil { return writeError(c, http.StatusInternalServerError, fmt.Sprintf("agent execution failed: %v", err)) } @@ -432,8 +438,9 @@ func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *sessio Choices: []ChatCompletionChoice{{ Index: 0, Message: ChatCompletionMessage{ - Role: "assistant", - Content: sess.GetLastAssistantMessageContent(), + Role: "assistant", + Content: sess.GetLastAssistantMessageContent(), + ToolCalls: toolCalls, }, FinishReason: "stop", }}, @@ -453,11 +460,22 @@ func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess * // Initial "role: assistant" delta so clients can start rendering. stream.send(ChatCompletionStreamDelta{Role: "assistant"}, "") - runErr := runAgentLoop(c.Request().Context(), rt, sess, func(content string) { - if content != "" { - stream.send(ChatCompletionStreamDelta{Content: content}, "") - } - }) + emit := agentEmit{ + onContent: func(content string) { + if content != "" { + stream.send(ChatCompletionStreamDelta{Content: content}, "") + } + }, + onToolCall: func(tc ToolCallReference) { + // Surface tool calls to the client using OpenAI's exact wire + // shape: a single delta carrying the full tool_call entry. + // (OpenAI streams arguments token-by-token; we have them all + // at once, so one chunk per call is enough.) Tools still run + // server-side — this is purely for client visibility. + stream.send(ChatCompletionStreamDelta{ToolCalls: []ToolCallReference{tc}}, "") + }, + } + runErr := runAgentLoop(c.Request().Context(), rt, sess, emit) if runErr != nil { // Emit a structured error envelope (OpenAI streams use a regular // `data:` line carrying an `error` object, then close the stream diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index acc54b069..55250e05e 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -367,6 +367,26 @@ func TestStopSequences_UnmarshalJSON(t *testing.T) { } } +func TestSSEStream_ToolCallDelta(t *testing.T) { + rec := httptest.NewRecorder() + s := newSSEStream(rec, "chatcmpl-x", "root") + s.send(ChatCompletionStreamDelta{ToolCalls: []ToolCallReference{{ + Index: 0, + ID: "call_1", + Type: "function", + Function: ToolCallFunction{ + Name: "search", + Arguments: `{"q":"docker"}`, + }, + }}}, "") + + body := rec.Body.String() + assert.Contains(t, body, `"tool_calls":[`) + assert.Contains(t, body, `"id":"call_1"`) + assert.Contains(t, body, `"name":"search"`) + assert.Contains(t, body, `"arguments":"{\"q\":\"docker\"}"`) +} + func TestSSEStream_SendError(t *testing.T) { rec := httptest.NewRecorder() s := newSSEStream(rec, "chatcmpl-x", "root") diff --git a/pkg/chatserver/types.go b/pkg/chatserver/types.go index 22f7e41f2..71ea3dd37 100644 --- a/pkg/chatserver/types.go +++ b/pkg/chatserver/types.go @@ -76,10 +76,36 @@ func (s *StopSequences) UnmarshalJSON(data []byte) error { // content (image parts, audio, etc.) is not supported and falls back to the // `Content` string. type ChatCompletionMessage struct { - Role string `json:"role"` - Content string `json:"content"` - Name string `json:"name,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Role string `json:"role"` + Content string `json:"content"` + Name string `json:"name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolCalls []ToolCallReference `json:"tool_calls,omitempty"` +} + +// ToolCallReference mirrors OpenAI's `tool_calls` entry. The server fills +// it in on the *response* side so clients can introspect what tools the +// agent invoked. Tools are still executed server-side; this is purely +// informational. +type ToolCallReference struct { + // Index is the position of the tool call in the assistant message. + // In streaming mode multiple chunks targeting the same Index are + // concatenated by the client. + Index int `json:"index,omitempty"` + // ID matches what is later echoed back as ToolCallID on `tool` role + // messages — useful when correlating tool calls with their results. + ID string `json:"id,omitempty"` + // Type is always "function" today; OpenAI reserves the field for + // future expansion. + Type string `json:"type,omitempty"` + // Function carries the tool's name and JSON-encoded arguments. + Function ToolCallFunction `json:"function"` +} + +// ToolCallFunction mirrors OpenAI's nested tool function descriptor. +type ToolCallFunction struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` } // --- Non-streaming response ----------------------------------------------- @@ -126,8 +152,9 @@ type ChatCompletionStreamChoice struct { } type ChatCompletionStreamDelta struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ToolCalls []ToolCallReference `json:"tool_calls,omitempty"` } // --- Models endpoint ------------------------------------------------------ From fd55cc1db60e17373c6ae0e4ff5d25c753e7d7ed Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 14:03:29 +0200 Subject: [PATCH 13/19] chatserver: serve `/openapi.json` for schema introspection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a static OpenAPI 3.1 document describing /v1/models, /v1/chat/completions, the new tool_calls fields, the X-Conversation-Id header, and the bearer-auth security scheme. - The spec is hand-written and embedded with `//go:embed`. That keeps it easy to review (it's plain JSON, not generated noise), trivial to update when the API changes, and free of generation steps in the build. - A new `GET /openapi.json` route serves the spec verbatim. - `bearerAuthMiddleware` exempts /openapi.json so introspection tooling can discover the API even on locked-down deployments — there's no secret in the spec, only the shape of the API. Tests cover both the document shape (correct paths advertised) and the auth bypass. Assisted-By: docker-agent --- .dockerignore | 1 + pkg/chatserver/openapi.go | 23 +++ pkg/chatserver/openapi.json | 273 +++++++++++++++++++++++++++++++++ pkg/chatserver/openapi_test.go | 45 ++++++ pkg/chatserver/server.go | 11 ++ 5 files changed, 353 insertions(+) create mode 100644 pkg/chatserver/openapi.go create mode 100644 pkg/chatserver/openapi.json create mode 100644 pkg/chatserver/openapi_test.go diff --git a/.dockerignore b/.dockerignore index 40b2fd660..d37e17333 100644 --- a/.dockerignore +++ b/.dockerignore @@ -7,5 +7,6 @@ !./**/*.css !./**/*.go !./**/*.txt +!/pkg/chatserver/openapi.json !/pkg/config/builtin-agents/*.yaml !/pkg/tui/styles/themes/*.yaml \ No newline at end of file diff --git a/pkg/chatserver/openapi.go b/pkg/chatserver/openapi.go new file mode 100644 index 000000000..de35467c9 --- /dev/null +++ b/pkg/chatserver/openapi.go @@ -0,0 +1,23 @@ +package chatserver + +import ( + _ "embed" + "net/http" + + "github.com/labstack/echo/v4" +) + +// openAPISpec is the static OpenAPI 3.1 document describing the chat +// completions API. Embedding the JSON keeps the schema diffable and +// tractable to review, and means we don't pay a generation step on every +// build. +// +//go:embed openapi.json +var openAPISpec []byte + +// handleOpenAPI serves the static OpenAPI document. It is exempted from +// the bearer-auth middleware (see bearerAuthMiddleware) so tooling that +// wants to introspect the API can do so without credentials. +func (s *server) handleOpenAPI(c echo.Context) error { + return c.Blob(http.StatusOK, "application/json", openAPISpec) +} diff --git a/pkg/chatserver/openapi.json b/pkg/chatserver/openapi.json new file mode 100644 index 000000000..7d16b6070 --- /dev/null +++ b/pkg/chatserver/openapi.json @@ -0,0 +1,273 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "docker-agent chat completions", + "summary": "OpenAI-compatible HTTP API exposing a docker-agent agent.", + "description": "Implements a small subset of OpenAI's REST API (chat completions and models) so any tool that already speaks OpenAI's protocol can drive a docker-agent agent without a custom integration.", + "version": "1.0.0" + }, + "servers": [ + { + "url": "http://127.0.0.1:8083", + "description": "Default loopback bind" + } + ], + "components": { + "securitySchemes": { + "bearerAuth": { + "type": "http", + "scheme": "bearer", + "description": "Static token configured via --api-key. When --api-key is not set the server is unauthenticated." + } + }, + "schemas": { + "Model": { + "type": "object", + "required": ["id", "object", "owned_by"], + "properties": { + "id": { "type": "string", "description": "Agent name." }, + "object": { "type": "string", "const": "model" }, + "created": { "type": "integer", "format": "int64" }, + "owned_by": { "type": "string", "const": "docker-agent" } + } + }, + "ModelsResponse": { + "type": "object", + "required": ["object", "data"], + "properties": { + "object": { "type": "string", "const": "list" }, + "data": { + "type": "array", + "items": { "$ref": "#/components/schemas/Model" } + } + } + }, + "ChatCompletionMessage": { + "type": "object", + "required": ["role"], + "properties": { + "role": { + "type": "string", + "enum": ["system", "user", "assistant", "tool", "developer"] + }, + "content": { "type": "string" }, + "name": { "type": "string" }, + "tool_call_id": { "type": "string" }, + "tool_calls": { + "type": "array", + "items": { "$ref": "#/components/schemas/ToolCallReference" } + } + } + }, + "ToolCallReference": { + "type": "object", + "required": ["function"], + "properties": { + "index": { "type": "integer" }, + "id": { "type": "string" }, + "type": { "type": "string", "const": "function" }, + "function": { + "type": "object", + "required": ["name"], + "properties": { + "name": { "type": "string" }, + "arguments": { + "type": "string", + "description": "JSON-encoded arguments object." + } + } + } + } + }, + "ChatCompletionRequest": { + "type": "object", + "required": ["messages"], + "properties": { + "model": { + "type": "string", + "description": "Agent name to invoke. Defaults to the team's default agent when missing or unknown." + }, + "messages": { + "type": "array", + "minItems": 1, + "items": { "$ref": "#/components/schemas/ChatCompletionMessage" } + }, + "stream": { "type": "boolean", "default": false }, + "temperature": { + "type": "number", + "minimum": 0, + "maximum": 2, + "description": "Validated; full runtime plumbing is in progress." + }, + "top_p": { + "type": "number", + "exclusiveMinimum": 0, + "maximum": 1 + }, + "max_tokens": { "type": "integer", "minimum": 1 }, + "stop": { + "oneOf": [ + { "type": "string" }, + { "type": "array", "items": { "type": "string" } } + ] + } + } + }, + "ChatCompletionChoice": { + "type": "object", + "required": ["index", "message"], + "properties": { + "index": { "type": "integer" }, + "message": { "$ref": "#/components/schemas/ChatCompletionMessage" }, + "finish_reason": { + "type": "string", + "enum": ["stop", "tool_calls", "error", "length"] + } + } + }, + "ChatCompletionUsage": { + "type": "object", + "properties": { + "prompt_tokens": { "type": "integer", "format": "int64" }, + "completion_tokens": { "type": "integer", "format": "int64" }, + "total_tokens": { "type": "integer", "format": "int64" } + } + }, + "ChatCompletionResponse": { + "type": "object", + "required": ["id", "object", "created", "model", "choices"], + "properties": { + "id": { "type": "string" }, + "object": { "type": "string", "const": "chat.completion" }, + "created": { "type": "integer", "format": "int64" }, + "model": { "type": "string" }, + "choices": { + "type": "array", + "items": { "$ref": "#/components/schemas/ChatCompletionChoice" } + }, + "usage": { "$ref": "#/components/schemas/ChatCompletionUsage" } + } + }, + "ErrorResponse": { + "type": "object", + "required": ["error"], + "properties": { + "error": { + "type": "object", + "required": ["message", "type"], + "properties": { + "message": { "type": "string" }, + "type": { + "type": "string", + "enum": ["invalid_request_error", "internal_error"] + }, + "code": { "type": "string" } + } + } + } + } + } + }, + "security": [{ "bearerAuth": [] }], + "paths": { + "/v1/models": { + "get": { + "summary": "List the agents this server exposes.", + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ModelsResponse" } + } + } + }, + "401": { + "description": "Missing or invalid bearer token (only when --api-key is set).", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + } + } + } + }, + "/v1/chat/completions": { + "post": { + "summary": "Create a chat completion.", + "description": "Set `stream: true` to receive Server-Sent Events instead of a single JSON response. The optional `X-Conversation-Id` request header reuses a server-side session across turns when --conversations-max is non-zero.", + "parameters": [ + { + "in": "header", + "name": "X-Conversation-Id", + "schema": { "type": "string" }, + "description": "Stable identifier used to look up a cached session." + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ChatCompletionRequest" } + } + } + }, + "responses": { + "200": { + "description": "OK. Either a JSON ChatCompletion or a `text/event-stream` of `chat.completion.chunk` events.", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ChatCompletionResponse" } + }, + "text/event-stream": { + "schema": { "type": "string" } + } + } + }, + "400": { + "description": "Bad request (malformed JSON, missing user message, invalid sampling parameters).", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + }, + "401": { + "description": "Missing or invalid bearer token.", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + }, + "413": { "description": "Request body exceeds --max-request-size." }, + "500": { + "description": "Agent execution failed.", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + } + } + } + }, + "/openapi.json": { + "get": { + "summary": "Returns this OpenAPI document.", + "security": [], + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { "type": "object" } + } + } + } + } + } + } + } +} diff --git a/pkg/chatserver/openapi_test.go b/pkg/chatserver/openapi_test.go new file mode 100644 index 000000000..291ea4bd9 --- /dev/null +++ b/pkg/chatserver/openapi_test.go @@ -0,0 +1,45 @@ +package chatserver + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpenAPIEndpoint(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/openapi.json", http.NoBody) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var doc map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &doc)) + assert.Equal(t, "3.1.0", doc["openapi"], "OpenAPI version") + paths, ok := doc["paths"].(map[string]any) + require.True(t, ok) + assert.Contains(t, paths, "/v1/chat/completions") + assert.Contains(t, paths, "/v1/models") +} + +func TestOpenAPIEndpoint_BypassesAuth(t *testing.T) { + // /openapi.json must be reachable without a bearer token even when + // --api-key is set, so introspection tooling works against locked- + // down deployments. + srv, _ := newTestServer("root") + r := newRouter(srv, Options{APIKey: "secret"}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/openapi.json", http.NoBody) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index bc120a67b..da2f3a342 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -203,6 +203,12 @@ func newRouter(s *server, opts Options) http.Handler { e.Use(middleware.RequestLogger()) e.Use(middleware.BodyLimit(strconv.FormatInt(maxBytes, 10))) e.Use(requestTimeoutMiddleware(timeout)) + + // Register /openapi.json *before* the bearer-auth middleware so the + // schema is reachable for introspection without credentials. CORS + // configuration is then layered for /v1/* routes. + e.GET("/openapi.json", s.handleOpenAPI) + if opts.APIKey != "" { e.Use(bearerAuthMiddleware(opts.APIKey)) } @@ -327,6 +333,11 @@ func bearerAuthMiddleware(expected string) echo.MiddlewareFunc { if c.Request().Method == http.MethodOptions { return next(c) } + // Schema introspection is always reachable so tooling can + // discover the API without credentials. + if c.Path() == "/openapi.json" { + return next(c) + } got, ok := strings.CutPrefix(c.Request().Header.Get("Authorization"), "Bearer ") if !ok || subtle.ConstantTimeCompare([]byte(got), exp) != 1 { return writeError(c, http.StatusUnauthorized, "missing or invalid bearer token") From a8ce2eb7c1aff5c0900771b44385f72330a73560 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 14:07:25 +0200 Subject: [PATCH 14/19] chatserver: accept OpenAI multimodal content (text + image_url) OpenAI's chat protocol lets the `content` field of a message be either a string or an array of typed parts: "content": [ {"type": "text", "text": "What is in this picture?"}, {"type": "image_url", "image_url": {"url": "..."}} ] The chat server used to drop the parts variant on the floor: the field was typed as `string`, so multi-part requests deserialised to an empty content and the request was rejected as having "no user message". That made the server unable to serve any vision-capable agent. - Replace the plain `Content string` with a JSON-union (un)marshaller. `Content` still carries a flat-text view for string-form content and for the concatenated text of parts; a new `Parts []ContentPart` field holds the typed entries when the array shape is used. Existing Go callers (and every test that still writes `Content: "..."`) keep working unchanged. - `convertParts` translates the wire shape to the runtime's `chat.MessagePart` union (text + image_url), so the model provider sees the actual image. Unknown part types are dropped gracefully so future part kinds degrade rather than 500. - `appendLatestUser` (used by X-Conversation-Id continuation) gets the same multi-part path. - The OpenAPI spec advertises the union shape and the new ContentPart schema. Tests cover string/array round-trips, image_url plumbing into the session, and (still passing) all the pre-existing behaviour. Assisted-By: docker-agent --- pkg/chatserver/agent.go | 69 ++++++++++++++++++-- pkg/chatserver/openapi.json | 27 +++++++- pkg/chatserver/server_test.go | 70 ++++++++++++++++++++ pkg/chatserver/types.go | 117 ++++++++++++++++++++++++++++++++-- 4 files changed, 273 insertions(+), 10 deletions(-) diff --git a/pkg/chatserver/agent.go b/pkg/chatserver/agent.go index a372f0526..752286cdf 100644 --- a/pkg/chatserver/agent.go +++ b/pkg/chatserver/agent.go @@ -72,11 +72,29 @@ func buildSession(messages []ChatCompletionMessage) *session.Session { hasUser := false for _, m := range messages { + role := strings.ToLower(strings.TrimSpace(m.Role)) + if len(m.Parts) > 0 && (role == "" || (role != "system" && role != "assistant" && role != "tool")) { + // Multi-part content: route through chat.MultiContent so the + // runtime/provider sees image parts. Only user-style messages + // support images today. + parts := convertParts(m.Parts) + if len(parts) == 0 { + continue + } + sess.AddMessage(&session.Message{Message: chat.Message{ + Role: chat.MessageRoleUser, + Content: m.Content, + MultiContent: parts, + }}) + hasUser = true + continue + } + content := m.Content if strings.TrimSpace(content) == "" { continue } - switch strings.ToLower(strings.TrimSpace(m.Role)) { + switch role { case "system": sess.AddMessage(session.SystemMessage(content)) case "assistant": @@ -104,6 +122,38 @@ func buildSession(messages []ChatCompletionMessage) *session.Session { return sess } +// convertParts maps the chatserver wire shape to chat.MessagePart so +// images and (future) other typed parts reach the runtime intact. +// Unknown part types are dropped; an empty result tells the caller to +// skip the message entirely. +func convertParts(in []ContentPart) []chat.MessagePart { + out := make([]chat.MessagePart, 0, len(in)) + for _, p := range in { + switch p.Type { + case "text": + if strings.TrimSpace(p.Text) == "" { + continue + } + out = append(out, chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: p.Text, + }) + case "image_url": + if p.ImageURL == nil || p.ImageURL.URL == "" { + continue + } + out = append(out, chat.MessagePart{ + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + URL: p.ImageURL.URL, + Detail: chat.ImageURLDetail(p.ImageURL.Detail), + }, + }) + } + } + return out +} + // appendLatestUser walks msgs from the end and appends only the last // user-role message into sess. Used by conversation continuation, where // the session already contains the full prior history and we just need @@ -111,16 +161,25 @@ func buildSession(messages []ChatCompletionMessage) *session.Session { func appendLatestUser(sess *session.Session, msgs []ChatCompletionMessage) { for i := len(msgs) - 1; i >= 0; i-- { m := msgs[i] - content := strings.TrimSpace(m.Content) - if content == "" { - continue - } role := strings.ToLower(strings.TrimSpace(m.Role)) // Treat any non-system/assistant/tool role as user (matches // buildSession's policy). if role == "system" || role == "assistant" || role == "tool" { continue } + parts := convertParts(m.Parts) + if len(parts) > 0 { + sess.AddMessage(&session.Message{Message: chat.Message{ + Role: chat.MessageRoleUser, + Content: m.Content, + MultiContent: parts, + }}) + return + } + content := strings.TrimSpace(m.Content) + if content == "" { + continue + } sess.AddMessage(session.UserMessage(m.Content)) return } diff --git a/pkg/chatserver/openapi.json b/pkg/chatserver/openapi.json index 7d16b6070..664fef72c 100644 --- a/pkg/chatserver/openapi.json +++ b/pkg/chatserver/openapi.json @@ -50,7 +50,16 @@ "type": "string", "enum": ["system", "user", "assistant", "tool", "developer"] }, - "content": { "type": "string" }, + "content": { + "description": "Either a plain string or an array of typed content parts (text or image_url).", + "oneOf": [ + { "type": "string" }, + { + "type": "array", + "items": { "$ref": "#/components/schemas/ContentPart" } + } + ] + }, "name": { "type": "string" }, "tool_call_id": { "type": "string" }, "tool_calls": { @@ -59,6 +68,22 @@ } } }, + "ContentPart": { + "type": "object", + "required": ["type"], + "properties": { + "type": { "type": "string", "enum": ["text", "image_url"] }, + "text": { "type": "string" }, + "image_url": { + "type": "object", + "required": ["url"], + "properties": { + "url": { "type": "string" }, + "detail": { "type": "string", "enum": ["auto", "low", "high"] } + } + } + } + }, "ToolCallReference": { "type": "object", "required": ["function"], diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index 55250e05e..d950e3e08 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -2,6 +2,7 @@ package chatserver import ( "context" + "encoding/json" "errors" "net/http" "net/http/httptest" @@ -340,6 +341,75 @@ func TestValidateSamplingParams(t *testing.T) { } } +func TestChatCompletionMessage_UnmarshalContentString(t *testing.T) { + var m ChatCompletionMessage + require.NoError(t, json.Unmarshal([]byte(`{"role":"user","content":"hello"}`), &m)) + assert.Equal(t, "user", m.Role) + assert.Equal(t, "hello", m.Content) + assert.Empty(t, m.Parts) +} + +func TestChatCompletionMessage_UnmarshalContentParts(t *testing.T) { + var m ChatCompletionMessage + input := `{ + "role":"user", + "content":[ + {"type":"text","text":"What is in this picture?"}, + {"type":"image_url","image_url":{"url":"https://example.com/x.png","detail":"high"}} + ] + }` + require.NoError(t, json.Unmarshal([]byte(input), &m)) + assert.Equal(t, "user", m.Role) + require.Len(t, m.Parts, 2) + assert.Equal(t, "text", m.Parts[0].Type) + assert.Equal(t, "image_url", m.Parts[1].Type) + require.NotNil(t, m.Parts[1].ImageURL) + assert.Equal(t, "https://example.com/x.png", m.Parts[1].ImageURL.URL) + // Flat text is pre-computed for callers that don't care about parts. + assert.Equal(t, "What is in this picture?", m.Content) +} + +func TestChatCompletionMessage_RoundTripText(t *testing.T) { + in := ChatCompletionMessage{Role: "assistant", Content: "hi there"} + raw, err := json.Marshal(in) + require.NoError(t, err) + assert.JSONEq(t, `{"role":"assistant","content":"hi there"}`, string(raw)) +} + +func TestChatCompletionMessage_RoundTripParts(t *testing.T) { + in := ChatCompletionMessage{ + Role: "user", + Parts: []ContentPart{ + {Type: "text", Text: "hi"}, + {Type: "image_url", ImageURL: &ContentImageURL{URL: "http://x/y"}}, + }, + } + raw, err := json.Marshal(in) + require.NoError(t, err) + assert.JSONEq(t, `{"role":"user","content":[{"type":"text","text":"hi"},{"type":"image_url","image_url":{"url":"http://x/y"}}]}`, string(raw)) +} + +func TestBuildSession_AcceptsImageParts(t *testing.T) { + sess := buildSession([]ChatCompletionMessage{{ + Role: "user", + Parts: []ContentPart{ + {Type: "text", Text: "What is this?"}, + {Type: "image_url", ImageURL: &ContentImageURL{URL: "https://example.com/x.png"}}, + }, + }}) + require.NotNil(t, sess) + + all := sess.GetAllMessages() + require.Len(t, all, 1) + last := all[0].Message + assert.Equal(t, chat.MessageRoleUser, last.Role) + require.Len(t, last.MultiContent, 2) + assert.Equal(t, chat.MessagePartTypeText, last.MultiContent[0].Type) + assert.Equal(t, chat.MessagePartTypeImageURL, last.MultiContent[1].Type) + require.NotNil(t, last.MultiContent[1].ImageURL) + assert.Equal(t, "https://example.com/x.png", last.MultiContent[1].ImageURL.URL) +} + func TestStopSequences_UnmarshalJSON(t *testing.T) { cases := []struct { name string diff --git a/pkg/chatserver/types.go b/pkg/chatserver/types.go index 71ea3dd37..bbf3a4e06 100644 --- a/pkg/chatserver/types.go +++ b/pkg/chatserver/types.go @@ -3,6 +3,7 @@ package chatserver import ( "encoding/json" "errors" + "strings" "github.com/openai/openai-go/v3" ) @@ -72,17 +73,125 @@ func (s *StopSequences) UnmarshalJSON(data []byte) error { } } -// ChatCompletionMessage is a single message in the conversation. Multi-modal -// content (image parts, audio, etc.) is not supported and falls back to the -// `Content` string. +// ChatCompletionMessage is a single message in the conversation. +// +// On the wire OpenAI accepts message content in two shapes: either a +// plain string (`"content": "hello"`) or an array of typed parts +// (`"content": [{"type":"text",...}, {"type":"image_url",...}]`). +// Both shapes are accepted on the request side; the response always +// uses the string form for text-only content and the parts form when +// images or other non-text content are present. The custom JSON +// (un)marshallers below preserve that union without forcing every Go +// caller to deal with it. type ChatCompletionMessage struct { + Role string `json:"role"` + // Content is the text content of the message. Populated whether the + // wire format used a string or a parts array (the parts' text values + // are concatenated). + Content string `json:"-"` + // Parts holds the original typed parts when the wire format used an + // array. Empty when the wire format was a plain string. + Parts []ContentPart `json:"-"` + + Name string `json:"name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolCalls []ToolCallReference `json:"tool_calls,omitempty"` +} + +// ContentPart mirrors one entry in OpenAI's typed-parts array. Today the +// server understands `text` and `image_url` parts; unknown types are +// preserved in the request payload but ignored when building the +// session, so future part types degrade gracefully. +type ContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *ContentImageURL `json:"image_url,omitempty"` +} + +// ContentImageURL carries an image part. URL may be a regular http(s) +// URL or a data URL (`data:image/png;base64,...`). +type ContentImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` +} + +// jsonMessageEnvelope is the on-the-wire form of ChatCompletionMessage. +// It exists so we can run the union-shape decoding for `content` without +// duplicating every other field. +type jsonMessageEnvelope struct { Role string `json:"role"` - Content string `json:"content"` + Content json.RawMessage `json:"content,omitempty"` Name string `json:"name,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` ToolCalls []ToolCallReference `json:"tool_calls,omitempty"` } +// UnmarshalJSON accepts either a string `content` field or an array of +// typed parts (OpenAI's multimodal shape). +func (m *ChatCompletionMessage) UnmarshalJSON(data []byte) error { + var env jsonMessageEnvelope + if err := json.Unmarshal(data, &env); err != nil { + return err + } + m.Role = env.Role + m.Name = env.Name + m.ToolCallID = env.ToolCallID + m.ToolCalls = env.ToolCalls + + if len(env.Content) == 0 || string(env.Content) == "null" { + return nil + } + switch env.Content[0] { + case '"': + return json.Unmarshal(env.Content, &m.Content) + case '[': + if err := json.Unmarshal(env.Content, &m.Parts); err != nil { + return err + } + // Pre-compute the flat text so callers that don't care about + // images can keep using m.Content. + var buf strings.Builder + for _, p := range m.Parts { + if p.Type == "text" { + if buf.Len() > 0 { + buf.WriteByte(' ') + } + buf.WriteString(p.Text) + } + } + m.Content = buf.String() + return nil + default: + return errors.New("content must be a string or array of parts") + } +} + +// MarshalJSON emits the parts array when present, otherwise a plain +// string. Tool/role/name/tool_call_id round-trip verbatim. +func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { + env := jsonMessageEnvelope{ + Role: m.Role, + Name: m.Name, + ToolCallID: m.ToolCallID, + ToolCalls: m.ToolCalls, + } + switch { + case len(m.Parts) > 0: + raw, err := json.Marshal(m.Parts) + if err != nil { + return nil, err + } + env.Content = raw + case m.Content != "": + raw, err := json.Marshal(m.Content) + if err != nil { + return nil, err + } + env.Content = raw + } + return json.Marshal(env) +} + // ToolCallReference mirrors OpenAI's `tool_calls` entry. The server fills // it in on the *response* side so clients can introspect what tools the // agent invoked. Tools are still executed server-side; this is purely From 2ee80affe240fe29c0baa120fe18c7fdf87604ec Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 14:18:05 +0200 Subject: [PATCH 15/19] fix(chatserver): always store conversation after request When a conversation is evicted from the LRU cache while a request is processing it, the updated session was not being stored back because maybeStoreConversation only called Put when isNew=true. This caused conversation state to be lost when: 1. Request R1 retrieves conversation C from cache (isNew=false) 2. R1 processes the request, updating the session 3. Meanwhile, C is evicted due to LRU policy 4. R1 finishes and calls maybeStoreConversation(C, sess, false) 5. Since isNew=false, Put was not called 6. The updated session is lost Fix: Always call Put, regardless of isNew flag. This ensures the updated session is stored and refreshes the lastUsed timestamp, preventing premature eviction of active conversations. The Put operation is idempotent and safe to call multiple times for the same conversation ID. Assisted-By: docker-agent Assisted-By: docker-agent --- pkg/chatserver/server.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index da2f3a342..2f8b2706d 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -417,18 +417,17 @@ func (s *server) resolveSession(id string, msgs []ChatCompletionMessage) (*sessi } // maybeStoreConversation inserts the session into the cache after a -// run. We only need to insert when the conversation is new — existing -// entries are mutated in place. +// run. We always insert to handle the case where the conversation was +// evicted while the request was in flight. func (s *server) maybeStoreConversation(id string, sess *session.Session, isNew bool) { if id == "" || s.conversations == nil { return } - if isNew { - s.conversations.Put(id, sess) - } + // Always Put, even for existing conversations, to handle eviction + // during request processing. Put refreshes the lastUsed timestamp + // and ensures the updated session is stored. + s.conversations.Put(id, sess) } - -// chatCompletion runs the agent to completion and replies with one // non-streaming OpenAI ChatCompletion object. func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { var toolCalls []ToolCallReference From 047167a457a9b914271cd3a7d8ac096c0f69bac1 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 14:18:53 +0200 Subject: [PATCH 16/19] refactor(chatserver): remove unused isNew parameter The isNew flag was used to decide whether to call Put on the conversation store, but after the previous fix, we always call Put regardless of whether the conversation is new or existing. This commit removes the now-unused isNew parameter from resolveSession and maybeStoreConversation, simplifying the code. Assisted-By: docker-agent --- pkg/chatserver/server.go | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index 2f8b2706d..baf24fca8 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -368,7 +368,7 @@ func (s *server) handleChatCompletions(c echo.Context) error { } conversationID := c.Request().Header.Get("X-Conversation-Id") - sess, isNew := s.resolveSession(conversationID, req.Messages) + sess := s.resolveSession(conversationID, req.Messages) if sess == nil { return writeError(c, http.StatusBadRequest, "no user message provided") } @@ -389,11 +389,11 @@ func (s *server) handleChatCompletions(c echo.Context) error { if req.Stream { err := s.streamChatCompletion(c, rt, sess, model) - s.maybeStoreConversation(conversationID, sess, isNew) + s.maybeStoreConversation(conversationID, sess) return err } err = s.chatCompletion(c, rt, sess, model) - s.maybeStoreConversation(conversationID, sess, isNew) + s.maybeStoreConversation(conversationID, sess) return err } @@ -402,24 +402,20 @@ func (s *server) handleChatCompletions(c echo.Context) error { // session for it, we append only the latest user message from the // request (the prior history is already in the session). Otherwise we // build a brand-new session from the full request history. -// -// Returns the session to run, plus a flag indicating whether it was -// freshly created (so callers know whether they need to insert it into -// the cache). -func (s *server) resolveSession(id string, msgs []ChatCompletionMessage) (*session.Session, bool) { +func (s *server) resolveSession(id string, msgs []ChatCompletionMessage) *session.Session { if id != "" { if existing := s.conversations.Get(id); existing != nil { appendLatestUser(existing, msgs) - return existing, false + return existing } } - return buildSession(msgs), true + return buildSession(msgs) } // maybeStoreConversation inserts the session into the cache after a // run. We always insert to handle the case where the conversation was // evicted while the request was in flight. -func (s *server) maybeStoreConversation(id string, sess *session.Session, isNew bool) { +func (s *server) maybeStoreConversation(id string, sess *session.Session) { if id == "" || s.conversations == nil { return } @@ -428,6 +424,7 @@ func (s *server) maybeStoreConversation(id string, sess *session.Session, isNew // and ensures the updated session is stored. s.conversations.Put(id, sess) } + // non-streaming OpenAI ChatCompletion object. func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { var toolCalls []ToolCallReference From c7fab7a225b835ec6d7b254241a5b6490cbbc01c Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 14:20:22 +0200 Subject: [PATCH 17/19] test(chatserver): add test for conversation restore after eviction Add a test that verifies a conversation evicted from the LRU cache while a request is processing it can still be stored back after the request completes. This test validates the fix in commit 9563a431 which ensures maybeStoreConversation always calls Put, preventing loss of session state when a conversation is evicted during request processing. Assisted-By: docker-agent --- pkg/chatserver/conversations_eviction_test.go | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 pkg/chatserver/conversations_eviction_test.go diff --git a/pkg/chatserver/conversations_eviction_test.go b/pkg/chatserver/conversations_eviction_test.go new file mode 100644 index 000000000..59338368a --- /dev/null +++ b/pkg/chatserver/conversations_eviction_test.go @@ -0,0 +1,51 @@ +package chatserver + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/session" +) + +// TestConversationStore_RestoreAfterEviction tests that a conversation +// can be stored back after it's been evicted from the cache. +func TestConversationStore_RestoreAfterEviction(t *testing.T) { + now := time.Unix(1_000_000, 0) + c := newConversationStore(2, time.Hour) + c.now = func() time.Time { return now } + + // Store a conversation + sess1 := session.New() + sess1.AddMessage(session.UserMessage("first")) + c.Put("conv-1", sess1) + + // Retrieve it (simulating a request starting) + retrieved := c.Get("conv-1") + require.NotNil(t, retrieved) + require.Same(t, sess1, retrieved) + + // Simulate the request processing (updating the session) + retrieved.AddMessage(session.UserMessage("updated")) + + // Manually evict the conversation (simulating LRU eviction) + c.mu.Lock() + delete(c.items, "conv-1") + c.mu.Unlock() + + // Verify it's gone + assert.Nil(t, c.Get("conv-1"), "conv-1 should be evicted") + + // Now the request completes and stores the updated session back + // This should work even though conv-1 was evicted + now = now.Add(time.Second) + c.Put("conv-1", retrieved) + + // Verify the updated session is stored + final := c.Get("conv-1") + require.NotNil(t, final) + assert.Same(t, retrieved, final) + assert.Equal(t, "updated", final.GetLastUserMessageContent()) +} From 0c8e842e418d8dec44cd67e315315f975ec72d92 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 14:25:13 +0200 Subject: [PATCH 18/19] chatserver: restore doc comment on chatCompletion The previous fix accidentally deleted the doc-comment header line on `(*server).chatCompletion`, leaving a dangling fragment ("// non-streaming OpenAI ChatCompletion object.") detached from the function it documents. Assisted-By: docker-agent --- pkg/chatserver/server.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index baf24fca8..5b5d94893 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -425,6 +425,7 @@ func (s *server) maybeStoreConversation(id string, sess *session.Session) { s.conversations.Put(id, sess) } +// chatCompletion runs the agent to completion and replies with one // non-streaming OpenAI ChatCompletion object. func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { var toolCalls []ToolCallReference From 0e1c5a89c33500addc118e12e670daf346e8ca05 Mon Sep 17 00:00:00 2001 From: David Gageot Date: Mon, 27 Apr 2026 14:29:00 +0200 Subject: [PATCH 19/19] chatserver: serialize requests sharing an X-Conversation-Id MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Concurrent requests with the same X-Conversation-Id share the same `*session.Session` pointer (the conversation cache hands out the same instance to every caller), so two simultaneous runtime RunStream calls would interleave message appends, send overlapping prompts to the model, and produce a garbled transcript. Although `session.Session` has internal mutex protection on Messages, the agent loop reads-then-writes (decide what to send, append model output) so per-field synchronisation isn't enough — the whole turn must be atomic with respect to other turns on the same id. Reject the second concurrent request with 409 Conflict instead of trying to serialise it on the server. That: - Surfaces the misuse to the caller immediately (vs. mysterious interleaving), - Keeps server-side resources bounded (no queue, no parked goroutines), - Matches how OpenAI's own conversation API expects clients to use the protocol (one request at a time per conversation). Empty conversation id and nil lock-set are no-ops, so callers without the feature enabled keep their old behaviour. The OpenAPI spec advertises the new 409 response. Tests cover acquire/release semantics, nil/empty no-ops, and a race-detector- friendly stress test that proves at most one holder of the same id at a time. Assisted-By: docker-agent --- pkg/chatserver/conversation_lock.go | 48 ++++++++++++++++ pkg/chatserver/conversation_lock_test.go | 70 ++++++++++++++++++++++++ pkg/chatserver/handlers_test.go | 5 +- pkg/chatserver/openapi.json | 8 +++ pkg/chatserver/server.go | 23 +++++--- pkg/chatserver/server_test.go | 20 +++++++ 6 files changed, 165 insertions(+), 9 deletions(-) create mode 100644 pkg/chatserver/conversation_lock.go create mode 100644 pkg/chatserver/conversation_lock_test.go diff --git a/pkg/chatserver/conversation_lock.go b/pkg/chatserver/conversation_lock.go new file mode 100644 index 000000000..c5c9a1141 --- /dev/null +++ b/pkg/chatserver/conversation_lock.go @@ -0,0 +1,48 @@ +package chatserver + +import "sync" + +// conversationLockSet ensures only one in-flight request at a time per +// conversation id. Concurrent requests sharing an id would otherwise share +// the same `*session.Session` (the cache hands out the same pointer to every +// caller for that id), and two concurrent runtime.RunStream calls on one +// session interleave message appends and produce garbled transcripts. +// +// We reject the second request with 409 Conflict instead of serialising it, +// for two reasons: it surfaces the misuse to the client immediately, and it +// keeps the handler's resource cost bounded (no queue, no waiting goroutines). +type conversationLockSet struct { + mu sync.Mutex + active map[string]struct{} +} + +func newConversationLockSet() *conversationLockSet { + return &conversationLockSet{active: make(map[string]struct{})} +} + +// tryAcquire returns true when id was not already in flight. The caller +// must call release when the request finishes. Empty id is a no-op (and +// returns true) so callers without a conversation id don't need a guard. +func (l *conversationLockSet) tryAcquire(id string) bool { + if l == nil || id == "" { + return true + } + l.mu.Lock() + defer l.mu.Unlock() + if _, ok := l.active[id]; ok { + return false + } + l.active[id] = struct{}{} + return true +} + +// release marks id as no longer in flight. Safe to call when id is the +// empty string or l is nil. +func (l *conversationLockSet) release(id string) { + if l == nil || id == "" { + return + } + l.mu.Lock() + delete(l.active, id) + l.mu.Unlock() +} diff --git a/pkg/chatserver/conversation_lock_test.go b/pkg/chatserver/conversation_lock_test.go new file mode 100644 index 000000000..b6e0f4e4a --- /dev/null +++ b/pkg/chatserver/conversation_lock_test.go @@ -0,0 +1,70 @@ +package chatserver + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConversationLockSet_AcquireRelease(t *testing.T) { + l := newConversationLockSet() + assert.True(t, l.tryAcquire("a"), "first acquire should succeed") + assert.False(t, l.tryAcquire("a"), "second acquire on the same id should fail") + l.release("a") + assert.True(t, l.tryAcquire("a"), "acquire after release should succeed") + l.release("a") +} + +func TestConversationLockSet_DifferentIDsDontBlock(t *testing.T) { + l := newConversationLockSet() + assert.True(t, l.tryAcquire("a")) + assert.True(t, l.tryAcquire("b"), "different ids should not block each other") + l.release("a") + l.release("b") +} + +func TestConversationLockSet_EmptyIDIsNoop(t *testing.T) { + l := newConversationLockSet() + // Empty id is the "no conversation" path: tryAcquire must always + // succeed and release must be safe. + assert.True(t, l.tryAcquire("")) + assert.True(t, l.tryAcquire("")) + l.release("") +} + +func TestConversationLockSet_NilIsNoop(t *testing.T) { + var l *conversationLockSet + assert.True(t, l.tryAcquire("a")) + l.release("a") // must not panic +} + +func TestConversationLockSet_RaceFreeUnderConcurrency(t *testing.T) { + // Run the race detector over a hot loop. The lock set's invariant — + // "at most one acquired ID at a time" — must hold. + l := newConversationLockSet() + const goroutines = 50 + const iters = 200 + + var maxConcurrent atomic.Int32 + var current atomic.Int32 + var wg sync.WaitGroup + for range goroutines { + wg.Go(func() { + for range iters { + if l.tryAcquire("hot") { + n := current.Add(1) + if n > maxConcurrent.Load() { + maxConcurrent.Store(n) + } + current.Add(-1) + l.release("hot") + } + } + }) + } + wg.Wait() + assert.LessOrEqual(t, maxConcurrent.Load(), int32(1), + "at most one holder of the same id at a time") +} diff --git a/pkg/chatserver/handlers_test.go b/pkg/chatserver/handlers_test.go index 9ca43aa23..ef1f6144d 100644 --- a/pkg/chatserver/handlers_test.go +++ b/pkg/chatserver/handlers_test.go @@ -20,7 +20,10 @@ func newTestServer(exposed ...string) (*server, *echo.Echo) { if len(exposed) == 0 { exposed = []string{"root"} } - srv := &server{policy: agentPolicy{exposed: exposed, fallback: exposed[0]}} + srv := &server{ + policy: agentPolicy{exposed: exposed, fallback: exposed[0]}, + conversationLocks: newConversationLockSet(), + } e := echo.New() return srv, e } diff --git a/pkg/chatserver/openapi.json b/pkg/chatserver/openapi.json index 664fef72c..e34ded003 100644 --- a/pkg/chatserver/openapi.json +++ b/pkg/chatserver/openapi.json @@ -267,6 +267,14 @@ } }, "413": { "description": "Request body exceeds --max-request-size." }, + "409": { + "description": "Another request with the same X-Conversation-Id is in flight. Retry sequentially.", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + }, "500": { "description": "Agent execution failed.", "content": { diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go index 5b5d94893..928e56636 100644 --- a/pkg/chatserver/server.go +++ b/pkg/chatserver/server.go @@ -127,10 +127,11 @@ func Run(ctx context.Context, agentFilename string, opts Options, ln net.Listene httpServer := &http.Server{ Handler: newRouter(&server{ - team: t, - policy: policy, - conversations: newConversationStore(opts.ConversationsMaxSessions, conversationTTL(opts)), - runtimes: newRuntimePool(t, opts.MaxIdleRuntimes), + team: t, + policy: policy, + conversations: newConversationStore(opts.ConversationsMaxSessions, conversationTTL(opts)), + conversationLocks: newConversationLockSet(), + runtimes: newRuntimePool(t, opts.MaxIdleRuntimes), }, opts), ReadHeaderTimeout: 30 * time.Second, } @@ -180,10 +181,11 @@ func serve(ctx context.Context, httpServer *http.Server, ln net.Listener) error // runtime, so the only shared state is the team (whose toolsets are // independently safe to call) and the optional conversation cache. type server struct { - team *team.Team - policy agentPolicy - conversations *conversationStore - runtimes *runtimePool + team *team.Team + policy agentPolicy + conversations *conversationStore + conversationLocks *conversationLockSet + runtimes *runtimePool } func newRouter(s *server, opts Options) http.Handler { @@ -368,6 +370,11 @@ func (s *server) handleChatCompletions(c echo.Context) error { } conversationID := c.Request().Header.Get("X-Conversation-Id") + if !s.conversationLocks.tryAcquire(conversationID) { + return writeError(c, http.StatusConflict, "another request is already in flight for this conversation id") + } + defer s.conversationLocks.release(conversationID) + sess := s.resolveSession(conversationID, req.Messages) if sess == nil { return writeError(c, http.StatusBadRequest, "no user message provided") diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go index d950e3e08..e4df5074b 100644 --- a/pkg/chatserver/server_test.go +++ b/pkg/chatserver/server_test.go @@ -280,6 +280,26 @@ func TestBearerAuthMiddleware(t *testing.T) { } } +func TestHandleChatCompletions_RejectsConcurrentSameConversation(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{}) + + // Pre-acquire the conversation lock to simulate an in-flight + // request. The next request with the same id must get 409. + require.True(t, srv.conversationLocks.tryAcquire("conv-x")) + defer srv.conversationLocks.release("conv-x") + + body := `{"messages":[{"role":"user","content":"hi"}]}` + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Conversation-Id", "conv-x") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusConflict, rec.Code) + assert.Contains(t, rec.Body.String(), "another request is already in flight") +} + func TestBearerAuthMiddleware_AllowsCORSPreflight(t *testing.T) { // CORS preflight must succeed without an Authorization header. srv, _ := newTestServer("root")