Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cmd/root/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/docker/docker-agent/pkg/config"
"github.com/docker/docker-agent/pkg/mcp"
"github.com/docker/docker-agent/pkg/runregistry"
"github.com/docker/docker-agent/pkg/telemetry"
)

Expand Down Expand Up @@ -38,7 +39,7 @@ func newMCPCmd() *cobra.Command {
cmd.PersistentFlags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to run (all agents if not specified)")
cmd.PersistentFlags().BoolVar(&flags.http, "http", false, "Use streaming HTTP transport instead of stdio")
cmd.PersistentFlags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8081", "Address to listen on")
cmd.PersistentFlags().StringVar(&flags.attach, "attach", "", "Attach to a running TUI run by pid (or empty for the most recent)")
cmd.PersistentFlags().StringVar(&flags.attach, "attach", "", "Attach to a running TUI run by pid, address, or session id (or empty for the most recent)")
cmd.PersistentFlags().Lookup("attach").NoOptDefVal = "latest"
cmd.PersistentFlags().StringVar(&flags.runConfig.MCPToolName, "tool-name", "", "Override the MCP tool identifier clients call (defaults to agent name); only valid when exposing a single agent")
cmd.PersistentFlags().DurationVar(&flags.runConfig.MCPKeepAlive, "mcp-keepalive", 0, "Interval between MCP keep-alive pings (e.g. 30s); 0 disables keep-alive")
Expand Down Expand Up @@ -81,7 +82,7 @@ func (f *mcpFlags) runAttach(ctx context.Context) error {
if target == "latest" {
target = ""
}
rec, err := resolveTarget(target)
rec, err := runregistry.Find(target)
if err != nil {
return err
}
Expand Down
70 changes: 30 additions & 40 deletions cmd/root/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"sync"

"github.com/spf13/cobra"

"github.com/docker/docker-agent/pkg/api"
"github.com/docker/docker-agent/pkg/runregistry"
"github.com/docker/docker-agent/pkg/runtime"
"github.com/docker/docker-agent/pkg/telemetry"
)
Expand Down Expand Up @@ -41,7 +41,7 @@ Output objects always carry a "type" field.`,
RunE: flags.run,
}

cmd.Flags().StringVar(&flags.target, "to", "", "Target run pid (defaults to the most recent live run)")
cmd.Flags().StringVar(&flags.target, "to", "", targetFlagUsage)
return cmd
}

Expand All @@ -51,7 +51,7 @@ func (f *protoFlags) run(cmd *cobra.Command, args []string) (commandErr error) {
telemetry.TrackCommand(ctx, "proto", args)
defer func() { telemetry.TrackCommandError(ctx, "proto", args, commandErr) }()

rec, err := resolveTarget(f.target)
rec, err := runregistry.Find(f.target)
if err != nil {
return err
}
Expand Down Expand Up @@ -116,75 +116,65 @@ func readCommands(ctx context.Context, in io.Reader, client *runtime.Client, ses
continue
}

if err := dispatchProto(ctx, client, sessionID, req, w); err != nil {
handled, err := dispatchProto(ctx, client, sessionID, req, w)
if err != nil {
w.send(map[string]any{"id": req.ID, "type": "error", "message": err.Error()})
continue
}
w.send(map[string]any{"id": req.ID, "type": "ack"})
// Skip the generic ack when the dispatcher already produced a
// typed reply (e.g. transcript), to avoid two responses per
// request.
if !handled {
w.send(map[string]any{"id": req.ID, "type": "ack"})
}
}
return scanner.Err()
}

func dispatchProto(ctx context.Context, client *runtime.Client, sessionID string, req protoRequest, w *protoWriter) error {
// dispatchProto routes a request to the runtime client. The bool return is
// true when dispatchProto already produced a typed reply on w; in that case
// the caller MUST NOT emit a generic "ack".
func dispatchProto(ctx context.Context, client *runtime.Client, sessionID string, req protoRequest, w *protoWriter) (bool, error) {
switch req.Type {
case "send":
return client.SteerSession(ctx, sessionID, []api.Message{{Content: req.Message}})
return false, client.SteerSession(ctx, sessionID, []api.Message{{Content: req.Message}})
case "followup":
return client.FollowUpSession(ctx, sessionID, []api.Message{{Content: req.Message}})
return false, client.FollowUpSession(ctx, sessionID, []api.Message{{Content: req.Message}})
case "resume":
decision := req.Decision
if decision == "" {
decision = "approve"
}
return client.ResumeSession(ctx, sessionID, decision, req.Reason, req.ToolName)
return false, client.ResumeSession(ctx, sessionID, decision, req.Reason, req.ToolName)
case "interrupt":
return client.ResumeSession(ctx, sessionID, "reject", req.Reason, req.ToolName)
return false, client.ResumeSession(ctx, sessionID, "reject", req.Reason, req.ToolName)
case "transcript":
sess, err := client.GetSession(ctx, sessionID)
if err != nil {
return err
return false, err
}
messages := sess.Messages
if req.Limit > 0 && len(messages) > req.Limit {
messages = messages[len(messages)-req.Limit:]
}
w.send(map[string]any{"id": req.ID, "type": "transcript", "title": sess.Title, "messages": messages})
return nil
return true, nil
default:
return fmt.Errorf("unknown request type %q", req.Type)
return false, fmt.Errorf("unknown request type %q", req.Type)
}
}

// streamEvents forwards every SSE event received from the run's control
// plane as a {"type":"event","event":<raw json>} envelope on w.
func streamEvents(ctx context.Context, addr, sessionID string, w *protoWriter) {
url := addr + "/api/sessions/" + sessionID + "/events"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
body, err := openEventStream(ctx, addr, sessionID)
if err != nil {
return
}
resp, err := http.DefaultClient.Do(req)
if err != nil || resp.StatusCode >= 400 {
if resp != nil {
resp.Body.Close()
}
return
}
defer resp.Body.Close()
defer body.Close()

scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
for scanner.Scan() {
if ctx.Err() != nil {
return
}
line := scanner.Bytes()
after, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
continue
}
var event any
if err := json.Unmarshal(after, &event); err != nil {
continue
}
w.send(map[string]any{"type": "event", "event": event})
}
_ = readEventStream(ctx, body, func(payload json.RawMessage) error {
w.send(map[string]any{"type": "event", "event": payload})
return nil
})
}
22 changes: 13 additions & 9 deletions cmd/root/proto_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,20 @@ func TestProtoDispatch_RoutesRequestsToHTTPClient(t *testing.T) {
ctx := t.Context()

cases := []struct {
req protoRequest
want string
req protoRequest
want string
wantHandled bool
}{
{protoRequest{Type: "send", Message: "hi"}, "POST /api/sessions/s1/steer"},
{protoRequest{Type: "followup", Message: "later"}, "POST /api/sessions/s1/followup"},
{protoRequest{Type: "resume", Decision: "approve"}, "POST /api/sessions/s1/resume"},
{protoRequest{Type: "interrupt"}, "POST /api/sessions/s1/resume"},
{protoRequest{Type: "transcript"}, "GET /api/sessions/s1"},
{protoRequest{Type: "send", Message: "hi"}, "POST /api/sessions/s1/steer", false},
{protoRequest{Type: "followup", Message: "later"}, "POST /api/sessions/s1/followup", false},
{protoRequest{Type: "resume", Decision: "approve"}, "POST /api/sessions/s1/resume", false},
{protoRequest{Type: "interrupt"}, "POST /api/sessions/s1/resume", false},
{protoRequest{Type: "transcript"}, "GET /api/sessions/s1", true},
}
for _, c := range cases {
require.NoError(t, dispatchProto(ctx, client, "s1", c.req, w))
handled, err := dispatchProto(ctx, client, "s1", c.req, w)
require.NoError(t, err)
assert.Equal(t, c.wantHandled, handled, "handled flag for %s", c.req.Type)
}

rec.mu.Lock()
Expand All @@ -81,7 +84,8 @@ func TestProtoDispatch_UnknownType(t *testing.T) {
out := &bytes.Buffer{}
w := newProtoWriter(out)

err := dispatchProto(t.Context(), nil, "s1", protoRequest{Type: "nope"}, w)
handled, err := dispatchProto(t.Context(), nil, "s1", protoRequest{Type: "nope"}, w)
require.Error(t, err)
assert.False(t, handled)
assert.Contains(t, err.Error(), "unknown request type")
}
41 changes: 39 additions & 2 deletions cmd/root/run_event_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/docker/docker-agent/pkg/app"
"github.com/docker/docker-agent/pkg/runtime"
"github.com/docker/docker-agent/pkg/shellpath"
)

type onEventHook struct {
Expand Down Expand Up @@ -66,10 +67,46 @@ func withEventHooks(hooks []onEventHook) app.Opt {
}
}

// maxHookOutput caps the diagnostic output we keep for a failed on-event
// hook. Large enough to be useful, small enough that a chatty or runaway
// command can't push unbounded data into the agent's heap.
const maxHookOutput = 4 * 1024

func runEventHook(command string, payload []byte) {
cmd := exec.CommandContext(context.Background(), "sh", "-c", command)
shell, argsPrefix := shellpath.DetectShell()
// Hooks are detached from the app context on purpose: a hook still
// flushing the last event when the user exits the TUI should be allowed
// to finish. Each invocation receives a single event on stdin, processes
// it, and exits; the spawning goroutine ends with the subprocess.
cmd := exec.CommandContext(context.Background(), shell, append(argsPrefix, command)...)
cmd.Stdin = bytes.NewReader(payload)
var out boundedBuffer
cmd.Stdout = &out
cmd.Stderr = &out
if err := cmd.Run(); err != nil {
slog.Warn("on-event hook failed", "command", command, "error", err)
slog.Warn("on-event hook failed", "command", command, "error", err, "output", strings.TrimSpace(out.String()))
}
}

// boundedBuffer captures up to maxHookOutput bytes from a hook subprocess
// and silently discards the rest. It implements only io.Writer so it can be
// assigned to exec.Cmd's Stdout/Stderr without forcing exec to buffer the
// full output internally.
type boundedBuffer struct {
buf bytes.Buffer
}

func (b *boundedBuffer) Write(p []byte) (int, error) {
if remaining := maxHookOutput - b.buf.Len(); remaining > 0 {
if len(p) > remaining {
b.buf.Write(p[:remaining])
} else {
b.buf.Write(p)
}
}
return len(p), nil
}

func (b *boundedBuffer) String() string {
return b.buf.String()
}
27 changes: 27 additions & 0 deletions cmd/root/run_event_hooks_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package root

import (
"bytes"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -27,3 +29,28 @@ func TestParseOnEventFlags_BadFormat(t *testing.T) {
assert.Error(t, err, "expected error for %q", s)
}
}

func TestBoundedBuffer_CapsAtMaxHookOutput(t *testing.T) {
var b boundedBuffer

n, err := b.Write(bytes.Repeat([]byte("a"), maxHookOutput-3))
require.NoError(t, err)
assert.Equal(t, maxHookOutput-3, n)

// A write that straddles the cap is fully accepted from the caller's
// perspective (so io.Copy doesn't error) but only the bytes up to the
// cap are retained.
n, err = b.Write([]byte("bbbbbb"))
require.NoError(t, err)
assert.Equal(t, 6, n)

// Subsequent writes past the cap are silently discarded.
n, err = b.Write([]byte("ccccc"))
require.NoError(t, err)
assert.Equal(t, 5, n)

got := b.String()
assert.Len(t, got, maxHookOutput)
assert.True(t, strings.HasPrefix(got, strings.Repeat("a", maxHookOutput-3)))
assert.True(t, strings.HasSuffix(got, "bbb"))
}
48 changes: 10 additions & 38 deletions cmd/root/send.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package root

import (
"errors"
"fmt"
"io"
"strconv"
"strings"

"github.com/spf13/cobra"
Expand All @@ -30,21 +28,26 @@ func newSendCmd() *cobra.Command {
Long: `Send a message to the most recent docker-agent run that exposes a control
plane (started with run --listen). Use - to read the message from stdin.

The target can be selected explicitly with --to <addr|pid>; otherwise the
most recent live run is used.`,
The target can be selected explicitly with --to, accepting a pid, an address
(http://host:port), or a session id. Without --to, the most recent live run
is used.`,
Example: ` docker-agent send "summarize the diff"
echo "and now write tests" | docker-agent send -
docker-agent send --to 12345 "hello"`,
docker-agent send --to 12345 "hello"
docker-agent send --to http://127.0.0.1:8765 "hi"`,
GroupID: "advanced",
Args: cobra.ExactArgs(1),
RunE: flags.run,
}

cmd.Flags().StringVar(&flags.target, "to", "", "Target run pid (defaults to the most recent live run)")
cmd.Flags().StringVar(&flags.target, "to", "", targetFlagUsage)
cmd.Flags().BoolVar(&flags.followUp, "followup", false, "Queue as an end-of-turn follow-up instead of mid-turn steering")
return cmd
}

// targetFlagUsage is the canonical help text for --to across send/watch/proto.
const targetFlagUsage = "Target run pid, address (http://host:port), or session id (defaults to the most recent live run)"

func (f *sendFlags) run(cmd *cobra.Command, args []string) (commandErr error) {
ctx := cmd.Context()
telemetry.TrackCommand(ctx, "send", args)
Expand All @@ -55,7 +58,7 @@ func (f *sendFlags) run(cmd *cobra.Command, args []string) (commandErr error) {
return err
}

rec, err := resolveTarget(f.target)
rec, err := runregistry.Find(f.target)
if err != nil {
return err
}
Expand Down Expand Up @@ -90,34 +93,3 @@ func readMessage(stdin io.Reader, arg string) (string, error) {
}
return strings.TrimRight(string(buf), "\n"), nil
}

// resolveTarget returns the registry record matching --to. An empty target
// resolves to the most recent live run; a numeric target is matched by pid.
func resolveTarget(target string) (runregistry.Record, error) {
if target == "" {
rec, ok, err := runregistry.Latest()
if err != nil {
return runregistry.Record{}, err
}
if !ok {
return runregistry.Record{}, errors.New("no live docker-agent run found; start one with: docker-agent run --listen 127.0.0.1:0")
}
return rec, nil
}

pid, err := strconv.Atoi(target)
if err != nil {
return runregistry.Record{}, fmt.Errorf("--to must be a pid, got %q", target)
}

records, err := runregistry.List()
if err != nil {
return runregistry.Record{}, err
}
for _, r := range records {
if r.PID == pid {
return r, nil
}
}
return runregistry.Record{}, fmt.Errorf("no live run with pid %d", pid)
}
Loading
Loading