Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,14 @@ public static bool RequiresNothing(object? message) =>

if (requireApproval)
{
// Create tool call content for approval request
// Create tool call content for approval request.
// Transport headers (e.g. Authorization) are intentionally excluded from the
// approval event: they must not cross into the externally-surfaced approval request.
McpServerToolCallContent toolCall = new(this.Id, toolName, serverLabel ?? serverUrl)
{
Arguments = arguments
};

if (headers != null)
{
toolCall.AdditionalProperties ??= [];
toolCall.AdditionalProperties.Add(headers);
}

ToolApprovalRequestContent approvalRequest = new(this.Id, toolCall);

ChatMessage requestMessage = new(ChatRole.Assistant, [approvalRequest]);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Agents.AI.Workflows.Declarative.Events;
using Microsoft.Agents.AI.Workflows.Declarative.Interpreter;
using Microsoft.Agents.AI.Workflows.Declarative.Kit;
using Microsoft.Agents.AI.Workflows.Declarative.ObjectModel;
using Microsoft.Agents.AI.Workflows.Declarative.PowerFx;
Expand Down Expand Up @@ -290,6 +292,151 @@ public async Task InvokeMcpToolExecuteWithRequireApprovalAndHeadersAsync()
await this.ExecuteTestAsync(model);
}

[Fact]
public async Task InvokeMcpToolApprovalRequestExcludesTransportHeadersAsync()
{
// Arrange
this.State.InitializeSystem();
InvokeMcpTool model = this.CreateModel(
displayName: nameof(InvokeMcpToolApprovalRequestExcludesTransportHeadersAsync),
serverUrl: TestServerUrl,
serverLabel: TestServerLabel,
toolName: TestToolName,
requireApproval: true,
headerKey: "Authorization",
headerValue: "Bearer super-secret-token");
MockMcpToolProvider mockProvider = new();
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);

ExternalInputRequest? capturedRequest = null;

// Act
await this.ExecuteAsync(
[
action,
new DelegateActionExecutor<ExternalInputRequest>(
InvokeMcpToolExecutor.Steps.ExternalInput(action.Id),
this.State,
CaptureRequestAsync)
],
isDiscrete: false);

// Assert - the approval event must not carry any transport headers (e.g. Authorization).
Assert.NotNull(capturedRequest);
ToolApprovalRequestContent approvalRequest =
capturedRequest!.AgentResponse.Messages
.SelectMany(message => message.Contents)
.OfType<ToolApprovalRequestContent>()
.Single();

AdditionalPropertiesDictionary? additionalProperties = approvalRequest.ToolCall.AdditionalProperties;
Assert.True(additionalProperties is null || additionalProperties.Count == 0);

// Defense in depth: the credential value must not appear anywhere in the serialized approval content.
string serializedApproval = System.Text.Json.JsonSerializer.Serialize(capturedRequest.AgentResponse);
Assert.DoesNotContain("super-secret-token", serializedApproval);

ValueTask CaptureRequestAsync(IWorkflowContext context, ExternalInputRequest request, CancellationToken cancellationToken)
{
capturedRequest = request;
return default;
}
}

[Fact]
public async Task InvokeMcpToolInvocationForwardsHeadersToTransportAsync()
Comment thread
peibekwe marked this conversation as resolved.
{
// Arrange
this.State.InitializeSystem();
const string HeaderKey = "Authorization";
const string HeaderValue = "Bearer super-secret-token";
InvokeMcpTool model = this.CreateModel(
displayName: nameof(InvokeMcpToolInvocationForwardsHeadersToTransportAsync),
serverUrl: TestServerUrl,
serverLabel: TestServerLabel,
toolName: TestToolName,
requireApproval: false,
headerKey: HeaderKey,
headerValue: HeaderValue);

IDictionary<string, string>? capturedHeaders = null;
Mock<IMcpToolHandler> mockProvider = new();
mockProvider
.Setup(provider => provider.InvokeToolAsync(
It.IsAny<string>(),
It.IsAny<string?>(),
It.IsAny<string>(),
It.IsAny<IDictionary<string, object?>?>(),
It.IsAny<IDictionary<string, string>?>(),
It.IsAny<string?>(),
It.IsAny<CancellationToken>()))
.Callback<string, string?, string, IDictionary<string, object?>?, IDictionary<string, string>?, string?, CancellationToken>(
(_, _, _, _, headers, _, _) => capturedHeaders = headers)
.ReturnsAsync(new McpServerToolResultContent("mock-call-id") { Outputs = [new TextContent("ok")] });
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);

// Act
await this.ExecuteAsync(action, isDiscrete: false);

// Assert - headers remain available to the actual transport invocation.
Assert.NotNull(capturedHeaders);
Assert.True(capturedHeaders!.TryGetValue(HeaderKey, out string? forwardedValue));
Assert.Equal(HeaderValue, forwardedValue);
}

[Fact]
public async Task InvokeMcpToolApprovedCaptureResponseForwardsHeadersToTransportAsync()
{
// Arrange - exercises the post-approval CaptureResponseAsync resume path to prove the
// fix did not regress header forwarding on the path that the vulnerability actually targets.
this.State.InitializeSystem();
const string HeaderKey = "Authorization";
const string HeaderValue = "Bearer super-secret-token";
InvokeMcpTool model = this.CreateModel(
displayName: nameof(InvokeMcpToolApprovedCaptureResponseForwardsHeadersToTransportAsync),
serverUrl: TestServerUrl,
serverLabel: TestServerLabel,
toolName: TestToolName,
requireApproval: true,
headerKey: HeaderKey,
headerValue: HeaderValue);

IDictionary<string, string>? capturedHeaders = null;
Mock<IMcpToolHandler> mockProvider = new();
mockProvider
.Setup(provider => provider.InvokeToolAsync(
It.IsAny<string>(),
It.IsAny<string?>(),
It.IsAny<string>(),
It.IsAny<IDictionary<string, object?>?>(),
It.IsAny<IDictionary<string, string>?>(),
It.IsAny<string?>(),
It.IsAny<CancellationToken>()))
.Callback<string, string?, string, IDictionary<string, object?>?, IDictionary<string, string>?, string?, CancellationToken>(
(_, _, _, _, headers, _, _) => capturedHeaders = headers)
.ReturnsAsync(new McpServerToolResultContent("mock-call-id") { Outputs = [new TextContent("ok")] });
MockAgentProvider mockAgentProvider = new();
InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State);

Mock<IWorkflowContext> mockContext = new(MockBehavior.Loose);

// Build an approved response matching this action's request id.
McpServerToolCallContent toolCall = new(action.Id, TestToolName, TestServerLabel);
ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall);
ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true);
ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse]));

// Act - call CaptureResponseAsync directly so the post-approval branch actually executes.
await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None);

// Assert - headers reach the transport invocation on the approved path.
Assert.NotNull(capturedHeaders);
Assert.True(capturedHeaders!.TryGetValue(HeaderKey, out string? forwardedValue));
Assert.Equal(HeaderValue, forwardedValue);
}

[Fact]
public async Task InvokeMcpToolExecuteWithEmptyHeaderValueAsync()
{
Expand Down
Loading