From e878be55a3663b7864bc0ef8b9526e2f0be2f88f Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Sun, 5 May 2024 07:51:00 -0700 Subject: [PATCH] [.Net] refactor over streaming version api (#2461) * update * update * fix comment --- .../CodeSnippet/AgentCodeSnippet.cs | 2 +- .../CodeSnippet/BuildInMessageCodeSnippet.cs | 4 +- .../CodeSnippet/MistralAICodeSnippet.cs | 4 +- .../CodeSnippet/OpenAICodeSnippet.cs | 4 +- .../CodeSnippet/SemanticKernelCodeSnippet.cs | 2 +- .../Example02_TwoAgent_MathChat.cs | 7 +- ...7_Dynamic_GroupChat_Calculate_Fibonacci.cs | 21 +--- .../Example10_SemanticKernel.cs | 2 +- .../Example13_OpenAIAgent_JsonMode.cs | 3 +- dotnet/sample/AutoGen.BasicSamples/Program.cs | 3 +- .../AutoGen.Core/Agent/IMiddlewareAgent.cs | 8 +- .../src/AutoGen.Core/Agent/IStreamingAgent.cs | 3 +- .../src/AutoGen.Core/Agent/MiddlewareAgent.cs | 26 ++-- .../Agent/MiddlewareStreamingAgent.cs | 65 +++++----- .../Extension/MiddlewareExtension.cs | 7 ++ .../PrintMessageMiddlewareExtension.cs | 2 +- .../Extension/StreamingMiddlewareExtension.cs | 85 +------------ .../Middleware/DelegateStreamingMiddleware.cs | 38 ------ .../Middleware/FunctionCallMiddleware.cs | 13 +- .../AutoGen.Core/Middleware/IMiddleware.cs | 2 +- .../Middleware/IStreamingMiddleware.cs | 12 +- .../Middleware/PrintMessageMiddleware.cs | 115 +++++++++++------- .../Extension/AgentExtension.cs | 6 +- .../Agent/MistralClientAgent.cs | 10 +- .../Extension/MistralAgentExtension.cs | 7 +- .../Middleware/MistralChatMessageConnector.cs | 9 +- dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs | 12 +- .../AutoGen.OpenAI/Agent/OpenAIChatAgent.cs | 10 +- .../OpenAIChatRequestMessageConnector.cs | 20 ++- ...manticKernelChatMessageContentConnector.cs | 12 +- .../SemanticKernelAgent.cs | 28 ++--- dotnet/src/AutoGen/Agent/ConversableAgent.cs | 36 ++++-- dotnet/src/AutoGen/AutoGen.csproj | 1 - .../MistralClientAgentTests.cs | 12 +- .../test/AutoGen.Tests/OpenAIChatAgentTest.cs | 13 +- .../AutoGen.Tests/RegisterReplyAgentTest.cs | 27 ---- .../AutoGen.Tests/SemanticKernelAgentTest.cs | 9 +- dotnet/test/AutoGen.Tests/SingleAgentTest.cs | 4 +- dotnet/website/update.md | 5 +- 39 files changed, 255 insertions(+), 394 deletions(-) delete mode 100644 dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs delete mode 100644 dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs index df45e4bfe9f..abaf94cbd4f 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/AgentCodeSnippet.cs @@ -19,7 +19,7 @@ public async Task ChatWithAnAgent(IStreamingAgent agent) #region ChatWithAnAgent_GenerateStreamingReplyAsync var textMessage = new TextMessage(Role.User, "Hello"); - await foreach (var streamingReply in await agent.GenerateStreamingReplyAsync([message])) + await foreach (var streamingReply in agent.GenerateStreamingReplyAsync([message])) { if (streamingReply is TextMessageUpdate update) { diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs index b272ba23a03..f26485116c8 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/BuildInMessageCodeSnippet.cs @@ -11,7 +11,7 @@ public async Task StreamingCallCodeSnippetAsync() IStreamingAgent agent = default; #region StreamingCallCodeSnippet var helloTextMessage = new TextMessage(Role.User, "Hello"); - var reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]); + var reply = agent.GenerateStreamingReplyAsync([helloTextMessage]); var finalTextMessage = new TextMessage(Role.Assistant, string.Empty, from: agent.Name); await foreach (var message in reply) { @@ -24,7 +24,7 @@ await foreach (var message in reply) #endregion StreamingCallCodeSnippet #region StreamingCallWithFinalMessage - reply = await agent.GenerateStreamingReplyAsync([helloTextMessage]); + reply = agent.GenerateStreamingReplyAsync([helloTextMessage]); TextMessage finalMessage = null; await foreach (var message in reply) { diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs index 6bb9e910730..cd49810dc6c 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/MistralAICodeSnippet.cs @@ -38,7 +38,7 @@ public async Task CreateMistralAIClientAsync() #endregion create_mistral_agent #region streaming_chat - var reply = await agent.GenerateStreamingReplyAsync( + var reply = agent.GenerateStreamingReplyAsync( messages: [new TextMessage(Role.User, "Hello, how are you?")] ); @@ -75,7 +75,7 @@ public async Task MistralAIChatAgentGetWeatherToolUsageAsync() #endregion create_get_weather_function_call_middleware #region register_function_call_middleware - agent = agent.RegisterMiddleware(functionCallMiddleware); + agent = agent.RegisterStreamingMiddleware(functionCallMiddleware); #endregion register_function_call_middleware #region send_message_with_function_call diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs index 8d129e75157..022f7e9f984 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/OpenAICodeSnippet.cs @@ -60,7 +60,7 @@ public async Task CreateOpenAIChatAgentAsync() #endregion create_openai_chat_agent #region create_openai_chat_agent_streaming - var streamingReply = await openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); await foreach (var streamingMessage in streamingReply) { @@ -123,7 +123,7 @@ public async Task OpenAIChatAgentGetWeatherFunctionCallAsync() { functions.GetWeatherFunctionContract.Name, functions.GetWeatherWrapper } // GetWeatherWrapper is a wrapper function for GetWeather, which is also auto-generated }); - openAIChatAgent = openAIChatAgent.RegisterMiddleware(functionCallMiddleware); + openAIChatAgent = openAIChatAgent.RegisterStreamingMiddleware(functionCallMiddleware); #endregion create_function_call_middleware #region chat_agent_send_function_call diff --git a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs index 77f93fdf4aa..b0366eb2b3f 100644 --- a/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs +++ b/dotnet/sample/AutoGen.BasicSamples/CodeSnippet/SemanticKernelCodeSnippet.cs @@ -49,7 +49,7 @@ public async Task CreateSemanticKernelAgentAsync() #endregion create_semantic_kernel_agent #region create_semantic_kernel_agent_streaming - var streamingReply = await semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + var streamingReply = semanticKernelAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); await foreach (var streamingMessage in streamingReply) { diff --git a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs index 8d42b9d0504..f20b0848a3e 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example02_TwoAgent_MathChat.cs @@ -18,16 +18,17 @@ public static async Task RunAsync() var teacher = new AssistantAgent( name: "teacher", systemMessage: @"You are a teacher that create pre-school math question for student and check answer. - If the answer is correct, you terminate conversation by saying [TERMINATE]. + If the answer is correct, you stop the conversation by saying [COMPLETE]. If the answer is wrong, you ask student to fix it.", llmConfig: new ConversableAgentConfig { Temperature = 0, ConfigList = [gpt35], }) - .RegisterPostProcess(async (_, reply, _) => + .RegisterMiddleware(async (msgs, option, agent, _) => { - if (reply.GetContent()?.ToLower().Contains("terminate") is true) + var reply = await agent.GenerateReplyAsync(msgs, option); + if (reply.GetContent()?.ToLower().Contains("complete") is true) { return new TextMessage(Role.Assistant, GroupChatExtension.TERMINATE, from: reply.From); } diff --git a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs index 6b1dc0965ee..89e6f45f898 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example07_Dynamic_GroupChat_Calculate_Fibonacci.cs @@ -85,26 +85,16 @@ public static async Task CreateRunnerAgentAsync(InteractiveService servi systemMessage: "You run dotnet code", defaultReply: "No code available.") .RegisterDotnetCodeBlockExectionHook(interactiveService: service) - .RegisterReply(async (msgs, _) => + .RegisterMiddleware(async (msgs, option, agent, _) => { - if (msgs.Count() == 0) + if (msgs.Count() == 0 || msgs.All(msg => msg.From != "coder")) { return new TextMessage(Role.Assistant, "No code available. Coder please write code"); } - - return null; - }) - .RegisterPreProcess(async (msgs, _) => - { - // retrieve the most recent message from coder - var coderMsg = msgs.LastOrDefault(msg => msg.From == "coder"); - if (coderMsg is null) - { - return Enumerable.Empty(); - } else { - return new[] { coderMsg }; + var coderMsg = msgs.Last(msg => msg.From == "coder"); + return await agent.GenerateReplyAsync([coderMsg], option); } }) .RegisterPrintMessage(); @@ -122,8 +112,9 @@ public static async Task CreateAdminAsync() systemMessage: "You are group admin, terminate the group chat once task is completed by saying [TERMINATE] plus the final answer", temperature: 0, config: gpt3Config) - .RegisterPostProcess(async (_, reply, _) => + .RegisterMiddleware(async (msgs, option, agent, _) => { + var reply = await agent.GenerateReplyAsync(msgs, option); if (reply is TextMessage textMessage && textMessage.Content.Contains("TERMINATE") is true) { var content = $"{textMessage.Content}\n\n {GroupChatExtension.TERMINATE}"; diff --git a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs index e4ef7de9df7..61c341204ec 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example10_SemanticKernel.cs @@ -62,7 +62,7 @@ public static async Task RunAsync() Console.WriteLine((reply as IMessage).Content.Items[0].As().Text); var skAgentWithMiddleware = skAgent - .RegisterMessageConnector() + .RegisterMessageConnector() // Register the message connector to support more AutoGen built-in message types .RegisterPrintMessage(); // Now the skAgentWithMiddleware supports more IMessage types like TextMessage, ImageMessage or MultiModalMessage diff --git a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs index 2591ab23016..35b7b7d1d2f 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Example13_OpenAIAgent_JsonMode.cs @@ -28,7 +28,8 @@ public static async Task RunAsync() systemMessage: "You are a helpful assistant designed to output JSON.", seed: 0, // explicitly set a seed to enable deterministic output responseFormat: ChatCompletionsResponseFormat.JsonObject) // set response format to JSON object to enable JSON mode - .RegisterMessageConnector(); + .RegisterMessageConnector() + .RegisterPrintMessage(); #endregion create_agent #region chat_with_agent diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs index bddbb68bf48..fb0bacbb5a1 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Program.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Program.cs -using AutoGen.BasicSample; -await Example14_MistralClientAgent_TokenCount.RunAsync(); +await Example02_TwoAgent_MathChat.RunAsync(); diff --git a/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs b/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs index 7b318183d52..a0b01e7c3e2 100644 --- a/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/IMiddlewareAgent.cs @@ -23,7 +23,7 @@ public interface IMiddlewareAgent : IAgent void Use(IMiddleware middleware); } -public interface IMiddlewareStreamAgent : IMiddlewareAgent, IStreamingAgent +public interface IMiddlewareStreamAgent : IStreamingAgent { /// /// Get the inner agent. @@ -44,7 +44,11 @@ public interface IMiddlewareAgent : IMiddlewareAgent T TAgent { get; } } -public interface IMiddlewareStreamAgent : IMiddlewareStreamAgent, IMiddlewareAgent +public interface IMiddlewareStreamAgent : IMiddlewareStreamAgent where T : IStreamingAgent { + /// + /// Get the typed inner agent. + /// + T TStreamingAgent { get; } } diff --git a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs index f4004b1397b..665f18bac12 100644 --- a/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/IStreamingAgent.cs @@ -3,7 +3,6 @@ using System.Collections.Generic; using System.Threading; -using System.Threading.Tasks; namespace AutoGen.Core; @@ -12,7 +11,7 @@ namespace AutoGen.Core; /// public interface IStreamingAgent : IAgent { - public Task> GenerateStreamingReplyAsync( + public IAsyncEnumerable GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default); diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs index 307e0da79ae..84d0d4b59e6 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareAgent.cs @@ -14,7 +14,7 @@ namespace AutoGen.Core; /// public class MiddlewareAgent : IMiddlewareAgent { - private readonly IAgent _agent; + private IAgent _agent; private readonly List middlewares = new(); /// @@ -22,10 +22,17 @@ public class MiddlewareAgent : IMiddlewareAgent /// /// the inner agent where middleware will be added. /// the name of the agent if provided. Otherwise, the name of will be used. - public MiddlewareAgent(IAgent innerAgent, string? name = null) + public MiddlewareAgent(IAgent innerAgent, string? name = null, IEnumerable? middlewares = null) { this.Name = name ?? innerAgent.Name; this._agent = innerAgent; + if (middlewares != null && middlewares.Any()) + { + foreach (var middleware in middlewares) + { + this.Use(middleware); + } + } } /// @@ -55,13 +62,7 @@ public MiddlewareAgent(MiddlewareAgent other) GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - IAgent agent = this._agent; - foreach (var middleware in this.middlewares) - { - agent = new DelegateAgent(middleware, agent); - } - - return agent.GenerateReplyAsync(messages, options, cancellationToken); + return _agent.GenerateReplyAsync(messages, options, cancellationToken); } /// @@ -71,15 +72,18 @@ public MiddlewareAgent(MiddlewareAgent other) /// public void Use(Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, string? middlewareName = null) { - this.middlewares.Add(new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) => + var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) => { return await func(context.Messages, context.Options, agent, cancellationToken); - })); + }); + + this.Use(middleware); } public void Use(IMiddleware middleware) { this.middlewares.Add(middleware); + _agent = new DelegateAgent(middleware, _agent); } public override string ToString() diff --git a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs index b83922227b7..251d3c110f9 100644 --- a/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs +++ b/dotnet/src/AutoGen.Core/Agent/MiddlewareStreamingAgent.cs @@ -2,33 +2,31 @@ // MiddlewareStreamingAgent.cs using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; namespace AutoGen.Core; -public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent +public class MiddlewareStreamingAgent : IMiddlewareStreamAgent { - private readonly IStreamingAgent _agent; + private IStreamingAgent _agent; private readonly List _streamingMiddlewares = new(); - private readonly List _middlewares = new(); public MiddlewareStreamingAgent( IStreamingAgent agent, string? name = null, - IEnumerable? streamingMiddlewares = null, - IEnumerable? middlewares = null) - : base(agent, name) + IEnumerable? streamingMiddlewares = null) { + this.Name = name ?? agent.Name; _agent = agent; - if (streamingMiddlewares != null) - { - _streamingMiddlewares.AddRange(streamingMiddlewares); - } - if (middlewares != null) + if (streamingMiddlewares != null && streamingMiddlewares.Any()) { - _middlewares.AddRange(middlewares); + foreach (var middleware in streamingMiddlewares) + { + this.UseStreaming(middleware); + } } } @@ -42,26 +40,28 @@ public class MiddlewareStreamingAgent : MiddlewareAgent, IMiddlewareStreamAgent /// public IEnumerable StreamingMiddlewares => _streamingMiddlewares; - public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public string Name { get; } + + public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + return _agent.GenerateReplyAsync(messages, options, cancellationToken); + } + + public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var agent = _agent; - foreach (var middleware in _streamingMiddlewares) - { - agent = new DelegateStreamingAgent(middleware, agent); - } - return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); + return _agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); } public void UseStreaming(IStreamingMiddleware middleware) { _streamingMiddlewares.Add(middleware); + _agent = new DelegateStreamingAgent(middleware, _agent); } private class DelegateStreamingAgent : IStreamingAgent { private IStreamingMiddleware? streamingMiddleware; - private IMiddleware? middleware; private IStreamingAgent innerAgent; public string Name => innerAgent.Name; @@ -72,24 +72,19 @@ public DelegateStreamingAgent(IStreamingMiddleware middleware, IStreamingAgent n this.innerAgent = next; } - public DelegateStreamingAgent(IMiddleware middleware, IStreamingAgent next) - { - this.middleware = middleware; - this.innerAgent = next; - } - public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - if (middleware is null) + if (this.streamingMiddleware is null) { - return await innerAgent.GenerateReplyAsync(messages, options, cancellationToken); + return innerAgent.GenerateReplyAsync(messages, options, cancellationToken); } var context = new MiddlewareContext(messages, options); - return await middleware.InvokeAsync(context, innerAgent, cancellationToken); + return this.streamingMiddleware.InvokeAsync(context, (IAgent)innerAgent, cancellationToken); } - public Task> GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + public IAsyncEnumerable GenerateStreamingReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { if (streamingMiddleware is null) { @@ -105,20 +100,20 @@ public Task> GenerateStreamingReplyAsync(IEn public sealed class MiddlewareStreamingAgent : MiddlewareStreamingAgent, IMiddlewareStreamAgent where T : IStreamingAgent { - public MiddlewareStreamingAgent(T innerAgent, string? name = null) - : base(innerAgent, name) + public MiddlewareStreamingAgent(T innerAgent, string? name = null, IEnumerable? streamingMiddlewares = null) + : base(innerAgent, name, streamingMiddlewares) { - TAgent = innerAgent; + TStreamingAgent = innerAgent; } public MiddlewareStreamingAgent(MiddlewareStreamingAgent other) : base(other) { - TAgent = other.TAgent; + TStreamingAgent = other.TStreamingAgent; } /// /// Get the inner agent. /// - public T TAgent { get; } + public T TStreamingAgent { get; } } diff --git a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs index c522c78f506..5beed7fd815 100644 --- a/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/MiddlewareExtension.cs @@ -20,6 +20,7 @@ public static class MiddlewareExtension /// /// /// throw when agent name is null. + [Obsolete("Use RegisterMiddleware instead.")] public static MiddlewareAgent RegisterReply( this TAgent agent, Func, CancellationToken, Task> replyFunc) @@ -45,6 +46,7 @@ public static class MiddlewareExtension /// One example is , which print the formatted message to console before the agent return the reply. /// /// throw when agent name is null. + [Obsolete("Use RegisterMiddleware instead.")] public static MiddlewareAgent RegisterPostProcess( this TAgent agent, Func, IMessage, CancellationToken, Task> postprocessFunc) @@ -62,6 +64,7 @@ public static class MiddlewareExtension /// Register a pre process hook to an agent. The hook will be called before the agent generate the reply. This is useful when you want to modify the conversation history before the agent generate the reply. /// /// throw when agent name is null. + [Obsolete("Use RegisterMiddleware instead.")] public static MiddlewareAgent RegisterPreProcess( this TAgent agent, Func, CancellationToken, Task>> preprocessFunc) @@ -77,6 +80,7 @@ public static class MiddlewareExtension /// /// Register a middleware to an existing agent and return a new agent with the middleware. + /// To register a streaming middleware, use . /// public static MiddlewareAgent RegisterMiddleware( this TAgent agent, @@ -94,6 +98,7 @@ public static class MiddlewareExtension /// /// Register a middleware to an existing agent and return a new agent with the middleware. + /// To register a streaming middleware, use . /// public static MiddlewareAgent RegisterMiddleware( this TAgent agent, @@ -107,6 +112,7 @@ public static class MiddlewareExtension /// /// Register a middleware to an existing agent and return a new agent with the middleware. + /// To register a streaming middleware, use . /// public static MiddlewareAgent RegisterMiddleware( this MiddlewareAgent agent, @@ -124,6 +130,7 @@ public static class MiddlewareExtension /// /// Register a middleware to an existing agent and return a new agent with the middleware. + /// To register a streaming middleware, use . /// public static MiddlewareAgent RegisterMiddleware( this MiddlewareAgent agent, diff --git a/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs index deb196ca324..262b50d125d 100644 --- a/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/PrintMessageMiddlewareExtension.cs @@ -62,7 +62,7 @@ public static MiddlewareStreamingAgent RegisterPrintMessage(this { var middleware = new PrintMessageMiddleware(); var middlewareAgent = new MiddlewareStreamingAgent(agent); - middlewareAgent.Use(middleware); + middlewareAgent.UseStreaming(middleware); return middlewareAgent; } diff --git a/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs b/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs index 901d7f2492a..2ec7b3f9f3b 100644 --- a/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs +++ b/dotnet/src/AutoGen.Core/Extension/StreamingMiddlewareExtension.cs @@ -1,17 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // StreamingMiddlewareExtension.cs -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; - namespace AutoGen.Core; public static class StreamingMiddlewareExtension { /// - /// Register a middleware to an existing agent and return a new agent with the middleware. + /// Register an to an existing and return a new agent with the registered middleware. + /// For registering an , please refer to /// public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this TStreamingAgent agent, @@ -21,16 +17,12 @@ public static class StreamingMiddlewareExtension var middlewareAgent = new MiddlewareStreamingAgent(agent); middlewareAgent.UseStreaming(middleware); - if (middleware is IMiddleware middlewareBase) - { - middlewareAgent.Use(middlewareBase); - } - return middlewareAgent; } /// - /// Register a middleware to an existing agent and return a new agent with the middleware. + /// Register an to an existing and return a new agent with the registered middleware. + /// For registering an , please refer to /// public static MiddlewareStreamingAgent RegisterStreamingMiddleware( this MiddlewareStreamingAgent agent, @@ -40,75 +32,6 @@ public static class StreamingMiddlewareExtension var copyAgent = new MiddlewareStreamingAgent(agent); copyAgent.UseStreaming(middleware); - if (middleware is IMiddleware middlewareBase) - { - copyAgent.Use(middlewareBase); - } - - return copyAgent; - } - - - /// - /// Register a middleware to an existing agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterStreamingMiddleware( - this TAgent agent, - Func>> func, - string? middlewareName = null) - where TAgent : IStreamingAgent - { - var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func)); - - return agent.RegisterStreamingMiddleware(middleware); - } - - /// - /// Register a streaming middleware to an existing agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterStreamingMiddleware( - this MiddlewareStreamingAgent agent, - Func>> func, - string? middlewareName = null) - where TAgent : IStreamingAgent - { - var middleware = new DelegateStreamingMiddleware(middlewareName, new DelegateStreamingMiddleware.MiddlewareDelegate(func)); - - return agent.RegisterStreamingMiddleware(middleware); - } - - /// - /// Register a middleware to an existing streaming agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterMiddleware( - this MiddlewareStreamingAgent streamingAgent, - Func, GenerateReplyOptions?, IAgent, CancellationToken, Task> func, - string? middlewareName = null) - where TStreamingAgent : IStreamingAgent - { - var middleware = new DelegateMiddleware(middlewareName, async (context, agent, cancellationToken) => - { - return await func(context.Messages, context.Options, agent, cancellationToken); - }); - - return streamingAgent.RegisterMiddleware(middleware); - } - - /// - /// Register a middleware to an existing streaming agent and return a new agent with the middleware. - /// - public static MiddlewareStreamingAgent RegisterMiddleware( - this MiddlewareStreamingAgent streamingAgent, - IMiddleware middleware) - where TStreamingAgent : IStreamingAgent - { - var copyAgent = new MiddlewareStreamingAgent(streamingAgent); - copyAgent.Use(middleware); - if (middleware is IStreamingMiddleware streamingMiddleware) - { - copyAgent.UseStreaming(streamingMiddleware); - } - return copyAgent; } } diff --git a/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs deleted file mode 100644 index 5499abccf4c..00000000000 --- a/dotnet/src/AutoGen.Core/Middleware/DelegateStreamingMiddleware.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// DelegateStreamingMiddleware.cs - -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; - -namespace AutoGen.Core; - -internal class DelegateStreamingMiddleware : IStreamingMiddleware -{ - public delegate Task> MiddlewareDelegate( - MiddlewareContext context, - IStreamingAgent agent, - CancellationToken cancellationToken); - - private readonly MiddlewareDelegate middlewareDelegate; - - public DelegateStreamingMiddleware(string? name, MiddlewareDelegate middlewareDelegate) - { - this.Name = name; - this.middlewareDelegate = middlewareDelegate; - } - - public string? Name { get; } - - public Task> InvokeAsync( - MiddlewareContext context, - IStreamingAgent agent, - CancellationToken cancellationToken = default) - { - var messages = context.Messages; - var options = context.Options; - - return this.middlewareDelegate(context, agent, cancellationToken); - } -} - diff --git a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs index d00151b32a8..2bc02805538 100644 --- a/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/FunctionCallMiddleware.cs @@ -29,7 +29,7 @@ namespace AutoGen.Core; /// If the streaming reply from the inner agent is other types of message, the most recent message will be used to invoke the function. /// /// -public class FunctionCallMiddleware : IMiddleware, IStreamingMiddleware +public class FunctionCallMiddleware : IStreamingMiddleware { private readonly IEnumerable? functions; private readonly IDictionary>>? functionMap; @@ -71,15 +71,10 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, return reply; } - public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) - { - return Task.FromResult(this.StreamingInvokeAsync(context, agent, cancellationToken)); - } - - private async IAsyncEnumerable StreamingInvokeAsync( + public async IAsyncEnumerable InvokeAsync( MiddlewareContext context, IStreamingAgent agent, - [EnumeratorCancellation] CancellationToken cancellationToken) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var lastMessage = context.Messages.Last(); if (lastMessage is ToolCallMessage toolCallMessage) @@ -93,7 +88,7 @@ public Task> InvokeAsync(MiddlewareContext c options.Functions = combinedFunctions?.ToArray(); IStreamingMessage? initMessage = default; - await foreach (var message in await agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken)) + await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, options, cancellationToken)) { if (message is ToolCallMessageUpdate toolCallMessageUpdate && this.functionMap != null) { diff --git a/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs index 2813ee9cdb4..00ec5a97fc2 100644 --- a/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/IMiddleware.cs @@ -7,7 +7,7 @@ namespace AutoGen.Core; /// -/// The middleware interface +/// The middleware interface. For streaming-version middleware, check . /// public interface IMiddleware { diff --git a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs index b8965dcc41c..bc7aec57f52 100644 --- a/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/IStreamingMiddleware.cs @@ -3,18 +3,18 @@ using System.Collections.Generic; using System.Threading; -using System.Threading.Tasks; namespace AutoGen.Core; /// -/// The streaming middleware interface +/// The streaming middleware interface. For non-streaming version middleware, check . /// -public interface IStreamingMiddleware +public interface IStreamingMiddleware : IMiddleware { - public string? Name { get; } - - public Task> InvokeAsync( + /// + /// The streaming version of . + /// + public IAsyncEnumerable InvokeAsync( MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default); diff --git a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs index 9461b697357..099f78e5f17 100644 --- a/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs +++ b/dotnet/src/AutoGen.Core/Middleware/PrintMessageMiddleware.cs @@ -2,6 +2,8 @@ // PrintMessageMiddleware.cs using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -10,7 +12,7 @@ namespace AutoGen.Core; /// /// The middleware that prints the reply from agent to the console. /// -public class PrintMessageMiddleware : IMiddleware +public class PrintMessageMiddleware : IStreamingMiddleware { public string? Name => nameof(PrintMessageMiddleware); @@ -19,51 +21,12 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, if (agent is IStreamingAgent streamingAgent) { IMessage? recentUpdate = null; - await foreach (var message in await streamingAgent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken)) + await foreach (var message in this.InvokeAsync(context, streamingAgent, cancellationToken)) { - if (message is TextMessageUpdate textMessageUpdate) - { - if (recentUpdate is null) - { - // Print from: xxx - Console.WriteLine($"from: {textMessageUpdate.From}"); - recentUpdate = new TextMessage(textMessageUpdate); - Console.Write(textMessageUpdate.Content); - } - else if (recentUpdate is TextMessage recentTextMessage) - { - // Print the content of the message - Console.Write(textMessageUpdate.Content); - recentTextMessage.Update(textMessageUpdate); - } - else - { - throw new InvalidOperationException("The recent update is not a TextMessage"); - } - } - else if (message is ToolCallMessageUpdate toolCallUpdate) - { - if (recentUpdate is null) - { - recentUpdate = new ToolCallMessage(toolCallUpdate); - } - else if (recentUpdate is ToolCallMessage recentToolCallMessage) - { - recentToolCallMessage.Update(toolCallUpdate); - } - else - { - throw new InvalidOperationException("The recent update is not a ToolCallMessage"); - } - } - else if (message is IMessage imessage) + if (message is IMessage imessage) { recentUpdate = imessage; } - else - { - throw new InvalidOperationException("The message is not a valid message"); - } } Console.WriteLine(); if (recentUpdate is not null && recentUpdate is not TextMessage) @@ -84,4 +47,72 @@ await foreach (var message in await streamingAgent.GenerateStreamingReplyAsync(c return reply; } } + + public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + IMessage? recentUpdate = null; + await foreach (var message in agent.GenerateStreamingReplyAsync(context.Messages, context.Options, cancellationToken)) + { + if (message is TextMessageUpdate textMessageUpdate) + { + if (recentUpdate is null) + { + // Print from: xxx + Console.WriteLine($"from: {textMessageUpdate.From}"); + recentUpdate = new TextMessage(textMessageUpdate); + Console.Write(textMessageUpdate.Content); + + yield return message; + } + else if (recentUpdate is TextMessage recentTextMessage) + { + // Print the content of the message + Console.Write(textMessageUpdate.Content); + recentTextMessage.Update(textMessageUpdate); + + yield return recentTextMessage; + } + else + { + throw new InvalidOperationException("The recent update is not a TextMessage"); + } + } + else if (message is ToolCallMessageUpdate toolCallUpdate) + { + if (recentUpdate is null) + { + recentUpdate = new ToolCallMessage(toolCallUpdate); + + yield return message; + } + else if (recentUpdate is ToolCallMessage recentToolCallMessage) + { + recentToolCallMessage.Update(toolCallUpdate); + + yield return message; + } + else + { + throw new InvalidOperationException("The recent update is not a ToolCallMessage"); + } + } + else if (message is IMessage imessage) + { + recentUpdate = imessage; + + yield return imessage; + } + else + { + throw new InvalidOperationException("The message is not a valid message"); + } + } + Console.WriteLine(); + if (recentUpdate is not null && recentUpdate is not TextMessage) + { + Console.WriteLine(recentUpdate.FormatMessage()); + } + + yield return recentUpdate ?? throw new InvalidOperationException("The message is not a valid message"); + } } diff --git a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs index 034ca170e3d..83955c53fa1 100644 --- a/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs +++ b/dotnet/src/AutoGen.DotnetInteractive/Extension/AgentExtension.cs @@ -28,19 +28,19 @@ public static class AgentExtension string codeBlockSuffix = "```", int maximumOutputToKeep = 500) { - return agent.RegisterReply(async (msgs, ct) => + return agent.RegisterMiddleware(async (msgs, option, innerAgent, ct) => { var lastMessage = msgs.LastOrDefault(); if (lastMessage == null || lastMessage.GetContent() is null) { - return null; + return await innerAgent.GenerateReplyAsync(msgs, option, ct); } // retrieve all code blocks from last message var codeBlocks = lastMessage.GetContent()!.Split(new[] { codeBlockPrefix }, StringSplitOptions.RemoveEmptyEntries); if (codeBlocks.Length <= 0) { - return null; + return await innerAgent.GenerateReplyAsync(msgs, option, ct); } // run code blocks diff --git a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs index 2ba28bbb701..cc2c7414550 100644 --- a/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs +++ b/dotnet/src/AutoGen.Mistral/Agent/MistralClientAgent.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using AutoGen.Core; @@ -77,19 +78,14 @@ public class MistralClientAgent : IStreamingAgent return new MessageEnvelope(response, from: this.Name); } - public async Task> GenerateStreamingReplyAsync( + public async IAsyncEnumerable GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var request = BuildChatRequest(messages, options); var response = _client.StreamingChatCompletionsAsync(request); - return ProcessMessage(response); - } - - private async IAsyncEnumerable ProcessMessage(IAsyncEnumerable response) - { await foreach (var content in response) { yield return new MessageEnvelope(content, from: this.Name); diff --git a/dotnet/src/AutoGen.Mistral/Extension/MistralAgentExtension.cs b/dotnet/src/AutoGen.Mistral/Extension/MistralAgentExtension.cs index 5b3c998b6c0..787393d067f 100644 --- a/dotnet/src/AutoGen.Mistral/Extension/MistralAgentExtension.cs +++ b/dotnet/src/AutoGen.Mistral/Extension/MistralAgentExtension.cs @@ -18,9 +18,7 @@ public static class MistralAgentExtension connector = new MistralChatMessageConnector(); } - return agent.RegisterStreamingMiddleware(connector) - .RegisterMiddleware(connector); - + return agent.RegisterStreamingMiddleware(connector); } /// @@ -34,7 +32,6 @@ public static class MistralAgentExtension connector = new MistralChatMessageConnector(); } - return agent.RegisterStreamingMiddleware(connector) - .RegisterMiddleware(connector); + return agent.RegisterStreamingMiddleware(connector); } } diff --git a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs index 44f34401e1c..3ba910aa700 100644 --- a/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs +++ b/dotnet/src/AutoGen.Mistral/Middleware/MistralChatMessageConnector.cs @@ -15,17 +15,12 @@ public class MistralChatMessageConnector : IStreamingMiddleware, IMiddleware { public string? Name => nameof(MistralChatMessageConnector); - public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) - { - return Task.FromResult(StreamingInvoke(context, agent, cancellationToken)); - } - - private async IAsyncEnumerable StreamingInvoke(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var messages = context.Messages; var chatMessages = ProcessMessage(messages, agent); var chunks = new List(); - await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken)) + await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken)) { if (reply is IStreamingMessage chatMessage) { diff --git a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs index cb5a97c1310..52070788e34 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/GPTAgent.cs @@ -90,30 +90,28 @@ public class GPTAgent : IStreamingAgent GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var agent = this._innerAgent - .RegisterMessageConnector(); + var agent = this._innerAgent.RegisterMessageConnector(); if (this.functionMap is not null) { var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); - agent = agent.RegisterMiddleware(functionMapMiddleware); + agent = agent.RegisterStreamingMiddleware(functionMapMiddleware); } return await agent.GenerateReplyAsync(messages, options, cancellationToken); } - public async Task> GenerateStreamingReplyAsync( + public IAsyncEnumerable GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) { - var agent = this._innerAgent - .RegisterMessageConnector(); + var agent = this._innerAgent.RegisterMessageConnector(); if (this.functionMap is not null) { var functionMapMiddleware = new FunctionCallMiddleware(functionMap: this.functionMap); agent = agent.RegisterStreamingMiddleware(functionMapMiddleware); } - return await agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); + return agent.GenerateStreamingReplyAsync(messages, options, cancellationToken); } } diff --git a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs index 487a361d7de..37a4882f69e 100644 --- a/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs +++ b/dotnet/src/AutoGen.OpenAI/Agent/OpenAIChatAgent.cs @@ -87,15 +87,7 @@ public class OpenAIChatAgent : IStreamingAgent return new MessageEnvelope(reply, from: this.Name); } - public Task> GenerateStreamingReplyAsync( - IEnumerable messages, - GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) - { - return Task.FromResult(this.StreamingReplyAsync(messages, options, cancellationToken)); - } - - private async IAsyncEnumerable StreamingReplyAsync( + public async IAsyncEnumerable GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) diff --git a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs index 118d99703ab..2bd9470ffa7 100644 --- a/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs +++ b/dotnet/src/AutoGen.OpenAI/Middleware/OpenAIChatRequestMessageConnector.cs @@ -44,22 +44,14 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, return PostProcessMessage(reply); } - public async Task> InvokeAsync( + public async IAsyncEnumerable InvokeAsync( MiddlewareContext context, IStreamingAgent agent, - CancellationToken cancellationToken = default) - { - return InvokeStreamingAsync(context, agent, cancellationToken); - } - - private async IAsyncEnumerable InvokeStreamingAsync( - MiddlewareContext context, - IStreamingAgent agent, - [EnumeratorCancellation] CancellationToken cancellationToken) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var chatMessages = ProcessIncomingMessages(agent, context.Messages) .Select(m => new MessageEnvelope(m)); - var streamingReply = await agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken); + var streamingReply = agent.GenerateStreamingReplyAsync(chatMessages, context.Options, cancellationToken); string? currentToolName = null; await foreach (var reply in streamingReply) { @@ -135,6 +127,12 @@ private IMessage PostProcessMessage(IMessage message) private IMessage PostProcessMessage(IMessage message) { + // throw exception if prompt filter results is not null + if (message.Content.Choices[0].FinishReason == CompletionsFinishReason.ContentFiltered) + { + throw new InvalidOperationException("The content is filtered because its potential risk. Please try another input."); + } + return PostProcessMessage(message.Content.Choices[0].Message, message.From); } diff --git a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs index 557683c9615..6a8395ef22e 100644 --- a/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs +++ b/dotnet/src/AutoGen.SemanticKernel/Middleware/SemanticKernelChatMessageContentConnector.cs @@ -47,20 +47,12 @@ public async Task InvokeAsync(MiddlewareContext context, IAgent agent, return PostProcessMessage(reply); } - public Task> InvokeAsync(MiddlewareContext context, IStreamingAgent agent, CancellationToken cancellationToken = default) - { - return Task.FromResult(InvokeStreamingAsync(context, agent, cancellationToken)); - } - - private async IAsyncEnumerable InvokeStreamingAsync( - MiddlewareContext context, - IStreamingAgent agent, - [EnumeratorCancellation] CancellationToken cancellationToken) + public async IAsyncEnumerable InvokeAsync(MiddlewareContext context, IStreamingAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { var chatMessageContents = ProcessMessage(context.Messages, agent) .Select(m => new MessageEnvelope(m)); - await foreach (var reply in await agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken)) + await foreach (var reply in agent.GenerateStreamingReplyAsync(chatMessageContents, context.Options, cancellationToken)) { yield return PostProcessStreamingMessage(reply); } diff --git a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs index b887a6ef586..21f652f56c4 100644 --- a/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs +++ b/dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.SemanticKernel; @@ -64,17 +65,25 @@ public async Task GenerateReplyAsync(IEnumerable messages, G return new MessageEnvelope(reply.First(), from: this.Name); } - public async Task> GenerateStreamingReplyAsync( + public async IAsyncEnumerable GenerateStreamingReplyAsync( IEnumerable messages, GenerateReplyOptions? options = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var chatHistory = BuildChatHistory(messages); var option = BuildOption(options); var chatService = _kernel.GetRequiredService(); var response = chatService.GetStreamingChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken); - return ProcessMessage(response); + await foreach (var content in response) + { + if (content.ChoiceIndex > 0) + { + throw new InvalidOperationException("Only one choice is supported in streaming response"); + } + + yield return new MessageEnvelope(content, from: this.Name); + } } private ChatHistory BuildChatHistory(IEnumerable messages) @@ -101,19 +110,6 @@ private PromptExecutionSettings BuildOption(GenerateReplyOptions? options) }; } - private async IAsyncEnumerable ProcessMessage(IAsyncEnumerable response) - { - await foreach (var content in response) - { - if (content.ChoiceIndex > 0) - { - throw new InvalidOperationException("Only one choice is supported in streaming response"); - } - - yield return new MessageEnvelope(content, from: this.Name); - } - } - private IEnumerable ProcessMessage(IEnumerable messages) { return messages.Select(m => m switch diff --git a/dotnet/src/AutoGen/Agent/ConversableAgent.cs b/dotnet/src/AutoGen/Agent/ConversableAgent.cs index e70a74a801c..d79d2519297 100644 --- a/dotnet/src/AutoGen/Agent/ConversableAgent.cs +++ b/dotnet/src/AutoGen/Agent/ConversableAgent.cs @@ -79,19 +79,33 @@ public class ConversableAgent : IAgent IAgent? agent = null; foreach (var llmConfig in config.ConfigList ?? Enumerable.Empty()) { - agent = agent switch + var nextAgent = llmConfig switch { - null => llmConfig switch - { - AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0), - OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0), - _ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"), - }, - IAgent innerAgent => innerAgent.RegisterReply(async (messages, cancellationToken) => - { - return await innerAgent.GenerateReplyAsync(messages, cancellationToken: cancellationToken); - }), + AzureOpenAIConfig azureConfig => new GPTAgent(this.Name!, this.systemMessage, azureConfig, temperature: config.Temperature ?? 0), + OpenAIConfig openAIConfig => new GPTAgent(this.Name!, this.systemMessage, openAIConfig, temperature: config.Temperature ?? 0), + _ => throw new ArgumentException($"Unsupported config type {llmConfig.GetType()}"), }; + + if (agent == null) + { + agent = nextAgent; + } + else + { + agent = agent.RegisterMiddleware(async (messages, option, agent, cancellationToken) => + { + var agentResponse = await nextAgent.GenerateReplyAsync(messages, option, cancellationToken: cancellationToken); + + if (agentResponse is null) + { + return await agent.GenerateReplyAsync(messages, option, cancellationToken); + } + else + { + return agentResponse; + } + }); + } } return agent; diff --git a/dotnet/src/AutoGen/AutoGen.csproj b/dotnet/src/AutoGen/AutoGen.csproj index 2b9aaed6dd5..8f4bbccb5d2 100644 --- a/dotnet/src/AutoGen/AutoGen.csproj +++ b/dotnet/src/AutoGen/AutoGen.csproj @@ -20,7 +20,6 @@ - diff --git a/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs b/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs index 110e81fdb21..5a9d1f95c73 100644 --- a/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs +++ b/dotnet/test/AutoGen.Mistral.Tests/MistralClientAgentTests.cs @@ -114,7 +114,7 @@ public async Task MistralAgentTwoAgentFunctionCallTest() model: "mistral-small-latest", randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); var functionCallMiddlewareExecutorMiddleware = new FunctionCallMiddleware( functionMap: new Dictionary>> @@ -127,7 +127,7 @@ public async Task MistralAgentTwoAgentFunctionCallTest() model: "mistral-small-latest", randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddlewareExecutorMiddleware); + .RegisterStreamingMiddleware(functionCallMiddlewareExecutorMiddleware); await twoAgentTest.TwoAgentGetWeatherFunctionCallTestAsync(executorAgent, functionCallAgent); } @@ -148,7 +148,7 @@ public async Task MistralAgentFunctionCallMiddlewareMessageTest() model: "mistral-small-latest", randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); var question = new TextMessage(Role.User, "what's the weather in Seattle?"); var reply = await functionCallAgent.SendAsync(question); @@ -193,7 +193,7 @@ public async Task MistralAgentFunctionCallAutoInvokeTestAsync() toolChoice: ToolChoiceEnum.Any, randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); await singleAgentTest.EchoFunctionCallExecutionTestAsync(agent); await singleAgentTest.EchoFunctionCallExecutionStreamingTestAsync(agent); } @@ -214,7 +214,7 @@ public async Task MistralAgentFunctionCallTestAsync() systemMessage: "You are a helpful assistant that can call functions", randomSeed: 0) .RegisterMessageConnector() - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); await singleAgentTest.EchoFunctionCallTestAsync(agent); @@ -222,7 +222,7 @@ public async Task MistralAgentFunctionCallTestAsync() var question = new TextMessage(Role.User, "what's the weather in Seattle?"); IMessage? finalReply = null; - await foreach (var reply in await agent.GenerateStreamingReplyAsync([question])) + await foreach (var reply in agent.GenerateStreamingReplyAsync([question])) { reply.From.Should().Be(agent.Name); if (reply is IMessage message) diff --git a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs b/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs index a4753b66871..c504eb06a18 100644 --- a/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/OpenAIChatAgentTest.cs @@ -47,7 +47,7 @@ public async Task BasicConversationTestAsync() reply.As>().Content.Usage.TotalTokens.Should().BeGreaterThan(0); // test streaming - var streamingReply = await openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + var streamingReply = openAIChatAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); await foreach (var streamingMessage in streamingReply) { @@ -93,7 +93,7 @@ public async Task OpenAIChatMessageContentConnectorTestAsync() // test streaming foreach (var message in messages) { - var reply = await assistant.GenerateStreamingReplyAsync([message]); + var reply = assistant.GenerateStreamingReplyAsync([message]); await foreach (var streamingMessage in reply) { @@ -119,10 +119,9 @@ public async Task OpenAIChatAgentToolCallTestAsync() MiddlewareStreamingAgent assistant = openAIChatAgent .RegisterMessageConnector(); - assistant.Middlewares.Count().Should().Be(1); assistant.StreamingMiddlewares.Count().Should().Be(1); var functionCallAgent = assistant - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); var question = "What's the weather in Seattle"; var messages = new IMessage[] @@ -150,7 +149,7 @@ public async Task OpenAIChatAgentToolCallTestAsync() // test streaming foreach (var message in messages) { - var reply = await functionCallAgent.GenerateStreamingReplyAsync([message]); + var reply = functionCallAgent.GenerateStreamingReplyAsync([message]); ToolCallMessage? toolCallMessage = null; await foreach (var streamingMessage in reply) { @@ -191,7 +190,7 @@ public async Task OpenAIChatAgentToolCallInvokingTestAsync() .RegisterMessageConnector(); var functionCallAgent = assistant - .RegisterMiddleware(functionCallMiddleware); + .RegisterStreamingMiddleware(functionCallMiddleware); var question = "What's the weather in Seattle"; var messages = new IMessage[] @@ -220,7 +219,7 @@ public async Task OpenAIChatAgentToolCallInvokingTestAsync() // test streaming foreach (var message in messages) { - var reply = await functionCallAgent.GenerateStreamingReplyAsync([message]); + var reply = functionCallAgent.GenerateStreamingReplyAsync([message]); await foreach (var streamingMessage in reply) { if (streamingMessage is not IMessage) diff --git a/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs b/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs deleted file mode 100644 index d4866ad8736..00000000000 --- a/dotnet/test/AutoGen.Tests/RegisterReplyAgentTest.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// RegisterReplyAgentTest.cs - -using System.Threading.Tasks; -using FluentAssertions; -using Xunit; - -namespace AutoGen.Tests -{ - public class RegisterReplyAgentTest - { - [Fact] - public async Task RegisterReplyTestAsync() - { - IAgent echoAgent = new EchoAgent("echo"); - echoAgent = echoAgent - .RegisterReply(async (conversations, ct) => new TextMessage(Role.Assistant, "I'm your father", from: echoAgent.Name)); - - var msg = new Message(Role.User, "hey"); - var reply = await echoAgent.SendAsync(msg); - reply.Should().BeOfType(); - reply.GetContent().Should().Be("I'm your father"); - reply.GetRole().Should().Be(Role.Assistant); - reply.From.Should().Be("echo"); - } - } -} diff --git a/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs b/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs index 2e5b56f8091..dcb5cd47b0d 100644 --- a/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SemanticKernelAgentTest.cs @@ -44,7 +44,7 @@ public async Task BasicConversationTestAsync() reply.As>().From.Should().Be("assistant"); // test streaming - var streamingReply = await skAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); + var streamingReply = skAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent }); await foreach (var streamingMessage in streamingReply) { @@ -63,10 +63,8 @@ public async Task SemanticKernelChatMessageContentConnectorTestAsync() var kernel = builder.Build(); - var connector = new SemanticKernelChatMessageContentConnector(); var skAgent = new SemanticKernelAgent(kernel, "assistant") - .RegisterStreamingMiddleware(connector) - .RegisterMiddleware(connector); + .RegisterMessageConnector(); var messages = new IMessage[] { @@ -90,7 +88,7 @@ public async Task SemanticKernelChatMessageContentConnectorTestAsync() // test streaming foreach (var message in messages) { - var reply = await skAgent.GenerateStreamingReplyAsync([message]); + var reply = skAgent.GenerateStreamingReplyAsync([message]); await foreach (var streamingMessage in reply) { @@ -122,7 +120,6 @@ public async Task SemanticKernelPluginTestAsync() var skAgent = new SemanticKernelAgent(kernel, "assistant") .RegisterMessageConnector(); - skAgent.Middlewares.Count().Should().Be(1); skAgent.StreamingMiddlewares.Count().Should().Be(1); var question = "What is the weather in Seattle?"; diff --git a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs index 6dfb61761eb..ae566889bf5 100644 --- a/dotnet/test/AutoGen.Tests/SingleAgentTest.cs +++ b/dotnet/test/AutoGen.Tests/SingleAgentTest.cs @@ -261,7 +261,7 @@ public async Task EchoFunctionCallExecutionStreamingTestAsync(IStreamingAgent ag { Temperature = 0, }; - var replyStream = await agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); + var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); var answer = "[ECHO] Hello world"; IStreamingMessage? finalReply = default; await foreach (var reply in replyStream) @@ -302,7 +302,7 @@ public async Task UpperCaseStreamingTestAsync(IStreamingAgent agent) { Temperature = 0, }; - var replyStream = await agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); + var replyStream = agent.GenerateStreamingReplyAsync(messages: new[] { message, helloWorld }, option); var answer = "A B C D E F G H I J K L M N"; TextMessage? finalReply = default; await foreach (var reply in replyStream) diff --git a/dotnet/website/update.md b/dotnet/website/update.md index a97b9480514..b65ab128ef7 100644 --- a/dotnet/website/update.md +++ b/dotnet/website/update.md @@ -1,6 +1,9 @@ +##### Update +- [API Breaking Change] Update the return type of `IStreamingAgent.GenerateStreamingReplyAsync` from `Task>` to `IAsyncEnumerable` +- [API Breaking Change] Update the return type of `IStreamingMiddleware.InvokeAsync` from `Task>` to `IAsyncEnumerable` +- [API Breaking Change] Mark `RegisterReply`, `RegisterPreProcess` and `RegisterPostProcess` as obsolete. You can replace them with `RegisterMiddleware` ##### Update on 0.0.12 (2024-04-22) - Add AutoGen.Mistral package to support Mistral.AI models - ##### Update on 0.0.11 (2024-04-10) - Add link to Discord channel in nuget's readme.md - Document improvements