From fbff773dc79b754a554955df46f4d38f553c0415 Mon Sep 17 00:00:00 2001 From: Andrew Garner Date: Thu, 5 Mar 2026 17:11:24 -0600 Subject: [PATCH 1/2] Add RunContext to ssh Runner for programmatic cancellation --- ssh/combo_runner.go | 35 +++++++++----- ssh/combo_runner_test.go | 72 +++++++++++++++++++++++++++++ ssh/interactive.go | 7 ++- ssh/interfaces.go | 3 ++ ssh/non_interactive.go | 8 +++- ssh/scp.go | 8 +++- ssh/sshfakes/fake_runner.go | 84 ++++++++++++++++++++++++++++++++++ ssh/sshfakes/fake_scprunner.go | 79 ++++++++++++++++++++++++++++++++ 8 files changed, 282 insertions(+), 14 deletions(-) diff --git a/ssh/combo_runner.go b/ssh/combo_runner.go index a9a68b9a3..53f9b8eb8 100644 --- a/ssh/combo_runner.go +++ b/ssh/combo_runner.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "os" "syscall" "time" @@ -51,6 +52,10 @@ func NewComboRunner( } func (r ComboRunner) Run(connOpts ConnectionOpts, result boshdir.SSHResult, cmdFactory func(boshdir.Host, SSHArgs) boshsys.Command) error { + return r.RunContext(context.Background(), connOpts, result, cmdFactory) +} + +func (r ComboRunner) RunContext(ctx context.Context, connOpts ConnectionOpts, result boshdir.SSHResult, cmdFactory func(boshdir.Host, SSHArgs) boshsys.Command) error { sess := r.sessionFactory(connOpts, result) sshArgs, err := sess.Start() @@ -68,7 +73,7 @@ func (r ComboRunner) Run(connOpts ConnectionOpts, result boshdir.SSHResult, cmdF ps, doneCh := r.runCmds(cmds) - return r.waitProcs(ps, doneCh, cancelCh) + return r.waitProcs(ctx, ps, doneCh, cancelCh) } type comboRunnerCmd struct { @@ -147,9 +152,11 @@ func (r ComboRunner) runCmds(cmds []comboRunnerCmd) ([]boshsys.Process, chan []b return processes, doneCh } -func (r ComboRunner) waitProcs(ps []boshsys.Process, doneCh chan []boshsys.Result, cancelCh chan struct{}) error { +func (r ComboRunner) waitProcs(ctx context.Context, ps []boshsys.Process, doneCh chan []boshsys.Result, cancelCh chan struct{}) error { r.logger.Debug(r.logTag, "Waiting for all processes or cancel signal") + ctxDone := ctx.Done() + for { select { case results := <-doneCh: @@ -169,16 +176,22 @@ func (r ComboRunner) waitProcs(ps []boshsys.Process, doneCh chan []boshsys.Resul case <-cancelCh: r.logger.Debug(r.logTag, "Received cancel signal") + r.terminateAll(ps) + // After terminating, doneCh will eventually fire. + + case <-ctxDone: + r.logger.Debug(r.logTag, "Context cancelled") + r.terminateAll(ps) + // Nil out so we don't re-enter this case while waiting for doneCh. + ctxDone = nil + } + } +} - for _, p := range ps { - err := p.TerminateNicely(10 * time.Second) - if err != nil { - r.logger.Error(r.logTag, "Failed to terminate with error '%s'", err.Error()) - } - } - - // Expecting that after terminating all processes - // doneCh will be signaled at some point. +func (r ComboRunner) terminateAll(ps []boshsys.Process) { + for _, p := range ps { + if err := p.TerminateNicely(10 * time.Second); err != nil { + r.logger.Error(r.logTag, "Failed to terminate with error '%s'", err) } } } diff --git a/ssh/combo_runner_test.go b/ssh/combo_runner_test.go index bebe0ecca..875b656ce 100644 --- a/ssh/combo_runner_test.go +++ b/ssh/combo_runner_test.go @@ -2,6 +2,7 @@ package ssh_test import ( "bytes" + "context" "errors" "os" "strings" @@ -376,5 +377,76 @@ var _ = Describe("ComboRunner", func() { Expect(session.FinishCallCount()).To(Equal(2)) }) }) + + Describe("context cancellation", func() { + It("terminates processes when context is cancelled", func() { + result.Hosts = []boshdir.Host{ + {Host: "127.0.0.1"}, + {Host: "127.0.0.2"}, + } + + proc1 := &fakesys.FakeProcess{ + TerminatedNicelyCallBack: func(p *fakesys.FakeProcess) { + p.WaitCh <- boshsys.Result{} + }, + } + cmdRunner.AddProcess("cmd 127.0.0.1", proc1) + + proc2 := &fakesys.FakeProcess{ + TerminatedNicelyCallBack: func(p *fakesys.FakeProcess) { + p.WaitCh <- boshsys.Result{Error: errors.New("term-err")} + }, + } + cmdRunner.AddProcess("cmd 127.0.0.2", proc2) + + ctx, cancel := context.WithCancel(context.Background()) + + errCh := make(chan error) + go func() { + defer GinkgoRecover() + errCh <- comboRunner.RunContext(ctx, connOpts, result, cmdFactory) + }() + + cancel() + + var err error + Eventually(errCh).Should(Receive(&err)) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("term-err")) + + Expect(proc1.TerminatedNicely).To(BeTrue()) + Expect(proc2.TerminatedNicely).To(BeTrue()) + Expect(session.FinishCallCount()).To(Equal(1)) + }) + + It("returns without error when context is cancelled and all processes exit cleanly", func() { + result.Hosts = []boshdir.Host{ + {Host: "127.0.0.1"}, + } + + proc1 := &fakesys.FakeProcess{ + TerminatedNicelyCallBack: func(p *fakesys.FakeProcess) { + p.WaitCh <- boshsys.Result{} + }, + } + cmdRunner.AddProcess("cmd 127.0.0.1", proc1) + + ctx, cancel := context.WithCancel(context.Background()) + + errCh := make(chan error) + go func() { + defer GinkgoRecover() + errCh <- comboRunner.RunContext(ctx, connOpts, result, cmdFactory) + }() + + cancel() + + var err error + Eventually(errCh).Should(Receive(&err)) + Expect(err).ToNot(HaveOccurred()) + + Expect(proc1.TerminatedNicely).To(BeTrue()) + }) + }) }) }) diff --git a/ssh/interactive.go b/ssh/interactive.go index 4999c09f8..b68af9c38 100644 --- a/ssh/interactive.go +++ b/ssh/interactive.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "os" bosherr "github.com/cloudfoundry/bosh-utils/errors" @@ -18,6 +19,10 @@ func NewInteractiveRunner(comboRunner ComboRunner) InteractiveRunner { } func (r InteractiveRunner) Run(connOpts ConnectionOpts, result boshdir.SSHResult, rawCmd []string) error { + return r.RunContext(context.Background(), connOpts, result, rawCmd) +} + +func (r InteractiveRunner) RunContext(ctx context.Context, connOpts ConnectionOpts, result boshdir.SSHResult, rawCmd []string) error { if len(result.Hosts) != 1 { return bosherr.Errorf("Interactive SSH only works for a single host at a time") } @@ -39,5 +44,5 @@ func (r InteractiveRunner) Run(connOpts ConnectionOpts, result boshdir.SSHResult } } - return r.comboRunner.Run(connOpts, result, cmdFactory) + return r.comboRunner.RunContext(ctx, connOpts, result, cmdFactory) } diff --git a/ssh/interfaces.go b/ssh/interfaces.go index 3dff7c08a..7e0c1f2fc 100644 --- a/ssh/interfaces.go +++ b/ssh/interfaces.go @@ -1,6 +1,7 @@ package ssh import ( + "context" "io" boshdir "github.com/cloudfoundry/bosh-cli/v7/director" @@ -13,12 +14,14 @@ import ( type Runner interface { Run(ConnectionOpts, boshdir.SSHResult, []string) error + RunContext(ctx context.Context, connOpts ConnectionOpts, result boshdir.SSHResult, cmd []string) error } //counterfeiter:generate . SCPRunner type SCPRunner interface { Run(ConnectionOpts, boshdir.SSHResult, SCPArgs) error + RunContext(ctx context.Context, connOpts ConnectionOpts, result boshdir.SSHResult, scpArgs SCPArgs) error } type ConnectionOpts struct { diff --git a/ssh/non_interactive.go b/ssh/non_interactive.go index 0b9a72a34..3627edb5a 100644 --- a/ssh/non_interactive.go +++ b/ssh/non_interactive.go @@ -1,6 +1,8 @@ package ssh import ( + "context" + bosherr "github.com/cloudfoundry/bosh-utils/errors" boshsys "github.com/cloudfoundry/bosh-utils/system" @@ -16,6 +18,10 @@ func NewNonInteractiveRunner(comboRunner ComboRunner) NonInteractiveRunner { } func (r NonInteractiveRunner) Run(connOpts ConnectionOpts, result boshdir.SSHResult, rawCmd []string) error { + return r.RunContext(context.Background(), connOpts, result, rawCmd) +} + +func (r NonInteractiveRunner) RunContext(ctx context.Context, connOpts ConnectionOpts, result boshdir.SSHResult, rawCmd []string) error { if len(result.Hosts) == 0 { return bosherr.Errorf("Non-interactive SSH expects at least one host") } @@ -31,5 +37,5 @@ func (r NonInteractiveRunner) Run(connOpts ConnectionOpts, result boshdir.SSHRes } } - return r.comboRunner.Run(connOpts, result, cmdFactory) + return r.comboRunner.RunContext(ctx, connOpts, result, cmdFactory) } diff --git a/ssh/scp.go b/ssh/scp.go index 172f61b50..8cebb2b51 100644 --- a/ssh/scp.go +++ b/ssh/scp.go @@ -1,6 +1,8 @@ package ssh import ( + "context" + boshsys "github.com/cloudfoundry/bosh-utils/system" boshdir "github.com/cloudfoundry/bosh-cli/v7/director" @@ -15,6 +17,10 @@ func NewSCPRunner(comboRunner ComboRunner) SCPRunnerImpl { } func (r SCPRunnerImpl) Run(connOpts ConnectionOpts, result boshdir.SSHResult, scpArgs SCPArgs) error { + return r.RunContext(context.Background(), connOpts, result, scpArgs) +} + +func (r SCPRunnerImpl) RunContext(ctx context.Context, connOpts ConnectionOpts, result boshdir.SSHResult, scpArgs SCPArgs) error { cmdFactory := func(host boshdir.Host, sshArgs SSHArgs) boshsys.Command { return boshsys.Command{ Name: "scp", @@ -22,5 +28,5 @@ func (r SCPRunnerImpl) Run(connOpts ConnectionOpts, result boshdir.SSHResult, sc } } - return r.comboRunner.Run(connOpts, result, cmdFactory) + return r.comboRunner.RunContext(ctx, connOpts, result, cmdFactory) } diff --git a/ssh/sshfakes/fake_runner.go b/ssh/sshfakes/fake_runner.go index d897d34fa..68cb533a7 100644 --- a/ssh/sshfakes/fake_runner.go +++ b/ssh/sshfakes/fake_runner.go @@ -2,6 +2,7 @@ package sshfakes import ( + "context" "sync" "github.com/cloudfoundry/bosh-cli/v7/director" @@ -22,6 +23,20 @@ type FakeRunner struct { runReturnsOnCall map[int]struct { result1 error } + RunContextStub func(context.Context, ssh.ConnectionOpts, director.SSHResult, []string) error + runContextMutex sync.RWMutex + runContextArgsForCall []struct { + arg1 context.Context + arg2 ssh.ConnectionOpts + arg3 director.SSHResult + arg4 []string + } + runContextReturns struct { + result1 error + } + runContextReturnsOnCall map[int]struct { + result1 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } @@ -94,6 +109,75 @@ func (fake *FakeRunner) RunReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeRunner) RunContext(arg1 context.Context, arg2 ssh.ConnectionOpts, arg3 director.SSHResult, arg4 []string) error { + var arg4Copy []string + if arg4 != nil { + arg4Copy = make([]string, len(arg4)) + copy(arg4Copy, arg4) + } + fake.runContextMutex.Lock() + ret, specificReturn := fake.runContextReturnsOnCall[len(fake.runContextArgsForCall)] + fake.runContextArgsForCall = append(fake.runContextArgsForCall, struct { + arg1 context.Context + arg2 ssh.ConnectionOpts + arg3 director.SSHResult + arg4 []string + }{arg1, arg2, arg3, arg4Copy}) + stub := fake.RunContextStub + fakeReturns := fake.runContextReturns + fake.recordInvocation("RunContext", []interface{}{arg1, arg2, arg3, arg4Copy}) + fake.runContextMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRunner) RunContextCallCount() int { + fake.runContextMutex.RLock() + defer fake.runContextMutex.RUnlock() + return len(fake.runContextArgsForCall) +} + +func (fake *FakeRunner) RunContextCalls(stub func(context.Context, ssh.ConnectionOpts, director.SSHResult, []string) error) { + fake.runContextMutex.Lock() + defer fake.runContextMutex.Unlock() + fake.RunContextStub = stub +} + +func (fake *FakeRunner) RunContextArgsForCall(i int) (context.Context, ssh.ConnectionOpts, director.SSHResult, []string) { + fake.runContextMutex.RLock() + defer fake.runContextMutex.RUnlock() + argsForCall := fake.runContextArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeRunner) RunContextReturns(result1 error) { + fake.runContextMutex.Lock() + defer fake.runContextMutex.Unlock() + fake.RunContextStub = nil + fake.runContextReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRunner) RunContextReturnsOnCall(i int, result1 error) { + fake.runContextMutex.Lock() + defer fake.runContextMutex.Unlock() + fake.RunContextStub = nil + if fake.runContextReturnsOnCall == nil { + fake.runContextReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.runContextReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeRunner) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() diff --git a/ssh/sshfakes/fake_scprunner.go b/ssh/sshfakes/fake_scprunner.go index 786a9747a..841c5e057 100644 --- a/ssh/sshfakes/fake_scprunner.go +++ b/ssh/sshfakes/fake_scprunner.go @@ -2,6 +2,7 @@ package sshfakes import ( + "context" "sync" "github.com/cloudfoundry/bosh-cli/v7/director" @@ -22,6 +23,20 @@ type FakeSCPRunner struct { runReturnsOnCall map[int]struct { result1 error } + RunContextStub func(context.Context, ssh.ConnectionOpts, director.SSHResult, ssh.SCPArgs) error + runContextMutex sync.RWMutex + runContextArgsForCall []struct { + arg1 context.Context + arg2 ssh.ConnectionOpts + arg3 director.SSHResult + arg4 ssh.SCPArgs + } + runContextReturns struct { + result1 error + } + runContextReturnsOnCall map[int]struct { + result1 error + } invocations map[string][][]interface{} invocationsMutex sync.RWMutex } @@ -89,6 +104,70 @@ func (fake *FakeSCPRunner) RunReturnsOnCall(i int, result1 error) { }{result1} } +func (fake *FakeSCPRunner) RunContext(arg1 context.Context, arg2 ssh.ConnectionOpts, arg3 director.SSHResult, arg4 ssh.SCPArgs) error { + fake.runContextMutex.Lock() + ret, specificReturn := fake.runContextReturnsOnCall[len(fake.runContextArgsForCall)] + fake.runContextArgsForCall = append(fake.runContextArgsForCall, struct { + arg1 context.Context + arg2 ssh.ConnectionOpts + arg3 director.SSHResult + arg4 ssh.SCPArgs + }{arg1, arg2, arg3, arg4}) + stub := fake.RunContextStub + fakeReturns := fake.runContextReturns + fake.recordInvocation("RunContext", []interface{}{arg1, arg2, arg3, arg4}) + fake.runContextMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeSCPRunner) RunContextCallCount() int { + fake.runContextMutex.RLock() + defer fake.runContextMutex.RUnlock() + return len(fake.runContextArgsForCall) +} + +func (fake *FakeSCPRunner) RunContextCalls(stub func(context.Context, ssh.ConnectionOpts, director.SSHResult, ssh.SCPArgs) error) { + fake.runContextMutex.Lock() + defer fake.runContextMutex.Unlock() + fake.RunContextStub = stub +} + +func (fake *FakeSCPRunner) RunContextArgsForCall(i int) (context.Context, ssh.ConnectionOpts, director.SSHResult, ssh.SCPArgs) { + fake.runContextMutex.RLock() + defer fake.runContextMutex.RUnlock() + argsForCall := fake.runContextArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeSCPRunner) RunContextReturns(result1 error) { + fake.runContextMutex.Lock() + defer fake.runContextMutex.Unlock() + fake.RunContextStub = nil + fake.runContextReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeSCPRunner) RunContextReturnsOnCall(i int, result1 error) { + fake.runContextMutex.Lock() + defer fake.runContextMutex.Unlock() + fake.RunContextStub = nil + if fake.runContextReturnsOnCall == nil { + fake.runContextReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.runContextReturnsOnCall[i] = struct { + result1 error + }{result1} +} + func (fake *FakeSCPRunner) Invocations() map[string][][]interface{} { fake.invocationsMutex.RLock() defer fake.invocationsMutex.RUnlock() From 1850e8da080b20ecac814b65eba006c6f6f410b9 Mon Sep 17 00:00:00 2001 From: Andrew Garner Date: Thu, 5 Mar 2026 11:54:21 -0600 Subject: [PATCH 2/2] Add --stream-logs flag to run-errand for live log tailing via SSH When --stream-logs is set, the errand is started asynchronously and task events are polled to discover running instances. SSH sessions are established to each instance to tail errand log files in real time, interleaved with the standard task event output. Requirements: - The CLI must have SSH access to the BOSH VMs where errands run (directly or via a gateway configured with --gw-* flags). - Errands must write output to log files under /var/vcap/sys/log/. By default, the CLI tails /var/vcap/sys/log//.{stdout,stderr}.log. Errands that write to non-standard paths will not have their output streamed unless --stream-log-path is specified. - Use --stream-log-path to override the default log location, e.g. --stream-log-path "my-job/*.log". The path is relative to /var/vcap/sys/log/ and supports shell glob and brace expansion. --- cmd/cmd.go | 9 +- cmd/errand_event_watcher.go | 175 ++++++++ cmd/errand_event_watcher_test.go | 351 +++++++++++++++ cmd/factory.go | 9 +- cmd/opts/opts.go | 5 + cmd/run_errand.go | 207 ++++++++- cmd/run_errand_test.go | 519 +++++++++++++++++++++- director/directorfakes/fake_deployment.go | 328 ++++++++++++++ director/errands.go | 81 +++- director/interfaces.go | 4 + director/task_client_request.go | 64 +++ director/tasks.go | 12 + 12 files changed, 1741 insertions(+), 23 deletions(-) create mode 100644 cmd/errand_event_watcher.go create mode 100644 cmd/errand_event_watcher_test.go diff --git a/cmd/cmd.go b/cmd/cmd.go index c3d858e4a..fd8e62abe 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -226,7 +226,14 @@ func (c Cmd) Execute() (cmdErr error) { case *RunErrandOpts: director, deployment := c.directorAndDeployment() downloader := NewUIDownloader(director, deps.Time, deps.FS, deps.UI) - return NewRunErrandCmd(deployment, downloader, deps.UI).Run(*opts) + var nonIntSSHRunner boshssh.Runner + var taskReporter boshdir.TaskReporter + if opts.StreamLogs { + sshProvider := boshssh.NewProvider(deps.CmdRunner, deps.FS, deps.UI, deps.Logger) + nonIntSSHRunner = sshProvider.NewSSHRunner(false) + taskReporter = boshuit.NewReporter(deps.UI, true) + } + return NewRunErrandCmd(deployment, downloader, deps.UI, nonIntSSHRunner, taskReporter, nil).Run(*opts) case *AttachDiskOpts: return NewAttachDiskCmd(c.deployment()).Run(*opts) diff --git a/cmd/errand_event_watcher.go b/cmd/errand_event_watcher.go new file mode 100644 index 000000000..918823160 --- /dev/null +++ b/cmd/errand_event_watcher.go @@ -0,0 +1,175 @@ +package cmd + +import ( + "encoding/json" + "strings" + "time" + + boshdir "github.com/cloudfoundry/bosh-cli/v7/director" +) + +type taskEvent struct { + Stage string `json:"stage"` + State string `json:"state"` + Task string `json:"task"` + Time int64 `json:"time"` + Index int `json:"index"` + Total int `json:"total"` + Progress int `json:"progress"` +} + +type ErrandEventWatcher struct { + deployment boshdir.Deployment + taskID int + pollDelay time.Duration + taskReporter boshdir.TaskReporter +} + +func NewErrandEventWatcher(deployment boshdir.Deployment, taskID int, pollDelay time.Duration) *ErrandEventWatcher { + return &ErrandEventWatcher{ + deployment: deployment, + taskID: taskID, + pollDelay: pollDelay, + } +} + +func (w *ErrandEventWatcher) WithTaskReporter(reporter boshdir.TaskReporter) *ErrandEventWatcher { + w.taskReporter = reporter + return w +} + +// Watch polls the task event stream and sends discovered instance slugs +// (e.g. "smoke-tests/abc-123") on the returned channel. The channel is closed +// when the task is no longer running. If a TaskReporter is set, event chunks +// are also fed to it for real-time formatted output. +func (w *ErrandEventWatcher) Watch(stopCh <-chan struct{}) <-chan string { + slugCh := make(chan string, 16) + + if w.taskReporter != nil { + w.taskReporter.TaskStarted(w.taskID) + } + + go func() { + defer close(slugCh) + + var offset int + seen := map[string]bool{} + var lastState string + + for { + select { + case <-stopCh: + return + default: + } + + chunk, newOffset, err := w.deployment.FetchTaskOutputChunk(w.taskID, offset, "event") + if err == nil && len(chunk) > 0 { + offset = newOffset + if w.taskReporter != nil { + w.taskReporter.TaskOutputChunk(w.taskID, chunk) + } + for _, slug := range parseEventChunk(chunk) { + if !seen[slug] { + seen[slug] = true + select { + case slugCh <- slug: + case <-stopCh: + return + } + } + } + } + + state, err := w.deployment.TaskState(w.taskID) + if err != nil || !isTaskRunning(state) { + if err == nil { + lastState = state + } + // One final fetch to catch any remaining events + chunk, _, err = w.deployment.FetchTaskOutputChunk(w.taskID, offset, "event") + if err == nil && len(chunk) > 0 { + if w.taskReporter != nil { + w.taskReporter.TaskOutputChunk(w.taskID, chunk) + } + for _, slug := range parseEventChunk(chunk) { + if !seen[slug] { + seen[slug] = true + select { + case slugCh <- slug: + case <-stopCh: + return + } + } + } + } + if w.taskReporter != nil { + w.taskReporter.TaskFinished(w.taskID, lastState) + } + return + } + + select { + case <-time.After(w.pollDelay): + case <-stopCh: + return + } + } + }() + + return slugCh +} + +func parseEventChunk(chunk []byte) []string { + var slugs []string + + for _, line := range strings.Split(string(chunk), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + slug := parseErrandEventLine(line) + if slug != "" { + slugs = append(slugs, slug) + } + } + + return slugs +} + +func parseErrandEventLine(line string) string { + var ev taskEvent + if err := json.Unmarshal([]byte(line), &ev); err != nil { + return "" + } + + if ev.Stage != "Running errand" || ev.State != "started" { + return "" + } + + return ParseInstanceSlug(ev.Task) +} + +// ParseInstanceSlug extracts "group/uuid" from a task field like "group/uuid (idx)". +func ParseInstanceSlug(task string) string { + task = strings.TrimSpace(task) + if task == "" { + return "" + } + + // Strip trailing " (N)" index suffix if present + if idx := strings.LastIndex(task, " ("); idx >= 0 { + task = task[:idx] + } + + if !strings.Contains(task, "/") { + return "" + } + + return task +} + +func isTaskRunning(state string) bool { + return state == "queued" || state == "processing" || state == "cancelling" +} diff --git a/cmd/errand_event_watcher_test.go b/cmd/errand_event_watcher_test.go new file mode 100644 index 000000000..5adfdc708 --- /dev/null +++ b/cmd/errand_event_watcher_test.go @@ -0,0 +1,351 @@ +package cmd_test + +import ( + "encoding/json" + "strings" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/cloudfoundry/bosh-cli/v7/cmd" + fakedir "github.com/cloudfoundry/bosh-cli/v7/director/directorfakes" +) + +var _ = Describe("ErrandEventWatcher", func() { + Describe("ParseInstanceSlug", func() { + It("extracts group/uuid from 'group/uuid (idx)'", func() { + Expect(cmd.ParseInstanceSlug("smoke-tests/abc-123 (0)")).To(Equal("smoke-tests/abc-123")) + }) + + It("extracts group/uuid when no index suffix", func() { + Expect(cmd.ParseInstanceSlug("smoke-tests/abc-123")).To(Equal("smoke-tests/abc-123")) + }) + + It("returns empty for group-only (no slash)", func() { + Expect(cmd.ParseInstanceSlug("smoke-tests")).To(Equal("")) + }) + + It("returns empty for empty string", func() { + Expect(cmd.ParseInstanceSlug("")).To(Equal("")) + }) + + It("handles whitespace", func() { + Expect(cmd.ParseInstanceSlug(" smoke-tests/abc-123 (2) ")).To(Equal("smoke-tests/abc-123")) + }) + }) + + Describe("Watch", func() { + var ( + deployment *fakedir.FakeDeployment + ) + + BeforeEach(func() { + deployment = &fakedir.FakeDeployment{} + }) + + makeEvent := func(stage, state, task string) string { + ev := map[string]any{ + "stage": stage, + "state": state, + "task": task, + "time": 1772657703, + } + b, err := json.Marshal(ev) + Expect(err).NotTo(HaveOccurred()) + return string(b) + } + + It("parses 'Running errand' started events and emits instance slugs", func() { + events := strings.Join([]string{ + makeEvent("Preparing deployment", "started", "Preparing deployment"), + makeEvent("Running errand", "started", "smoke-tests/abc-123 (0)"), + }, "\n") + + callCount := 0 + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + callCount++ + if callCount == 1 { + return []byte(events), len(events), nil + } + return nil, offset, nil + } + deployment.TaskStateReturns("done", nil) + + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + var slugs []string + for s := range slugCh { + slugs = append(slugs, s) + } + + Expect(slugs).To(Equal([]string{"smoke-tests/abc-123"})) + }) + + It("handles multiple instances", func() { + events := strings.Join([]string{ + makeEvent("Running errand", "started", "mysql/aaa-111 (0)"), + makeEvent("Running errand", "started", "mysql/bbb-222 (1)"), + makeEvent("Running errand", "started", "mysql/ccc-333 (2)"), + }, "\n") + + callCount := 0 + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + callCount++ + if callCount == 1 { + return []byte(events), len(events), nil + } + return nil, offset, nil + } + deployment.TaskStateReturns("done", nil) + + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + var slugs []string + for s := range slugCh { + slugs = append(slugs, s) + } + + Expect(slugs).To(ConsistOf("mysql/aaa-111", "mysql/bbb-222", "mysql/ccc-333")) + }) + + It("ignores non-errand events", func() { + events := strings.Join([]string{ + makeEvent("Preparing deployment", "started", "Preparing deployment"), + makeEvent("Creating missing vms", "started", "smoke-tests/abc-123 (0)"), + makeEvent("Fetching logs", "started", "smoke-tests/abc-123 (0)"), + }, "\n") + + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + return []byte(events), len(events), nil + } + deployment.TaskStateReturns("done", nil) + + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + var slugs []string + for s := range slugCh { + slugs = append(slugs, s) + } + + Expect(slugs).To(BeEmpty()) + }) + + It("ignores finished/failed states for Running errand", func() { + events := strings.Join([]string{ + makeEvent("Running errand", "finished", "smoke-tests/abc-123 (0)"), + makeEvent("Running errand", "failed", "smoke-tests/abc-123 (0)"), + }, "\n") + + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + return []byte(events), len(events), nil + } + deployment.TaskStateReturns("done", nil) + + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + var slugs []string + for s := range slugCh { + slugs = append(slugs, s) + } + + Expect(slugs).To(BeEmpty()) + }) + + It("handles malformed JSON gracefully", func() { + events := strings.Join([]string{ + "this is not json", + "", + makeEvent("Running errand", "started", "smoke-tests/abc-123 (0)"), + "{broken json", + }, "\n") + + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + return []byte(events), len(events), nil + } + deployment.TaskStateReturns("done", nil) + + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + var slugs []string + for s := range slugCh { + slugs = append(slugs, s) + } + + Expect(slugs).To(Equal([]string{"smoke-tests/abc-123"})) + }) + + It("deduplicates slugs", func() { + events := strings.Join([]string{ + makeEvent("Running errand", "started", "smoke-tests/abc-123 (0)"), + makeEvent("Running errand", "started", "smoke-tests/abc-123 (0)"), + }, "\n") + + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + return []byte(events), len(events), nil + } + deployment.TaskStateReturns("done", nil) + + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + var slugs []string + for s := range slugCh { + slugs = append(slugs, s) + } + + Expect(slugs).To(Equal([]string{"smoke-tests/abc-123"})) + }) + + It("polls incrementally using offset", func() { + event1 := makeEvent("Running errand", "started", "mysql/aaa-111 (0)") + event2 := makeEvent("Running errand", "started", "mysql/bbb-222 (1)") + + callCount := 0 + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + callCount++ + switch callCount { + case 1: + return []byte(event1 + "\n"), len(event1) + 1, nil + case 2: + return []byte(event2 + "\n"), len(event1) + 1 + len(event2) + 1, nil + default: + return nil, offset, nil + } + } + + taskStateCallCount := 0 + deployment.TaskStateStub = func(id int) (string, error) { + taskStateCallCount++ + if taskStateCallCount <= 2 { + return "processing", nil + } + return "done", nil + } + + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + var slugs []string + for s := range slugCh { + slugs = append(slugs, s) + } + + Expect(slugs).To(ConsistOf("mysql/aaa-111", "mysql/bbb-222")) + Expect(deployment.FetchTaskOutputChunkCallCount()).To(BeNumerically(">=", 2)) + + _, secondOffset, _ := deployment.FetchTaskOutputChunkArgsForCall(1) + Expect(secondOffset).To(Equal(len(event1) + 1)) + }) + + It("feeds event chunks to TaskReporter when set", func() { + events := strings.Join([]string{ + makeEvent("Preparing deployment", "started", "Preparing deployment"), + makeEvent("Running errand", "started", "smoke-tests/abc-123 (0)"), + }, "\n") + + callCount := 0 + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + callCount++ + if callCount == 1 { + return []byte(events), len(events), nil + } + return nil, offset, nil + } + deployment.TaskStateReturns("done", nil) + + reporter := &fakedir.FakeTaskReporter{} + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + watcher.WithTaskReporter(reporter) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + for range slugCh { + } + + Expect(reporter.TaskStartedCallCount()).To(Equal(1)) + startedID := reporter.TaskStartedArgsForCall(0) + Expect(startedID).To(Equal(42)) + + Expect(reporter.TaskOutputChunkCallCount()).To(BeNumerically(">=", 1)) + chunkID, chunkData := reporter.TaskOutputChunkArgsForCall(0) + Expect(chunkID).To(Equal(42)) + Expect(string(chunkData)).To(ContainSubstring("Running errand")) + + Expect(reporter.TaskFinishedCallCount()).To(Equal(1)) + finishedID, finishedState := reporter.TaskFinishedArgsForCall(0) + Expect(finishedID).To(Equal(42)) + Expect(finishedState).To(Equal("done")) + }) + + It("discovers slugs from the final fetch after task completes and reports them", func() { + lateEvent := makeEvent("Running errand", "started", "smoke-tests/abc-123 (0)") + + // The goroutine calls FetchTaskOutputChunk twice per iteration + // (once in the normal poll, once in the final fetch when the task + // is done). We need the data to appear only in the final fetch + // (call 3), not the normal poll (calls 1 and 2). + // + // Iteration 1: fetch(1)=empty, TaskState="done" -> final fetch(2)=event + fetchCallCount := 0 + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + fetchCallCount++ + if fetchCallCount == 2 { + return []byte(lateEvent), len(lateEvent), nil + } + return nil, offset, nil + } + + deployment.TaskStateReturns("done", nil) + + reporter := &fakedir.FakeTaskReporter{} + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + watcher.WithTaskReporter(reporter) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + var slugs []string + for s := range slugCh { + slugs = append(slugs, s) + } + + Expect(slugs).To(Equal([]string{"smoke-tests/abc-123"})) + + Expect(reporter.TaskOutputChunkCallCount()).To(Equal(1)) + chunkID, chunkData := reporter.TaskOutputChunkArgsForCall(0) + Expect(chunkID).To(Equal(42)) + Expect(string(chunkData)).To(ContainSubstring("Running errand")) + }) + + It("stops when stopCh is closed", func() { + deployment.FetchTaskOutputChunkStub = func(taskID, offset int, type_ string) ([]byte, int, error) { + return nil, offset, nil + } + deployment.TaskStateReturns("processing", nil) + + watcher := cmd.NewErrandEventWatcher(deployment, 42, 0) + stopCh := make(chan struct{}) + slugCh := watcher.Watch(stopCh) + + close(stopCh) + + var slugs []string + for s := range slugCh { + slugs = append(slugs, s) + } + + Expect(slugs).To(BeEmpty()) + }) + }) +}) diff --git a/cmd/factory.go b/cmd/factory.go index 77d9a1f4e..dda364136 100644 --- a/cmd/factory.go +++ b/cmd/factory.go @@ -94,10 +94,11 @@ func (f Factory) New(args []string) (Cmd, error) { return nil } - boshOpts.SSH.GatewayFlags.UUIDGen = f.deps.UUIDGen //nolint:staticcheck - boshOpts.SCP.GatewayFlags.UUIDGen = f.deps.UUIDGen //nolint:staticcheck - boshOpts.Logs.GatewayFlags.UUIDGen = f.deps.UUIDGen //nolint:staticcheck - boshOpts.Pcap.GatewayFlags.UUIDGen = f.deps.UUIDGen //nolint:staticcheck + boshOpts.SSH.GatewayFlags.UUIDGen = f.deps.UUIDGen //nolint:staticcheck + boshOpts.SCP.GatewayFlags.UUIDGen = f.deps.UUIDGen //nolint:staticcheck + boshOpts.Logs.GatewayFlags.UUIDGen = f.deps.UUIDGen //nolint:staticcheck + boshOpts.Pcap.GatewayFlags.UUIDGen = f.deps.UUIDGen //nolint:staticcheck + boshOpts.RunErrand.GatewayFlags.UUIDGen = f.deps.UUIDGen //nolint:staticcheck helpText := bytes.NewBufferString("") parser.WriteHelp(helpText) diff --git a/cmd/opts/opts.go b/cmd/opts/opts.go index ca3673fcf..9ab33dc4c 100644 --- a/cmd/opts/opts.go +++ b/cmd/opts/opts.go @@ -716,6 +716,11 @@ type RunErrandOpts struct { DownloadLogs bool `long:"download-logs" description:"Download logs"` LogsDirectory DirOrCWDArg `long:"logs-dir" description:"Destination directory for logs" default:"."` + StreamLogs bool `long:"stream-logs" description:"Stream errand log files via SSH while errand runs"` + StreamLogPath string `long:"stream-log-path" description:"Log file path to tail (default: /var/vcap/sys/log/$ERRAND/$ERRAND.{stdout,stderr}.log)"` + + GatewayFlags + cmd } diff --git a/cmd/run_errand.go b/cmd/run_errand.go index 87d26796c..5461f4e03 100644 --- a/cmd/run_errand.go +++ b/cmd/run_errand.go @@ -1,31 +1,69 @@ package cmd import ( - bosherr "github.com/cloudfoundry/bosh-utils/errors" - + "context" "fmt" + "os/signal" + "regexp" + "strings" + "sync" + "syscall" + "time" + + bosherr "github.com/cloudfoundry/bosh-utils/errors" - . "github.com/cloudfoundry/bosh-cli/v7/cmd/opts" //nolint:staticcheck boshdir "github.com/cloudfoundry/bosh-cli/v7/director" + boshssh "github.com/cloudfoundry/bosh-cli/v7/ssh" biui "github.com/cloudfoundry/bosh-cli/v7/ui" boshtbl "github.com/cloudfoundry/bosh-cli/v7/ui/table" + + . "github.com/cloudfoundry/bosh-cli/v7/cmd/opts" //nolint:staticcheck ) +var safeLogPathRe = regexp.MustCompile(`^[a-zA-Z0-9._\-/*{},]+$`) + +type ErrandContextFunc func() (context.Context, context.CancelFunc) + type RunErrandCmd struct { - deployment boshdir.Deployment - downloader Downloader - ui biui.UI + deployment boshdir.Deployment + downloader Downloader + ui biui.UI + nonIntSSHRunner boshssh.Runner + taskReporter boshdir.TaskReporter + ctxFactory ErrandContextFunc } func NewRunErrandCmd( deployment boshdir.Deployment, downloader Downloader, ui biui.UI, + nonIntSSHRunner boshssh.Runner, + taskReporter boshdir.TaskReporter, + ctxFactory ErrandContextFunc, ) RunErrandCmd { - return RunErrandCmd{deployment: deployment, downloader: downloader, ui: ui} + if ctxFactory == nil { + ctxFactory = func() (context.Context, context.CancelFunc) { + return signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + } + } + return RunErrandCmd{ + deployment: deployment, + downloader: downloader, + ui: ui, + nonIntSSHRunner: nonIntSSHRunner, + taskReporter: taskReporter, + ctxFactory: ctxFactory, + } } func (c RunErrandCmd) Run(opts RunErrandOpts) error { + if opts.StreamLogs { + return c.runWithStreaming(opts) + } + return c.runWithoutStreaming(opts) +} + +func (c RunErrandCmd) runWithoutStreaming(opts RunErrandOpts) error { results, err := c.deployment.RunErrand( opts.Args.Name, opts.KeepAlive, @@ -36,9 +74,135 @@ func (c RunErrandCmd) Run(opts RunErrandOpts) error { return err } + return c.finishErrand(opts, results) +} + +func (c RunErrandCmd) runWithStreaming(opts RunErrandOpts) error { + if c.nonIntSSHRunner == nil { + return bosherr.Errorf("SSH runner is required for --stream-logs") + } + + tailCmd, err := BuildErrandTailCmd(opts.Args.Name, opts.StreamLogPath) + if err != nil { + return err + } + + taskID, err := c.deployment.StartErrand( + opts.Args.Name, + opts.KeepAlive, + opts.WhenChanged, + opts.InstanceGroupOrInstanceSlugFlags.Slugs, //nolint:staticcheck + ) + if err != nil { + return err + } + + c.ui.PrintLinef("Errand started as task %d, streaming logs...", taskID) + + sshOpts, connOpts, err := opts.GatewayFlags.AsSSHOpts() //nolint:staticcheck + if err != nil { + return err + } + + // sshCtx is cancelled either when we call sshCancel (normal exit) or + // when the process receives SIGINT/SIGTERM (Ctrl+C). In both cases + // ComboRunner.waitProcs sees ctx.Done() and terminates the ssh processes. + sshCtx, sshCancel := c.ctxFactory() + defer sshCancel() + + stopCh := make(chan struct{}) + var stopOnce sync.Once + closeStop := func() { stopOnce.Do(func() { close(stopCh) }) } + defer closeStop() + + watcher := NewErrandEventWatcher(c.deployment, taskID, 2*time.Second) + if c.taskReporter != nil { + watcher.WithTaskReporter(c.taskReporter) + } + slugCh := watcher.Watch(stopCh) + + var sessions []boshdir.AllOrInstanceGroupOrInstanceSlug + var sshWg sync.WaitGroup + + // Consume slugs from the watcher, setting up SSH tails for each. + // Also select on sshCtx.Done() so Ctrl+C breaks out immediately. + for done := false; !done; { + select { + case slug, ok := <-slugCh: + if !ok { + done = true + break + } + + parts := strings.SplitN(slug, "/", 2) + if len(parts) != 2 { + continue + } + + instanceSlug := boshdir.NewAllOrInstanceGroupOrInstanceSlug(parts[0], parts[1]) + + result, setupErr := c.deployment.SetUpSSH(instanceSlug, sshOpts) + if setupErr != nil { + c.ui.PrintLinef("Warning: failed to set up SSH for %s: %s", slug, setupErr.Error()) + continue + } + + sessions = append(sessions, instanceSlug) + + sshWg.Add(1) + go func(s string) { + defer sshWg.Done() + runErr := c.nonIntSSHRunner.RunContext(sshCtx, connOpts, result, tailCmd) + if runErr != nil && sshCtx.Err() == nil { + c.ui.PrintLinef("Warning: SSH tail on %s exited: %s", s, runErr.Error()) + } + }(slug) + + case <-sshCtx.Done(): + done = true + } + } + + // Check whether a signal fired before we cancel the context ourselves. + interrupted := sshCtx.Err() != nil + + // Stop the event watcher so its goroutine exits. + closeStop() + + // Cancel the SSH context to terminate all local ssh processes. + sshCancel() + sshWg.Wait() + + for _, slug := range sessions { + if cleanupErr := c.deployment.CleanUpSSH(slug, sshOpts); cleanupErr != nil { + c.ui.PrintLinef("Warning: failed to clean up SSH for %s: %s", slug, cleanupErr) + } + } + + if interrupted { + c.ui.PrintLinef("\nStreaming interrupted. Errand task %d is still running on the director.", taskID) + c.ui.PrintLinef("Use 'bosh task %d' to monitor or 'bosh cancel-task %d' to stop it.", taskID, taskID) + return nil + } + + results, err := c.deployment.WaitForErrandResult(taskID) + if err != nil { + return err + } + + // Suppress stdout/stderr in the summary since we already streamed the logs. + for i := range results { + results[i].Stdout = "" + results[i].Stderr = "" + } + + return c.finishErrand(opts, results) +} + +func (c RunErrandCmd) finishErrand(opts RunErrandOpts, results []boshdir.ErrandResult) error { errandErr := c.summarize(opts.Args.Name, results) - for _, result := range results { + for _, result := range results { if opts.DownloadLogs && len(result.LogsBlobstoreID) > 0 { err := c.downloader.Download( result.LogsBlobstoreID, @@ -50,12 +214,37 @@ func (c RunErrandCmd) Run(opts RunErrandOpts) error { return bosherr.WrapError(err, "Downloading errand logs") } } - } return errandErr } +func BuildErrandTailCmd(errandName, customPath string) ([]string, error) { + var logPath string + if customPath != "" { + if !safeLogPathRe.MatchString(customPath) { + return nil, bosherr.Errorf("--stream-log-path contains invalid characters: %q", customPath) + } + logPath = fmt.Sprintf("/var/vcap/sys/log/%s", customPath) + } else { + if !safeLogPathRe.MatchString(errandName) { + return nil, bosherr.Errorf("errand name contains invalid characters: %q", errandName) + } + logPath = fmt.Sprintf("/var/vcap/sys/log/%[1]s/%[1]s.{stdout,stderr}.log", errandName) + } + + tailScript := fmt.Sprintf( + `until ls %[1]s >/dev/null 2>&1;do sleep 1; done && exec tail -n 0 -F %[1]s`, + logPath, + ) + + // The script must be single-quoted because ssh concatenates all trailing + // arguments with spaces before passing them to the remote shell. Without + // quotes, "bash -c