Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions cmd/root/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package root

import (
"github.com/spf13/cobra"

"github.com/docker/docker-agent/pkg/chatserver"
"github.com/docker/docker-agent/pkg/cli"
"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/telemetry"
)

type chatFlags struct {
agentName string
listenAddr string
runConfig config.RuntimeConfig
}

func newChatCmd() *cobra.Command {
var flags chatFlags

cmd := &cobra.Command{
Use: "chat <agent-file>|<registry-ref>",
Short: "Start an agent as an OpenAI-compatible chat completions server",
Long: `Start an HTTP server that exposes the agent through an OpenAI-compatible
API at /v1/chat/completions and /v1/models. This lets tools that already
speak OpenAI's chat protocol (such as Open WebUI) drive a docker-agent
agent without any custom integration.`,
Example: ` docker-agent serve chat ./agent.yaml
docker-agent serve chat ./team.yaml --agent reviewer
docker-agent serve chat agentcatalog/pirate --listen 127.0.0.1:9090`,
Args: cobra.ExactArgs(1),
RunE: flags.runChatCommand,
}

cmd.Flags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to expose (all agents if not specified)")
cmd.Flags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8083", "Address to listen on")
addRuntimeConfigFlags(cmd, &flags.runConfig)

return cmd
}

func (f *chatFlags) runChatCommand(cmd *cobra.Command, args []string) (commandErr error) {
ctx := cmd.Context()
telemetry.TrackCommand(ctx, "serve", append([]string{"chat"}, args...))
defer func() { // do not inline this defer so that commandErr is not resolved early
telemetry.TrackCommandError(ctx, "serve", append([]string{"chat"}, args...), commandErr)
}()

out := cli.NewPrinter(cmd.OutOrStdout())
agentFilename := args[0]

ln, cleanup, err := newListener(ctx, f.listenAddr)
if err != nil {
return err
}
defer cleanup()

out.Println("Listening on", ln.Addr().String())
out.Println("OpenAI-compatible chat completions endpoint: http://" + ln.Addr().String() + "/v1/chat/completions")

return chatserver.Run(ctx, agentFilename, f.agentName, &f.runConfig, ln)
}
3 changes: 2 additions & 1 deletion cmd/root/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ func newServeCmd() *cobra.Command {

cmd.AddCommand(newA2ACmd())
cmd.AddCommand(newACPCmd())
cmd.AddCommand(newMCPCmd())
cmd.AddCommand(newAPICmd())
cmd.AddCommand(newChatCmd())
cmd.AddCommand(newMCPCmd())

return cmd
}
5 changes: 3 additions & 2 deletions e2e/binary/binary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ func TestAutoComplete(t *testing.T) {
res, err := Exec(binDir+"/docker-agent", "__complete", "serve", "")
require.NoError(t, err)
props := lines(res.Stdout)
require.Greater(t, len(props), 4)
require.Greater(t, len(props), 5)
require.Contains(t, props[0], "a2a")
require.Contains(t, props[0], "Start an agent as an A2A")
require.Contains(t, props[1], "acp")
require.Contains(t, props[2], "api")
require.Contains(t, props[3], "mcp")
require.Contains(t, props[3], "chat")
require.Contains(t, props[4], "mcp")
})

t.Run("cli plugin auto-complete docker agent", func(t *testing.T) {
Expand Down
160 changes: 160 additions & 0 deletions pkg/chatserver/agent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package chatserver

import (
"context"
"errors"
"fmt"
"slices"
"strings"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/runtime"
"github.com/docker/docker-agent/pkg/session"
"github.com/docker/docker-agent/pkg/team"
"github.com/docker/docker-agent/pkg/tools"
)

// agentPolicy decides which agent in a team is exposed by the server and
// which one to run for a given request. It is built once at startup and is
// read-only thereafter, so it's safe to share across goroutines.
type agentPolicy struct {
// exposed is the list of agent names advertised on /v1/models.
exposed []string
// fallback is used when the request's "model" field doesn't match any
// exposed agent (so we don't fail when clients hard-code "gpt-4").
fallback string
}

// newAgentPolicy validates the requested agent name against the team and
// returns the selection policy. If agentName is empty, every agent in the
// team is exposed and the team's default agent is used as fallback.
// Otherwise only that one agent is exposed and used.
func newAgentPolicy(t *team.Team, agentName string) (agentPolicy, error) {
if agentName != "" {
if !slices.Contains(t.AgentNames(), agentName) {
return agentPolicy{}, fmt.Errorf("agent %q not found", agentName)
}
return agentPolicy{exposed: []string{agentName}, fallback: agentName}, nil
}
a, err := t.DefaultAgent()
if err != nil {
return agentPolicy{}, fmt.Errorf("resolving default agent: %w", err)
}
return agentPolicy{exposed: t.AgentNames(), fallback: a.Name()}, nil
}

// pick returns the agent name to use for a request. The "model" field is
// honoured when it matches an exposed agent; otherwise we silently fall
// back, mirroring how OpenAI's API behaves with unknown model strings on
// some compatible servers.
func (p agentPolicy) pick(model string) string {
if model != "" && slices.Contains(p.exposed, model) {
return model
}
return p.fallback
}

// buildSession converts an OpenAI-style message history into a docker-agent
// session. System messages are added as system context, prior user/
// assistant/tool turns are replayed verbatim so the agent sees the full
// conversation, and the latest user message becomes the prompt.
//
// Tool approval and non-interactive mode are forced on: this is a headless
// HTTP endpoint, there's no human in the loop to approve anything.
//
// Returns nil when the history contains no usable user message, in which
// case the caller should reject the request.
func buildSession(messages []ChatCompletionMessage) *session.Session {
sess := session.New(
session.WithToolsApproved(true),
session.WithNonInteractive(true),
)

hasUser := false
for _, m := range messages {
content := m.Content
if strings.TrimSpace(content) == "" {
continue
}
switch strings.ToLower(strings.TrimSpace(m.Role)) {
case "system":
sess.AddMessage(session.SystemMessage(content))
case "assistant":
sess.AddMessage(&session.Message{Message: chat.Message{
Role: chat.MessageRoleAssistant,
Content: content,
}})
case "tool":
sess.AddMessage(&session.Message{Message: chat.Message{
Role: chat.MessageRoleTool,
Content: content,
ToolCallID: m.ToolCallID,
}})
default:
// user, developer, or any other role: feed it to the agent
// as user input rather than rejecting the request.
sess.AddMessage(session.UserMessage(content))
hasUser = true
}
}

if !hasUser {
return nil
}
return sess
}

// runAgentLoop drives the runtime to completion, forwarding assistant
// content to emit (which may be nil for non-streaming mode).
//
// The session is built with ToolsApproved=true and NonInteractive=true,
// which means the runtime auto-approves tool calls and auto-stops on
// max-iterations. The handler cases below are intentionally kept as
// defence-in-depth: if those session settings ever drift, this handler
// still won't hang the request. Elicitation is the exception — the
// runtime always blocks until we respond, so its case is required for
// correctness, not just defence.
//
// The first error reported by the runtime is surfaced; later events in
// the same run are still drained so the runtime can shut down cleanly.
func runAgentLoop(ctx context.Context, rt runtime.Runtime, sess *session.Session, emit func(string)) error {
var runErr error
for ev := range rt.RunStream(ctx, sess) {
switch e := ev.(type) {
case *runtime.AgentChoiceEvent:
if emit != nil {
emit(e.Content)
}
case *runtime.ToolCallConfirmationEvent:
// Defensive: should never fire while ToolsApproved=true.
rt.Resume(ctx, runtime.ResumeApprove())
case *runtime.ElicitationRequestEvent:
// Required: the runtime blocks until we respond, regardless
// of NonInteractive. Decline so the tool call fails fast.
_ = rt.ResumeElicitation(ctx, tools.ElicitationActionDecline, nil)
case *runtime.MaxIterationsReachedEvent:
// Defensive: in non-interactive mode the runtime already
// stops on its own and this Resume is dropped.
rt.Resume(ctx, runtime.ResumeReject(""))
case *runtime.ErrorEvent:
if runErr == nil {
runErr = errors.New(e.Error)
}
}
}
return runErr
}

// sessionUsage extracts approximate token usage from a completed session,
// returning nil when nothing is known so we can omit the field entirely
// rather than reporting zeroes.
func sessionUsage(sess *session.Session) *ChatCompletionUsage {
if sess.InputTokens == 0 && sess.OutputTokens == 0 {
return nil
}
return &ChatCompletionUsage{
PromptTokens: sess.InputTokens,
CompletionTokens: sess.OutputTokens,
TotalTokens: sess.InputTokens + sess.OutputTokens,
}
}
124 changes: 124 additions & 0 deletions pkg/chatserver/handlers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package chatserver

import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/labstack/echo/v4"
"github.com/openai/openai-go/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// newTestServer builds a server with a fake policy for tests that don't run
// the agent loop. Handlers that touch s.team will panic — those code paths
// are exercised by integration tests, not here.
func newTestServer(exposed ...string) (*server, *echo.Echo) {
if len(exposed) == 0 {
exposed = []string{"root"}
}
srv := &server{policy: agentPolicy{exposed: exposed, fallback: exposed[0]}}
e := echo.New()
return srv, e
}

func TestHandleModels(t *testing.T) {
srv, e := newTestServer("root", "reviewer")

req := httptest.NewRequest(http.MethodGet, "/v1/models", http.NoBody)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

require.NoError(t, srv.handleModels(c))
require.Equal(t, http.StatusOK, rec.Code)

var got ModelsResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
assert.Equal(t, "list", got.Object)
require.Len(t, got.Data, 2)

ids := []string{got.Data[0].ID, got.Data[1].ID}
assert.ElementsMatch(t, []string{"root", "reviewer"}, ids)
for _, m := range got.Data {
assert.Equal(t, "docker-agent", m.OwnedBy)
// openai.Model carries a typed `Object constant.Model` field that
// always serialises to "model". Ensure the wire shape is stable.
assert.Equal(t, openai.Model{}.Object.Default(), m.Object)
}
}

func TestHandleChatCompletions_RejectsBadJSON(t *testing.T) {
srv, e := newTestServer()

req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader("not json"))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

require.NoError(t, srv.handleChatCompletions(c))
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Contains(t, rec.Body.String(), "invalid_request_error")
}

func TestHandleChatCompletions_RejectsEmptyMessages(t *testing.T) {
srv, e := newTestServer()

req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions",
strings.NewReader(`{"messages":[]}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

require.NoError(t, srv.handleChatCompletions(c))
assert.Equal(t, http.StatusBadRequest, rec.Code)

var got ErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
assert.Equal(t, "invalid_request_error", got.Error.Type)
assert.Contains(t, got.Error.Message, "at least one message")
}

func TestHandleChatCompletions_RejectsHistoryWithoutUser(t *testing.T) {
srv, e := newTestServer()

body := `{"messages":[{"role":"system","content":"be helpful"}]}`
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

require.NoError(t, srv.handleChatCompletions(c))
assert.Equal(t, http.StatusBadRequest, rec.Code)
assert.Contains(t, rec.Body.String(), "no user message")
}

func TestWriteError_ShapeAndType(t *testing.T) {
cases := []struct {
name string
status int
message string
wantType string
}{
{"client error", http.StatusBadRequest, "bad input", "invalid_request_error"},
{"server error", http.StatusInternalServerError, "boom", "internal_error"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

require.NoError(t, writeError(c, tc.status, tc.message))
assert.Equal(t, tc.status, rec.Code)

var got ErrorResponse
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got))
assert.Equal(t, tc.message, got.Error.Message)
assert.Equal(t, tc.wantType, got.Error.Type)
})
}
}
Loading
Loading