Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Auto detect text files and normalize to LF
* text=auto eol=lf

# Go source files
*.go text eol=lf

# Shell scripts
*.sh text eol=lf

# Windows specific files
*.bat text eol=crlf
*.cmd text eol=crlf
*.ps1 text eol=crlf

# Binary files
*.exe binary
*.dll binary
*.so binary
*.dylib binary
*.png binary
*.jpg binary
*.jpeg binary
*.gif binary
*.ico binary
*.pdf binary
*.zip binary
*.tar binary
*.gz binary
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ require (
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
golang.org/x/oauth2 v0.32.0
golang.org/x/sys v0.37.0
golang.org/x/term v0.36.0
google.golang.org/genai v1.31.0
modernc.org/sqlite v1.39.1
Expand Down Expand Up @@ -116,7 +117,6 @@ require (
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b // indirect
golang.org/x/net v0.46.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/sys v0.37.0 // indirect
golang.org/x/text v0.30.0 // indirect
golang.org/x/time v0.14.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250826171959-ef028d996bc1 // indirect
Expand Down
7 changes: 0 additions & 7 deletions pkg/runtime/remote_runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,6 @@ func (r *RemoteRuntime) CurrentAgent() *agent.Agent {
return agent.New(r.currentAgent, fmt.Sprintf("Remote agent: %s", r.currentAgent))
}

// StopPendingProcesses stops all pending tool operations for the remote runtime
func (r *RemoteRuntime) StopPendingProcesses(ctx context.Context) error {
// For remote runtime, stop the team's toolsets
// This will kill any spawned processes from shell tools
return r.team.StopToolSets(ctx)
}

// RunStream starts the agent's interaction loop and returns a channel of events
func (r *RemoteRuntime) RunStream(ctx context.Context, sess *session.Session) <-chan Event {
slog.Debug("Starting remote runtime stream", "agent", r.currentAgent, "session_id", r.sessionID)
Expand Down
6 changes: 0 additions & 6 deletions pkg/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ type ElicitationRequestHandler func(ctx context.Context, message string, schema
type Runtime interface {
// CurrentAgent returns the currently active agent
CurrentAgent() *agent.Agent
// StopPendingProcesses stops all pending tool operations (e.g., running shell commands)
StopPendingProcesses(ctx context.Context) error
// RunStream starts the agent's interaction loop and returns a channel of events
RunStream(ctx context.Context, sess *session.Session) <-chan Event
// Run starts the agent's interaction loop and returns the final messages
Expand Down Expand Up @@ -180,10 +178,6 @@ func (r *runtime) CurrentAgent() *agent.Agent {
return current
}

func (r *runtime) StopPendingProcesses(ctx context.Context) error {
return r.team.StopToolSets(ctx)
}

// registerDefaultTools registers the default tool handlers
func (r *runtime) registerDefaultTools() {
slog.Debug("Registering default tools")
Expand Down
32 changes: 0 additions & 32 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ func New(sessionStore session.Store, runConfig config.RuntimeConfig, teams map[s
group.POST("/sessions/:id/resume", s.resumeSession)
// Create a new session and run an agent loop
group.POST("/sessions", s.createSession)
// Stop a running session
group.POST("/sessions/:id/stop", s.stopSession)
// Delete a session
group.DELETE("/sessions/:id", s.deleteSession)

Expand Down Expand Up @@ -963,36 +961,6 @@ func (s *Server) resumeSession(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"message": "session resumed"})
}

func (s *Server) stopSession(c echo.Context) error {
sessionID := c.Param("id")

// Get the runtime for this session to access its team
rt, rtExists := s.runtimes[sessionID]

// Cancel the runtime context if it's still running
s.cancelsMu.Lock()
if cancel, exists := s.runtimeCancels[sessionID]; exists {
slog.Info("Stopping session execution", "session_id", sessionID)
cancel()
delete(s.runtimeCancels, sessionID)
s.cancelsMu.Unlock()

// Stop all pending tool operations (including killing shell-spawned processes)
if rtExists {
if err := rt.StopPendingProcesses(c.Request().Context()); err != nil {
slog.Error("Failed to stop pending tools for session", "session_id", sessionID, "error", err)
// Don't return error here, as we still want to report success for stopping the session
}
}

return c.JSON(http.StatusOK, map[string]string{"message": "session stopped successfully"})
}
s.cancelsMu.Unlock()

slog.Debug("No active runtime found for session", "session_id", sessionID)
return c.JSON(http.StatusNotFound, map[string]string{"error": "no active session found"})
}

func (s *Server) deleteSession(c echo.Context) error {
sessionID := c.Param("id")

Expand Down
10 changes: 9 additions & 1 deletion pkg/tools/builtin/cmd_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@ import (
"syscall"
)

type processGroup struct {
// Unix doesn't need to store handles, process group is managed by kernel
}

func platformSpecificSysProcAttr() *syscall.SysProcAttr {
return &syscall.SysProcAttr{
Setpgid: true,
}
}

func kill(proc *os.Process) error {
func createProcessGroup(proc *os.Process) (*processGroup, error) {
return &processGroup{}, nil
}

func kill(proc *os.Process, pg *processGroup) error {
return syscall.Kill(-proc.Pid, syscall.SIGTERM)
}
60 changes: 59 additions & 1 deletion pkg/tools/builtin/cmd_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,70 @@ package builtin
import (
"os"
"syscall"
"unsafe"

"golang.org/x/sys/windows"
)

type processGroup struct {
jobHandle windows.Handle
processHandle windows.Handle
}

func platformSpecificSysProcAttr() *syscall.SysProcAttr {
return nil
}

func kill(proc *os.Process) error {
func createProcessGroup(proc *os.Process) (*processGroup, error) {
job, err := windows.CreateJobObject(nil, nil)
if err != nil {
return nil, err
}

info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{
BasicLimitInformation: windows.JOBOBJECT_BASIC_LIMIT_INFORMATION{
LimitFlags: windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
},
}
if _, err := windows.SetInformationJobObject(
job,
windows.JobObjectExtendedLimitInformation,
uintptr(unsafe.Pointer(&info)),
uint32(unsafe.Sizeof(info))); err != nil {
_ = windows.CloseHandle(job)
return nil, err
}

handle, err := windows.OpenProcess(windows.PROCESS_SET_QUOTA|windows.PROCESS_TERMINATE, false, uint32(proc.Pid))
if err != nil {
_ = windows.CloseHandle(job)
return nil, err
}

if err := windows.AssignProcessToJobObject(job, handle); err != nil {
_ = windows.CloseHandle(handle)
_ = windows.CloseHandle(job)
return nil, err
}

return &processGroup{
jobHandle: job,
processHandle: handle,
}, nil
}

func kill(proc *os.Process, pg *processGroup) error {
if pg != nil {
// Close handles to trigger JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
// which will terminate all processes in the job
if pg.processHandle != 0 {
_ = windows.CloseHandle(pg.processHandle)
}
if pg.jobHandle != 0 {
_ = windows.CloseHandle(pg.jobHandle)
}
}

// Also call Kill on the process as a fallback
return proc.Kill()
}
92 changes: 35 additions & 57 deletions pkg/tools/builtin/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"os"
"os/exec"
"runtime"
"sync"
"strings"

"github.com/docker/cagent/pkg/tools"
)
Expand All @@ -25,8 +25,6 @@ type shellHandler struct {
shell string
shellArgsPrefix []string
env []string
mu sync.Mutex
processes []*os.Process
}

type RunShellArgs struct {
Expand All @@ -40,74 +38,67 @@ func (h *shellHandler) RunShell(ctx context.Context, toolCall tools.ToolCall) (*
return nil, fmt.Errorf("invalid arguments: %w", err)
}

cmd := exec.CommandContext(ctx, h.shell, append(h.shellArgsPrefix, params.Cmd)...)
cmd := exec.Command(h.shell, append(h.shellArgsPrefix, params.Cmd)...)
cmd.Env = h.env
if params.Cwd != "" {
cmd.Dir = params.Cwd
} else {
// Use the current working directory; avoid PWD on Windows (may be MSYS-style like /c/...)
if wd, err := os.Getwd(); err == nil {
cmd.Dir = wd
}
}

// Set up process group for proper cleanup
// On Unix: create new process group so we can kill the entire tree
cmd.SysProcAttr = platformSpecificSysProcAttr()

// Note: On Windows, we would set CreationFlags, but that requires
// platform-specific code in a _windows.go file

// Capture output using buffers
var outBuf, errBuf bytes.Buffer
var outBuf bytes.Buffer
cmd.Stdout = &outBuf
cmd.Stderr = &errBuf
cmd.Stderr = &outBuf

// Start the command so we can track it
if err := cmd.Start(); err != nil {
return &tools.ToolCallResult{
Output: fmt.Sprintf("Error starting command: %s", err),
}, nil
}

// Track the process for cleanup
h.mu.Lock()
h.processes = append(h.processes, cmd.Process)
h.mu.Unlock()

// Remove from tracking once complete
defer func() {
h.mu.Lock()
for i, p := range h.processes {
if p != nil && p.Pid == cmd.Process.Pid {
h.processes = append(h.processes[:i], h.processes[i+1:]...)
break
}
}
h.mu.Unlock()
}()

// Wait for the command to complete and get the result
err := cmd.Wait()

// Combine stdout and stderr
output := outBuf.String() + errBuf.String()

pg, err := createProcessGroup(cmd.Process)
if err != nil {
return &tools.ToolCallResult{
Output: fmt.Sprintf("Error executing command: %s\nOutput: %s", err, output),
Output: fmt.Sprintf("Error creating process group: %s", err),
}, nil
}

if output == "" {
done := make(chan error, 1)
go func() {
done <- cmd.Wait()
}()

select {
case <-ctx.Done():
if cmd.Process != nil {
_ = kill(cmd.Process, pg)
}
return &tools.ToolCallResult{
Output: "<no output>",
Output: "Command cancelled",
}, nil
}
case err := <-done:
output := outBuf.String()

return &tools.ToolCallResult{
Output: output,
}, nil
if err != nil {
return &tools.ToolCallResult{
Output: fmt.Sprintf("Error executing command: %s\nOutput: %s", err, output),
}, nil
}

if strings.TrimSpace(output) == "" {
return &tools.ToolCallResult{
Output: "<no output>",
}, nil
}

return &tools.ToolCallResult{
Output: fmt.Sprintf("Output: %s", output),
}, nil
}
}

func NewShellTool(env []string) *ShellTool {
Expand Down Expand Up @@ -236,18 +227,5 @@ func (t *ShellTool) Start(context.Context) error {
}

func (t *ShellTool) Stop(context.Context) error {
t.handler.mu.Lock()
defer t.handler.mu.Unlock()

// Kill all tracked processes
for _, proc := range t.handler.processes {
if proc != nil {
_ = kill(proc)
}
}

// Clear the processes list
t.handler.processes = nil

return nil
}
Loading