From 8534870c4351c52b4969d825ff68e5019463ff17 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 3 Jun 2026 11:21:37 +0000 Subject: [PATCH 1/4] Add mcp tool execution fix --- .../ObjectModel/InvokeMcpToolExecutor.cs | 42 ++- .../ObjectModel/InvokeMcpToolExecutorTest.cs | 244 ++++++++++++++++++ 2 files changed, 280 insertions(+), 6 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeMcpToolExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeMcpToolExecutor.cs index c4c490551a..d99ef50a7a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeMcpToolExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeMcpToolExecutor.cs @@ -27,6 +27,12 @@ internal sealed class InvokeMcpToolExecutor( WorkflowFormulaState state) : DeclarativeActionExecutor(model, state) { + /// + /// Snapshot of evaluated parameters at approval-request time. + /// Used to prevent TOCTOU attacks where state mutates during the approval window. + /// + private ApprovalSnapshot? _approvalSnapshot; + /// /// Step identifiers for the MCP tool invocation workflow. /// @@ -75,6 +81,10 @@ public static bool RequiresNothing(object? message) => if (requireApproval) { + // Snapshot the evaluated parameters to prevent TOCTOU attacks. + // If state mutates during the approval window, the approved values are used on resume. + this._approvalSnapshot = new ApprovalSnapshot(serverUrl, serverLabel, toolName, arguments, connectionName); + // Create tool call content for approval request. // Transport headers (e.g. Authorization) are intentionally excluded from the // approval event: they must not cross into the externally-surfaced approval request. @@ -137,13 +147,14 @@ public async ValueTask CaptureResponseAsync( return; } - // Approved - now invoke the tool - string serverUrl = this.GetServerUrl(); - string? serverLabel = this.GetServerLabel(); - string toolName = this.GetToolName(); - Dictionary? arguments = this.GetArguments(); + // Approved - use the snapshot from approval-request time to prevent TOCTOU attacks. + // Headers are re-evaluated (they may contain auth secrets that should not be persisted). + string serverUrl = this._approvalSnapshot?.ServerUrl ?? this.GetServerUrl(); + string? serverLabel = this._approvalSnapshot?.ServerLabel ?? this.GetServerLabel(); + string toolName = this._approvalSnapshot?.ToolName ?? this.GetToolName(); + Dictionary? arguments = this._approvalSnapshot?.Arguments ?? this.GetArguments(); Dictionary? headers = this.GetHeaders(); - string? connectionName = this.GetConnectionName(); + string? connectionName = this._approvalSnapshot?.ConnectionName ?? this.GetConnectionName(); McpServerToolResultContent resultContent = await mcpToolHandler.InvokeToolAsync( serverUrl, @@ -365,4 +376,23 @@ private bool GetAutoSendValue() return result; } + + /// + /// Stores the evaluated parameters at approval-request time so that + /// uses the values the user reviewed, + /// even if mutates during the approval window. + /// + private readonly struct ApprovalSnapshot( + string serverUrl, + string? serverLabel, + string toolName, + Dictionary? arguments, + string? connectionName) + { + public string ServerUrl { get; } = serverUrl; + public string? ServerLabel { get; } = serverLabel; + public string ToolName { get; } = toolName; + public Dictionary? Arguments { get; } = arguments; + public string? ConnectionName { get; } = connectionName; + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs index b8d936dab9..1209e62ebb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs @@ -11,6 +11,7 @@ using Microsoft.Agents.AI.Workflows.Declarative.PowerFx; using Microsoft.Agents.ObjectModel; using Microsoft.Extensions.AI; +using Microsoft.PowerFx.Types; using Moq; namespace Microsoft.Agents.AI.Workflows.Declarative.UnitTests.ObjectModel; @@ -842,6 +843,205 @@ public async Task InvokeMcpToolCaptureResponseWithApprovedAndConversationIdAsync #endregion + #region Approval Snapshot Security Tests + + /// + /// Verifies that mutating the tool name variable after approval does not change + /// which tool is actually invoked. The originally-approved tool name must be used. + /// + [Fact] + public async Task InvokeMcpToolCaptureResponseUsesApprovedToolNameNotMutatedAsync() + { + // Arrange + const string ApprovedToolName = "safe_readonly_query"; + const string MutatedToolName = "dangerous_admin_tool"; + + this.State.Set("TargetTool", FormulaValue.New(ApprovedToolName)); + this.State.InitializeSystem(); + this.State.Bind(); + + InvokeMcpTool model = this.CreateModelWithVariableToolName( + displayName: nameof(InvokeMcpToolCaptureResponseUsesApprovedToolNameNotMutatedAsync), + serverUrl: TestServerUrl, + variableName: "TargetTool"); + + string? capturedToolName = null; + Mock mockProvider = new(); + mockProvider.Setup(provider => provider.InvokeToolAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny?>(), + It.IsAny?>(), + It.IsAny(), + It.IsAny())) + .Callback?, IDictionary?, string?, CancellationToken>( + (_, _, toolName, _, _, _, _) => capturedToolName = toolName) + .ReturnsAsync(new McpServerToolResultContent("capture-call-id") + { + Outputs = [new TextContent("result")] + }); + MockAgentProvider mockAgentProvider = new(); + InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State); + + // Act - trigger ExecuteAsync to store the approval snapshot + Mock mockContext = CreateMockWorkflowContext(); + await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None); + + // Simulate parallel branch mutating state during the approval window + this.State.Set("TargetTool", FormulaValue.New(MutatedToolName)); + this.State.Bind(); + + // User clicks approve (they saw "safe_readonly_query" in the approval UI) + McpServerToolCallContent toolCall = new(action.Id, ApprovedToolName, TestServerUrl); + ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall); + ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true); + ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse])); + + // Resume after approval + await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None); + + // Assert - the originally-approved tool name must be used, not the mutated one + Assert.NotNull(capturedToolName); + Assert.Equal(ApprovedToolName, capturedToolName); + } + + /// + /// Verifies that mutating an argument variable after approval does not change + /// the arguments actually passed to the MCP tool. The originally-approved arguments must be used. + /// + [Fact] + public async Task InvokeMcpToolCaptureResponseUsesApprovedArgumentsNotMutatedAsync() + { + // Arrange + const string ApprovedQuery = "SELECT * FROM users LIMIT 10"; + const string MutatedQuery = "DROP TABLE users CASCADE; --"; + + this.State.Set("SqlQuery", FormulaValue.New(ApprovedQuery)); + this.State.InitializeSystem(); + this.State.Bind(); + + InvokeMcpTool model = this.CreateModelWithVariableArgument( + displayName: nameof(InvokeMcpToolCaptureResponseUsesApprovedArgumentsNotMutatedAsync), + serverUrl: TestServerUrl, + toolName: TestToolName, + argumentKey: "query", + variableName: "SqlQuery"); + + IDictionary? capturedArguments = null; + Mock mockProvider = new(); + mockProvider.Setup(provider => provider.InvokeToolAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny?>(), + It.IsAny?>(), + It.IsAny(), + It.IsAny())) + .Callback?, IDictionary?, string?, CancellationToken>( + (_, _, _, arguments, _, _, _) => capturedArguments = arguments) + .ReturnsAsync(new McpServerToolResultContent("capture-call-id") + { + Outputs = [new TextContent("result")] + }); + MockAgentProvider mockAgentProvider = new(); + InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State); + + // Act - trigger ExecuteAsync to store the approval snapshot + Mock mockContext = CreateMockWorkflowContext(); + await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None); + + // Simulate parallel branch mutating state during the approval window + this.State.Set("SqlQuery", FormulaValue.New(MutatedQuery)); + this.State.Bind(); + + // User clicks approve + McpServerToolCallContent toolCall = new(action.Id, TestToolName, TestServerUrl); + ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall); + ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true); + ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse])); + + // Resume after approval + await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None); + + // Assert - the originally-approved argument must be used, not the mutated one + Assert.NotNull(capturedArguments); + Assert.Equal(ApprovedQuery, capturedArguments["query"]?.ToString()); + } + + /// + /// Verifies that mutating the server URL variable after approval does not redirect + /// the MCP tool call to a different server. The originally-approved server URL must be used. + /// + [Fact] + public async Task InvokeMcpToolCaptureResponseUsesApprovedServerUrlNotMutatedAsync() + { + // Arrange + const string ApprovedServerUrl = "https://internal-mcp.corp"; + const string MutatedServerUrl = "https://attacker.evil/steal"; + + this.State.Set("McpEndpoint", FormulaValue.New(ApprovedServerUrl)); + this.State.InitializeSystem(); + this.State.Bind(); + + InvokeMcpTool model = this.CreateModelWithVariableServerUrl( + displayName: nameof(InvokeMcpToolCaptureResponseUsesApprovedServerUrlNotMutatedAsync), + variableName: "McpEndpoint", + toolName: TestToolName); + + string? capturedServerUrl = null; + Mock mockProvider = new(); + mockProvider.Setup(provider => provider.InvokeToolAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny?>(), + It.IsAny?>(), + It.IsAny(), + It.IsAny())) + .Callback?, IDictionary?, string?, CancellationToken>( + (serverUrl, _, _, _, _, _, _) => capturedServerUrl = serverUrl) + .ReturnsAsync(new McpServerToolResultContent("capture-call-id") + { + Outputs = [new TextContent("result")] + }); + MockAgentProvider mockAgentProvider = new(); + InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State); + + // Act - trigger ExecuteAsync to store the approval snapshot + Mock mockContext = CreateMockWorkflowContext(); + await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None); + + // Simulate parallel branch mutating state during the approval window + this.State.Set("McpEndpoint", FormulaValue.New(MutatedServerUrl)); + this.State.Bind(); + + // User clicks approve + McpServerToolCallContent toolCall = new(action.Id, TestToolName, ApprovedServerUrl); + ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall); + ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true); + ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse])); + + // Resume after approval + await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None); + + // Assert - the originally-approved server URL must be used, not the mutated one + Assert.NotNull(capturedServerUrl); + Assert.Equal(ApprovedServerUrl, capturedServerUrl); + } + + private static Mock CreateMockWorkflowContext() + { + Mock mockContext = new(); + mockContext.Setup(c => c.AddEventAsync(It.IsAny(), It.IsAny())) + .Returns(ValueTask.CompletedTask); + mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(ValueTask.CompletedTask); + return mockContext; + } + + #endregion + #region CompleteAsync Tests [Fact] @@ -951,6 +1151,50 @@ private InvokeMcpTool CreateModel( return AssignParent(builder); } + private InvokeMcpTool CreateModelWithVariableToolName(string displayName, string serverUrl, string variableName) + { + InvokeMcpTool.Builder builder = new() + { + Id = this.CreateActionId(), + DisplayName = this.FormatDisplayName(displayName), + ServerUrl = new StringExpression.Builder(StringExpression.Literal(serverUrl)), + ToolName = new StringExpression.Builder( + StringExpression.Variable(PropertyPath.TopicVariable(variableName))), + RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)), + }; + return AssignParent(builder); + } + + private InvokeMcpTool CreateModelWithVariableArgument( + string displayName, string serverUrl, string toolName, string argumentKey, string variableName) + { + InvokeMcpTool.Builder builder = new() + { + Id = this.CreateActionId(), + DisplayName = this.FormatDisplayName(displayName), + ServerUrl = new StringExpression.Builder(StringExpression.Literal(serverUrl)), + ToolName = new StringExpression.Builder(StringExpression.Literal(toolName)), + RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)), + }; + builder.Arguments.Add(argumentKey, + ValueExpression.Variable(PropertyPath.TopicVariable(variableName))); + return AssignParent(builder); + } + + private InvokeMcpTool CreateModelWithVariableServerUrl(string displayName, string variableName, string toolName) + { + InvokeMcpTool.Builder builder = new() + { + Id = this.CreateActionId(), + DisplayName = this.FormatDisplayName(displayName), + ServerUrl = new StringExpression.Builder( + StringExpression.Variable(PropertyPath.TopicVariable(variableName))), + ToolName = new StringExpression.Builder(StringExpression.Literal(toolName)), + RequireApproval = new BoolExpression.Builder(BoolExpression.Literal(true)), + }; + return AssignParent(builder); + } + #endregion #region Mock MCP Tool Provider From 4e0c34d036f9228463704868cb81d98d534da5ec Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 3 Jun 2026 14:31:48 +0000 Subject: [PATCH 2/4] Apply IsolationKeyScopedAgentSessionStore to MapAGUI by default if not yet set and improve comments in samples --- .../AGUI/Step01_GettingStarted/Server/Program.cs | 5 +++++ .../AGUI/Step01_GettingStarted/Server/Server.csproj | 1 + .../AGUI/Step02_BackendTools/Server/Program.cs | 5 +++++ .../AGUI/Step02_BackendTools/Server/Server.csproj | 1 + .../AGUI/Step03_FrontendTools/Server/Program.cs | 5 +++++ .../AGUI/Step03_FrontendTools/Server/Server.csproj | 1 + .../AGUI/Step04_HumanInLoop/Server/Program.cs | 5 +++++ .../AGUI/Step04_HumanInLoop/Server/Server.csproj | 1 + .../AGUI/Step05_StateManagement/Server/Program.cs | 5 +++++ .../AGUI/Step05_StateManagement/Server/Server.csproj | 1 + .../AGUIDojoServer/AGUIDojoServer.csproj | 1 + .../AGUIClientServer/AGUIDojoServer/Program.cs | 5 +++++ .../AGUIClientServer/AGUIServer/Program.cs | 5 +++-- .../AGUIWebChat/Server/AGUIWebChatServer.csproj | 1 + .../05-end-to-end/AGUIWebChat/Server/Program.cs | 5 +++++ .../AGUIEndpointRouteBuilderExtensions.cs | 11 ++++++++++- 16 files changed, 55 insertions(+), 3 deletions(-) diff --git a/dotnet/samples/02-agents/AGUI/Step01_GettingStarted/Server/Program.cs b/dotnet/samples/02-agents/AGUI/Step01_GettingStarted/Server/Program.cs index 2c7333015d..0981ece789 100644 --- a/dotnet/samples/02-agents/AGUI/Step01_GettingStarted/Server/Program.cs +++ b/dotnet/samples/02-agents/AGUI/Step01_GettingStarted/Server/Program.cs @@ -10,6 +10,11 @@ builder.Services.AddHttpClient().AddLogging(); builder.Services.AddAGUI(); +// WARNING: When adding session persistence (e.g., WithInMemorySessionStore), or running in production, +// make sure to also register a SessionIsolationKeyProvider to scope sessions by principal in multi-user +// deployments, e.g.: +// builder.Services.UseClaimsBasedSessionIsolation(new() { ClaimType = ClaimTypes.NameIdentifier }); + WebApplication app = builder.Build(); string endpoint = builder.Configuration["AZURE_OPENAI_ENDPOINT"] diff --git a/dotnet/samples/02-agents/AGUI/Step01_GettingStarted/Server/Server.csproj b/dotnet/samples/02-agents/AGUI/Step01_GettingStarted/Server/Server.csproj index 01c8663a7b..a551fed512 100644 --- a/dotnet/samples/02-agents/AGUI/Step01_GettingStarted/Server/Server.csproj +++ b/dotnet/samples/02-agents/AGUI/Step01_GettingStarted/Server/Server.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/samples/02-agents/AGUI/Step02_BackendTools/Server/Program.cs b/dotnet/samples/02-agents/AGUI/Step02_BackendTools/Server/Program.cs index 33a32410e2..53b680c861 100644 --- a/dotnet/samples/02-agents/AGUI/Step02_BackendTools/Server/Program.cs +++ b/dotnet/samples/02-agents/AGUI/Step02_BackendTools/Server/Program.cs @@ -16,6 +16,11 @@ options.SerializerOptions.TypeInfoResolverChain.Add(SampleJsonSerializerContext.Default)); builder.Services.AddAGUI(); +// WARNING: When adding session persistence (e.g., WithInMemorySessionStore), or running in production, +// make sure to also register a SessionIsolationKeyProvider to scope sessions by principal in multi-user +// deployments, e.g.: +// builder.Services.UseClaimsBasedSessionIsolation(new() { ClaimType = ClaimTypes.NameIdentifier }); + WebApplication app = builder.Build(); string endpoint = builder.Configuration["AZURE_OPENAI_ENDPOINT"] diff --git a/dotnet/samples/02-agents/AGUI/Step02_BackendTools/Server/Server.csproj b/dotnet/samples/02-agents/AGUI/Step02_BackendTools/Server/Server.csproj index 01c8663a7b..a551fed512 100644 --- a/dotnet/samples/02-agents/AGUI/Step02_BackendTools/Server/Server.csproj +++ b/dotnet/samples/02-agents/AGUI/Step02_BackendTools/Server/Server.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/samples/02-agents/AGUI/Step03_FrontendTools/Server/Program.cs b/dotnet/samples/02-agents/AGUI/Step03_FrontendTools/Server/Program.cs index 2c7333015d..0981ece789 100644 --- a/dotnet/samples/02-agents/AGUI/Step03_FrontendTools/Server/Program.cs +++ b/dotnet/samples/02-agents/AGUI/Step03_FrontendTools/Server/Program.cs @@ -10,6 +10,11 @@ builder.Services.AddHttpClient().AddLogging(); builder.Services.AddAGUI(); +// WARNING: When adding session persistence (e.g., WithInMemorySessionStore), or running in production, +// make sure to also register a SessionIsolationKeyProvider to scope sessions by principal in multi-user +// deployments, e.g.: +// builder.Services.UseClaimsBasedSessionIsolation(new() { ClaimType = ClaimTypes.NameIdentifier }); + WebApplication app = builder.Build(); string endpoint = builder.Configuration["AZURE_OPENAI_ENDPOINT"] diff --git a/dotnet/samples/02-agents/AGUI/Step03_FrontendTools/Server/Server.csproj b/dotnet/samples/02-agents/AGUI/Step03_FrontendTools/Server/Server.csproj index 01c8663a7b..a551fed512 100644 --- a/dotnet/samples/02-agents/AGUI/Step03_FrontendTools/Server/Server.csproj +++ b/dotnet/samples/02-agents/AGUI/Step03_FrontendTools/Server/Server.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/samples/02-agents/AGUI/Step04_HumanInLoop/Server/Program.cs b/dotnet/samples/02-agents/AGUI/Step04_HumanInLoop/Server/Program.cs index edfcd03219..88967acb99 100644 --- a/dotnet/samples/02-agents/AGUI/Step04_HumanInLoop/Server/Program.cs +++ b/dotnet/samples/02-agents/AGUI/Step04_HumanInLoop/Server/Program.cs @@ -27,6 +27,11 @@ options.SerializerOptions.TypeInfoResolverChain.Add(ApprovalJsonContext.Default)); builder.Services.AddAGUI(); +// WARNING: When adding session persistence (e.g., WithInMemorySessionStore), or running in production, +// make sure to also register a SessionIsolationKeyProvider to scope sessions by principal in multi-user +// deployments, e.g.: +// builder.Services.UseClaimsBasedSessionIsolation(new() { ClaimType = ClaimTypes.NameIdentifier }); + WebApplication app = builder.Build(); app.UseHttpLogging(); diff --git a/dotnet/samples/02-agents/AGUI/Step04_HumanInLoop/Server/Server.csproj b/dotnet/samples/02-agents/AGUI/Step04_HumanInLoop/Server/Server.csproj index 01c8663a7b..a551fed512 100644 --- a/dotnet/samples/02-agents/AGUI/Step04_HumanInLoop/Server/Server.csproj +++ b/dotnet/samples/02-agents/AGUI/Step04_HumanInLoop/Server/Server.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/samples/02-agents/AGUI/Step05_StateManagement/Server/Program.cs b/dotnet/samples/02-agents/AGUI/Step05_StateManagement/Server/Program.cs index 1965cf55f7..67a6889fb1 100644 --- a/dotnet/samples/02-agents/AGUI/Step05_StateManagement/Server/Program.cs +++ b/dotnet/samples/02-agents/AGUI/Step05_StateManagement/Server/Program.cs @@ -17,6 +17,11 @@ // Configure to listen on port 8888 builder.WebHost.UseUrls("http://localhost:8888"); +// WARNING: When adding session persistence (e.g., WithInMemorySessionStore), or running in production, +// make sure to also register a SessionIsolationKeyProvider to scope sessions by principal in multi-user +// deployments, e.g.: +// builder.Services.UseClaimsBasedSessionIsolation(new() { ClaimType = ClaimTypes.NameIdentifier }); + WebApplication app = builder.Build(); string endpoint = builder.Configuration["AZURE_OPENAI_ENDPOINT"] diff --git a/dotnet/samples/02-agents/AGUI/Step05_StateManagement/Server/Server.csproj b/dotnet/samples/02-agents/AGUI/Step05_StateManagement/Server/Server.csproj index 01c8663a7b..a551fed512 100644 --- a/dotnet/samples/02-agents/AGUI/Step05_StateManagement/Server/Server.csproj +++ b/dotnet/samples/02-agents/AGUI/Step05_StateManagement/Server/Server.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIDojoServer/AGUIDojoServer.csproj b/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIDojoServer/AGUIDojoServer.csproj index 96a72d1109..03e2493623 100644 --- a/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIDojoServer/AGUIDojoServer.csproj +++ b/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIDojoServer/AGUIDojoServer.csproj @@ -15,6 +15,7 @@ + diff --git a/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIDojoServer/Program.cs b/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIDojoServer/Program.cs index e3b0020362..3f0032d4da 100644 --- a/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIDojoServer/Program.cs +++ b/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIDojoServer/Program.cs @@ -19,6 +19,11 @@ builder.Services.ConfigureHttpJsonOptions(options => options.SerializerOptions.TypeInfoResolverChain.Add(AGUIDojoServerSerializerContext.Default)); builder.Services.AddAGUI(); +// WARNING: When adding session persistence (e.g., WithInMemorySessionStore), or running in production, +// make sure to also register a SessionIsolationKeyProvider to scope sessions by principal in multi-user +// deployments, e.g.: +// builder.Services.UseClaimsBasedSessionIsolation(new() { ClaimType = ClaimTypes.NameIdentifier }); + WebApplication app = builder.Build(); app.UseHttpLogging(); diff --git a/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIServer/Program.cs b/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIServer/Program.cs index e3b97d34e1..575924255a 100644 --- a/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIServer/Program.cs +++ b/dotnet/samples/05-end-to-end/AGUIClientServer/AGUIServer/Program.cs @@ -49,8 +49,9 @@ AGUIServerSerializerContext.Default.Options) ]); -// When running in production, make sure to use an SessionIsolationKeyProvider, e.g. ClaimsIdentity-based -// if using Claims-based Identity for Authentication/Authorization +// WARNING: When adding session persistence (e.g., WithInMemorySessionStore), or running in production, +// make sure to also register a SessionIsolationKeyProvider to scope sessions by principal in multi-user +// deployments, e.g.: // builder.Services.UseClaimsBasedSessionIsolation(new() { ClaimType = ClaimTypes.NameIdentifier }); // Register the agent with the host and configure it to use an in-memory session store diff --git a/dotnet/samples/05-end-to-end/AGUIWebChat/Server/AGUIWebChatServer.csproj b/dotnet/samples/05-end-to-end/AGUIWebChat/Server/AGUIWebChatServer.csproj index e798d23506..8d44079173 100644 --- a/dotnet/samples/05-end-to-end/AGUIWebChat/Server/AGUIWebChatServer.csproj +++ b/dotnet/samples/05-end-to-end/AGUIWebChat/Server/AGUIWebChatServer.csproj @@ -14,6 +14,7 @@ + diff --git a/dotnet/samples/05-end-to-end/AGUIWebChat/Server/Program.cs b/dotnet/samples/05-end-to-end/AGUIWebChat/Server/Program.cs index 185b7d6bbf..06a138b8c3 100644 --- a/dotnet/samples/05-end-to-end/AGUIWebChat/Server/Program.cs +++ b/dotnet/samples/05-end-to-end/AGUIWebChat/Server/Program.cs @@ -12,6 +12,11 @@ builder.Services.AddHttpClient().AddLogging(); builder.Services.AddAGUI(); +// WARNING: When adding session persistence (e.g., WithInMemorySessionStore), or running in production, +// make sure to also register a SessionIsolationKeyProvider to scope sessions by principal in multi-user +// deployments, e.g.: +// builder.Services.UseClaimsBasedSessionIsolation(new() { ClaimType = ClaimTypes.NameIdentifier }); + WebApplication app = builder.Build(); string endpoint = builder.Configuration["AZURE_OPENAI_ENDPOINT"] ?? throw new InvalidOperationException("AZURE_OPENAI_ENDPOINT is not set."); diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIEndpointRouteBuilderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIEndpointRouteBuilderExtensions.cs index 85fd00fb8b..0d4c390bbb 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIEndpointRouteBuilderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore/AGUIEndpointRouteBuilderExtensions.cs @@ -103,7 +103,16 @@ public static IEndpointConventionBuilder MapAGUI( ArgumentNullException.ThrowIfNull(aiAgent); var agentSessionStore = endpoints.ServiceProvider.GetKeyedService(aiAgent.Name); - var hostAgent = new AIHostAgent(aiAgent, agentSessionStore ?? new NoopAgentSessionStore()); + + // Ensure that we have an IsolationKeyScopedAgentSessionStore registered. + var isolationKeyProvider = endpoints.ServiceProvider.GetService(); + if (agentSessionStore?.GetService() is null) + { + agentSessionStore ??= new NoopAgentSessionStore(); + agentSessionStore = new IsolationKeyScopedAgentSessionStore(agentSessionStore, isolationKeyProvider, new() { Strict = isolationKeyProvider != null }); + } + + var hostAgent = new AIHostAgent(aiAgent, agentSessionStore); return endpoints.MapPost(pattern, async ([FromBody] RunAgentInput? input, HttpContext context, CancellationToken cancellationToken) => { From 57058bd57dd8d065070d9247e990c369b3ef10f7 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 3 Jun 2026 16:01:49 +0000 Subject: [PATCH 3/4] Address PR comments --- .../ObjectModel/InvokeMcpToolExecutor.cs | 53 ++++++-- .../ObjectModel/InvokeMcpToolExecutorTest.cs | 114 +++++++++++++++++- 2 files changed, 152 insertions(+), 15 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeMcpToolExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeMcpToolExecutor.cs index d99ef50a7a..27079104a6 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeMcpToolExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/ObjectModel/InvokeMcpToolExecutor.cs @@ -27,6 +27,8 @@ internal sealed class InvokeMcpToolExecutor( WorkflowFormulaState state) : DeclarativeActionExecutor(model, state) { + private const string ApprovalSnapshotStateKey = nameof(_approvalSnapshot); + /// /// Snapshot of evaluated parameters at approval-request time. /// Used to prevent TOCTOU attacks where state mutates during the approval window. @@ -173,9 +175,33 @@ public async ValueTask CaptureResponseAsync( /// public async ValueTask CompleteAsync(IWorkflowContext context, ActionExecutorResult message, CancellationToken cancellationToken) { + // Clear the approval snapshot after successful completion. + this._approvalSnapshot = null; + await ClearSnapshotStateAsync(context, cancellationToken).ConfigureAwait(false); + await context.RaiseCompletionEventAsync(this.Model, cancellationToken).ConfigureAwait(false); } + /// + /// + /// Persists the approval snapshot to workflow state so it survives checkpoint/restore cycles. + /// + protected override async ValueTask OnCheckpointingAsync(IWorkflowContext context, CancellationToken cancellationToken = default) + { + await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, this._approvalSnapshot, null, cancellationToken).ConfigureAwait(false); + await base.OnCheckpointingAsync(context, cancellationToken).ConfigureAwait(false); + } + + /// + /// + /// Restores the approval snapshot from workflow state after a checkpoint restore. + /// + protected override async ValueTask OnCheckpointRestoredAsync(IWorkflowContext context, CancellationToken cancellationToken = default) + { + await base.OnCheckpointRestoredAsync(context, cancellationToken).ConfigureAwait(false); + this._approvalSnapshot = await context.ReadStateAsync(ApprovalSnapshotStateKey, null, cancellationToken).ConfigureAwait(false); + } + private async ValueTask ProcessResultAsync(IWorkflowContext context, McpServerToolResultContent resultContent, CancellationToken cancellationToken) { bool autoSend = this.GetAutoSendValue(); @@ -377,22 +403,23 @@ private bool GetAutoSendValue() return result; } + /// + /// Clears the persisted approval snapshot state after a successful tool invocation. + /// + private static async ValueTask ClearSnapshotStateAsync(IWorkflowContext context, CancellationToken cancellationToken) + { + await context.QueueStateUpdateAsync(ApprovalSnapshotStateKey, null, null, cancellationToken).ConfigureAwait(false); + } + /// /// Stores the evaluated parameters at approval-request time so that /// uses the values the user reviewed, /// even if mutates during the approval window. /// - private readonly struct ApprovalSnapshot( - string serverUrl, - string? serverLabel, - string toolName, - Dictionary? arguments, - string? connectionName) - { - public string ServerUrl { get; } = serverUrl; - public string? ServerLabel { get; } = serverLabel; - public string ToolName { get; } = toolName; - public Dictionary? Arguments { get; } = arguments; - public string? ConnectionName { get; } = connectionName; - } + internal sealed record ApprovalSnapshot( + string ServerUrl, + string? ServerLabel, + string ToolName, + Dictionary? Arguments, + string? ConnectionName); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs index 1209e62ebb..d48fe6a371 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs @@ -2,12 +2,14 @@ using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Workflows.Declarative.Events; using Microsoft.Agents.AI.Workflows.Declarative.Interpreter; using Microsoft.Agents.AI.Workflows.Declarative.Kit; using Microsoft.Agents.AI.Workflows.Declarative.ObjectModel; +using ApprovalSnapshot = Microsoft.Agents.AI.Workflows.Declarative.ObjectModel.InvokeMcpToolExecutor.ApprovalSnapshot; using Microsoft.Agents.AI.Workflows.Declarative.PowerFx; using Microsoft.Agents.ObjectModel; using Microsoft.Extensions.AI; @@ -1030,16 +1032,124 @@ public async Task InvokeMcpToolCaptureResponseUsesApprovedServerUrlNotMutatedAsy Assert.Equal(ApprovedServerUrl, capturedServerUrl); } + /// + /// Verifies that the approval snapshot survives a checkpoint/restore cycle. + /// After restore, the originally-approved tool name must still be used even if state was mutated. + /// + [Fact] + public async Task InvokeMcpToolCaptureResponseUsesSnapshotAfterCheckpointRestoreAsync() + { + // Arrange + const string ApprovedToolName = "safe_readonly_query"; + const string MutatedToolName = "dangerous_admin_tool"; + + this.State.Set("TargetTool", FormulaValue.New(ApprovedToolName)); + this.State.InitializeSystem(); + this.State.Bind(); + + InvokeMcpTool model = this.CreateModelWithVariableToolName( + displayName: nameof(InvokeMcpToolCaptureResponseUsesSnapshotAfterCheckpointRestoreAsync), + serverUrl: TestServerUrl, + variableName: "TargetTool"); + + string? capturedToolName = null; + Mock mockProvider = new(); + mockProvider.Setup(provider => provider.InvokeToolAsync( + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny?>(), + It.IsAny?>(), + It.IsAny(), + It.IsAny())) + .Callback?, IDictionary?, string?, CancellationToken>( + (_, _, toolName, _, _, _, _) => capturedToolName = toolName) + .ReturnsAsync(new McpServerToolResultContent("capture-call-id") + { + Outputs = [new TextContent("result")] + }); + MockAgentProvider mockAgentProvider = new(); + InvokeMcpToolExecutor action = new(model, mockProvider.Object, mockAgentProvider.Object, this.State); + + // Act - trigger ExecuteAsync to store the approval snapshot + Mock mockContext = CreateMockWorkflowContextWithStateStore(); + await action.HandleAsync(new ActionExecutorResult(action.Id), mockContext.Object, CancellationToken.None); + + // Simulate checkpoint: persist to state store + await InvokeProtectedMethodAsync(action, "OnCheckpointingAsync", mockContext.Object, CancellationToken.None); + + // Simulate restore on a "new" executor instance by clearing the in-memory field via reflection + // (In production, a new executor instance would be created with _approvalSnapshot == null) + typeof(InvokeMcpToolExecutor) + .GetField("_approvalSnapshot", BindingFlags.NonPublic | BindingFlags.Instance)! + .SetValue(action, null); + + // Restore from state store + await InvokeProtectedMethodAsync(action, "OnCheckpointRestoredAsync", mockContext.Object, CancellationToken.None); + + // Mutate state after restore (simulating parallel branch) + this.State.Set("TargetTool", FormulaValue.New(MutatedToolName)); + this.State.Bind(); + + // User clicks approve + McpServerToolCallContent toolCall = new(action.Id, ApprovedToolName, TestServerUrl); + ToolApprovalRequestContent approvalRequest = new(action.Id, toolCall); + ToolApprovalResponseContent approvalResponse = approvalRequest.CreateResponse(approved: true); + ExternalInputResponse response = new(new ChatMessage(ChatRole.User, [approvalResponse])); + + // Resume after approval + await action.CaptureResponseAsync(mockContext.Object, response, CancellationToken.None); + + // Assert - the originally-approved tool name must be used, not the mutated one + Assert.NotNull(capturedToolName); + Assert.Equal(ApprovedToolName, capturedToolName); + } + private static Mock CreateMockWorkflowContext() { Mock mockContext = new(); mockContext.Setup(c => c.AddEventAsync(It.IsAny(), It.IsAny())) - .Returns(ValueTask.CompletedTask); + .Returns(default(ValueTask)); mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(ValueTask.CompletedTask); + .Returns(default(ValueTask)); + mockContext.Setup(c => c.SendMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(default(ValueTask)); return mockContext; } + /// + /// Creates a mock workflow context that actually stores state values (for checkpoint/restore tests). + /// + private static Mock CreateMockWorkflowContextWithStateStore() + { + Dictionary stateStore = new(); + Mock mockContext = new(); + mockContext.Setup(c => c.AddEventAsync(It.IsAny(), It.IsAny())) + .Returns(default(ValueTask)); + mockContext.Setup(c => c.QueueStateUpdateAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((key, value, _, _) => stateStore[key] = value) + .Returns(default(ValueTask)); + mockContext.Setup(c => c.SendMessageAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(default(ValueTask)); + mockContext.Setup(c => c.ReadStateAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((key, _, _) => + new ValueTask(stateStore.TryGetValue(key, out object? val) ? val as ApprovalSnapshot : null)); + mockContext.Setup(c => c.ReadStateKeysAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new HashSet()); + return mockContext; + } + + /// + /// Invokes a protected method on an executor via reflection (for testing checkpoint hooks). + /// + private static async ValueTask InvokeProtectedMethodAsync(InvokeMcpToolExecutor action, string methodName, IWorkflowContext context, CancellationToken cancellationToken) + { + MethodInfo method = typeof(InvokeMcpToolExecutor) + .GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance)!; + ValueTask result = (ValueTask)method.Invoke(action, [context, cancellationToken])!; + await result.ConfigureAwait(false); + } + #endregion #region CompleteAsync Tests From a99c348b324af54eeacb40fdab5aba2511018fb3 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 3 Jun 2026 16:15:36 +0000 Subject: [PATCH 4/4] Fix formatting --- .../ObjectModel/InvokeMcpToolExecutorTest.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs index d48fe6a371..0f1ce950ff 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.Declarative.UnitTests/ObjectModel/InvokeMcpToolExecutorTest.cs @@ -9,12 +9,12 @@ using Microsoft.Agents.AI.Workflows.Declarative.Interpreter; using Microsoft.Agents.AI.Workflows.Declarative.Kit; using Microsoft.Agents.AI.Workflows.Declarative.ObjectModel; -using ApprovalSnapshot = Microsoft.Agents.AI.Workflows.Declarative.ObjectModel.InvokeMcpToolExecutor.ApprovalSnapshot; using Microsoft.Agents.AI.Workflows.Declarative.PowerFx; using Microsoft.Agents.ObjectModel; using Microsoft.Extensions.AI; using Microsoft.PowerFx.Types; using Moq; +using ApprovalSnapshot = Microsoft.Agents.AI.Workflows.Declarative.ObjectModel.InvokeMcpToolExecutor.ApprovalSnapshot; namespace Microsoft.Agents.AI.Workflows.Declarative.UnitTests.ObjectModel;