diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs index d9acab96d5..576c749a90 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffAgentExecutor.cs @@ -426,7 +426,7 @@ await AddUpdateAsync( { AgentId = this._agent.Id, AuthorName = this._agent.Name ?? this._agent.Id, - Contents = [new FunctionResultContent(handoffRequest.CallId, "Transferred.")], + Contents = [CreateHandoffResult(handoffRequest.CallId)], CreatedAt = DateTimeOffset.UtcNow, MessageId = Guid.NewGuid().ToString("N"), Role = ChatRole.Tool, @@ -459,4 +459,6 @@ ValueTask AddUpdateAsync(AgentResponseUpdate update, CancellationToken cancellat ? this._handoffFunctionToAgentId.TryGetValue(requestedHandoff, out string? targetId) ? targetId : null : null; } + + internal static FunctionResultContent CreateHandoffResult(string requestCallId) => new(requestCallId, "Transferred."); } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs index 7bc178c2d4..61eebc0e2b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Specialized/HandoffMessagesFilter.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; -using System.Linq; using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI.Workflows.Specialized; @@ -31,113 +30,78 @@ public IEnumerable FilterMessages(IEnumerable messages return messages; } - Dictionary filteringCandidates = new(); - List filteredMessages = []; - HashSet messagesToRemove = []; + HashSet filteredCallsWithoutResponses = new(); + List retainedMessages = []; + + bool filterAllToolCalls = this._filteringBehavior == HandoffToolCallFilteringBehavior.All; + + // The logic of filtering is fairly straightforward: We are only interested in FunctionCallContent and FunctionResponseContent. + // We are going to assume that Handoff operates as follows: + // * Each agent is only taking one turn at a time + // * Each agent is taking a turn alone + // + // In the case of certain providers, like Gemini (see microsoft/agent-framework #5244), we will see the function call name as the + // call id as well, so we may see multiple calls with the same call id, and assume that the call is terminated before another + // "CallId-less" FCC is issued. We also need to rely on the idea that FRC follows their corresponding FCC in the message stream. + // (This changes the previous behaviour where FRC could arrive earlier, and relies on strict ordering). + // + // The benefit of expecting all the AIContent to be strictly ordered is that we never need to reach back into a post-filtered + // content to retroactively remove it, or to try to inject it back into the middle of a Message that has already been processed. - bool filterHandoffOnly = this._filteringBehavior == HandoffToolCallFilteringBehavior.HandoffOnly; foreach (ChatMessage unfilteredMessage in messages) { - ChatMessage filteredMessage = unfilteredMessage.Clone(); + if (unfilteredMessage.Contents is null || unfilteredMessage.Contents.Count == 0) + { + retainedMessages.Add(unfilteredMessage); + continue; + } - // .Clone() is shallow, so we cannot modify the contents of the cloned message in place. - List contents = []; - contents.Capacity = unfilteredMessage.Contents?.Count ?? 0; - filteredMessage.Contents = contents; + // We may need to filter out a subset of the message's content, but we won't know until we iterate through it. Create a new list + // of AIContent which we will stuff into a clone of the message if we need to filter out any content. + List retainedContents = new(capacity: unfilteredMessage.Contents.Count); - // Because this runs after the role changes from assistant to user for the target agent, we cannot rely on tool calls - // originating only from messages with the Assistant role. Instead, we need to inspect the contents of all non-Tool (result) - // FunctionCallContent. - if (unfilteredMessage.Role != ChatRole.Tool) + foreach (AIContent content in unfilteredMessage.Contents) { - for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) + if (content is FunctionCallContent fcc + && (filterAllToolCalls || IsHandoffFunctionName(fcc.Name))) { - AIContent content = unfilteredMessage.Contents[i]; - if (content is not FunctionCallContent fcc || (filterHandoffOnly && !IsHandoffFunctionName(fcc.Name))) - { - filteredMessage.Contents.Add(content); - - // Track non-handoff function calls so their tool results are preserved in HandoffOnly mode - if (filterHandoffOnly && content is FunctionCallContent nonHandoffFcc) - { - filteringCandidates[nonHandoffFcc.CallId] = new FilterCandidateState(nonHandoffFcc.CallId) - { - IsHandoffFunction = false, - }; - } - } - else if (filterHandoffOnly) + // If we already have an unmatched candidate with the same CallId, that means we have two FCCs in a row without an FRC, + // which violates our assumption of strict ordering. + if (!filteredCallsWithoutResponses.Add(fcc.CallId)) { - if (!filteringCandidates.TryGetValue(fcc.CallId, out FilterCandidateState? candidateState)) - { - filteringCandidates[fcc.CallId] = new FilterCandidateState(fcc.CallId) - { - IsHandoffFunction = true, - }; - } - else - { - candidateState.IsHandoffFunction = true; - (int messageIndex, int contentIndex) = candidateState.FunctionCallResultLocation!.Value; - ChatMessage messageToFilter = filteredMessages[messageIndex]; - messageToFilter.Contents.RemoveAt(contentIndex); - if (messageToFilter.Contents.Count == 0) - { - messagesToRemove.Add(messageIndex); - } - } + throw new InvalidOperationException($"Duplicate FunctionCallContent with CallId '{fcc.CallId}' without corresponding FunctionResultContent."); } - else - { - // All mode: strip all FunctionCallContent - } - } - } - else - { - if (!filterHandoffOnly) - { + + // If we are filtering all tool calls, or this is a handoff call (and we are not filtering None, already checked), then + // filter this FCC continue; } - - for (int i = 0; i < unfilteredMessage.Contents!.Count; i++) + else if (content is FunctionResultContent frc) { - AIContent content = unfilteredMessage.Contents[i]; - if (content is not FunctionResultContent frc - || (filteringCandidates.TryGetValue(frc.CallId, out FilterCandidateState? candidateState) - && candidateState.IsHandoffFunction is false)) - { - // Either this is not a function result content, so we should let it through, or it is a FRC that - // we know is not related to a handoff call. In either case, we should include it. - filteredMessage.Contents.Add(content); - } - else if (candidateState is null) + // We rely on the corresponding FCC to have already been processed, so check if it is in the candidate dictionary. + // If it is, we can filter out the FRC, but we need to remove the candidate from the dictionary, since a future FCC can + // come in with the same CallId, and should be considered a new call that may need to be filtered. + if (filteredCallsWithoutResponses.Remove(frc.CallId)) { - // We haven't seen the corresponding function call yet, so add it as a candidate to be filtered later - filteringCandidates[frc.CallId] = new FilterCandidateState(frc.CallId) - { - FunctionCallResultLocation = (filteredMessages.Count, filteredMessage.Contents.Count), - }; + continue; } - // else we have seen the corresponding function call and it is a handoff, so we should filter it out. } + + // FCC/FRC, but not filtered, or neither FCC nor FRC: this should not be filtered out + retainedContents.Add(content); } - if (filteredMessage.Contents.Count > 0) + if (retainedContents.Count == 0) { - filteredMessages.Add(filteredMessage); + // message was fully filtered, skip it + continue; } - } - return filteredMessages.Where((_, index) => !messagesToRemove.Contains(index)); - } - - private class FilterCandidateState(string callId) - { - public (int MessageIndex, int ContentIndex)? FunctionCallResultLocation { get; set; } - - public string CallId => callId; + ChatMessage filteredMessage = unfilteredMessage.Clone(); + filteredMessage.Contents = retainedContents; + retainedMessages.Add(filteredMessage); + } - public bool? IsHandoffFunction { get; set; } + return retainedMessages; } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffMessageFilterTests.cs new file mode 100644 index 0000000000..bbb53c31fb --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/HandoffMessageFilterTests.cs @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using FluentAssertions; +using Microsoft.Agents.AI.Workflows.Specialized; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Workflows.UnitTests; + +public class HandoffMessageFilterTests +{ + private List CreateTestMessages(bool firstAgentUsesCallId, bool secondAgentUsesCallId, HandoffToolCallFilteringBehavior filter = HandoffToolCallFilteringBehavior.None) + { + FunctionCallContent handoffRequest1 = CreateHandoffCall(1, firstAgentUsesCallId); + FunctionResultContent handoffResponse1 = CreateHandoffResponse(handoffRequest1); + + FunctionCallContent toolCall = CreateToolCall(secondAgentUsesCallId); + FunctionResultContent toolResponse = CreateToolResponse(toolCall); + + // Approvals come from the function call middleware over ChatClient, so we can expect there to be a RequestId (not that we + // care, because we do not filter approval content) + ToolApprovalRequestContent toolApproval = new(Guid.NewGuid().ToString("N"), toolCall); + ToolApprovalResponseContent toolApprovalResponse = new(toolApproval.RequestId, true, toolCall); + + FunctionCallContent handoffRequest2 = CreateHandoffCall(1, secondAgentUsesCallId); + FunctionResultContent handoffResponse2 = CreateHandoffResponse(handoffRequest2); + + List result = [new(ChatRole.User, "Hello")]; + + // Agent 1 turn + result.Add(new(ChatRole.Assistant, "Hello! What do you want help with today?")); + result.Add(new(ChatRole.User, "Please explain temperature")); + + // Unless we are filtering none, we expect the handoff call to be filtered out, so we add it conditionally + if (filter == HandoffToolCallFilteringBehavior.None) + { + result.Add(new(ChatRole.Assistant, [handoffRequest1])); + result.Add(new(ChatRole.Tool, [handoffResponse1])); + } + + // Agent 2 turn + + // Tool approvals are never filtered, so we add them unconditionally + result.Add(new(ChatRole.Assistant, [toolApproval])); + result.Add(new(ChatRole.User, [toolApprovalResponse])); + + // Unless we are filtering all, we expect the tool call to be retained, so we add it conditionally + if (filter != HandoffToolCallFilteringBehavior.All) + { + result.Add(new(ChatRole.Assistant, [toolCall])); + result.Add(new(ChatRole.Tool, [toolResponse])); + } + + result.Add(new(ChatRole.Assistant, "Temperature is a measure of the average kinetic energy of the particles in a substance.")); + + if (filter == HandoffToolCallFilteringBehavior.None) + { + result.Add(new(ChatRole.Assistant, [handoffRequest2])); + result.Add(new(ChatRole.Tool, [handoffResponse2])); + } + + return result; + } + + private static FunctionCallContent CreateHandoffCall(int id, bool useCallId) + { + string callName = $"{HandoffWorkflowBuilder.FunctionPrefix}{id}"; + string callId = useCallId ? Guid.NewGuid().ToString("N") : callName; + + return new FunctionCallContent(callId, callName); + } + + private static FunctionResultContent CreateHandoffResponse(FunctionCallContent call) + => HandoffAgentExecutor.CreateHandoffResult(call.CallId); + + private static FunctionCallContent CreateToolCall(bool useCallId) + { + const string CallName = "ToolFunction"; + string callId = useCallId ? Guid.NewGuid().ToString("N") : CallName; + + return new FunctionCallContent(callId, CallName); + } + + private static FunctionResultContent CreateToolResponse(FunctionCallContent call) + => new(call.CallId, new object()); + + [Theory] + [InlineData(true, true, HandoffToolCallFilteringBehavior.None)] + [InlineData(true, false, HandoffToolCallFilteringBehavior.None)] + [InlineData(false, true, HandoffToolCallFilteringBehavior.None)] + [InlineData(false, false, HandoffToolCallFilteringBehavior.None)] + [InlineData(true, true, HandoffToolCallFilteringBehavior.HandoffOnly)] + [InlineData(true, false, HandoffToolCallFilteringBehavior.HandoffOnly)] + [InlineData(false, true, HandoffToolCallFilteringBehavior.HandoffOnly)] + [InlineData(false, false, HandoffToolCallFilteringBehavior.HandoffOnly)] + [InlineData(true, true, HandoffToolCallFilteringBehavior.All)] + [InlineData(true, false, HandoffToolCallFilteringBehavior.All)] + [InlineData(false, true, HandoffToolCallFilteringBehavior.All)] + [InlineData(false, false, HandoffToolCallFilteringBehavior.All)] + public void Test_HandoffMessageFilter_FiltersOnlyExpectedMessages(bool firstAgentUsesCallId, bool secondAgentUsesCallId, HandoffToolCallFilteringBehavior behavior) + { + // Arrange + List messages = this.CreateTestMessages(firstAgentUsesCallId, secondAgentUsesCallId); + List expected = this.CreateTestMessages(firstAgentUsesCallId, secondAgentUsesCallId, behavior); + + HandoffMessagesFilter filter = new(behavior); + + // Act + IEnumerable filteredMessages = filter.FilterMessages(messages); + + // Assert + filteredMessages.Should().BeEquivalentTo(expected); + } +}