diff --git a/cmd/root/chat.go b/cmd/root/chat.go new file mode 100644 index 000000000..ec1b0c23c --- /dev/null +++ b/cmd/root/chat.go @@ -0,0 +1,98 @@ +package root + +import ( + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/docker/docker-agent/pkg/chatserver" + "github.com/docker/docker-agent/pkg/cli" + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/telemetry" +) + +type chatFlags struct { + agentName string + listenAddr string + corsOrigin string + apiKey string + apiKeyEnv string + maxRequestSize int64 + requestTimeout time.Duration + conversationsMaxItems int + conversationTTL time.Duration + maxIdleRuntimes int + runConfig config.RuntimeConfig +} + +func newChatCmd() *cobra.Command { + var flags chatFlags + + cmd := &cobra.Command{ + Use: "chat |", + Short: "Start an agent as an OpenAI-compatible chat completions server", + Long: `Start an HTTP server that exposes the agent through an OpenAI-compatible +API at /v1/chat/completions and /v1/models. This lets tools that already +speak OpenAI's chat protocol (such as Open WebUI) drive a docker-agent +agent without any custom integration.`, + Example: ` docker-agent serve chat ./agent.yaml + docker-agent serve chat ./team.yaml --agent reviewer + docker-agent serve chat agentcatalog/pirate --listen 127.0.0.1:9090`, + Args: cobra.ExactArgs(1), + RunE: flags.runChatCommand, + } + + cmd.Flags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to expose (all agents if not specified)") + cmd.Flags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8083", "Address to listen on") + cmd.Flags().StringVar(&flags.corsOrigin, "cors-origin", "", "Allowed CORS origin (e.g. https://example.com); empty disables CORS entirely") + cmd.Flags().StringVar(&flags.apiKey, "api-key", "", "Required Bearer token clients must present (Authorization: Bearer ); empty disables auth") + cmd.Flags().StringVar(&flags.apiKeyEnv, "api-key-env", "", "Read the API key from this environment variable instead of the command line") + cmd.Flags().Int64Var(&flags.maxRequestSize, "max-request-size", 1<<20, "Maximum request body size in bytes (default 1 MiB)") + cmd.Flags().DurationVar(&flags.requestTimeout, "request-timeout", 5*time.Minute, "Per-request timeout (covers model + tool calls + streaming)") + cmd.Flags().IntVar(&flags.conversationsMaxItems, "conversations-max", 0, "Cache up to N conversations server-side, keyed by X-Conversation-Id (0 disables; clients must resend full history)") + cmd.Flags().DurationVar(&flags.conversationTTL, "conversation-ttl", 30*time.Minute, "Idle TTL after which a cached conversation is evicted") + cmd.Flags().IntVar(&flags.maxIdleRuntimes, "max-idle-runtimes", 4, "Maximum number of idle runtimes pooled per agent (0 disables pooling)") + addRuntimeConfigFlags(cmd, &flags.runConfig) + + return cmd +} + +func (f *chatFlags) runChatCommand(cmd *cobra.Command, args []string) (commandErr error) { + ctx := cmd.Context() + telemetry.TrackCommand(ctx, "serve", append([]string{"chat"}, args...)) + defer func() { // do not inline this defer so that commandErr is not resolved early + telemetry.TrackCommandError(ctx, "serve", append([]string{"chat"}, args...), commandErr) + }() + + out := cli.NewPrinter(cmd.OutOrStdout()) + agentFilename := args[0] + + ln, cleanup, err := newListener(ctx, f.listenAddr) + if err != nil { + return err + } + defer cleanup() + + out.Println("Listening on", ln.Addr().String()) + out.Println("OpenAI-compatible chat completions endpoint: http://" + ln.Addr().String() + "/v1/chat/completions") + + apiKey := f.apiKey + if f.apiKeyEnv != "" { + if v := os.Getenv(f.apiKeyEnv); v != "" { + apiKey = v + } + } + + return chatserver.Run(ctx, agentFilename, chatserver.Options{ + AgentName: f.agentName, + RunConfig: &f.runConfig, + CORSOrigin: f.corsOrigin, + APIKey: apiKey, + MaxRequestBytes: f.maxRequestSize, + RequestTimeout: f.requestTimeout, + ConversationsMaxSessions: f.conversationsMaxItems, + ConversationTTL: f.conversationTTL, + MaxIdleRuntimes: f.maxIdleRuntimes, + }, ln) +} diff --git a/cmd/root/serve.go b/cmd/root/serve.go index df541cecd..13a10d836 100644 --- a/cmd/root/serve.go +++ b/cmd/root/serve.go @@ -13,8 +13,9 @@ func newServeCmd() *cobra.Command { cmd.AddCommand(newA2ACmd()) cmd.AddCommand(newACPCmd()) - cmd.AddCommand(newMCPCmd()) cmd.AddCommand(newAPICmd()) + cmd.AddCommand(newChatCmd()) + cmd.AddCommand(newMCPCmd()) return cmd } diff --git a/e2e/binary/binary_test.go b/e2e/binary/binary_test.go index 1c8c823c2..4d7130404 100644 --- a/e2e/binary/binary_test.go +++ b/e2e/binary/binary_test.go @@ -54,12 +54,13 @@ func TestAutoComplete(t *testing.T) { res, err := Exec(binDir+"/docker-agent", "__complete", "serve", "") require.NoError(t, err) props := lines(res.Stdout) - require.Greater(t, len(props), 4) + require.Greater(t, len(props), 5) require.Contains(t, props[0], "a2a") require.Contains(t, props[0], "Start an agent as an A2A") require.Contains(t, props[1], "acp") require.Contains(t, props[2], "api") - require.Contains(t, props[3], "mcp") + require.Contains(t, props[3], "chat") + require.Contains(t, props[4], "mcp") }) t.Run("cli plugin auto-complete docker agent", func(t *testing.T) { diff --git a/examples/chat/main.go b/examples/chat/main.go new file mode 100644 index 000000000..5efba5bcd --- /dev/null +++ b/examples/chat/main.go @@ -0,0 +1,174 @@ +// A very, very basic chat client for `docker agent serve chat`. +// +// PR #2510 (`feat: add docker agent serve chat command`) exposes any +// docker-agent agent through an OpenAI-compatible HTTP server. The whole +// point of that feature is that any tool already speaking OpenAI's +// /v1/chat/completions protocol can drive a docker-agent agent without +// custom integration. This example demonstrates exactly that: it uses the +// official github.com/openai/openai-go SDK, only repointed at the local +// chat server, to run an interactive REPL against an agent. +// +// Prerequisites: +// +// # Start an agent in chat mode (in another terminal): +// ./bin/docker-agent serve chat ./examples/42.yaml +// # It listens on http://127.0.0.1:8083 by default. +// +// Then run this client: +// +// go run ./examples/chat +// # or, to pin a specific agent in a multi-agent team: +// go run ./examples/chat -model root +// # or, to point at a different server: +// go run ./examples/chat -base http://127.0.0.1:9090/v1 +// +// Type a message and press . Type "exit" (or send EOF with ^D) to +// quit. +package main + +import ( + "bufio" + "context" + "errors" + "flag" + "fmt" + "io" + "log" + "os" + "os/signal" + "strings" + "syscall" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" +) + +func main() { + baseURL := flag.String("base", "http://127.0.0.1:8083/v1", "Base URL of the docker-agent chat server") + model := flag.String("model", "", "Agent name to talk to (defaults to the team's default agent)") + stream := flag.Bool("stream", true, "Stream the agent's response token-by-token") + flag.Parse() + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + err := run(ctx, *baseURL, *model, *stream) + cancel() + if err != nil && !errors.Is(err, context.Canceled) { + log.Fatal(err) + } +} + +func run(ctx context.Context, baseURL, model string, stream bool) error { + // The chat server doesn't validate API keys, but the OpenAI SDK + // requires *some* string to be passed. + client := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithAPIKey("not-needed"), + ) + + // Ask the server which agents are exposed and pick a default model + // when the user didn't pin one. This also doubles as a connectivity + // check. + if model == "" { + picked, err := pickDefaultModel(ctx, &client) + if err != nil { + return fmt.Errorf("listing models: %w", err) + } + model = picked + } + fmt.Printf("Connected to %s — chatting with %q. Type \"exit\" to quit.\n", baseURL, model) + + // History keeps the conversation going across turns. The chat server + // is stateless: it builds a fresh session per request from whatever + // messages the client sends, so it's the client's job to remember. + var history []openai.ChatCompletionMessageParamUnion + + in := bufio.NewScanner(os.Stdin) + in.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for { + fmt.Print("\n> ") + if !in.Scan() { + if err := in.Err(); err != nil { + return err + } + fmt.Println() + return nil // EOF + } + userInput := strings.TrimSpace(in.Text()) + if userInput == "" { + continue + } + if userInput == "exit" || userInput == "quit" { + return nil + } + + history = append(history, openai.UserMessage(userInput)) + + reply, err := chat(ctx, &client, model, history, stream) + if err != nil { + return err + } + history = append(history, openai.AssistantMessage(reply)) + } +} + +// pickDefaultModel queries /v1/models and returns the first agent name +// the server advertises. +func pickDefaultModel(ctx context.Context, client *openai.Client) (string, error) { + page, err := client.Models.List(ctx) + if err != nil { + return "", err + } + if len(page.Data) == 0 { + return "", errors.New("server exposes no models") + } + return page.Data[0].ID, nil +} + +// chat sends the conversation to the server, prints the assistant's reply +// to stdout (streaming if requested) and returns the final assembled +// content so the caller can append it to the history. +func chat( + ctx context.Context, + client *openai.Client, + model string, + history []openai.ChatCompletionMessageParamUnion, + stream bool, +) (string, error) { + params := openai.ChatCompletionNewParams{ + Model: model, + Messages: history, + } + + if !stream { + resp, err := client.Chat.Completions.New(ctx, params) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", errors.New("server returned no choices") + } + content := resp.Choices[0].Message.Content + fmt.Println(content) + return content, nil + } + + s := client.Chat.Completions.NewStreaming(ctx, params) + var b strings.Builder + for s.Next() { + chunk := s.Current() + if len(chunk.Choices) == 0 { + continue + } + delta := chunk.Choices[0].Delta.Content + if delta == "" { + continue + } + fmt.Print(delta) + b.WriteString(delta) + } + if err := s.Err(); err != nil && !errors.Is(err, io.EOF) { + return "", err + } + fmt.Println() + return b.String(), nil +} diff --git a/pkg/chatserver/agent.go b/pkg/chatserver/agent.go new file mode 100644 index 000000000..752286cdf --- /dev/null +++ b/pkg/chatserver/agent.go @@ -0,0 +1,262 @@ +package chatserver + +import ( + "context" + "errors" + "fmt" + "slices" + "strings" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/tools" +) + +// agentPolicy decides which agent in a team is exposed by the server and +// which one to run for a given request. It is built once at startup and is +// read-only thereafter, so it's safe to share across goroutines. +type agentPolicy struct { + // exposed is the list of agent names advertised on /v1/models. + exposed []string + // fallback is used when the request's "model" field doesn't match any + // exposed agent (so we don't fail when clients hard-code "gpt-4"). + fallback string +} + +// newAgentPolicy validates the requested agent name against the team and +// returns the selection policy. If agentName is empty, every agent in the +// team is exposed and the team's default agent is used as fallback. +// Otherwise only that one agent is exposed and used. +func newAgentPolicy(t *team.Team, agentName string) (agentPolicy, error) { + if agentName != "" { + if !slices.Contains(t.AgentNames(), agentName) { + return agentPolicy{}, fmt.Errorf("agent %q not found", agentName) + } + return agentPolicy{exposed: []string{agentName}, fallback: agentName}, nil + } + a, err := t.DefaultAgent() + if err != nil { + return agentPolicy{}, fmt.Errorf("resolving default agent: %w", err) + } + return agentPolicy{exposed: t.AgentNames(), fallback: a.Name()}, nil +} + +// pick returns the agent name to use for a request. The "model" field is +// honoured when it matches an exposed agent; otherwise we silently fall +// back, mirroring how OpenAI's API behaves with unknown model strings on +// some compatible servers. +func (p agentPolicy) pick(model string) string { + if model != "" && slices.Contains(p.exposed, model) { + return model + } + return p.fallback +} + +// buildSession converts an OpenAI-style message history into a docker-agent +// session. System messages are added as system context, prior user/ +// assistant/tool turns are replayed verbatim so the agent sees the full +// conversation, and the latest user message becomes the prompt. +// +// Tool approval and non-interactive mode are forced on: this is a headless +// HTTP endpoint, there's no human in the loop to approve anything. +// +// Returns nil when the history contains no usable user message, in which +// case the caller should reject the request. +func buildSession(messages []ChatCompletionMessage) *session.Session { + sess := session.New( + session.WithToolsApproved(true), + session.WithNonInteractive(true), + ) + + hasUser := false + for _, m := range messages { + role := strings.ToLower(strings.TrimSpace(m.Role)) + if len(m.Parts) > 0 && (role == "" || (role != "system" && role != "assistant" && role != "tool")) { + // Multi-part content: route through chat.MultiContent so the + // runtime/provider sees image parts. Only user-style messages + // support images today. + parts := convertParts(m.Parts) + if len(parts) == 0 { + continue + } + sess.AddMessage(&session.Message{Message: chat.Message{ + Role: chat.MessageRoleUser, + Content: m.Content, + MultiContent: parts, + }}) + hasUser = true + continue + } + + content := m.Content + if strings.TrimSpace(content) == "" { + continue + } + switch role { + case "system": + sess.AddMessage(session.SystemMessage(content)) + case "assistant": + sess.AddMessage(&session.Message{Message: chat.Message{ + Role: chat.MessageRoleAssistant, + Content: content, + }}) + case "tool": + sess.AddMessage(&session.Message{Message: chat.Message{ + Role: chat.MessageRoleTool, + Content: content, + ToolCallID: m.ToolCallID, + }}) + default: + // user, developer, or any other role: feed it to the agent + // as user input rather than rejecting the request. + sess.AddMessage(session.UserMessage(content)) + hasUser = true + } + } + + if !hasUser { + return nil + } + return sess +} + +// convertParts maps the chatserver wire shape to chat.MessagePart so +// images and (future) other typed parts reach the runtime intact. +// Unknown part types are dropped; an empty result tells the caller to +// skip the message entirely. +func convertParts(in []ContentPart) []chat.MessagePart { + out := make([]chat.MessagePart, 0, len(in)) + for _, p := range in { + switch p.Type { + case "text": + if strings.TrimSpace(p.Text) == "" { + continue + } + out = append(out, chat.MessagePart{ + Type: chat.MessagePartTypeText, + Text: p.Text, + }) + case "image_url": + if p.ImageURL == nil || p.ImageURL.URL == "" { + continue + } + out = append(out, chat.MessagePart{ + Type: chat.MessagePartTypeImageURL, + ImageURL: &chat.MessageImageURL{ + URL: p.ImageURL.URL, + Detail: chat.ImageURLDetail(p.ImageURL.Detail), + }, + }) + } + } + return out +} + +// appendLatestUser walks msgs from the end and appends only the last +// user-role message into sess. Used by conversation continuation, where +// the session already contains the full prior history and we just need +// to inject what the client just said. +func appendLatestUser(sess *session.Session, msgs []ChatCompletionMessage) { + for i := len(msgs) - 1; i >= 0; i-- { + m := msgs[i] + role := strings.ToLower(strings.TrimSpace(m.Role)) + // Treat any non-system/assistant/tool role as user (matches + // buildSession's policy). + if role == "system" || role == "assistant" || role == "tool" { + continue + } + parts := convertParts(m.Parts) + if len(parts) > 0 { + sess.AddMessage(&session.Message{Message: chat.Message{ + Role: chat.MessageRoleUser, + Content: m.Content, + MultiContent: parts, + }}) + return + } + content := strings.TrimSpace(m.Content) + if content == "" { + continue + } + sess.AddMessage(session.UserMessage(m.Content)) + return + } +} + +// agentEmit collects the side-effect callbacks invoked by runAgentLoop as +// it drives the runtime. All callbacks are optional; nil means "ignore +// this kind of event". +type agentEmit struct { + // onContent fires for every assistant text delta from the model. + onContent func(string) + // onToolCall fires when the agent dispatches a tool. Called once per + // tool, with the tool already populated with its arguments. + onToolCall func(ToolCallReference) +} + +// runAgentLoop drives the runtime to completion, forwarding events to +// the supplied callbacks. +// +// The session is built with ToolsApproved=true and NonInteractive=true, +// which means the runtime auto-approves tool calls and auto-stops on +// max-iterations. The handler cases below are intentionally kept as +// defence-in-depth: if those session settings ever drift, this handler +// still won't hang the request. Elicitation is the exception — the +// runtime always blocks until we respond, so its case is required for +// correctness, not just defence. +// +// All ErrorEvents seen in the run are joined into the returned error so +// callers can see the full picture; the loop keeps draining until the +// stream closes so the runtime can shut down cleanly. +func runAgentLoop(ctx context.Context, rt runtime.Runtime, sess *session.Session, emit agentEmit) error { + var runErrs []error + toolIndex := 0 + for ev := range rt.RunStream(ctx, sess) { + switch e := ev.(type) { + case *runtime.AgentChoiceEvent: + if emit.onContent != nil { + emit.onContent(e.Content) + } + case *runtime.ToolCallEvent: + if emit.onToolCall != nil { + emit.onToolCall(ToolCallReference{ + Index: toolIndex, + ID: e.ToolCall.ID, + Type: string(e.ToolCall.Type), + Function: ToolCallFunction{Name: e.ToolCall.Function.Name, Arguments: e.ToolCall.Function.Arguments}, + }) + toolIndex++ + } + case *runtime.ToolCallConfirmationEvent: + // Defensive: should never fire while ToolsApproved=true. + rt.Resume(ctx, runtime.ResumeApprove()) + case *runtime.ElicitationRequestEvent: + // Required: the runtime blocks until we respond, regardless + // of NonInteractive. Decline so the tool call fails fast. + _ = rt.ResumeElicitation(ctx, tools.ElicitationActionDecline, nil) + case *runtime.MaxIterationsReachedEvent: + // Defensive: in non-interactive mode the runtime already + // stops on its own and this Resume is dropped. + rt.Resume(ctx, runtime.ResumeReject("")) + case *runtime.ErrorEvent: + runErrs = append(runErrs, errors.New(e.Error)) + } + } + return errors.Join(runErrs...) +} + +// sessionUsage extracts approximate token usage from a completed session, +// returning nil when nothing is known so we can omit the field entirely +// rather than reporting zeroes. +func sessionUsage(sess *session.Session) *ChatCompletionUsage { + if sess.InputTokens == 0 && sess.OutputTokens == 0 { + return nil + } + return &ChatCompletionUsage{ + PromptTokens: sess.InputTokens, + CompletionTokens: sess.OutputTokens, + TotalTokens: sess.InputTokens + sess.OutputTokens, + } +} diff --git a/pkg/chatserver/conversation_lock.go b/pkg/chatserver/conversation_lock.go new file mode 100644 index 000000000..c5c9a1141 --- /dev/null +++ b/pkg/chatserver/conversation_lock.go @@ -0,0 +1,48 @@ +package chatserver + +import "sync" + +// conversationLockSet ensures only one in-flight request at a time per +// conversation id. Concurrent requests sharing an id would otherwise share +// the same `*session.Session` (the cache hands out the same pointer to every +// caller for that id), and two concurrent runtime.RunStream calls on one +// session interleave message appends and produce garbled transcripts. +// +// We reject the second request with 409 Conflict instead of serialising it, +// for two reasons: it surfaces the misuse to the client immediately, and it +// keeps the handler's resource cost bounded (no queue, no waiting goroutines). +type conversationLockSet struct { + mu sync.Mutex + active map[string]struct{} +} + +func newConversationLockSet() *conversationLockSet { + return &conversationLockSet{active: make(map[string]struct{})} +} + +// tryAcquire returns true when id was not already in flight. The caller +// must call release when the request finishes. Empty id is a no-op (and +// returns true) so callers without a conversation id don't need a guard. +func (l *conversationLockSet) tryAcquire(id string) bool { + if l == nil || id == "" { + return true + } + l.mu.Lock() + defer l.mu.Unlock() + if _, ok := l.active[id]; ok { + return false + } + l.active[id] = struct{}{} + return true +} + +// release marks id as no longer in flight. Safe to call when id is the +// empty string or l is nil. +func (l *conversationLockSet) release(id string) { + if l == nil || id == "" { + return + } + l.mu.Lock() + delete(l.active, id) + l.mu.Unlock() +} diff --git a/pkg/chatserver/conversation_lock_test.go b/pkg/chatserver/conversation_lock_test.go new file mode 100644 index 000000000..039721ecf --- /dev/null +++ b/pkg/chatserver/conversation_lock_test.go @@ -0,0 +1,70 @@ +package chatserver + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConversationLockSet_AcquireRelease(t *testing.T) { + l := newConversationLockSet() + assert.True(t, l.tryAcquire("a"), "first acquire should succeed") + assert.False(t, l.tryAcquire("a"), "second acquire on the same id should fail") + l.release("a") + assert.True(t, l.tryAcquire("a"), "acquire after release should succeed") + l.release("a") +} + +func TestConversationLockSet_DifferentIDsDontBlock(t *testing.T) { + l := newConversationLockSet() + assert.True(t, l.tryAcquire("a")) + assert.True(t, l.tryAcquire("b"), "different ids should not block each other") + l.release("a") + l.release("b") +} + +func TestConversationLockSet_EmptyIDIsNoop(t *testing.T) { + l := newConversationLockSet() + // Empty id is the "no conversation" path: tryAcquire must always + // succeed and release must be safe. + assert.True(t, l.tryAcquire("")) + assert.True(t, l.tryAcquire("")) + l.release("") +} + +func TestConversationLockSet_NilIsNoop(t *testing.T) { + var l *conversationLockSet + assert.True(t, l.tryAcquire("a")) + l.release("a") // must not panic +} + +func TestConversationLockSet_RaceFreeUnderConcurrency(t *testing.T) { + // Run the race detector over a hot loop. The lock set's invariant — + // "at most one acquired ID at a time" — must hold. + l := newConversationLockSet() + const goroutines = 50 + const iters = 200 + + var maxConcurrent int32 + var current int32 + var wg sync.WaitGroup + for range goroutines { + wg.Go(func() { + for range iters { + if l.tryAcquire("hot") { + n := atomic.AddInt32(¤t, 1) + if n > atomic.LoadInt32(&maxConcurrent) { + atomic.StoreInt32(&maxConcurrent, n) + } + atomic.AddInt32(¤t, -1) + l.release("hot") + } + } + }) + } + wg.Wait() + assert.LessOrEqual(t, atomic.LoadInt32(&maxConcurrent), int32(1), + "at most one holder of the same id at a time") +} diff --git a/pkg/chatserver/conversations.go b/pkg/chatserver/conversations.go new file mode 100644 index 000000000..68d28c052 --- /dev/null +++ b/pkg/chatserver/conversations.go @@ -0,0 +1,137 @@ +package chatserver + +import ( + "sync" + "time" + + "github.com/docker/docker-agent/pkg/session" +) + +// conversationStore keeps long-lived sessions keyed by the +// `X-Conversation-Id` header so clients don't have to resend the full +// conversation history on every turn. +// +// It's an LRU with a TTL: entries past `ttl` since their last use are +// considered expired and lazily evicted on Get. When the store would +// grow past `maxEntries`, the least-recently-used entry is evicted on +// Put. Both eviction paths are O(n) since this cache is small (typical +// `maxEntries` ≤ a few hundred); a doubly-linked-list LRU would be +// strictly faster but the extra code is rarely worth it. +// +// All operations are safe for concurrent use. +type conversationStore struct { + mu sync.Mutex + items map[string]*conversationEntry + maxEntries int + ttl time.Duration + now func() time.Time // injectable for tests +} + +type conversationEntry struct { + sess *session.Session + lastUsed time.Time +} + +// newConversationStore returns a store that holds at most maxEntries +// sessions and forgets entries that have been idle for more than ttl. +// Either bound can be zero/negative to disable that bound. A store with +// both bounds disabled is functionally a regular map; a store with +// maxEntries == 0 is disabled and Get always misses. +func newConversationStore(maxEntries int, ttl time.Duration) *conversationStore { + return &conversationStore{ + items: make(map[string]*conversationEntry), + maxEntries: maxEntries, + ttl: ttl, + now: time.Now, + } +} + +// Get returns the stored session for id and refreshes its last-used +// timestamp. Misses return nil. The store is disabled when maxEntries +// <= 0, in which case Get always misses. +func (c *conversationStore) Get(id string) *session.Session { + if c == nil || c.maxEntries <= 0 || id == "" { + return nil + } + c.mu.Lock() + defer c.mu.Unlock() + e, ok := c.items[id] + if !ok { + return nil + } + if c.expired(e) { + delete(c.items, id) + return nil + } + e.lastUsed = c.now() + return e.sess +} + +// Put stores sess under id and evicts the least-recently-used entry if +// the store is over capacity. Has no effect when the store is disabled. +func (c *conversationStore) Put(id string, sess *session.Session) { + if c == nil || c.maxEntries <= 0 || id == "" || sess == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + now := c.now() + c.items[id] = &conversationEntry{sess: sess, lastUsed: now} + + // Drop expired neighbours in the same critical section so callers + // don't accumulate dead weight on long-running stores. + for k, v := range c.items { + if c.expired(v) { + delete(c.items, k) + } + } + for len(c.items) > c.maxEntries { + c.evictOldestLocked() + } +} + +// Delete removes id from the store, if present. Useful for clients that +// want to explicitly close out a conversation. +func (c *conversationStore) Delete(id string) { + if c == nil || id == "" { + return + } + c.mu.Lock() + defer c.mu.Unlock() + delete(c.items, id) +} + +// Len returns the current number of cached conversations. Mostly useful +// for tests and metrics. +func (c *conversationStore) Len() int { + if c == nil { + return 0 + } + c.mu.Lock() + defer c.mu.Unlock() + return len(c.items) +} + +func (c *conversationStore) expired(e *conversationEntry) bool { + if c.ttl <= 0 { + return false + } + return c.now().Sub(e.lastUsed) > c.ttl +} + +// evictOldestLocked removes the oldest entry. Caller holds c.mu. +func (c *conversationStore) evictOldestLocked() { + var oldestKey string + var oldestTime time.Time + first := true + for k, v := range c.items { + if first || v.lastUsed.Before(oldestTime) { + oldestKey = k + oldestTime = v.lastUsed + first = false + } + } + if !first { + delete(c.items, oldestKey) + } +} diff --git a/pkg/chatserver/conversations_eviction_test.go b/pkg/chatserver/conversations_eviction_test.go new file mode 100644 index 000000000..59338368a --- /dev/null +++ b/pkg/chatserver/conversations_eviction_test.go @@ -0,0 +1,51 @@ +package chatserver + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/session" +) + +// TestConversationStore_RestoreAfterEviction tests that a conversation +// can be stored back after it's been evicted from the cache. +func TestConversationStore_RestoreAfterEviction(t *testing.T) { + now := time.Unix(1_000_000, 0) + c := newConversationStore(2, time.Hour) + c.now = func() time.Time { return now } + + // Store a conversation + sess1 := session.New() + sess1.AddMessage(session.UserMessage("first")) + c.Put("conv-1", sess1) + + // Retrieve it (simulating a request starting) + retrieved := c.Get("conv-1") + require.NotNil(t, retrieved) + require.Same(t, sess1, retrieved) + + // Simulate the request processing (updating the session) + retrieved.AddMessage(session.UserMessage("updated")) + + // Manually evict the conversation (simulating LRU eviction) + c.mu.Lock() + delete(c.items, "conv-1") + c.mu.Unlock() + + // Verify it's gone + assert.Nil(t, c.Get("conv-1"), "conv-1 should be evicted") + + // Now the request completes and stores the updated session back + // This should work even though conv-1 was evicted + now = now.Add(time.Second) + c.Put("conv-1", retrieved) + + // Verify the updated session is stored + final := c.Get("conv-1") + require.NotNil(t, final) + assert.Same(t, retrieved, final) + assert.Equal(t, "updated", final.GetLastUserMessageContent()) +} diff --git a/pkg/chatserver/conversations_test.go b/pkg/chatserver/conversations_test.go new file mode 100644 index 000000000..038cc0758 --- /dev/null +++ b/pkg/chatserver/conversations_test.go @@ -0,0 +1,90 @@ +package chatserver + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/session" +) + +func TestConversationStore_Disabled(t *testing.T) { + c := newConversationStore(0, time.Hour) + c.Put("a", session.New()) + assert.Nil(t, c.Get("a")) + assert.Equal(t, 0, c.Len()) +} + +func TestConversationStore_PutGet(t *testing.T) { + c := newConversationStore(8, time.Hour) + s := session.New() + c.Put("a", s) + + got := c.Get("a") + require.NotNil(t, got) + assert.Same(t, s, got) +} + +func TestConversationStore_TTL(t *testing.T) { + now := time.Unix(1_000_000, 0) + c := newConversationStore(8, time.Minute) + c.now = func() time.Time { return now } + + c.Put("a", session.New()) + assert.NotNil(t, c.Get("a")) + + now = now.Add(2 * time.Minute) + assert.Nil(t, c.Get("a"), "entry should be expired") + assert.Equal(t, 0, c.Len(), "expired entry should be evicted on Get miss") +} + +func TestConversationStore_LRUEviction(t *testing.T) { + now := time.Unix(1_000_000, 0) + c := newConversationStore(2, time.Hour) + c.now = func() time.Time { return now } + + c.Put("a", session.New()) + now = now.Add(time.Second) + c.Put("b", session.New()) + now = now.Add(time.Second) + // Touch "a" so it becomes the most-recently-used. + require.NotNil(t, c.Get("a")) + now = now.Add(time.Second) + c.Put("c", session.New()) + + // "b" was the LRU when capacity was exceeded, so it should be the + // one that got evicted. + assert.Nil(t, c.Get("b")) + assert.NotNil(t, c.Get("a")) + assert.NotNil(t, c.Get("c")) +} + +func TestConversationStore_Delete(t *testing.T) { + c := newConversationStore(8, time.Hour) + c.Put("a", session.New()) + c.Delete("a") + assert.Nil(t, c.Get("a")) +} + +func TestAppendLatestUser(t *testing.T) { + sess := session.New() + appendLatestUser(sess, []ChatCompletionMessage{ + {Role: "system", Content: "be helpful"}, + {Role: "user", Content: "first"}, + {Role: "assistant", Content: "ack"}, + {Role: "user", Content: "second"}, + {Role: "tool", Content: "tool result", ToolCallID: "x"}, + }) + assert.Equal(t, "second", sess.GetLastUserMessageContent()) +} + +func TestAppendLatestUser_NoUserMessage(t *testing.T) { + sess := session.New() + appendLatestUser(sess, []ChatCompletionMessage{ + {Role: "system", Content: "be helpful"}, + {Role: "assistant", Content: "ack"}, + }) + assert.Empty(t, sess.GetLastUserMessageContent()) +} diff --git a/pkg/chatserver/handlers_test.go b/pkg/chatserver/handlers_test.go new file mode 100644 index 000000000..ef1f6144d --- /dev/null +++ b/pkg/chatserver/handlers_test.go @@ -0,0 +1,127 @@ +package chatserver + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/labstack/echo/v4" + "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestServer builds a server with a fake policy for tests that don't run +// the agent loop. Handlers that touch s.team will panic — those code paths +// are exercised by integration tests, not here. +func newTestServer(exposed ...string) (*server, *echo.Echo) { + if len(exposed) == 0 { + exposed = []string{"root"} + } + srv := &server{ + policy: agentPolicy{exposed: exposed, fallback: exposed[0]}, + conversationLocks: newConversationLockSet(), + } + e := echo.New() + return srv, e +} + +func TestHandleModels(t *testing.T) { + srv, e := newTestServer("root", "reviewer") + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/models", http.NoBody) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, srv.handleModels(c)) + require.Equal(t, http.StatusOK, rec.Code) + + var got ModelsResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, "list", got.Object) + require.Len(t, got.Data, 2) + + ids := []string{got.Data[0].ID, got.Data[1].ID} + assert.ElementsMatch(t, []string{"root", "reviewer"}, ids) + for _, m := range got.Data { + assert.Equal(t, "docker-agent", m.OwnedBy) + // openai.Model carries a typed `Object constant.Model` field that + // always serialises to "model". Ensure the wire shape is stable. + assert.Equal(t, openai.Model{}.Object.Default(), m.Object) + } +} + +func TestHandleChatCompletions_RejectsBadJSON(t *testing.T) { + srv, e := newTestServer() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", strings.NewReader("not json")) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, srv.handleChatCompletions(c)) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Contains(t, rec.Body.String(), "invalid_request_error") +} + +func TestHandleChatCompletions_RejectsEmptyMessages(t *testing.T) { + srv, e := newTestServer() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", + strings.NewReader(`{"messages":[]}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, srv.handleChatCompletions(c)) + assert.Equal(t, http.StatusBadRequest, rec.Code) + + var got ErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, "invalid_request_error", got.Error.Type) + assert.Contains(t, got.Error.Message, "at least one message") +} + +func TestHandleChatCompletions_RejectsHistoryWithoutUser(t *testing.T) { + srv, e := newTestServer() + + body := `{"messages":[{"role":"system","content":"be helpful"}]}` + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, srv.handleChatCompletions(c)) + assert.Equal(t, http.StatusBadRequest, rec.Code) + assert.Contains(t, rec.Body.String(), "no user message") +} + +func TestWriteError_ShapeAndType(t *testing.T) { + cases := []struct { + name string + status int + message string + wantType string + }{ + {"client error", http.StatusBadRequest, "bad input", "invalid_request_error"}, + {"server error", http.StatusInternalServerError, "boom", "internal_error"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/", http.NoBody) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + require.NoError(t, writeError(c, tc.status, tc.message)) + assert.Equal(t, tc.status, rec.Code) + + var got ErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + assert.Equal(t, tc.message, got.Error.Message) + assert.Equal(t, tc.wantType, got.Error.Type) + }) + } +} diff --git a/pkg/chatserver/openapi.go b/pkg/chatserver/openapi.go new file mode 100644 index 000000000..de35467c9 --- /dev/null +++ b/pkg/chatserver/openapi.go @@ -0,0 +1,23 @@ +package chatserver + +import ( + _ "embed" + "net/http" + + "github.com/labstack/echo/v4" +) + +// openAPISpec is the static OpenAPI 3.1 document describing the chat +// completions API. Embedding the JSON keeps the schema diffable and +// tractable to review, and means we don't pay a generation step on every +// build. +// +//go:embed openapi.json +var openAPISpec []byte + +// handleOpenAPI serves the static OpenAPI document. It is exempted from +// the bearer-auth middleware (see bearerAuthMiddleware) so tooling that +// wants to introspect the API can do so without credentials. +func (s *server) handleOpenAPI(c echo.Context) error { + return c.Blob(http.StatusOK, "application/json", openAPISpec) +} diff --git a/pkg/chatserver/openapi.json b/pkg/chatserver/openapi.json new file mode 100644 index 000000000..e34ded003 --- /dev/null +++ b/pkg/chatserver/openapi.json @@ -0,0 +1,306 @@ +{ + "openapi": "3.1.0", + "info": { + "title": "docker-agent chat completions", + "summary": "OpenAI-compatible HTTP API exposing a docker-agent agent.", + "description": "Implements a small subset of OpenAI's REST API (chat completions and models) so any tool that already speaks OpenAI's protocol can drive a docker-agent agent without a custom integration.", + "version": "1.0.0" + }, + "servers": [ + { + "url": "http://127.0.0.1:8083", + "description": "Default loopback bind" + } + ], + "components": { + "securitySchemes": { + "bearerAuth": { + "type": "http", + "scheme": "bearer", + "description": "Static token configured via --api-key. When --api-key is not set the server is unauthenticated." + } + }, + "schemas": { + "Model": { + "type": "object", + "required": ["id", "object", "owned_by"], + "properties": { + "id": { "type": "string", "description": "Agent name." }, + "object": { "type": "string", "const": "model" }, + "created": { "type": "integer", "format": "int64" }, + "owned_by": { "type": "string", "const": "docker-agent" } + } + }, + "ModelsResponse": { + "type": "object", + "required": ["object", "data"], + "properties": { + "object": { "type": "string", "const": "list" }, + "data": { + "type": "array", + "items": { "$ref": "#/components/schemas/Model" } + } + } + }, + "ChatCompletionMessage": { + "type": "object", + "required": ["role"], + "properties": { + "role": { + "type": "string", + "enum": ["system", "user", "assistant", "tool", "developer"] + }, + "content": { + "description": "Either a plain string or an array of typed content parts (text or image_url).", + "oneOf": [ + { "type": "string" }, + { + "type": "array", + "items": { "$ref": "#/components/schemas/ContentPart" } + } + ] + }, + "name": { "type": "string" }, + "tool_call_id": { "type": "string" }, + "tool_calls": { + "type": "array", + "items": { "$ref": "#/components/schemas/ToolCallReference" } + } + } + }, + "ContentPart": { + "type": "object", + "required": ["type"], + "properties": { + "type": { "type": "string", "enum": ["text", "image_url"] }, + "text": { "type": "string" }, + "image_url": { + "type": "object", + "required": ["url"], + "properties": { + "url": { "type": "string" }, + "detail": { "type": "string", "enum": ["auto", "low", "high"] } + } + } + } + }, + "ToolCallReference": { + "type": "object", + "required": ["function"], + "properties": { + "index": { "type": "integer" }, + "id": { "type": "string" }, + "type": { "type": "string", "const": "function" }, + "function": { + "type": "object", + "required": ["name"], + "properties": { + "name": { "type": "string" }, + "arguments": { + "type": "string", + "description": "JSON-encoded arguments object." + } + } + } + } + }, + "ChatCompletionRequest": { + "type": "object", + "required": ["messages"], + "properties": { + "model": { + "type": "string", + "description": "Agent name to invoke. Defaults to the team's default agent when missing or unknown." + }, + "messages": { + "type": "array", + "minItems": 1, + "items": { "$ref": "#/components/schemas/ChatCompletionMessage" } + }, + "stream": { "type": "boolean", "default": false }, + "temperature": { + "type": "number", + "minimum": 0, + "maximum": 2, + "description": "Validated; full runtime plumbing is in progress." + }, + "top_p": { + "type": "number", + "exclusiveMinimum": 0, + "maximum": 1 + }, + "max_tokens": { "type": "integer", "minimum": 1 }, + "stop": { + "oneOf": [ + { "type": "string" }, + { "type": "array", "items": { "type": "string" } } + ] + } + } + }, + "ChatCompletionChoice": { + "type": "object", + "required": ["index", "message"], + "properties": { + "index": { "type": "integer" }, + "message": { "$ref": "#/components/schemas/ChatCompletionMessage" }, + "finish_reason": { + "type": "string", + "enum": ["stop", "tool_calls", "error", "length"] + } + } + }, + "ChatCompletionUsage": { + "type": "object", + "properties": { + "prompt_tokens": { "type": "integer", "format": "int64" }, + "completion_tokens": { "type": "integer", "format": "int64" }, + "total_tokens": { "type": "integer", "format": "int64" } + } + }, + "ChatCompletionResponse": { + "type": "object", + "required": ["id", "object", "created", "model", "choices"], + "properties": { + "id": { "type": "string" }, + "object": { "type": "string", "const": "chat.completion" }, + "created": { "type": "integer", "format": "int64" }, + "model": { "type": "string" }, + "choices": { + "type": "array", + "items": { "$ref": "#/components/schemas/ChatCompletionChoice" } + }, + "usage": { "$ref": "#/components/schemas/ChatCompletionUsage" } + } + }, + "ErrorResponse": { + "type": "object", + "required": ["error"], + "properties": { + "error": { + "type": "object", + "required": ["message", "type"], + "properties": { + "message": { "type": "string" }, + "type": { + "type": "string", + "enum": ["invalid_request_error", "internal_error"] + }, + "code": { "type": "string" } + } + } + } + } + } + }, + "security": [{ "bearerAuth": [] }], + "paths": { + "/v1/models": { + "get": { + "summary": "List the agents this server exposes.", + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ModelsResponse" } + } + } + }, + "401": { + "description": "Missing or invalid bearer token (only when --api-key is set).", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + } + } + } + }, + "/v1/chat/completions": { + "post": { + "summary": "Create a chat completion.", + "description": "Set `stream: true` to receive Server-Sent Events instead of a single JSON response. The optional `X-Conversation-Id` request header reuses a server-side session across turns when --conversations-max is non-zero.", + "parameters": [ + { + "in": "header", + "name": "X-Conversation-Id", + "schema": { "type": "string" }, + "description": "Stable identifier used to look up a cached session." + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ChatCompletionRequest" } + } + } + }, + "responses": { + "200": { + "description": "OK. Either a JSON ChatCompletion or a `text/event-stream` of `chat.completion.chunk` events.", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ChatCompletionResponse" } + }, + "text/event-stream": { + "schema": { "type": "string" } + } + } + }, + "400": { + "description": "Bad request (malformed JSON, missing user message, invalid sampling parameters).", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + }, + "401": { + "description": "Missing or invalid bearer token.", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + }, + "413": { "description": "Request body exceeds --max-request-size." }, + "409": { + "description": "Another request with the same X-Conversation-Id is in flight. Retry sequentially.", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + }, + "500": { + "description": "Agent execution failed.", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/ErrorResponse" } + } + } + } + } + } + }, + "/openapi.json": { + "get": { + "summary": "Returns this OpenAPI document.", + "security": [], + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { "type": "object" } + } + } + } + } + } + } + } +} diff --git a/pkg/chatserver/openapi_test.go b/pkg/chatserver/openapi_test.go new file mode 100644 index 000000000..291ea4bd9 --- /dev/null +++ b/pkg/chatserver/openapi_test.go @@ -0,0 +1,45 @@ +package chatserver + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOpenAPIEndpoint(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/openapi.json", http.NoBody) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) + + var doc map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &doc)) + assert.Equal(t, "3.1.0", doc["openapi"], "OpenAPI version") + paths, ok := doc["paths"].(map[string]any) + require.True(t, ok) + assert.Contains(t, paths, "/v1/chat/completions") + assert.Contains(t, paths, "/v1/models") +} + +func TestOpenAPIEndpoint_BypassesAuth(t *testing.T) { + // /openapi.json must be reachable without a bearer token even when + // --api-key is set, so introspection tooling works against locked- + // down deployments. + srv, _ := newTestServer("root") + r := newRouter(srv, Options{APIKey: "secret"}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/openapi.json", http.NoBody) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) +} diff --git a/pkg/chatserver/runtime_pool.go b/pkg/chatserver/runtime_pool.go new file mode 100644 index 000000000..d79f03448 --- /dev/null +++ b/pkg/chatserver/runtime_pool.go @@ -0,0 +1,107 @@ +package chatserver + +import ( + "errors" + "sync" + + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/team" +) + +// runtimePool keeps a small set of `runtime.Runtime` instances ready for +// reuse, keyed by agent name. Building a runtime is non-trivial (it +// resolves the agent's tools, creates per-agent hook executors, sets up +// channels for resume/elicitation), so reusing the work across requests +// is a real latency win for hot paths. +// +// Concurrency model: a single runtime is *not* safe for concurrent +// RunStream calls (its resume/elicitation channels are per-runtime +// state). The pool therefore hands out a runtime to one caller at a +// time. Callers Get → use → Put back. When the pool is empty a fresh +// runtime is built. +// +// `maxIdle` bounds the number of idle runtimes per agent. Returning a +// runtime to a full pool is a no-op; it simply gets garbage collected. +type runtimePool struct { + team *team.Team + maxIdle int + + mu sync.Mutex + idle map[string]chan runtime.Runtime +} + +// errInvalidRuntime is returned when a caller asks for a runtime for an +// agent the pool can't create one for. Today this can only happen if +// runtime.New fails for a reason unrelated to the team (e.g. context +// cancellation in a future async path). +var errInvalidRuntime = errors.New("failed to acquire runtime") + +func newRuntimePool(t *team.Team, maxIdle int) *runtimePool { + if maxIdle < 0 { + maxIdle = 0 + } + return &runtimePool{ + team: t, + maxIdle: maxIdle, + idle: make(map[string]chan runtime.Runtime), + } +} + +// Get returns a ready-to-use runtime for the given agent, either +// recycled from the pool or freshly created. +func (p *runtimePool) Get(agent string) (runtime.Runtime, error) { + if p == nil { + return nil, errInvalidRuntime + } + if rt := p.takeIdle(agent); rt != nil { + return rt, nil + } + rt, err := runtime.New(p.team, runtime.WithCurrentAgent(agent)) + if err != nil { + return nil, err + } + return rt, nil +} + +// Put hands a finished runtime back to the pool. If the agent's idle +// slot is full the runtime is discarded (not closed: the team owns the +// underlying toolsets). The runtime must not be used by the caller +// after Put returns. +func (p *runtimePool) Put(agent string, rt runtime.Runtime) { + if p == nil || rt == nil || p.maxIdle == 0 { + return + } + ch := p.channelFor(agent) + select { + case ch <- rt: + default: + // pool full: drop on the floor. The team owns the toolsets, + // so nothing leaks; the runtime itself is ordinary garbage. + } +} + +func (p *runtimePool) takeIdle(agent string) runtime.Runtime { + p.mu.Lock() + ch, ok := p.idle[agent] + p.mu.Unlock() + if !ok { + return nil + } + select { + case rt := <-ch: + return rt + default: + return nil + } +} + +func (p *runtimePool) channelFor(agent string) chan runtime.Runtime { + p.mu.Lock() + defer p.mu.Unlock() + ch, ok := p.idle[agent] + if !ok { + ch = make(chan runtime.Runtime, p.maxIdle) + p.idle[agent] = ch + } + return ch +} diff --git a/pkg/chatserver/runtime_pool_test.go b/pkg/chatserver/runtime_pool_test.go new file mode 100644 index 000000000..b69774289 --- /dev/null +++ b/pkg/chatserver/runtime_pool_test.go @@ -0,0 +1,26 @@ +package chatserver + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRuntimePool_DisabledIsNotCached(t *testing.T) { + p := newRuntimePool(nil, 0) + + // Put with maxIdle=0 must be a no-op (we don't have a runtime to put, + // but the channel-for behaviour itself shouldn't allocate). + p.Put("root", nil) + assert.Empty(t, p.idle, "no per-agent channels should be allocated when pooling is disabled") +} + +func TestRuntimePool_NegativeCapTreatedAsZero(t *testing.T) { + p := newRuntimePool(nil, -1) + assert.Equal(t, 0, p.maxIdle) +} + +func TestRuntimePool_takeIdleNoChannel(t *testing.T) { + p := newRuntimePool(nil, 4) + assert.Nil(t, p.takeIdle("anything")) +} diff --git a/pkg/chatserver/server.go b/pkg/chatserver/server.go new file mode 100644 index 000000000..928e56636 --- /dev/null +++ b/pkg/chatserver/server.go @@ -0,0 +1,619 @@ +// Package chatserver implements an OpenAI-compatible HTTP server that exposes +// docker-agent agents through the /v1/chat/completions and /v1/models +// endpoints. +// +// The goal is to let any tool that already speaks OpenAI's chat protocol +// (e.g. Open WebUI, custom shell scripts using the openai SDK) drive a +// docker-agent agent without needing to know about docker-agent's own +// protocol. +// +// On types: we deliberately don't reuse the request/response structs from +// github.com/openai/openai-go/v3. The SDK is built around its internal +// `apijson` encoder; with stdlib `encoding/json` those types serialize +// every field and produce noisy responses. `apijson` lives under +// `internal/`, so we can't borrow it. `openai.Model` is the one type that +// round-trips cleanly with stdlib json, so we reuse it for /v1/models. +package chatserver + +import ( + "context" + "crypto/subtle" + "encoding/json" + "errors" + "fmt" + "log/slog" + "math" + "net" + "net/http" + "net/url" + "regexp" + "slices" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "github.com/openai/openai-go/v3" + + "github.com/docker/docker-agent/pkg/config" + "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/session" + "github.com/docker/docker-agent/pkg/team" + "github.com/docker/docker-agent/pkg/teamloader" +) + +// Options configures the chat completions server. Future improvements +// (auth, conversations, etc.) extend this struct rather than the Run +// signature so callers stay stable. +type Options struct { + // AgentName pins the single agent to expose. Empty exposes every + // agent in the team and uses the team's default as the fallback. + AgentName string + // RunConfig is the runtime configuration used to load the team. + RunConfig *config.RuntimeConfig + // CORSOrigin is the allowed value for the Access-Control-Allow-Origin + // header. When empty, the CORS middleware is not registered at all + // (the server never emits any Access-Control-* response header). + // + // Multiple values can be provided separated by commas. Each entry is + // either a literal origin (matched exactly), the wildcard "*", or a + // pattern starting with "~" interpreted as a Go regular expression + // against the request's Origin header. Examples: + // + // "https://app.example.com" + // "https://app.example.com,https://staging.example.com" + // "~^https://[a-z0-9-]+\\.example\\.com$" + CORSOrigin string + // APIKey, if non-empty, is the static bearer token clients must + // present in the `Authorization` header (`Authorization: Bearer X`). + // Empty disables authentication; once set, every request to /v1/* is + // rejected with 401 unless it carries the matching token. + // /v1/models is also protected so an unauthenticated client can't + // fingerprint the server. + APIKey string + // MaxRequestBytes caps the size of an incoming request body. Zero + // means use the package default (1 MiB). + MaxRequestBytes int64 + // RequestTimeout caps how long a single chat completion is allowed to + // run. Zero means use the package default (5 minutes). The cap covers + // model calls, tool calls, and SSE streaming combined. + RequestTimeout time.Duration + // ConversationsMaxSessions, when > 0, enables the X-Conversation-Id + // header: clients can pass a stable id to reuse the same session + // across requests instead of re-sending the full message history + // every turn. This is the size of the in-memory LRU cache. + ConversationsMaxSessions int + // ConversationTTL is how long a cached conversation may be idle + // before it's evicted. Zero means use the package default + // (30 minutes). + ConversationTTL time.Duration + // MaxIdleRuntimes bounds the number of idle runtimes pooled per + // agent. Building a runtime resolves tools and sets up channels; + // keeping a small pool of warm runtimes avoids paying that cost on + // every request. Zero disables pooling (a fresh runtime is built + // for every request, the original behaviour). + MaxIdleRuntimes int +} + +const ( + defaultMaxRequestBytes int64 = 1 << 20 // 1 MiB + defaultRequestTimeout time.Duration = 5 * time.Minute + defaultConversationTTL time.Duration = 30 * time.Minute +) + +// Run starts an OpenAI-compatible HTTP server on the given listener and +// blocks until ctx is cancelled or the server fails. The team is loaded +// once from agentFilename and shared across requests; every chat completion +// request gets a fresh session. +func Run(ctx context.Context, agentFilename string, opts Options, ln net.Listener) error { + slog.Debug("Starting chat completions server", "agent", agentFilename, "addr", ln.Addr()) + + t, err := loadTeam(ctx, agentFilename, opts.RunConfig) + if err != nil { + return err + } + defer func() { + if err := t.StopToolSets(ctx); err != nil { + slog.Error("Failed to stop tool sets", "error", err) + } + }() + + policy, err := newAgentPolicy(t, opts.AgentName) + if err != nil { + return err + } + + httpServer := &http.Server{ + Handler: newRouter(&server{ + team: t, + policy: policy, + conversations: newConversationStore(opts.ConversationsMaxSessions, conversationTTL(opts)), + conversationLocks: newConversationLockSet(), + runtimes: newRuntimePool(t, opts.MaxIdleRuntimes), + }, opts), + ReadHeaderTimeout: 30 * time.Second, + } + return serve(ctx, httpServer, ln) +} + +func conversationTTL(opts Options) time.Duration { + if opts.ConversationTTL > 0 { + return opts.ConversationTTL + } + return defaultConversationTTL +} + +// loadTeam resolves and loads the team referenced by agentFilename. +func loadTeam(ctx context.Context, agentFilename string, runConfig *config.RuntimeConfig) (*team.Team, error) { + src, err := config.Resolve(agentFilename, nil) + if err != nil { + return nil, err + } + t, err := teamloader.Load(ctx, src, runConfig) + if err != nil { + return nil, fmt.Errorf("failed to load agents: %w", err) + } + return t, nil +} + +// serve runs httpServer on ln until ctx is cancelled, then triggers a +// graceful shutdown. +func serve(ctx context.Context, httpServer *http.Server, ln net.Listener) error { + errCh := make(chan error, 1) + go func() { errCh <- httpServer.Serve(ln) }() + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return httpServer.Shutdown(shutdownCtx) + case err := <-errCh: + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err + } +} + +// server is concurrent-safe: every request creates its own session and +// runtime, so the only shared state is the team (whose toolsets are +// independently safe to call) and the optional conversation cache. +type server struct { + team *team.Team + policy agentPolicy + conversations *conversationStore + conversationLocks *conversationLockSet + runtimes *runtimePool +} + +func newRouter(s *server, opts Options) http.Handler { + e := echo.New() + e.HideBanner = true + e.HidePort = true + + maxBytes := opts.MaxRequestBytes + if maxBytes <= 0 { + maxBytes = defaultMaxRequestBytes + } + timeout := opts.RequestTimeout + if timeout <= 0 { + timeout = defaultRequestTimeout + } + + e.Use(middleware.RequestLogger()) + e.Use(middleware.BodyLimit(strconv.FormatInt(maxBytes, 10))) + e.Use(requestTimeoutMiddleware(timeout)) + + // Register /openapi.json *before* the bearer-auth middleware so the + // schema is reachable for introspection without credentials. CORS + // configuration is then layered for /v1/* routes. + e.GET("/openapi.json", s.handleOpenAPI) + + if opts.APIKey != "" { + e.Use(bearerAuthMiddleware(opts.APIKey)) + } + if opts.CORSOrigin != "" { + cfg, err := corsMiddlewareConfig(opts.CORSOrigin) + if err != nil { + // Bad config is reported via the request log. The middleware + // is simply not registered, which is the safest default. + slog.Error("Invalid --cors-origin, CORS disabled", "error", err) + } else { + e.Use(middleware.CORSWithConfig(cfg)) + } + } + + e.GET("/v1/models", s.handleModels) + e.POST("/v1/chat/completions", s.handleChatCompletions) + return e +} + +// requestTimeoutMiddleware caps each request's lifetime. Streaming +// handlers honour the timeout via c.Request().Context(). +func requestTimeoutMiddleware(d time.Duration) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + ctx, cancel := context.WithTimeout(c.Request().Context(), d) + defer cancel() + c.SetRequest(c.Request().WithContext(ctx)) + return next(c) + } + } +} + +// corsMiddlewareConfig parses a comma-separated --cors-origin value into an +// echo middleware.CORSConfig. Each entry is one of: +// +// - the literal "*" wildcard; +// - a regex when prefixed with "~" (compiled and matched against the +// request's Origin header); +// - a literal origin matched verbatim. +// +// Returns an error when no entry parses successfully, in which case the +// caller leaves the middleware unregistered. +func corsMiddlewareConfig(spec string) (middleware.CORSConfig, error) { + var literals []string + var patterns []*regexp.Regexp + for raw := range strings.SplitSeq(spec, ",") { + entry := strings.TrimSpace(raw) + if entry == "" { + continue + } + if rest, ok := strings.CutPrefix(entry, "~"); ok { + re, err := regexp.Compile(rest) + if err != nil { + return middleware.CORSConfig{}, fmt.Errorf("invalid CORS regex %q: %w", rest, err) + } + patterns = append(patterns, re) + continue + } + if err := validateCORSOrigin(entry); err != nil { + return middleware.CORSConfig{}, err + } + literals = append(literals, entry) + } + if len(literals) == 0 && len(patterns) == 0 { + return middleware.CORSConfig{}, errors.New("no usable CORS origins") + } + + cfg := middleware.CORSConfig{ + AllowOrigins: literals, + AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions}, + AllowHeaders: []string{"Authorization", "Content-Type", "Accept"}, + MaxAge: 86400, + } + if len(patterns) > 0 { + cfg.AllowOriginFunc = func(origin string) (bool, error) { + for _, re := range patterns { + if re.MatchString(origin) { + return true, nil + } + } + return false, nil + } + } + return cfg, nil +} + +// validateCORSOrigin sanity-checks a literal origin entry. The aim is to +// reject obvious typos early ("http//foo.com", "https://foo.com/bar") +// rather than to be a full URL parser — the echo middleware will still +// do its own matching at request time. +func validateCORSOrigin(o string) error { + if o == "*" { + return nil + } + u, err := url.Parse(o) + if err != nil { + return fmt.Errorf("invalid CORS origin %q: %w", o, err) + } + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("invalid CORS origin %q: scheme must be http or https", o) + } + if u.Host == "" { + return fmt.Errorf("invalid CORS origin %q: missing host", o) + } + if u.Path != "" || u.RawQuery != "" || u.Fragment != "" { + return fmt.Errorf("invalid CORS origin %q: must not include path, query, or fragment", o) + } + return nil +} + +// bearerAuthMiddleware enforces the static `Authorization: Bearer ` +// header. CORS preflight requests (OPTIONS) are exempted so that browsers +// can negotiate before sending the auth header. +// +// The expected token is captured by closure rather than read per-request, +// and the comparison uses subtle.ConstantTimeCompare so timing observation +// can't reveal valid prefixes. +func bearerAuthMiddleware(expected string) echo.MiddlewareFunc { + exp := []byte(expected) + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if c.Request().Method == http.MethodOptions { + return next(c) + } + // Schema introspection is always reachable so tooling can + // discover the API without credentials. + if c.Path() == "/openapi.json" { + return next(c) + } + got, ok := strings.CutPrefix(c.Request().Header.Get("Authorization"), "Bearer ") + if !ok || subtle.ConstantTimeCompare([]byte(got), exp) != 1 { + return writeError(c, http.StatusUnauthorized, "missing or invalid bearer token") + } + return next(c) + } + } +} + +func (s *server) handleModels(c echo.Context) error { + data := make([]openai.Model, 0, len(s.policy.exposed)) + for _, name := range s.policy.exposed { + data = append(data, openai.Model{ID: name, OwnedBy: "docker-agent"}) + } + return c.JSON(http.StatusOK, ModelsResponse{Object: "list", Data: data}) +} + +func (s *server) handleChatCompletions(c echo.Context) error { + var req ChatCompletionRequest + if err := json.NewDecoder(c.Request().Body).Decode(&req); err != nil { + return writeError(c, http.StatusBadRequest, err.Error()) + } + if len(req.Messages) == 0 { + return writeError(c, http.StatusBadRequest, "at least one message is required") + } + if err := validateSamplingParams(&req); err != nil { + return writeError(c, http.StatusBadRequest, err.Error()) + } + + conversationID := c.Request().Header.Get("X-Conversation-Id") + if !s.conversationLocks.tryAcquire(conversationID) { + return writeError(c, http.StatusConflict, "another request is already in flight for this conversation id") + } + defer s.conversationLocks.release(conversationID) + + sess := s.resolveSession(conversationID, req.Messages) + if sess == nil { + return writeError(c, http.StatusBadRequest, "no user message provided") + } + + agentName := s.policy.pick(req.Model) + rt, err := s.runtimes.Get(agentName) + if err != nil { + return writeError(c, http.StatusInternalServerError, fmt.Sprintf("failed to acquire runtime: %v", err)) + } + defer s.runtimes.Put(agentName, rt) + + // Echo back the requested model verbatim when set, so clients matching + // on the model field stay happy. Otherwise expose the actual agent. + model := agentName + if req.Model != "" { + model = req.Model + } + + if req.Stream { + err := s.streamChatCompletion(c, rt, sess, model) + s.maybeStoreConversation(conversationID, sess) + return err + } + err = s.chatCompletion(c, rt, sess, model) + s.maybeStoreConversation(conversationID, sess) + return err +} + +// resolveSession decides whether to start fresh or continue an existing +// conversation. When X-Conversation-Id is set and we have an existing +// session for it, we append only the latest user message from the +// request (the prior history is already in the session). Otherwise we +// build a brand-new session from the full request history. +func (s *server) resolveSession(id string, msgs []ChatCompletionMessage) *session.Session { + if id != "" { + if existing := s.conversations.Get(id); existing != nil { + appendLatestUser(existing, msgs) + return existing + } + } + return buildSession(msgs) +} + +// maybeStoreConversation inserts the session into the cache after a +// run. We always insert to handle the case where the conversation was +// evicted while the request was in flight. +func (s *server) maybeStoreConversation(id string, sess *session.Session) { + if id == "" || s.conversations == nil { + return + } + // Always Put, even for existing conversations, to handle eviction + // during request processing. Put refreshes the lastUsed timestamp + // and ensures the updated session is stored. + s.conversations.Put(id, sess) +} + +// chatCompletion runs the agent to completion and replies with one +// non-streaming OpenAI ChatCompletion object. +func (s *server) chatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { + var toolCalls []ToolCallReference + emit := agentEmit{ + onToolCall: func(tc ToolCallReference) { + toolCalls = append(toolCalls, tc) + }, + } + if err := runAgentLoop(c.Request().Context(), rt, sess, emit); err != nil { + return writeError(c, http.StatusInternalServerError, fmt.Sprintf("agent execution failed: %v", err)) + } + + return c.JSON(http.StatusOK, ChatCompletionResponse{ + ID: newChatID(), + Object: "chat.completion", + Created: time.Now().Unix(), + Model: model, + Choices: []ChatCompletionChoice{{ + Index: 0, + Message: ChatCompletionMessage{ + Role: "assistant", + Content: sess.GetLastAssistantMessageContent(), + ToolCalls: toolCalls, + }, + FinishReason: "stop", + }}, + Usage: sessionUsage(sess), + }) +} + +// streamChatCompletion runs the agent and streams its response back to the +// client as Server-Sent Events in OpenAI's chat.completion.chunk format. +// +// The error return is reserved for future use (e.g. surfacing a write +// failure to the request logger). Today every error is converted into an +// in-band SSE error event, so the function always returns nil. +func (s *server) streamChatCompletion(c echo.Context, rt runtime.Runtime, sess *session.Session, model string) error { //nolint:unparam // see comment + stream := newSSEStream(c.Response(), newChatID(), model) + + // Initial "role: assistant" delta so clients can start rendering. + stream.send(ChatCompletionStreamDelta{Role: "assistant"}, "") + + emit := agentEmit{ + onContent: func(content string) { + if content != "" { + stream.send(ChatCompletionStreamDelta{Content: content}, "") + } + }, + onToolCall: func(tc ToolCallReference) { + // Surface tool calls to the client using OpenAI's exact wire + // shape: a single delta carrying the full tool_call entry. + // (OpenAI streams arguments token-by-token; we have them all + // at once, so one chunk per call is enough.) Tools still run + // server-side — this is purely for client visibility. + stream.send(ChatCompletionStreamDelta{ToolCalls: []ToolCallReference{tc}}, "") + }, + } + runErr := runAgentLoop(c.Request().Context(), rt, sess, emit) + if runErr != nil { + // Emit a structured error envelope (OpenAI streams use a regular + // `data:` line carrying an `error` object, then close the stream + // with finish_reason "error" instead of "stop"). Clients matching + // on the OpenAI protocol can therefore distinguish a model error + // from a normal completion. + stream.sendError(runErr) + stream.send(ChatCompletionStreamDelta{}, "error") + } else { + stream.send(ChatCompletionStreamDelta{}, "stop") + } + stream.done() + return nil +} + +// sseStream writes OpenAI-style chat.completion.chunk events to a response. +// It centralises SSE bookkeeping (headers, JSON encoding, flushing, +// terminator) so the handler can focus on what to emit. +type sseStream struct { + w http.ResponseWriter + id string + model string + created int64 +} + +func newSSEStream(w http.ResponseWriter, id, model string) *sseStream { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + return &sseStream{w: w, id: id, model: model, created: time.Now().Unix()} +} + +func (s *sseStream) send(delta ChatCompletionStreamDelta, finishReason string) { + chunk := ChatCompletionStreamResponse{ + ID: s.id, + Object: "chat.completion.chunk", + Created: s.created, + Model: s.model, + Choices: []ChatCompletionStreamChoice{{ + Index: 0, + Delta: delta, + FinishReason: finishReason, + }}, + } + data, err := json.Marshal(chunk) + if err != nil { + return + } + _, _ = fmt.Fprintf(s.w, "data: %s\n\n", data) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } +} + +// done writes the OpenAI sentinel terminator that ends the stream. +func (s *sseStream) done() { + _, _ = fmt.Fprint(s.w, "data: [DONE]\n\n") + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } +} + +// sendError emits an OpenAI-style error envelope as a separate SSE event +// alongside the chunked deltas. Real OpenAI streams use this shape when a +// run fails mid-flight, e.g. a content filter trips: the message arrives +// in its own `data:` line carrying an `error` object before the stream +// terminates. +func (s *sseStream) sendError(err error) { + envelope := ErrorResponse{Error: ErrorDetail{ + Message: err.Error(), + Type: "internal_error", + }} + data, marshalErr := json.Marshal(envelope) + if marshalErr != nil { + return + } + _, _ = fmt.Fprintf(s.w, "data: %s\n\n", data) + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } +} + +// newChatID returns a fresh OpenAI-style chat completion id. +func newChatID() string { return "chatcmpl-" + uuid.NewString() } + +// writeError writes an OpenAI-style error envelope. +func writeError(c echo.Context, status int, message string) error { + return c.JSON(status, ErrorResponse{Error: ErrorDetail{ + Message: message, + Type: errTypeFor(status), + }}) +} + +func errTypeFor(status int) string { + if status >= 500 { + return "internal_error" + } + return "invalid_request_error" +} + +// validateSamplingParams range-checks the OpenAI sampling fields. Even when +// we don't yet plumb them all the way through to the model, validating up +// front lets clients learn about typos / out-of-range values immediately +// instead of getting an opaque provider error several seconds later. +func validateSamplingParams(req *ChatCompletionRequest) error { + if req.Temperature != nil { + t := *req.Temperature + if math.IsNaN(t) || t < 0 || t > 2 { + return fmt.Errorf("temperature must be in [0, 2], got %g", t) + } + } + if req.TopP != nil { + p := *req.TopP + if math.IsNaN(p) || p <= 0 || p > 1 { + return fmt.Errorf("top_p must be in (0, 1], got %g", p) + } + } + if req.MaxTokens != nil && *req.MaxTokens <= 0 { + return fmt.Errorf("max_tokens must be > 0, got %d", *req.MaxTokens) + } + if slices.Contains(req.Stop, "") { + return errors.New("stop sequences must not be empty strings") + } + return nil +} diff --git a/pkg/chatserver/server_test.go b/pkg/chatserver/server_test.go new file mode 100644 index 000000000..e4df5074b --- /dev/null +++ b/pkg/chatserver/server_test.go @@ -0,0 +1,517 @@ +package chatserver + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/chat" + "github.com/docker/docker-agent/pkg/session" +) + +func TestBuildSession_RequiresUserMessage(t *testing.T) { + tests := []struct { + name string + messages []ChatCompletionMessage + wantNil bool + }{ + { + name: "empty list", + wantNil: true, + }, + { + name: "only system messages", + messages: []ChatCompletionMessage{ + {Role: "system", Content: "be helpful"}, + }, + wantNil: true, + }, + { + name: "blank user message is ignored", + messages: []ChatCompletionMessage{ + {Role: "user", Content: " "}, + }, + wantNil: true, + }, + { + name: "valid user message", + messages: []ChatCompletionMessage{ + {Role: "user", Content: "hello"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + sess := buildSession(tc.messages) + if tc.wantNil { + assert.Nil(t, sess) + return + } + require.NotNil(t, sess) + assert.True(t, sess.ToolsApproved) + assert.True(t, sess.NonInteractive) + }) + } +} + +func TestBuildSession_PreservesHistory(t *testing.T) { + sess := buildSession([]ChatCompletionMessage{ + {Role: "system", Content: "you are a docker agent"}, + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi there"}, + {Role: "user", Content: "how are you?"}, + }) + require.NotNil(t, sess) + + // GetAllMessages omits system messages. + all := sess.GetAllMessages() + require.Len(t, all, 3) + + roles := make([]chat.MessageRole, len(all)) + for i, m := range all { + roles[i] = m.Message.Role + } + assert.Equal(t, []chat.MessageRole{ + chat.MessageRoleUser, + chat.MessageRoleAssistant, + chat.MessageRoleUser, + }, roles) + + assert.Equal(t, "how are you?", sess.GetLastUserMessageContent()) + assert.Equal(t, "hi there", sess.GetLastAssistantMessageContent()) +} + +func TestBuildSession_PreservesToolMessage(t *testing.T) { + sess := buildSession([]ChatCompletionMessage{ + {Role: "user", Content: "compute 2+2"}, + {Role: "assistant", Content: ""}, // dropped: empty content + {Role: "tool", Content: "4", ToolCallID: "call_1"}, + }) + require.NotNil(t, sess) + + all := sess.GetAllMessages() + require.Len(t, all, 2) + + last := all[len(all)-1].Message + assert.Equal(t, chat.MessageRoleTool, last.Role) + assert.Equal(t, "4", last.Content) + assert.Equal(t, "call_1", last.ToolCallID) +} + +func TestBuildSession_UnknownRoleTreatedAsUser(t *testing.T) { + sess := buildSession([]ChatCompletionMessage{ + {Role: "developer", Content: "do this"}, + }) + require.NotNil(t, sess) + + all := sess.GetAllMessages() + require.Len(t, all, 1) + assert.Equal(t, chat.MessageRoleUser, all[0].Message.Role) + assert.Equal(t, "do this", all[0].Message.Content) +} + +func TestSessionUsage_OmitsZero(t *testing.T) { + sess := session.New() + assert.Nil(t, sessionUsage(sess)) + + sess.InputTokens = 5 + sess.OutputTokens = 7 + usage := sessionUsage(sess) + require.NotNil(t, usage) + assert.Equal(t, int64(5), usage.PromptTokens) + assert.Equal(t, int64(7), usage.CompletionTokens) + assert.Equal(t, int64(12), usage.TotalTokens) +} + +func TestAgentPolicy_Pick(t *testing.T) { + p := agentPolicy{exposed: []string{"root", "reviewer"}, fallback: "root"} + + assert.Equal(t, "reviewer", p.pick("reviewer")) + assert.Equal(t, "root", p.pick("root")) + assert.Equal(t, "root", p.pick(""), "empty model falls back") + assert.Equal(t, "root", p.pick("gpt-4"), "unknown model falls back") +} + +func TestErrTypeFor(t *testing.T) { + assert.Equal(t, "invalid_request_error", errTypeFor(400)) + assert.Equal(t, "invalid_request_error", errTypeFor(404)) + assert.Equal(t, "internal_error", errTypeFor(500)) + assert.Equal(t, "internal_error", errTypeFor(502)) +} + +func TestNewRouter_CORSDisabledByDefault(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"), + "no CORS header should be emitted when no origin is configured") +} + +func TestNewRouter_CORSAllowsConfiguredOrigin(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{CORSOrigin: "https://example.com"}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, "https://example.com", rec.Header().Get("Access-Control-Allow-Origin")) +} + +func TestCorsMiddlewareConfig(t *testing.T) { + cases := []struct { + name string + spec string + wantErr bool + }{ + {name: "single literal", spec: "https://app.example.com"}, + {name: "comma list", spec: "https://a.example.com, https://b.example.com"}, + {name: "regex", spec: `~^https://[a-z]+\.example\.com$`}, + {name: "wildcard", spec: "*"}, + {name: "mixed", spec: `https://a.example.com,~^https://b\.example\.com$`}, + {name: "empty entries collapse", spec: ", , https://x.com,,"}, + + {name: "missing scheme", spec: "app.example.com", wantErr: true}, + {name: "with path", spec: "https://example.com/api", wantErr: true}, + {name: "with query", spec: "https://example.com?x=1", wantErr: true}, + {name: "bad regex", spec: "~[", wantErr: true}, + {name: "all blanks", spec: ", , ,", wantErr: true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := corsMiddlewareConfig(tc.spec) + if tc.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + }) + } +} + +func TestNewRouter_CORSAllowList(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{CORSOrigin: "https://a.example.com,https://b.example.com"}) + + cases := []struct { + origin string + want string // expected Access-Control-Allow-Origin + }{ + {"https://a.example.com", "https://a.example.com"}, + {"https://b.example.com", "https://b.example.com"}, + {"https://evil.example.com", ""}, + } + for _, tc := range cases { + t.Run(tc.origin, func(t *testing.T) { + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", tc.origin) + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + assert.Equal(t, tc.want, rec.Header().Get("Access-Control-Allow-Origin")) + }) + } +} + +func TestNewRouter_CORSRegex(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{CORSOrigin: `~^https://[a-z]+\.example\.com$`}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", "https://staging.example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, "https://staging.example.com", rec.Header().Get("Access-Control-Allow-Origin")) + + // A non-matching origin must not get the header. + req2 := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req2.Header.Set("Origin", "https://evil.attacker.com") + req2.Header.Set("Access-Control-Request-Method", "GET") + rec2 := httptest.NewRecorder() + r.ServeHTTP(rec2, req2) + assert.Empty(t, rec2.Header().Get("Access-Control-Allow-Origin")) +} + +func TestBearerAuthMiddleware(t *testing.T) { + cases := []struct { + name string + header string + wantStatus int + }{ + {"missing", "", http.StatusUnauthorized}, + {"wrong scheme", "Basic abc", http.StatusUnauthorized}, + {"wrong token", "Bearer wrong", http.StatusUnauthorized}, + {"correct token", "Bearer secret", http.StatusOK}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{APIKey: "secret"}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/models", http.NoBody) + if tc.header != "" { + req.Header.Set("Authorization", tc.header) + } + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, tc.wantStatus, rec.Code) + }) + } +} + +func TestHandleChatCompletions_RejectsConcurrentSameConversation(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{}) + + // Pre-acquire the conversation lock to simulate an in-flight + // request. The next request with the same id must get 409. + require.True(t, srv.conversationLocks.tryAcquire("conv-x")) + defer srv.conversationLocks.release("conv-x") + + body := `{"messages":[{"role":"user","content":"hi"}]}` + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Conversation-Id", "conv-x") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusConflict, rec.Code) + assert.Contains(t, rec.Body.String(), "another request is already in flight") +} + +func TestBearerAuthMiddleware_AllowsCORSPreflight(t *testing.T) { + // CORS preflight must succeed without an Authorization header. + srv, _ := newTestServer("root") + r := newRouter(srv, Options{APIKey: "secret", CORSOrigin: "https://example.com"}) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodOptions, "/v1/models", http.NoBody) + req.Header.Set("Origin", "https://example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.NotEqual(t, http.StatusUnauthorized, rec.Code) +} + +func TestNewRouter_RejectsOversizedBody(t *testing.T) { + srv, _ := newTestServer("root") + r := newRouter(srv, Options{MaxRequestBytes: 16}) + + body := strings.NewReader(`{"messages":[{"role":"user","content":"this body is far longer than sixteen bytes"}]}`) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/chat/completions", body) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusRequestEntityTooLarge, rec.Code) +} + +func TestValidateSamplingParams(t *testing.T) { + f := func(v float64) *float64 { return &v } + i := func(v int64) *int64 { return &v } + + cases := []struct { + name string + req ChatCompletionRequest + wantErr string + }{ + {name: "all empty"}, + {name: "valid", req: ChatCompletionRequest{ + Temperature: f(0.7), TopP: f(0.95), MaxTokens: i(256), + Stop: StopSequences{"\n\n", "END"}, + }}, + {name: "temp negative", req: ChatCompletionRequest{Temperature: f(-0.1)}, wantErr: "temperature"}, + {name: "temp too high", req: ChatCompletionRequest{Temperature: f(2.5)}, wantErr: "temperature"}, + {name: "topp zero", req: ChatCompletionRequest{TopP: f(0)}, wantErr: "top_p"}, + {name: "topp too high", req: ChatCompletionRequest{TopP: f(1.5)}, wantErr: "top_p"}, + {name: "max_tokens zero", req: ChatCompletionRequest{MaxTokens: i(0)}, wantErr: "max_tokens"}, + {name: "empty stop", req: ChatCompletionRequest{Stop: StopSequences{""}}, wantErr: "stop"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := validateSamplingParams(&tc.req) + if tc.wantErr == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + }) + } +} + +func TestChatCompletionMessage_UnmarshalContentString(t *testing.T) { + var m ChatCompletionMessage + require.NoError(t, json.Unmarshal([]byte(`{"role":"user","content":"hello"}`), &m)) + assert.Equal(t, "user", m.Role) + assert.Equal(t, "hello", m.Content) + assert.Empty(t, m.Parts) +} + +func TestChatCompletionMessage_UnmarshalContentParts(t *testing.T) { + var m ChatCompletionMessage + input := `{ + "role":"user", + "content":[ + {"type":"text","text":"What is in this picture?"}, + {"type":"image_url","image_url":{"url":"https://example.com/x.png","detail":"high"}} + ] + }` + require.NoError(t, json.Unmarshal([]byte(input), &m)) + assert.Equal(t, "user", m.Role) + require.Len(t, m.Parts, 2) + assert.Equal(t, "text", m.Parts[0].Type) + assert.Equal(t, "image_url", m.Parts[1].Type) + require.NotNil(t, m.Parts[1].ImageURL) + assert.Equal(t, "https://example.com/x.png", m.Parts[1].ImageURL.URL) + // Flat text is pre-computed for callers that don't care about parts. + assert.Equal(t, "What is in this picture?", m.Content) +} + +func TestChatCompletionMessage_RoundTripText(t *testing.T) { + in := ChatCompletionMessage{Role: "assistant", Content: "hi there"} + raw, err := json.Marshal(in) + require.NoError(t, err) + assert.JSONEq(t, `{"role":"assistant","content":"hi there"}`, string(raw)) +} + +func TestChatCompletionMessage_RoundTripParts(t *testing.T) { + in := ChatCompletionMessage{ + Role: "user", + Parts: []ContentPart{ + {Type: "text", Text: "hi"}, + {Type: "image_url", ImageURL: &ContentImageURL{URL: "http://x/y"}}, + }, + } + raw, err := json.Marshal(in) + require.NoError(t, err) + assert.JSONEq(t, `{"role":"user","content":[{"type":"text","text":"hi"},{"type":"image_url","image_url":{"url":"http://x/y"}}]}`, string(raw)) +} + +func TestBuildSession_AcceptsImageParts(t *testing.T) { + sess := buildSession([]ChatCompletionMessage{{ + Role: "user", + Parts: []ContentPart{ + {Type: "text", Text: "What is this?"}, + {Type: "image_url", ImageURL: &ContentImageURL{URL: "https://example.com/x.png"}}, + }, + }}) + require.NotNil(t, sess) + + all := sess.GetAllMessages() + require.Len(t, all, 1) + last := all[0].Message + assert.Equal(t, chat.MessageRoleUser, last.Role) + require.Len(t, last.MultiContent, 2) + assert.Equal(t, chat.MessagePartTypeText, last.MultiContent[0].Type) + assert.Equal(t, chat.MessagePartTypeImageURL, last.MultiContent[1].Type) + require.NotNil(t, last.MultiContent[1].ImageURL) + assert.Equal(t, "https://example.com/x.png", last.MultiContent[1].ImageURL.URL) +} + +func TestStopSequences_UnmarshalJSON(t *testing.T) { + cases := []struct { + name string + json string + want []string + err bool + }{ + {"null", `null`, nil, false}, + {"single string", `"END"`, []string{"END"}, false}, + {"array", `["a", "b"]`, []string{"a", "b"}, false}, + {"empty array", `[]`, []string{}, false}, + {"number invalid", `42`, nil, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got StopSequences + err := got.UnmarshalJSON([]byte(tc.json)) + if tc.err { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, []string(got)) + }) + } +} + +func TestSSEStream_ToolCallDelta(t *testing.T) { + rec := httptest.NewRecorder() + s := newSSEStream(rec, "chatcmpl-x", "root") + s.send(ChatCompletionStreamDelta{ToolCalls: []ToolCallReference{{ + Index: 0, + ID: "call_1", + Type: "function", + Function: ToolCallFunction{ + Name: "search", + Arguments: `{"q":"docker"}`, + }, + }}}, "") + + body := rec.Body.String() + assert.Contains(t, body, `"tool_calls":[`) + assert.Contains(t, body, `"id":"call_1"`) + assert.Contains(t, body, `"name":"search"`) + assert.Contains(t, body, `"arguments":"{\"q\":\"docker\"}"`) +} + +func TestSSEStream_SendError(t *testing.T) { + rec := httptest.NewRecorder() + s := newSSEStream(rec, "chatcmpl-x", "root") + s.sendError(errors.New("model exploded")) + s.send(ChatCompletionStreamDelta{}, "error") + s.done() + + body := rec.Body.String() + // One error envelope. + assert.Contains(t, body, `"error":{"message":"model exploded"`) + // One terminating chunk with finish_reason=error (instead of stop). + assert.Contains(t, body, `"finish_reason":"error"`) + // And the OpenAI sentinel. + assert.Contains(t, body, "data: [DONE]") +} + +func TestRequestTimeoutMiddleware_AppliesDeadline(t *testing.T) { + e := echo.New() + e.Use(requestTimeoutMiddleware(5 * time.Millisecond)) + + var gotErr error + e.GET("/sleep", func(c echo.Context) error { + select { + case <-c.Request().Context().Done(): + gotErr = c.Request().Context().Err() + return c.String(http.StatusOK, "ok") + case <-time.After(time.Second): + return c.String(http.StatusOK, "too slow") + } + }) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/sleep", http.NoBody) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + require.Error(t, gotErr) + assert.ErrorIs(t, gotErr, context.DeadlineExceeded) +} diff --git a/pkg/chatserver/types.go b/pkg/chatserver/types.go new file mode 100644 index 000000000..bbf3a4e06 --- /dev/null +++ b/pkg/chatserver/types.go @@ -0,0 +1,289 @@ +package chatserver + +import ( + "encoding/json" + "errors" + "strings" + + "github.com/openai/openai-go/v3" +) + +// This file declares the OpenAI-compatible request/response types used by +// /v1/chat/completions and /v1/models. We hand-roll most of them instead of +// borrowing from github.com/openai/openai-go/v3 because the SDK's response +// structs are deserialised through its internal `apijson` package and don't +// have `omitempty` JSON tags; marshalling them with stdlib `encoding/json` +// produces noisy responses full of empty audio/tool_call/refusal +// placeholders. `openai.Model` round-trips cleanly with stdlib json, so +// /v1/models reuses it. + +// --- Request -------------------------------------------------------------- + +// ChatCompletionRequest is the body of a /v1/chat/completions call. We +// declare every field commonly sent by OpenAI clients so they are accepted +// without surprise. Whether each field is *acted on* is documented inline. +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` + + // Temperature is parsed and range-checked but not yet plumbed through + // to the runtime/model layer (no per-request override exists today). + // Set on the agent's YAML configuration to control sampling. + Temperature *float64 `json:"temperature,omitempty"` + // TopP is parsed and range-checked but not yet plumbed through. + TopP *float64 `json:"top_p,omitempty"` + // MaxTokens is the maximum number of tokens the model may generate in + // the response. Parsed and validated; runtime plumbing is tracked for + // a follow-up. + MaxTokens *int64 `json:"max_tokens,omitempty"` + // Stop is one or more substrings that, if produced, end generation. + // Accepted as either a single string or an array of strings, matching + // the OpenAI schema. Validated; not yet enforced. + Stop StopSequences `json:"stop,omitempty"` +} + +// StopSequences is a JSON-flexible field that accepts either a single +// string or an array of strings. OpenAI's API uses both shapes +// interchangeably; clients in the wild send both. +type StopSequences []string + +func (s *StopSequences) UnmarshalJSON(data []byte) error { + if len(data) == 0 || string(data) == "null" { + *s = nil + return nil + } + switch data[0] { + case '"': + var one string + if err := json.Unmarshal(data, &one); err != nil { + return err + } + *s = []string{one} + return nil + case '[': + var many []string + if err := json.Unmarshal(data, &many); err != nil { + return err + } + *s = many + return nil + default: + return errors.New("stop must be a string or array of strings") + } +} + +// ChatCompletionMessage is a single message in the conversation. +// +// On the wire OpenAI accepts message content in two shapes: either a +// plain string (`"content": "hello"`) or an array of typed parts +// (`"content": [{"type":"text",...}, {"type":"image_url",...}]`). +// Both shapes are accepted on the request side; the response always +// uses the string form for text-only content and the parts form when +// images or other non-text content are present. The custom JSON +// (un)marshallers below preserve that union without forcing every Go +// caller to deal with it. +type ChatCompletionMessage struct { + Role string `json:"role"` + // Content is the text content of the message. Populated whether the + // wire format used a string or a parts array (the parts' text values + // are concatenated). + Content string `json:"-"` + // Parts holds the original typed parts when the wire format used an + // array. Empty when the wire format was a plain string. + Parts []ContentPart `json:"-"` + + Name string `json:"name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolCalls []ToolCallReference `json:"tool_calls,omitempty"` +} + +// ContentPart mirrors one entry in OpenAI's typed-parts array. Today the +// server understands `text` and `image_url` parts; unknown types are +// preserved in the request payload but ignored when building the +// session, so future part types degrade gracefully. +type ContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageURL *ContentImageURL `json:"image_url,omitempty"` +} + +// ContentImageURL carries an image part. URL may be a regular http(s) +// URL or a data URL (`data:image/png;base64,...`). +type ContentImageURL struct { + URL string `json:"url"` + Detail string `json:"detail,omitempty"` +} + +// jsonMessageEnvelope is the on-the-wire form of ChatCompletionMessage. +// It exists so we can run the union-shape decoding for `content` without +// duplicating every other field. +type jsonMessageEnvelope struct { + Role string `json:"role"` + Content json.RawMessage `json:"content,omitempty"` + Name string `json:"name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` + ToolCalls []ToolCallReference `json:"tool_calls,omitempty"` +} + +// UnmarshalJSON accepts either a string `content` field or an array of +// typed parts (OpenAI's multimodal shape). +func (m *ChatCompletionMessage) UnmarshalJSON(data []byte) error { + var env jsonMessageEnvelope + if err := json.Unmarshal(data, &env); err != nil { + return err + } + m.Role = env.Role + m.Name = env.Name + m.ToolCallID = env.ToolCallID + m.ToolCalls = env.ToolCalls + + if len(env.Content) == 0 || string(env.Content) == "null" { + return nil + } + switch env.Content[0] { + case '"': + return json.Unmarshal(env.Content, &m.Content) + case '[': + if err := json.Unmarshal(env.Content, &m.Parts); err != nil { + return err + } + // Pre-compute the flat text so callers that don't care about + // images can keep using m.Content. + var buf strings.Builder + for _, p := range m.Parts { + if p.Type == "text" { + if buf.Len() > 0 { + buf.WriteByte(' ') + } + buf.WriteString(p.Text) + } + } + m.Content = buf.String() + return nil + default: + return errors.New("content must be a string or array of parts") + } +} + +// MarshalJSON emits the parts array when present, otherwise a plain +// string. Tool/role/name/tool_call_id round-trip verbatim. +func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) { + env := jsonMessageEnvelope{ + Role: m.Role, + Name: m.Name, + ToolCallID: m.ToolCallID, + ToolCalls: m.ToolCalls, + } + switch { + case len(m.Parts) > 0: + raw, err := json.Marshal(m.Parts) + if err != nil { + return nil, err + } + env.Content = raw + case m.Content != "": + raw, err := json.Marshal(m.Content) + if err != nil { + return nil, err + } + env.Content = raw + } + return json.Marshal(env) +} + +// ToolCallReference mirrors OpenAI's `tool_calls` entry. The server fills +// it in on the *response* side so clients can introspect what tools the +// agent invoked. Tools are still executed server-side; this is purely +// informational. +type ToolCallReference struct { + // Index is the position of the tool call in the assistant message. + // In streaming mode multiple chunks targeting the same Index are + // concatenated by the client. + Index int `json:"index,omitempty"` + // ID matches what is later echoed back as ToolCallID on `tool` role + // messages — useful when correlating tool calls with their results. + ID string `json:"id,omitempty"` + // Type is always "function" today; OpenAI reserves the field for + // future expansion. + Type string `json:"type,omitempty"` + // Function carries the tool's name and JSON-encoded arguments. + Function ToolCallFunction `json:"function"` +} + +// ToolCallFunction mirrors OpenAI's nested tool function descriptor. +type ToolCallFunction struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments,omitempty"` +} + +// --- Non-streaming response ----------------------------------------------- + +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionChoice `json:"choices"` + Usage *ChatCompletionUsage `json:"usage,omitempty"` +} + +type ChatCompletionChoice struct { + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// ChatCompletionUsage reports approximate token counts. Best-effort: when +// the underlying provider doesn't report usage we omit the field entirely. +type ChatCompletionUsage struct { + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +// --- Streaming response --------------------------------------------------- + +// ChatCompletionStreamResponse is one SSE chunk emitted when the client +// requests stream: true. +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatCompletionStreamChoice `json:"choices"` +} + +type ChatCompletionStreamChoice struct { + Index int `json:"index"` + Delta ChatCompletionStreamDelta `json:"delta"` + FinishReason string `json:"finish_reason,omitempty"` +} + +type ChatCompletionStreamDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + ToolCalls []ToolCallReference `json:"tool_calls,omitempty"` +} + +// --- Models endpoint ------------------------------------------------------ + +// ModelsResponse is the body returned by /v1/models. Each agent in the team +// is exposed as one entry. +type ModelsResponse struct { + Object string `json:"object"` + Data []openai.Model `json:"data"` +} + +// --- Errors --------------------------------------------------------------- + +// ErrorResponse is the OpenAI-style error envelope returned on 4xx/5xx. +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +type ErrorDetail struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code,omitempty"` +}