Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/metrics-collector.lock.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions pkg/cli/mcp_tools_privileged.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ from where the previous request stopped due to timeout.`,
mcpLog.Printf("Executing logs tool: workflow=%s, count=%d, firewall=%v, no_firewall=%v, filtered_integrity=%v, timeout=%d, command_args=%v",
args.WorkflowName, args.Count, args.Firewall, args.NoFirewall, args.FilteredIntegrity, timeoutValue, cmdArgs)

notifyProgress(ctx, req, 0, 100, "Downloading workflow logs...")

// Execute the CLI command
// Use separate stdout/stderr capture instead of CombinedOutput because:
// - Stdout contains JSON output (--json flag)
Expand Down Expand Up @@ -234,6 +236,8 @@ from where the previous request stopped due to timeout.`,
// Always write output to a file and return schema + file path
finalOutput := buildLogsFileResponse(outputStr)

notifyProgress(ctx, req, 100, 100, "Workflow logs downloaded")

return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: finalOutput},
Expand Down Expand Up @@ -353,6 +357,8 @@ Multi-run diff returns JSON describing changes between the base and each compari

cmdArgs = appendRepoFlagFromEnv(cmdArgs)

notifyProgress(ctx, req, 0, 100, "Downloading audit artifacts...")

// Execute the CLI command.
// Use separate stdout/stderr capture instead of CombinedOutput because:
// - Stdout contains JSON output (--json flag)
Expand Down Expand Up @@ -409,6 +415,8 @@ Multi-run diff returns JSON describing changes between the base and each compari
}, nil, nil
}

notifyProgress(ctx, req, 100, 100, "Audit complete")

return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: outputStr},
Expand Down Expand Up @@ -483,6 +491,10 @@ Returns JSON describing the differences between the base run and each comparison
cmdArgs = append(cmdArgs, "--artifacts", strings.Join(args.Artifacts, ","))
}

cmdArgs = appendRepoFlagFromEnv(cmdArgs)

notifyProgress(ctx, req, 0, 100, "Downloading artifacts for diff...")

cmd := execCmd(ctx, cmdArgs...)
stdout, err := cmd.Output()
outputStr := string(stdout)
Expand Down Expand Up @@ -517,6 +529,8 @@ Returns JSON describing the differences between the base run and each comparison
}, nil, nil
}

notifyProgress(ctx, req, 100, 100, "Diff complete")

return &mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{Text: outputStr}},
}, nil, nil
Expand All @@ -525,6 +539,26 @@ Returns JSON describing the differences between the base run and each comparison
return nil
}

// notifyProgress sends a progress notification to the MCP client if the request
// includes a progress token. The req, req.Params, and req.Session fields are
// checked for nil before use. Errors are silently ignored because progress
// notifications are best-effort; the tool result is not affected. If the client
// has disconnected or the notification fails for any reason, the tool continues
// executing normally.
func notifyProgress(ctx context.Context, req *mcp.CallToolRequest, progress, total float64, message string) {
if req == nil || req.Session == nil {
return
}
if token := req.Params.GetProgressToken(); token != nil {
_ = req.Session.NotifyProgress(ctx, &mcp.ProgressNotificationParams{
ProgressToken: token,
Progress: progress,
Total: total,
Message: message,
})
}
Comment on lines +542 to +559
}

// filtering out debug log lines (e.g. "workflow:script_registry Creating... +151ns").
// Console messages are identified by their prefix symbols (✗, ✓, ℹ, ⚠, etc.).
// Falls back to the last non-empty line if no console message is found.
Expand Down
227 changes: 227 additions & 0 deletions pkg/cli/mcp_tools_privileged_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,51 @@ import (
"os/exec"
"slices"
"strings"
"sync"
"testing"

"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// connectInMemoryWithProgress creates an in-memory MCP client-server connection
// that captures progress notifications. Returns the client session and a
// function to retrieve captured notifications. The returned getNotifications
// function returns a snapshot copy of all captured notifications and is safe
// to call concurrently with ongoing notification capture.
func connectInMemoryWithProgress(t *testing.T, server *mcp.Server) (*mcp.ClientSession, func() []*mcp.ProgressNotificationParams) {
t.Helper()
ctx := context.Background()
t1, t2 := mcp.NewInMemoryTransports()
_, err := server.Connect(ctx, t1, nil)
require.NoError(t, err, "server.Connect should succeed")

var mu sync.Mutex
var captured []*mcp.ProgressNotificationParams

client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "1.0"}, &mcp.ClientOptions{
ProgressNotificationHandler: func(_ context.Context, req *mcp.ProgressNotificationClientRequest) {
mu.Lock()
defer mu.Unlock()
p := *req.Params
captured = append(captured, &p)
},
})
session, err := client.Connect(ctx, t2, nil)
require.NoError(t, err, "client.Connect should succeed")
t.Cleanup(func() { session.Close() })

getNotifications := func() []*mcp.ProgressNotificationParams {
mu.Lock()
defer mu.Unlock()
result := make([]*mcp.ProgressNotificationParams, len(captured))
copy(result, captured)
return result
}
return session, getNotifications
}

// TestExtractLastConsoleMessage verifies that extractLastConsoleMessage correctly
// filters debug log lines and returns only user-facing console messages.
func TestExtractLastConsoleMessage(t *testing.T) {
Expand Down Expand Up @@ -508,3 +546,192 @@ func TestAuditDiffToolErrorEnvelopeHelperProcess(t *testing.T) {
_, _ = fmt.Fprintln(os.Stderr, "✗ failed to diff workflow runs")
os.Exit(1)
}

// TestLogsToolEmitsProgressNotifications verifies that the logs MCP tool
// sends progress notifications when a progress token is provided.
func TestLogsToolEmitsProgressNotifications(t *testing.T) {
const fakeOutput = `{"file_path":"/tmp/gh-aw/aw-mcp/logs/runs.json"}`

mockExecCmd := func(ctx context.Context, args ...string) *exec.Cmd {
return exec.CommandContext(ctx, "sh", "-c", `printf '%s' "$1"`, "sh", fakeOutput)
}

server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0"}, nil)
err := registerLogsTool(server, mockExecCmd, "", false)
require.NoError(t, err, "registerLogsTool should succeed")

session, getNotifications := connectInMemoryWithProgress(t, server)

params := &mcp.CallToolParams{Name: "logs", Arguments: map[string]any{}}
params.SetProgressToken("logs-progress-token")
_, err = session.CallTool(context.Background(), params)
require.NoError(t, err, "logs tool should succeed")

notifications := getNotifications()
require.GreaterOrEqual(t, len(notifications), 2, "logs tool should emit at least 2 progress notifications")

first := notifications[0]
assert.InDelta(t, float64(0), first.Progress, 0.001, "first notification should have progress=0")
assert.InDelta(t, float64(100), first.Total, 0.001, "first notification should have total=100")
assert.NotEmpty(t, first.Message, "first notification should have a message")

last := notifications[len(notifications)-1]
assert.InDelta(t, float64(100), last.Progress, 0.001, "last notification should have progress=100")
assert.InDelta(t, float64(100), last.Total, 0.001, "last notification should have total=100")
assert.NotEmpty(t, last.Message, "last notification should have a message")
}

// TestLogsToolNoProgressWithoutToken verifies that the logs MCP tool
// does not send progress notifications when no progress token is provided.
func TestLogsToolNoProgressWithoutToken(t *testing.T) {
const fakeOutput = `{"file_path":"/tmp/gh-aw/aw-mcp/logs/runs.json"}`

mockExecCmd := func(ctx context.Context, args ...string) *exec.Cmd {
return exec.CommandContext(ctx, "sh", "-c", `printf '%s' "$1"`, "sh", fakeOutput)
}

server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0"}, nil)
err := registerLogsTool(server, mockExecCmd, "", false)
require.NoError(t, err, "registerLogsTool should succeed")

session, getNotifications := connectInMemoryWithProgress(t, server)

// Call without setting a progress token
_, err = session.CallTool(context.Background(), &mcp.CallToolParams{
Name: "logs",
Arguments: map[string]any{},
})
require.NoError(t, err, "logs tool should succeed")

assert.Empty(t, getNotifications(), "logs tool should not emit progress notifications without a token")
}
Comment on lines +584 to +607

// TestAuditToolEmitsProgressNotifications verifies that the audit MCP tool
// sends progress notifications when a progress token is provided.
func TestAuditToolEmitsProgressNotifications(t *testing.T) {
const fakeOutput = `{"overview":{"run_id":"1234567890"}}`

mockExecCmd := func(ctx context.Context, args ...string) *exec.Cmd {
return exec.CommandContext(ctx, "sh", "-c", `printf '%s' "$1"`, "sh", fakeOutput)
}

server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0"}, nil)
err := registerAuditTool(server, mockExecCmd, "", false)
require.NoError(t, err, "registerAuditTool should succeed")

session, getNotifications := connectInMemoryWithProgress(t, server)

params := &mcp.CallToolParams{
Name: "audit",
Arguments: map[string]any{"run_id_or_url": "1234567890"},
}
params.SetProgressToken("audit-progress-token")
_, err = session.CallTool(context.Background(), params)
require.NoError(t, err, "audit tool should succeed")

notifications := getNotifications()
require.GreaterOrEqual(t, len(notifications), 2, "audit tool should emit at least 2 progress notifications")

first := notifications[0]
assert.InDelta(t, float64(0), first.Progress, 0.001, "first notification should have progress=0")
assert.InDelta(t, float64(100), first.Total, 0.001, "first notification should have total=100")
assert.NotEmpty(t, first.Message, "first notification should have a message")

last := notifications[len(notifications)-1]
assert.InDelta(t, float64(100), last.Progress, 0.001, "last notification should have progress=100")
assert.InDelta(t, float64(100), last.Total, 0.001, "last notification should have total=100")
assert.NotEmpty(t, last.Message, "last notification should have a message")
}

// TestAuditDiffToolEmitsProgressNotifications verifies that the audit-diff MCP
// tool sends progress notifications when a progress token is provided.
func TestAuditDiffToolEmitsProgressNotifications(t *testing.T) {
const fakeOutput = `[{"base_run_id":100,"compare_run_id":200}]`

mockExecCmd := func(ctx context.Context, args ...string) *exec.Cmd {
return exec.CommandContext(ctx, "sh", "-c", `printf '%s' "$1"`, "sh", fakeOutput)
}

server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0"}, nil)
err := registerAuditDiffTool(server, mockExecCmd, "", false)
require.NoError(t, err, "registerAuditDiffTool should succeed")

session, getNotifications := connectInMemoryWithProgress(t, server)

params := &mcp.CallToolParams{
Name: "audit-diff",
Arguments: map[string]any{
"base_run_id": "100",
"compare_run_ids": []string{"200"},
},
}
params.SetProgressToken("audit-diff-progress-token")
_, err = session.CallTool(context.Background(), params)
require.NoError(t, err, "audit-diff tool should succeed")

notifications := getNotifications()
require.GreaterOrEqual(t, len(notifications), 2, "audit-diff tool should emit at least 2 progress notifications")

first := notifications[0]
assert.InDelta(t, float64(0), first.Progress, 0.001, "first notification should have progress=0")
assert.InDelta(t, float64(100), first.Total, 0.001, "first notification should have total=100")
assert.NotEmpty(t, first.Message, "first notification should have a message")

last := notifications[len(notifications)-1]
assert.InDelta(t, float64(100), last.Progress, 0.001, "last notification should have progress=100")
assert.InDelta(t, float64(100), last.Total, 0.001, "last notification should have total=100")
assert.NotEmpty(t, last.Message, "last notification should have a message")
}

// TestAuditToolNoProgressWithoutToken verifies that the audit MCP tool
// does not send progress notifications when no progress token is provided.
func TestAuditToolNoProgressWithoutToken(t *testing.T) {
const fakeOutput = `{"overview":{"run_id":"1234567890"}}`

mockExecCmd := func(ctx context.Context, args ...string) *exec.Cmd {
return exec.CommandContext(ctx, "sh", "-c", `printf '%s' "$1"`, "sh", fakeOutput)
}

server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0"}, nil)
err := registerAuditTool(server, mockExecCmd, "", false)
require.NoError(t, err, "registerAuditTool should succeed")

session, getNotifications := connectInMemoryWithProgress(t, server)

// Call without setting a progress token
_, err = session.CallTool(context.Background(), &mcp.CallToolParams{
Name: "audit",
Arguments: map[string]any{"run_id_or_url": "1234567890"},
})
require.NoError(t, err, "audit tool should succeed")

assert.Empty(t, getNotifications(), "audit tool should not emit progress notifications without a token")
}

// TestAuditDiffToolNoProgressWithoutToken verifies that the audit-diff MCP tool
// does not send progress notifications when no progress token is provided.
func TestAuditDiffToolNoProgressWithoutToken(t *testing.T) {
const fakeOutput = `[{"base_run_id":100,"compare_run_id":200}]`

mockExecCmd := func(ctx context.Context, args ...string) *exec.Cmd {
return exec.CommandContext(ctx, "sh", "-c", `printf '%s' "$1"`, "sh", fakeOutput)
}

server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "1.0"}, nil)
err := registerAuditDiffTool(server, mockExecCmd, "", false)
require.NoError(t, err, "registerAuditDiffTool should succeed")

session, getNotifications := connectInMemoryWithProgress(t, server)

// Call without setting a progress token
_, err = session.CallTool(context.Background(), &mcp.CallToolParams{
Name: "audit-diff",
Arguments: map[string]any{
"base_run_id": "100",
"compare_run_ids": []string{"200"},
},
})
require.NoError(t, err, "audit-diff tool should succeed")

assert.Empty(t, getNotifications(), "audit-diff tool should not emit progress notifications without a token")
}