From eff5aee5aa63adba37eced2c7fa918db88644d19 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 20 Nov 2025 18:33:14 +0000 Subject: [PATCH 1/3] .NET: Add per run / thread feature collection support and improved custom ChatMessageStore support (#2345) * Add the ability to override services on an agent per run. * Remove Run from AgentFeatureCollection name. * Adding features param to GetNewThread. * Move feature collection. * Add features to DeserializeThread * Remove servicecollection based option * Add feature collection unit tests and fix bug identified in code review. * Add more unit tests for DelegatingAIAgent and AgentRunOptions * Fix formatting. * Address PR comments. * Switch to dedicated ConversationIdAgentFeature and improve 3rd party storage samples. * Fix bug in sample. --- .../Program.cs | 18 +- .../Program.cs | 243 ++++++++++++++---- .../src/Microsoft.Agents.AI.A2A/A2AAgent.cs | 6 +- .../AIAgent.cs | 30 +-- .../AgentRunOptions.cs | 6 + .../AgentThread.cs | 27 +- .../DelegatingAIAgent.cs | 6 +- .../Features/AgentFeatureCollection.cs | 192 ++++++++++++++ .../Features/ConversationIdAgentFeature.cs | 29 +++ .../Features/IAgentFeatureCollection.cs | 46 ++++ .../InMemoryAgentThread.cs | 6 - .../CopilotStudioAgent.cs | 6 +- .../DurableAIAgent.cs | 12 +- .../DurableAIAgentProxy.cs | 5 +- .../PurviewAgent.cs | 8 +- .../WorkflowHostAgent.cs | 4 +- .../WorkflowThread.cs | 3 - .../ChatClient/ChatClientAgent.cs | 102 ++++++-- .../ChatClient/ChatClientAgentOptions.cs | 10 + .../ChatClient/ChatClientAgentThread.cs | 31 --- .../A2AAgentTests.cs | 18 ++ .../AIAgentTests.cs | 23 +- .../AgentFeatureCollectionTests.cs | 119 +++++++++ .../AgentRunOptionsTests.cs | 4 +- .../AgentThreadTests.cs | 11 - .../DelegatingAIAgentTests.cs | 30 ++- .../BasicStreamingTests.cs | 8 +- .../SharedStateTests.cs | 4 +- ...AGUIEndpointRouteBuilderExtensionsTests.cs | 8 +- .../TestAgent.cs | 5 +- .../AgentExtensionsTests.cs | 4 +- .../ChatClient/ChatClientAgentTests.cs | 130 ++++++---- .../ChatClient/ChatClientAgentThreadTests.cs | 63 ----- .../ChatClientAgent_DeserializeThreadTests.cs | 143 +++++++++++ .../ChatClientAgent_GetNewThreadTests.cs | 176 +++++++++++++ .../TestAIAgent.cs | 4 +- .../AgentWorkflowBuilderTests.cs | 4 +- .../InProcessExecutionTests.cs | 8 +- .../RepresentationTests.cs | 4 +- .../Sample/06_GroupChat_Workflow.cs | 4 +- .../SpecializedExecutorSmokeTests.cs | 4 +- .../TestEchoAgent.cs | 4 +- 42 files changed, 1212 insertions(+), 356 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs create mode 100644 dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs create mode 100644 dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs create mode 100644 dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index fd00618f5f..66e3ea2a52 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -28,10 +28,10 @@ internal sealed class UpperCaseParrotAgent : AIAgent { public override string? Name => "UpperCaseParrotAgent"; - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new CustomAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new CustomAgentThread(serializedThread, jsonSerializerOptions); public override async Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) @@ -39,11 +39,16 @@ public override async Task RunAsync(IEnumerable m // Create a thread if the user didn't supply one. thread ??= this.GetNewThread(); + if (thread is not CustomAgentThread typedThread) + { + throw new ArgumentException($"The provided thread is not of type {nameof(CustomAgentThread)}.", nameof(thread)); + } + // Clone the input messages and turn them into response messages with upper case text. List responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList(); // Notify the thread of the input and output messages. - await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken); + await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken); return new AgentRunResponse { @@ -58,11 +63,16 @@ public override async IAsyncEnumerable RunStreamingAsync // Create a thread if the user didn't supply one. thread ??= this.GetNewThread(); + if (thread is not CustomAgentThread typedThread) + { + throw new ArgumentException($"The provided thread is not of type {nameof(CustomAgentThread)}.", nameof(thread)); + } + // Clone the input messages and turn them into response messages with upper case text. List responseMessages = CloneAndToUpperCase(messages, this.DisplayName).ToList(); // Notify the thread of the input and output messages. - await NotifyThreadOfNewMessagesAsync(thread, messages.Concat(responseMessages), cancellationToken); + await typedThread.MessageStore.AddMessagesAsync(messages.Concat(responseMessages), cancellationToken); foreach (var message in responseMessages) { diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs index 8986734972..cacc237f4d 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs @@ -21,53 +21,200 @@ // Replace this with a vector store implementation of your choice if you want to persist the chat history to disk. VectorStore vectorStore = new InMemoryVectorStore(); -// Create the agent -AIAgent agent = new AzureOpenAIClient( - new Uri(endpoint), - new AzureCliCredential()) - .GetChatClient(deploymentName) - .CreateAIAgent(new ChatClientAgentOptions - { - Instructions = "You are good at telling jokes.", - Name = "Joker", - ChatMessageStoreFactory = ctx => +// Execute various samples showing how to use a custom ChatMessageStore with an agent. +await CustomChatMessageStore_UsingFactory_Async(); +await CustomChatMessageStore_UsingFactoryAndExistingExternalId_Async(); +await CustomChatMessageStore_PerThread_Async(); +await CustomChatMessageStore_PerRun_Async(); + +// Here we can see how to create a custom ChatMessageStore using a factory method +// provided to the agent via the ChatMessageStoreFactory option. +// This allows us to use a custom chat message store, where the consumer of the agent +// doesn't need to know anything about the storage mechanism used. +async Task CustomChatMessageStore_UsingFactory_Async() +{ + Console.WriteLine("\n--- With Factory ---\n"); + + // Create the agent + AIAgent agent = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + // Use a service that doesn't require storage of chat history in the service itself. + .GetChatClient(deploymentName) + .CreateAIAgent(new ChatClientAgentOptions { - // Create a new chat message store for this agent that stores the messages in a vector store. - // Each thread must get its own copy of the VectorChatMessageStore, since the store - // also contains the id that the thread is stored under. - return new VectorChatMessageStore(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions); - } - }); + Instructions = "You are good at telling jokes.", + Name = "Joker", + ChatMessageStoreFactory = ctx => + { + // Create a new chat message store for this agent that stores the messages in a vector store. + // Each thread must get its own copy of the VectorChatMessageStore, since the store + // also contains the id that the thread is stored under. + return new VectorChatMessageStore(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions, ctx.Features); + } + }); + + // Start a new thread for the agent conversation. + AgentThread thread = agent.GetNewThread(); + + // Run the agent with the thread that stores conversation history in the vector store. + Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); + + // Serialize the thread state, so it can be stored for later use. + // Since the chat history is stored in the vector store, the serialized thread + // only contains the guid that the messages are stored under in the vector store. + JsonElement serializedThread = thread.Serialize(); + + Console.WriteLine("\n--- Serialized thread ---\n"); + Console.WriteLine(JsonSerializer.Serialize(serializedThread, new JsonSerializerOptions { WriteIndented = true })); + + // The serialized thread can now be saved to a database, file, or any other storage mechanism + // and loaded again later. + + // Deserialize the thread state after loading from storage. + AgentThread resumedThread = agent.DeserializeThread(serializedThread); + + // Run the agent with the thread that stores conversation history in the vector store a second time. + Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); +} + +// Here we can see how to create a custom ChatMessageStore using a factory method +// provided to the agent via the ChatMessageStoreFactory option. +// It also shows how we can pass a custom storage id at runtime to the message store using +// the VectorChatMessageStoreThreadDbKeyFeature. +// Note that not all agents or chat message stores may support this feature. +async Task CustomChatMessageStore_UsingFactoryAndExistingExternalId_Async() +{ + Console.WriteLine("\n--- With Factory and Existing External ID ---\n"); + + // Create the agent + AIAgent agent = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + // Use a service that doesn't require storage of chat history in the service itself. + .GetChatClient(deploymentName) + .CreateAIAgent(new ChatClientAgentOptions + { + Instructions = "You are good at telling jokes.", + Name = "Joker", + ChatMessageStoreFactory = ctx => + { + // Create a new chat message store for this agent that stores the messages in a vector store. + // Each thread must get its own copy of the VectorChatMessageStore, since the store + // also contains the id that the thread is stored under. + return new VectorChatMessageStore(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions, ctx.Features); + } + }); + + // Start a new thread for the agent conversation. + AgentThread thread = agent.GetNewThread(); + + // Run the agent with the thread that stores conversation history in the vector store. + Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); + + // We can access the VectorChatMessageStore via the thread's GetService method if we need to read the key under which threads are stored. + var messageStoreFromFactory = thread.GetService()!; + Console.WriteLine($"\nThread is stored in vector store under key: {messageStoreFromFactory.ThreadDbKey}"); + + // It's possible to create a new thread that uses the same chat message store id by providing + // the VectorChatMessageStoreThreadDbKeyFeature in the feature collection when creating the new thread. + AgentFeatureCollection features = new(); + features.Set(new VectorChatMessageStoreThreadDbKeyFeature(messageStoreFromFactory.ThreadDbKey!)); + AgentThread resumedThread = agent.GetNewThread(features); -// Start a new thread for the agent conversation. -AgentThread thread = agent.GetNewThread(); + // Run the agent with the thread that stores conversation history in the vector store. + Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); +} -// Run the agent with the thread that stores conversation history in the vector store. -Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); +// Here we can see how to create a custom ChatMessageStore and pass it to the thread +// when creating a new thread. +async Task CustomChatMessageStore_PerThread_Async() +{ + Console.WriteLine("\n--- Per Thread ---\n"); -// Serialize the thread state, so it can be stored for later use. -// Since the chat history is stored in the vector store, the serialized thread -// only contains the guid that the messages are stored under in the vector store. -JsonElement serializedThread = thread.Serialize(); + // We can also create an agent without a factory that provides a ChatMessageStore. + AIAgent agent = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + // Use a service that doesn't require storage of chat history in the service itself. + .GetChatClient(deploymentName) + .CreateAIAgent(new ChatClientAgentOptions + { + Instructions = "You are good at telling jokes.", + Name = "Joker" + }); -Console.WriteLine("\n--- Serialized thread ---\n"); -Console.WriteLine(JsonSerializer.Serialize(serializedThread, new JsonSerializerOptions { WriteIndented = true })); + // Instead of using a factory on the agent to create the ChatMessageStore, we can + // create a VectorChatMessageStore ourselves and register it in a feature collection. + // We can then pass the feature collection when creating a new thread. + // We also have the opportunity here to pass any id that we want for storing the chat history in the vector store. + VectorChatMessageStore perThreadMessageStore = new(vectorStore, "chat-history-1"); + AgentFeatureCollection features = new(); + features.Set(perThreadMessageStore); + AgentThread thread = agent.GetNewThread(features); -// The serialized thread can now be saved to a database, file, or any other storage mechanism -// and loaded again later. + Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); -// Deserialize the thread state after loading from storage. -AgentThread resumedThread = agent.DeserializeThread(serializedThread); + // When serializing this thread, we'll see that it has the id from the message store stored in its state. + JsonElement serializedThread = thread.Serialize(); -// Run the agent with the thread that stores conversation history in the vector store a second time. -Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); + Console.WriteLine("\n--- Serialized thread ---\n"); + Console.WriteLine(JsonSerializer.Serialize(serializedThread, new JsonSerializerOptions { WriteIndented = true })); +} -// We can access the VectorChatMessageStore via the thread's GetService method if we need to read the key under which threads are stored. -var messageStore = resumedThread.GetService()!; -Console.WriteLine($"\nThread is stored in vector store under key: {messageStore.ThreadDbKey}"); +// Here we can see how to create a custom ChatMessageStore for a single run using the Features option +// passed when we run the agent. +// Note that if the agent doesn't support a chat message store, it would be ignored. +async Task CustomChatMessageStore_PerRun_Async() +{ + Console.WriteLine("\n--- Per Run ---\n"); + + // We can also create an agent without a factory that provides a ChatMessageStore. + AIAgent agent = new AzureOpenAIClient( + new Uri(endpoint), + new AzureCliCredential()) + // Use a service that doesn't require storage of chat history in the service itself. + .GetChatClient(deploymentName) + .CreateAIAgent(new ChatClientAgentOptions + { + Instructions = "You are good at telling jokes.", + Name = "Joker" + }); + + // Start a new thread for the agent conversation. + AgentThread thread = agent.GetNewThread(); + + // Instead of using a factory on the agent to create the ChatMessageStore, we can + // create a VectorChatMessageStore ourselves and register it in a feature collection. + // We can then pass the feature collection to the agent when running it by using the Features option. + // The message store would only be used for the run that it's passed to. + // If the agent doesn't support a message store, it would be ignored. + // We also have the opportunity here to pass any id that we want for storing the chat history in the vector store. + VectorChatMessageStore perRunMessageStore = new(vectorStore, "chat-history-1"); + AgentFeatureCollection features = new(); + features.Set(perRunMessageStore); + + Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread, options: new AgentRunOptions() { Features = features })); + + // When serializing this thread, we'll see that it has no messagestore state, since the messagestore was not attached to the thread, + // but just provided for the single run. Note that, depending on the circumstances, the thread may still contain other state, e.g. Memories, + // if an AIContextProvider is attached which adds memory to an agent. + JsonElement serializedThread = thread.Serialize(); + + Console.WriteLine("\n--- Serialized thread ---\n"); + Console.WriteLine(JsonSerializer.Serialize(serializedThread, new JsonSerializerOptions { WriteIndented = true })); +} namespace SampleApp { + /// + /// A feature that allows providing the thread database key for the . + /// + internal sealed class VectorChatMessageStoreThreadDbKeyFeature(string threadDbKey) + { + public string ThreadDbKey { get; } = threadDbKey; + } + /// /// A sample implementation of that stores chat messages in a vector store. /// @@ -75,29 +222,35 @@ internal sealed class VectorChatMessageStore : ChatMessageStore { private readonly VectorStore _vectorStore; - public VectorChatMessageStore(VectorStore vectorStore, JsonElement serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null) + public VectorChatMessageStore(VectorStore vectorStore, string threadDbKey) + { + this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); + this.ThreadDbKey = threadDbKey ?? throw new ArgumentNullException(nameof(threadDbKey)); + } + + public VectorChatMessageStore(VectorStore vectorStore, JsonElement serializedStoreState, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? features = null) { this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); - if (serializedStoreState.ValueKind is JsonValueKind.String) - { - // Here we can deserialize the thread id so that we can access the same messages as before the suspension. - this.ThreadDbKey = serializedStoreState.Deserialize(); - } + // Here we can deserialize the thread id so that we can access the same messages as before the suspension, or if + // a user provided a ConversationIdAgentFeature in the features collection, we can use that + // or finally we can generate one ourselves. + this.ThreadDbKey = serializedStoreState.ValueKind is JsonValueKind.String + ? serializedStoreState.Deserialize() + : features?.Get()?.ThreadDbKey + ?? Guid.NewGuid().ToString("N"); } - public string? ThreadDbKey { get; private set; } + public string? ThreadDbKey { get; } public override async Task AddMessagesAsync(IEnumerable messages, CancellationToken cancellationToken = default) { - this.ThreadDbKey ??= Guid.NewGuid().ToString("N"); - var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); await collection.UpsertAsync(messages.Select(x => new ChatHistoryItem() { - Key = this.ThreadDbKey + x.MessageId, + Key = this.ThreadDbKey + (string.IsNullOrWhiteSpace(x.MessageId) ? Guid.NewGuid().ToString("N") : x.MessageId), Timestamp = DateTimeOffset.UtcNow, ThreadId = this.ThreadDbKey, SerializedMessage = JsonSerializer.Serialize(x), diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs index cafbf90b87..15df3ea5ae 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs @@ -54,8 +54,8 @@ public A2AAgent(A2AClient a2aClient, string? id = null, string? name = null, str } /// - public sealed override AgentThread GetNewThread() - => new A2AAgentThread(); + public sealed override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) + => new A2AAgentThread() { ContextId = featureCollection?.Get()?.ConversationId }; /// /// Get a new instance using an existing context id, to continue that conversation. @@ -66,7 +66,7 @@ public AgentThread GetNewThread(string contextId) => new A2AAgentThread() { ContextId = contextId }; /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new A2AAgentThread(serializedThread, jsonSerializerOptions); /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs index 35aa866552..2921b88724 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIAgent.cs @@ -108,6 +108,7 @@ public abstract class AIAgent /// /// Creates a new conversation thread that is compatible with this agent. /// + /// An optional feature collection to override or provide additional context or capabilities to the thread where the thread supports these features. /// A new instance ready for use with this agent. /// /// @@ -121,13 +122,14 @@ public abstract class AIAgent /// may be deferred until first use to optimize performance. /// /// - public abstract AgentThread GetNewThread(); + public abstract AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null); /// /// Deserializes an agent thread from its JSON serialized representation. /// /// A containing the serialized thread state. /// Optional settings to customize the deserialization process. + /// An optional feature collection to override or provide additional context or capabilities to the thread where the thread supports these features. /// A restored instance with the state from . /// The is not in the expected format. /// The serialized data is invalid or cannot be deserialized. @@ -136,7 +138,7 @@ public abstract class AIAgent /// allowing conversations to resume across application restarts or be migrated between /// different agent instances. /// - public abstract AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null); + public abstract AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null); /// /// Run the agent with no message assuming that all required instructions are already provided to the agent or on the thread. @@ -328,28 +330,4 @@ public abstract IAsyncEnumerable RunStreamingAsync( AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default); - - /// - /// Notifies the specified thread about new messages that have been added to the conversation. - /// - /// The conversation thread to notify about the new messages. - /// The collection of new messages to report to the thread. - /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous notification operation. - /// or is . - /// - /// - /// This method ensures that conversation threads are kept informed about message additions, which - /// is important for threads that manage their own state, memory components, or derived context. - /// While all agent implementations should notify their threads, the specific actions taken by - /// each thread type may vary. - /// - /// - protected static async Task NotifyThreadOfNewMessagesAsync(AgentThread thread, IEnumerable messages, CancellationToken cancellationToken) - { - _ = Throw.IfNull(thread); - _ = Throw.IfNull(messages); - - await thread.MessagesReceivedAsync(messages, cancellationToken).ConfigureAwait(false); - } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs index 7262979207..d12afae76b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRunOptions.cs @@ -34,6 +34,7 @@ public AgentRunOptions(AgentRunOptions options) this.ContinuationToken = options.ContinuationToken; this.AllowBackgroundResponses = options.AllowBackgroundResponses; this.AdditionalProperties = options.AdditionalProperties?.Clone(); + this.Features = options.Features; } /// @@ -90,4 +91,9 @@ public AgentRunOptions(AgentRunOptions options) /// preserving implementation-specific details or extending the options with custom data. /// public AdditionalPropertiesDictionary? AdditionalProperties { get; set; } + + /// + /// Gets or sets the collection of features provided by the caller and middleware for this run. + /// + public IAgentFeatureCollection? Features { get; set; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs index fb5863a5c9..bfb52e021b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentThread.cs @@ -1,11 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -30,8 +26,8 @@ namespace Microsoft.Agents.AI; /// Chat history reduction, e.g. where messages needs to be summarized or truncated to reduce the size. /// /// An is always constructed by an so that the -/// can attach any necessary behaviors to the . See the -/// and methods for more information. +/// can attach any necessary behaviors to the . See the +/// and methods for more information. /// /// /// Because of these behaviors, an may not be reusable across different agents, since each agent @@ -41,13 +37,13 @@ namespace Microsoft.Agents.AI; /// To support conversations that may need to survive application restarts or separate service requests, an can be serialized /// and deserialized, so that it can be saved in a persistent store. /// The provides the method to serialize the thread to a -/// and the method +/// and the method /// can be used to deserialize the thread. /// /// /// -/// -/// +/// +/// public abstract class AgentThread { /// @@ -65,19 +61,6 @@ protected AgentThread() public virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => default; - /// - /// This method is called when new messages have been contributed to the chat by any participant. - /// - /// - /// Inheritors can use this method to update their context based on the new message. - /// - /// The new messages. - /// The to monitor for cancellation requests. The default is . - /// A task that completes when the context has been updated. - /// The thread has been deleted. - protected internal virtual Task MessagesReceivedAsync(IEnumerable newMessages, CancellationToken cancellationToken = default) - => Task.CompletedTask; - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs index 353c82c996..a542a841a3 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/DelegatingAIAgent.cs @@ -74,11 +74,11 @@ protected DelegatingAIAgent(AIAgent innerAgent) } /// - public override AgentThread GetNewThread() => this.InnerAgent.GetNewThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => this.InnerAgent.GetNewThread(featureCollection); /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => this.InnerAgent.DeserializeThread(serializedThread, jsonSerializerOptions); + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) + => this.InnerAgent.DeserializeThread(serializedThread, jsonSerializerOptions, featureCollection); /// public override Task RunAsync( diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs new file mode 100644 index 0000000000..475f54c77e --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +#pragma warning disable CA1043 // Use Integral Or String Argument For Indexers + +/// +/// Default implementation for . +/// +[DebuggerDisplay("Count = {GetCount()}")] +[DebuggerTypeProxy(typeof(FeatureCollectionDebugView))] +public class AgentFeatureCollection : IAgentFeatureCollection +{ + private static readonly KeyComparer s_featureKeyComparer = new(); + private readonly IAgentFeatureCollection? _defaults; + private readonly int _initialCapacity; + private Dictionary? _features; + private volatile int _containerRevision; + + /// + /// Initializes a new instance of . + /// + public AgentFeatureCollection() + { + } + + /// + /// Initializes a new instance of with the specified initial capacity. + /// + /// The initial number of elements that the collection can contain. + /// is less than 0 + public AgentFeatureCollection(int initialCapacity) + { + Throw.IfLessThan(initialCapacity, 0); + + this._initialCapacity = initialCapacity; + } + + /// + /// Initializes a new instance of with the specified defaults. + /// + /// The feature defaults. + public AgentFeatureCollection(IAgentFeatureCollection defaults) + { + this._defaults = defaults; + } + + /// + public virtual int Revision + { + get { return this._containerRevision + (this._defaults?.Revision ?? 0); } + } + + /// + public bool IsReadOnly { get { return false; } } + + /// + public object? this[Type key] + { + get + { + Throw.IfNull(key); + + return this._features != null && this._features.TryGetValue(key, out var result) ? result : this._defaults?[key]; + } + set + { + Throw.IfNull(key); + + if (value == null) + { + if (this._features?.Remove(key) is true) + { + this._containerRevision++; + } + return; + } + + if (this._features == null) + { + this._features = new Dictionary(this._initialCapacity); + } + this._features[key] = value; + this._containerRevision++; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + + /// + public IEnumerator> GetEnumerator() + { + if (this._features != null) + { + foreach (var pair in this._features) + { + yield return pair; + } + } + + if (this._defaults != null) + { + // Don't return features masked by the wrapper. + foreach (var pair in this._features == null ? this._defaults : this._defaults.Except(this._features, s_featureKeyComparer)) + { + yield return pair; + } + } + } + + /// + public TFeature? Get() + { + if (typeof(TFeature).IsValueType) + { + var feature = this[typeof(TFeature)]; + if (feature is null && Nullable.GetUnderlyingType(typeof(TFeature)) is null) + { + throw new InvalidOperationException( + $"{typeof(TFeature).FullName} does not exist in the feature collection " + + $"and because it is a struct the method can't return null. Use 'AgentFeatureCollection[typeof({typeof(TFeature).FullName})] is not null' to check if the feature exists."); + } + return (TFeature?)feature; + } + return (TFeature?)this[typeof(TFeature)]; + } + + /// + public void Set(TFeature? instance) + { + this[typeof(TFeature)] = instance; + } + + // Used by the debugger. Count over enumerable is required to get the correct value. + private int GetCount() => this.Count(); + + private sealed class KeyComparer : IEqualityComparer> + { + public bool Equals(KeyValuePair x, KeyValuePair y) + { + return x.Key.Equals(y.Key); + } + + public int GetHashCode(KeyValuePair obj) + { + return obj.Key.GetHashCode(); + } + } + + private sealed class FeatureCollectionDebugView(AgentFeatureCollection features) + { + private readonly AgentFeatureCollection _features = features; + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public DictionaryItemDebugView[] Items => this._features.Select(pair => new DictionaryItemDebugView(pair)).ToArray(); + } + + /// + /// Defines a key/value pair for displaying an item of a dictionary by a debugger. + /// + [DebuggerDisplay("{Value}", Name = "[{Key}]")] + internal readonly struct DictionaryItemDebugView + { + public DictionaryItemDebugView(TKey key, TValue value) + { + this.Key = key; + this.Value = value; + } + + public DictionaryItemDebugView(KeyValuePair keyValue) + { + this.Key = keyValue.Key; + this.Value = keyValue.Value; + } + + [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)] + public TKey Key { get; } + + [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)] + public TValue Value { get; } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs new file mode 100644 index 0000000000..6a17456c73 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// An agent feature that allows providing a conversation identifier. +/// +/// +/// This feature allows a user to provide a specific identifier for chat history whether stored in the underlying AI service or stored in a 3rd party store. +/// +public class ConversationIdAgentFeature +{ + /// + /// Initializes a new instance of the class with the specified thread + /// identifier. + /// + /// The unique identifier of the thread required by the underlying AI service or 3rd party store. Cannot be or empty. + public ConversationIdAgentFeature(string conversationId) + { + this.ConversationId = Throw.IfNullOrWhitespace(conversationId); + } + + /// + /// Gets the conversation identifier. + /// + public string ConversationId { get; } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs new file mode 100644 index 0000000000..f2e7f38f86 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; + +namespace Microsoft.Agents.AI; + +#pragma warning disable CA1043 // Use Integral Or String Argument For Indexers +#pragma warning disable CA1716 // Identifiers should not match keywords + +/// +/// Represents a collection of Agent features. +/// +public interface IAgentFeatureCollection : IEnumerable> +{ + /// + /// Indicates if the collection can be modified. + /// + bool IsReadOnly { get; } + + /// + /// Incremented for each modification and can be used to verify cached results. + /// + int Revision { get; } + + /// + /// Gets or sets a given feature. Setting a null value removes the feature. + /// + /// + /// The requested feature, or null if it is not present. + object? this[Type key] { get; set; } + + /// + /// Retrieves the requested feature from the collection. + /// + /// The feature key. + /// The requested feature, or null if it is not present. + TFeature? Get(); + + /// + /// Sets the given feature in the collection. + /// + /// The feature key. + /// The feature value. + void Set(TFeature? instance); +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentThread.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentThread.cs index af6080a715..13fcc134f0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentThread.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentThread.cs @@ -4,8 +4,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI; @@ -116,10 +114,6 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio public override object? GetService(Type serviceType, object? serviceKey = null) => base.GetService(serviceType, serviceKey) ?? this.MessageStore?.GetService(serviceType, serviceKey); - /// - protected internal override Task MessagesReceivedAsync(IEnumerable newMessages, CancellationToken cancellationToken = default) - => this.MessageStore.AddMessagesAsync(newMessages, cancellationToken); - [DebuggerBrowsable(DebuggerBrowsableState.Never)] private string DebuggerDisplay => $"Count = {this.MessageStore.Count}"; diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs index 6ca2f38d3d..634236269a 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs @@ -42,8 +42,8 @@ public CopilotStudioAgent(CopilotClient client, ILoggerFactory? loggerFactory = } /// - public sealed override AgentThread GetNewThread() - => new CopilotStudioAgentThread(); + public sealed override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) + => new CopilotStudioAgentThread() { ConversationId = featureCollection?.Get()?.ConversationId }; /// /// Get a new instance using an existing conversation id, to continue that conversation. @@ -54,7 +54,7 @@ public AgentThread GetNewThread(string conversationId) => new CopilotStudioAgentThread() { ConversationId = conversationId }; /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new CopilotStudioAgentThread(serializedThread, jsonSerializerOptions); /// diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs index 1a117aff14..fc8ca78682 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgent.cs @@ -33,21 +33,17 @@ internal DurableAIAgent(TaskOrchestrationContext context, string agentName) /// Creates a new agent thread for this agent using a random session ID. /// /// A new agent thread. - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { AgentSessionId sessionId = this._context.NewAgentSessionId(this._agentName); return new DurableAgentThread(sessionId); } - /// - /// Deserializes an agent thread from JSON. - /// - /// The serialized thread data. - /// Optional JSON serializer options. - /// The deserialized agent thread. + /// public override AgentThread DeserializeThread( JsonElement serializedThread, - JsonSerializerOptions? jsonSerializerOptions = null) + JsonSerializerOptions? jsonSerializerOptions = null, + IAgentFeatureCollection? featureCollection = null) { return DurableAgentThread.Deserialize(serializedThread, jsonSerializerOptions); } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs index 58f9598a7e..0078266896 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAIAgentProxy.cs @@ -13,12 +13,13 @@ internal class DurableAIAgentProxy(string name, IDurableAgentClient agentClient) public override AgentThread DeserializeThread( JsonElement serializedThread, - JsonSerializerOptions? jsonSerializerOptions = null) + JsonSerializerOptions? jsonSerializerOptions = null, + IAgentFeatureCollection? featureCollection = null) { return DurableAgentThread.Deserialize(serializedThread, jsonSerializerOptions); } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { return new DurableAgentThread(AgentSessionId.WithRandomKey(this.Name!)); } diff --git a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs index fd2a1950e9..0fb0ee13f3 100644 --- a/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Purview/PurviewAgent.cs @@ -30,15 +30,15 @@ public PurviewAgent(AIAgent innerAgent, PurviewWrapper purviewWrapper) } /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { - return this._innerAgent.DeserializeThread(serializedThread, jsonSerializerOptions); + return this._innerAgent.DeserializeThread(serializedThread, jsonSerializerOptions, featureCollection); } /// - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { - return this._innerAgent.GetNewThread(); + return this._innerAgent.GetNewThread(featureCollection); } /// diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs index 98dc5903bf..b1767b6ea4 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs @@ -61,9 +61,9 @@ private async ValueTask ValidateWorkflowAsync() protocol.ThrowIfNotChatProtocol(); } - public override AgentThread GetNewThread() => new WorkflowThread(this._workflow, this.GenerateNewId(), this._executionEnvironment, this._checkpointManager); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new WorkflowThread(this._workflow, this.GenerateNewId(), this._executionEnvironment, this._checkpointManager); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new WorkflowThread(this._workflow, serializedThread, this._executionEnvironment, this._checkpointManager, jsonSerializerOptions); private async ValueTask UpdateThreadAsync(IEnumerable messages, AgentThread? thread = null, CancellationToken cancellationToken = default) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowThread.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowThread.cs index ffa044791f..d27de6bd5c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowThread.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowThread.cs @@ -68,9 +68,6 @@ public WorkflowThread(Workflow workflow, JsonElement serializedThread, IWorkflow public CheckpointInfo? LastCheckpoint { get; set; } - protected override Task MessagesReceivedAsync(IEnumerable newMessages, CancellationToken cancellationToken = default) - => this.MessageStore.AddMessagesAsync(newMessages, cancellationToken); - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { JsonMarshaller marshaller = new(jsonSerializerOptions); diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index d04d9bb9fb..207492f482 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -270,7 +270,7 @@ public override async IAsyncEnumerable RunStreamingAsync this.UpdateThreadWithTypeAndConversationId(safeThread, chatResponse.ConversationId); // To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request. - await NotifyThreadOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), cancellationToken).ConfigureAwait(false); + await NotifyMessageStoreOfNewMessagesAsync(safeThread, inputMessages.Concat(aiContextProviderMessages ?? []).Concat(chatResponse.Messages), options, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. await NotifyAIContextProviderOfSuccessAsync(safeThread, inputMessages, aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); @@ -286,10 +286,16 @@ public override async IAsyncEnumerable RunStreamingAsync : this.ChatClient.GetService(serviceType, serviceKey)); /// - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new ChatClientAgentThread { - AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) + ConversationId = featureCollection?.Get()?.ConversationId, + MessageStore = + featureCollection?.Get() + ?? this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, Features = featureCollection, JsonSerializerOptions = null }), + AIContextProvider = + featureCollection?.Get() + ?? this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, Features = featureCollection, JsonSerializerOptions = null }) }; /// @@ -316,16 +322,52 @@ public AgentThread GetNewThread(string conversationId) AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) }; + /// + /// Creates a new agent thread instance using an existing to continue a conversation. + /// + /// The chat history of the existing conversation to continue. + /// + /// A new instance configured to work with the provided . + /// + /// + /// + /// This method creates threads that do not support server-side conversation storage. + /// Some AI services require server-side conversation storage to function properly, and creating a thread + /// with a may not be compatible with these services. + /// + /// + /// Where a service requires server-side conversation storage, use . + /// + /// + /// If the agent detects, during the first run, that the underlying AI service requires server-side conversation storage, + /// the thread will throw an exception to indicate that it cannot continue using the provided . + /// + /// + public AgentThread GetNewThread(ChatMessageStore chatMessageStore) + => new ChatClientAgentThread() + { + MessageStore = chatMessageStore, + AIContextProvider = this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) + }; + /// - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { - Func? chatMessageStoreFactory = this._agentOptions?.ChatMessageStoreFactory is null ? - null : - (jse, jso) => this._agentOptions.ChatMessageStoreFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }); - - Func? aiContextProviderFactory = this._agentOptions?.AIContextProviderFactory is null ? - null : - (jse, jso) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }); + var chatMessageStoreFeature = featureCollection?.Get(); + Func? chatMessageStoreFactory = + chatMessageStoreFeature is not null + ? (jse, jso) => chatMessageStoreFeature + : this._agentOptions?.ChatMessageStoreFactory is not null + ? (jse, jso) => this._agentOptions.ChatMessageStoreFactory.Invoke(new() { SerializedState = jse, Features = featureCollection, JsonSerializerOptions = jso }) + : null; + + var aiContextProviderFeature = featureCollection?.Get(); + Func? aiContextProviderFactory = + aiContextProviderFeature is not null + ? (jse, jso) => aiContextProviderFeature + : this._agentOptions?.AIContextProviderFactory is not null + ? (jse, jso) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, Features = featureCollection, JsonSerializerOptions = jso }) + : null; return new ChatClientAgentThread( serializedThread, @@ -384,7 +426,7 @@ private async Task RunCoreAsync() is ChatMessageStore chatMessageStoreFeature) + { + messageStore = chatMessageStoreFeature; + } + // Add any existing messages from the thread to the messages to be sent to the chat client. - if (typedThread.MessageStore is not null) + if (messageStore is not null) { - inputMessagesForChatClient.AddRange(await typedThread.MessageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false)); + inputMessagesForChatClient.AddRange(await messageStore.GetMessagesAsync(cancellationToken).ConfigureAwait(false)); } // If we have an AIContextProvider, we should get context from it, and update our @@ -684,8 +735,29 @@ private void UpdateThreadWithTypeAndConversationId(ChatClientAgentThread thread, // If the service doesn't use service side thread storage (i.e. we got no id back from invocation), and // the thread has no MessageStore yet, and we have a custom messages store, we should update the thread // with the custom MessageStore so that it has somewhere to store the chat history. - thread.MessageStore ??= this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }); + thread.MessageStore ??= this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }) ?? new InMemoryChatMessageStore(); + } + } + + private static Task NotifyMessageStoreOfNewMessagesAsync(ChatClientAgentThread thread, IEnumerable newMessages, AgentRunOptions? runOptions, CancellationToken cancellationToken) + { + var messageStore = thread.MessageStore; + + // If the caller provided an override message store via run options, we should use that instead of the message store + // on the thread. + if (runOptions?.Features?.Get() is ChatMessageStore chatMessageStoreFeature) + { + messageStore = chatMessageStoreFeature; } + + // Only notify the message store if we have one. + // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. + if (messageStore is not null) + { + return messageStore.AddMessagesAsync(newMessages, cancellationToken); + } + + return Task.CompletedTask; } private string GetLoggingAgentName() => this.Name ?? "UnnamedAgent"; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs index f83e6912d5..2c0041ba7d 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs @@ -128,6 +128,11 @@ public class AIContextProviderFactoryContext /// Gets or sets the JSON serialization options to use when deserializing the . /// public JsonSerializerOptions? JsonSerializerOptions { get; set; } + + /// + /// Gets or sets the collection of features provided by the caller and middleware. + /// + public IAgentFeatureCollection? Features { get; set; } } /// @@ -145,5 +150,10 @@ public class ChatMessageStoreFactoryContext /// Gets or sets the JSON serialization options to use when deserializing the . /// public JsonSerializerOptions? JsonSerializerOptions { get; set; } + + /// + /// Gets or sets the collection of features provided by the caller and middleware. + /// + public IAgentFeatureCollection? Features { get; set; } } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentThread.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentThread.cs index f0f51895b2..d125fa8ea2 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentThread.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentThread.cs @@ -1,12 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Diagnostics; using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -182,33 +178,6 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio ?? this.AIContextProvider?.GetService(serviceType, serviceKey) ?? this.MessageStore?.GetService(serviceType, serviceKey); - /// - protected override async Task MessagesReceivedAsync(IEnumerable newMessages, CancellationToken cancellationToken = default) - { - switch (this) - { - case { ConversationId: not null }: - // If the thread messages are stored in the service - // there is nothing to do here, since invoking the - // service should already update the thread. - break; - - case { MessageStore: null }: - // If there is no conversation id, and no store we can createa a default in memory store and add messages to it. - this._messageStore = new InMemoryChatMessageStore(); - await this._messageStore!.AddMessagesAsync(newMessages, cancellationToken).ConfigureAwait(false); - break; - - case { MessageStore: not null }: - // If a store has been provided, we need to add the messages to the store. - await this._messageStore!.AddMessagesAsync(newMessages, cancellationToken).ConfigureAwait(false); - break; - - default: - throw new UnreachableException(); - } - } - [DebuggerBrowsable(DebuggerBrowsableState.Never)] private string DebuggerDisplay => this._conversationId is { } conversationId ? $"ConversationId = {conversationId}" : diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs index 9399d99528..c163bbcc1d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentTests.cs @@ -73,6 +73,24 @@ public void Constructor_WithDefaultParameters_UsesBaseProperties() Assert.Equal(agent.Id, agent.DisplayName); } + [Fact] + public void GetNewThread_WithStringFeature_UsesItForContextId() + { + // Arrange + var contextIdFeature = new ConversationIdAgentFeature("feature-context-id"); + var agentWithFeature = new A2AAgent(this._a2aClient); + + // Act + var features = new AgentFeatureCollection(); + features.Set(contextIdFeature); + var thread = agentWithFeature.GetNewThread(features); + + // Assert + Assert.IsType(thread); + var a2aThread = (A2AAgentThread)thread; + Assert.Equal(contextIdFeature.ConversationId, a2aThread.ContextId); + } + [Fact] public async Task RunAsync_AllowsNonUserRoleMessagesAsync() { diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs index bfa14a89d4..0f265466a3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIAgentTests.cs @@ -8,7 +8,6 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; -using Moq.Protected; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -222,21 +221,6 @@ public void ValidateAgentIDIsIdempotent() Assert.Equal(id, agent.Id); } - [Fact] - public async Task NotifyThreadOfNewMessagesNotifiesThreadAsync() - { - var cancellationToken = default(CancellationToken); - - var messages = new[] { new ChatMessage(ChatRole.User, "msg1"), new ChatMessage(ChatRole.User, "msg2") }; - - var threadMock = new Mock { CallBase = true }; - threadMock.SetupAllProperties(); - - await MockAgent.NotifyThreadOfNewMessagesAsync(threadMock.Object, messages, cancellationToken); - - threadMock.Protected().Verify("MessagesReceivedAsync", Times.Once(), messages, cancellationToken); - } - #region GetService Method Tests /// @@ -360,13 +344,10 @@ public abstract class TestAgentThread : AgentThread; private sealed class MockAgent : AIAgent { - public static new Task NotifyThreadOfNewMessagesAsync(AgentThread thread, IEnumerable messages, CancellationToken cancellationToken) => - AIAgent.NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken); - - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); public override Task RunAsync( diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs new file mode 100644 index 0000000000..7b55cfd64f --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains unit tests for the class. +/// +public class AgentFeatureCollectionTests +{ + [Fact] + public void AddedInterfaceIsReturned() + { + var interfaces = new AgentFeatureCollection(); + var thing = new Thing(); + + interfaces[typeof(IThing)] = thing; + + var thing2 = interfaces[typeof(IThing)]; + Assert.Equal(thing2, thing); + } + + [Fact] + public void IndexerAlsoAddsItems() + { + var interfaces = new AgentFeatureCollection(); + var thing = new Thing(); + + interfaces[typeof(IThing)] = thing; + + Assert.Equal(interfaces[typeof(IThing)], thing); + } + + [Fact] + public void SetNullValueRemoves() + { + var interfaces = new AgentFeatureCollection(); + var thing = new Thing(); + + interfaces[typeof(IThing)] = thing; + Assert.Equal(interfaces[typeof(IThing)], thing); + + interfaces[typeof(IThing)] = null; + + var thing2 = interfaces[typeof(IThing)]; + Assert.Null(thing2); + } + + [Fact] + public void GetMissingStructFeatureThrows() + { + var interfaces = new AgentFeatureCollection(); + + var ex = Assert.Throws(() => interfaces.Get()); + Assert.Equal("System.Int32 does not exist in the feature collection and because it is a struct the method can't return null. Use 'AgentFeatureCollection[typeof(System.Int32)] is not null' to check if the feature exists.", ex.Message); + } + + [Fact] + public void GetMissingFeatureReturnsNull() + { + var interfaces = new AgentFeatureCollection(); + + Assert.Null(interfaces.Get()); + } + + [Fact] + public void GetStructFeature() + { + var interfaces = new AgentFeatureCollection(); + const int Value = 20; + interfaces.Set(Value); + + Assert.Equal(Value, interfaces.Get()); + } + + [Fact] + public void GetNullableStructFeatureWhenSetWithNonNullableStruct() + { + var interfaces = new AgentFeatureCollection(); + const int Value = 20; + interfaces.Set(Value); + + Assert.Null(interfaces.Get()); + } + + [Fact] + public void GetNullableStructFeatureWhenSetWithNullableStruct() + { + var interfaces = new AgentFeatureCollection(); + const int Value = 20; + interfaces.Set(Value); + + Assert.Equal(Value, interfaces.Get()); + } + + [Fact] + public void GetFeature() + { + var interfaces = new AgentFeatureCollection(); + var thing = new Thing(); + interfaces.Set(thing); + + Assert.Equal(thing, interfaces.Get()); + } + + private interface IThing + { + string Hello(); + } + + private sealed class Thing : IThing + { + public string Hello() + { + return "World"; + } + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunOptionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunOptionsTests.cs index 32560949fb..9a7f51e3aa 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunOptionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRunOptionsTests.cs @@ -23,7 +23,8 @@ public void CloningConstructorCopiesProperties() { ["key1"] = "value1", ["key2"] = 42 - } + }, + Features = new AgentFeatureCollection() }; // Act @@ -37,6 +38,7 @@ public void CloningConstructorCopiesProperties() Assert.NotSame(options.AdditionalProperties, clone.AdditionalProperties); Assert.Equal("value1", clone.AdditionalProperties["key1"]); Assert.Equal(42, clone.AdditionalProperties["key2"]); + Assert.Same(options.Features, clone.Features); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentThreadTests.cs index 4d7c4ad219..e75cb4caa1 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentThreadTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentThreadTests.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; -using Microsoft.Extensions.AI; #pragma warning disable CA1861 // Avoid constant arrays as arguments @@ -21,15 +19,6 @@ public void Serialize_ReturnsDefaultJsonElement() Assert.Equal(default, result); } - [Fact] - public void MessagesReceivedAsync_ReturnsCompletedTask() - { - var thread = new TestAgentThread(); - var messages = new List { new(ChatRole.User, "hello") }; - var result = thread.MessagesReceivedAsync(messages); - Assert.True(result.IsCompleted); - } - #region GetService Method Tests /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs index 4dca99a77c..b6a72110fb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/DelegatingAIAgentTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -34,7 +35,12 @@ public DelegatingAIAgentTests() this._innerAgentMock.Setup(x => x.Id).Returns("test-agent-id"); this._innerAgentMock.Setup(x => x.Name).Returns("Test Agent"); this._innerAgentMock.Setup(x => x.Description).Returns("Test Description"); - this._innerAgentMock.Setup(x => x.GetNewThread()).Returns(this._testThread); + this._innerAgentMock.Setup(x => x.GetNewThread(It.IsAny())).Returns(this._testThread); + this._innerAgentMock.Setup(x => x.DeserializeThread( + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns(this._testThread); this._innerAgentMock .Setup(x => x.RunAsync( @@ -135,11 +141,29 @@ public void Description_DelegatesToInnerAgent() public void GetNewThread_DelegatesToInnerAgent() { // Act - var thread = this._delegatingAgent.GetNewThread(); + var featureCollection = new AgentFeatureCollection(); + var thread = this._delegatingAgent.GetNewThread(featureCollection); // Assert Assert.Same(this._testThread, thread); - this._innerAgentMock.Verify(x => x.GetNewThread(), Times.Once); + this._innerAgentMock.Verify(x => x.GetNewThread(featureCollection), Times.Once); + } + + /// + /// Verify that DeserializeThread delegates to inner agent. + /// + [Fact] + public void DeserializeThread_DelegatesToInnerAgent() + { + // Act + var featureCollection = new AgentFeatureCollection(); + var jsonElement = new JsonElement(); + var jso = new JsonSerializerOptions(); + var thread = this._delegatingAgent.DeserializeThread(jsonElement, jso, featureCollection); + + // Assert + Assert.Same(this._testThread, thread); + this._innerAgentMock.Verify(x => x.DeserializeThread(jsonElement, jso, featureCollection), Times.Once); } /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs index 923eaa7752..2e81dbef64 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs @@ -289,12 +289,12 @@ public FakeChatClientAgent() public override string? Description => this._description; - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(); } - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); } @@ -366,12 +366,12 @@ public FakeMultiMessageAgent() public override string? Description => this._description; - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(); } - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs index 47d9e63520..d7b5a0da89 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs @@ -417,9 +417,9 @@ stateObj is JsonElement state && await Task.CompletedTask; } - public override AgentThread GetNewThread() => new FakeInMemoryAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new FakeInMemoryAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs index e5fb206147..126f1fdea2 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs @@ -425,9 +425,9 @@ private sealed class MultiResponseAgent : AIAgent public override string? Description => "Agent that produces multiple text chunks"; - public override AgentThread GetNewThread() => new TestInMemoryAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new TestInMemoryAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new TestInMemoryAgentThread(serializedThread, jsonSerializerOptions); public override Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) @@ -514,9 +514,9 @@ private sealed class TestAgent : AIAgent public override string? Description => "Test agent"; - public override AgentThread GetNewThread() => new TestInMemoryAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new TestInMemoryAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new TestInMemoryAgentThread(serializedThread, jsonSerializerOptions); public override Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs index b0ad7ec0fe..9af6bb4a31 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AzureFunctions.UnitTests/TestAgent.cs @@ -11,11 +11,12 @@ internal sealed class TestAgent(string name, string description) : AIAgent public override string? Description => description; - public override AgentThread GetNewThread() => new DummyAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new DummyAgentThread(); public override AgentThread DeserializeThread( JsonElement serializedThread, - JsonSerializerOptions? jsonSerializerOptions = null) => new DummyAgentThread(); + JsonSerializerOptions? jsonSerializerOptions = null, + IAgentFeatureCollection? featureCollection = null) => new DummyAgentThread(); public override Task RunAsync( IEnumerable messages, diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs index f2b2bcfd6a..5af498f809 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/AgentExtensionsTests.cs @@ -324,10 +324,10 @@ public TestAgent(string? name, string? description, Exception exceptionToThrow) this._exceptionToThrow = exceptionToThrow; } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); public override string? Name { get; } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index 862b9ef3b4..1ed34dc426 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -426,10 +426,10 @@ public async Task RunAsyncSetsConversationIdOnThreadWhenReturnedByChatClientAsyn } /// - /// Verify that RunAsync uses the ChatMessageStore factory when the chat client returns no conversation id. + /// Verify that RunAsync uses the default InMemoryChatMessageStore when the chat client returns no conversation id. /// [Fact] - public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChatClientAsync() + public async Task RunAsyncUsesDefaultInMemoryChatMessageStoreWhenNoConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -438,12 +438,9 @@ public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChat It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); - Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions", - ChatMessageStoreFactory = mockFactory.Object }); // Act @@ -455,14 +452,13 @@ public async Task RunAsyncUsesChatMessageStoreWhenNoConversationIdReturnedByChat Assert.Equal(2, messageStore.Count); Assert.Equal("test", messageStore[0].Text); Assert.Equal("response", messageStore[1].Text); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// - /// Verify that RunAsync doesn't use the ChatMessageStore factory when the chat client returns a conversation id. + /// Verify that RunAsync uses the ChatMessageStore factory when the chat client returns no conversation id. /// [Fact] - public async Task RunAsyncIgnoresChatMessageStoreWhenConversationIdReturnedByChatClientAsync() + public async Task RunAsyncUsesChatMessageStoreFactoryWhenProvidedAndNoConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -470,9 +466,13 @@ public async Task RunAsyncIgnoresChatMessageStoreWhenConversationIdReturnedByCha s => s.GetResponseAsync( It.IsAny>(), It.IsAny(), - It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + Mock mockChatMessageStore = new(); + Mock> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); + mockFactory.Setup(f => f(It.IsAny())).Returns(mockChatMessageStore.Object); + ChatClientAgent agent = new(mockService.Object, options: new() { Instructions = "test instructions", @@ -484,8 +484,70 @@ public async Task RunAsyncIgnoresChatMessageStoreWhenConversationIdReturnedByCha await agent.RunAsync([new(ChatRole.User, "test")], thread); // Assert - Assert.Equal("ConvId", thread!.ConversationId); - mockFactory.Verify(f => f(It.IsAny()), Times.Never); + Assert.IsType(thread!.MessageStore, exactMatch: false); + mockChatMessageStore.Verify(s => s.AddMessagesAsync(It.Is>(x => x.Count() == 2), It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny()), Times.Once); + } + + /// + /// Verify that RunAsync uses the ChatMessageStore provided via run params when the chat client returns no conversation id. + /// + [Fact] + public async Task RunAsyncUsesChatMessageStoreWhenProvidedViaFeaturesAndNoConversationIdReturnedByChatClientAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + Mock mockChatMessageStore = new(); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + Instructions = "test instructions", + }); + + AgentFeatureCollection features = new(); + features.Set(mockChatMessageStore.Object); + + // Act + ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + await agent.RunAsync([new(ChatRole.User, "test")], thread, options: new AgentRunOptions() { Features = features }); + + // Assert + Assert.IsType(thread!.MessageStore, exactMatch: false); + mockChatMessageStore.Verify(s => s.GetMessagesAsync(It.IsAny()), Times.Once); + mockChatMessageStore.Verify(s => s.AddMessagesAsync(It.Is>(x => x.Count() == 2), It.IsAny()), Times.Once); + } + + /// + /// Verify that RunAsync throws when a ChatMessageStore Factory is provided but when the chat client returns a conversation id. + /// + [Fact] + public async Task RunAsyncThrowsWhenChatMessageStoreFactoryProvidedAndConversationIdReturnedByChatClientAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); + Mock> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).Returns(new InMemoryChatMessageStore()); + ChatClientAgent agent = new(mockService.Object, options: new() + { + Instructions = "test instructions", + ChatMessageStoreFactory = mockFactory.Object + }); + + // Act & Assert + ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; + var exception = await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], thread)); + Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); } /// @@ -1914,10 +1976,10 @@ public async Task RunStreamingAsyncUsesChatMessageStoreWhenNoConversationIdRetur } /// - /// Verify that RunStreamingAsync doesn't use the ChatMessageStore factory when the chat client returns a conversation id. + /// Verify that RunStreamingAsync throws when a ChatMessageStore factory is provided and the chat client returns a conversation id. /// [Fact] - public async Task RunStreamingAsyncIgnoresChatMessageStoreWhenConversationIdReturnedByChatClientAsync() + public async Task RunStreamingAsyncThrowsWhenChatMessageStoreFactoryProvidedAndConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -1939,13 +2001,10 @@ public async Task RunStreamingAsyncIgnoresChatMessageStoreWhenConversationIdRetu ChatMessageStoreFactory = mockFactory.Object }); - // Act + // Act & Assert ChatClientAgentThread? thread = agent.GetNewThread() as ChatClientAgentThread; - await agent.RunStreamingAsync([new(ChatRole.User, "test")], thread).ToListAsync(); - - // Assert - Assert.Equal("ConvId", thread!.ConversationId); - mockFactory.Verify(f => f(It.IsAny()), Times.Never); + var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync([new(ChatRole.User, "test")], thread).ToListAsync()); + Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); } /// @@ -2074,37 +2133,6 @@ await Assert.ThrowsAsync(async () => #endregion - #region GetNewThread Tests - - [Fact] - public void GetNewThreadUsesAIContextProviderFactoryIfProvided() - { - // Arrange - var mockChatClient = new Mock(); - var mockContextProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - Instructions = "Test instructions", - AIContextProviderFactory = _ => - { - factoryCalled = true; - return mockContextProvider.Object; - } - }); - - // Act - var thread = agent.GetNewThread(); - - // Assert - Assert.True(factoryCalled, "AIContextProviderFactory was not called."); - Assert.IsType(thread); - var typedThread = (ChatClientAgentThread)thread; - Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); - } - - #endregion - #region Background Responses Tests [Theory] diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentThreadTests.cs index 8226e697ca..48caef1b3d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentThreadTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentThreadTests.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Text.Json; -using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; @@ -91,50 +90,6 @@ public void SetChatMessageStoreThrowsWhenConversationIdIsSet() #endregion Constructor and Property Tests - #region OnNewMessagesAsync Tests - - [Fact] - public async Task OnNewMessagesAsyncDoesNothingWhenAgentServiceIdAsync() - { - // Arrange - var thread = new ChatClientAgentThread { ConversationId = "thread-123" }; - var messages = new List - { - new(ChatRole.User, "Hello"), - new(ChatRole.Assistant, "Hi there!") - }; - var agent = new MessageSendingAgent(); - - // Act - await agent.SendMessagesAsync(thread, messages, CancellationToken.None); - Assert.Equal("thread-123", thread.ConversationId); - Assert.Null(thread.MessageStore); - } - - [Fact] - public async Task OnNewMessagesAsyncAddsMessagesToStoreAsync() - { - // Arrange - var store = new InMemoryChatMessageStore(); - var thread = new ChatClientAgentThread { MessageStore = store }; - var messages = new List - { - new(ChatRole.User, "Hello"), - new(ChatRole.Assistant, "Hi there!") - }; - var agent = new MessageSendingAgent(); - - // Act - await agent.SendMessagesAsync(thread, messages, CancellationToken.None); - - // Assert - Assert.Equal(2, store.Count); - Assert.Equal("Hello", store[0].Text); - Assert.Equal("Hi there!", store[1].Text); - } - - #endregion OnNewMessagesAsync Tests - #region Deserialize Tests [Fact] @@ -372,22 +327,4 @@ public void GetService_RequestingChatMessageStore_ReturnsChatMessageStore() } #endregion - - private sealed class MessageSendingAgent : AIAgent - { - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) - => throw new NotImplementedException(); - - public override AgentThread GetNewThread() - => throw new NotImplementedException(); - - public override Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) - => throw new NotImplementedException(); - - public override IAsyncEnumerable RunStreamingAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) - => throw new NotImplementedException(); - - public Task SendMessagesAsync(AgentThread thread, IEnumerable messages, CancellationToken cancellationToken = default) - => NotifyThreadOfNewMessagesAsync(thread, messages, cancellationToken); - } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs new file mode 100644 index 0000000000..7c3abdcbe7 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.Extensions.AI; +using Moq; + +namespace Microsoft.Agents.AI.UnitTests.ChatClient; + +/// +/// Contains unit tests for the ChatClientAgent.DeserializeThread methods. +/// +public class ChatClientAgent_DeserializeThreadTests +{ + [Fact] + public void DeserializeThread_UsesAIContextProviderFactory_IfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockContextProvider = new Mock(); + var factoryCalled = false; + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + AIContextProviderFactory = _ => + { + factoryCalled = true; + return mockContextProvider.Object; + } + }); + + var json = JsonSerializer.Deserialize(""" + { + "aiContextProviderState": ["CP1"] + } + """, TestJsonSerializerContext.Default.JsonElement); + + // Act + var thread = agent.DeserializeThread(json); + + // Assert + Assert.True(factoryCalled, "AIContextProviderFactory was not called."); + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); + } + + [Fact] + public void DeserializeThread_UsesChatMessageStoreFactory_IfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var factoryCalled = false; + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + ChatMessageStoreFactory = _ => + { + factoryCalled = true; + return mockMessageStore.Object; + } + }); + + var json = JsonSerializer.Deserialize(""" + { + "storeState": { } + } + """, TestJsonSerializerContext.Default.JsonElement); + + // Act + var thread = agent.DeserializeThread(json); + + // Assert + Assert.True(factoryCalled, "ChatMessageStoreFactory was not called."); + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } + + [Fact] + public void DeserializeThread_UsesChatMessageStore_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + ChatMessageStoreFactory = _ => + { + Assert.Fail("ChatMessageStoreFactory should not have been called."); + return null!; + } + }); + + var json = JsonSerializer.Deserialize(""" + { + } + """, TestJsonSerializerContext.Default.JsonElement); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockMessageStore.Object); + var thread = agent.DeserializeThread(json, null, agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } + + [Fact] + public void DeserializeThread_UsesAIContextProvider_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockContextProvider = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + AIContextProviderFactory = _ => + { + Assert.Fail("AIContextProviderFactory should not have been called."); + return null!; + } + }); + + var json = JsonSerializer.Deserialize(""" + { + } + """, TestJsonSerializerContext.Default.JsonElement); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockContextProvider.Object); + var thread = agent.DeserializeThread(json, null, agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs new file mode 100644 index 0000000000..10af9bd9a5 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_GetNewThreadTests.cs @@ -0,0 +1,176 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using Microsoft.Extensions.AI; +using Moq; + +namespace Microsoft.Agents.AI.UnitTests.ChatClient; + +/// +/// Contains unit tests for the ChatClientAgent.GetNewThread methods. +/// +public class ChatClientAgent_GetNewThreadTests +{ + [Fact] + public void GetNewThread_UsesAIContextProviderFactory_IfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockContextProvider = new Mock(); + var factoryCalled = false; + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + AIContextProviderFactory = _ => + { + factoryCalled = true; + return mockContextProvider.Object; + } + }); + + // Act + var thread = agent.GetNewThread(); + + // Assert + Assert.True(factoryCalled, "AIContextProviderFactory was not called."); + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); + } + + [Fact] + public void GetNewThread_UsesChatMessageStoreFactory_IfProvided() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var factoryCalled = false; + var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions + { + Instructions = "Test instructions", + ChatMessageStoreFactory = _ => + { + factoryCalled = true; + return mockMessageStore.Object; + } + }); + + // Act + var thread = agent.GetNewThread(); + + // Assert + Assert.True(factoryCalled, "ChatMessageStoreFactory was not called."); + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } + + [Fact] + public void GetNewThread_UsesChatMessageStore_FromTypedOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var thread = agent.GetNewThread(mockMessageStore.Object); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } + + [Fact] + public void GetNewThread_UsesConversationId_FromTypedOverload() + { + // Arrange + var mockChatClient = new Mock(); + const string TestConversationId = "test_conversation_id"; + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var thread = agent.GetNewThread(TestConversationId); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Equal(TestConversationId, typedThread.ConversationId); + } + + [Fact] + public void GetNewThread_UsesConversationId_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var testConversationId = new ConversationIdAgentFeature("test_conversation_id"); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(testConversationId); + var thread = agent.GetNewThread(agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Equal(testConversationId.ConversationId, typedThread.ConversationId); + } + + [Fact] + public void GetNewThread_UsesChatMessageStore_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockMessageStore.Object); + var thread = agent.GetNewThread(agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockMessageStore.Object, typedThread.MessageStore); + } + + [Fact] + public void GetNewThread_UsesAIContextProvider_FromFeatureOverload() + { + // Arrange + var mockChatClient = new Mock(); + var mockContextProvider = new Mock(); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockContextProvider.Object); + var thread = agent.GetNewThread(agentFeatures); + + // Assert + Assert.IsType(thread); + var typedThread = (ChatClientAgentThread)thread; + Assert.Same(mockContextProvider.Object, typedThread.AIContextProvider); + } + + [Fact] + public void GetNewThread_Throws_IfBothConversationIdAndMessageStoreAreSet() + { + // Arrange + var mockChatClient = new Mock(); + var mockMessageStore = new Mock(); + var testConversationId = new ConversationIdAgentFeature("test_conversation_id"); + var agent = new ChatClientAgent(mockChatClient.Object); + + // Act & Assert + var agentFeatures = new AgentFeatureCollection(); + agentFeatures.Set(mockMessageStore.Object); + agentFeatures.Set(testConversationId); + + var exception = Assert.Throws(() => agent.GetNewThread(agentFeatures)); + Assert.Equal("Only the ConversationId or MessageStore may be set, but not both and switching from one to another is not supported.", exception.Message); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs index fb00973c78..65689c9d05 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestAIAgent.cs @@ -24,10 +24,10 @@ internal sealed class TestAIAgent : AIAgent public override string? Description => this.DescriptionFunc?.Invoke() ?? base.Description; - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) => + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => this.DeserializeThreadFunc(serializedThread, jsonSerializerOptions); - public override AgentThread GetNewThread() => + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => this.GetNewThreadFunc(); public override Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs index 0437fc7695..8dd7b438ae 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs @@ -135,10 +135,10 @@ private class DoubleEchoAgent(string name) : AIAgent { public override string Name => name; - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new DoubleEchoAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new DoubleEchoAgentThread(); public override Task RunAsync( diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs index cffdb8c73c..1f65bf0688 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs @@ -146,10 +146,12 @@ public SimpleTestAgent(string name) public override string Name => this._name; - public override AgentThread GetNewThread() => new SimpleTestAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new SimpleTestAgentThread(); - public override AgentThread DeserializeThread(System.Text.Json.JsonElement serializedThread, - System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null) => new SimpleTestAgentThread(); + public override AgentThread DeserializeThread( + System.Text.Json.JsonElement serializedThread, + System.Text.Json.JsonSerializerOptions? jsonSerializerOptions = null, + IAgentFeatureCollection? featureCollection = null) => new SimpleTestAgentThread(); public override Task RunAsync( IEnumerable messages, diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs index 9cf460e658..768ddbda73 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RepresentationTests.cs @@ -24,10 +24,10 @@ private sealed class TestExecutor() : Executor("TestExecutor") private sealed class TestAgent : AIAgent { - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => throw new NotImplementedException(); public override Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs index a0e57006ed..2633e4bdf7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs @@ -60,10 +60,10 @@ internal sealed class HelloAgent(string id = nameof(HelloAgent)) : AIAgent public override string Id => id; public override string? Name => id; - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new HelloAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new HelloAgentThread(); public override async Task RunAsync(IEnumerable messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs index b93d7862d5..977dbd4ad7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/SpecializedExecutorSmokeTests.cs @@ -51,10 +51,10 @@ static ChatMessage ToMessage(string text) return result; } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new TestAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) => new TestAgentThread(); public static TestAIAgent FromStrings(params string[] messages) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs index 369f08bd8b..84bb071516 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs @@ -16,12 +16,12 @@ internal class TestEchoAgent(string? id = null, string? name = null, string? pre public override string Id => id ?? base.Id; public override string? Name => name ?? base.Name; - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { return JsonSerializer.Deserialize(serializedThread, jsonSerializerOptions) ?? this.GetNewThread(); } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { return new EchoAgentThread(); } From 9d86adfcb221c9bec80f3370a0524db8c1b542eb Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 24 Nov 2025 10:57:49 +0000 Subject: [PATCH 2/3] .NET: Update AgentFeatureCollections with feedback (#2379) * Update AgentFeatureCollections with feedback * Address feedback. * Fix issue with sample. * Change generic type restriction to notnull * Remove revision * Update dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollectionExtensions.cs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add revision back again and improve some formatting. * Remove virtual from revision. * Add overloads taking type as param and add unit tests. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../Program.cs | 25 +-- .../src/Microsoft.Agents.AI.A2A/A2AAgent.cs | 7 +- .../Features/AgentFeatureCollection.cs | 155 ++++++++++-------- .../AgentFeatureCollectionExtensions.cs | 23 +++ .../Features/ConversationIdAgentFeature.cs | 4 +- .../Features/IAgentFeatureCollection.cs | 42 ++++- .../CopilotStudioAgent.cs | 7 +- .../ChatClient/ChatClientAgent.cs | 24 +-- .../AgentFeatureCollectionTests.cs | 145 +++++++++++----- 9 files changed, 291 insertions(+), 141 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollectionExtensions.cs diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs index cacc237f4d..98d3a27245 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyThreadStorage/Program.cs @@ -118,9 +118,8 @@ async Task CustomChatMessageStore_UsingFactoryAndExistingExternalId_Async() // It's possible to create a new thread that uses the same chat message store id by providing // the VectorChatMessageStoreThreadDbKeyFeature in the feature collection when creating the new thread. - AgentFeatureCollection features = new(); - features.Set(new VectorChatMessageStoreThreadDbKeyFeature(messageStoreFromFactory.ThreadDbKey!)); - AgentThread resumedThread = agent.GetNewThread(features); + AgentThread resumedThread = agent.GetNewThread( + new AgentFeatureCollection().WithFeature(new VectorChatMessageStoreThreadDbKeyFeature(messageStoreFromFactory.ThreadDbKey!))); // Run the agent with the thread that stores conversation history in the vector store. Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedThread)); @@ -149,9 +148,7 @@ async Task CustomChatMessageStore_PerThread_Async() // We can then pass the feature collection when creating a new thread. // We also have the opportunity here to pass any id that we want for storing the chat history in the vector store. VectorChatMessageStore perThreadMessageStore = new(vectorStore, "chat-history-1"); - AgentFeatureCollection features = new(); - features.Set(perThreadMessageStore); - AgentThread thread = agent.GetNewThread(features); + AgentThread thread = agent.GetNewThread(new AgentFeatureCollection().WithFeature(perThreadMessageStore)); Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread)); @@ -191,10 +188,13 @@ async Task CustomChatMessageStore_PerRun_Async() // If the agent doesn't support a message store, it would be ignored. // We also have the opportunity here to pass any id that we want for storing the chat history in the vector store. VectorChatMessageStore perRunMessageStore = new(vectorStore, "chat-history-1"); - AgentFeatureCollection features = new(); - features.Set(perRunMessageStore); - - Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", thread, options: new AgentRunOptions() { Features = features })); + Console.WriteLine(await agent.RunAsync( + "Tell me a joke about a pirate.", + thread, + options: new AgentRunOptions() + { + Features = new AgentFeatureCollection().WithFeature(perRunMessageStore) + })); // When serializing this thread, we'll see that it has no messagestore state, since the messagestore was not attached to the thread, // but just provided for the single run. Note that, depending on the circumstances, the thread may still contain other state, e.g. Memories, @@ -237,8 +237,9 @@ public VectorChatMessageStore(VectorStore vectorStore, JsonElement serializedSto // or finally we can generate one ourselves. this.ThreadDbKey = serializedStoreState.ValueKind is JsonValueKind.String ? serializedStoreState.Deserialize() - : features?.Get()?.ThreadDbKey - ?? Guid.NewGuid().ToString("N"); + : features?.TryGet(out var threadDbKeyFeature) is true + ? threadDbKeyFeature.ThreadDbKey + : Guid.NewGuid().ToString("N"); } public string? ThreadDbKey { get; } diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs index 15df3ea5ae..3d1dc4282e 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs @@ -55,7 +55,12 @@ public A2AAgent(A2AClient a2aClient, string? id = null, string? name = null, str /// public sealed override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) - => new A2AAgentThread() { ContextId = featureCollection?.Get()?.ConversationId }; + => new A2AAgentThread() + { + ContextId = featureCollection?.TryGet(out var conversationIdFeature) is true + ? conversationIdFeature.ConversationId + : null + }; /// /// Get a new instance using an existing context id, to continue that conversation. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs index 475f54c77e..df157f454c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollection.cs @@ -4,6 +4,7 @@ using System.Collections; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Linq; using Microsoft.Shared.Diagnostics; @@ -18,9 +19,7 @@ namespace Microsoft.Agents.AI; [DebuggerTypeProxy(typeof(FeatureCollectionDebugView))] public class AgentFeatureCollection : IAgentFeatureCollection { - private static readonly KeyComparer s_featureKeyComparer = new(); - private readonly IAgentFeatureCollection? _defaults; - private readonly int _initialCapacity; + private readonly IAgentFeatureCollection? _innerCollection; private Dictionary? _features; private volatile int _containerRevision; @@ -39,124 +38,142 @@ public AgentFeatureCollection() public AgentFeatureCollection(int initialCapacity) { Throw.IfLessThan(initialCapacity, 0); - - this._initialCapacity = initialCapacity; + this._features = new(initialCapacity); } /// - /// Initializes a new instance of with the specified defaults. + /// Initializes a new instance of with the specified inner collection. /// - /// The feature defaults. - public AgentFeatureCollection(IAgentFeatureCollection defaults) + /// The inner collection. + /// + /// + /// When providing an inner collection, and if a feature is not found in this collection, + /// an attempt will be made to retrieve it from the inner collection as a fallback. + /// + /// + /// The method will only remove features from this collection + /// and not from the inner collection. When removing a feature from this collection, and + /// it exists in the inner collection, it will still be retrievable from the inner collection. + /// + /// + public AgentFeatureCollection(IAgentFeatureCollection innerCollection) { - this._defaults = defaults; + this._innerCollection = Throw.IfNull(innerCollection); } /// - public virtual int Revision + public int Revision { - get { return this._containerRevision + (this._defaults?.Revision ?? 0); } + get { return this._containerRevision + (this._innerCollection?.Revision ?? 0); } } /// public bool IsReadOnly { get { return false; } } + IEnumerator IEnumerable.GetEnumerator() + { + return this.GetEnumerator(); + } + /// - public object? this[Type key] + public IEnumerator> GetEnumerator() { - get + if (this._features is not { Count: > 0 }) { - Throw.IfNull(key); + IEnumerable> e = ((IEnumerable>?)this._innerCollection) ?? []; + return e.GetEnumerator(); + } - return this._features != null && this._features.TryGetValue(key, out var result) ? result : this._defaults?[key]; + if (this._innerCollection is null) + { + return this._features.GetEnumerator(); } - set + + if (this._innerCollection is AgentFeatureCollection innerCollection && innerCollection._features is not { Count: > 0 }) { - Throw.IfNull(key); + return this._features.GetEnumerator(); + } + + return YieldAll(); - if (value == null) + IEnumerator> YieldAll() + { + HashSet set = []; + + foreach (var entry in this._features) { - if (this._features?.Remove(key) is true) - { - this._containerRevision++; - } - return; + set.Add(entry.Key); + yield return entry; } - if (this._features == null) + foreach (var entry in this._innerCollection.Where(x => !set.Contains(x.Key))) { - this._features = new Dictionary(this._initialCapacity); + yield return entry; } - this._features[key] = value; - this._containerRevision++; } } - IEnumerator IEnumerable.GetEnumerator() - { - return this.GetEnumerator(); - } - /// - public IEnumerator> GetEnumerator() + public bool TryGet([MaybeNullWhen(false)] out TFeature feature) + where TFeature : notnull { - if (this._features != null) + if (this.TryGet(typeof(TFeature), out var obj)) { - foreach (var pair in this._features) - { - yield return pair; - } + feature = (TFeature)obj; + return true; } - if (this._defaults != null) - { - // Don't return features masked by the wrapper. - foreach (var pair in this._features == null ? this._defaults : this._defaults.Except(this._features, s_featureKeyComparer)) - { - yield return pair; - } - } + feature = default; + return false; } /// - public TFeature? Get() + public bool TryGet(Type type, [MaybeNullWhen(false)] out object feature) { - if (typeof(TFeature).IsValueType) + if (this._features?.TryGetValue(type, out var obj) is true) { - var feature = this[typeof(TFeature)]; - if (feature is null && Nullable.GetUnderlyingType(typeof(TFeature)) is null) - { - throw new InvalidOperationException( - $"{typeof(TFeature).FullName} does not exist in the feature collection " + - $"and because it is a struct the method can't return null. Use 'AgentFeatureCollection[typeof({typeof(TFeature).FullName})] is not null' to check if the feature exists."); - } - return (TFeature?)feature; + feature = obj; + return true; } - return (TFeature?)this[typeof(TFeature)]; + + if (this._innerCollection?.TryGet(type, out var defaultFeature) is true) + { + feature = defaultFeature; + return true; + } + + feature = default; + return false; } /// - public void Set(TFeature? instance) + public void Set(TFeature instance) + where TFeature : notnull { - this[typeof(TFeature)] = instance; + Throw.IfNull(instance); + + this._features ??= new(); + this._features[typeof(TFeature)] = instance; + this._containerRevision++; } - // Used by the debugger. Count over enumerable is required to get the correct value. - private int GetCount() => this.Count(); + /// + public void Remove() + where TFeature : notnull + => this.Remove(typeof(TFeature)); - private sealed class KeyComparer : IEqualityComparer> + /// + public void Remove(Type type) { - public bool Equals(KeyValuePair x, KeyValuePair y) + if (this._features?.Remove(type) is true) { - return x.Key.Equals(y.Key); - } - - public int GetHashCode(KeyValuePair obj) - { - return obj.Key.GetHashCode(); + this._containerRevision++; } } + // Used by the debugger. Count over enumerable is required to get the correct value. + private int GetCount() => this.Count(); + private sealed class FeatureCollectionDebugView(AgentFeatureCollection features) { private readonly AgentFeatureCollection _features = features; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollectionExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollectionExtensions.cs new file mode 100644 index 0000000000..95641858b7 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/AgentFeatureCollectionExtensions.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.Agents.AI; + +/// +/// Extension methods for . +/// +public static class AgentFeatureCollectionExtensions +{ + /// + /// Adds the specified feature to the collection and returns the collection. + /// + /// The feature key. + /// The feature collection to add the new feature to. + /// The feature to add to the collection. + /// The updated collection. + public static IAgentFeatureCollection WithFeature(this IAgentFeatureCollection features, TFeature feature) + where TFeature : notnull + { + features.Set(feature); + return features; + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs index 6a17456c73..2cd267197f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/ConversationIdAgentFeature.cs @@ -8,7 +8,7 @@ namespace Microsoft.Agents.AI; /// An agent feature that allows providing a conversation identifier. /// /// -/// This feature allows a user to provide a specific identifier for chat history whether stored in the underlying AI service or stored in a 3rd party store. +/// This feature allows a user to provide a specific identifier for chat history when stored in the underlying AI service. /// public class ConversationIdAgentFeature { @@ -16,7 +16,7 @@ public class ConversationIdAgentFeature /// Initializes a new instance of the class with the specified thread /// identifier. /// - /// The unique identifier of the thread required by the underlying AI service or 3rd party store. Cannot be or empty. + /// The unique identifier of the thread required by the underlying AI service. Cannot be or empty. public ConversationIdAgentFeature(string conversationId) { this.ConversationId = Throw.IfNullOrWhitespace(conversationId); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs index f2e7f38f86..dca17dc668 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/Features/IAgentFeatureCollection.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; namespace Microsoft.Agents.AI; @@ -24,23 +25,48 @@ public interface IAgentFeatureCollection : IEnumerable - /// Gets or sets a given feature. Setting a null value removes the feature. + /// Attempts to retrieve a feature of the specified type. /// - /// - /// The requested feature, or null if it is not present. - object? this[Type key] { get; set; } + /// The type of the feature to retrieve. + /// When this method returns, contains the feature of type if found; otherwise, the + /// default value for the type. + /// + /// if the feature of type was successfully retrieved; + /// otherwise, . + /// + bool TryGet([MaybeNullWhen(false)] out TFeature feature) + where TFeature : notnull; /// - /// Retrieves the requested feature from the collection. + /// Attempts to retrieve a feature of the specified type. + /// + /// The type of the feature to get. + /// When this method returns, contains the feature of type if found; otherwise, the + /// default value for the type. + /// + /// if the feature of type was successfully retrieved; + /// otherwise, . + /// + bool TryGet(Type type, [MaybeNullWhen(false)] out object feature); + + /// + /// Remove a feature from the collection. /// /// The feature key. - /// The requested feature, or null if it is not present. - TFeature? Get(); + void Remove() + where TFeature : notnull; + + /// + /// Remove a feature from the collection. + /// + /// The type of the feature to remove. + void Remove(Type type); /// /// Sets the given feature in the collection. /// /// The feature key. /// The feature value. - void Set(TFeature? instance); + void Set(TFeature instance) + where TFeature : notnull; } diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs index 634236269a..c689984537 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs @@ -43,7 +43,12 @@ public CopilotStudioAgent(CopilotClient client, ILoggerFactory? loggerFactory = /// public sealed override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) - => new CopilotStudioAgentThread() { ConversationId = featureCollection?.Get()?.ConversationId }; + => new CopilotStudioAgentThread() + { + ConversationId = featureCollection?.TryGet(out var conversationIdFeature) is true + ? conversationIdFeature.ConversationId + : null + }; /// /// Get a new instance using an existing conversation id, to continue that conversation. diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 207492f482..4cce2e221a 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -289,13 +289,17 @@ public override async IAsyncEnumerable RunStreamingAsync public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new ChatClientAgentThread { - ConversationId = featureCollection?.Get()?.ConversationId, + ConversationId = featureCollection?.TryGet(out var conversationIdAgentFeature) is true + ? conversationIdAgentFeature.ConversationId + : null, MessageStore = - featureCollection?.Get() - ?? this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, Features = featureCollection, JsonSerializerOptions = null }), + featureCollection?.TryGet(out var chatMessageStoreFeature) is true + ? chatMessageStoreFeature + : this._agentOptions?.ChatMessageStoreFactory?.Invoke(new() { SerializedState = default, Features = featureCollection, JsonSerializerOptions = null }), AIContextProvider = - featureCollection?.Get() - ?? this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, Features = featureCollection, JsonSerializerOptions = null }) + featureCollection?.TryGet(out var aIContextProviderFeature) is true + ? aIContextProviderFeature + : this._agentOptions?.AIContextProviderFactory?.Invoke(new() { SerializedState = default, Features = featureCollection, JsonSerializerOptions = null }) }; /// @@ -353,17 +357,15 @@ public AgentThread GetNewThread(ChatMessageStore chatMessageStore) /// public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { - var chatMessageStoreFeature = featureCollection?.Get(); Func? chatMessageStoreFactory = - chatMessageStoreFeature is not null + featureCollection?.TryGet(out var chatMessageStoreFeature) is true ? (jse, jso) => chatMessageStoreFeature : this._agentOptions?.ChatMessageStoreFactory is not null ? (jse, jso) => this._agentOptions.ChatMessageStoreFactory.Invoke(new() { SerializedState = jse, Features = featureCollection, JsonSerializerOptions = jso }) : null; - var aiContextProviderFeature = featureCollection?.Get(); Func? aiContextProviderFactory = - aiContextProviderFeature is not null + featureCollection?.TryGet(out var aiContextProviderFeature) is true ? (jse, jso) => aiContextProviderFeature : this._agentOptions?.AIContextProviderFactory is not null ? (jse, jso) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, Features = featureCollection, JsonSerializerOptions = jso }) @@ -644,7 +646,7 @@ await thread.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvider // If the caller provided an override message store via run options, we should use that instead of the message store // on the thread. - if (runOptions?.Features?.Get() is ChatMessageStore chatMessageStoreFeature) + if (runOptions?.Features?.TryGet(out var chatMessageStoreFeature) is true) { messageStore = chatMessageStoreFeature; } @@ -745,7 +747,7 @@ private static Task NotifyMessageStoreOfNewMessagesAsync(ChatClientAgentThread t // If the caller provided an override message store via run options, we should use that instead of the message store // on the thread. - if (runOptions?.Features?.Get() is ChatMessageStore chatMessageStoreFeature) + if (runOptions?.Features?.TryGet(out var chatMessageStoreFeature) is true) { messageStore = chatMessageStoreFeature; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs index 7b55cfd64f..9d3a3e7c66 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentFeatureCollectionTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Linq; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -10,98 +11,168 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; public class AgentFeatureCollectionTests { [Fact] - public void AddedInterfaceIsReturned() + public void Feature_RoundTrips() { + // Arrange. var interfaces = new AgentFeatureCollection(); var thing = new Thing(); - interfaces[typeof(IThing)] = thing; + // Act. + interfaces.Set(thing); + Assert.True(interfaces.TryGet(out var actualThing)); - var thing2 = interfaces[typeof(IThing)]; - Assert.Equal(thing2, thing); + // Assert. + Assert.Same(actualThing, thing); + Assert.Equal(1, interfaces.Revision); } [Fact] - public void IndexerAlsoAddsItems() + public void RemoveOfT_Removes() { + // Arrange. var interfaces = new AgentFeatureCollection(); var thing = new Thing(); - interfaces[typeof(IThing)] = thing; + interfaces.Set(thing); + Assert.True(interfaces.TryGet(out _)); - Assert.Equal(interfaces[typeof(IThing)], thing); + // Act. + interfaces.Remove(); + + // Assert. + Assert.False(interfaces.TryGet(out _)); + Assert.Equal(2, interfaces.Revision); } [Fact] - public void SetNullValueRemoves() + public void Remove_Removes() { + // Arrange. var interfaces = new AgentFeatureCollection(); var thing = new Thing(); - interfaces[typeof(IThing)] = thing; - Assert.Equal(interfaces[typeof(IThing)], thing); + interfaces.Set(thing); + Assert.True(interfaces.TryGet(out _)); - interfaces[typeof(IThing)] = null; + // Act. + interfaces.Remove(typeof(IThing)); - var thing2 = interfaces[typeof(IThing)]; - Assert.Null(thing2); + // Assert. + Assert.False(interfaces.TryGet(out _)); + Assert.Equal(2, interfaces.Revision); } [Fact] - public void GetMissingStructFeatureThrows() + public void TryGetMissingFeature_ReturnsFalse() { + // Arrange. var interfaces = new AgentFeatureCollection(); - var ex = Assert.Throws(() => interfaces.Get()); - Assert.Equal("System.Int32 does not exist in the feature collection and because it is a struct the method can't return null. Use 'AgentFeatureCollection[typeof(System.Int32)] is not null' to check if the feature exists.", ex.Message); + // Act & Assert. + Assert.False(interfaces.TryGet(out var actualThing)); + Assert.Null(actualThing); } [Fact] - public void GetMissingFeatureReturnsNull() + public void Set_Null_Throws() { + // Arrange. var interfaces = new AgentFeatureCollection(); - Assert.Null(interfaces.Get()); + // Act & Assert. + Assert.Throws(() => interfaces.Set(null!)); } [Fact] - public void GetStructFeature() + public void IsReadOnly_DefaultsToFalse() { + // Arrange. var interfaces = new AgentFeatureCollection(); - const int Value = 20; - interfaces.Set(Value); - Assert.Equal(Value, interfaces.Get()); + // Act & Assert. + Assert.False(interfaces.IsReadOnly); } [Fact] - public void GetNullableStructFeatureWhenSetWithNonNullableStruct() + public void TryGetOfT_FallsBackToInnerCollection() { - var interfaces = new AgentFeatureCollection(); - const int Value = 20; - interfaces.Set(Value); + // Arrange. + var inner = new AgentFeatureCollection(); + var thing = new Thing(); + inner.Set(thing); + var outer = new AgentFeatureCollection(inner); - Assert.Null(interfaces.Get()); + // Act & Assert. + Assert.True(outer.TryGet(out var actualThing)); + Assert.Same(actualThing, thing); } [Fact] - public void GetNullableStructFeatureWhenSetWithNullableStruct() + public void TryGetOfT_OverridesInnerWithOuterCollection() { - var interfaces = new AgentFeatureCollection(); - const int Value = 20; - interfaces.Set(Value); - - Assert.Equal(Value, interfaces.Get()); + // Arrange. + var inner = new AgentFeatureCollection(); + var innerThing = new Thing(); + inner.Set(innerThing); + + var outer = new AgentFeatureCollection(inner); + var outerThing = new Thing(); + outer.Set(outerThing); + + // Act & Assert. + Assert.True(outer.TryGet(out var actualThing)); + Assert.Same(outerThing, actualThing); } [Fact] - public void GetFeature() + public void TryGet_FallsBackToInnerCollection() { - var interfaces = new AgentFeatureCollection(); + // Arrange. + var inner = new AgentFeatureCollection(); var thing = new Thing(); - interfaces.Set(thing); + inner.Set(thing); + var outer = new AgentFeatureCollection(inner); + + // Act & Assert. + Assert.True(outer.TryGet(typeof(IThing), out var actualThing)); + Assert.Same(actualThing, thing); + } + + [Fact] + public void TryGet_OverridesInnerWithOuterCollection() + { + // Arrange. + var inner = new AgentFeatureCollection(); + var innerThing = new Thing(); + inner.Set(innerThing); + + var outer = new AgentFeatureCollection(inner); + var outerThing = new Thing(); + outer.Set(outerThing); + + // Act & Assert. + Assert.True(outer.TryGet(typeof(IThing), out var actualThing)); + Assert.Same(outerThing, actualThing); + } + + [Fact] + public void Enumerate_OverridesInnerWithOuterCollection() + { + // Arrange. + var inner = new AgentFeatureCollection(); + var innerThing = new Thing(); + inner.Set(innerThing); + + var outer = new AgentFeatureCollection(inner); + var outerThing = new Thing(); + outer.Set(outerThing); + + // Act. + var items = outer.ToList(); - Assert.Equal(thing, interfaces.Get()); + // Assert. + Assert.Single(items); + Assert.Same(outerThing, items.First().Value as IThing); } private interface IThing From f958bf06baa4db1aa205363f9872b7ac2631f293 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Wed, 31 Dec 2025 12:26:36 +0000 Subject: [PATCH 3/3] Fix breaks after merge from main --- .../AggregatorPromptAgentFactoryTests.cs | 4 ++-- .../ForwardedPropertiesTests.cs | 4 ++-- .../ChatClient/ChatClientAgent_DeserializeThreadTests.cs | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dotnet/tests/Microsoft.Agents.AI.Declarative.UnitTests/AggregatorPromptAgentFactoryTests.cs b/dotnet/tests/Microsoft.Agents.AI.Declarative.UnitTests/AggregatorPromptAgentFactoryTests.cs index 09ee72504a..7c763243bc 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Declarative.UnitTests/AggregatorPromptAgentFactoryTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Declarative.UnitTests/AggregatorPromptAgentFactoryTests.cs @@ -66,12 +66,12 @@ public TestAgentFactory(AIAgent? agentToReturn = null) private sealed class TestAgent : AIAgent { - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { throw new NotImplementedException(); } - public override AgentThread GetNewThread() + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) { throw new NotImplementedException(); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs index 1777ff456a..ae40948a5b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs @@ -334,9 +334,9 @@ protected override async IAsyncEnumerable RunCoreStreami await Task.CompletedTask; } - public override AgentThread GetNewThread() => new FakeInMemoryAgentThread(); + public override AgentThread GetNewThread(IAgentFeatureCollection? featureCollection = null) => new FakeInMemoryAgentThread(); - public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + public override AgentThread DeserializeThread(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null, IAgentFeatureCollection? featureCollection = null) { return new FakeInMemoryAgentThread(serializedThread, jsonSerializerOptions); } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs index dc1407154e..8edbc2cadc 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeThreadTests.cs @@ -85,7 +85,7 @@ public void DeserializeThread_UsesChatMessageStore_FromFeatureOverload() var mockMessageStore = new Mock(); var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { - Instructions = "Test instructions", + ChatOptions = new ChatOptions { Instructions = "Test instructions" }, ChatMessageStoreFactory = _ => { Assert.Fail("ChatMessageStoreFactory should not have been called."); @@ -117,7 +117,7 @@ public void DeserializeThread_UsesAIContextProvider_FromFeatureOverload() var mockContextProvider = new Mock(); var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { - Instructions = "Test instructions", + ChatOptions = new ChatOptions { Instructions = "Test instructions" }, AIContextProviderFactory = _ => { Assert.Fail("AIContextProviderFactory should not have been called.");