diff --git a/cmd/root/mcp.go b/cmd/root/mcp.go index c17c4d008..4180ba67a 100644 --- a/cmd/root/mcp.go +++ b/cmd/root/mcp.go @@ -8,6 +8,7 @@ import ( "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/mcp" + "github.com/docker/docker-agent/pkg/runregistry" "github.com/docker/docker-agent/pkg/telemetry" ) @@ -38,7 +39,7 @@ func newMCPCmd() *cobra.Command { cmd.PersistentFlags().StringVarP(&flags.agentName, "agent", "a", "", "Name of the agent to run (all agents if not specified)") cmd.PersistentFlags().BoolVar(&flags.http, "http", false, "Use streaming HTTP transport instead of stdio") cmd.PersistentFlags().StringVarP(&flags.listenAddr, "listen", "l", "127.0.0.1:8081", "Address to listen on") - cmd.PersistentFlags().StringVar(&flags.attach, "attach", "", "Attach to a running TUI run by pid (or empty for the most recent)") + cmd.PersistentFlags().StringVar(&flags.attach, "attach", "", "Attach to a running TUI run by pid, address, or session id (or empty for the most recent)") cmd.PersistentFlags().Lookup("attach").NoOptDefVal = "latest" cmd.PersistentFlags().StringVar(&flags.runConfig.MCPToolName, "tool-name", "", "Override the MCP tool identifier clients call (defaults to agent name); only valid when exposing a single agent") cmd.PersistentFlags().DurationVar(&flags.runConfig.MCPKeepAlive, "mcp-keepalive", 0, "Interval between MCP keep-alive pings (e.g. 30s); 0 disables keep-alive") @@ -81,7 +82,7 @@ func (f *mcpFlags) runAttach(ctx context.Context) error { if target == "latest" { target = "" } - rec, err := resolveTarget(target) + rec, err := runregistry.Find(target) if err != nil { return err } diff --git a/cmd/root/proto.go b/cmd/root/proto.go index 0e328b374..8c5175c13 100644 --- a/cmd/root/proto.go +++ b/cmd/root/proto.go @@ -7,12 +7,12 @@ import ( "encoding/json" "fmt" "io" - "net/http" "sync" "github.com/spf13/cobra" "github.com/docker/docker-agent/pkg/api" + "github.com/docker/docker-agent/pkg/runregistry" "github.com/docker/docker-agent/pkg/runtime" "github.com/docker/docker-agent/pkg/telemetry" ) @@ -41,7 +41,7 @@ Output objects always carry a "type" field.`, RunE: flags.run, } - cmd.Flags().StringVar(&flags.target, "to", "", "Target run pid (defaults to the most recent live run)") + cmd.Flags().StringVar(&flags.target, "to", "", targetFlagUsage) return cmd } @@ -51,7 +51,7 @@ func (f *protoFlags) run(cmd *cobra.Command, args []string) (commandErr error) { telemetry.TrackCommand(ctx, "proto", args) defer func() { telemetry.TrackCommandError(ctx, "proto", args, commandErr) }() - rec, err := resolveTarget(f.target) + rec, err := runregistry.Find(f.target) if err != nil { return err } @@ -116,75 +116,65 @@ func readCommands(ctx context.Context, in io.Reader, client *runtime.Client, ses continue } - if err := dispatchProto(ctx, client, sessionID, req, w); err != nil { + handled, err := dispatchProto(ctx, client, sessionID, req, w) + if err != nil { w.send(map[string]any{"id": req.ID, "type": "error", "message": err.Error()}) continue } - w.send(map[string]any{"id": req.ID, "type": "ack"}) + // Skip the generic ack when the dispatcher already produced a + // typed reply (e.g. transcript), to avoid two responses per + // request. + if !handled { + w.send(map[string]any{"id": req.ID, "type": "ack"}) + } } return scanner.Err() } -func dispatchProto(ctx context.Context, client *runtime.Client, sessionID string, req protoRequest, w *protoWriter) error { +// dispatchProto routes a request to the runtime client. The bool return is +// true when dispatchProto already produced a typed reply on w; in that case +// the caller MUST NOT emit a generic "ack". +func dispatchProto(ctx context.Context, client *runtime.Client, sessionID string, req protoRequest, w *protoWriter) (bool, error) { switch req.Type { case "send": - return client.SteerSession(ctx, sessionID, []api.Message{{Content: req.Message}}) + return false, client.SteerSession(ctx, sessionID, []api.Message{{Content: req.Message}}) case "followup": - return client.FollowUpSession(ctx, sessionID, []api.Message{{Content: req.Message}}) + return false, client.FollowUpSession(ctx, sessionID, []api.Message{{Content: req.Message}}) case "resume": decision := req.Decision if decision == "" { decision = "approve" } - return client.ResumeSession(ctx, sessionID, decision, req.Reason, req.ToolName) + return false, client.ResumeSession(ctx, sessionID, decision, req.Reason, req.ToolName) case "interrupt": - return client.ResumeSession(ctx, sessionID, "reject", req.Reason, req.ToolName) + return false, client.ResumeSession(ctx, sessionID, "reject", req.Reason, req.ToolName) case "transcript": sess, err := client.GetSession(ctx, sessionID) if err != nil { - return err + return false, err } messages := sess.Messages if req.Limit > 0 && len(messages) > req.Limit { messages = messages[len(messages)-req.Limit:] } w.send(map[string]any{"id": req.ID, "type": "transcript", "title": sess.Title, "messages": messages}) - return nil + return true, nil default: - return fmt.Errorf("unknown request type %q", req.Type) + return false, fmt.Errorf("unknown request type %q", req.Type) } } +// streamEvents forwards every SSE event received from the run's control +// plane as a {"type":"event","event":} envelope on w. func streamEvents(ctx context.Context, addr, sessionID string, w *protoWriter) { - url := addr + "/api/sessions/" + sessionID + "/events" - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + body, err := openEventStream(ctx, addr, sessionID) if err != nil { return } - resp, err := http.DefaultClient.Do(req) - if err != nil || resp.StatusCode >= 400 { - if resp != nil { - resp.Body.Close() - } - return - } - defer resp.Body.Close() + defer body.Close() - scanner := bufio.NewScanner(resp.Body) - scanner.Buffer(make([]byte, 64*1024), 1024*1024) - for scanner.Scan() { - if ctx.Err() != nil { - return - } - line := scanner.Bytes() - after, ok := bytes.CutPrefix(line, []byte("data: ")) - if !ok { - continue - } - var event any - if err := json.Unmarshal(after, &event); err != nil { - continue - } - w.send(map[string]any{"type": "event", "event": event}) - } + _ = readEventStream(ctx, body, func(payload json.RawMessage) error { + w.send(map[string]any{"type": "event", "event": payload}) + return nil + }) } diff --git a/cmd/root/proto_test.go b/cmd/root/proto_test.go index 6a19398df..de856d019 100644 --- a/cmd/root/proto_test.go +++ b/cmd/root/proto_test.go @@ -51,17 +51,20 @@ func TestProtoDispatch_RoutesRequestsToHTTPClient(t *testing.T) { ctx := t.Context() cases := []struct { - req protoRequest - want string + req protoRequest + want string + wantHandled bool }{ - {protoRequest{Type: "send", Message: "hi"}, "POST /api/sessions/s1/steer"}, - {protoRequest{Type: "followup", Message: "later"}, "POST /api/sessions/s1/followup"}, - {protoRequest{Type: "resume", Decision: "approve"}, "POST /api/sessions/s1/resume"}, - {protoRequest{Type: "interrupt"}, "POST /api/sessions/s1/resume"}, - {protoRequest{Type: "transcript"}, "GET /api/sessions/s1"}, + {protoRequest{Type: "send", Message: "hi"}, "POST /api/sessions/s1/steer", false}, + {protoRequest{Type: "followup", Message: "later"}, "POST /api/sessions/s1/followup", false}, + {protoRequest{Type: "resume", Decision: "approve"}, "POST /api/sessions/s1/resume", false}, + {protoRequest{Type: "interrupt"}, "POST /api/sessions/s1/resume", false}, + {protoRequest{Type: "transcript"}, "GET /api/sessions/s1", true}, } for _, c := range cases { - require.NoError(t, dispatchProto(ctx, client, "s1", c.req, w)) + handled, err := dispatchProto(ctx, client, "s1", c.req, w) + require.NoError(t, err) + assert.Equal(t, c.wantHandled, handled, "handled flag for %s", c.req.Type) } rec.mu.Lock() @@ -81,7 +84,8 @@ func TestProtoDispatch_UnknownType(t *testing.T) { out := &bytes.Buffer{} w := newProtoWriter(out) - err := dispatchProto(t.Context(), nil, "s1", protoRequest{Type: "nope"}, w) + handled, err := dispatchProto(t.Context(), nil, "s1", protoRequest{Type: "nope"}, w) require.Error(t, err) + assert.False(t, handled) assert.Contains(t, err.Error(), "unknown request type") } diff --git a/cmd/root/run_event_hooks.go b/cmd/root/run_event_hooks.go index 29ac48a1a..07df4be88 100644 --- a/cmd/root/run_event_hooks.go +++ b/cmd/root/run_event_hooks.go @@ -13,6 +13,7 @@ import ( "github.com/docker/docker-agent/pkg/app" "github.com/docker/docker-agent/pkg/runtime" + "github.com/docker/docker-agent/pkg/shellpath" ) type onEventHook struct { @@ -66,10 +67,46 @@ func withEventHooks(hooks []onEventHook) app.Opt { } } +// maxHookOutput caps the diagnostic output we keep for a failed on-event +// hook. Large enough to be useful, small enough that a chatty or runaway +// command can't push unbounded data into the agent's heap. +const maxHookOutput = 4 * 1024 + func runEventHook(command string, payload []byte) { - cmd := exec.CommandContext(context.Background(), "sh", "-c", command) + shell, argsPrefix := shellpath.DetectShell() + // Hooks are detached from the app context on purpose: a hook still + // flushing the last event when the user exits the TUI should be allowed + // to finish. Each invocation receives a single event on stdin, processes + // it, and exits; the spawning goroutine ends with the subprocess. + cmd := exec.CommandContext(context.Background(), shell, append(argsPrefix, command)...) cmd.Stdin = bytes.NewReader(payload) + var out boundedBuffer + cmd.Stdout = &out + cmd.Stderr = &out if err := cmd.Run(); err != nil { - slog.Warn("on-event hook failed", "command", command, "error", err) + slog.Warn("on-event hook failed", "command", command, "error", err, "output", strings.TrimSpace(out.String())) } } + +// boundedBuffer captures up to maxHookOutput bytes from a hook subprocess +// and silently discards the rest. It implements only io.Writer so it can be +// assigned to exec.Cmd's Stdout/Stderr without forcing exec to buffer the +// full output internally. +type boundedBuffer struct { + buf bytes.Buffer +} + +func (b *boundedBuffer) Write(p []byte) (int, error) { + if remaining := maxHookOutput - b.buf.Len(); remaining > 0 { + if len(p) > remaining { + b.buf.Write(p[:remaining]) + } else { + b.buf.Write(p) + } + } + return len(p), nil +} + +func (b *boundedBuffer) String() string { + return b.buf.String() +} diff --git a/cmd/root/run_event_hooks_test.go b/cmd/root/run_event_hooks_test.go index 3cd6670da..559f02935 100644 --- a/cmd/root/run_event_hooks_test.go +++ b/cmd/root/run_event_hooks_test.go @@ -1,6 +1,8 @@ package root import ( + "bytes" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -27,3 +29,28 @@ func TestParseOnEventFlags_BadFormat(t *testing.T) { assert.Error(t, err, "expected error for %q", s) } } + +func TestBoundedBuffer_CapsAtMaxHookOutput(t *testing.T) { + var b boundedBuffer + + n, err := b.Write(bytes.Repeat([]byte("a"), maxHookOutput-3)) + require.NoError(t, err) + assert.Equal(t, maxHookOutput-3, n) + + // A write that straddles the cap is fully accepted from the caller's + // perspective (so io.Copy doesn't error) but only the bytes up to the + // cap are retained. + n, err = b.Write([]byte("bbbbbb")) + require.NoError(t, err) + assert.Equal(t, 6, n) + + // Subsequent writes past the cap are silently discarded. + n, err = b.Write([]byte("ccccc")) + require.NoError(t, err) + assert.Equal(t, 5, n) + + got := b.String() + assert.Len(t, got, maxHookOutput) + assert.True(t, strings.HasPrefix(got, strings.Repeat("a", maxHookOutput-3))) + assert.True(t, strings.HasSuffix(got, "bbb")) +} diff --git a/cmd/root/send.go b/cmd/root/send.go index b5cf50492..b64fc1404 100644 --- a/cmd/root/send.go +++ b/cmd/root/send.go @@ -1,10 +1,8 @@ package root import ( - "errors" "fmt" "io" - "strconv" "strings" "github.com/spf13/cobra" @@ -30,21 +28,26 @@ func newSendCmd() *cobra.Command { Long: `Send a message to the most recent docker-agent run that exposes a control plane (started with run --listen). Use - to read the message from stdin. -The target can be selected explicitly with --to ; otherwise the -most recent live run is used.`, +The target can be selected explicitly with --to, accepting a pid, an address +(http://host:port), or a session id. Without --to, the most recent live run +is used.`, Example: ` docker-agent send "summarize the diff" echo "and now write tests" | docker-agent send - - docker-agent send --to 12345 "hello"`, + docker-agent send --to 12345 "hello" + docker-agent send --to http://127.0.0.1:8765 "hi"`, GroupID: "advanced", Args: cobra.ExactArgs(1), RunE: flags.run, } - cmd.Flags().StringVar(&flags.target, "to", "", "Target run pid (defaults to the most recent live run)") + cmd.Flags().StringVar(&flags.target, "to", "", targetFlagUsage) cmd.Flags().BoolVar(&flags.followUp, "followup", false, "Queue as an end-of-turn follow-up instead of mid-turn steering") return cmd } +// targetFlagUsage is the canonical help text for --to across send/watch/proto. +const targetFlagUsage = "Target run pid, address (http://host:port), or session id (defaults to the most recent live run)" + func (f *sendFlags) run(cmd *cobra.Command, args []string) (commandErr error) { ctx := cmd.Context() telemetry.TrackCommand(ctx, "send", args) @@ -55,7 +58,7 @@ func (f *sendFlags) run(cmd *cobra.Command, args []string) (commandErr error) { return err } - rec, err := resolveTarget(f.target) + rec, err := runregistry.Find(f.target) if err != nil { return err } @@ -90,34 +93,3 @@ func readMessage(stdin io.Reader, arg string) (string, error) { } return strings.TrimRight(string(buf), "\n"), nil } - -// resolveTarget returns the registry record matching --to. An empty target -// resolves to the most recent live run; a numeric target is matched by pid. -func resolveTarget(target string) (runregistry.Record, error) { - if target == "" { - rec, ok, err := runregistry.Latest() - if err != nil { - return runregistry.Record{}, err - } - if !ok { - return runregistry.Record{}, errors.New("no live docker-agent run found; start one with: docker-agent run --listen 127.0.0.1:0") - } - return rec, nil - } - - pid, err := strconv.Atoi(target) - if err != nil { - return runregistry.Record{}, fmt.Errorf("--to must be a pid, got %q", target) - } - - records, err := runregistry.List() - if err != nil { - return runregistry.Record{}, err - } - for _, r := range records { - if r.PID == pid { - return r, nil - } - } - return runregistry.Record{}, fmt.Errorf("no live run with pid %d", pid) -} diff --git a/cmd/root/send_test.go b/cmd/root/send_test.go index d84cdaab0..0e8bee18b 100644 --- a/cmd/root/send_test.go +++ b/cmd/root/send_test.go @@ -13,16 +13,19 @@ import ( "github.com/docker/docker-agent/pkg/runregistry" ) -func TestResolveTarget_NoLiveRun(t *testing.T) { +// These smoke tests exercise the send command's reliance on +// runregistry.Find. The richer behaviour of Find itself (pid, addr, session +// id, ambiguity) lives in the runregistry package tests. + +func TestSendUsesRunregistryFind_NoLiveRun(t *testing.T) { paths.SetDataDir(t.TempDir()) t.Cleanup(func() { paths.SetDataDir("") }) - _, err := resolveTarget("") - require.Error(t, err) - assert.Contains(t, err.Error(), "no live docker-agent run") + _, err := runregistry.Find("") + require.ErrorIs(t, err, runregistry.ErrNoRun) } -func TestResolveTarget_LatestWhenEmpty(t *testing.T) { +func TestSendUsesRunregistryFind_LatestWhenEmpty(t *testing.T) { paths.SetDataDir(t.TempDir()) t.Cleanup(func() { paths.SetDataDir("") }) @@ -32,12 +35,12 @@ func TestResolveTarget_LatestWhenEmpty(t *testing.T) { require.NoError(t, err) defer cleanup() - rec, err := resolveTarget("") + rec, err := runregistry.Find("") require.NoError(t, err) assert.Equal(t, "s1", rec.SessionID) } -func TestResolveTarget_ByPID(t *testing.T) { +func TestSendUsesRunregistryFind_ByPID(t *testing.T) { paths.SetDataDir(t.TempDir()) t.Cleanup(func() { paths.SetDataDir("") }) @@ -47,13 +50,7 @@ func TestResolveTarget_ByPID(t *testing.T) { require.NoError(t, err) defer cleanup() - rec, err := resolveTarget(strconv.Itoa(os.Getpid())) + rec, err := runregistry.Find(strconv.Itoa(os.Getpid())) require.NoError(t, err) assert.Equal(t, "matched", rec.SessionID) } - -func TestResolveTarget_NonNumericTo(t *testing.T) { - _, err := resolveTarget("not-a-pid") - require.Error(t, err) - assert.Contains(t, err.Error(), "must be a pid") -} diff --git a/cmd/root/sse.go b/cmd/root/sse.go new file mode 100644 index 000000000..1d407b4a4 --- /dev/null +++ b/cmd/root/sse.go @@ -0,0 +1,68 @@ +package root + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +// maxErrorBodyBytes caps how much of an error response body we read into +// memory. SSE error replies should be tiny; this just defends the client +// against a misbehaving server that streams an unbounded error payload. +const maxErrorBodyBytes = 4 * 1024 + +// openEventStream connects to the SSE event stream of a session running on +// addr and returns the response body. Callers are responsible for closing +// the body. The body produces standard text/event-stream output with one +// JSON payload per "data:" line. +func openEventStream(ctx context.Context, addr, sessionID string) (io.ReadCloser, error) { + url := addr + "/api/sessions/" + sessionID + "/events" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return nil, err + } + req.Header.Set("Accept", "text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("connecting to %s: %w", url, err) + } + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodyBytes)) + _ = resp.Body.Close() + return nil, fmt.Errorf("server returned %s: %s", resp.Status, strings.TrimSpace(string(body))) + } + return resp.Body, nil +} + +// readEventStream reads SSE "data:" lines from r and invokes onEvent with +// each raw JSON payload. The function returns when ctx is cancelled (with +// ctx.Err), the stream ends (nil), or onEvent returns an error. +// +// Payloads are passed through as json.RawMessage so callers can either +// forward the bytes verbatim or re-decode them into a typed value without +// paying a redundant unmarshal/marshal round-trip. +func readEventStream(ctx context.Context, r io.Reader, onEvent func(json.RawMessage) error) error { + scanner := bufio.NewScanner(r) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + for scanner.Scan() { + if err := ctx.Err(); err != nil { + return err + } + after, ok := bytes.CutPrefix(scanner.Bytes(), []byte("data: ")) + if !ok { + continue + } + // Copy because the scanner reuses its underlying buffer. + payload := append(json.RawMessage(nil), after...) + if err := onEvent(payload); err != nil { + return err + } + } + return scanner.Err() +} diff --git a/cmd/root/sse_test.go b/cmd/root/sse_test.go new file mode 100644 index 000000000..c93242bb1 --- /dev/null +++ b/cmd/root/sse_test.go @@ -0,0 +1,101 @@ +package root + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadEventStream_DecodesDataLines(t *testing.T) { + body := strings.NewReader("data: {\"a\":1}\n\ndata: {\"a\":2}\n\n: ignored comment\n") + + var got []string + err := readEventStream(t.Context(), body, func(p json.RawMessage) error { + got = append(got, string(p)) + return nil + }) + require.NoError(t, err) + assert.Equal(t, []string{`{"a":1}`, `{"a":2}`}, got) +} + +func TestReadEventStream_StopsOnHandlerError(t *testing.T) { + body := strings.NewReader("data: {\"a\":1}\n\ndata: {\"a\":2}\n\n") + want := assert.AnError + + count := 0 + err := readEventStream(t.Context(), body, func(p json.RawMessage) error { + count++ + return want + }) + require.ErrorIs(t, err, want) + assert.Equal(t, 1, count) +} + +func TestOpenEventStream_PropagatesHTTPError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("nope")) + })) + defer srv.Close() + + _, err := openEventStream(t.Context(), srv.URL, "missing") + require.Error(t, err) + assert.Contains(t, err.Error(), "404") +} + +func TestOpenEventStream_StreamsSuccess(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"hello\":\"world\"}\n\n")) + })) + defer srv.Close() + + body, err := openEventStream(t.Context(), srv.URL, "s1") + require.NoError(t, err) + defer body.Close() + + var got []string + err = readEventStream(t.Context(), body, func(p json.RawMessage) error { + got = append(got, string(p)) + return nil + }) + require.NoError(t, err) + assert.Equal(t, []string{`{"hello":"world"}`}, got) +} + +// TestOpenEventStream_CapsErrorBody guards against a misbehaving server +// pushing an unbounded error body into client memory. +func TestOpenEventStream_CapsErrorBody(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(strings.Repeat("x", 10*maxErrorBodyBytes))) + })) + defer srv.Close() + + _, err := openEventStream(t.Context(), srv.URL, "s1") + require.Error(t, err) + // The error message embeds at most maxErrorBodyBytes of the body. + assert.LessOrEqual(t, len(err.Error()), maxErrorBodyBytes+256) +} + +// TestReadEventStream_ReturnsCtxErr verifies the helper surfaces ctx +// cancellation so callers can distinguish it from a clean stream end. +func TestReadEventStream_ReturnsCtxErr(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + body := strings.NewReader("data: {}\n\ndata: {}\n\n") + + calls := 0 + err := readEventStream(ctx, body, func(json.RawMessage) error { + calls++ + cancel() + return nil + }) + require.ErrorIs(t, err, context.Canceled) + assert.Equal(t, 1, calls) +} diff --git a/cmd/root/watch.go b/cmd/root/watch.go index 5f8725acb..7a858fc21 100644 --- a/cmd/root/watch.go +++ b/cmd/root/watch.go @@ -1,16 +1,15 @@ package root import ( - "bufio" - "bytes" "context" + "encoding/json" + "errors" "fmt" - "io" - "net/http" "github.com/spf13/cobra" "github.com/docker/docker-agent/pkg/cli" + "github.com/docker/docker-agent/pkg/runregistry" "github.com/docker/docker-agent/pkg/telemetry" ) @@ -28,13 +27,14 @@ func newWatchCmd() *cobra.Command { exposes a control plane (started with run --listen) and print each event as one JSON line on stdout.`, Example: ` docker-agent watch - docker-agent watch --to 12345 | jq`, + docker-agent watch --to 12345 | jq + docker-agent watch --to http://127.0.0.1:8765`, GroupID: "advanced", Args: cobra.NoArgs, RunE: flags.run, } - cmd.Flags().StringVar(&flags.target, "to", "", "Target run pid (defaults to the most recent live run)") + cmd.Flags().StringVar(&flags.target, "to", "", targetFlagUsage) return cmd } @@ -43,50 +43,28 @@ func (f *watchFlags) run(cmd *cobra.Command, args []string) (commandErr error) { telemetry.TrackCommand(ctx, "watch", args) defer func() { telemetry.TrackCommandError(ctx, "watch", args, commandErr) }() - rec, err := resolveTarget(f.target) + rec, err := runregistry.Find(f.target) if err != nil { return err } - url := rec.Addr + "/api/sessions/" + rec.SessionID + "/events" - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + body, err := openEventStream(ctx, rec.Addr, rec.SessionID) if err != nil { return err } - req.Header.Set("Accept", "text/event-stream") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("connecting to %s: %w", url, err) - } - defer resp.Body.Close() - - if resp.StatusCode >= 400 { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("server returned %s: %s", resp.Status, string(body)) - } + defer body.Close() out := cli.NewPrinter(cmd.OutOrStdout()) out.Println("Watching", rec.Addr, "(session", rec.SessionID+")") - return printSSE(ctx, resp.Body, cmd.OutOrStdout()) -} - -func printSSE(ctx context.Context, body io.Reader, out io.Writer) error { - scanner := bufio.NewScanner(body) - scanner.Buffer(make([]byte, 64*1024), 1024*1024) - for scanner.Scan() { - if ctx.Err() != nil { - return nil - } - line := scanner.Bytes() - after, ok := bytes.CutPrefix(line, []byte("data: ")) - if !ok { - continue - } - if _, err := fmt.Fprintln(out, string(after)); err != nil { - return err - } + stdout := cmd.OutOrStdout() + err = readEventStream(ctx, body, func(payload json.RawMessage) error { + _, err := fmt.Fprintln(stdout, string(payload)) + return err + }) + // Ctrl+C is the normal way to stop watching; don't treat it as failure. + if errors.Is(err, context.Canceled) { + return nil } - return scanner.Err() + return err } diff --git a/pkg/runregistry/registry.go b/pkg/runregistry/registry.go index 2910f42bd..bdbad399c 100644 --- a/pkg/runregistry/registry.go +++ b/pkg/runregistry/registry.go @@ -11,6 +11,7 @@ import ( "io/fs" "os" "path/filepath" + "slices" "strconv" "strings" "syscall" @@ -38,7 +39,8 @@ func Dir() string { // // The registry directory is created with 0o700 so other local users cannot // enumerate live PIDs/addresses by listing it. Individual records are still -// written with 0o600 for the same reason. +// written with 0o600 for the same reason. Writes go through a sibling temp +// file + rename so concurrent readers never see torn JSON. func Write(rec Record) (func(), error) { if err := os.MkdirAll(Dir(), 0o700); err != nil { return nil, fmt.Errorf("creating run registry dir: %w", err) @@ -49,13 +51,45 @@ func Write(rec Record) (func(), error) { if err != nil { return nil, err } - if err := os.WriteFile(path, buf, 0o600); err != nil { + if err := writeAtomic(path, buf, 0o600); err != nil { return nil, err } return func() { _ = os.Remove(path) }, nil } +// writeAtomic writes data to path through a sibling temp file + rename so +// readers never observe a partially-written file. +func writeAtomic(path string, data []byte, perm os.FileMode) error { + dir, name := filepath.Split(path) + tmp, err := os.CreateTemp(dir, name+".*") + if err != nil { + return err + } + tmpName := tmp.Name() + cleanup := func() { _ = os.Remove(tmpName) } + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + cleanup() + return err + } + if err := tmp.Chmod(perm); err != nil { + _ = tmp.Close() + cleanup() + return err + } + if err := tmp.Close(); err != nil { + cleanup() + return err + } + if err := os.Rename(tmpName, path); err != nil { + cleanup() + return err + } + return nil +} + // List returns every record currently registered. Stale records (whose pid is // no longer alive) are skipped and best-effort removed. func List() ([]Record, error) { @@ -96,13 +130,90 @@ func Latest() (Record, bool, error) { if err != nil || len(records) == 0 { return Record{}, false, err } - latest := records[0] - for _, r := range records[1:] { - if r.StartedAt.After(latest.StartedAt) { - latest = r + latest := slices.MaxFunc(records, func(a, b Record) int { + return a.StartedAt.Compare(b.StartedAt) + }) + return latest, true, nil +} + +// ErrNoRun is returned when no live run can be found that satisfies the +// caller's request (empty registry, or no record matches the target). +var ErrNoRun = errors.New("no live docker-agent run found; start one with: docker-agent run --listen 127.0.0.1:0") + +// Find resolves a target reference to a single live record. +// +// An empty target returns the most recently started run. A numeric target is +// matched by PID; a target starting with "http://" or "https://" is matched +// against record addresses; anything else is matched as a (possibly partial) +// session ID. PID and address matches are exact. Session-ID matching prefers +// exact equality and only falls back to substring matching when no record +// matches exactly; ambiguous substring matches return an error so callers +// don't act on the wrong session. +func Find(target string) (Record, error) { + target = strings.TrimSpace(target) + if target == "" { + rec, ok, err := Latest() + if err != nil { + return Record{}, err + } + if !ok { + return Record{}, ErrNoRun } + return rec, nil + } + + records, err := List() + if err != nil { + return Record{}, err + } + if len(records) == 0 { + return Record{}, ErrNoRun + } + + if pid, err := strconv.Atoi(target); err == nil { + for _, r := range records { + if r.PID == pid { + return r, nil + } + } + return Record{}, fmt.Errorf("no live run with pid %d: %w", pid, ErrNoRun) + } + + if strings.HasPrefix(target, "http://") || strings.HasPrefix(target, "https://") { + want := strings.TrimRight(target, "/") + for _, r := range records { + if strings.TrimRight(r.Addr, "/") == want { + return r, nil + } + } + return Record{}, fmt.Errorf("no live run at %s: %w", target, ErrNoRun) + } + + // Prefer an exact session-id match: an unambiguous full id must always + // resolve, even when other ids contain it as a substring. + for _, r := range records { + if r.SessionID == target { + return r, nil + } + } + var matches []Record + for _, r := range records { + if strings.Contains(r.SessionID, target) { + matches = append(matches, r) + } + } + switch len(matches) { + case 0: + return Record{}, fmt.Errorf("no live run matches %q (pid, http URL, or session id): %w", target, ErrNoRun) + case 1: + return matches[0], nil + default: + ids := make([]string, 0, len(matches)) + for _, r := range matches { + ids = append(ids, r.SessionID) + } + return Record{}, fmt.Errorf("ambiguous target %q: matches sessions %s", target, strings.Join(ids, ", ")) } - return latest, true, nil } // pidAlive reports whether the given pid corresponds to a live process. diff --git a/pkg/runregistry/registry_test.go b/pkg/runregistry/registry_test.go index d83651352..c47f745df 100644 --- a/pkg/runregistry/registry_test.go +++ b/pkg/runregistry/registry_test.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "testing" "time" @@ -93,6 +94,102 @@ func TestLatest_PicksMostRecent(t *testing.T) { assert.Equal(t, "new", rec.SessionID) } +func TestFind(t *testing.T) { + withTempDataDir(t) + + pid := os.Getpid() + writeRecord(t, "1.json", Record{PID: pid, Addr: "http://127.0.0.1:1111", SessionID: "alpha", StartedAt: time.Now().Add(-time.Hour)}) + writeRecord(t, "2.json", Record{PID: pid, Addr: "http://127.0.0.1:2222", SessionID: "beta", StartedAt: time.Now()}) + + t.Run("empty target returns latest", func(t *testing.T) { + rec, err := Find("") + require.NoError(t, err) + assert.Equal(t, "beta", rec.SessionID) + }) + + t.Run("by pid", func(t *testing.T) { + rec, err := Find(strconv.Itoa(pid)) + require.NoError(t, err) + assert.Equal(t, pid, rec.PID) + }) + + t.Run("by addr", func(t *testing.T) { + rec, err := Find("http://127.0.0.1:1111") + require.NoError(t, err) + assert.Equal(t, "alpha", rec.SessionID) + }) + + t.Run("by addr trims trailing slash", func(t *testing.T) { + rec, err := Find("http://127.0.0.1:2222/") + require.NoError(t, err) + assert.Equal(t, "beta", rec.SessionID) + }) + + t.Run("by session id exact", func(t *testing.T) { + rec, err := Find("alpha") + require.NoError(t, err) + assert.Equal(t, "alpha", rec.SessionID) + }) + + t.Run("unknown pid errors", func(t *testing.T) { + _, err := Find("999999999") + require.Error(t, err) + assert.Contains(t, err.Error(), "no live run with pid") + assert.ErrorIs(t, err, ErrNoRun) + }) + + t.Run("unknown addr errors", func(t *testing.T) { + _, err := Find("http://nope") + require.Error(t, err) + assert.Contains(t, err.Error(), "no live run at") + assert.ErrorIs(t, err, ErrNoRun) + }) + + t.Run("unknown session id errors", func(t *testing.T) { + _, err := Find("zzz") + require.Error(t, err) + assert.Contains(t, err.Error(), "no live run matches") + assert.ErrorIs(t, err, ErrNoRun) + }) +} + +func TestFind_AmbiguousSessionID(t *testing.T) { + withTempDataDir(t) + + pid := os.Getpid() + writeRecord(t, "1.json", Record{PID: pid, Addr: "http://a", SessionID: "shared-1", StartedAt: time.Now()}) + writeRecord(t, "2.json", Record{PID: pid, Addr: "http://b", SessionID: "shared-2", StartedAt: time.Now()}) + + _, err := Find("shared") + require.Error(t, err) + assert.Contains(t, err.Error(), "ambiguous") +} + +// TestFind_ExactMatchBeatsSubstring guards against a regression where an +// exact session-id match was reported as ambiguous because a longer id +// contained it as a substring. +func TestFind_ExactMatchBeatsSubstring(t *testing.T) { + withTempDataDir(t) + + pid := os.Getpid() + writeRecord(t, "1.json", Record{PID: pid, Addr: "http://a", SessionID: "abc", StartedAt: time.Now()}) + writeRecord(t, "2.json", Record{PID: pid, Addr: "http://b", SessionID: "abcd", StartedAt: time.Now()}) + + rec, err := Find("abc") + require.NoError(t, err) + assert.Equal(t, "abc", rec.SessionID) +} + +func TestFind_EmptyRegistry(t *testing.T) { + withTempDataDir(t) + + _, err := Find("") + require.ErrorIs(t, err, ErrNoRun) + + _, err = Find("123") + require.ErrorIs(t, err, ErrNoRun) +} + func withTempDataDir(t *testing.T) { t.Helper() paths.SetDataDir(t.TempDir())